Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/agents/mcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
)

from .util import (
MCPToolCallResultCallback,
MCPToolCallResultContext,
MCPToolMetaContext,
MCPToolMetaResolver,
MCPUtil,
Expand Down Expand Up @@ -50,6 +52,8 @@
"MCPServerManager",
"LocalMCPApprovalCallable",
"MCPUtil",
"MCPToolCallResultCallback",
"MCPToolCallResultContext",
"MCPToolMetaContext",
"MCPToolMetaResolver",
"ToolFilter",
Expand Down
26 changes: 26 additions & 0 deletions src/agents/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ..util._types import MaybeAwaitable
from .util import (
HttpClientFactory,
MCPToolCallResultCallback,
MCPToolMetaResolver,
ToolFilter,
ToolFilterContext,
Expand Down Expand Up @@ -229,6 +230,7 @@ def __init__(
require_approval: RequireApprovalSetting = None,
failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET,
tool_meta_resolver: MCPToolMetaResolver | None = None,
tool_call_result_callback: MCPToolCallResultCallback | None = None,
):
"""
Args:
Expand All @@ -248,13 +250,17 @@ def __init__(
SDK default) will be used.
tool_meta_resolver: Optional callable that produces MCP request metadata (`_meta`) for
tool calls. It is invoked by the Agents SDK before calling `call_tool`.
tool_call_result_callback: Optional callback invoked after an MCP tool call returns. The
callback receives result metadata and the model-visible tool output, but cannot
change the output returned to the model.
"""
self.use_structured_content = use_structured_content
self._needs_approval_policy = self._normalize_needs_approval(
require_approval=require_approval
)
self._failure_error_function = failure_error_function
self.tool_meta_resolver = tool_meta_resolver
self.tool_call_result_callback = tool_call_result_callback

@abc.abstractmethod
async def connect(self):
Expand Down Expand Up @@ -544,6 +550,7 @@ def __init__(
require_approval: RequireApprovalSetting = None,
failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET,
tool_meta_resolver: MCPToolMetaResolver | None = None,
tool_call_result_callback: MCPToolCallResultCallback | None = None,
):
"""
Args:
Expand Down Expand Up @@ -576,12 +583,16 @@ def __init__(
SDK default) will be used.
tool_meta_resolver: Optional callable that produces MCP request metadata (`_meta`) for
tool calls. It is invoked by the Agents SDK before calling `call_tool`.
tool_call_result_callback: Optional callback invoked after an MCP tool call returns. The
callback receives result metadata and the model-visible tool output, but cannot
change the output returned to the model.
"""
super().__init__(
use_structured_content=use_structured_content,
require_approval=require_approval,
failure_error_function=failure_error_function,
tool_meta_resolver=tool_meta_resolver,
tool_call_result_callback=tool_call_result_callback,
)
self.session: ClientSession | None = None
self.exit_stack: AsyncExitStack = AsyncExitStack()
Expand Down Expand Up @@ -1108,6 +1119,7 @@ def __init__(
require_approval: RequireApprovalSetting = None,
failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET,
tool_meta_resolver: MCPToolMetaResolver | None = None,
tool_call_result_callback: MCPToolCallResultCallback | None = None,
):
"""Create a new MCP server based on the stdio transport.

Expand Down Expand Up @@ -1145,6 +1157,9 @@ def __init__(
SDK default) will be used.
tool_meta_resolver: Optional callable that produces MCP request metadata (`_meta`) for
tool calls. It is invoked by the Agents SDK before calling `call_tool`.
tool_call_result_callback: Optional callback invoked after an MCP tool call returns. The
callback receives result metadata and the model-visible tool output, but cannot
change the output returned to the model.
"""
super().__init__(
cache_tools_list=cache_tools_list,
Expand All @@ -1157,6 +1172,7 @@ def __init__(
require_approval=require_approval,
failure_error_function=failure_error_function,
tool_meta_resolver=tool_meta_resolver,
tool_call_result_callback=tool_call_result_callback,
)

self.params = StdioServerParameters(
Expand Down Expand Up @@ -1229,6 +1245,7 @@ def __init__(
require_approval: RequireApprovalSetting = None,
failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET,
tool_meta_resolver: MCPToolMetaResolver | None = None,
tool_call_result_callback: MCPToolCallResultCallback | None = None,
):
"""Create a new MCP server based on the HTTP with SSE transport.

Expand Down Expand Up @@ -1268,6 +1285,9 @@ def __init__(
SDK default) will be used.
tool_meta_resolver: Optional callable that produces MCP request metadata (`_meta`) for
tool calls. It is invoked by the Agents SDK before calling `call_tool`.
tool_call_result_callback: Optional callback invoked after an MCP tool call returns. The
callback receives result metadata and the model-visible tool output, but cannot
change the output returned to the model.
"""
super().__init__(
cache_tools_list=cache_tools_list,
Expand All @@ -1280,6 +1300,7 @@ def __init__(
require_approval=require_approval,
failure_error_function=failure_error_function,
tool_meta_resolver=tool_meta_resolver,
tool_call_result_callback=tool_call_result_callback,
)

self.params = params
Expand Down Expand Up @@ -1365,6 +1386,7 @@ def __init__(
require_approval: RequireApprovalSetting = None,
failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET,
tool_meta_resolver: MCPToolMetaResolver | None = None,
tool_call_result_callback: MCPToolCallResultCallback | None = None,
):
"""Create a new MCP server based on the Streamable HTTP transport.

Expand Down Expand Up @@ -1405,6 +1427,9 @@ def __init__(
SDK default) will be used.
tool_meta_resolver: Optional callable that produces MCP request metadata (`_meta`) for
tool calls. It is invoked by the Agents SDK before calling `call_tool`.
tool_call_result_callback: Optional callback invoked after an MCP tool call returns. The
callback receives result metadata and the model-visible tool output, but cannot
change the output returned to the model.
"""
super().__init__(
cache_tools_list=cache_tools_list,
Expand All @@ -1417,6 +1442,7 @@ def __init__(
require_approval=require_approval,
failure_error_function=failure_error_function,
tool_meta_resolver=tool_meta_resolver,
tool_call_result_callback=tool_call_result_callback,
)

self.params = params
Expand Down
88 changes: 87 additions & 1 deletion src/agents/mcp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import inspect
import json
from collections import Counter
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Protocol, Union

import httpx
Expand Down Expand Up @@ -149,13 +150,50 @@ class MCPToolMetaContext:
"""The parsed tool arguments."""


@dataclass(frozen=True)
class MCPToolCallResultContext:
"""Context information available when an MCP tool call returns a result."""

run_context: RunContextWrapper[Any]
"""The current run context."""

server_name: str
"""The name of the MCP server."""

tool_name: str
"""The original MCP tool name invoked on the server."""

tool_display_name: str
"""The public tool name exposed through the Agents SDK."""

arguments: Mapping[str, Any]
"""The parsed tool arguments."""

result_meta: Mapping[str, Any] | None
"""The MCP tool result `_meta` payload, if present."""

structured_content: Mapping[str, Any] | None
"""The MCP tool result `structuredContent` payload, if present."""

is_error: bool | None
"""The MCP tool result `isError` flag, if present."""

tool_output: ToolOutput
"""The model-visible tool output produced by the Agents SDK."""


if TYPE_CHECKING:
MCPToolMetaResolver = Callable[
[MCPToolMetaContext],
MaybeAwaitable[dict[str, Any] | None],
]
MCPToolCallResultCallback = Callable[
[MCPToolCallResultContext],
MaybeAwaitable[None],
]
else:
MCPToolMetaResolver = Callable[..., Any]
MCPToolCallResultCallback = Callable[..., Any]
"""A function that produces MCP request metadata for tool calls.

Args:
Expand All @@ -164,6 +202,7 @@ class MCPToolMetaContext:
Returns:
A dict to send as MCP `_meta`, or None to omit metadata.
"""
"""A callback that observes MCP tool call results without changing tool output."""


def create_static_tool_filter(
Expand Down Expand Up @@ -541,6 +580,43 @@ def _merge_mcp_meta(
merged.update(copy.deepcopy(explicit_meta))
return merged

@staticmethod
def _copy_mapping_proxy(value: Any) -> Mapping[str, Any] | None:
if not isinstance(value, dict):
return None
return MappingProxyType(copy.deepcopy(value))

@classmethod
async def _maybe_call_tool_result_callback(
cls,
*,
server: MCPServer,
context: RunContextWrapper[Any],
tool_name: str,
tool_display_name: str,
arguments: dict[str, Any],
result: Any,
tool_output: ToolOutput,
) -> None:
callback = getattr(server, "tool_call_result_callback", None)
if callback is None:
return

callback_context = MCPToolCallResultContext(
run_context=context,
server_name=server.name,
tool_name=tool_name,
tool_display_name=tool_display_name,
arguments=MappingProxyType(copy.deepcopy(arguments)),
result_meta=cls._copy_mapping_proxy(getattr(result, "meta", None)),
structured_content=cls._copy_mapping_proxy(getattr(result, "structuredContent", None)),
is_error=getattr(result, "isError", None),
tool_output=copy.deepcopy(tool_output),
)
callback_result = callback(callback_context)
if inspect.isawaitable(callback_result):
await callback_result

@classmethod
async def _resolve_meta(
cls,
Expand Down Expand Up @@ -688,6 +764,16 @@ async def invoke_mcp_tool(
else:
tool_output = tool_output_list

await cls._maybe_call_tool_result_callback(
server=server,
context=context,
tool_name=tool.name,
tool_display_name=tool_name_for_display,
arguments=json_data,
result=result,
tool_output=tool_output,
)

current_span = get_current_span()
if current_span:
if isinstance(current_span.span_data, FunctionSpanData):
Expand Down
4 changes: 3 additions & 1 deletion tests/mcp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from agents.mcp import MCPServer
from agents.mcp.server import _UNSET, _MCPServerWithClientSession, _UnsetType
from agents.mcp.util import MCPToolMetaResolver, ToolFilter
from agents.mcp.util import MCPToolCallResultCallback, MCPToolMetaResolver, ToolFilter
from agents.tool import ToolErrorFunction

tee = shutil.which("tee") or ""
Expand Down Expand Up @@ -76,12 +76,14 @@ def __init__(
require_approval: object | None = None,
failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET,
tool_meta_resolver: MCPToolMetaResolver | None = None,
tool_call_result_callback: MCPToolCallResultCallback | None = None,
):
super().__init__(
use_structured_content=False,
require_approval=require_approval, # type: ignore[arg-type]
failure_error_function=failure_error_function,
tool_meta_resolver=tool_meta_resolver,
tool_call_result_callback=tool_call_result_callback,
)
self.tools: list[MCPTool] = tools or []
self.tool_calls: list[str] = []
Expand Down
82 changes: 82 additions & 0 deletions tests/mcp/test_mcp_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,88 @@ def resolve_meta(context):
assert args == {"foo": "bar"}


@pytest.mark.asyncio
async def test_mcp_tool_call_result_callback_observes_result_without_mutating_output():
captured: dict[str, Any] = {}

def tool_call_result_callback(context):
captured["run_context"] = context.run_context
captured["server_name"] = context.server_name
captured["tool_name"] = context.tool_name
captured["tool_display_name"] = context.tool_display_name
captured["arguments"] = dict(context.arguments)
captured["result_meta"] = dict(context.result_meta or {})
captured["structured_content"] = dict(context.structured_content or {})
captured["is_error"] = context.is_error
captured["tool_output"] = context.tool_output

with pytest.raises(TypeError):
context.arguments["mutated"] = True

if isinstance(context.tool_output, dict):
context.tool_output["text"] = "mutated"

class ResultMCPServer(FakeMCPServer):
async def call_tool(
self,
tool_name: str,
arguments: dict[str, Any] | None,
meta: dict[str, Any] | None = None,
) -> CallToolResult:
self.tool_calls.append(tool_name)
self.tool_metas.append(meta)
return CallToolResult(
content=[TextContent(text="model summary", type="text")],
structuredContent={"rows": 1245},
isError=False,
_meta={"frontend": {"type": "chart"}},
)

server = ResultMCPServer(
server_name="analytics",
tool_call_result_callback=tool_call_result_callback,
)
ctx: RunContextWrapper[dict[str, Any]] = RunContextWrapper(context={"frontend_events": []})
tool = MCPTool(name="query_sales", inputSchema={})

output = await MCPUtil.invoke_mcp_tool(
server,
tool,
ctx,
'{"sql": "select 1"}',
tool_display_name="analytics__query_sales",
)

assert output == {"type": "text", "text": "model summary"}
assert captured == {
"run_context": ctx,
"server_name": "analytics",
"tool_name": "query_sales",
"tool_display_name": "analytics__query_sales",
"arguments": {"sql": "select 1"},
"result_meta": {"frontend": {"type": "chart"}},
"structured_content": {"rows": 1245},
"is_error": False,
"tool_output": {"type": "text", "text": "mutated"},
}


@pytest.mark.asyncio
async def test_mcp_tool_call_result_callback_can_be_async():
captured: list[str] = []

async def tool_call_result_callback(context):
captured.append(context.tool_name)

server = FakeMCPServer(tool_call_result_callback=tool_call_result_callback)
ctx = RunContextWrapper(context=None)
tool = MCPTool(name="test_tool_1", inputSchema={})

await MCPUtil.invoke_mcp_tool(server, tool, ctx, "{}")

assert captured == ["test_tool_1"]


@pytest.mark.asyncio
async def test_to_function_tool_passes_static_mcp_meta():
server = FakeMCPServer()
Expand Down
Loading