From 2418ad251f523e828db32735c70d636f95a7f96c Mon Sep 17 00:00:00 2001 From: Owen Kaplan Date: Wed, 18 Mar 2026 15:55:29 -0400 Subject: [PATCH 1/4] feat: add AgentAsTool --- src/strands/agent/__init__.py | 2 + src/strands/agent/agent.py | 29 +++ src/strands/agent/agent_as_tool.py | 173 ++++++++++++++++++ tests/strands/agent/test_agent.py | 50 +++++ tests/strands/agent/test_agent_as_tool.py | 213 ++++++++++++++++++++++ tests_integ/test_agent_as_tool.py | 36 ++++ 6 files changed, 503 insertions(+) create mode 100644 src/strands/agent/agent_as_tool.py create mode 100644 tests/strands/agent/test_agent_as_tool.py create mode 100644 tests_integ/test_agent_as_tool.py diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index c901e800f..a4911b34e 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -11,6 +11,7 @@ from ..event_loop._retry import ModelRetryStrategy from .agent import Agent +from .agent_as_tool import AgentAsTool from .agent_result import AgentResult from .base import AgentBase from .conversation_manager import ( @@ -24,6 +25,7 @@ "Agent", "AgentBase", "AgentResult", + "AgentAsTool", "ConversationManager", "NullConversationManager", "SlidingWindowConversationManager", diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f378a886a..ad2b7a14b 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -62,6 +62,7 @@ from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException from ..types.traces import AttributeValue +from .agent_as_tool import AgentAsTool from .agent_result import AgentResult from .base import AgentBase from .conversation_manager import ( @@ -612,6 +613,34 @@ async def structured_output_async(self, output_model: type[T], prompt: AgentInpu finally: await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, invocation_state={})) + def as_tool( + self, + name: str | None = None, + description: str | None = None, + ) -> AgentAsTool: + r"""Convert this agent into a tool for use by another agent. + + Args: + name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. + Defaults to the agent's name. + description: Tool description. Defaults to the agent's description. + + Returns: + An AgentAsTool wrapping this agent. + + Example: + ```python + researcher = Agent(name="researcher", description="Finds information") + writer = Agent(name="writer", tools=[researcher.as_tool()]) + writer("Write about AI agents") + ``` + """ + if not name: + name = self.name + if not description: + description = self.description or f"Use the {name} tool to invoke this agent as a tool" + return AgentAsTool(self, name=name, description=description) + def cleanup(self) -> None: """Clean up resources used by the agent. diff --git a/src/strands/agent/agent_as_tool.py b/src/strands/agent/agent_as_tool.py new file mode 100644 index 000000000..713456b91 --- /dev/null +++ b/src/strands/agent/agent_as_tool.py @@ -0,0 +1,173 @@ +"""Agent-as-tool adapter. + +This module provides the AgentAsTool class that wraps an Agent (or any AgentBase) as a tool +so it can be passed to another agent's tool list. +""" + +import logging +from typing import Any + +from typing_extensions import override + +from ..types._events import ToolResultEvent +from ..types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse +from .base import AgentBase + +logger = logging.getLogger(__name__) + + +class AgentAsTool(AgentTool): + """Adapter that exposes an Agent as a tool for use by other agents. + + The tool accepts a single ``input`` string parameter, invokes the wrapped + agent, and returns the text response. + + Example: + ```python + from strands import Agent + from strands.agent.agent_as_tool import AgentAsTool + + researcher = Agent(name="researcher", description="Finds information") + + # Use directly + tool = AgentAsTool(researcher, name="researcher", description="Finds information") + + # Or via convenience method + tool = researcher.as_tool() + + writer = Agent(name="writer", tools=[tool]) + writer("Write about AI agents") + ``` + """ + + def __init__( + self, + agent: AgentBase, + *, + name: str, + description: str, + ) -> None: + r"""Initialize the agent-as-tool adapter. + + Args: + agent: The agent to wrap as a tool. + name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. + description: Tool description. + """ + super().__init__() + self._agent = agent + self._tool_name = name + self._description = description + + @property + def agent(self) -> AgentBase: + """The wrapped agent instance.""" + return self._agent + + @property + def tool_name(self) -> str: + """Get the tool name.""" + return self._tool_name + + @property + def tool_spec(self) -> ToolSpec: + """Get the tool specification.""" + return { + "name": self._tool_name, + "description": self._description, + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "The input to send to the agent tool.", + } + }, + "required": ["input"], + } + }, + } + + @property + def tool_type(self) -> str: + """Get the tool type.""" + return "agent" + + @override + async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: + """Invoke the wrapped agent via streaming and yield events. + + Intermediate agent events are forwarded as ToolStreamEvents so the parent + agent's callback handler can display sub-agent progress. The final + AgentResult is yielded as a ToolResultEvent. + + Args: + tool_use: The tool use request containing the input parameter. + invocation_state: Context for the tool invocation. + **kwargs: Additional keyword arguments. + + Yields: + ToolStreamEvent for intermediate events, then ToolResultEvent with the final response. + """ + prompt = tool_use["input"].get("input", "") if isinstance(tool_use["input"], dict) else tool_use["input"] + tool_use_id = tool_use["toolUseId"] + + logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking agent", self._tool_name, tool_use_id) + + try: + result = None + async for event in self._agent.stream_async(prompt): + if "result" in event: + result = event["result"] + else: + yield event + + if result is None: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": "Agent did not produce a result"}], + } + ) + return + + if result.structured_output: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"json": result.structured_output.model_dump()}], + } + ) + else: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"text": str(result)}], + } + ) + + except Exception as e: + logger.warning( + "tool_name=<%s>, tool_use_id=<%s> | agent invocation failed: %s", + self._tool_name, + tool_use_id, + e, + ) + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Agent error: {e}"}], + } + ) + + @override + def get_display_properties(self) -> dict[str, str]: + """Get properties for UI display.""" + properties = super().get_display_properties() + properties["Agent"] = getattr(self._agent, "name", "unknown") + return properties diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 967a0dafb..6bb64f870 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -16,6 +16,7 @@ import strands from strands import Agent, Plugin, ToolContext from strands.agent import AgentResult +from strands.agent.agent_as_tool import AgentAsTool from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState @@ -2699,3 +2700,52 @@ def hook_callback(event: BeforeModelCallEvent): agent("test") assert len(hook_called) == 1 + + +def test_as_tool_returns_agent_tool(): + """Test that as_tool returns an AgentAsTool wrapping the agent.""" + agent = Agent(name="researcher", description="Finds information") + tool = agent.as_tool() + + assert isinstance(tool, AgentAsTool) + assert tool.agent is agent + + +def test_as_tool_defaults_name_from_agent(): + """Test that as_tool defaults the tool name to the agent's name.""" + agent = Agent(name="researcher") + tool = agent.as_tool() + + assert tool.tool_name == "researcher" + + +def test_as_tool_defaults_description_from_agent(): + """Test that as_tool defaults the description to the agent's description.""" + agent = Agent(name="researcher", description="Finds information") + tool = agent.as_tool() + + assert tool.tool_spec["description"] == "Finds information" + + +def test_as_tool_custom_name(): + """Test that as_tool accepts a custom name.""" + agent = Agent(name="researcher") + tool = agent.as_tool(name="custom_name") + + assert tool.tool_name == "custom_name" + + +def test_as_tool_custom_description(): + """Test that as_tool accepts a custom description.""" + agent = Agent(name="researcher", description="Original") + tool = agent.as_tool(description="Custom description") + + assert tool.tool_spec["description"] == "Custom description" + + +def test_as_tool_defaults_description_when_agent_has_none(): + """Test that as_tool generates a default description when agent has none.""" + agent = Agent(name="researcher") + tool = agent.as_tool() + + assert tool.tool_spec["description"] == "Use the researcher tool to invoke this agent as a tool" diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py new file mode 100644 index 000000000..2b2fb9ca6 --- /dev/null +++ b/tests/strands/agent/test_agent_as_tool.py @@ -0,0 +1,213 @@ +"""Tests for AgentAsTool - the agent-as-tool adapter.""" + +from unittest.mock import MagicMock + +import pytest + +from strands.agent.agent_as_tool import AgentAsTool +from strands.agent.agent_result import AgentResult +from strands.telemetry.metrics import EventLoopMetrics +from strands.types._events import ToolResultEvent + + +async def _mock_stream_async(result, intermediate_events=None): + """Helper that yields intermediate events then the final result event.""" + for event in intermediate_events or []: + yield event + yield {"result": result} + + +@pytest.fixture +def mock_agent(): + agent = MagicMock() + agent.name = "test_agent" + agent.description = "A test agent" + return agent + + +@pytest.fixture +def tool(mock_agent): + return AgentAsTool(mock_agent, name="test_agent", description="A test agent") + + +@pytest.fixture +def tool_use(): + return { + "toolUseId": "tool-123", + "name": "test_agent", + "input": {"input": "hello"}, + } + + +@pytest.fixture +def agent_result(): + return AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "response text"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + + +# --- init --- + + +def test_init_sets_name(mock_agent): + tool = AgentAsTool(mock_agent, name="my_tool", description="desc") + assert tool.tool_name == "my_tool" + + +def test_init_sets_description(mock_agent): + tool = AgentAsTool(mock_agent, name="my_tool", description="custom desc") + assert tool._description == "custom desc" + + +def test_init_stores_agent_reference(mock_agent, tool): + assert tool.agent is mock_agent + + +# --- properties --- + + +def test_tool_name(tool): + assert tool.tool_name == "test_agent" + + +def test_tool_type(tool): + assert tool.tool_type == "agent" + + +def test_tool_spec_name(tool): + assert tool.tool_spec["name"] == "test_agent" + + +def test_tool_spec_description(tool): + assert tool.tool_spec["description"] == "A test agent" + + +def test_tool_spec_input_schema(tool): + schema = tool.tool_spec["inputSchema"]["json"] + assert schema["type"] == "object" + assert "input" in schema["properties"] + assert schema["properties"]["input"]["type"] == "string" + assert schema["required"] == ["input"] + + +def test_display_properties(tool): + props = tool.get_display_properties() + assert props["Agent"] == "test_agent" + assert props["Type"] == "agent" + + +# --- stream --- + + +@pytest.mark.asyncio +async def test_stream_success(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "success" + assert result_events[0]["tool_result"]["content"][0]["text"] == "response text\n" + + +@pytest.mark.asyncio +async def test_stream_passes_input_to_agent(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("hello") + + +@pytest.mark.asyncio +async def test_stream_empty_input(tool, mock_agent, agent_result): + empty_tool_use = { + "toolUseId": "tool-123", + "name": "test_agent", + "input": {}, + } + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(empty_tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("") + + +@pytest.mark.asyncio +async def test_stream_error(tool, mock_agent, tool_use): + mock_agent.stream_async.side_effect = RuntimeError("boom") + + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert events[0]["tool_result"]["status"] == "error" + assert "boom" in events[0]["tool_result"]["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_stream_propagates_tool_use_id(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert result_events[0]["tool_result"]["toolUseId"] == "tool-123" + + +@pytest.mark.asyncio +async def test_stream_forwards_intermediate_events(tool, mock_agent, tool_use, agent_result): + intermediate = [{"data": "partial"}, {"data": "more"}] + mock_agent.stream_async.return_value = _mock_stream_async(agent_result, intermediate) + + events = [event async for event in tool.stream(tool_use, {})] + + # Intermediate events are yielded as-is (raw dicts); wrapping in ToolStreamEvent happens in the caller + non_result_events = [e for e in events if not isinstance(e, ToolResultEvent)] + assert len(non_result_events) == 2 + assert non_result_events[0]["data"] == "partial" + assert non_result_events[1]["data"] == "more" + + +@pytest.mark.asyncio +async def test_stream_no_result_yields_error(tool, mock_agent, tool_use): + async def _empty_stream(): + return + yield # noqa: RET504 - make it an async generator + + mock_agent.stream_async.return_value = _empty_stream() + + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert events[0]["tool_result"]["status"] == "error" + assert "did not produce a result" in events[0]["tool_result"]["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_stream_structured_output(tool, mock_agent, tool_use): + from pydantic import BaseModel + + class MyOutput(BaseModel): + answer: str + + structured = MyOutput(answer="42") + result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ignored"}]}, + metrics=EventLoopMetrics(), + state={}, + structured_output=structured, + ) + mock_agent.stream_async.return_value = _mock_stream_async(result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert result_events[0]["tool_result"]["status"] == "success" + assert result_events[0]["tool_result"]["content"][0]["json"] == {"answer": "42"} diff --git a/tests_integ/test_agent_as_tool.py b/tests_integ/test_agent_as_tool.py new file mode 100644 index 000000000..a808fcd23 --- /dev/null +++ b/tests_integ/test_agent_as_tool.py @@ -0,0 +1,36 @@ +import pytest + +from strands import Agent, tool + + +@tool +def get_tiger_height() -> int: + """Returns the height of a tiger in centimeters.""" + return 100 + + +@pytest.mark.asyncio +async def test_stream_async_with_agent_tool(): + inner_agent = Agent( + name="myAgentTool", + description="An agent tool knowledgeable about tigers", + tools=[get_tiger_height], + ) + agent_tool = inner_agent.as_tool() + agent = Agent( + name="myOtherAgent", + tools=[agent_tool], + ) + + result = await agent.invoke_async( + prompt="Invoke the myAgentTool and ask about the height of tigers.", + ) + + # Outer agent completed and called the agent tool + assert result.stop_reason == "end_turn" + assert "myAgentTool" in result.metrics.tool_metrics + assert result.metrics.tool_metrics["myAgentTool"].success_count >= 1 + + # Inner agent called get_tiger_height + assert "get_tiger_height" in inner_agent.event_loop_metrics.tool_metrics + assert inner_agent.event_loop_metrics.tool_metrics["get_tiger_height"].success_count >= 1 From 2826d268c01008c87ef93659f6e6208df59c9869 Mon Sep 17 00:00:00 2001 From: Owen Kaplan Date: Thu, 19 Mar 2026 14:58:55 -0400 Subject: [PATCH 2/4] feat: add preserve_context to AgentAsTool --- src/strands/agent/agent_as_tool.py | 46 ++++++- tests/strands/agent/test_agent_as_tool.py | 139 ++++++++++++++++++++++ 2 files changed, 183 insertions(+), 2 deletions(-) diff --git a/src/strands/agent/agent_as_tool.py b/src/strands/agent/agent_as_tool.py index 713456b91..c2aefe021 100644 --- a/src/strands/agent/agent_as_tool.py +++ b/src/strands/agent/agent_as_tool.py @@ -82,7 +82,14 @@ def tool_spec(self) -> ToolSpec: "input": { "type": "string", "description": "The input to send to the agent tool.", - } + }, + "preserve_context": { + "type": "boolean", + "description": ( + "Whether to preserve the agent's conversation context across invocations. " + "Defaults to true. Set to false to clear conversation history before this call." + ), + }, }, "required": ["input"], } @@ -110,9 +117,44 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw Yields: ToolStreamEvent for intermediate events, then ToolResultEvent with the final response. """ - prompt = tool_use["input"].get("input", "") if isinstance(tool_use["input"], dict) else tool_use["input"] + tool_input = tool_use["input"] + if isinstance(tool_input, dict): + prompt = tool_input.get("input", "") + preserve_context = tool_input.get("preserve_context", True) + elif isinstance(tool_input, str): + prompt = tool_input + preserve_context = True + else: + logger.warning( + "tool_name=<%s> | unexpected input type: %s", + self._tool_name, + type(tool_input), + ) + prompt = str(tool_input) + preserve_context = True + tool_use_id = tool_use["toolUseId"] + if not preserve_context: + # AgentBase is a protocol and does not guarantee a messages attribute. + # We check for it at runtime to support Agent and other implementations + # that expose a mutable messages list. + messages = getattr(self._agent, "messages", None) + if isinstance(messages, list): + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | clearing agent conversation context", + self._tool_name, + tool_use_id, + ) + messages.clear() + else: + logger.warning( + "tool_name=<%s>, tool_use_id=<%s> | preserve_context=false requested" + " but agent does not expose a messages list", + self._tool_name, + tool_use_id, + ) + logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking agent", self._tool_name, tool_use_id) try: diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py index 2b2fb9ca6..f142090a8 100644 --- a/tests/strands/agent/test_agent_as_tool.py +++ b/tests/strands/agent/test_agent_as_tool.py @@ -1,5 +1,6 @@ """Tests for AgentAsTool - the agent-as-tool adapter.""" +import logging from unittest.mock import MagicMock import pytest @@ -90,6 +91,8 @@ def test_tool_spec_input_schema(tool): assert schema["type"] == "object" assert "input" in schema["properties"] assert schema["properties"]["input"]["type"] == "string" + assert "preserve_context" in schema["properties"] + assert schema["properties"]["preserve_context"]["type"] == "boolean" assert schema["required"] == ["input"] @@ -211,3 +214,139 @@ class MyOutput(BaseModel): result_events = [e for e in events if isinstance(e, ToolResultEvent)] assert result_events[0]["tool_result"]["status"] == "success" assert result_events[0]["tool_result"]["content"][0]["json"] == {"answer": "42"} + + +@pytest.mark.asyncio +async def test_stream_string_input(tool, mock_agent, agent_result): + """When tool_use input is a plain string rather than a dict.""" + tool_use = { + "toolUseId": "tool-123", + "name": "test_agent", + "input": "direct string", + } + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("direct string") + + +# --- preserve_context --- + + +class _FakeAgent: + """Minimal fake agent with a real messages list for preserve_context tests.""" + + def __init__(self): + self.name = "fake_agent" + self.messages: list = [] + + async def invoke_async(self, prompt=None, **kwargs): + pass + + def __call__(self, prompt=None, **kwargs): + pass + + def stream_async(self, prompt=None, **kwargs): + return _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + +@pytest.mark.asyncio +async def test_stream_clears_context_when_preserve_context_false(): + agent = _FakeAgent() + agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + tool = AgentAsTool(agent, name="fake_agent", description="desc") + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello", "preserve_context": False}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert agent.messages == [] + + +@pytest.mark.asyncio +async def test_stream_preserves_context_by_default(): + agent = _FakeAgent() + agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + tool = AgentAsTool(agent, name="fake_agent", description="desc") + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert len(agent.messages) >= 1 + + +@pytest.mark.asyncio +async def test_stream_preserves_context_when_explicitly_true(): + agent = _FakeAgent() + agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + tool = AgentAsTool(agent, name="fake_agent", description="desc") + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello", "preserve_context": True}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert len(agent.messages) >= 1 + + +@pytest.mark.asyncio +async def test_stream_preserve_context_false_warns_when_no_messages_attr(caplog): + """Agent without a messages attribute should log a warning.""" + + class _NoMessagesAgent: + name = "bare_agent" + + async def invoke_async(self, prompt=None, **kwargs): + pass + + def __call__(self, prompt=None, **kwargs): + pass + + def stream_async(self, prompt=None, **kwargs): + return _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + agent = _NoMessagesAgent() + tool = AgentAsTool(agent, name="bare_agent", description="desc") + + tool_use = { + "toolUseId": "tool-123", + "name": "bare_agent", + "input": {"input": "hello", "preserve_context": False}, + } + + with caplog.at_level(logging.WARNING, logger="strands.agent.agent_as_tool"): + async for _ in tool.stream(tool_use, {}): + pass + + assert "preserve_context=false requested" in caplog.text From 9ca227c1dad9b092b00f6cd527f5b01c3e487dd0 Mon Sep 17 00:00:00 2001 From: Owen Kaplan Date: Fri, 20 Mar 2026 14:22:41 -0400 Subject: [PATCH 3/4] fix: move preserve_context to class var; reset to original agent state; yield AgentAsToolStreamEvents; small fixes --- src/strands/__init__.py | 2 + src/strands/agent/__init__.py | 2 +- .../{agent_as_tool.py => _agent_as_tool.py} | 102 ++++--- src/strands/agent/agent.py | 12 +- src/strands/types/_events.py | 26 ++ tests/strands/agent/test_agent.py | 5 +- tests/strands/agent/test_agent_as_tool.py | 279 +++++++++++------- tests/strands/types/test__events.py | 37 +++ 8 files changed, 317 insertions(+), 148 deletions(-) rename src/strands/agent/{agent_as_tool.py => _agent_as_tool.py} (63%) diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 2078f16ce..dc3a0c7ff 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -1,6 +1,7 @@ """A framework for building, deploying, and managing AI agents.""" from . import agent, models, telemetry, types +from .agent._agent_as_tool import AgentAsTool from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy @@ -11,6 +12,7 @@ __all__ = [ "Agent", + "AgentAsTool", "AgentBase", "AgentSkills", "agent", diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index a4911b34e..d0254852d 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -10,8 +10,8 @@ from typing import Any from ..event_loop._retry import ModelRetryStrategy +from ._agent_as_tool import AgentAsTool from .agent import Agent -from .agent_as_tool import AgentAsTool from .agent_result import AgentResult from .base import AgentBase from .conversation_manager import ( diff --git a/src/strands/agent/agent_as_tool.py b/src/strands/agent/_agent_as_tool.py similarity index 63% rename from src/strands/agent/agent_as_tool.py rename to src/strands/agent/_agent_as_tool.py index c2aefe021..534be5d9a 100644 --- a/src/strands/agent/agent_as_tool.py +++ b/src/strands/agent/_agent_as_tool.py @@ -4,12 +4,15 @@ so it can be passed to another agent's tool list. """ +import copy import logging from typing import Any from typing_extensions import override -from ..types._events import ToolResultEvent +from ..agent.state import AgentState +from ..types._events import AgentAsToolStreamEvent, ToolResultEvent +from ..types.content import Messages from ..types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse from .base import AgentBase @@ -25,7 +28,7 @@ class AgentAsTool(AgentTool): Example: ```python from strands import Agent - from strands.agent.agent_as_tool import AgentAsTool + from strands.agent import AgentAsTool researcher = Agent(name="researcher", description="Finds information") @@ -35,6 +38,9 @@ class AgentAsTool(AgentTool): # Or via convenience method tool = researcher.as_tool() + # Start each invocation with a fresh conversation + tool = researcher.as_tool(preserve_context=False) + writer = Agent(name="writer", tools=[tool]) writer("Write about AI agents") ``` @@ -46,6 +52,7 @@ def __init__( *, name: str, description: str, + preserve_context: bool = True, ) -> None: r"""Initialize the agent-as-tool adapter. @@ -53,11 +60,34 @@ def __init__( agent: The agent to wrap as a tool. name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. description: Tool description. + preserve_context: Whether to preserve the agent's conversation history across + invocations. When False, the agent's messages and state are reset to the + values they had at construction time before each call, ensuring every + invocation starts from the same baseline regardless of any external + interactions with the agent. Defaults to True. Only effective when the + wrapped agent exposes a mutable ``messages`` list and/or an ``AgentState`` + (e.g. ``strands.agent.Agent``). """ super().__init__() self._agent = agent self._tool_name = name self._description = description + self._preserve_context = preserve_context + + # When preserve_context=False, we snapshot the agent's initial state so we can + # restore it before each invocation. This mirrors GraphNode.reset_executor_state(). + # We require an Agent instance for this since AgentBase doesn't guarantee + # messages/state attributes. + self._initial_messages: Messages = [] + self._initial_state: AgentState = AgentState() + + if not preserve_context: + from .agent import Agent + + if not isinstance(agent, Agent): + raise TypeError(f"preserve_context=False requires an Agent instance, got {type(agent).__name__}") + self._initial_messages = copy.deepcopy(agent.messages) + self._initial_state = AgentState(agent.state.get()) @property def agent(self) -> AgentBase: @@ -83,13 +113,6 @@ def tool_spec(self) -> ToolSpec: "type": "string", "description": "The input to send to the agent tool.", }, - "preserve_context": { - "type": "boolean", - "description": ( - "Whether to preserve the agent's conversation context across invocations. " - "Defaults to true. Set to false to clear conversation history before this call." - ), - }, }, "required": ["input"], } @@ -105,8 +128,8 @@ def tool_type(self) -> str: async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: """Invoke the wrapped agent via streaming and yield events. - Intermediate agent events are forwarded as ToolStreamEvents so the parent - agent's callback handler can display sub-agent progress. The final + Intermediate agent events are wrapped in AgentAsToolStreamEvent so the caller + can distinguish sub-agent progress from regular tool events. The final AgentResult is yielded as a ToolResultEvent. Args: @@ -115,45 +138,21 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw **kwargs: Additional keyword arguments. Yields: - ToolStreamEvent for intermediate events, then ToolResultEvent with the final response. + AgentAsToolStreamEvent for intermediate events, then ToolResultEvent with the final response. """ tool_input = tool_use["input"] if isinstance(tool_input, dict): prompt = tool_input.get("input", "") - preserve_context = tool_input.get("preserve_context", True) elif isinstance(tool_input, str): prompt = tool_input - preserve_context = True else: - logger.warning( - "tool_name=<%s> | unexpected input type: %s", - self._tool_name, - type(tool_input), - ) + logger.warning("tool_name=<%s> | unexpected input type: %s", self._tool_name, type(tool_input)) prompt = str(tool_input) - preserve_context = True tool_use_id = tool_use["toolUseId"] - if not preserve_context: - # AgentBase is a protocol and does not guarantee a messages attribute. - # We check for it at runtime to support Agent and other implementations - # that expose a mutable messages list. - messages = getattr(self._agent, "messages", None) - if isinstance(messages, list): - logger.debug( - "tool_name=<%s>, tool_use_id=<%s> | clearing agent conversation context", - self._tool_name, - tool_use_id, - ) - messages.clear() - else: - logger.warning( - "tool_name=<%s>, tool_use_id=<%s> | preserve_context=false requested" - " but agent does not expose a messages list", - self._tool_name, - tool_use_id, - ) + if not self._preserve_context: + self._reset_agent_state(tool_use_id) logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking agent", self._tool_name, tool_use_id) @@ -163,7 +162,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw if "result" in event: result = event["result"] else: - yield event + yield AgentAsToolStreamEvent(tool_use, event, self) if result is None: yield ToolResultEvent( @@ -207,6 +206,29 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw } ) + def _reset_agent_state(self, tool_use_id: str) -> None: + """Reset the wrapped agent to its initial state. + + Restores messages and state to the values captured at construction time. + This mirrors the pattern used by ``GraphNode.reset_executor_state()``. + + Args: + tool_use_id: Tool use ID for logging context. + """ + from .agent import Agent + + # isinstance narrows the type for mypy; __init__ guarantees this when preserve_context=False + if not isinstance(self._agent, Agent): + return + + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | resetting agent to initial state", + self._tool_name, + tool_use_id, + ) + self._agent.messages = copy.deepcopy(self._initial_messages) + self._agent.state = AgentState(self._initial_state.get()) + @override def get_display_properties(self) -> dict[str, str]: """Get properties for UI display.""" diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index ad2b7a14b..8d94de45b 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -62,7 +62,7 @@ from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException from ..types.traces import AttributeValue -from .agent_as_tool import AgentAsTool +from ._agent_as_tool import AgentAsTool from .agent_result import AgentResult from .base import AgentBase from .conversation_manager import ( @@ -617,6 +617,7 @@ def as_tool( self, name: str | None = None, description: str | None = None, + preserve_context: bool = True, ) -> AgentAsTool: r"""Convert this agent into a tool for use by another agent. @@ -624,6 +625,11 @@ def as_tool( name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. Defaults to the agent's name. description: Tool description. Defaults to the agent's description. + preserve_context: Whether to preserve the agent's conversation history across + invocations. When False, the agent's messages and state are reset to the + values they had at construction time before each call, ensuring every + invocation starts from the same baseline regardless of any external + interactions with the agent. Defaults to True. Returns: An AgentAsTool wrapping this agent. @@ -638,8 +644,8 @@ def as_tool( if not name: name = self.name if not description: - description = self.description or f"Use the {name} tool to invoke this agent as a tool" - return AgentAsTool(self, name=name, description=description) + description = self.description or f"Use the {name} agent as a tool by providing a natural language input" + return AgentAsTool(self, name=name, description=description, preserve_context=preserve_context) def cleanup(self) -> None: """Clean up resources used by the agent. diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 5b0ae78f6..5603aedfb 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from ..agent import AgentResult + from ..agent._agent_as_tool import AgentAsTool from ..multiagent.base import MultiAgentResult, NodeResult @@ -323,6 +324,31 @@ def tool_use_id(self) -> str: return cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use"))["toolUseId"] +class AgentAsToolStreamEvent(ToolStreamEvent): + """Event emitted when an agent-as-tool yields intermediate events during execution. + + Extends ToolStreamEvent with a reference to the originating AgentAsTool so callers + can distinguish sub-agent stream events from regular tool stream events and access + the wrapped agent, tool name, description, etc. + """ + + def __init__(self, tool_use: ToolUse, tool_stream_data: Any, agent_as_tool: "AgentAsTool") -> None: + """Initialize with tool streaming data and agent-tool reference. + + Args: + tool_use: The tool invocation producing the stream. + tool_stream_data: The yielded event from the sub-agent execution. + agent_as_tool: The AgentAsTool instance that produced this event. + """ + super().__init__(tool_use, tool_stream_data) + self._agent_as_tool = agent_as_tool + + @property + def agent_as_tool(self) -> "AgentAsTool": + """The AgentAsTool instance that produced this event.""" + return self._agent_as_tool + + class ToolCancelEvent(TypedEvent): """Event emitted when a user cancels a tool call from their BeforeToolCallEvent hook.""" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 6bb64f870..c089ba808 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -15,8 +15,7 @@ import strands from strands import Agent, Plugin, ToolContext -from strands.agent import AgentResult -from strands.agent.agent_as_tool import AgentAsTool +from strands.agent import AgentAsTool, AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState @@ -2748,4 +2747,4 @@ def test_as_tool_defaults_description_when_agent_has_none(): agent = Agent(name="researcher") tool = agent.as_tool() - assert tool.tool_spec["description"] == "Use the researcher tool to invoke this agent as a tool" + assert tool.tool_spec["description"] == "Use the researcher agent as a tool by providing a natural language input" diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py index f142090a8..a6cc28d51 100644 --- a/tests/strands/agent/test_agent_as_tool.py +++ b/tests/strands/agent/test_agent_as_tool.py @@ -1,14 +1,13 @@ """Tests for AgentAsTool - the agent-as-tool adapter.""" -import logging from unittest.mock import MagicMock import pytest -from strands.agent.agent_as_tool import AgentAsTool +from strands.agent import AgentAsTool from strands.agent.agent_result import AgentResult from strands.telemetry.metrics import EventLoopMetrics -from strands.types._events import ToolResultEvent +from strands.types._events import AgentAsToolStreamEvent, ToolResultEvent, ToolStreamEvent async def _mock_stream_async(result, intermediate_events=None): @@ -53,50 +52,40 @@ def agent_result(): # --- init --- -def test_init_sets_name(mock_agent): - tool = AgentAsTool(mock_agent, name="my_tool", description="desc") +def test_init(mock_agent): + tool = AgentAsTool(mock_agent, name="my_tool", description="custom desc") assert tool.tool_name == "my_tool" + assert tool._description == "custom desc" + assert tool.agent is mock_agent -def test_init_sets_description(mock_agent): - tool = AgentAsTool(mock_agent, name="my_tool", description="custom desc") - assert tool._description == "custom desc" +def test_init_preserve_context_defaults_true(mock_agent): + tool = AgentAsTool(mock_agent, name="t", description="d") + assert tool._preserve_context is True -def test_init_stores_agent_reference(mock_agent, tool): - assert tool.agent is mock_agent +def test_init_preserve_context_false(fake_agent): + tool = AgentAsTool(fake_agent, name="t", description="d", preserve_context=False) + assert tool._preserve_context is False # --- properties --- -def test_tool_name(tool): +def test_tool_properties(tool): assert tool.tool_name == "test_agent" - - -def test_tool_type(tool): assert tool.tool_type == "agent" + spec = tool.tool_spec + assert spec["name"] == "test_agent" + assert spec["description"] == "A test agent" -def test_tool_spec_name(tool): - assert tool.tool_spec["name"] == "test_agent" - - -def test_tool_spec_description(tool): - assert tool.tool_spec["description"] == "A test agent" - - -def test_tool_spec_input_schema(tool): - schema = tool.tool_spec["inputSchema"]["json"] + schema = spec["inputSchema"]["json"] assert schema["type"] == "object" assert "input" in schema["properties"] assert schema["properties"]["input"]["type"] == "string" - assert "preserve_context" in schema["properties"] - assert schema["properties"]["preserve_context"]["type"] == "boolean" assert schema["required"] == ["input"] - -def test_display_properties(tool): props = tool.get_display_properties() assert props["Agent"] == "test_agent" assert props["Type"] == "agent" @@ -142,6 +131,21 @@ async def test_stream_empty_input(tool, mock_agent, agent_result): mock_agent.stream_async.assert_called_once_with("") +@pytest.mark.asyncio +async def test_stream_string_input(tool, mock_agent, agent_result): + tool_use = { + "toolUseId": "tool-123", + "name": "test_agent", + "input": "direct string", + } + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("direct string") + + @pytest.mark.asyncio async def test_stream_error(tool, mock_agent, tool_use): mock_agent.stream_async.side_effect = RuntimeError("boom") @@ -170,11 +174,32 @@ async def test_stream_forwards_intermediate_events(tool, mock_agent, tool_use, a events = [event async for event in tool.stream(tool_use, {})] - # Intermediate events are yielded as-is (raw dicts); wrapping in ToolStreamEvent happens in the caller - non_result_events = [e for e in events if not isinstance(e, ToolResultEvent)] - assert len(non_result_events) == 2 - assert non_result_events[0]["data"] == "partial" - assert non_result_events[1]["data"] == "more" + stream_events = [e for e in events if isinstance(e, AgentAsToolStreamEvent)] + assert len(stream_events) == 2 + assert stream_events[0]["tool_stream_event"]["data"]["data"] == "partial" + assert stream_events[1]["tool_stream_event"]["data"]["data"] == "more" + assert stream_events[0].agent_as_tool is tool + assert stream_events[0].tool_use_id == "tool-123" + + +@pytest.mark.asyncio +async def test_stream_events_not_double_wrapped_by_executor(tool, mock_agent, tool_use, agent_result): + """AgentAsToolStreamEvent is a ToolStreamEvent subclass, so the executor should pass it through directly.""" + intermediate = [{"data": "chunk"}] + mock_agent.stream_async.return_value = _mock_stream_async(agent_result, intermediate) + + events = [event async for event in tool.stream(tool_use, {})] + + stream_events = [e for e in events if isinstance(e, AgentAsToolStreamEvent)] + assert len(stream_events) == 1 + + event = stream_events[0] + # It's a ToolStreamEvent (so the executor yields it directly) + assert isinstance(event, ToolStreamEvent) + # But it's specifically an AgentAsToolStreamEvent (not re-wrapped) + assert type(event) is AgentAsToolStreamEvent + # And it references the originating AgentAsTool + assert event.agent_as_tool is tool @pytest.mark.asyncio @@ -216,72 +241,133 @@ class MyOutput(BaseModel): assert result_events[0]["tool_result"]["content"][0]["json"] == {"answer": "42"} +# --- preserve_context --- + + +@pytest.fixture +def fake_agent(): + """A real Agent instance for preserve_context tests.""" + from strands.agent.agent import Agent + + return Agent(name="fake_agent", callback_handler=None) + + @pytest.mark.asyncio -async def test_stream_string_input(tool, mock_agent, agent_result): - """When tool_use input is a plain string rather than a dict.""" +async def test_stream_resets_to_initial_state_when_preserve_context_false(fake_agent): + fake_agent.messages = [{"role": "user", "content": [{"text": "initial"}]}] + fake_agent.state.set("counter", 0) + + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + # Mutate agent state as if a previous invocation happened + fake_agent.messages.append({"role": "assistant", "content": [{"text": "reply"}]}) + fake_agent.state.set("counter", 5) + + # Mock stream_async so we don't need a real model + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + tool_use = { "toolUseId": "tool-123", - "name": "test_agent", - "input": "direct string", + "name": "fake_agent", + "input": {"input": "hello"}, } - mock_agent.stream_async.return_value = _mock_stream_async(agent_result) async for _ in tool.stream(tool_use, {}): pass - mock_agent.stream_async.assert_called_once_with("direct string") - + assert fake_agent.messages == [{"role": "user", "content": [{"text": "initial"}]}] + assert fake_agent.state.get("counter") == 0 -# --- preserve_context --- +@pytest.mark.asyncio +async def test_stream_resets_on_every_invocation(fake_agent): + """Each call should reset to the same initial snapshot, not to the previous call's state.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "seed"}]}] + fake_agent.state.set("count", 1) + + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) -class _FakeAgent: - """Minimal fake agent with a real messages list for preserve_context tests.""" - - def __init__(self): - self.name = "fake_agent" - self.messages: list = [] + tool_use = { + "toolUseId": "tool-1", + "name": "fake_agent", + "input": {"input": "first"}, + } - async def invoke_async(self, prompt=None, **kwargs): + async for _ in tool.stream(tool_use, {}): pass + fake_agent.messages.append({"role": "assistant", "content": [{"text": "added"}]}) + fake_agent.state.set("count", 99) - def __call__(self, prompt=None, **kwargs): + tool_use["toolUseId"] = "tool-2" + async for _ in tool.stream(tool_use, {}): pass - def stream_async(self, prompt=None, **kwargs): - return _mock_stream_async( - AgentResult( - stop_reason="end_turn", - message={"role": "assistant", "content": [{"text": "ok"}]}, - metrics=EventLoopMetrics(), - state={}, - ) - ) + assert fake_agent.messages == [{"role": "user", "content": [{"text": "seed"}]}] + assert fake_agent.state.get("count") == 1 @pytest.mark.asyncio -async def test_stream_clears_context_when_preserve_context_false(): - agent = _FakeAgent() - agent.messages = [{"role": "user", "content": [{"text": "old"}]}] - tool = AgentAsTool(agent, name="fake_agent", description="desc") +async def test_stream_initial_snapshot_is_deep_copy(fake_agent): + """Mutating the agent's messages after construction should not affect the snapshot.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "original"}]}] + + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.messages[0]["content"][0]["text"] = "mutated" + fake_agent.messages.append({"role": "assistant", "content": [{"text": "extra"}]}) + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) tool_use = { "toolUseId": "tool-123", "name": "fake_agent", - "input": {"input": "hello", "preserve_context": False}, + "input": {"input": "hello"}, } async for _ in tool.stream(tool_use, {}): pass - assert agent.messages == [] + assert fake_agent.messages == [{"role": "user", "content": [{"text": "original"}]}] @pytest.mark.asyncio -async def test_stream_preserves_context_by_default(): - agent = _FakeAgent() - agent.messages = [{"role": "user", "content": [{"text": "old"}]}] - tool = AgentAsTool(agent, name="fake_agent", description="desc") +async def test_stream_resets_empty_initial_state_when_preserve_context_false(fake_agent): + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) tool_use = { "toolUseId": "tool-123", @@ -292,33 +378,43 @@ async def test_stream_preserves_context_by_default(): async for _ in tool.stream(tool_use, {}): pass - assert len(agent.messages) >= 1 + assert fake_agent.messages == [] + assert fake_agent.state.get() == {} @pytest.mark.asyncio -async def test_stream_preserves_context_when_explicitly_true(): - agent = _FakeAgent() - agent.messages = [{"role": "user", "content": [{"text": "old"}]}] - tool = AgentAsTool(agent, name="fake_agent", description="desc") +async def test_stream_preserves_context_by_default(fake_agent): + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc") + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) tool_use = { "toolUseId": "tool-123", "name": "fake_agent", - "input": {"input": "hello", "preserve_context": True}, + "input": {"input": "hello"}, } async for _ in tool.stream(tool_use, {}): pass - assert len(agent.messages) >= 1 + assert len(fake_agent.messages) >= 1 + assert fake_agent.state.get("key") == "value" -@pytest.mark.asyncio -async def test_stream_preserve_context_false_warns_when_no_messages_attr(caplog): - """Agent without a messages attribute should log a warning.""" +def test_preserve_context_false_requires_agent_instance(): + """preserve_context=False should raise TypeError for non-Agent instances.""" - class _NoMessagesAgent: - name = "bare_agent" + class _NotAnAgent: + name = "not_agent" async def invoke_async(self, prompt=None, **kwargs): pass @@ -327,26 +423,7 @@ def __call__(self, prompt=None, **kwargs): pass def stream_async(self, prompt=None, **kwargs): - return _mock_stream_async( - AgentResult( - stop_reason="end_turn", - message={"role": "assistant", "content": [{"text": "ok"}]}, - metrics=EventLoopMetrics(), - state={}, - ) - ) - - agent = _NoMessagesAgent() - tool = AgentAsTool(agent, name="bare_agent", description="desc") - - tool_use = { - "toolUseId": "tool-123", - "name": "bare_agent", - "input": {"input": "hello", "preserve_context": False}, - } - - with caplog.at_level(logging.WARNING, logger="strands.agent.agent_as_tool"): - async for _ in tool.stream(tool_use, {}): pass - assert "preserve_context=false requested" in caplog.text + with pytest.raises(TypeError, match="requires an Agent instance"): + AgentAsTool(_NotAnAgent(), name="bad", description="desc", preserve_context=False) diff --git a/tests/strands/types/test__events.py b/tests/strands/types/test__events.py index 6163faeb6..48465e1f6 100644 --- a/tests/strands/types/test__events.py +++ b/tests/strands/types/test__events.py @@ -6,6 +6,7 @@ from strands.telemetry import EventLoopMetrics from strands.types._events import ( + AgentAsToolStreamEvent, AgentResultEvent, CitationStreamEvent, EventLoopStopEvent, @@ -465,3 +466,39 @@ def test_event_inheritance(self): assert hasattr(event, "is_callback_event") assert hasattr(event, "as_dict") assert hasattr(event, "prepare") + + +class TestAgentAsToolStreamEvent: + """Tests for AgentAsToolStreamEvent.""" + + def test_initialization(self): + """Test AgentAsToolStreamEvent initialization with agent-tool reference.""" + tool_use: ToolUse = { + "toolUseId": "agent_tool_123", + "name": "researcher", + "input": {"input": "hello"}, + } + agent_event = {"data": "partial response"} + mock_agent_as_tool = MagicMock() + mock_agent_as_tool.tool_name = "researcher" + + event = AgentAsToolStreamEvent(tool_use, agent_event, mock_agent_as_tool) + + assert event["tool_stream_event"]["tool_use"] == tool_use + assert event["tool_stream_event"]["data"] == agent_event + assert event.agent_as_tool is mock_agent_as_tool + assert event.tool_use_id == "agent_tool_123" + + def test_is_tool_stream_event_subclass(self): + """Test that AgentAsToolStreamEvent is a ToolStreamEvent subclass.""" + tool_use: ToolUse = { + "toolUseId": "id_123", + "name": "tool", + "input": {}, + } + mock_agent_as_tool = MagicMock() + event = AgentAsToolStreamEvent(tool_use, {}, mock_agent_as_tool) + + assert isinstance(event, ToolStreamEvent) + assert isinstance(event, TypedEvent) + assert type(event) is AgentAsToolStreamEvent From 415a55d627a298eef583c4d95aaaa26f514a440c Mon Sep 17 00:00:00 2001 From: Owen Kaplan Date: Fri, 20 Mar 2026 14:36:35 -0400 Subject: [PATCH 4/4] fix: make preserve_context default to true --- src/strands/agent/_agent_as_tool.py | 4 +- src/strands/agent/agent.py | 4 +- tests/strands/agent/test_agent_as_tool.py | 72 +++++++++++++++++------ 3 files changed, 57 insertions(+), 23 deletions(-) diff --git a/src/strands/agent/_agent_as_tool.py b/src/strands/agent/_agent_as_tool.py index 534be5d9a..a54b67df9 100644 --- a/src/strands/agent/_agent_as_tool.py +++ b/src/strands/agent/_agent_as_tool.py @@ -52,7 +52,7 @@ def __init__( *, name: str, description: str, - preserve_context: bool = True, + preserve_context: bool = False, ) -> None: r"""Initialize the agent-as-tool adapter. @@ -64,7 +64,7 @@ def __init__( invocations. When False, the agent's messages and state are reset to the values they had at construction time before each call, ensuring every invocation starts from the same baseline regardless of any external - interactions with the agent. Defaults to True. Only effective when the + interactions with the agent. Defaults to False. Only effective when the wrapped agent exposes a mutable ``messages`` list and/or an ``AgentState`` (e.g. ``strands.agent.Agent``). """ diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 8d94de45b..f09399dbf 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -617,7 +617,7 @@ def as_tool( self, name: str | None = None, description: str | None = None, - preserve_context: bool = True, + preserve_context: bool = False, ) -> AgentAsTool: r"""Convert this agent into a tool for use by another agent. @@ -629,7 +629,7 @@ def as_tool( invocations. When False, the agent's messages and state are reset to the values they had at construction time before each call, ensuring every invocation starts from the same baseline regardless of any external - interactions with the agent. Defaults to True. + interactions with the agent. Defaults to False. Returns: An AgentAsTool wrapping this agent. diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py index a6cc28d51..68128e6e5 100644 --- a/tests/strands/agent/test_agent_as_tool.py +++ b/tests/strands/agent/test_agent_as_tool.py @@ -25,9 +25,17 @@ def mock_agent(): return agent +@pytest.fixture +def fake_agent(): + """A real Agent instance for tests that need Agent-specific features.""" + from strands.agent.agent import Agent + + return Agent(name="fake_agent", callback_handler=None) + + @pytest.fixture def tool(mock_agent): - return AgentAsTool(mock_agent, name="test_agent", description="A test agent") + return AgentAsTool(mock_agent, name="test_agent", description="A test agent", preserve_context=True) @pytest.fixture @@ -53,20 +61,20 @@ def agent_result(): def test_init(mock_agent): - tool = AgentAsTool(mock_agent, name="my_tool", description="custom desc") + tool = AgentAsTool(mock_agent, name="my_tool", description="custom desc", preserve_context=True) assert tool.tool_name == "my_tool" assert tool._description == "custom desc" assert tool.agent is mock_agent -def test_init_preserve_context_defaults_true(mock_agent): - tool = AgentAsTool(mock_agent, name="t", description="d") - assert tool._preserve_context is True +def test_init_preserve_context_defaults_false(fake_agent): + tool = AgentAsTool(fake_agent, name="t", description="d") + assert tool._preserve_context is False -def test_init_preserve_context_false(fake_agent): - tool = AgentAsTool(fake_agent, name="t", description="d", preserve_context=False) - assert tool._preserve_context is False +def test_init_preserve_context_true(mock_agent): + tool = AgentAsTool(mock_agent, name="t", description="d", preserve_context=True) + assert tool._preserve_context is True # --- properties --- @@ -244,14 +252,6 @@ class MyOutput(BaseModel): # --- preserve_context --- -@pytest.fixture -def fake_agent(): - """A real Agent instance for preserve_context tests.""" - from strands.agent.agent import Agent - - return Agent(name="fake_agent", callback_handler=None) - - @pytest.mark.asyncio async def test_stream_resets_to_initial_state_when_preserve_context_false(fake_agent): fake_agent.messages = [{"role": "user", "content": [{"text": "initial"}]}] @@ -383,11 +383,45 @@ async def test_stream_resets_empty_initial_state_when_preserve_context_false(fak @pytest.mark.asyncio -async def test_stream_preserves_context_by_default(fake_agent): +async def test_stream_resets_context_by_default(fake_agent): + """Default preserve_context=False means each invocation starts fresh.""" fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] fake_agent.state.set("key", "value") tool = AgentAsTool(fake_agent, name="fake_agent", description="desc") + # Mutate after construction + fake_agent.messages.append({"role": "assistant", "content": [{"text": "extra"}]}) + fake_agent.state.set("key", "changed") + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + # Should reset to construction-time snapshot + assert fake_agent.messages == [{"role": "user", "content": [{"text": "old"}]}] + assert fake_agent.state.get("key") == "value" + + +@pytest.mark.asyncio +async def test_stream_preserves_context_when_explicitly_true(fake_agent): + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( AgentResult( stop_reason="end_turn", @@ -411,7 +445,7 @@ async def test_stream_preserves_context_by_default(fake_agent): def test_preserve_context_false_requires_agent_instance(): - """preserve_context=False should raise TypeError for non-Agent instances.""" + """Default preserve_context=False should raise TypeError for non-Agent instances.""" class _NotAnAgent: name = "not_agent" @@ -426,4 +460,4 @@ def stream_async(self, prompt=None, **kwargs): pass with pytest.raises(TypeError, match="requires an Agent instance"): - AgentAsTool(_NotAnAgent(), name="bad", description="desc", preserve_context=False) + AgentAsTool(_NotAnAgent(), name="bad", description="desc")