Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions src/strands/tools/mcp/mcp_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
It allows MCP tools to be seamlessly integrated and used within the agent ecosystem.
"""

import asyncio
import logging
from datetime import timedelta
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -110,10 +111,22 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
"""
logger.debug("tool_name=<%s>, tool_use_id=<%s> | streaming", self.tool_name, tool_use["toolUseId"])

result = await self.mcp_client.call_tool_async(
tool_use_id=tool_use["toolUseId"],
name=self.mcp_tool.name, # Use original MCP name for server communication
arguments=tool_use["input"],
read_timeout_seconds=self.timeout,
)
yield ToolResultEvent(result)
result, exception = await self._invoke_tool(tool_use)
yield ToolResultEvent(result, exception=exception)

async def _invoke_tool(self, tool_use: ToolUse) -> tuple[Any, Exception | None]:
"""Invoke the MCP tool and return (result, exception).

Returns both the MCPToolResult and the original exception (if any),
so callers can access the exception via ToolResultEvent.exception —
matching the pattern used by decorated tools.
"""
try:
coro = self.mcp_client._create_call_tool_coroutine(
self.mcp_tool.name, tool_use["input"], self.timeout
)
future = self.mcp_client._invoke_on_background_thread(coro)
call_tool_result = await asyncio.wrap_future(future)
return self.mcp_client._handle_tool_result(tool_use["toolUseId"], call_tool_result), None
except Exception as e:
return self.mcp_client._handle_tool_execution_error(tool_use["toolUseId"], e), e
2 changes: 2 additions & 0 deletions src/strands/tools/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -
status="error",
toolUseId=tool_use_id,
content=[{"text": f"Tool execution failed: {str(exception)}"}],
isError=True,
)

def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> MCPToolResult:
Expand Down Expand Up @@ -703,6 +704,7 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes
status=status,
toolUseId=tool_use_id,
content=mapped_contents,
isError=call_tool_result.isError,
)

if call_tool_result.structuredContent:
Expand Down
5 changes: 5 additions & 0 deletions src/strands/tools/mcp/mcp_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class MCPToolResult(ToolResult):
that provides structured results beyond the standard text/image/document content.

Attributes:
isError: Flag indicating whether this result represents an error.
Set to True when the MCP tool reports a failure via CallToolResult.isError
(application-level error) or when a Python exception occurs during tool
execution (protocol/client error). Set to False or omitted on success.
structuredContent: Optional JSON object containing structured data returned
by the MCP tool. This allows MCP tools to return complex data structures
that can be processed programmatically by agents or other tools.
Expand All @@ -63,5 +67,6 @@ class MCPToolResult(ToolResult):
performance metrics, or business-specific tracking information).
"""

isError: NotRequired[bool]
structuredContent: NotRequired[dict[str, Any]]
metadata: NotRequired[dict[str, Any]]
127 changes: 106 additions & 21 deletions tests/strands/tools/mcp/test_mcp_agent_tool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import timedelta
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
from mcp.types import Tool as MCPTool
Expand All @@ -8,16 +8,6 @@
from strands.types._events import ToolResultEvent


@pytest.fixture
def mock_mcp_tool():
mock_tool = MagicMock(spec=MCPTool)
mock_tool.name = "test_tool"
mock_tool.description = "A test tool"
mock_tool.inputSchema = {"type": "object", "properties": {}}
mock_tool.outputSchema = None # MCP tools can have optional outputSchema
return mock_tool


@pytest.fixture
def mock_mcp_client():
mock_server = MagicMock(spec=MCPClient)
Expand All @@ -26,9 +16,32 @@ def mock_mcp_client():
"toolUseId": "test-123",
"content": [{"text": "Success result"}],
}
# Mock internal methods used by MCPAgentTool.stream()
mock_server._create_call_tool_coroutine.return_value = MagicMock()
mock_server._handle_tool_result.return_value = {
"status": "success",
"toolUseId": "test-123",
"content": [{"text": "Success result"}],
}
mock_server._handle_tool_execution_error.return_value = {
"status": "error",
"toolUseId": "test-123",
"content": [{"text": "error"}],
"isError": True,
}
return mock_server


@pytest.fixture
def mock_mcp_tool():
mock_tool = MagicMock(spec=MCPTool)
mock_tool.name = "test_tool"
mock_tool.description = "A test tool"
mock_tool.inputSchema = {"type": "object", "properties": {}}
mock_tool.outputSchema = None # MCP tools can have optional outputSchema
return mock_tool


@pytest.fixture
def mcp_agent_tool(mock_mcp_tool, mock_mcp_client):
return MCPAgentTool(mock_mcp_tool, mock_mcp_client)
Expand Down Expand Up @@ -84,13 +97,25 @@ def test_tool_spec_without_output_schema(mock_mcp_tool, mock_mcp_client):
async def test_stream(mcp_agent_tool, mock_mcp_client, alist):
tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}}

tru_events = await alist(mcp_agent_tool.stream(tool_use, {}))
exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)]
mock_result = mock_mcp_client._handle_tool_result.return_value

with patch("asyncio.wrap_future") as mock_wrap_future:
# Make wrap_future return a coroutine that resolves to the mock call_tool result
async def mock_awaitable(_):
return MagicMock() # call_tool_result (raw MCP response)

assert tru_events == exp_events
mock_mcp_client.call_tool_async.assert_called_once_with(
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=None
mock_wrap_future.side_effect = mock_awaitable

tru_events = await alist(mcp_agent_tool.stream(tool_use, {}))

assert len(tru_events) == 1
event = tru_events[0]
assert event.exception is None
assert event.tool_result == mock_result
mock_mcp_client._create_call_tool_coroutine.assert_called_once_with(
"test_tool", {"param": "value"}, None
)
mock_mcp_client._handle_tool_result.assert_called_once()


def test_timeout_initialization(mock_mcp_tool, mock_mcp_client):
Expand All @@ -110,10 +135,70 @@ async def test_stream_with_timeout(mock_mcp_tool, mock_mcp_client, alist):
agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout)
tool_use = {"toolUseId": "test-456", "name": "test_tool", "input": {"param": "value"}}

tru_events = await alist(agent_tool.stream(tool_use, {}))
exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)]
mock_result = mock_mcp_client._handle_tool_result.return_value

with patch("asyncio.wrap_future") as mock_wrap_future:

async def mock_awaitable(_):
return MagicMock()

assert tru_events == exp_events
mock_mcp_client.call_tool_async.assert_called_once_with(
tool_use_id="test-456", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=timeout
mock_wrap_future.side_effect = mock_awaitable

tru_events = await alist(agent_tool.stream(tool_use, {}))

assert len(tru_events) == 1
assert tru_events[0].exception is None
assert tru_events[0].tool_result == mock_result
mock_mcp_client._create_call_tool_coroutine.assert_called_once_with(
"test_tool", {"param": "value"}, timeout
)


@pytest.mark.asyncio
async def test_stream_propagates_exception(mock_mcp_tool, mock_mcp_client, alist):
"""Test that stream() passes the original exception via ToolResultEvent.exception.

This ensures parity with decorated tools, where the exception is accessible
via event.exception for debugging and conditional handling.
"""
agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client)
tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}}

test_exception = RuntimeError("MCP server connection failed")
with patch("asyncio.wrap_future", side_effect=test_exception):
mock_error_result = {
"status": "error", "toolUseId": "test-123",
"content": [{"text": "Tool execution failed: MCP server connection failed"}],
"isError": True,
}
mock_mcp_client._handle_tool_execution_error.return_value = mock_error_result

tru_events = await alist(agent_tool.stream(tool_use, {}))

assert len(tru_events) == 1
event = tru_events[0]
assert event.exception is test_exception
assert event.tool_result == mock_error_result
mock_mcp_client._handle_tool_execution_error.assert_called_once_with("test-123", test_exception)


@pytest.mark.asyncio
async def test_stream_no_exception_on_success(mock_mcp_tool, mock_mcp_client, alist):
"""Test that stream() sets exception=None on successful execution."""
agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client)
tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}}

mock_result = mock_mcp_client._handle_tool_result.return_value

with patch("asyncio.wrap_future") as mock_wrap_future:

async def mock_awaitable(_):
return MagicMock()

mock_wrap_future.side_effect = mock_awaitable

tru_events = await alist(agent_tool.stream(tool_use, {}))

assert len(tru_events) == 1
assert tru_events[0].exception is None
assert tru_events[0].tool_result == mock_result
4 changes: 4 additions & 0 deletions tests/strands/tools/mcp/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_

assert result["status"] == expected_status
assert result["toolUseId"] == "test-123"
assert result["isError"] == is_error
assert len(result["content"]) == 1
assert result["content"][0]["text"] == "Test message"
# No structured content should be present when not provided by MCP
Expand Down Expand Up @@ -176,6 +177,7 @@ def test_call_tool_sync_exception(mock_transport, mock_session):

assert result["status"] == "error"
assert result["toolUseId"] == "test-123"
assert result["isError"] is True
assert len(result["content"]) == 1
assert "Test exception" in result["content"][0]["text"]

Expand Down Expand Up @@ -214,6 +216,7 @@ async def mock_awaitable():

assert result["status"] == expected_status
assert result["toolUseId"] == "test-123"
assert result["isError"] == is_error
assert len(result["content"]) == 1
assert result["content"][0]["text"] == "Test message"

Expand Down Expand Up @@ -241,6 +244,7 @@ async def test_call_tool_async_exception(mock_transport, mock_session):

assert result["status"] == "error"
assert result["toolUseId"] == "test-123"
assert result["isError"] is True
assert len(result["content"]) == 1
assert "Test exception" in result["content"][0]["text"]

Expand Down