From 9da2d585e67dbf4a4792098e9296a834609c2884 Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 08:49:23 +0000 Subject: [PATCH 1/5] add support for supplying a progress callback to call_tool requests --- src/mcp/client/session.py | 54 +++++++++++++---- tests/client/test_session.py | 109 +++++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 10 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7bb8821f71..027b08706a 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -35,6 +35,13 @@ async def __call__( ) -> None: ... +class ProgressFnT(Protocol): + async def __call__( + self, + params: types.ProgressNotificationParams, + ) -> None: ... + + class MessageHandlerFnT(Protocol): async def __call__( self, @@ -91,6 +98,9 @@ class ClientSession( types.ServerNotification, ] ): + _progress_id: int + _in_progress: dict[types.ProgressToken, ProgressFnT] + def __init__( self, read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], @@ -114,6 +124,8 @@ def __init__( self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler + self._progress_id = 0 + self._in_progress = {} async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() @@ -259,19 +271,37 @@ async def call_tool( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.CallToolResult: """Send a tools/call request.""" - return await self.send_request( - types.ClientRequest( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams(name=name, arguments=arguments), - ) - ), - types.CallToolResult, - request_read_timeout_seconds=read_timeout_seconds, - ) + if progress_callback is None: + progress_id = None + call_params = types.CallToolRequestParams(name=name, arguments=arguments) + else: + progress_id = self._progress_id + self._progress_id = progress_id + 1 + + call_meta = types.RequestParams.Meta(progressToken=progress_id) + call_params = types.CallToolRequestParams( + name=name, arguments=arguments, _meta=call_meta + ) + self._in_progress[progress_id] = progress_callback + + try: + return await self.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=call_params, + ) + ), + types.CallToolResult, + request_read_timeout_seconds=read_timeout_seconds, + ) + finally: + if progress_id is not None: + self._in_progress.pop(progress_id, None) async def list_prompts(self) -> types.ListPromptsResult: """Send a prompts/list request.""" @@ -384,5 +414,9 @@ async def _received_notification( match notification.root: case types.LoggingMessageNotification(params=params): await self._logging_callback(params) + case types.ProgressNotification(params=params): + if params.progressToken in self._in_progress: + progress_callback = self._in_progress[params.progressToken] + await progress_callback(params) case _: pass diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 6abcf70cbc..1733e8a9a8 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -250,3 +250,112 @@ async def mock_server(): # Assert that the default client info was sent assert received_client_info == DEFAULT_CLIENT_INFO + + +@pytest.mark.anyio +async def test_client_session_progress(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + + async def mock_server(): + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, types.CallToolRequest) + assert request.root.params.meta + assert request.root.params.meta.progressToken is not None + + progress_token = request.root.params.meta.progressToken + + notifications = [ + types.ServerNotification( + root=types.ProgressNotification( + params=types.ProgressNotificationParams( + progressToken=progress_token, progress=1 + ), + method="notifications/progress", + ) + ), + types.ServerNotification( + root=types.ProgressNotification( + params=types.ProgressNotificationParams( + progressToken=progress_token, progress=2 + ), + method="notifications/progress", + ) + ), + ] + result = ServerResult(types.CallToolResult(content=[])) + + async with server_to_client_send: + for notification in notifications: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + types.JSONRPCNotification( + jsonrpc="2.0", + **notification.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) + ) + ) + ) + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) + ) + ) + ) + + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + progress_count = 0 + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + async def progress_callback(params: types.ProgressNotificationParams): + nonlocal progress_count + progress_count = progress_count + 1 + + result = await session.call_tool( + "tool_with_progress", progress_callback=progress_callback + ) + + # Assert the result + assert isinstance(result, types.CallToolResult) + assert len(result.content) == 0 + assert progress_count == 2 From 80913e7a02564b5e3df8ba12c1c541aadefcdceb Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 08:57:49 +0000 Subject: [PATCH 2/5] tidy up notification to reduce copy paste --- tests/client/test_session.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 1733e8a9a8..08f2281491 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -261,6 +261,8 @@ async def test_client_session_progress(): SessionMessage ](1) + send_notification_count = 10 + async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message @@ -273,24 +275,16 @@ async def mock_server(): assert request.root.params.meta.progressToken is not None progress_token = request.root.params.meta.progressToken - notifications = [ types.ServerNotification( root=types.ProgressNotification( params=types.ProgressNotificationParams( - progressToken=progress_token, progress=1 + progressToken=progress_token, progress=i ), method="notifications/progress", ) - ), - types.ServerNotification( - root=types.ProgressNotification( - params=types.ProgressNotificationParams( - progressToken=progress_token, progress=2 - ), - method="notifications/progress", - ) - ), + ) + for i in range(send_notification_count) ] result = ServerResult(types.CallToolResult(content=[])) @@ -358,4 +352,4 @@ async def progress_callback(params: types.ProgressNotificationParams): # Assert the result assert isinstance(result, types.CallToolResult) assert len(result.content) == 0 - assert progress_count == 2 + assert progress_count == send_notification_count From 76ae778fea790678e6c6f0394f2f92ab0a3dc96a Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 13:48:51 +0000 Subject: [PATCH 3/5] moved core logic for progress call back to BaseSession --- src/mcp/client/session.py | 57 ++++++++-------------------------- src/mcp/shared/session.py | 64 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 72 insertions(+), 49 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 027b08706a..5239a249bd 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -8,7 +8,7 @@ import mcp.types as types from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, RequestResponder +from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") @@ -35,13 +35,6 @@ async def __call__( ) -> None: ... -class ProgressFnT(Protocol): - async def __call__( - self, - params: types.ProgressNotificationParams, - ) -> None: ... - - class MessageHandlerFnT(Protocol): async def __call__( self, @@ -98,9 +91,6 @@ class ClientSession( types.ServerNotification, ] ): - _progress_id: int - _in_progress: dict[types.ProgressToken, ProgressFnT] - def __init__( self, read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], @@ -124,8 +114,6 @@ def __init__( self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler - self._progress_id = 0 - self._in_progress = {} async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() @@ -274,34 +262,17 @@ async def call_tool( progress_callback: ProgressFnT | None = None, ) -> types.CallToolResult: """Send a tools/call request.""" - - if progress_callback is None: - progress_id = None - call_params = types.CallToolRequestParams(name=name, arguments=arguments) - else: - progress_id = self._progress_id - self._progress_id = progress_id + 1 - - call_meta = types.RequestParams.Meta(progressToken=progress_id) - call_params = types.CallToolRequestParams( - name=name, arguments=arguments, _meta=call_meta - ) - self._in_progress[progress_id] = progress_callback - - try: - return await self.send_request( - types.ClientRequest( - types.CallToolRequest( - method="tools/call", - params=call_params, - ) - ), - types.CallToolResult, - request_read_timeout_seconds=read_timeout_seconds, - ) - finally: - if progress_id is not None: - self._in_progress.pop(progress_id, None) + return await self.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams(name=name, arguments=arguments), + ) + ), + types.CallToolResult, + request_read_timeout_seconds=read_timeout_seconds, + progress_callback=progress_callback, + ) async def list_prompts(self) -> types.ListPromptsResult: """Send a prompts/list request.""" @@ -414,9 +385,5 @@ async def _received_notification( match notification.root: case types.LoggingMessageNotification(params=params): await self._logging_callback(params) - case types.ProgressNotification(params=params): - if params.progressToken in self._in_progress: - progress_callback = self._in_progress[params.progressToken] - await progress_callback(params) case _: pass diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index cce8b1184e..8bed4d6153 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -3,7 +3,7 @@ from contextlib import AsyncExitStack from datetime import timedelta from types import TracebackType -from typing import Any, Generic, TypeVar +from typing import Any, Generic, Protocol, TypeVar import anyio import httpx @@ -24,6 +24,9 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + ProgressNotification, + ProgressNotificationParams, + ProgressToken, RequestParams, ServerNotification, ServerRequest, @@ -39,6 +42,14 @@ "ReceiveNotificationT", ClientNotification, ServerNotification ) + +class ProgressFnT(Protocol): + async def __call__( + self, + params: ProgressNotificationParams, + ) -> None: ... + + RequestId = str | int @@ -168,7 +179,9 @@ class BaseSession( RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError] ] _request_id: int + _progress_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] + _in_progress: dict[ProgressToken, ProgressFnT] def __init__( self, @@ -187,6 +200,8 @@ def __init__( self._receive_notification_type = receive_notification_type self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} + self._progress_id = 0 + self._in_progress = {} self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: @@ -214,12 +229,16 @@ async def send_request( result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, metadata: MessageMetadata = None, + progress_callback: ProgressFnT | None = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the response contains an error. If a request read timeout is provided, it will take precedence over the session read timeout. + If progress_callback is provided any progress notifications sent from the + receiver will be passed back to the sender + Do not use this method to emit notifications! Use send_notification() instead. """ @@ -227,6 +246,27 @@ async def send_request( request_id = self._request_id self._request_id = request_id + 1 + progress_id = None + send_request = None + + if progress_callback is not None: + if request.root.params is not None: + progress_id = self._progress_id + self._progress_id = progress_id + 1 + new_params = request.root.params.model_copy( + update={"meta": RequestParams.Meta(progressToken=progress_id)} + ) + new_root = request.root.model_copy(update={"params": new_params}) + send_request = request.model_copy(update={"root": new_root}) + self._in_progress[progress_id] = progress_callback + else: + raise ValueError( + f"{type(request.root).__name__} does not support progress" + ) + + if send_request is None: + send_request = request + response_stream, response_stream_reader = anyio.create_memory_object_stream[ JSONRPCResponse | JSONRPCError ](1) @@ -236,11 +276,11 @@ async def send_request( jsonrpc_request = JSONRPCRequest( jsonrpc="2.0", id=request_id, - **request.model_dump(by_alias=True, mode="json", exclude_none=True), + **send_request.model_dump( + by_alias=True, mode="json", exclude_none=True + ), ) - # TODO: Support progress callbacks - await self._write_stream.send( SessionMessage( message=JSONRPCMessage(jsonrpc_request), metadata=metadata @@ -276,6 +316,8 @@ async def send_request( finally: self._response_streams.pop(request_id, None) + if progress_id is not None: + self._in_progress.pop(progress_id, None) await response_stream.aclose() await response_stream_reader.aclose() @@ -364,6 +406,20 @@ async def _receive_loop(self) -> None: if cancelled_id in self._in_flight: await self._in_flight[cancelled_id].cancel() else: + match notification.root: + case ProgressNotification(params=params): + if params.progressToken in self._in_progress: + progress_callback = self._in_progress[ + params.progressToken + ] + await progress_callback(params) + else: + logging.warning( + "Unknown progress token %s", + params.progressToken, + ) + case _: + pass await self._received_notification(notification) await self._handle_incoming(notification) except Exception as e: From 2637c0470ef92d2b4a4ece6f1dde95bc2053f885 Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 14:22:08 +0000 Subject: [PATCH 4/5] add exception in case both progress_call back and request.params.meta.progressToken is set --- src/mcp/shared/session.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 8bed4d6153..e0fcc8e17a 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -251,14 +251,23 @@ async def send_request( if progress_callback is not None: if request.root.params is not None: - progress_id = self._progress_id - self._progress_id = progress_id + 1 - new_params = request.root.params.model_copy( - update={"meta": RequestParams.Meta(progressToken=progress_id)} - ) - new_root = request.root.model_copy(update={"params": new_params}) - send_request = request.model_copy(update={"root": new_root}) - self._in_progress[progress_id] = progress_callback + if ( + request.root.params.meta is None + or request.root.params.meta.progressToken is None + ): + progress_id = self._progress_id + self._progress_id = progress_id + 1 + new_params = request.root.params.model_copy( + update={"meta": RequestParams.Meta(progressToken=progress_id)} + ) + new_root = request.root.model_copy(update={"params": new_params}) + send_request = request.model_copy(update={"root": new_root}) + self._in_progress[progress_id] = progress_callback + else: + raise ValueError( + "Request has progressToken and progress_callback provided " + "via send_request method only one or other is supported" + ) else: raise ValueError( f"{type(request.root).__name__} does not support progress" From 1004793d3c13054bea85cf08c5dfc2649e4674bd Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 14:58:16 +0000 Subject: [PATCH 5/5] update logic for setting progressToken now properly handles none initial request params --- src/mcp/shared/session.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index e0fcc8e17a..449b489f55 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -250,28 +250,30 @@ async def send_request( send_request = None if progress_callback is not None: - if request.root.params is not None: + if request.root.params is None: + progress_id = self._progress_id + new_params = RequestParams( + _meta=RequestParams.Meta(progressToken=progress_id) + ) + else: if ( request.root.params.meta is None or request.root.params.meta.progressToken is None ): progress_id = self._progress_id - self._progress_id = progress_id + 1 new_params = request.root.params.model_copy( update={"meta": RequestParams.Meta(progressToken=progress_id)} ) - new_root = request.root.model_copy(update={"params": new_params}) - send_request = request.model_copy(update={"root": new_root}) - self._in_progress[progress_id] = progress_callback else: raise ValueError( "Request has progressToken and progress_callback provided " "via send_request method only one or other is supported" ) - else: - raise ValueError( - f"{type(request.root).__name__} does not support progress" - ) + + new_root = request.root.model_copy(update={"params": new_params}) + send_request = request.model_copy(update={"root": new_root}) + self._progress_id = progress_id + 1 + self._in_progress[progress_id] = progress_callback if send_request is None: send_request = request