diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index b7ff33280..e106cf3c2 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -190,6 +190,7 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): ) await read_stream_writer.aclose() await write_stream_reader.aclose() + await sse_stream_reader.aclose() logging.debug(f"Client session disconnected {session_id}") logger.debug("Starting SSE response task") diff --git a/src/mcp/server/streaming_asgi_transport.py b/src/mcp/server/streaming_asgi_transport.py index a74751312..690a0c392 100644 --- a/src/mcp/server/streaming_asgi_transport.py +++ b/src/mcp/server/streaming_asgi_transport.py @@ -9,6 +9,7 @@ """ import typing +from collections.abc import Awaitable, Callable from typing import Any, cast import anyio @@ -65,6 +66,8 @@ async def handle_async_request( ) -> Response: assert isinstance(request.stream, AsyncByteStream) + disconnect_event = anyio.Event() + # ASGI scope. scope = { "type": "http", @@ -97,11 +100,17 @@ async def handle_async_request( content_send_channel, content_receive_channel = anyio.create_memory_object_stream[bytes](100) # ASGI callables. + async def send_disconnect() -> None: + disconnect_event.set() + async def receive() -> dict[str, Any]: nonlocal request_complete + if disconnect_event.is_set(): + return {"type": "http.disconnect"} + if request_complete: - await response_complete.wait() + await disconnect_event.wait() return {"type": "http.disconnect"} try: @@ -140,7 +149,9 @@ async def process_messages() -> None: async with asgi_receive_channel: async for message in asgi_receive_channel: if message["type"] == "http.response.start": - assert not response_started + if response_started: + # Ignore duplicate response.start from ASGI app during SSE disconnect + continue status_code = message["status"] response_headers = message.get("headers", []) response_started = True @@ -176,7 +187,7 @@ async def process_messages() -> None: return Response( status_code, headers=response_headers, - stream=StreamingASGIResponseStream(content_receive_channel), + stream=StreamingASGIResponseStream(content_receive_channel, send_disconnect), ) @@ -192,12 +203,18 @@ class StreamingASGIResponseStream(AsyncByteStream): def __init__( self, receive_channel: anyio.streams.memory.MemoryObjectReceiveStream[bytes], + send_disconnect: Callable[[], Awaitable[None]], ) -> None: self.receive_channel = receive_channel + self.send_disconnect = send_disconnect async def __aiter__(self) -> typing.AsyncIterator[bytes]: try: async for chunk in self.receive_channel: yield chunk finally: - await self.receive_channel.aclose() + await self.aclose() + + async def aclose(self) -> None: + await self.receive_channel.aclose() + await self.send_disconnect() diff --git a/tests/conftest.py b/tests/conftest.py index af7e47993..75da636b0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,45 @@ +import anyio import pytest +import sse_starlette +from packaging import version @pytest.fixture def anyio_backend(): return "asyncio" + + +SSE_STARLETTE_VERSION = version.parse(sse_starlette.__version__) +NEEDS_RESET = SSE_STARLETTE_VERSION < version.parse("3.0.0") + + +@pytest.fixture(autouse=True) +def reset_sse_app_status(): + """Reset sse-starlette's global AppStatus singleton before each test. + + AppStatus.should_exit_event is a global asyncio.Event that gets bound to + an event loop. This ensures each test gets a fresh Event and prevents + RuntimeError("bound to a different event loop") during parallel test + execution with pytest-xdist. + + NOTE: This fixture is only necessary for sse-starlette < 3.0.0. + Version 3.0+ eliminated the global state issue entirely by using + context-local events instead of module-level singletons, providing + automatic test isolation without manual cleanup. + + See for more details. + """ + if not NEEDS_RESET: + yield + return + + # lazy import to avoid import errors + from sse_starlette.sse import AppStatus + + # Setup: Reset before test + AppStatus.should_exit_event = anyio.Event() # type: ignore[attr-defined] + + yield + + # Teardown: Reset after test to prevent contamination + AppStatus.should_exit_event = anyio.Event() # type: ignore[attr-defined] diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index fdb6ccfd8..8a4438eb1 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,14 +1,11 @@ import json -import multiprocessing -import socket -import time -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator from typing import Any import anyio import httpx import pytest -import uvicorn +from anyio.abc import TaskGroup from inline_snapshot import snapshot from pydantic import AnyUrl from starlette.applications import Starlette @@ -21,7 +18,9 @@ from mcp.client.sse import sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport +from mcp.server.streaming_asgi_transport import StreamingASGITransport from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._httpx_utils import McpHttpClientFactory from mcp.shared.exceptions import McpError from mcp.types import ( EmptyResult, @@ -32,21 +31,10 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server SERVER_NAME = "test_server_for_SSE" - - -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - return f"http://127.0.0.1:{server_port}" +TEST_SERVER_HOST = "testserver" +TEST_SERVER_BASE_URL = f"http://{TEST_SERVER_HOST}" # Test server implementation @@ -80,116 +68,122 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] return [TextContent(type="text", text=f"Called {name}")] -# Test fixtures -def make_server_app() -> Starlette: - """Create test Starlette app with SSE transport""" - # Configure security with allowed hosts/origins for testing +def create_asgi_client_factory(app: Starlette, tg: TaskGroup) -> McpHttpClientFactory: + """Factory function to create httpx clients with StreamingASGITransport""" + + def asgi_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + transport = StreamingASGITransport(app=app, task_group=tg) + return httpx.AsyncClient( + transport=transport, base_url=TEST_SERVER_BASE_URL, headers=headers, timeout=timeout, auth=auth + ) + + return asgi_client_factory + + +def create_sse_app(server: Server) -> Starlette: + """Helper to create SSE app with given server""" security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + allowed_hosts=[TEST_SERVER_HOST], + allowed_origins=[TEST_SERVER_BASE_URL], ) sse = SseServerTransport("/messages/", security_settings=security_settings) - server = ServerTest() async def handle_sse(request: Request) -> Response: async with sse.connect_sse(request.scope, request.receive, request._send) as streams: await server.run(streams[0], streams[1], server.create_initialization_options()) return Response() - app = Starlette( + return Starlette( routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ] ) - return app - - -def run_server(server_port: int) -> None: - app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) +# Test fixtures +@pytest.fixture() +def server_app() -> Starlette: + """Create test Starlette app with SSE transport""" + app = create_sse_app(ServerTest()) + return app @pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() +async def tg() -> AsyncGenerator[TaskGroup, None]: + async with anyio.create_task_group() as tg: + try: + yield tg + finally: + tg.cancel_scope.cancel() - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) - yield - - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): - print("server process failed to terminate") +@pytest.fixture() +async def http_client(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]: + """Create test client using StreamingASGITransport""" + transport = StreamingASGITransport(app=server_app, task_group=tg) + async with httpx.AsyncClient(transport=transport, base_url=TEST_SERVER_BASE_URL) as client: + yield client @pytest.fixture() -async def http_client(server: None, server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]: - """Create test client""" - async with httpx.AsyncClient(base_url=server_url) as client: - yield client +async def sse_client_session(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[ClientSession, None]: + asgi_client_factory = create_asgi_client_factory(server_app, tg) + + async with sse_client( + f"{TEST_SERVER_BASE_URL}/sse", + httpx_client_factory=asgi_client_factory, + ) as streams: + async with ClientSession(*streams) as session: + yield session # Tests @pytest.mark.anyio async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: """Test the SSE connection establishment simply with an HTTP client.""" - async with anyio.create_task_group(): - async def connection_test() -> None: - async with http_client.stream("GET", "/sse") as response: - assert response.status_code == 200 - assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + async def connection_test() -> None: + async with http_client.stream("GET", "/sse") as response: + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" - line_number = 0 - async for line in response.aiter_lines(): - if line_number == 0: - assert line == "event: endpoint" - elif line_number == 1: - assert line.startswith("data: /messages/?session_id=") - else: - return - line_number += 1 + line_number = 0 + async for line in response.aiter_lines(): + if line_number == 0: + assert line == "event: endpoint" + elif line_number == 1: + assert line.startswith("data: /messages/?session_id=") + else: + return + line_number += 1 - # Add timeout to prevent test from hanging if it fails - with anyio.fail_after(3): - await connection_test() + # Add timeout to prevent test from hanging if it fails + with anyio.fail_after(3): + await connection_test() @pytest.mark.anyio -async def test_sse_client_basic_connection(server: None, server_url: str) -> None: - async with sse_client(server_url + "/sse") as streams: - async with ClientSession(*streams) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.serverInfo.name == SERVER_NAME +async def test_sse_client_basic_connection(sse_client_session: ClientSession) -> None: + # Test initialization + result = await sse_client_session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME - # Test ping - ping_result = await session.send_ping() - assert isinstance(ping_result, EmptyResult) + # Test ping + ping_result = await sse_client_session.send_ping() + assert isinstance(ping_result, EmptyResult) @pytest.fixture -async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: - async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - yield session +async def initialized_sse_client_session(sse_client_session: ClientSession) -> AsyncGenerator[ClientSession, None]: + session = sse_client_session + await session.initialize() + yield session @pytest.mark.anyio @@ -232,51 +226,38 @@ async def test_sse_client_timeout( pytest.fail("the client should have timed out and returned an error already") -def run_mounted_server(server_port: int) -> None: - app = make_server_app() - main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) - server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) - - @pytest.fixture() -def mounted_server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) +async def mounted_server_app(server_app: Starlette) -> Starlette: + """Create a mounted server app""" + app = Starlette(routes=[Mount("/mounted_app", app=server_app)]) + return app - yield - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): - print("server process failed to terminate") +@pytest.fixture() +async def sse_client_mounted_server_app_session( + tg: TaskGroup, mounted_server_app: Starlette +) -> AsyncGenerator[ClientSession, None]: + asgi_client_factory = create_asgi_client_factory(mounted_server_app, tg) + + async with sse_client( + f"{TEST_SERVER_BASE_URL}/mounted_app/sse", + httpx_client_factory=asgi_client_factory, + ) as streams: + async with ClientSession(*streams) as session: + yield session @pytest.mark.anyio -async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None: - async with sse_client(server_url + "/mounted_app/sse") as streams: - async with ClientSession(*streams) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.serverInfo.name == SERVER_NAME +async def test_sse_client_basic_connection_mounted_app(sse_client_mounted_server_app_session: ClientSession) -> None: + session = sse_client_mounted_server_app_session + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME - # Test ping - ping_result = await session.send_ping() - assert isinstance(ping_result, EmptyResult) + # Test ping + ping_result = await session.send_ping() + assert isinstance(ping_result, EmptyResult) # Test server with request context that returns headers in the response @@ -322,54 +303,15 @@ async def handle_list_tools() -> list[Tool]: ] -def run_context_server(server_port: int) -> None: - """Run a server that captures request context""" - # Configure security with allowed hosts/origins for testing - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) - sse = SseServerTransport("/messages/", security_settings=security_settings) - context_server = RequestContextServer() - - async def handle_sse(request: Request) -> Response: - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await context_server.run(streams[0], streams[1], context_server.create_initialization_options()) - return Response() - - app = Starlette( - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ] - ) - - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting context server on {server_port}") - server.run() - - @pytest.fixture() -def context_server(server_port: int) -> Generator[None, None, None]: - """Fixture that provides a server with request context capture""" - proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True) - print("starting context server process") - proc.start() - - # Wait for server to be running - print("waiting for context server to start") - wait_for_server(server_port) - - yield - - print("killing context server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): - print("context server process failed to terminate") +async def context_server_app() -> Starlette: + """Fixture that provides the context server app""" + app = create_sse_app(RequestContextServer()) + return app @pytest.mark.anyio -async def test_request_context_propagation(context_server: None, server_url: str) -> None: +async def test_request_context_propagation(tg: TaskGroup, context_server_app: Starlette) -> None: """Test that request context is properly propagated through SSE transport.""" # Test with custom headers custom_headers = { @@ -378,11 +320,14 @@ async def test_request_context_propagation(context_server: None, server_url: str "X-Trace-Id": "trace-123", } - async with sse_client(server_url + "/sse", headers=custom_headers) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: + asgi_client_factory = create_asgi_client_factory(context_server_app, tg) + + async with sse_client( + f"{TEST_SERVER_BASE_URL}/sse", + headers=custom_headers, + httpx_client_factory=asgi_client_factory, + ) as streams: + async with ClientSession(*streams) as session: # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) @@ -391,9 +336,9 @@ async def test_request_context_propagation(context_server: None, server_url: str tool_result = await session.call_tool("echo_headers", {}) # Parse the JSON response - assert len(tool_result.content) == 1 - headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}") + content_item = tool_result.content[0] + headers_data = json.loads(content_item.text if content_item.type == "text" else "{}") # Verify headers were propagated assert headers_data.get("authorization") == "Bearer test-token" @@ -402,19 +347,22 @@ async def test_request_context_propagation(context_server: None, server_url: str @pytest.mark.anyio -async def test_request_context_isolation(context_server: None, server_url: str) -> None: +async def test_request_context_isolation(tg: TaskGroup, context_server_app: Starlette) -> None: """Test that request contexts are isolated between different SSE clients.""" contexts: list[dict[str, Any]] = [] + asgi_client_factory = create_asgi_client_factory(context_server_app, tg) + # Create multiple clients with different headers for i in range(3): headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} - async with sse_client(server_url + "/sse", headers=headers) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: + async with sse_client( + f"{TEST_SERVER_BASE_URL}/sse", + headers=headers, + httpx_client_factory=asgi_client_factory, + ) as streams: + async with ClientSession(*streams) as session: await session.initialize() # Call the tool that echoes context