Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/strands/tools/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ def _create_call_tool_coroutine(
name: str,
arguments: dict[str, Any] | None,
read_timeout_seconds: timedelta | None,
meta: dict[str, Any] | None = None,
) -> Coroutine[Any, Any, MCPCallToolResult]:
"""Create the appropriate coroutine for calling a tool.

Expand All @@ -575,6 +576,7 @@ def _create_call_tool_coroutine(
name: Name of the tool to call.
arguments: Optional arguments to pass to the tool.
read_timeout_seconds: Optional timeout for the tool call.
meta: Optional metadata to pass to the tool call per MCP spec (_meta).

Returns:
A coroutine that will execute the tool call.
Expand All @@ -595,7 +597,7 @@ async def _call_as_task() -> MCPCallToolResult:

async def _call_tool_direct() -> MCPCallToolResult:
return await cast(ClientSession, self._background_thread_session).call_tool(
name, arguments, read_timeout_seconds
name, arguments, read_timeout_seconds, meta=meta
)

return _call_tool_direct()
Expand All @@ -606,6 +608,7 @@ def call_tool_sync(
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
meta: dict[str, Any] | None = None,
) -> MCPToolResult:
"""Synchronously calls a tool on the MCP server.

Expand All @@ -617,6 +620,7 @@ def call_tool_sync(
name: Name of the tool to call
arguments: Optional arguments to pass to the tool
read_timeout_seconds: Optional timeout for the tool call
meta: Optional metadata to pass to the tool call per MCP spec (_meta)

Returns:
MCPToolResult: The result of the tool call
Expand All @@ -626,7 +630,7 @@ def call_tool_sync(
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)

try:
coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds)
coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta)
call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(coro).result()
return self._handle_tool_result(tool_use_id, call_tool_result)
except Exception as e:
Expand All @@ -639,6 +643,7 @@ async def call_tool_async(
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
meta: dict[str, Any] | None = None,
) -> MCPToolResult:
"""Asynchronously calls a tool on the MCP server.

Expand All @@ -650,6 +655,7 @@ async def call_tool_async(
name: Name of the tool to call
arguments: Optional arguments to pass to the tool
read_timeout_seconds: Optional timeout for the tool call
meta: Optional metadata to pass to the tool call per MCP spec (_meta)

Returns:
MCPToolResult: The result of the tool call
Expand All @@ -659,7 +665,7 @@ async def call_tool_async(
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)

try:
coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds)
coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta)
future = self._invoke_on_background_thread(coro)
call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future)
return self._handle_tool_result(tool_use_id, call_tool_result)
Expand Down
2 changes: 1 addition & 1 deletion src/strands/tools/mcp/mcp_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def patch_mcp_client(wrapped: Callable[..., Any], instance: Any, args: Any, kwar
if hasattr(request.root, "params") and request.root.params:
# Handle Pydantic models
if hasattr(request.root.params, "model_dump") and hasattr(request.root.params, "model_validate"):
params_dict = request.root.params.model_dump()
params_dict = request.root.params.model_dump(by_alias=True)
# Add _meta with tracing context
meta = params_dict.setdefault("_meta", {})
propagate.get_global_textmap().inject(meta)
Expand Down
63 changes: 54 additions & 9 deletions tests/strands/tools/mcp/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})

mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None)
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None)

assert result["status"] == expected_status
assert result["toolUseId"] == "test-123"
Expand Down Expand Up @@ -153,7 +153,7 @@ def test_call_tool_sync_with_structured_content(mock_transport, mock_session):
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})

mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None)
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None)

assert result["status"] == "success"
assert result["toolUseId"] == "test-123"
Expand All @@ -180,6 +180,51 @@ def test_call_tool_sync_exception(mock_transport, mock_session):
assert "Test exception" in result["content"][0]["text"]


def test_call_tool_sync_forwards_meta(mock_transport, mock_session):
"""Test that call_tool_sync forwards meta to ClientSession.call_tool."""
mock_content = MCPTextContent(type="text", text="Test message")
mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content])
meta = {"com.example/request_id": "abc-123"}

with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, meta=meta
)

mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=meta)
assert result["status"] == "success"


@pytest.mark.asyncio
async def test_call_tool_async_forwards_meta(mock_transport, mock_session):
"""Test that call_tool_async forwards meta to ClientSession.call_tool."""
mock_content = MCPTextContent(type="text", text="Test message")
mock_result = MCPCallToolResult(isError=False, content=[mock_content])
mock_session.call_tool.return_value = mock_result
meta = {"com.example/request_id": "abc-123"}

with MCPClient(mock_transport["transport_callable"]) as client:
with (
patch("asyncio.run_coroutine_threadsafe") as mock_run_coroutine_threadsafe,
patch("asyncio.wrap_future") as mock_wrap_future,
):
mock_future = MagicMock()
mock_run_coroutine_threadsafe.return_value = mock_future

async def mock_awaitable():
return mock_result

mock_wrap_future.return_value = mock_awaitable()

result = await client.call_tool_async(
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, meta=meta
)

mock_run_coroutine_threadsafe.assert_called_once()

assert result["status"] == "success"


@pytest.mark.asyncio
@pytest.mark.parametrize("is_error,expected_status", [(False, "success"), (True, "error")])
async def test_call_tool_async_status(mock_transport, mock_session, is_error, expected_status):
Expand Down Expand Up @@ -584,7 +629,7 @@ def test_call_tool_sync_embedded_nested_text(mock_transport, mock_session):
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="er-text", name="get_file_contents", arguments={})

mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None)
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
assert result["status"] == "success"
assert len(result["content"]) == 1
assert result["content"][0]["text"] == "inner text"
Expand All @@ -609,7 +654,7 @@ def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="er-blob", name="get_file_contents", arguments={})

mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None)
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
assert result["status"] == "success"
assert len(result["content"]) == 1
assert result["content"][0]["text"] == '{"k":"v"}'
Expand All @@ -635,7 +680,7 @@ def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session):
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="er-image", name="get_file_contents", arguments={})

mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None)
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
assert result["status"] == "success"
assert len(result["content"]) == 1
assert "image" in result["content"][0]
Expand All @@ -660,7 +705,7 @@ def test_call_tool_sync_embedded_non_textual_blob_dropped(mock_transport, mock_s
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="er-binary", name="get_file_contents", arguments={})

mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None)
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
assert result["status"] == "success"
assert len(result["content"]) == 0 # Content should be dropped

Expand All @@ -683,7 +728,7 @@ def test_call_tool_sync_embedded_multiple_textual_mimes(mock_transport, mock_ses
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="er-yaml", name="get_file_contents", arguments={})

mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None)
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
assert result["status"] == "success"
assert len(result["content"]) == 1
assert "key: value" in result["content"][0]["text"]
Expand All @@ -710,7 +755,7 @@ def __init__(self):
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="er-unknown", name="get_file_contents", arguments={})

mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None)
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
assert result["status"] == "success"
assert len(result["content"]) == 0 # Unknown resource type should be dropped

Expand Down Expand Up @@ -762,7 +807,7 @@ def test_call_tool_sync_with_meta_and_structured_content(mock_transport, mock_se
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})

mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None)
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None)

assert result["status"] == "success"
assert result["toolUseId"] == "test-123"
Expand Down
29 changes: 27 additions & 2 deletions tests/strands/tools/mcp/test_mcp_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ class MockPydanticParams:
def __init__(self, **data):
self._data = data

def model_dump(self):
def model_dump(self, by_alias=False):
return self._data.copy()

@classmethod
Expand Down Expand Up @@ -431,6 +431,31 @@ def test_patch_mcp_client_injects_context_pydantic_model(self):
# Verify the params object is still a MockPydanticParams (or dict if fallback occurred)
assert hasattr(mock_request.root.params, "model_dump") or isinstance(mock_request.root.params, dict)

def test_patch_mcp_client_preserves_existing_meta_pydantic(self):
"""Test that instrumentation preserves existing _meta values in Pydantic models."""
mock_request = MagicMock()
mock_request.root.method = "tools/call"

# Pydantic model with existing _meta (returned via by_alias=True)
mock_params = MockPydanticParams(_meta={"com.example/request_id": "abc-123"}, name="echo")
mock_request.root.params = mock_params

with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap:
mcp_instrumentation()
patch_function = mock_wrap.call_args_list[0][0][2]

mock_wrapped = MagicMock()

with patch.object(propagate, "get_global_textmap") as mock_textmap:
mock_textmap_instance = MagicMock()
mock_textmap.return_value = mock_textmap_instance

patch_function(mock_wrapped, None, [mock_request], {})

# inject should be called with the existing _meta dict (not a new empty one)
inject_call_args = mock_textmap_instance.inject.call_args[0][0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of getting a positional argument here, can we instead confirm the argument name is _meta?

assert inject_call_args.get("com.example/request_id") == "abc-123"

def test_patch_mcp_client_injects_context_dict_params(self):
"""Test that the client patch injects OpenTelemetry context into dict params."""
# Create a mock request with tools/call method and dict params
Expand Down Expand Up @@ -507,7 +532,7 @@ class FailingMockPydanticParams:
def __init__(self, **data):
self._data = data

def model_dump(self):
def model_dump(self, by_alias=False):
return self._data.copy()

def model_validate(self, data):
Expand Down
81 changes: 81 additions & 0 deletions tests_integ/mcp/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,87 @@ def test_mcp_client_without_structured_content():
assert result["content"] == [{"text": "SIMPLE_ECHO_TEST"}]


def test_call_tool_sync_with_meta():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests dont actually assert that the mcp server received the metadata. The echo_server.py mcp server should be updated to return any sent metadata. Then we can assert that the actual running server is working as intended.

"""Test that call_tool_sync works correctly when meta is provided."""
stdio_mcp_client = MCPClient(
lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"]))
)

with stdio_mcp_client:
result = stdio_mcp_client.call_tool_sync(
tool_use_id="test-meta-sync",
name="echo",
arguments={"to_echo": "META_TEST"},
meta={"com.example/request_id": "abc-123"},
)

assert result["status"] == "success"
assert result["content"] == [{"text": "META_TEST"}]


@pytest.mark.asyncio
async def test_call_tool_async_with_meta():
"""Test that call_tool_async works correctly when meta is provided."""
stdio_mcp_client = MCPClient(
lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"]))
)

with stdio_mcp_client:
result = await stdio_mcp_client.call_tool_async(
tool_use_id="test-meta-async",
name="echo",
arguments={"to_echo": "META_ASYNC_TEST"},
meta={"com.example/request_id": "def-456"},
)

assert result["status"] == "success"
assert result["content"] == [{"text": "META_ASYNC_TEST"}]


def test_instrumentation_preserves_meta_on_tool_call():
"""Test that OTel instrumentation correctly sets _meta on outgoing tool call requests."""
captured_params = []

def spy_send_request(wrapped, instance, args, kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not need to spy in our integ tests. The echo_server.py mcp server should be updated to return any sent metadata. Then we can assert that the actual running server is working as intended.

if args:
request = args[0]
method = getattr(getattr(request, "root", None), "method", None)
if method == "tools/call" and hasattr(request.root, "params"):
params = request.root.params
if hasattr(params, "model_dump"):
captured_params.append(params.model_dump(by_alias=True))
elif isinstance(params, dict):
captured_params.append(params.copy())
return wrapped(*args, **kwargs)

stdio_mcp_client = MCPClient(
lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"]))
)

with stdio_mcp_client:
from mcp.shared.session import BaseSession
from wrapt import wrap_function_wrapper

original_send = BaseSession.send_request
wrap_function_wrapper("mcp.shared.session", "BaseSession.send_request", spy_send_request)

try:
result = stdio_mcp_client.call_tool_sync(
tool_use_id="test-instrumentation",
name="echo",
arguments={"to_echo": "INSTRUMENTATION_TEST"},
)

assert result["status"] == "success"
assert len(captured_params) > 0

params = captured_params[-1]
assert "_meta" in params
assert isinstance(params["_meta"], dict)
finally:
BaseSession.send_request = original_send


@pytest.mark.skipif(
condition=os.environ.get("GITHUB_ACTIONS") == "true",
reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue",
Expand Down
Loading