From 09c2ae74f8a4d7631e7572ec3ee0c90fbb0b259b Mon Sep 17 00:00:00 2001 From: Kyle Stratis Date: Mon, 10 Nov 2025 15:01:48 -0500 Subject: [PATCH 1/5] Updates client session to handle all server notifications --- src/mcp/client/session.py | 95 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 2 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 8f071021d..f70f9311e 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -48,6 +48,45 @@ async def __call__( ) -> None: ... +class CancelledFnT(Protocol): + async def __call__( + self, + params: types.CancelledNotificationParams, + ) -> None: ... + + +class ProgressNotificationFnT(Protocol): + async def __call__( + self, + params: types.ProgressNotificationParams, + ) -> None: ... + + +class ResourceUpdatedFnT(Protocol): + async def __call__( + self, + params: types.ResourceUpdatedNotificationParams, + ) -> None: ... + + +class ResourceListChangedFnT(Protocol): + async def __call__( + self, + ) -> None: ... + + +class ToolListChangedFnT(Protocol): + async def __call__( + self, + ) -> None: ... + + +class PromptListChangedFnT(Protocol): + async def __call__( + self, + ) -> None: ... + + class MessageHandlerFnT(Protocol): async def __call__( self, @@ -96,6 +135,36 @@ async def _default_logging_callback( pass +async def _default_cancelled_callback( + params: types.CancelledNotificationParams, +) -> None: + pass + + +async def _default_progress_callback( + params: types.ProgressNotificationParams, +) -> None: + pass + + +async def _default_resource_updated_callback( + params: types.ResourceUpdatedNotificationParams, +) -> None: + pass + + +async def _default_resource_list_changed_callback() -> None: + pass + + +async def _default_tool_list_changed_callback() -> None: + pass + + +async def _default_prompt_list_changed_callback() -> None: + pass + + ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) @@ -117,6 +186,12 @@ def __init__( elicitation_callback: ElicitationFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, + cancelled_callback: CancelledFnT | None = None, + progress_notification_callback: ProgressNotificationFnT | None = None, + resource_updated_callback: ResourceUpdatedFnT | None = None, + resource_list_changed_callback: ResourceListChangedFnT | None = None, + tool_list_changed_callback: ToolListChangedFnT | None = None, + prompt_list_changed_callback: PromptListChangedFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, ) -> None: @@ -132,6 +207,12 @@ def __init__( self._elicitation_callback = elicitation_callback or _default_elicitation_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback + self._cancelled_callback = cancelled_callback or _default_cancelled_callback + self._progress_notification_callback = progress_notification_callback or _default_progress_callback + self._resource_updated_callback = resource_updated_callback or _default_resource_updated_callback + self._resource_list_changed_callback = resource_list_changed_callback or _default_resource_list_changed_callback + self._tool_list_changed_callback = tool_list_changed_callback or _default_tool_list_changed_callback + self._prompt_list_changed_callback = prompt_list_changed_callback or _default_prompt_list_changed_callback self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._server_capabilities: types.ServerCapabilities | None = None @@ -549,5 +630,15 @@ async def _received_notification(self, notification: types.ServerNotification) - match notification.root: case types.LoggingMessageNotification(params=params): await self._logging_callback(params) - case _: - pass + case types.CancelledNotification(params=params): + await self._cancelled_callback(params) + case types.ProgressNotification(params=params): + await self._progress_notification_callback(params) + case types.ResourceUpdatedNotification(params=params): + await self._resource_updated_callback(params) + case types.ResourceListChangedNotification(): + await self._resource_list_changed_callback() + case types.ToolListChangedNotification(): + await self._tool_list_changed_callback() + case types.PromptListChangedNotification(): + await self._prompt_list_changed_callback() From 5f4cee1a401a2b1fc166c8d1f2029e2220818c3e Mon Sep 17 00:00:00 2001 From: Kyle Stratis Date: Mon, 10 Nov 2025 15:19:33 -0500 Subject: [PATCH 2/5] Adds tests for new feature --- src/mcp/shared/memory.py | 24 +- tests/client/test_notification_callbacks.py | 429 ++++++++++++++++++++ 2 files changed, 452 insertions(+), 1 deletion(-) create mode 100644 tests/client/test_notification_callbacks.py diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 265d07c37..1153ab7df 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -13,7 +13,19 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream import mcp.types as types -from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.session import ( + ClientSession, + ElicitationFnT, + ListRootsFnT, + LoggingFnT, + MessageHandlerFnT, + ProgressNotificationFnT, + PromptListChangedFnT, + ResourceListChangedFnT, + ResourceUpdatedFnT, + SamplingFnT, + ToolListChangedFnT, +) from mcp.server import Server from mcp.server.fastmcp import FastMCP from mcp.shared.message import SessionMessage @@ -53,6 +65,11 @@ async def create_connected_server_and_client_session( sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, + progress_notification_callback: ProgressNotificationFnT | None = None, + resource_updated_callback: ResourceUpdatedFnT | None = None, + resource_list_changed_callback: ResourceListChangedFnT | None = None, + tool_list_changed_callback: ToolListChangedFnT | None = None, + prompt_list_changed_callback: PromptListChangedFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, raise_exceptions: bool = False, @@ -88,6 +105,11 @@ async def create_connected_server_and_client_session( sampling_callback=sampling_callback, list_roots_callback=list_roots_callback, logging_callback=logging_callback, + progress_notification_callback=progress_notification_callback, + resource_updated_callback=resource_updated_callback, + resource_list_changed_callback=resource_list_changed_callback, + tool_list_changed_callback=tool_list_changed_callback, + prompt_list_changed_callback=prompt_list_changed_callback, message_handler=message_handler, client_info=client_info, elicitation_callback=elicitation_callback, diff --git a/tests/client/test_notification_callbacks.py b/tests/client/test_notification_callbacks.py new file mode 100644 index 000000000..be2d3a731 --- /dev/null +++ b/tests/client/test_notification_callbacks.py @@ -0,0 +1,429 @@ +""" +Tests for client notification callbacks. + +This module tests all notification types that can be sent from the server to the client, +ensuring that the callback mechanism works correctly for each notification type. +""" + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +import pytest +from pydantic import AnyUrl + +import mcp.types as types +from mcp.shared.memory import ( + create_connected_server_and_client_session as create_session, +) +from mcp.shared.session import RequestResponder +from mcp.types import TextContent + +if TYPE_CHECKING: + from _pytest.fixtures import FixtureRequest + + +class ProgressNotificationCollector: + """Collector for ProgressNotification events.""" + + def __init__(self) -> None: + """Initialize the collector.""" + self.notifications: list[types.ProgressNotificationParams] = [] + + async def __call__(self, params: types.ProgressNotificationParams) -> None: + """Collect a progress notification.""" + self.notifications.append(params) + + +class ResourceUpdatedCollector: + """Collector for ResourceUpdatedNotification events.""" + + def __init__(self) -> None: + """Initialize the collector.""" + self.notifications: list[types.ResourceUpdatedNotificationParams] = [] + + async def __call__(self, params: types.ResourceUpdatedNotificationParams) -> None: + """Collect a resource updated notification.""" + self.notifications.append(params) + + +class ResourceListChangedCollector: + """Collector for ResourceListChangedNotification events.""" + + def __init__(self) -> None: + """Initialize the collector.""" + self.notification_count: int = 0 + + async def __call__(self) -> None: + """Collect a resource list changed notification.""" + self.notification_count += 1 + + +class ToolListChangedCollector: + """Collector for ToolListChangedNotification events.""" + + def __init__(self) -> None: + """Initialize the collector.""" + self.notification_count: int = 0 + + async def __call__(self) -> None: + """Collect a tool list changed notification.""" + self.notification_count += 1 + + +class PromptListChangedCollector: + """Collector for PromptListChangedNotification events.""" + + def __init__(self) -> None: + """Initialize the collector.""" + self.notification_count: int = 0 + + async def __call__(self) -> None: + """Collect a prompt list changed notification.""" + self.notification_count += 1 + + +@pytest.fixture +def progress_collector() -> ProgressNotificationCollector: + """Create a progress notification collector.""" + return ProgressNotificationCollector() + + +@pytest.fixture +def resource_updated_collector() -> ResourceUpdatedCollector: + """Create a resource updated collector.""" + return ResourceUpdatedCollector() + + +@pytest.fixture +def resource_list_changed_collector() -> ResourceListChangedCollector: + """Create a resource list changed collector.""" + return ResourceListChangedCollector() + + +@pytest.fixture +def tool_list_changed_collector() -> ToolListChangedCollector: + """Create a tool list changed collector.""" + return ToolListChangedCollector() + + +@pytest.fixture +def prompt_list_changed_collector() -> PromptListChangedCollector: + """Create a prompt list changed collector.""" + return PromptListChangedCollector() + + +@pytest.mark.anyio +async def test_progress_notification_callback(progress_collector: ProgressNotificationCollector) -> None: + """Test that progress notifications are correctly received by the callback.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + @server.tool("send_progress") + async def send_progress_tool(progress: float, total: float, message: str) -> bool: + """Send a progress notification to the client.""" + # Get the progress token from the request metadata + ctx = server.get_context() + if ctx.request_context.meta and ctx.request_context.meta.progressToken: + await ctx.session.send_progress_notification( + progress_token=ctx.request_context.meta.progressToken, + progress=progress, + total=total, + message=message, + ) + return True + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + """Handle exceptions from the session.""" + if isinstance(message, Exception): + raise message + + async with create_session( + server._mcp_server, + progress_notification_callback=progress_collector, + message_handler=message_handler, + ) as client_session: + # Call tool with progress token in metadata + result = await client_session.call_tool( + "send_progress", + {"progress": 50.0, "total": 100.0, "message": "Halfway there"}, + meta={"progressToken": "test-token-123"}, + ) + assert result.isError is False + + # Verify the progress notification was received + assert len(progress_collector.notifications) == 1 + notification = progress_collector.notifications[0] + assert notification.progressToken == "test-token-123" + assert notification.progress == 50.0 + assert notification.total == 100.0 + assert notification.message == "Halfway there" + + +@pytest.mark.anyio +async def test_resource_updated_callback(resource_updated_collector: ResourceUpdatedCollector) -> None: + """Test that resource updated notifications are correctly received by the callback.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + @server.tool("update_resource") + async def update_resource_tool(uri: str) -> bool: + """Send a resource updated notification to the client.""" + await server.get_context().session.send_resource_updated(AnyUrl(uri)) + return True + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + """Handle exceptions from the session.""" + if isinstance(message, Exception): + raise message + + async with create_session( + server._mcp_server, + resource_updated_callback=resource_updated_collector, + message_handler=message_handler, + ) as client_session: + # Trigger resource update notification + result = await client_session.call_tool("update_resource", {"uri": "file:///test/resource.txt"}) + assert result.isError is False + + # Verify the notification was received + assert len(resource_updated_collector.notifications) == 1 + notification = resource_updated_collector.notifications[0] + assert str(notification.uri) == "file:///test/resource.txt" + + +@pytest.mark.anyio +async def test_resource_list_changed_callback( + resource_list_changed_collector: ResourceListChangedCollector, +) -> None: + """Test that resource list changed notifications are correctly received by the callback.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + @server.tool("change_resource_list") + async def change_resource_list_tool() -> bool: + """Send a resource list changed notification to the client.""" + await server.get_context().session.send_resource_list_changed() + return True + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + """Handle exceptions from the session.""" + if isinstance(message, Exception): + raise message + + async with create_session( + server._mcp_server, + resource_list_changed_callback=resource_list_changed_collector, + message_handler=message_handler, + ) as client_session: + # Trigger resource list changed notification + result = await client_session.call_tool("change_resource_list", {}) + assert result.isError is False + + # Verify the notification was received + assert resource_list_changed_collector.notification_count == 1 + + +@pytest.mark.anyio +async def test_tool_list_changed_callback(tool_list_changed_collector: ToolListChangedCollector) -> None: + """Test that tool list changed notifications are correctly received by the callback.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + @server.tool("change_tool_list") + async def change_tool_list_tool() -> bool: + """Send a tool list changed notification to the client.""" + await server.get_context().session.send_tool_list_changed() + return True + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + """Handle exceptions from the session.""" + if isinstance(message, Exception): + raise message + + async with create_session( + server._mcp_server, + tool_list_changed_callback=tool_list_changed_collector, + message_handler=message_handler, + ) as client_session: + # Trigger tool list changed notification + result = await client_session.call_tool("change_tool_list", {}) + assert result.isError is False + + # Verify the notification was received + assert tool_list_changed_collector.notification_count == 1 + + +@pytest.mark.anyio +async def test_prompt_list_changed_callback(prompt_list_changed_collector: PromptListChangedCollector) -> None: + """Test that prompt list changed notifications are correctly received by the callback.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + @server.tool("change_prompt_list") + async def change_prompt_list_tool() -> bool: + """Send a prompt list changed notification to the client.""" + await server.get_context().session.send_prompt_list_changed() + return True + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + """Handle exceptions from the session.""" + if isinstance(message, Exception): + raise message + + async with create_session( + server._mcp_server, + prompt_list_changed_callback=prompt_list_changed_collector, + message_handler=message_handler, + ) as client_session: + # Trigger prompt list changed notification + result = await client_session.call_tool("change_prompt_list", {}) + assert result.isError is False + + # Verify the notification was received + assert prompt_list_changed_collector.notification_count == 1 + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "notification_type,callback_param,collector_fixture,tool_name,tool_args,verification", + [ + ( + "progress", + "progress_notification_callback", + "progress_collector", + "send_progress", + {"progress": 75.0, "total": 100.0, "message": "Almost done"}, + lambda c: ( # type: ignore[misc] + len(c.notifications) == 1 # type: ignore[attr-defined] + and c.notifications[0].progress == 75.0 # type: ignore[attr-defined] + and c.notifications[0].total == 100.0 # type: ignore[attr-defined] + and c.notifications[0].message == "Almost done" # type: ignore[attr-defined] + ), + ), + ( + "resource_updated", + "resource_updated_callback", + "resource_updated_collector", + "update_resource", + {"uri": "file:///test/data.json"}, + lambda c: ( # type: ignore[misc] + len(c.notifications) == 1 # type: ignore[attr-defined] + and str(c.notifications[0].uri) == "file:///test/data.json" # type: ignore[attr-defined] + ), + ), + ( + "resource_list_changed", + "resource_list_changed_callback", + "resource_list_changed_collector", + "change_resource_list", + {}, + lambda c: c.notification_count == 1, # type: ignore[attr-defined] + ), + ( + "tool_list_changed", + "tool_list_changed_callback", + "tool_list_changed_collector", + "change_tool_list", + {}, + lambda c: c.notification_count == 1, # type: ignore[attr-defined] + ), + ( + "prompt_list_changed", + "prompt_list_changed_callback", + "prompt_list_changed_collector", + "change_prompt_list", + {}, + lambda c: c.notification_count == 1, # type: ignore[attr-defined] + ), + ], +) +async def test_notification_callback_parametrized( + notification_type: str, + callback_param: str, + collector_fixture: str, + tool_name: str, + tool_args: dict[str, Any], + verification: Callable[[Any], bool], + request: "FixtureRequest", +) -> None: + """Parametrized test for all notification callbacks.""" + from mcp.server.fastmcp import FastMCP + + # Get the collector from the fixture + collector = request.getfixturevalue(collector_fixture) + + server = FastMCP("test") + + # Define all tools (simpler than dynamic tool creation) + @server.tool("send_progress") + async def send_progress_tool(progress: float, total: float, message: str) -> bool: + """Send a progress notification to the client.""" + ctx = server.get_context() + if ctx.request_context.meta and ctx.request_context.meta.progressToken: + await ctx.session.send_progress_notification( + progress_token=ctx.request_context.meta.progressToken, + progress=progress, + total=total, + message=message, + ) + return True + + @server.tool("update_resource") + async def update_resource_tool(uri: str) -> bool: + """Send a resource updated notification to the client.""" + await server.get_context().session.send_resource_updated(AnyUrl(uri)) + return True + + @server.tool("change_resource_list") + async def change_resource_list_tool() -> bool: + """Send a resource list changed notification to the client.""" + await server.get_context().session.send_resource_list_changed() + return True + + @server.tool("change_tool_list") + async def change_tool_list_tool() -> bool: + """Send a tool list changed notification to the client.""" + await server.get_context().session.send_tool_list_changed() + return True + + @server.tool("change_prompt_list") + async def change_prompt_list_tool() -> bool: + """Send a prompt list changed notification to the client.""" + await server.get_context().session.send_prompt_list_changed() + return True + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + """Handle exceptions from the session.""" + if isinstance(message, Exception): + raise message + + # Create session with the appropriate callback + session_kwargs: dict[str, Any] = {callback_param: collector, "message_handler": message_handler} + + async with create_session(server._mcp_server, **session_kwargs) as client_session: # type: ignore[arg-type] + # Call the appropriate tool + meta = {"progressToken": "param-test-token"} if notification_type == "progress" else None + result = await client_session.call_tool(tool_name, tool_args, meta=meta) + assert result.isError is False + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "true" + + # Verify using the provided verification function + assert verification(collector), f"Verification failed for {notification_type}" From 4734bf7ac89cedb6c81774b21e9821bd37e3ec54 Mon Sep 17 00:00:00 2001 From: Kyle Stratis Date: Mon, 10 Nov 2025 17:40:41 -0500 Subject: [PATCH 3/5] Adds client support for unknown custom and known custom notifications --- README.md | 64 ++++ .../clients/custom_notifications_example.py | 362 ++++++++++++++++++ .../clients/server_notification_client.py | 44 +++ src/mcp/client/session.py | 42 +- src/mcp/shared/memory.py | 6 + tests/client/test_custom_notifications.py | 218 +++++++++++ 6 files changed, 735 insertions(+), 1 deletion(-) create mode 100644 examples/snippets/clients/custom_notifications_example.py create mode 100644 examples/snippets/clients/server_notification_client.py create mode 100644 tests/client/test_custom_notifications.py diff --git a/README.md b/README.md index 5dbc4bd9d..3f1b10d77 100644 --- a/README.md +++ b/README.md @@ -2153,6 +2153,70 @@ if __name__ == "__main__": _Full example: [examples/snippets/clients/streamable_basic.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/streamable_basic.py)_ +### Handling Server Notifications + +Servers may send notifications, which derive from the `ServerNotification` class. To handle these, follow the following steps: + +1. For each notification type you want to support, write a callback function that follows implements the matching protocol, such as `ToolListChangedFnT` for the tool list changed notification. +2. Pass that function to the appropriate parameter when instantiating your client, e.g. `tool_list_changed_callback` for the tool list changed notification. This will be called every time your client receives the matching notification. + +You can also use this pattern with the `UnknownNotificationFnT` protocol to handle notification types that aren't anticipated in the SDK or by your code. This would handle custom notification types from the server. + + +```python +# Snippets demonstrating handling known and custom server notifications + +import asyncio + +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client +from mcp.types import ServerNotification + +# Create dummy server parameters for stdio connection +server_params = StdioServerParameters( + command="uv", + args=["run"], + env={}, +) + + +# Create a custom handler for the resource list changed notification +async def custom_resource_list_changed_handler() -> None: + """Custom handler for resource list changed notifications.""" + print("RESOURCE LIST CHANGED") + + +# Create a fallback handler for custom notifications we aren't aware of. +async def fallback_notification_handler(notification: ServerNotification) -> None: + """Fallback handler for unknown notifications.""" + print(f"UNKNOWN notification caught: {notification.root.method}") + + +async def run(): + async with stdio_client(server_params) as (read, write): + async with ClientSession( + read, + write, + resource_list_changed_callback=custom_resource_list_changed_handler, + unknown_notification_callback=fallback_notification_handler, + ) as session: + # Initialize the connection + await session.initialize() + + # Do client stuff here + + +if __name__ == "__main__": + asyncio.run(run()) +``` + +_Full example: [examples/snippets/clients/server_notification_client.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/server_notification_client.py)_ + + +If your client expects to connect to a server that sends custom notifications, you can create your handler or handlers, then pass them in a dictionary where the key is the notification literal and the value is a reference to the handler function. This dictionary is then passed in to the `custom_notification_handlers` parameter of the `ClientSession` constructor. + +For a runnable example, see [examples/snippets/clients/custom_notifications_example.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/custom_notifications_example.py). + ### Client Display Utilities When building MCP clients, the SDK provides utilities to help display human-readable names for tools, resources, and prompts: diff --git a/examples/snippets/clients/custom_notifications_example.py b/examples/snippets/clients/custom_notifications_example.py new file mode 100644 index 000000000..4bda4d7d3 --- /dev/null +++ b/examples/snippets/clients/custom_notifications_example.py @@ -0,0 +1,362 @@ +""" +Example demonstrating how to handle custom server notifications in MCP clients. + +This example shows multiple workflows: +1. Overriding standard MCP notification handlers (logging, progress, etc.) +2. Registering custom notification handlers by method name +3. Generic handler for any unknown notification type (fallback) + +The key feature demonstrated is custom_notification_handlers, which allows you to +register handlers for specific notification methods that your application defines. +""" + +import asyncio +from typing import Any, Literal + +import mcp.types as types +from mcp.server.fastmcp import FastMCP +from mcp.shared.memory import create_connected_server_and_client_session + +# Create a FastMCP server that sends various notifications +server = FastMCP("Notification Demo Server") + + +@server.tool("send_logging_notification") +async def send_log(message: str, level: Literal["debug", "info", "warning", "error"]) -> str: + """Sends a logging notification to demonstrate known notification handling.""" + await server.get_context().log(level=level, message=message, logger_name="demo") + return f"Sent {level} log: {message}" + + +@server.tool("send_progress_notification") +async def send_progress(progress: float, total: float, message: str) -> str: + """Sends a progress notification to demonstrate known notification handling.""" + ctx = server.get_context() + if ctx.request_context.meta and ctx.request_context.meta.progressToken: + await ctx.session.send_progress_notification( + progress_token=ctx.request_context.meta.progressToken, + progress=progress, + total=total, + message=message, + ) + return f"Sent progress: {progress}/{total} - {message}" + return "No progress token provided" + + +@server.tool("trigger_resource_list_change") +async def trigger_resource_change() -> str: + """Sends a resource list changed notification.""" + await server.get_context().session.send_resource_list_changed() + return "Sent resource list changed notification" + + +def create_notification_handlers() -> tuple[Any, Any, Any, Any, list[dict[str, Any]]]: + """Create notification handlers that share a common log.""" + notifications_log: list[dict[str, Any]] = [] + + async def unknown_notification_handler(notification: types.ServerNotification) -> None: + """Handler for unknown/custom notifications.""" + print(f"UNKNOWN notification caught: {notification.root.method}") + notifications_log.append({"type": "unknown", "method": notification.root.method}) + + async def custom_logging_handler(params: types.LoggingMessageNotificationParams) -> None: + """Custom handler for logging notifications.""" + print(f"LOG (custom handler): [{params.level}] {params.data}") + notifications_log.append({"type": "logging", "level": params.level, "message": params.data}) + + async def custom_progress_handler(params: types.ProgressNotificationParams) -> None: + """Custom handler for progress notifications.""" + print(f"PROGRESS: {params.progress}/{params.total} - {params.message or 'No message'}") + notifications_log.append({"type": "progress", "progress": params.progress, "total": params.total}) + + async def custom_resource_list_changed_handler() -> None: + """Custom handler for resource list changed notifications.""" + print("RESOURCE LIST CHANGED") + notifications_log.append({"type": "resource_list_changed"}) + + return ( + unknown_notification_handler, + custom_logging_handler, + custom_progress_handler, + custom_resource_list_changed_handler, + notifications_log, + ) + + +async def example_1_override_standard_handlers() -> None: + """Example 1: Override standard MCP notification handlers.""" + print("\n" + "=" * 70) + print("Example 1: Using Custom Handlers for Known Notification Types") + print("=" * 70) + print("\nWe're overriding the default handlers with custom ones.\n") + + # Create handlers for example 1 + ( + _, # unknown_handler not used in this example + logging_handler, + progress_handler, + resource_handler, + notifications_log, + ) = create_notification_handlers() + + async with create_connected_server_and_client_session( + server, + logging_callback=logging_handler, + progress_notification_callback=progress_handler, + resource_list_changed_callback=resource_handler, + ) as client: + print("Client connected with custom notification handlers\n") + + # Send various notifications + print("Sending logging notification...") + result1 = await client.call_tool( + "send_logging_notification", {"message": "Hello from server!", "level": "info"} + ) + print(f" Tool returned: {result1.content[0].text}\n") # type: ignore[attr-defined] + + await asyncio.sleep(0.1) # Give notifications time to process + + print("Sending progress notification...") + result2 = await client.call_tool( + "send_progress_notification", + {"progress": 75.0, "total": 100.0, "message": "Processing..."}, + meta={"progressToken": "demo-token"}, + ) + print(f" Tool returned: {result2.content[0].text}\n") # type: ignore[attr-defined] + + await asyncio.sleep(0.1) + + print("Sending resource list changed notification...") + result3 = await client.call_tool("trigger_resource_list_change", {}) + print(f" Tool returned: {result3.content[0].text}\n") # type: ignore[attr-defined] + + await asyncio.sleep(0.1) + + print(f"\nTotal notifications handled: {len(notifications_log)}") + print("\nNotifications received:") + for i, notif in enumerate(notifications_log, 1): + print(f" {i}. {notif}") + + +async def example_2_custom_notification_handlers() -> None: + """Example 2: Register handlers for custom notification types by method name.""" + print("\n" + "=" * 70) + print("Example 2: Custom Notification Handlers by Method Name") + print("=" * 70) + print("\nThis shows how to register handlers for SPECIFIC custom notification") + print("types that your application defines. These handlers are checked FIRST,") + print("before the standard notification types.\n") + + # Define handlers for specific custom notification methods + custom_notifications_received: list[dict[str, Any]] = [] + + async def analytics_notification_handler(notification: types.ServerNotification) -> None: + """Handler for custom analytics notifications from our app.""" + print(f"ANALYTICS notification: {notification.root.method}") + custom_notifications_received.append( + { + "handler": "analytics", + "method": notification.root.method, + "data": notification.root, + } + ) + + async def telemetry_notification_handler(notification: types.ServerNotification) -> None: + """Handler for custom telemetry notifications from our app.""" + print(f"TELEMETRY notification: {notification.root.method}") + custom_notifications_received.append( + { + "handler": "telemetry", + "method": notification.root.method, + "data": notification.root, + } + ) + + async def custom_app_notification_handler(notification: types.ServerNotification) -> None: + """Handler for general custom app notifications.""" + print(f"CUSTOM APP notification: {notification.root.method}") + custom_notifications_received.append( + { + "handler": "custom_app", + "method": notification.root.method, + "data": notification.root, + } + ) + + # Register custom handlers by their notification method names + # In a real app, your server would send notifications with these method names + custom_handlers = { + "notifications/custom/analytics": analytics_notification_handler, + "notifications/custom/telemetry": telemetry_notification_handler, + "notifications/custom/myapp/status": custom_app_notification_handler, + "notifications/custom/myapp/alert": custom_app_notification_handler, + } + + print("Custom notification handlers registered for:") + for method in custom_handlers: + print(f" • {method}") + + # Also create an unknown handler as fallback + (unknown_handler2, _, _, _, _) = create_notification_handlers() + + async with create_connected_server_and_client_session( + server, + custom_notification_handlers=custom_handlers, + unknown_notification_callback=unknown_handler2, + ) as client: + print("\nClient connected with custom notification handlers") + print("\nIn a real application:") + print(" • Your server sends notifications with method names like") + print(" 'notifications/custom/analytics'") + print(" • The client automatically routes them to the registered handlers") + print(" • Unknown notifications fall back to the unknown_notification_callback") + print("\nFor this demo, we'll just call a regular tool since we can't easily") + print("send truly custom notifications from FastMCP without extending the protocol.\n") + + await client.call_tool("send_logging_notification", {"message": "Regular operation", "level": "info"}) + await asyncio.sleep(0.1) + + print(f"\nCustom notification handlers ready: {len(custom_handlers)} registered") + print(f"Custom notifications received: {len(custom_notifications_received)}") + print("\nExample usage in your own code:") + print(""" + # Server side (in your MCP server): + await session.send_notification( + ServerNotification( + root=Notification( + method="notifications/custom/analytics", + params={"event": "user_action", "data": {...}} + ) + ) + ) + + # Client side (this file): + custom_handlers = { + "notifications/custom/analytics": your_analytics_handler, + } + """) + + +async def example_3_unknown_notification_fallback() -> None: + """Example 3: Unknown notification fallback handler.""" + print("\n" + "=" * 70) + print("Example 3: Unknown Notification Fallback (Conceptual)") + print("=" * 70) + print("\nThe unknown_notification_callback catches any notification that") + print("doesn't match registered custom handlers OR known MCP types.") + print("\nThis example shows that KNOWN notifications are NOT sent to the") + print("unknown handler - they go to their specific handlers instead.\n") + + # Create handlers for example 3 + ( + unknown_handler3, + logging_handler3, + _, + _, + notifications_log3, + ) = create_notification_handlers() + + async with create_connected_server_and_client_session( + server, + unknown_notification_callback=unknown_handler3, + logging_callback=logging_handler3, + ) as client: + print("Client connected with unknown notification fallback handler\n") + + print("Sending a standard logging notification (a KNOWN type)...") + result = await client.call_tool( + "send_logging_notification", {"message": "This uses the custom handler", "level": "debug"} + ) + print(f" Tool returned: {result.content[0].text}\n") # type: ignore[attr-defined] + + await asyncio.sleep(0.1) + + print(f"\nKnown notifications (logging): {len([n for n in notifications_log3 if n['type'] == 'logging'])}") + print(f"Unknown notifications: {len([n for n in notifications_log3 if n['type'] == 'unknown'])}") + print("\n✓ The known notification was handled by the logging_callback,") + print(" NOT by the unknown_notification_callback (as expected!)") + print("\nIn a real application with custom notification types:") + print(" • Notifications like 'notifications/custom/myapp' would be") + print(" sent to the unknown handler if not in custom_notification_handlers") + print(" • Known MCP types (logging, progress, etc.) are never 'unknown'") + + +async def example_4_selective_override() -> None: + """Example 4: Selective override of specific handlers.""" + print("\n" + "=" * 70) + print("Example 4: Real-World Pattern - Selective Override") + print("=" * 70) + print("\nOverride only specific notification handlers while using defaults") + print("for others. Perfect for production monitoring scenarios.\n") + + # Create handlers for example 4 + ( + _, + logging_handler4, + _, + _, + notifications_log4, + ) = create_notification_handlers() + + # Only override logging, let other notifications use defaults + async with create_connected_server_and_client_session( + server, + logging_callback=logging_handler4, # Custom + # progress_notification_callback uses default + # resource_list_changed_callback uses default + ) as client: + print("Client connected with selective handler overrides\n") + + print("Sending multiple notifications...") + await client.call_tool("send_logging_notification", {"message": "Custom handler!", "level": "warning"}) + await asyncio.sleep(0.05) + + await client.call_tool( + "send_progress_notification", + {"progress": 50.0, "total": 100.0, "message": "Default handler"}, + meta={"progressToken": "token-2"}, + ) + await asyncio.sleep(0.05) + + await client.call_tool("trigger_resource_list_change", {}) + await asyncio.sleep(0.1) + + print("") # newline after notifications + + print(f"\nCustom-handled notifications: {len([n for n in notifications_log4 if n['type'] == 'logging'])}") + print("Default-handled notifications: Progress and ResourceListChanged\n") + + +async def main() -> None: + """Run all examples demonstrating custom notification handling.""" + print("=" * 70) + print("MCP Custom Notification Handling Demo") + print("=" * 70) + + # Run all examples + await example_1_override_standard_handlers() + await example_2_custom_notification_handlers() + await example_3_unknown_notification_fallback() + await example_4_selective_override() + + # Summary + print("=" * 70) + print("Demo Complete!") + print("=" * 70) + print("\nKey Takeaways:") + print(" 1. Override standard MCP notification handlers (logging, progress, etc.)") + print(" 2. Register custom notification handlers by method name") + print(" → Use custom_notification_handlers dict") + print(" → These are checked FIRST, before standard types") + print(" 3. Unknown notification callback catches unrecognized types") + print(" → Fallback for anything not in custom handlers or standard types") + print(" 4. Selectively override only the handlers you need") + print(" 5. Default handlers are used when no custom handler is provided") + print("\nNotification Processing Order:") + print(" 1. Custom notification handlers (by method name)") + print(" 2. Standard MCP notification types") + print(" 3. Unknown notification callback (fallback)") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/snippets/clients/server_notification_client.py b/examples/snippets/clients/server_notification_client.py new file mode 100644 index 000000000..e63a6991a --- /dev/null +++ b/examples/snippets/clients/server_notification_client.py @@ -0,0 +1,44 @@ +# Snippets demonstrating handling known and custom server notifications + +import asyncio + +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client +from mcp.types import ServerNotification + +# Create dummy server parameters for stdio connection +server_params = StdioServerParameters( + command="uv", + args=["run"], + env={}, +) + + +# Create a custom handler for the resource list changed notification +async def custom_resource_list_changed_handler() -> None: + """Custom handler for resource list changed notifications.""" + print("RESOURCE LIST CHANGED") + + +# Create a fallback handler for custom notifications we aren't aware of. +async def fallback_notification_handler(notification: ServerNotification) -> None: + """Fallback handler for unknown notifications.""" + print(f"UNKNOWN notification caught: {notification.root.method}") + + +async def run(): + async with stdio_client(server_params) as (read, write): + async with ClientSession( + read, + write, + resource_list_changed_callback=custom_resource_list_changed_handler, + unknown_notification_callback=fallback_notification_handler, + ) as session: + # Initialize the connection + await session.initialize() + + # Do client stuff here + + +if __name__ == "__main__": + asyncio.run(run()) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index f70f9311e..24b69462c 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -87,6 +87,20 @@ async def __call__( ) -> None: ... +class UnknownNotificationFnT(Protocol): + async def __call__( + self, + notification: types.ServerNotification, + ) -> None: ... + + +class CustomNotificationHandlerFnT(Protocol): + async def __call__( + self, + notification: types.ServerNotification, + ) -> None: ... + + class MessageHandlerFnT(Protocol): async def __call__( self, @@ -165,6 +179,12 @@ async def _default_prompt_list_changed_callback() -> None: pass +async def _default_unknown_notification_callback( + notification: types.ServerNotification, +) -> None: + pass + + ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) @@ -192,6 +212,8 @@ def __init__( resource_list_changed_callback: ResourceListChangedFnT | None = None, tool_list_changed_callback: ToolListChangedFnT | None = None, prompt_list_changed_callback: PromptListChangedFnT | None = None, + unknown_notification_callback: UnknownNotificationFnT | None = None, + custom_notification_handlers: dict[str, CustomNotificationHandlerFnT] | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, ) -> None: @@ -213,6 +235,8 @@ def __init__( self._resource_list_changed_callback = resource_list_changed_callback or _default_resource_list_changed_callback self._tool_list_changed_callback = tool_list_changed_callback or _default_tool_list_changed_callback self._prompt_list_changed_callback = prompt_list_changed_callback or _default_prompt_list_changed_callback + self._unknown_notification_callback = unknown_notification_callback or _default_unknown_notification_callback + self._custom_notification_handlers = custom_notification_handlers or {} self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._server_capabilities: types.ServerCapabilities | None = None @@ -625,7 +649,19 @@ async def _handle_incoming( await self._message_handler(req) async def _received_notification(self, notification: types.ServerNotification) -> None: - """Handle notifications from the server.""" + """Handle notifications from the server. + + Notifications are handled in the following order: + 1. Custom notification handlers (registered by method name) + 2. Known notification types (LoggingMessage, Progress, etc.) + 3. Unknown notification handler (fallback) + """ + # Check if there's a custom handler registered for this notification method + notification_method = notification.root.method + if notification_method in self._custom_notification_handlers: + await self._custom_notification_handlers[notification_method](notification) + return + # Process specific notification types match notification.root: case types.LoggingMessageNotification(params=params): @@ -642,3 +678,7 @@ async def _received_notification(self, notification: types.ServerNotification) - await self._tool_list_changed_callback() case types.PromptListChangedNotification(): await self._prompt_list_changed_callback() + case _: # type: ignore[misc] + # Handle unknown/custom notifications that may exist at runtime + # but aren't part of the known ServerNotification type union + await self._unknown_notification_callback(notification) diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 1153ab7df..7725cb776 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -15,6 +15,7 @@ import mcp.types as types from mcp.client.session import ( ClientSession, + CustomNotificationHandlerFnT, ElicitationFnT, ListRootsFnT, LoggingFnT, @@ -25,6 +26,7 @@ ResourceUpdatedFnT, SamplingFnT, ToolListChangedFnT, + UnknownNotificationFnT, ) from mcp.server import Server from mcp.server.fastmcp import FastMCP @@ -70,6 +72,8 @@ async def create_connected_server_and_client_session( resource_list_changed_callback: ResourceListChangedFnT | None = None, tool_list_changed_callback: ToolListChangedFnT | None = None, prompt_list_changed_callback: PromptListChangedFnT | None = None, + unknown_notification_callback: UnknownNotificationFnT | None = None, + custom_notification_handlers: dict[str, CustomNotificationHandlerFnT] | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, raise_exceptions: bool = False, @@ -110,6 +114,8 @@ async def create_connected_server_and_client_session( resource_list_changed_callback=resource_list_changed_callback, tool_list_changed_callback=tool_list_changed_callback, prompt_list_changed_callback=prompt_list_changed_callback, + unknown_notification_callback=unknown_notification_callback, + custom_notification_handlers=custom_notification_handlers, message_handler=message_handler, client_info=client_info, elicitation_callback=elicitation_callback, diff --git a/tests/client/test_custom_notifications.py b/tests/client/test_custom_notifications.py new file mode 100644 index 000000000..088a3efd7 --- /dev/null +++ b/tests/client/test_custom_notifications.py @@ -0,0 +1,218 @@ +""" +Tests for custom notification handlers in ClientSession. + +This module tests both workflows for handling custom/unknown notifications: +1. Generic unknown notification handler (fallback) +2. Type-specific custom notification handlers (registry) +""" + +from typing import Any + +import pytest + +import mcp.types as types +from mcp.shared.memory import ( + create_connected_server_and_client_session as create_session, +) +from mcp.shared.session import RequestResponder + + +class UnknownNotificationCollector: + """Collector for unknown/custom notifications.""" + + def __init__(self) -> None: + self.notifications: list[types.ServerNotification] = [] + + async def __call__(self, notification: types.ServerNotification) -> None: + """Collect unknown notifications.""" + self.notifications.append(notification) + + +@pytest.fixture +def unknown_collector() -> UnknownNotificationCollector: + """Create a collector for unknown notifications.""" + return UnknownNotificationCollector() + + +@pytest.fixture +def message_handler() -> Any: + """Message handler that re-raises exceptions.""" + + async def handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + return handler + + +@pytest.mark.anyio +async def test_unknown_notification_callback_not_called_for_known_types( + unknown_collector: UnknownNotificationCollector, + message_handler: Any, +) -> None: + """Test that the unknown notification handler is NOT called for known notification types.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test-server") + + # Register a tool that sends a known notification (logging) + @server.tool("send_logging") + async def send_logging_tool() -> bool: + """Send a logging notification to the client.""" + # Logging notifications are handled by the specific logging_callback, + # not the unknown_notification_callback + return True + + async with create_session( + server._mcp_server, + unknown_notification_callback=unknown_collector, + message_handler=message_handler, + ) as client_session: + # Call the tool + result = await client_session.call_tool("send_logging", {}) + assert result.isError is False + + # The unknown notification collector should NOT have been called + # because logging is a known notification type + assert len(unknown_collector.notifications) == 0 + + +@pytest.mark.anyio +async def test_custom_notification_handler_takes_priority( + unknown_collector: UnknownNotificationCollector, + message_handler: Any, +) -> None: + """Test that custom notification handlers are checked before unknown handler.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test-server") + + # Track which handler was called + custom_handler_called: list[str] = [] + + async def custom_handler(notification: types.ServerNotification) -> None: + """Custom handler for a specific notification method.""" + custom_handler_called.append(notification.root.method) + + # Register a custom handler for a specific notification method + custom_handlers = { + "notifications/custom/test": custom_handler, + } + + @server.tool("trigger_notification") + async def trigger_tool() -> bool: + """Tool that returns success.""" + return True + + async with create_session( + server._mcp_server, + custom_notification_handlers=custom_handlers, + unknown_notification_callback=unknown_collector, + message_handler=message_handler, + ) as client_session: + # Call the tool + result = await client_session.call_tool("trigger_notification", {}) + assert result.isError is False + + # Neither handler should have been called for known notification types + assert len(custom_handler_called) == 0 + assert len(unknown_collector.notifications) == 0 + + +@pytest.mark.anyio +async def test_unknown_notification_callback_with_default( + message_handler: Any, +) -> None: + """Test that the default unknown notification callback does nothing.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test-server") + + @server.tool("test_tool") + async def test_tool() -> bool: + """Simple test tool.""" + return True + + # Don't pass an unknown_notification_callback - use the default + async with create_session( + server._mcp_server, + message_handler=message_handler, + ) as client_session: + # This should work fine with the default handler + result = await client_session.call_tool("test_tool", {}) + assert result.isError is False + + +@pytest.mark.anyio +async def test_custom_handlers_empty_dict( + unknown_collector: UnknownNotificationCollector, + message_handler: Any, +) -> None: + """Test that an empty custom handlers dict works correctly.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test-server") + + @server.tool("test_tool") + async def test_tool() -> bool: + """Simple test tool.""" + return True + + # Pass an empty custom handlers dict + async with create_session( + server._mcp_server, + custom_notification_handlers={}, + unknown_notification_callback=unknown_collector, + message_handler=message_handler, + ) as client_session: + result = await client_session.call_tool("test_tool", {}) + assert result.isError is False + + # No unknown notifications should have been received + assert len(unknown_collector.notifications) == 0 + + +@pytest.mark.anyio +async def test_multiple_custom_handlers( + message_handler: Any, +) -> None: + """Test that multiple custom notification handlers can be registered.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test-server") + + # Track which handlers were called + handler_calls: dict[str, int] = {} + + async def create_custom_handler(name: str) -> Any: + """Factory function to create a custom handler.""" + + async def handler(notification: types.ServerNotification) -> None: + handler_calls[name] = handler_calls.get(name, 0) + 1 + + return handler + + # Register multiple custom handlers + custom_handlers = { + "notifications/custom/type1": await create_custom_handler("type1"), + "notifications/custom/type2": await create_custom_handler("type2"), + "notifications/custom/type3": await create_custom_handler("type3"), + } + + @server.tool("test_tool") + async def test_tool() -> bool: + """Simple test tool.""" + return True + + async with create_session( + server._mcp_server, + custom_notification_handlers=custom_handlers, + message_handler=message_handler, + ) as client_session: + result = await client_session.call_tool("test_tool", {}) + assert result.isError is False + + # No handlers should have been called yet (no matching notifications) + assert len(handler_calls) == 0 From 64ea38aea660b5f55984f3d12bd974a2651480a8 Mon Sep 17 00:00:00 2001 From: Kyle Stratis Date: Wed, 12 Nov 2025 00:44:33 -0500 Subject: [PATCH 4/5] Removes handling for custom and unknown notifications --- README.md | 14 - .../clients/custom_notifications_example.py | 362 ------------------ .../clients/server_notification_client.py | 8 - src/mcp/client/session.py | 64 +--- src/mcp/shared/memory.py | 6 - tests/client/test_custom_notifications.py | 218 ----------- tests/client/test_notification_callbacks.py | 79 ++++ 7 files changed, 84 insertions(+), 667 deletions(-) delete mode 100644 examples/snippets/clients/custom_notifications_example.py delete mode 100644 tests/client/test_custom_notifications.py diff --git a/README.md b/README.md index 3f1b10d77..8abcfe877 100644 --- a/README.md +++ b/README.md @@ -2160,8 +2160,6 @@ Servers may send notifications, which derive from the `ServerNotification` class 1. For each notification type you want to support, write a callback function that follows implements the matching protocol, such as `ToolListChangedFnT` for the tool list changed notification. 2. Pass that function to the appropriate parameter when instantiating your client, e.g. `tool_list_changed_callback` for the tool list changed notification. This will be called every time your client receives the matching notification. -You can also use this pattern with the `UnknownNotificationFnT` protocol to handle notification types that aren't anticipated in the SDK or by your code. This would handle custom notification types from the server. - ```python # Snippets demonstrating handling known and custom server notifications @@ -2170,7 +2168,6 @@ import asyncio from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client -from mcp.types import ServerNotification # Create dummy server parameters for stdio connection server_params = StdioServerParameters( @@ -2186,19 +2183,12 @@ async def custom_resource_list_changed_handler() -> None: print("RESOURCE LIST CHANGED") -# Create a fallback handler for custom notifications we aren't aware of. -async def fallback_notification_handler(notification: ServerNotification) -> None: - """Fallback handler for unknown notifications.""" - print(f"UNKNOWN notification caught: {notification.root.method}") - - async def run(): async with stdio_client(server_params) as (read, write): async with ClientSession( read, write, resource_list_changed_callback=custom_resource_list_changed_handler, - unknown_notification_callback=fallback_notification_handler, ) as session: # Initialize the connection await session.initialize() @@ -2213,10 +2203,6 @@ if __name__ == "__main__": _Full example: [examples/snippets/clients/server_notification_client.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/server_notification_client.py)_ -If your client expects to connect to a server that sends custom notifications, you can create your handler or handlers, then pass them in a dictionary where the key is the notification literal and the value is a reference to the handler function. This dictionary is then passed in to the `custom_notification_handlers` parameter of the `ClientSession` constructor. - -For a runnable example, see [examples/snippets/clients/custom_notifications_example.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/custom_notifications_example.py). - ### Client Display Utilities When building MCP clients, the SDK provides utilities to help display human-readable names for tools, resources, and prompts: diff --git a/examples/snippets/clients/custom_notifications_example.py b/examples/snippets/clients/custom_notifications_example.py deleted file mode 100644 index 4bda4d7d3..000000000 --- a/examples/snippets/clients/custom_notifications_example.py +++ /dev/null @@ -1,362 +0,0 @@ -""" -Example demonstrating how to handle custom server notifications in MCP clients. - -This example shows multiple workflows: -1. Overriding standard MCP notification handlers (logging, progress, etc.) -2. Registering custom notification handlers by method name -3. Generic handler for any unknown notification type (fallback) - -The key feature demonstrated is custom_notification_handlers, which allows you to -register handlers for specific notification methods that your application defines. -""" - -import asyncio -from typing import Any, Literal - -import mcp.types as types -from mcp.server.fastmcp import FastMCP -from mcp.shared.memory import create_connected_server_and_client_session - -# Create a FastMCP server that sends various notifications -server = FastMCP("Notification Demo Server") - - -@server.tool("send_logging_notification") -async def send_log(message: str, level: Literal["debug", "info", "warning", "error"]) -> str: - """Sends a logging notification to demonstrate known notification handling.""" - await server.get_context().log(level=level, message=message, logger_name="demo") - return f"Sent {level} log: {message}" - - -@server.tool("send_progress_notification") -async def send_progress(progress: float, total: float, message: str) -> str: - """Sends a progress notification to demonstrate known notification handling.""" - ctx = server.get_context() - if ctx.request_context.meta and ctx.request_context.meta.progressToken: - await ctx.session.send_progress_notification( - progress_token=ctx.request_context.meta.progressToken, - progress=progress, - total=total, - message=message, - ) - return f"Sent progress: {progress}/{total} - {message}" - return "No progress token provided" - - -@server.tool("trigger_resource_list_change") -async def trigger_resource_change() -> str: - """Sends a resource list changed notification.""" - await server.get_context().session.send_resource_list_changed() - return "Sent resource list changed notification" - - -def create_notification_handlers() -> tuple[Any, Any, Any, Any, list[dict[str, Any]]]: - """Create notification handlers that share a common log.""" - notifications_log: list[dict[str, Any]] = [] - - async def unknown_notification_handler(notification: types.ServerNotification) -> None: - """Handler for unknown/custom notifications.""" - print(f"UNKNOWN notification caught: {notification.root.method}") - notifications_log.append({"type": "unknown", "method": notification.root.method}) - - async def custom_logging_handler(params: types.LoggingMessageNotificationParams) -> None: - """Custom handler for logging notifications.""" - print(f"LOG (custom handler): [{params.level}] {params.data}") - notifications_log.append({"type": "logging", "level": params.level, "message": params.data}) - - async def custom_progress_handler(params: types.ProgressNotificationParams) -> None: - """Custom handler for progress notifications.""" - print(f"PROGRESS: {params.progress}/{params.total} - {params.message or 'No message'}") - notifications_log.append({"type": "progress", "progress": params.progress, "total": params.total}) - - async def custom_resource_list_changed_handler() -> None: - """Custom handler for resource list changed notifications.""" - print("RESOURCE LIST CHANGED") - notifications_log.append({"type": "resource_list_changed"}) - - return ( - unknown_notification_handler, - custom_logging_handler, - custom_progress_handler, - custom_resource_list_changed_handler, - notifications_log, - ) - - -async def example_1_override_standard_handlers() -> None: - """Example 1: Override standard MCP notification handlers.""" - print("\n" + "=" * 70) - print("Example 1: Using Custom Handlers for Known Notification Types") - print("=" * 70) - print("\nWe're overriding the default handlers with custom ones.\n") - - # Create handlers for example 1 - ( - _, # unknown_handler not used in this example - logging_handler, - progress_handler, - resource_handler, - notifications_log, - ) = create_notification_handlers() - - async with create_connected_server_and_client_session( - server, - logging_callback=logging_handler, - progress_notification_callback=progress_handler, - resource_list_changed_callback=resource_handler, - ) as client: - print("Client connected with custom notification handlers\n") - - # Send various notifications - print("Sending logging notification...") - result1 = await client.call_tool( - "send_logging_notification", {"message": "Hello from server!", "level": "info"} - ) - print(f" Tool returned: {result1.content[0].text}\n") # type: ignore[attr-defined] - - await asyncio.sleep(0.1) # Give notifications time to process - - print("Sending progress notification...") - result2 = await client.call_tool( - "send_progress_notification", - {"progress": 75.0, "total": 100.0, "message": "Processing..."}, - meta={"progressToken": "demo-token"}, - ) - print(f" Tool returned: {result2.content[0].text}\n") # type: ignore[attr-defined] - - await asyncio.sleep(0.1) - - print("Sending resource list changed notification...") - result3 = await client.call_tool("trigger_resource_list_change", {}) - print(f" Tool returned: {result3.content[0].text}\n") # type: ignore[attr-defined] - - await asyncio.sleep(0.1) - - print(f"\nTotal notifications handled: {len(notifications_log)}") - print("\nNotifications received:") - for i, notif in enumerate(notifications_log, 1): - print(f" {i}. {notif}") - - -async def example_2_custom_notification_handlers() -> None: - """Example 2: Register handlers for custom notification types by method name.""" - print("\n" + "=" * 70) - print("Example 2: Custom Notification Handlers by Method Name") - print("=" * 70) - print("\nThis shows how to register handlers for SPECIFIC custom notification") - print("types that your application defines. These handlers are checked FIRST,") - print("before the standard notification types.\n") - - # Define handlers for specific custom notification methods - custom_notifications_received: list[dict[str, Any]] = [] - - async def analytics_notification_handler(notification: types.ServerNotification) -> None: - """Handler for custom analytics notifications from our app.""" - print(f"ANALYTICS notification: {notification.root.method}") - custom_notifications_received.append( - { - "handler": "analytics", - "method": notification.root.method, - "data": notification.root, - } - ) - - async def telemetry_notification_handler(notification: types.ServerNotification) -> None: - """Handler for custom telemetry notifications from our app.""" - print(f"TELEMETRY notification: {notification.root.method}") - custom_notifications_received.append( - { - "handler": "telemetry", - "method": notification.root.method, - "data": notification.root, - } - ) - - async def custom_app_notification_handler(notification: types.ServerNotification) -> None: - """Handler for general custom app notifications.""" - print(f"CUSTOM APP notification: {notification.root.method}") - custom_notifications_received.append( - { - "handler": "custom_app", - "method": notification.root.method, - "data": notification.root, - } - ) - - # Register custom handlers by their notification method names - # In a real app, your server would send notifications with these method names - custom_handlers = { - "notifications/custom/analytics": analytics_notification_handler, - "notifications/custom/telemetry": telemetry_notification_handler, - "notifications/custom/myapp/status": custom_app_notification_handler, - "notifications/custom/myapp/alert": custom_app_notification_handler, - } - - print("Custom notification handlers registered for:") - for method in custom_handlers: - print(f" • {method}") - - # Also create an unknown handler as fallback - (unknown_handler2, _, _, _, _) = create_notification_handlers() - - async with create_connected_server_and_client_session( - server, - custom_notification_handlers=custom_handlers, - unknown_notification_callback=unknown_handler2, - ) as client: - print("\nClient connected with custom notification handlers") - print("\nIn a real application:") - print(" • Your server sends notifications with method names like") - print(" 'notifications/custom/analytics'") - print(" • The client automatically routes them to the registered handlers") - print(" • Unknown notifications fall back to the unknown_notification_callback") - print("\nFor this demo, we'll just call a regular tool since we can't easily") - print("send truly custom notifications from FastMCP without extending the protocol.\n") - - await client.call_tool("send_logging_notification", {"message": "Regular operation", "level": "info"}) - await asyncio.sleep(0.1) - - print(f"\nCustom notification handlers ready: {len(custom_handlers)} registered") - print(f"Custom notifications received: {len(custom_notifications_received)}") - print("\nExample usage in your own code:") - print(""" - # Server side (in your MCP server): - await session.send_notification( - ServerNotification( - root=Notification( - method="notifications/custom/analytics", - params={"event": "user_action", "data": {...}} - ) - ) - ) - - # Client side (this file): - custom_handlers = { - "notifications/custom/analytics": your_analytics_handler, - } - """) - - -async def example_3_unknown_notification_fallback() -> None: - """Example 3: Unknown notification fallback handler.""" - print("\n" + "=" * 70) - print("Example 3: Unknown Notification Fallback (Conceptual)") - print("=" * 70) - print("\nThe unknown_notification_callback catches any notification that") - print("doesn't match registered custom handlers OR known MCP types.") - print("\nThis example shows that KNOWN notifications are NOT sent to the") - print("unknown handler - they go to their specific handlers instead.\n") - - # Create handlers for example 3 - ( - unknown_handler3, - logging_handler3, - _, - _, - notifications_log3, - ) = create_notification_handlers() - - async with create_connected_server_and_client_session( - server, - unknown_notification_callback=unknown_handler3, - logging_callback=logging_handler3, - ) as client: - print("Client connected with unknown notification fallback handler\n") - - print("Sending a standard logging notification (a KNOWN type)...") - result = await client.call_tool( - "send_logging_notification", {"message": "This uses the custom handler", "level": "debug"} - ) - print(f" Tool returned: {result.content[0].text}\n") # type: ignore[attr-defined] - - await asyncio.sleep(0.1) - - print(f"\nKnown notifications (logging): {len([n for n in notifications_log3 if n['type'] == 'logging'])}") - print(f"Unknown notifications: {len([n for n in notifications_log3 if n['type'] == 'unknown'])}") - print("\n✓ The known notification was handled by the logging_callback,") - print(" NOT by the unknown_notification_callback (as expected!)") - print("\nIn a real application with custom notification types:") - print(" • Notifications like 'notifications/custom/myapp' would be") - print(" sent to the unknown handler if not in custom_notification_handlers") - print(" • Known MCP types (logging, progress, etc.) are never 'unknown'") - - -async def example_4_selective_override() -> None: - """Example 4: Selective override of specific handlers.""" - print("\n" + "=" * 70) - print("Example 4: Real-World Pattern - Selective Override") - print("=" * 70) - print("\nOverride only specific notification handlers while using defaults") - print("for others. Perfect for production monitoring scenarios.\n") - - # Create handlers for example 4 - ( - _, - logging_handler4, - _, - _, - notifications_log4, - ) = create_notification_handlers() - - # Only override logging, let other notifications use defaults - async with create_connected_server_and_client_session( - server, - logging_callback=logging_handler4, # Custom - # progress_notification_callback uses default - # resource_list_changed_callback uses default - ) as client: - print("Client connected with selective handler overrides\n") - - print("Sending multiple notifications...") - await client.call_tool("send_logging_notification", {"message": "Custom handler!", "level": "warning"}) - await asyncio.sleep(0.05) - - await client.call_tool( - "send_progress_notification", - {"progress": 50.0, "total": 100.0, "message": "Default handler"}, - meta={"progressToken": "token-2"}, - ) - await asyncio.sleep(0.05) - - await client.call_tool("trigger_resource_list_change", {}) - await asyncio.sleep(0.1) - - print("") # newline after notifications - - print(f"\nCustom-handled notifications: {len([n for n in notifications_log4 if n['type'] == 'logging'])}") - print("Default-handled notifications: Progress and ResourceListChanged\n") - - -async def main() -> None: - """Run all examples demonstrating custom notification handling.""" - print("=" * 70) - print("MCP Custom Notification Handling Demo") - print("=" * 70) - - # Run all examples - await example_1_override_standard_handlers() - await example_2_custom_notification_handlers() - await example_3_unknown_notification_fallback() - await example_4_selective_override() - - # Summary - print("=" * 70) - print("Demo Complete!") - print("=" * 70) - print("\nKey Takeaways:") - print(" 1. Override standard MCP notification handlers (logging, progress, etc.)") - print(" 2. Register custom notification handlers by method name") - print(" → Use custom_notification_handlers dict") - print(" → These are checked FIRST, before standard types") - print(" 3. Unknown notification callback catches unrecognized types") - print(" → Fallback for anything not in custom handlers or standard types") - print(" 4. Selectively override only the handlers you need") - print(" 5. Default handlers are used when no custom handler is provided") - print("\nNotification Processing Order:") - print(" 1. Custom notification handlers (by method name)") - print(" 2. Standard MCP notification types") - print(" 3. Unknown notification callback (fallback)") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/snippets/clients/server_notification_client.py b/examples/snippets/clients/server_notification_client.py index e63a6991a..a51277cf8 100644 --- a/examples/snippets/clients/server_notification_client.py +++ b/examples/snippets/clients/server_notification_client.py @@ -4,7 +4,6 @@ from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client -from mcp.types import ServerNotification # Create dummy server parameters for stdio connection server_params = StdioServerParameters( @@ -20,19 +19,12 @@ async def custom_resource_list_changed_handler() -> None: print("RESOURCE LIST CHANGED") -# Create a fallback handler for custom notifications we aren't aware of. -async def fallback_notification_handler(notification: ServerNotification) -> None: - """Fallback handler for unknown notifications.""" - print(f"UNKNOWN notification caught: {notification.root.method}") - - async def run(): async with stdio_client(server_params) as (read, write): async with ClientSession( read, write, resource_list_changed_callback=custom_resource_list_changed_handler, - unknown_notification_callback=fallback_notification_handler, ) as session: # Initialize the connection await session.initialize() diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 50b92ac83..b7d8f4cc9 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -48,13 +48,6 @@ async def __call__( ) -> None: ... # pragma: no branch -class CancelledFnT(Protocol): - async def __call__( - self, - params: types.CancelledNotificationParams, - ) -> None: ... - - class ProgressNotificationFnT(Protocol): async def __call__( self, @@ -87,20 +80,6 @@ async def __call__( ) -> None: ... -class UnknownNotificationFnT(Protocol): - async def __call__( - self, - notification: types.ServerNotification, - ) -> None: ... - - -class CustomNotificationHandlerFnT(Protocol): - async def __call__( - self, - notification: types.ServerNotification, - ) -> None: ... - - class MessageHandlerFnT(Protocol): async def __call__( self, @@ -149,12 +128,6 @@ async def _default_logging_callback( pass -async def _default_cancelled_callback( - params: types.CancelledNotificationParams, -) -> None: - pass - - async def _default_progress_callback( params: types.ProgressNotificationParams, ) -> None: @@ -179,12 +152,6 @@ async def _default_prompt_list_changed_callback() -> None: pass -async def _default_unknown_notification_callback( - notification: types.ServerNotification, -) -> None: - pass - - ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) @@ -206,14 +173,11 @@ def __init__( elicitation_callback: ElicitationFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, - cancelled_callback: CancelledFnT | None = None, progress_notification_callback: ProgressNotificationFnT | None = None, resource_updated_callback: ResourceUpdatedFnT | None = None, resource_list_changed_callback: ResourceListChangedFnT | None = None, tool_list_changed_callback: ToolListChangedFnT | None = None, prompt_list_changed_callback: PromptListChangedFnT | None = None, - unknown_notification_callback: UnknownNotificationFnT | None = None, - custom_notification_handlers: dict[str, CustomNotificationHandlerFnT] | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, ) -> None: @@ -229,14 +193,11 @@ def __init__( self._elicitation_callback = elicitation_callback or _default_elicitation_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback - self._cancelled_callback = cancelled_callback or _default_cancelled_callback self._progress_notification_callback = progress_notification_callback or _default_progress_callback self._resource_updated_callback = resource_updated_callback or _default_resource_updated_callback self._resource_list_changed_callback = resource_list_changed_callback or _default_resource_list_changed_callback self._tool_list_changed_callback = tool_list_changed_callback or _default_tool_list_changed_callback self._prompt_list_changed_callback = prompt_list_changed_callback or _default_prompt_list_changed_callback - self._unknown_notification_callback = unknown_notification_callback or _default_unknown_notification_callback - self._custom_notification_handlers = custom_notification_handlers or {} self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._server_capabilities: types.ServerCapabilities | None = None @@ -651,25 +612,10 @@ async def _handle_incoming( await self._message_handler(req) async def _received_notification(self, notification: types.ServerNotification) -> None: - """Handle notifications from the server. - - Notifications are handled in the following order: - 1. Custom notification handlers (registered by method name) - 2. Known notification types (LoggingMessage, Progress, etc.) - 3. Unknown notification handler (fallback) - """ - # Check if there's a custom handler registered for this notification method - notification_method = notification.root.method - if notification_method in self._custom_notification_handlers: - await self._custom_notification_handlers[notification_method](notification) - return - - # Process specific notification types + """Handle notifications from the server.""" match notification.root: case types.LoggingMessageNotification(params=params): await self._logging_callback(params) - case types.CancelledNotification(params=params): - await self._cancelled_callback(params) case types.ProgressNotification(params=params): await self._progress_notification_callback(params) case types.ResourceUpdatedNotification(params=params): @@ -680,7 +626,7 @@ async def _received_notification(self, notification: types.ServerNotification) - await self._tool_list_changed_callback() case types.PromptListChangedNotification(): await self._prompt_list_changed_callback() - case _: # type: ignore[misc] - # Handle unknown/custom notifications that may exist at runtime - # but aren't part of the known ServerNotification type union - await self._unknown_notification_callback(notification) + case _: + # CancelledNotification is handled separately in shared/session.py + # and should never reach this point. This case is defensive. + pass diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index c12b5892d..53ab3b107 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -15,7 +15,6 @@ import mcp.types as types from mcp.client.session import ( ClientSession, - CustomNotificationHandlerFnT, ElicitationFnT, ListRootsFnT, LoggingFnT, @@ -26,7 +25,6 @@ ResourceUpdatedFnT, SamplingFnT, ToolListChangedFnT, - UnknownNotificationFnT, ) from mcp.server import Server from mcp.server.fastmcp import FastMCP @@ -72,8 +70,6 @@ async def create_connected_server_and_client_session( resource_list_changed_callback: ResourceListChangedFnT | None = None, tool_list_changed_callback: ToolListChangedFnT | None = None, prompt_list_changed_callback: PromptListChangedFnT | None = None, - unknown_notification_callback: UnknownNotificationFnT | None = None, - custom_notification_handlers: dict[str, CustomNotificationHandlerFnT] | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, raise_exceptions: bool = False, @@ -114,8 +110,6 @@ async def create_connected_server_and_client_session( resource_list_changed_callback=resource_list_changed_callback, tool_list_changed_callback=tool_list_changed_callback, prompt_list_changed_callback=prompt_list_changed_callback, - unknown_notification_callback=unknown_notification_callback, - custom_notification_handlers=custom_notification_handlers, message_handler=message_handler, client_info=client_info, elicitation_callback=elicitation_callback, diff --git a/tests/client/test_custom_notifications.py b/tests/client/test_custom_notifications.py deleted file mode 100644 index 088a3efd7..000000000 --- a/tests/client/test_custom_notifications.py +++ /dev/null @@ -1,218 +0,0 @@ -""" -Tests for custom notification handlers in ClientSession. - -This module tests both workflows for handling custom/unknown notifications: -1. Generic unknown notification handler (fallback) -2. Type-specific custom notification handlers (registry) -""" - -from typing import Any - -import pytest - -import mcp.types as types -from mcp.shared.memory import ( - create_connected_server_and_client_session as create_session, -) -from mcp.shared.session import RequestResponder - - -class UnknownNotificationCollector: - """Collector for unknown/custom notifications.""" - - def __init__(self) -> None: - self.notifications: list[types.ServerNotification] = [] - - async def __call__(self, notification: types.ServerNotification) -> None: - """Collect unknown notifications.""" - self.notifications.append(notification) - - -@pytest.fixture -def unknown_collector() -> UnknownNotificationCollector: - """Create a collector for unknown notifications.""" - return UnknownNotificationCollector() - - -@pytest.fixture -def message_handler() -> Any: - """Message handler that re-raises exceptions.""" - - async def handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): - raise message - - return handler - - -@pytest.mark.anyio -async def test_unknown_notification_callback_not_called_for_known_types( - unknown_collector: UnknownNotificationCollector, - message_handler: Any, -) -> None: - """Test that the unknown notification handler is NOT called for known notification types.""" - from mcp.server.fastmcp import FastMCP - - server = FastMCP("test-server") - - # Register a tool that sends a known notification (logging) - @server.tool("send_logging") - async def send_logging_tool() -> bool: - """Send a logging notification to the client.""" - # Logging notifications are handled by the specific logging_callback, - # not the unknown_notification_callback - return True - - async with create_session( - server._mcp_server, - unknown_notification_callback=unknown_collector, - message_handler=message_handler, - ) as client_session: - # Call the tool - result = await client_session.call_tool("send_logging", {}) - assert result.isError is False - - # The unknown notification collector should NOT have been called - # because logging is a known notification type - assert len(unknown_collector.notifications) == 0 - - -@pytest.mark.anyio -async def test_custom_notification_handler_takes_priority( - unknown_collector: UnknownNotificationCollector, - message_handler: Any, -) -> None: - """Test that custom notification handlers are checked before unknown handler.""" - from mcp.server.fastmcp import FastMCP - - server = FastMCP("test-server") - - # Track which handler was called - custom_handler_called: list[str] = [] - - async def custom_handler(notification: types.ServerNotification) -> None: - """Custom handler for a specific notification method.""" - custom_handler_called.append(notification.root.method) - - # Register a custom handler for a specific notification method - custom_handlers = { - "notifications/custom/test": custom_handler, - } - - @server.tool("trigger_notification") - async def trigger_tool() -> bool: - """Tool that returns success.""" - return True - - async with create_session( - server._mcp_server, - custom_notification_handlers=custom_handlers, - unknown_notification_callback=unknown_collector, - message_handler=message_handler, - ) as client_session: - # Call the tool - result = await client_session.call_tool("trigger_notification", {}) - assert result.isError is False - - # Neither handler should have been called for known notification types - assert len(custom_handler_called) == 0 - assert len(unknown_collector.notifications) == 0 - - -@pytest.mark.anyio -async def test_unknown_notification_callback_with_default( - message_handler: Any, -) -> None: - """Test that the default unknown notification callback does nothing.""" - from mcp.server.fastmcp import FastMCP - - server = FastMCP("test-server") - - @server.tool("test_tool") - async def test_tool() -> bool: - """Simple test tool.""" - return True - - # Don't pass an unknown_notification_callback - use the default - async with create_session( - server._mcp_server, - message_handler=message_handler, - ) as client_session: - # This should work fine with the default handler - result = await client_session.call_tool("test_tool", {}) - assert result.isError is False - - -@pytest.mark.anyio -async def test_custom_handlers_empty_dict( - unknown_collector: UnknownNotificationCollector, - message_handler: Any, -) -> None: - """Test that an empty custom handlers dict works correctly.""" - from mcp.server.fastmcp import FastMCP - - server = FastMCP("test-server") - - @server.tool("test_tool") - async def test_tool() -> bool: - """Simple test tool.""" - return True - - # Pass an empty custom handlers dict - async with create_session( - server._mcp_server, - custom_notification_handlers={}, - unknown_notification_callback=unknown_collector, - message_handler=message_handler, - ) as client_session: - result = await client_session.call_tool("test_tool", {}) - assert result.isError is False - - # No unknown notifications should have been received - assert len(unknown_collector.notifications) == 0 - - -@pytest.mark.anyio -async def test_multiple_custom_handlers( - message_handler: Any, -) -> None: - """Test that multiple custom notification handlers can be registered.""" - from mcp.server.fastmcp import FastMCP - - server = FastMCP("test-server") - - # Track which handlers were called - handler_calls: dict[str, int] = {} - - async def create_custom_handler(name: str) -> Any: - """Factory function to create a custom handler.""" - - async def handler(notification: types.ServerNotification) -> None: - handler_calls[name] = handler_calls.get(name, 0) + 1 - - return handler - - # Register multiple custom handlers - custom_handlers = { - "notifications/custom/type1": await create_custom_handler("type1"), - "notifications/custom/type2": await create_custom_handler("type2"), - "notifications/custom/type3": await create_custom_handler("type3"), - } - - @server.tool("test_tool") - async def test_tool() -> bool: - """Simple test tool.""" - return True - - async with create_session( - server._mcp_server, - custom_notification_handlers=custom_handlers, - message_handler=message_handler, - ) as client_session: - result = await client_session.call_tool("test_tool", {}) - assert result.isError is False - - # No handlers should have been called yet (no matching notifications) - assert len(handler_calls) == 0 diff --git a/tests/client/test_notification_callbacks.py b/tests/client/test_notification_callbacks.py index be2d3a731..ff24a0c3e 100644 --- a/tests/client/test_notification_callbacks.py +++ b/tests/client/test_notification_callbacks.py @@ -427,3 +427,82 @@ async def message_handler( # Verify using the provided verification function assert verification(collector), f"Verification failed for {notification_type}" + + +@pytest.mark.anyio +async def test_all_default_callbacks_with_notifications() -> None: + """Test that all default notification callbacks work (they do nothing). + + This single test covers multiple default callbacks by not providing + custom callbacks and triggering various notification types. + """ + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test-server") + + @server.tool("send_progress") + async def send_progress_tool(progress: float, total: float) -> bool: + """Send a progress notification.""" + ctx = server.get_context() + if ctx.request_context.meta and ctx.request_context.meta.progressToken: + await ctx.session.send_progress_notification( + progress_token=ctx.request_context.meta.progressToken, + progress=progress, + total=total, + ) + return True + + @server.tool("send_resource_updated") + async def send_resource_updated_tool(uri: str) -> bool: + """Send a resource updated notification.""" + from pydantic import AnyUrl + + await server.get_context().session.send_resource_updated(uri=AnyUrl(uri)) + return True + + @server.tool("send_resource_list_changed") + async def send_resource_list_changed_tool() -> bool: + """Send a resource list changed notification.""" + await server.get_context().session.send_resource_list_changed() + return True + + @server.tool("send_tool_list_changed") + async def send_tool_list_changed_tool() -> bool: + """Send a tool list changed notification.""" + await server.get_context().session.send_tool_list_changed() + return True + + @server.tool("send_prompt_list_changed") + async def send_prompt_list_changed_tool() -> bool: + """Send a prompt list changed notification.""" + await server.get_context().session.send_prompt_list_changed() + return True + + # Create session WITHOUT custom callbacks - all will use defaults + async with create_session(server._mcp_server) as client_session: + # Test progress notification with default callback + result1 = await client_session.call_tool( + "send_progress", + {"progress": 50.0, "total": 100.0}, + meta={"progressToken": "test-token"}, + ) + assert result1.isError is False + + # Test resource updated with default callback + result2 = await client_session.call_tool( + "send_resource_updated", + {"uri": "file:///test.txt"}, + ) + assert result2.isError is False + + # Test resource list changed with default callback + result3 = await client_session.call_tool("send_resource_list_changed", {}) + assert result3.isError is False + + # Test tool list changed with default callback + result4 = await client_session.call_tool("send_tool_list_changed", {}) + assert result4.isError is False + + # Test prompt list changed with default callback + result5 = await client_session.call_tool("send_prompt_list_changed", {}) + assert result5.isError is False From f6d18c4a4224b1c4054bfe095f2fa99c16a2d8ab Mon Sep 17 00:00:00 2001 From: Kyle Stratis Date: Wed, 12 Nov 2025 22:25:37 -0500 Subject: [PATCH 5/5] Adds comment about handling of progress notifications --- src/mcp/client/session.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index b7d8f4cc9..06a345100 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -131,6 +131,8 @@ async def _default_logging_callback( async def _default_progress_callback( params: types.ProgressNotificationParams, ) -> None: + """Note: Default progress handling happens in the BaseSession class. This callback will only be called after the + default progress handling has completed.""" pass