From 80c1fd0a4c125351974b2739016d5f980616a9bf Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 21 May 2026 14:39:44 +0900 Subject: [PATCH] feat: add MCP tool call result callback --- src/agents/mcp/__init__.py | 4 ++ src/agents/mcp/server.py | 26 +++++++++++ src/agents/mcp/util.py | 88 +++++++++++++++++++++++++++++++++++++- tests/mcp/helpers.py | 4 +- tests/mcp/test_mcp_util.py | 82 +++++++++++++++++++++++++++++++++++ 5 files changed, 202 insertions(+), 2 deletions(-) diff --git a/src/agents/mcp/__init__.py b/src/agents/mcp/__init__.py index f0de5bda66..d30a19d792 100644 --- a/src/agents/mcp/__init__.py +++ b/src/agents/mcp/__init__.py @@ -17,6 +17,8 @@ ) from .util import ( + MCPToolCallResultCallback, + MCPToolCallResultContext, MCPToolMetaContext, MCPToolMetaResolver, MCPUtil, @@ -50,6 +52,8 @@ "MCPServerManager", "LocalMCPApprovalCallable", "MCPUtil", + "MCPToolCallResultCallback", + "MCPToolCallResultContext", "MCPToolMetaContext", "MCPToolMetaResolver", "ToolFilter", diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 8d3bdd752a..b67d37fcac 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -45,6 +45,7 @@ from ..util._types import MaybeAwaitable from .util import ( HttpClientFactory, + MCPToolCallResultCallback, MCPToolMetaResolver, ToolFilter, ToolFilterContext, @@ -229,6 +230,7 @@ def __init__( require_approval: RequireApprovalSetting = None, failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET, tool_meta_resolver: MCPToolMetaResolver | None = None, + tool_call_result_callback: MCPToolCallResultCallback | None = None, ): """ Args: @@ -248,6 +250,9 @@ def __init__( SDK default) will be used. tool_meta_resolver: Optional callable that produces MCP request metadata (`_meta`) for tool calls. It is invoked by the Agents SDK before calling `call_tool`. + tool_call_result_callback: Optional callback invoked after an MCP tool call returns. The + callback receives result metadata and the model-visible tool output, but cannot + change the output returned to the model. """ self.use_structured_content = use_structured_content self._needs_approval_policy = self._normalize_needs_approval( @@ -255,6 +260,7 @@ def __init__( ) self._failure_error_function = failure_error_function self.tool_meta_resolver = tool_meta_resolver + self.tool_call_result_callback = tool_call_result_callback @abc.abstractmethod async def connect(self): @@ -544,6 +550,7 @@ def __init__( require_approval: RequireApprovalSetting = None, failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET, tool_meta_resolver: MCPToolMetaResolver | None = None, + tool_call_result_callback: MCPToolCallResultCallback | None = None, ): """ Args: @@ -576,12 +583,16 @@ def __init__( SDK default) will be used. tool_meta_resolver: Optional callable that produces MCP request metadata (`_meta`) for tool calls. It is invoked by the Agents SDK before calling `call_tool`. + tool_call_result_callback: Optional callback invoked after an MCP tool call returns. The + callback receives result metadata and the model-visible tool output, but cannot + change the output returned to the model. """ super().__init__( use_structured_content=use_structured_content, require_approval=require_approval, failure_error_function=failure_error_function, tool_meta_resolver=tool_meta_resolver, + tool_call_result_callback=tool_call_result_callback, ) self.session: ClientSession | None = None self.exit_stack: AsyncExitStack = AsyncExitStack() @@ -1108,6 +1119,7 @@ def __init__( require_approval: RequireApprovalSetting = None, failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET, tool_meta_resolver: MCPToolMetaResolver | None = None, + tool_call_result_callback: MCPToolCallResultCallback | None = None, ): """Create a new MCP server based on the stdio transport. @@ -1145,6 +1157,9 @@ def __init__( SDK default) will be used. tool_meta_resolver: Optional callable that produces MCP request metadata (`_meta`) for tool calls. It is invoked by the Agents SDK before calling `call_tool`. + tool_call_result_callback: Optional callback invoked after an MCP tool call returns. The + callback receives result metadata and the model-visible tool output, but cannot + change the output returned to the model. """ super().__init__( cache_tools_list=cache_tools_list, @@ -1157,6 +1172,7 @@ def __init__( require_approval=require_approval, failure_error_function=failure_error_function, tool_meta_resolver=tool_meta_resolver, + tool_call_result_callback=tool_call_result_callback, ) self.params = StdioServerParameters( @@ -1229,6 +1245,7 @@ def __init__( require_approval: RequireApprovalSetting = None, failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET, tool_meta_resolver: MCPToolMetaResolver | None = None, + tool_call_result_callback: MCPToolCallResultCallback | None = None, ): """Create a new MCP server based on the HTTP with SSE transport. @@ -1268,6 +1285,9 @@ def __init__( SDK default) will be used. tool_meta_resolver: Optional callable that produces MCP request metadata (`_meta`) for tool calls. It is invoked by the Agents SDK before calling `call_tool`. + tool_call_result_callback: Optional callback invoked after an MCP tool call returns. The + callback receives result metadata and the model-visible tool output, but cannot + change the output returned to the model. """ super().__init__( cache_tools_list=cache_tools_list, @@ -1280,6 +1300,7 @@ def __init__( require_approval=require_approval, failure_error_function=failure_error_function, tool_meta_resolver=tool_meta_resolver, + tool_call_result_callback=tool_call_result_callback, ) self.params = params @@ -1365,6 +1386,7 @@ def __init__( require_approval: RequireApprovalSetting = None, failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET, tool_meta_resolver: MCPToolMetaResolver | None = None, + tool_call_result_callback: MCPToolCallResultCallback | None = None, ): """Create a new MCP server based on the Streamable HTTP transport. @@ -1405,6 +1427,9 @@ def __init__( SDK default) will be used. tool_meta_resolver: Optional callable that produces MCP request metadata (`_meta`) for tool calls. It is invoked by the Agents SDK before calling `call_tool`. + tool_call_result_callback: Optional callback invoked after an MCP tool call returns. The + callback receives result metadata and the model-visible tool output, but cannot + change the output returned to the model. """ super().__init__( cache_tools_list=cache_tools_list, @@ -1417,6 +1442,7 @@ def __init__( require_approval=require_approval, failure_error_function=failure_error_function, tool_meta_resolver=tool_meta_resolver, + tool_call_result_callback=tool_call_result_callback, ) self.params = params diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index bf00cb2b79..650f4817fb 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -7,8 +7,9 @@ import inspect import json from collections import Counter -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass +from types import MappingProxyType from typing import TYPE_CHECKING, Any, Protocol, Union import httpx @@ -149,13 +150,50 @@ class MCPToolMetaContext: """The parsed tool arguments.""" +@dataclass(frozen=True) +class MCPToolCallResultContext: + """Context information available when an MCP tool call returns a result.""" + + run_context: RunContextWrapper[Any] + """The current run context.""" + + server_name: str + """The name of the MCP server.""" + + tool_name: str + """The original MCP tool name invoked on the server.""" + + tool_display_name: str + """The public tool name exposed through the Agents SDK.""" + + arguments: Mapping[str, Any] + """The parsed tool arguments.""" + + result_meta: Mapping[str, Any] | None + """The MCP tool result `_meta` payload, if present.""" + + structured_content: Mapping[str, Any] | None + """The MCP tool result `structuredContent` payload, if present.""" + + is_error: bool | None + """The MCP tool result `isError` flag, if present.""" + + tool_output: ToolOutput + """The model-visible tool output produced by the Agents SDK.""" + + if TYPE_CHECKING: MCPToolMetaResolver = Callable[ [MCPToolMetaContext], MaybeAwaitable[dict[str, Any] | None], ] + MCPToolCallResultCallback = Callable[ + [MCPToolCallResultContext], + MaybeAwaitable[None], + ] else: MCPToolMetaResolver = Callable[..., Any] + MCPToolCallResultCallback = Callable[..., Any] """A function that produces MCP request metadata for tool calls. Args: @@ -164,6 +202,7 @@ class MCPToolMetaContext: Returns: A dict to send as MCP `_meta`, or None to omit metadata. """ +"""A callback that observes MCP tool call results without changing tool output.""" def create_static_tool_filter( @@ -541,6 +580,43 @@ def _merge_mcp_meta( merged.update(copy.deepcopy(explicit_meta)) return merged + @staticmethod + def _copy_mapping_proxy(value: Any) -> Mapping[str, Any] | None: + if not isinstance(value, dict): + return None + return MappingProxyType(copy.deepcopy(value)) + + @classmethod + async def _maybe_call_tool_result_callback( + cls, + *, + server: MCPServer, + context: RunContextWrapper[Any], + tool_name: str, + tool_display_name: str, + arguments: dict[str, Any], + result: Any, + tool_output: ToolOutput, + ) -> None: + callback = getattr(server, "tool_call_result_callback", None) + if callback is None: + return + + callback_context = MCPToolCallResultContext( + run_context=context, + server_name=server.name, + tool_name=tool_name, + tool_display_name=tool_display_name, + arguments=MappingProxyType(copy.deepcopy(arguments)), + result_meta=cls._copy_mapping_proxy(getattr(result, "meta", None)), + structured_content=cls._copy_mapping_proxy(getattr(result, "structuredContent", None)), + is_error=getattr(result, "isError", None), + tool_output=copy.deepcopy(tool_output), + ) + callback_result = callback(callback_context) + if inspect.isawaitable(callback_result): + await callback_result + @classmethod async def _resolve_meta( cls, @@ -688,6 +764,16 @@ async def invoke_mcp_tool( else: tool_output = tool_output_list + await cls._maybe_call_tool_result_callback( + server=server, + context=context, + tool_name=tool.name, + tool_display_name=tool_name_for_display, + arguments=json_data, + result=result, + tool_output=tool_output, + ) + current_span = get_current_span() if current_span: if isinstance(current_span.span_data, FunctionSpanData): diff --git a/tests/mcp/helpers.py b/tests/mcp/helpers.py index ef820fad99..efadf2afb6 100644 --- a/tests/mcp/helpers.py +++ b/tests/mcp/helpers.py @@ -20,7 +20,7 @@ from agents.mcp import MCPServer from agents.mcp.server import _UNSET, _MCPServerWithClientSession, _UnsetType -from agents.mcp.util import MCPToolMetaResolver, ToolFilter +from agents.mcp.util import MCPToolCallResultCallback, MCPToolMetaResolver, ToolFilter from agents.tool import ToolErrorFunction tee = shutil.which("tee") or "" @@ -76,12 +76,14 @@ def __init__( require_approval: object | None = None, failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET, tool_meta_resolver: MCPToolMetaResolver | None = None, + tool_call_result_callback: MCPToolCallResultCallback | None = None, ): super().__init__( use_structured_content=False, require_approval=require_approval, # type: ignore[arg-type] failure_error_function=failure_error_function, tool_meta_resolver=tool_meta_resolver, + tool_call_result_callback=tool_call_result_callback, ) self.tools: list[MCPTool] = tools or [] self.tool_calls: list[str] = [] diff --git a/tests/mcp/test_mcp_util.py b/tests/mcp/test_mcp_util.py index e41884375d..b8f04ccc96 100644 --- a/tests/mcp/test_mcp_util.py +++ b/tests/mcp/test_mcp_util.py @@ -531,6 +531,88 @@ def resolve_meta(context): assert args == {"foo": "bar"} +@pytest.mark.asyncio +async def test_mcp_tool_call_result_callback_observes_result_without_mutating_output(): + captured: dict[str, Any] = {} + + def tool_call_result_callback(context): + captured["run_context"] = context.run_context + captured["server_name"] = context.server_name + captured["tool_name"] = context.tool_name + captured["tool_display_name"] = context.tool_display_name + captured["arguments"] = dict(context.arguments) + captured["result_meta"] = dict(context.result_meta or {}) + captured["structured_content"] = dict(context.structured_content or {}) + captured["is_error"] = context.is_error + captured["tool_output"] = context.tool_output + + with pytest.raises(TypeError): + context.arguments["mutated"] = True + + if isinstance(context.tool_output, dict): + context.tool_output["text"] = "mutated" + + class ResultMCPServer(FakeMCPServer): + async def call_tool( + self, + tool_name: str, + arguments: dict[str, Any] | None, + meta: dict[str, Any] | None = None, + ) -> CallToolResult: + self.tool_calls.append(tool_name) + self.tool_metas.append(meta) + return CallToolResult( + content=[TextContent(text="model summary", type="text")], + structuredContent={"rows": 1245}, + isError=False, + _meta={"frontend": {"type": "chart"}}, + ) + + server = ResultMCPServer( + server_name="analytics", + tool_call_result_callback=tool_call_result_callback, + ) + ctx: RunContextWrapper[dict[str, Any]] = RunContextWrapper(context={"frontend_events": []}) + tool = MCPTool(name="query_sales", inputSchema={}) + + output = await MCPUtil.invoke_mcp_tool( + server, + tool, + ctx, + '{"sql": "select 1"}', + tool_display_name="analytics__query_sales", + ) + + assert output == {"type": "text", "text": "model summary"} + assert captured == { + "run_context": ctx, + "server_name": "analytics", + "tool_name": "query_sales", + "tool_display_name": "analytics__query_sales", + "arguments": {"sql": "select 1"}, + "result_meta": {"frontend": {"type": "chart"}}, + "structured_content": {"rows": 1245}, + "is_error": False, + "tool_output": {"type": "text", "text": "mutated"}, + } + + +@pytest.mark.asyncio +async def test_mcp_tool_call_result_callback_can_be_async(): + captured: list[str] = [] + + async def tool_call_result_callback(context): + captured.append(context.tool_name) + + server = FakeMCPServer(tool_call_result_callback=tool_call_result_callback) + ctx = RunContextWrapper(context=None) + tool = MCPTool(name="test_tool_1", inputSchema={}) + + await MCPUtil.invoke_mcp_tool(server, tool, ctx, "{}") + + assert captured == ["test_tool_1"] + + @pytest.mark.asyncio async def test_to_function_tool_passes_static_mcp_meta(): server = FakeMCPServer()