Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 53 additions & 35 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,47 +251,61 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
await event_source.response.aclose()
break

async def _send_error_response(self, ctx: RequestContext, error: Exception) -> None:
"""Send an error response to the client."""
error_data = ErrorData(code=32000, message=str(error))
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=ctx.session_message.message.root.id, error=error_data)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
await ctx.read_stream_writer.send(session_message)

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
headers = self._prepare_request_headers(ctx.headers)
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)

async with ctx.client.stream(
"POST",
self.url,
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
headers=headers,
) as response:
if response.status_code == 202:
logger.debug("Received 202 Accepted")
return
try:
async with ctx.client.stream(
"POST",
self.url,
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
headers=headers,
) as response:
if response.status_code == 202:
logger.debug("Received 202 Accepted")
return

if response.status_code == 404:
if isinstance(message.root, JSONRPCRequest):
await self._send_session_terminated_error(
ctx.read_stream_writer,
message.root.id,
)
return
if response.status_code == 404:
if isinstance(message.root, JSONRPCRequest):
await self._send_session_terminated_error(
ctx.read_stream_writer,
message.root.id,
)
return

response.raise_for_status()
response.raise_for_status()
if is_initialization:
self._maybe_extract_session_id_from_response(response)

# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
# The server MUST NOT send a response to notifications.
if isinstance(message.root, JSONRPCRequest):
content_type = response.headers.get(CONTENT_TYPE, "").lower()
if content_type.startswith(JSON):
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
elif content_type.startswith(SSE):
await self._handle_sse_response(response, ctx, is_initialization)
else:
await self._handle_unexpected_content_type(
content_type,
ctx.read_stream_writer,
)
except Exception as exc:
if is_initialization:
self._maybe_extract_session_id_from_response(response)

# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
# The server MUST NOT send a response to notifications.
if isinstance(message.root, JSONRPCRequest):
content_type = response.headers.get(CONTENT_TYPE, "").lower()
if content_type.startswith(JSON):
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
elif content_type.startswith(SSE):
await self._handle_sse_response(response, ctx, is_initialization)
else:
await self._handle_unexpected_content_type(
content_type,
ctx.read_stream_writer,
)
raise exc
else:
await self._send_error_response(ctx, exc)

async def _handle_json_response(
self,
Expand Down Expand Up @@ -323,6 +337,7 @@ async def _handle_sse_response(
"""Handle SSE response from the server."""
try:
event_source = EventSource(response)
finished = False
async for sse in event_source.aiter_sse():
is_complete = await self._handle_sse_event(
sse,
Expand All @@ -333,11 +348,14 @@ async def _handle_sse_response(
# If the SSE event indicates completion, like returning respose/error
# break the loop
if is_complete:
finished = True
await response.aclose()
break
except Exception as e:
logger.exception("Error reading SSE stream:")
await ctx.read_stream_writer.send(e)
if not finished:
raise Exception("SSE stream ended without completing")
except Exception as exc:
logger.exception("Error handling SSE response")
await self._send_error_response(ctx, exc)

async def _handle_unexpected_content_type(
self,
Expand Down