diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 51a627c7c..8bd43b54f 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -565,6 +565,7 @@ def _create_call_tool_coroutine( name: str, arguments: dict[str, Any] | None, read_timeout_seconds: timedelta | None, + meta: dict[str, Any] | None = None, ) -> Coroutine[Any, Any, MCPCallToolResult]: """Create the appropriate coroutine for calling a tool. @@ -575,6 +576,7 @@ def _create_call_tool_coroutine( name: Name of the tool to call. arguments: Optional arguments to pass to the tool. read_timeout_seconds: Optional timeout for the tool call. + meta: Optional metadata to pass to the tool call per MCP spec (_meta). Returns: A coroutine that will execute the tool call. @@ -595,7 +597,7 @@ async def _call_as_task() -> MCPCallToolResult: async def _call_tool_direct() -> MCPCallToolResult: return await cast(ClientSession, self._background_thread_session).call_tool( - name, arguments, read_timeout_seconds + name, arguments, read_timeout_seconds, meta=meta ) return _call_tool_direct() @@ -606,6 +608,7 @@ def call_tool_sync( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, + meta: dict[str, Any] | None = None, ) -> MCPToolResult: """Synchronously calls a tool on the MCP server. @@ -617,6 +620,7 @@ def call_tool_sync( name: Name of the tool to call arguments: Optional arguments to pass to the tool read_timeout_seconds: Optional timeout for the tool call + meta: Optional metadata to pass to the tool call per MCP spec (_meta) Returns: MCPToolResult: The result of the tool call @@ -626,7 +630,7 @@ def call_tool_sync( raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) try: - coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds) + coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta) call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(coro).result() return self._handle_tool_result(tool_use_id, call_tool_result) except Exception as e: @@ -639,6 +643,7 @@ async def call_tool_async( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, + meta: dict[str, Any] | None = None, ) -> MCPToolResult: """Asynchronously calls a tool on the MCP server. @@ -650,6 +655,7 @@ async def call_tool_async( name: Name of the tool to call arguments: Optional arguments to pass to the tool read_timeout_seconds: Optional timeout for the tool call + meta: Optional metadata to pass to the tool call per MCP spec (_meta) Returns: MCPToolResult: The result of the tool call @@ -659,7 +665,7 @@ async def call_tool_async( raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) try: - coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds) + coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta) future = self._invoke_on_background_thread(coro) call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future) return self._handle_tool_result(tool_use_id, call_tool_result) diff --git a/src/strands/tools/mcp/mcp_instrumentation.py b/src/strands/tools/mcp/mcp_instrumentation.py index d1750daa3..ad7805b0c 100644 --- a/src/strands/tools/mcp/mcp_instrumentation.py +++ b/src/strands/tools/mcp/mcp_instrumentation.py @@ -90,7 +90,7 @@ def patch_mcp_client(wrapped: Callable[..., Any], instance: Any, args: Any, kwar if hasattr(request.root, "params") and request.root.params: # Handle Pydantic models if hasattr(request.root.params, "model_dump") and hasattr(request.root.params, "model_validate"): - params_dict = request.root.params.model_dump() + params_dict = request.root.params.model_dump(by_alias=True) # Add _meta with tracing context meta = params_dict.setdefault("_meta", {}) propagate.get_global_textmap().inject(meta) diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index 5eedd1e33..47dec78d4 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -124,7 +124,7 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_ with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None) + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) assert result["status"] == expected_status assert result["toolUseId"] == "test-123" @@ -153,7 +153,7 @@ def test_call_tool_sync_with_structured_content(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None) + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) assert result["status"] == "success" assert result["toolUseId"] == "test-123" @@ -180,6 +180,51 @@ def test_call_tool_sync_exception(mock_transport, mock_session): assert "Test exception" in result["content"][0]["text"] +def test_call_tool_sync_forwards_meta(mock_transport, mock_session): + """Test that call_tool_sync forwards meta to ClientSession.call_tool.""" + mock_content = MCPTextContent(type="text", text="Test message") + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content]) + meta = {"com.example/request_id": "abc-123"} + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync( + tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, meta=meta + ) + + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=meta) + assert result["status"] == "success" + + +@pytest.mark.asyncio +async def test_call_tool_async_forwards_meta(mock_transport, mock_session): + """Test that call_tool_async forwards meta to ClientSession.call_tool.""" + mock_content = MCPTextContent(type="text", text="Test message") + mock_result = MCPCallToolResult(isError=False, content=[mock_content]) + mock_session.call_tool.return_value = mock_result + meta = {"com.example/request_id": "abc-123"} + + with MCPClient(mock_transport["transport_callable"]) as client: + with ( + patch("asyncio.run_coroutine_threadsafe") as mock_run_coroutine_threadsafe, + patch("asyncio.wrap_future") as mock_wrap_future, + ): + mock_future = MagicMock() + mock_run_coroutine_threadsafe.return_value = mock_future + + async def mock_awaitable(): + return mock_result + + mock_wrap_future.return_value = mock_awaitable() + + result = await client.call_tool_async( + tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, meta=meta + ) + + mock_run_coroutine_threadsafe.assert_called_once() + + assert result["status"] == "success" + + @pytest.mark.asyncio @pytest.mark.parametrize("is_error,expected_status", [(False, "success"), (True, "error")]) async def test_call_tool_async_status(mock_transport, mock_session, is_error, expected_status): @@ -584,7 +629,7 @@ def test_call_tool_sync_embedded_nested_text(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-text", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert result["content"][0]["text"] == "inner text" @@ -609,7 +654,7 @@ def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-blob", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert result["content"][0]["text"] == '{"k":"v"}' @@ -635,7 +680,7 @@ def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-image", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert "image" in result["content"][0] @@ -660,7 +705,7 @@ def test_call_tool_sync_embedded_non_textual_blob_dropped(mock_transport, mock_s with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-binary", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 0 # Content should be dropped @@ -683,7 +728,7 @@ def test_call_tool_sync_embedded_multiple_textual_mimes(mock_transport, mock_ses with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-yaml", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert "key: value" in result["content"][0]["text"] @@ -710,7 +755,7 @@ def __init__(self): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-unknown", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 0 # Unknown resource type should be dropped @@ -762,7 +807,7 @@ def test_call_tool_sync_with_meta_and_structured_content(mock_transport, mock_se with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None) + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) assert result["status"] == "success" assert result["toolUseId"] == "test-123" diff --git a/tests/strands/tools/mcp/test_mcp_instrumentation.py b/tests/strands/tools/mcp/test_mcp_instrumentation.py index 85d533403..e787d006f 100644 --- a/tests/strands/tools/mcp/test_mcp_instrumentation.py +++ b/tests/strands/tools/mcp/test_mcp_instrumentation.py @@ -328,7 +328,7 @@ class MockPydanticParams: def __init__(self, **data): self._data = data - def model_dump(self): + def model_dump(self, by_alias=False): return self._data.copy() @classmethod @@ -431,6 +431,31 @@ def test_patch_mcp_client_injects_context_pydantic_model(self): # Verify the params object is still a MockPydanticParams (or dict if fallback occurred) assert hasattr(mock_request.root.params, "model_dump") or isinstance(mock_request.root.params, dict) + def test_patch_mcp_client_preserves_existing_meta_pydantic(self): + """Test that instrumentation preserves existing _meta values in Pydantic models.""" + mock_request = MagicMock() + mock_request.root.method = "tools/call" + + # Pydantic model with existing _meta (returned via by_alias=True) + mock_params = MockPydanticParams(_meta={"com.example/request_id": "abc-123"}, name="echo") + mock_request.root.params = mock_params + + with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: + mcp_instrumentation() + patch_function = mock_wrap.call_args_list[0][0][2] + + mock_wrapped = MagicMock() + + with patch.object(propagate, "get_global_textmap") as mock_textmap: + mock_textmap_instance = MagicMock() + mock_textmap.return_value = mock_textmap_instance + + patch_function(mock_wrapped, None, [mock_request], {}) + + # inject should be called with the existing _meta dict (not a new empty one) + inject_call_args = mock_textmap_instance.inject.call_args[0][0] + assert inject_call_args.get("com.example/request_id") == "abc-123" + def test_patch_mcp_client_injects_context_dict_params(self): """Test that the client patch injects OpenTelemetry context into dict params.""" # Create a mock request with tools/call method and dict params @@ -507,7 +532,7 @@ class FailingMockPydanticParams: def __init__(self, **data): self._data = data - def model_dump(self): + def model_dump(self, by_alias=False): return self._data.copy() def model_validate(self, data): diff --git a/tests_integ/mcp/test_mcp_client.py b/tests_integ/mcp/test_mcp_client.py index 130b35529..46353a358 100644 --- a/tests_integ/mcp/test_mcp_client.py +++ b/tests_integ/mcp/test_mcp_client.py @@ -238,6 +238,87 @@ def test_mcp_client_without_structured_content(): assert result["content"] == [{"text": "SIMPLE_ECHO_TEST"}] +def test_call_tool_sync_with_meta(): + """Test that call_tool_sync works correctly when meta is provided.""" + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with stdio_mcp_client: + result = stdio_mcp_client.call_tool_sync( + tool_use_id="test-meta-sync", + name="echo", + arguments={"to_echo": "META_TEST"}, + meta={"com.example/request_id": "abc-123"}, + ) + + assert result["status"] == "success" + assert result["content"] == [{"text": "META_TEST"}] + + +@pytest.mark.asyncio +async def test_call_tool_async_with_meta(): + """Test that call_tool_async works correctly when meta is provided.""" + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with stdio_mcp_client: + result = await stdio_mcp_client.call_tool_async( + tool_use_id="test-meta-async", + name="echo", + arguments={"to_echo": "META_ASYNC_TEST"}, + meta={"com.example/request_id": "def-456"}, + ) + + assert result["status"] == "success" + assert result["content"] == [{"text": "META_ASYNC_TEST"}] + + +def test_instrumentation_preserves_meta_on_tool_call(): + """Test that OTel instrumentation correctly sets _meta on outgoing tool call requests.""" + captured_params = [] + + def spy_send_request(wrapped, instance, args, kwargs): + if args: + request = args[0] + method = getattr(getattr(request, "root", None), "method", None) + if method == "tools/call" and hasattr(request.root, "params"): + params = request.root.params + if hasattr(params, "model_dump"): + captured_params.append(params.model_dump(by_alias=True)) + elif isinstance(params, dict): + captured_params.append(params.copy()) + return wrapped(*args, **kwargs) + + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with stdio_mcp_client: + from mcp.shared.session import BaseSession + from wrapt import wrap_function_wrapper + + original_send = BaseSession.send_request + wrap_function_wrapper("mcp.shared.session", "BaseSession.send_request", spy_send_request) + + try: + result = stdio_mcp_client.call_tool_sync( + tool_use_id="test-instrumentation", + name="echo", + arguments={"to_echo": "INSTRUMENTATION_TEST"}, + ) + + assert result["status"] == "success" + assert len(captured_params) > 0 + + params = captured_params[-1] + assert "_meta" in params + assert isinstance(params["_meta"], dict) + finally: + BaseSession.send_request = original_send + + @pytest.mark.skipif( condition=os.environ.get("GITHUB_ACTIONS") == "true", reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue",