diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 2078f16ce..dc3a0c7ff 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -1,6 +1,7 @@ """A framework for building, deploying, and managing AI agents.""" from . import agent, models, telemetry, types +from .agent._agent_as_tool import AgentAsTool from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy @@ -11,6 +12,7 @@ __all__ = [ "Agent", + "AgentAsTool", "AgentBase", "AgentSkills", "agent", diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index c901e800f..d0254852d 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -10,6 +10,7 @@ from typing import Any from ..event_loop._retry import ModelRetryStrategy +from ._agent_as_tool import AgentAsTool from .agent import Agent from .agent_result import AgentResult from .base import AgentBase @@ -24,6 +25,7 @@ "Agent", "AgentBase", "AgentResult", + "AgentAsTool", "ConversationManager", "NullConversationManager", "SlidingWindowConversationManager", diff --git a/src/strands/agent/_agent_as_tool.py b/src/strands/agent/_agent_as_tool.py new file mode 100644 index 000000000..8954be8a1 --- /dev/null +++ b/src/strands/agent/_agent_as_tool.py @@ -0,0 +1,285 @@ +"""Agent-as-tool adapter. + +This module provides the AgentAsTool class that wraps an Agent (or any AgentBase) as a tool +so it can be passed to another agent's tool list. +""" + +import copy +import logging +from typing import Any + +from typing_extensions import override + +from ..agent.state import AgentState +from ..types._events import AgentAsToolStreamEvent, ToolInterruptEvent, ToolResultEvent +from ..types.content import Messages +from ..types.interrupt import InterruptResponseContent +from ..types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse +from .base import AgentBase + +logger = logging.getLogger(__name__) + + +class AgentAsTool(AgentTool): + """Adapter that exposes an Agent as a tool for use by other agents. + + The tool accepts a single ``input`` string parameter, invokes the wrapped + agent, and returns the text response. + + Example: + ```python + from strands import Agent + from strands.agent import AgentAsTool + + researcher = Agent(name="researcher", description="Finds information") + + # Use directly + tool = AgentAsTool(researcher, name="researcher", description="Finds information") + + # Or via convenience method + tool = researcher.as_tool() + + # Start each invocation with a fresh conversation + tool = researcher.as_tool(preserve_context=False) + + writer = Agent(name="writer", tools=[tool]) + writer("Write about AI agents") + ``` + """ + + def __init__( + self, + agent: AgentBase, + *, + name: str, + description: str, + preserve_context: bool = False, + ) -> None: + r"""Initialize the agent-as-tool adapter. + + Args: + agent: The agent to wrap as a tool. + name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. + description: Tool description. + preserve_context: Whether to preserve the agent's conversation history across + invocations. When False, the agent's messages and state are reset to the + values they had at construction time before each call, ensuring every + invocation starts from the same baseline regardless of any external + interactions with the agent. Defaults to False. Only effective when the + wrapped agent exposes a mutable ``messages`` list and/or an ``AgentState`` + (e.g. ``strands.agent.Agent``). + """ + super().__init__() + self._agent = agent + self._tool_name = name + self._description = description + self._preserve_context = preserve_context + + # When preserve_context=False, we snapshot the agent's initial state so we can + # restore it before each invocation. This mirrors GraphNode.reset_executor_state(). + # We require an Agent instance for this since AgentBase doesn't guarantee + # messages/state attributes. + self._initial_messages: Messages = [] + self._initial_state: AgentState = AgentState() + + if not preserve_context: + from .agent import Agent + + if not isinstance(agent, Agent): + raise TypeError(f"preserve_context=False requires an Agent instance, got {type(agent).__name__}") + self._initial_messages = copy.deepcopy(agent.messages) + self._initial_state = AgentState(agent.state.get()) + + @property + def agent(self) -> AgentBase: + """The wrapped agent instance.""" + return self._agent + + @property + def tool_name(self) -> str: + """Get the tool name.""" + return self._tool_name + + @property + def tool_spec(self) -> ToolSpec: + """Get the tool specification.""" + return { + "name": self._tool_name, + "description": self._description, + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "The input to send to the agent tool.", + }, + }, + "required": ["input"], + } + }, + } + + @property + def tool_type(self) -> str: + """Get the tool type.""" + return "agent" + + @override + async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: + """Invoke the wrapped agent via streaming and yield events. + + Intermediate agent events are wrapped in AgentAsToolStreamEvent so the caller + can distinguish sub-agent progress from regular tool events. The final + AgentResult is yielded as a ToolResultEvent. + + When the sub-agent encounters a hook interrupt (e.g. from BeforeToolCallEvent), + the interrupts are propagated to the parent agent via ToolInterruptEvent. On + resume, interrupt responses are forwarded to the sub-agent automatically. + + Args: + tool_use: The tool use request containing the input parameter. + invocation_state: Context for the tool invocation. + **kwargs: Additional keyword arguments. + + Yields: + AgentAsToolStreamEvent for intermediate events, ToolInterruptEvent if the + sub-agent is interrupted, or ToolResultEvent with the final response. + """ + tool_input = tool_use["input"] + if isinstance(tool_input, dict): + prompt = tool_input.get("input", "") + elif isinstance(tool_input, str): + prompt = tool_input + else: + logger.warning("tool_name=<%s> | unexpected input type: %s", self._tool_name, type(tool_input)) + prompt = str(tool_input) + + tool_use_id = tool_use["toolUseId"] + + # Determine if we are resuming the sub-agent from an interrupt. + if self._is_sub_agent_interrupted(): + prompt = self._build_interrupt_responses() + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | resuming sub-agent from interrupt", + self._tool_name, + tool_use_id, + ) + elif not self._preserve_context: + self._reset_agent_state(tool_use_id) + + logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking agent", self._tool_name, tool_use_id) + + try: + result = None + async for event in self._agent.stream_async(prompt): + if "result" in event: + result = event["result"] + else: + yield AgentAsToolStreamEvent(tool_use, event, self) + + if result is None: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": "Agent did not produce a result"}], + } + ) + return + + # Propagate sub-agent interrupts to the parent agent. + if result.stop_reason == "interrupt" and result.interrupts: + yield ToolInterruptEvent(tool_use, list(result.interrupts)) + return + + if result.structured_output: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"json": result.structured_output.model_dump()}], + } + ) + else: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"text": str(result)}], + } + ) + + except Exception as e: + logger.warning( + "tool_name=<%s>, tool_use_id=<%s> | agent invocation failed: %s", + self._tool_name, + tool_use_id, + e, + ) + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Agent error: {e}"}], + } + ) + + def _reset_agent_state(self, tool_use_id: str) -> None: + """Reset the wrapped agent to its initial state. + + Restores messages and state to the values captured at construction time. + This mirrors the pattern used by ``GraphNode.reset_executor_state()``. + + Args: + tool_use_id: Tool use ID for logging context. + """ + from .agent import Agent + + # isinstance narrows the type for mypy; __init__ guarantees this when preserve_context=False + if not isinstance(self._agent, Agent): + return + + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | resetting agent to initial state", + self._tool_name, + tool_use_id, + ) + self._agent.messages = copy.deepcopy(self._initial_messages) + self._agent.state = AgentState(self._initial_state.get()) + + def _is_sub_agent_interrupted(self) -> bool: + """Check whether the wrapped agent is in an activated interrupt state.""" + from .agent import Agent + + if not isinstance(self._agent, Agent): + return False + return self._agent._interrupt_state.activated + + def _build_interrupt_responses(self) -> list[InterruptResponseContent]: + """Build interrupt response payloads from the sub-agent's interrupt state. + + The parent agent's ``_interrupt_state.resume()`` sets ``.response`` on the shared + ``Interrupt`` objects (registered by the executor), so we re-package them in the + format expected by ``Agent.stream_async``. + + Returns: + List of interrupt response content blocks for resuming the sub-agent. + """ + from .agent import Agent + + if not isinstance(self._agent, Agent): + return [] + + return [ + {"interruptResponse": {"interruptId": interrupt.id, "response": interrupt.response}} + for interrupt in self._agent._interrupt_state.interrupts.values() + if interrupt.response is not None + ] + + @override + def get_display_properties(self) -> dict[str, str]: + """Get properties for UI display.""" + properties = super().get_display_properties() + properties["Agent"] = getattr(self._agent, "name", "unknown") + return properties diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f378a886a..f09399dbf 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -62,6 +62,7 @@ from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException from ..types.traces import AttributeValue +from ._agent_as_tool import AgentAsTool from .agent_result import AgentResult from .base import AgentBase from .conversation_manager import ( @@ -612,6 +613,40 @@ async def structured_output_async(self, output_model: type[T], prompt: AgentInpu finally: await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, invocation_state={})) + def as_tool( + self, + name: str | None = None, + description: str | None = None, + preserve_context: bool = False, + ) -> AgentAsTool: + r"""Convert this agent into a tool for use by another agent. + + Args: + name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. + Defaults to the agent's name. + description: Tool description. Defaults to the agent's description. + preserve_context: Whether to preserve the agent's conversation history across + invocations. When False, the agent's messages and state are reset to the + values they had at construction time before each call, ensuring every + invocation starts from the same baseline regardless of any external + interactions with the agent. Defaults to False. + + Returns: + An AgentAsTool wrapping this agent. + + Example: + ```python + researcher = Agent(name="researcher", description="Finds information") + writer = Agent(name="writer", tools=[researcher.as_tool()]) + writer("Write about AI agents") + ``` + """ + if not name: + name = self.name + if not description: + description = self.description or f"Use the {name} agent as a tool by providing a natural language input" + return AgentAsTool(self, name=name, description=description, preserve_context=preserve_context) + def cleanup(self) -> None: """Clean up resources used by the agent. diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 0da6b5715..e8f21ca7c 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -226,6 +226,12 @@ async def _stream( # ToolStreamEvent and the last event is just the result. if isinstance(event, ToolInterruptEvent): + # Register any interrupts not already in the agent's state. + # For normal hooks this is a no-op (already registered by _Interruptible.interrupt()). + # For sub-agent interrupts propagated via AgentAsTool, this is where they get + # registered so that _interrupt_state.resume() can locate them by ID. + for interrupt in event.interrupts: + agent._interrupt_state.interrupts.setdefault(interrupt.id, interrupt) yield event return diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 5b0ae78f6..5603aedfb 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from ..agent import AgentResult + from ..agent._agent_as_tool import AgentAsTool from ..multiagent.base import MultiAgentResult, NodeResult @@ -323,6 +324,31 @@ def tool_use_id(self) -> str: return cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use"))["toolUseId"] +class AgentAsToolStreamEvent(ToolStreamEvent): + """Event emitted when an agent-as-tool yields intermediate events during execution. + + Extends ToolStreamEvent with a reference to the originating AgentAsTool so callers + can distinguish sub-agent stream events from regular tool stream events and access + the wrapped agent, tool name, description, etc. + """ + + def __init__(self, tool_use: ToolUse, tool_stream_data: Any, agent_as_tool: "AgentAsTool") -> None: + """Initialize with tool streaming data and agent-tool reference. + + Args: + tool_use: The tool invocation producing the stream. + tool_stream_data: The yielded event from the sub-agent execution. + agent_as_tool: The AgentAsTool instance that produced this event. + """ + super().__init__(tool_use, tool_stream_data) + self._agent_as_tool = agent_as_tool + + @property + def agent_as_tool(self) -> "AgentAsTool": + """The AgentAsTool instance that produced this event.""" + return self._agent_as_tool + + class ToolCancelEvent(TypedEvent): """Event emitted when a user cancels a tool call from their BeforeToolCallEvent hook.""" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 967a0dafb..c089ba808 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -15,7 +15,7 @@ import strands from strands import Agent, Plugin, ToolContext -from strands.agent import AgentResult +from strands.agent import AgentAsTool, AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState @@ -2699,3 +2699,52 @@ def hook_callback(event: BeforeModelCallEvent): agent("test") assert len(hook_called) == 1 + + +def test_as_tool_returns_agent_tool(): + """Test that as_tool returns an AgentAsTool wrapping the agent.""" + agent = Agent(name="researcher", description="Finds information") + tool = agent.as_tool() + + assert isinstance(tool, AgentAsTool) + assert tool.agent is agent + + +def test_as_tool_defaults_name_from_agent(): + """Test that as_tool defaults the tool name to the agent's name.""" + agent = Agent(name="researcher") + tool = agent.as_tool() + + assert tool.tool_name == "researcher" + + +def test_as_tool_defaults_description_from_agent(): + """Test that as_tool defaults the description to the agent's description.""" + agent = Agent(name="researcher", description="Finds information") + tool = agent.as_tool() + + assert tool.tool_spec["description"] == "Finds information" + + +def test_as_tool_custom_name(): + """Test that as_tool accepts a custom name.""" + agent = Agent(name="researcher") + tool = agent.as_tool(name="custom_name") + + assert tool.tool_name == "custom_name" + + +def test_as_tool_custom_description(): + """Test that as_tool accepts a custom description.""" + agent = Agent(name="researcher", description="Original") + tool = agent.as_tool(description="Custom description") + + assert tool.tool_spec["description"] == "Custom description" + + +def test_as_tool_defaults_description_when_agent_has_none(): + """Test that as_tool generates a default description when agent has none.""" + agent = Agent(name="researcher") + tool = agent.as_tool() + + assert tool.tool_spec["description"] == "Use the researcher agent as a tool by providing a natural language input" diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py new file mode 100644 index 000000000..abae6dffd --- /dev/null +++ b/tests/strands/agent/test_agent_as_tool.py @@ -0,0 +1,617 @@ +"""Tests for AgentAsTool - the agent-as-tool adapter.""" + +from unittest.mock import MagicMock + +import pytest + +from strands.agent import AgentAsTool +from strands.agent.agent_result import AgentResult +from strands.interrupt import Interrupt +from strands.telemetry.metrics import EventLoopMetrics +from strands.types._events import AgentAsToolStreamEvent, ToolInterruptEvent, ToolResultEvent, ToolStreamEvent + + +async def _mock_stream_async(result, intermediate_events=None): + """Helper that yields intermediate events then the final result event.""" + for event in intermediate_events or []: + yield event + yield {"result": result} + + +@pytest.fixture +def mock_agent(): + agent = MagicMock() + agent.name = "test_agent" + agent.description = "A test agent" + return agent + + +@pytest.fixture +def fake_agent(): + """A real Agent instance for tests that need Agent-specific features.""" + from strands.agent.agent import Agent + + return Agent(name="fake_agent", callback_handler=None) + + +@pytest.fixture +def tool(mock_agent): + return AgentAsTool(mock_agent, name="test_agent", description="A test agent", preserve_context=True) + + +@pytest.fixture +def tool_use(): + return { + "toolUseId": "tool-123", + "name": "test_agent", + "input": {"input": "hello"}, + } + + +@pytest.fixture +def agent_result(): + return AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "response text"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + + +# --- init --- + + +def test_init(mock_agent): + tool = AgentAsTool(mock_agent, name="my_tool", description="custom desc", preserve_context=True) + assert tool.tool_name == "my_tool" + assert tool._description == "custom desc" + assert tool.agent is mock_agent + + +def test_init_preserve_context_defaults_false(fake_agent): + tool = AgentAsTool(fake_agent, name="t", description="d") + assert tool._preserve_context is False + + +def test_init_preserve_context_true(mock_agent): + tool = AgentAsTool(mock_agent, name="t", description="d", preserve_context=True) + assert tool._preserve_context is True + + +# --- properties --- + + +def test_tool_properties(tool): + assert tool.tool_name == "test_agent" + assert tool.tool_type == "agent" + + spec = tool.tool_spec + assert spec["name"] == "test_agent" + assert spec["description"] == "A test agent" + + schema = spec["inputSchema"]["json"] + assert schema["type"] == "object" + assert "input" in schema["properties"] + assert schema["properties"]["input"]["type"] == "string" + assert schema["required"] == ["input"] + + props = tool.get_display_properties() + assert props["Agent"] == "test_agent" + assert props["Type"] == "agent" + + +# --- stream --- + + +@pytest.mark.asyncio +async def test_stream_success(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "success" + assert result_events[0]["tool_result"]["content"][0]["text"] == "response text\n" + + +@pytest.mark.asyncio +async def test_stream_passes_input_to_agent(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("hello") + + +@pytest.mark.asyncio +async def test_stream_empty_input(tool, mock_agent, agent_result): + empty_tool_use = { + "toolUseId": "tool-123", + "name": "test_agent", + "input": {}, + } + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(empty_tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("") + + +@pytest.mark.asyncio +async def test_stream_string_input(tool, mock_agent, agent_result): + tool_use = { + "toolUseId": "tool-123", + "name": "test_agent", + "input": "direct string", + } + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("direct string") + + +@pytest.mark.asyncio +async def test_stream_error(tool, mock_agent, tool_use): + mock_agent.stream_async.side_effect = RuntimeError("boom") + + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert events[0]["tool_result"]["status"] == "error" + assert "boom" in events[0]["tool_result"]["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_stream_propagates_tool_use_id(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert result_events[0]["tool_result"]["toolUseId"] == "tool-123" + + +@pytest.mark.asyncio +async def test_stream_forwards_intermediate_events(tool, mock_agent, tool_use, agent_result): + intermediate = [{"data": "partial"}, {"data": "more"}] + mock_agent.stream_async.return_value = _mock_stream_async(agent_result, intermediate) + + events = [event async for event in tool.stream(tool_use, {})] + + stream_events = [e for e in events if isinstance(e, AgentAsToolStreamEvent)] + assert len(stream_events) == 2 + assert stream_events[0]["tool_stream_event"]["data"]["data"] == "partial" + assert stream_events[1]["tool_stream_event"]["data"]["data"] == "more" + assert stream_events[0].agent_as_tool is tool + assert stream_events[0].tool_use_id == "tool-123" + + +@pytest.mark.asyncio +async def test_stream_events_not_double_wrapped_by_executor(tool, mock_agent, tool_use, agent_result): + """AgentAsToolStreamEvent is a ToolStreamEvent subclass, so the executor should pass it through directly.""" + intermediate = [{"data": "chunk"}] + mock_agent.stream_async.return_value = _mock_stream_async(agent_result, intermediate) + + events = [event async for event in tool.stream(tool_use, {})] + + stream_events = [e for e in events if isinstance(e, AgentAsToolStreamEvent)] + assert len(stream_events) == 1 + + event = stream_events[0] + # It's a ToolStreamEvent (so the executor yields it directly) + assert isinstance(event, ToolStreamEvent) + # But it's specifically an AgentAsToolStreamEvent (not re-wrapped) + assert type(event) is AgentAsToolStreamEvent + # And it references the originating AgentAsTool + assert event.agent_as_tool is tool + + +@pytest.mark.asyncio +async def test_stream_no_result_yields_error(tool, mock_agent, tool_use): + async def _empty_stream(): + return + yield # noqa: RET504 - make it an async generator + + mock_agent.stream_async.return_value = _empty_stream() + + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert events[0]["tool_result"]["status"] == "error" + assert "did not produce a result" in events[0]["tool_result"]["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_stream_structured_output(tool, mock_agent, tool_use): + from pydantic import BaseModel + + class MyOutput(BaseModel): + answer: str + + structured = MyOutput(answer="42") + result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ignored"}]}, + metrics=EventLoopMetrics(), + state={}, + structured_output=structured, + ) + mock_agent.stream_async.return_value = _mock_stream_async(result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert result_events[0]["tool_result"]["status"] == "success" + assert result_events[0]["tool_result"]["content"][0]["json"] == {"answer": "42"} + + +# --- preserve_context --- + + +@pytest.mark.asyncio +async def test_stream_resets_to_initial_state_when_preserve_context_false(fake_agent): + fake_agent.messages = [{"role": "user", "content": [{"text": "initial"}]}] + fake_agent.state.set("counter", 0) + + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + # Mutate agent state as if a previous invocation happened + fake_agent.messages.append({"role": "assistant", "content": [{"text": "reply"}]}) + fake_agent.state.set("counter", 5) + + # Mock stream_async so we don't need a real model + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert fake_agent.messages == [{"role": "user", "content": [{"text": "initial"}]}] + assert fake_agent.state.get("counter") == 0 + + +@pytest.mark.asyncio +async def test_stream_resets_on_every_invocation(fake_agent): + """Each call should reset to the same initial snapshot, not to the previous call's state.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "seed"}]}] + fake_agent.state.set("count", 1) + + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-1", + "name": "fake_agent", + "input": {"input": "first"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + fake_agent.messages.append({"role": "assistant", "content": [{"text": "added"}]}) + fake_agent.state.set("count", 99) + + tool_use["toolUseId"] = "tool-2" + async for _ in tool.stream(tool_use, {}): + pass + + assert fake_agent.messages == [{"role": "user", "content": [{"text": "seed"}]}] + assert fake_agent.state.get("count") == 1 + + +@pytest.mark.asyncio +async def test_stream_initial_snapshot_is_deep_copy(fake_agent): + """Mutating the agent's messages after construction should not affect the snapshot.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "original"}]}] + + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.messages[0]["content"][0]["text"] = "mutated" + fake_agent.messages.append({"role": "assistant", "content": [{"text": "extra"}]}) + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert fake_agent.messages == [{"role": "user", "content": [{"text": "original"}]}] + + +@pytest.mark.asyncio +async def test_stream_resets_empty_initial_state_when_preserve_context_false(fake_agent): + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert fake_agent.messages == [] + assert fake_agent.state.get() == {} + + +@pytest.mark.asyncio +async def test_stream_resets_context_by_default(fake_agent): + """Default preserve_context=False means each invocation starts fresh.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc") + + # Mutate after construction + fake_agent.messages.append({"role": "assistant", "content": [{"text": "extra"}]}) + fake_agent.state.set("key", "changed") + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + # Should reset to construction-time snapshot + assert fake_agent.messages == [{"role": "user", "content": [{"text": "old"}]}] + assert fake_agent.state.get("key") == "value" + + +@pytest.mark.asyncio +async def test_stream_preserves_context_when_explicitly_true(fake_agent): + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert len(fake_agent.messages) >= 1 + assert fake_agent.state.get("key") == "value" + + +def test_preserve_context_false_requires_agent_instance(): + """Default preserve_context=False should raise TypeError for non-Agent instances.""" + + class _NotAnAgent: + name = "not_agent" + + async def invoke_async(self, prompt=None, **kwargs): + pass + + def __call__(self, prompt=None, **kwargs): + pass + + def stream_async(self, prompt=None, **kwargs): + pass + + with pytest.raises(TypeError, match="requires an Agent instance"): + AgentAsTool(_NotAnAgent(), name="bad", description="desc") + + +# --- interrupt propagation --- + + +@pytest.fixture +def interrupt_result(): + interrupt = Interrupt(id="interrupt-1", name="approval", reason="need approval") + return AgentResult( + stop_reason="interrupt", + message={"role": "assistant", "content": [{"text": "pending"}]}, + metrics=EventLoopMetrics(), + state={}, + interrupts=[interrupt], + ) + + +@pytest.mark.asyncio +async def test_stream_interrupt_yields_tool_interrupt_event(tool, mock_agent, tool_use, interrupt_result): + """When the sub-agent returns an interrupt result, AgentAsTool should yield ToolInterruptEvent.""" + mock_agent.stream_async.return_value = _mock_stream_async(interrupt_result) + + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert isinstance(events[0], ToolInterruptEvent) + assert events[0].interrupts == interrupt_result.interrupts + assert events[0].tool_use_id == "tool-123" + + +@pytest.mark.asyncio +async def test_stream_interrupt_no_tool_result_appended(tool, mock_agent, tool_use, interrupt_result): + """ToolInterruptEvent should not produce a ToolResultEvent.""" + mock_agent.stream_async.return_value = _mock_stream_async(interrupt_result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert result_events == [] + + +@pytest.mark.asyncio +async def test_stream_interrupt_forwards_intermediate_events(tool, mock_agent, tool_use, interrupt_result): + """Intermediate events should still be yielded before the interrupt.""" + intermediate = [{"data": "partial"}] + mock_agent.stream_async.return_value = _mock_stream_async(interrupt_result, intermediate) + + events = [event async for event in tool.stream(tool_use, {})] + + stream_events = [e for e in events if isinstance(e, AgentAsToolStreamEvent)] + interrupt_events = [e for e in events if isinstance(e, ToolInterruptEvent)] + assert len(stream_events) == 1 + assert len(interrupt_events) == 1 + + +@pytest.mark.asyncio +async def test_stream_interrupt_resume_forwards_responses(fake_agent): + """On resume, AgentAsTool should forward interrupt responses to the sub-agent.""" + interrupt = Interrupt(id="interrupt-1", name="approval", reason="need approval", response="APPROVE") + + # Put the sub-agent in an activated interrupt state with the response already set + fake_agent._interrupt_state.interrupts["interrupt-1"] = interrupt + fake_agent._interrupt_state.activate() + + normal_result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "approved"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + fake_agent.stream_async = MagicMock(return_value=_mock_stream_async(normal_result)) + + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + tool_use = {"toolUseId": "tool-123", "name": "fake_agent", "input": {"input": "do something"}} + + events = [event async for event in tool.stream(tool_use, {})] + + # Should have called stream_async with interrupt responses, not the original prompt + call_args = fake_agent.stream_async.call_args + agent_input = call_args[0][0] + assert isinstance(agent_input, list) + assert len(agent_input) == 1 + assert agent_input[0]["interruptResponse"]["interruptId"] == "interrupt-1" + assert agent_input[0]["interruptResponse"]["response"] == "APPROVE" + + # Should produce a normal result + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "success" + + +@pytest.mark.asyncio +async def test_stream_interrupt_resume_skips_state_reset(fake_agent): + """When resuming from interrupt with preserve_context=False, state reset should be skipped.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "initial"}]}] + fake_agent.state.set("key", "value") + + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + # Simulate the sub-agent being in interrupt state after a previous invocation + interrupt = Interrupt(id="interrupt-1", name="approval", reason="need approval", response="APPROVE") + fake_agent._interrupt_state.interrupts["interrupt-1"] = interrupt + fake_agent._interrupt_state.activate() + + # Mutate messages to simulate sub-agent progress before interrupt + fake_agent.messages.append({"role": "assistant", "content": [{"text": "working on it"}]}) + + normal_result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "done"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + fake_agent.stream_async = MagicMock(return_value=_mock_stream_async(normal_result)) + + tool_use = {"toolUseId": "tool-123", "name": "fake_agent", "input": {"input": "do something"}} + async for _ in tool.stream(tool_use, {}): + pass + + # Messages should NOT have been reset — the sub-agent needs its conversation history intact + assert len(fake_agent.messages) == 2 + + +@pytest.mark.asyncio +async def test_is_sub_agent_interrupted_false_for_mock(tool): + """_is_sub_agent_interrupted returns False for non-Agent instances.""" + assert tool._is_sub_agent_interrupted() is False + + +@pytest.mark.asyncio +async def test_is_sub_agent_interrupted_true_when_activated(fake_agent): + """_is_sub_agent_interrupted returns True when the sub-agent's interrupt state is activated.""" + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + assert tool._is_sub_agent_interrupted() is False + + fake_agent._interrupt_state.activate() + assert tool._is_sub_agent_interrupted() is True + + +@pytest.mark.asyncio +async def test_build_interrupt_responses(fake_agent): + """_build_interrupt_responses packages sub-agent interrupts into response content blocks.""" + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + + interrupt_a = Interrupt(id="id-a", name="a", reason="r", response="yes") + interrupt_b = Interrupt(id="id-b", name="b", reason="r", response=None) + fake_agent._interrupt_state.interrupts = {"id-a": interrupt_a, "id-b": interrupt_b} + + responses = tool._build_interrupt_responses() + + # Only interrupt_a has a response + assert len(responses) == 1 + assert responses[0] == {"interruptResponse": {"interruptId": "id-a", "response": "yes"}} diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 4a5479503..c7ad10232 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -464,6 +464,57 @@ async def test_executor_stream_tool_interrupt_resume(executor, agent, tool_resul assert tru_results == exp_results +@pytest.mark.asyncio +async def test_executor_stream_tool_interrupt_registers_on_agent( + executor, agent, tool_results, invocation_state, alist +): + """ToolInterruptEvent from a tool should register interrupts in the agent's _interrupt_state.""" + # Create a tool that yields a ToolInterruptEvent with an interrupt NOT pre-registered on the agent + # (simulates AgentAsTool propagating sub-agent interrupts). + foreign_interrupt = Interrupt(id="sub-agent-interrupt-1", name="approval", reason="need approval") + + @strands.tool(name="agent_tool") + def agent_tool_func(): + return "unused" + + async def mock_stream(_tool_use, _invocation_state, **_kwargs): + yield ToolInterruptEvent(_tool_use, [foreign_interrupt]) + + agent_tool_func.stream = mock_stream + agent.tool_registry.register_tool(agent_tool_func) + + tool_use: ToolUse = {"name": "agent_tool", "toolUseId": "test_tool_id", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + events = await alist(stream) + + # Should yield the interrupt event + assert len(events) == 1 + assert isinstance(events[0], ToolInterruptEvent) + + # The interrupt should now be registered on the agent's _interrupt_state + assert "sub-agent-interrupt-1" in agent._interrupt_state.interrupts + assert agent._interrupt_state.interrupts["sub-agent-interrupt-1"] is foreign_interrupt + + +@pytest.mark.asyncio +async def test_executor_stream_tool_interrupt_does_not_overwrite_existing( + executor, agent, tool_results, invocation_state, alist +): + """setdefault should not overwrite interrupts already in the agent's state (normal hook case).""" + tool_use = {"name": "interrupt_tool", "toolUseId": "test_tool_id", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + # The interrupt_tool hook registered the interrupt via _Interruptible.interrupt(). + # The executor's setdefault should have been a no-op for this pre-registered interrupt. + registered = agent._interrupt_state.interrupts + assert len(registered) == 1 + interrupt = next(iter(registered.values())) + assert interrupt.name == "test_name" + assert interrupt.reason == "test reason" + + @pytest.mark.asyncio async def test_executor_stream_updates_invocation_state_with_agent( executor, agent, tool_results, invocation_state, weather_tool, alist diff --git a/tests/strands/types/test__events.py b/tests/strands/types/test__events.py index 6163faeb6..48465e1f6 100644 --- a/tests/strands/types/test__events.py +++ b/tests/strands/types/test__events.py @@ -6,6 +6,7 @@ from strands.telemetry import EventLoopMetrics from strands.types._events import ( + AgentAsToolStreamEvent, AgentResultEvent, CitationStreamEvent, EventLoopStopEvent, @@ -465,3 +466,39 @@ def test_event_inheritance(self): assert hasattr(event, "is_callback_event") assert hasattr(event, "as_dict") assert hasattr(event, "prepare") + + +class TestAgentAsToolStreamEvent: + """Tests for AgentAsToolStreamEvent.""" + + def test_initialization(self): + """Test AgentAsToolStreamEvent initialization with agent-tool reference.""" + tool_use: ToolUse = { + "toolUseId": "agent_tool_123", + "name": "researcher", + "input": {"input": "hello"}, + } + agent_event = {"data": "partial response"} + mock_agent_as_tool = MagicMock() + mock_agent_as_tool.tool_name = "researcher" + + event = AgentAsToolStreamEvent(tool_use, agent_event, mock_agent_as_tool) + + assert event["tool_stream_event"]["tool_use"] == tool_use + assert event["tool_stream_event"]["data"] == agent_event + assert event.agent_as_tool is mock_agent_as_tool + assert event.tool_use_id == "agent_tool_123" + + def test_is_tool_stream_event_subclass(self): + """Test that AgentAsToolStreamEvent is a ToolStreamEvent subclass.""" + tool_use: ToolUse = { + "toolUseId": "id_123", + "name": "tool", + "input": {}, + } + mock_agent_as_tool = MagicMock() + event = AgentAsToolStreamEvent(tool_use, {}, mock_agent_as_tool) + + assert isinstance(event, ToolStreamEvent) + assert isinstance(event, TypedEvent) + assert type(event) is AgentAsToolStreamEvent diff --git a/tests_integ/test_agent_as_tool.py b/tests_integ/test_agent_as_tool.py new file mode 100644 index 000000000..a808fcd23 --- /dev/null +++ b/tests_integ/test_agent_as_tool.py @@ -0,0 +1,36 @@ +import pytest + +from strands import Agent, tool + + +@tool +def get_tiger_height() -> int: + """Returns the height of a tiger in centimeters.""" + return 100 + + +@pytest.mark.asyncio +async def test_stream_async_with_agent_tool(): + inner_agent = Agent( + name="myAgentTool", + description="An agent tool knowledgeable about tigers", + tools=[get_tiger_height], + ) + agent_tool = inner_agent.as_tool() + agent = Agent( + name="myOtherAgent", + tools=[agent_tool], + ) + + result = await agent.invoke_async( + prompt="Invoke the myAgentTool and ask about the height of tigers.", + ) + + # Outer agent completed and called the agent tool + assert result.stop_reason == "end_turn" + assert "myAgentTool" in result.metrics.tool_metrics + assert result.metrics.tool_metrics["myAgentTool"].success_count >= 1 + + # Inner agent called get_tiger_height + assert "get_tiger_height" in inner_agent.event_loop_metrics.tool_metrics + assert inner_agent.event_loop_metrics.tool_metrics["get_tiger_height"].success_count >= 1