From 4eb85eb8883ef9f82f46dffe673d499c86876f57 Mon Sep 17 00:00:00 2001 From: Emaan Khan Date: Sun, 22 Mar 2026 16:08:56 -0700 Subject: [PATCH] feat: add streaming to direct tool calls --- src/strands/tools/_caller.py | 270 +++++++++++++++------- tests/strands/tools/test_tool_executor.py | 122 ++++++++++ tests_integ/test_tool_streaming.py | 79 +++++++ 3 files changed, 387 insertions(+), 84 deletions(-) create mode 100644 tests/strands/tools/test_tool_executor.py create mode 100644 tests_integ/test_tool_streaming.py diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index 0b5408f35..b1c87071e 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -8,14 +8,15 @@ """ import json +import logging import random import weakref -from collections.abc import Callable +from collections.abc import AsyncIterator, Iterator from typing import TYPE_CHECKING, Any from .._async import run_async from ..tools.executors._executor import ToolExecutor -from ..types._events import ToolInterruptEvent +from ..types._events import ToolInterruptEvent, TypedEvent from ..types.content import ContentBlock, Message from ..types.exceptions import ConcurrencyException from ..types.tools import ToolResult, ToolUse @@ -24,19 +25,27 @@ from ..agent import Agent from ..experimental.bidi.agent import BidiAgent +logger = logging.getLogger(__name__) -class _ToolCaller: - """Call tool as a function.""" - def __init__(self, agent: "Agent | BidiAgent") -> None: - """Initialize instance. +class _ToolExecutor: + """Callable wrapper for tools that provides streaming methods. + + This class enables three execution modes for tools: + 1. Synchronous: result = executor(x=5) + 2. Sync streaming: for event in executor.stream(x=5) + 3. Async streaming: async for event in executor.stream_async(x=5) + """ + + def __init__(self, agent: "Agent | BidiAgent", tool_name: str) -> None: + """Initialize tool executor. Args: - agent: Agent reference that will accept tool results. + agent: Agent reference that owns the tools. + tool_name: Name of the tool to execute. """ - # WARNING: Do not add any other member variables or methods as this could result in a name conflict with - # agent tools and thus break their execution. self._agent_ref = weakref.ref(agent) + self._tool_name = tool_name @property def _agent(self) -> "Agent | BidiAgent": @@ -46,104 +55,161 @@ def _agent(self) -> "Agent | BidiAgent": raise ReferenceError("Agent has been garbage collected") return agent - def __getattr__(self, name: str) -> Callable[..., Any]: - """Call tool as a function. + def __call__( + self, + user_message_override: str | None = None, + record_direct_tool_call: bool | None = None, + **kwargs: Any, + ) -> ToolResult: + """Synchronous tool execution (existing behavior - backward compatible). This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). Args: - name: The name of the attribute (tool) being accessed. + user_message_override: Optional custom message to record. + record_direct_tool_call: Whether to record in message history. + **kwargs: Tool parameters. Returns: - A function that when called will execute the named tool. + ToolResult from execution. Raises: - AttributeError: If no tool with the given name exists or if multiple tools match the given name. + AttributeError: If tool doesn't exist. + RuntimeError: If called during interrupt. + ConcurrencyException: If invocation lock cannot be acquired. """ + if self._agent._interrupt_state.activated: + raise RuntimeError("cannot directly call tool during interrupt") - def caller( - user_message_override: str | None = None, - record_direct_tool_call: bool | None = None, - **kwargs: Any, - ) -> Any: - """Call a tool directly by name. - - Args: - user_message_override: Optional custom message to record instead of default - record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class - attribute if provided. - **kwargs: Keyword arguments to pass to the tool. - - Returns: - The result returned by the tool. - - Raises: - AttributeError: If the tool doesn't exist. - """ - if self._agent._interrupt_state.activated: - raise RuntimeError("cannot directly call tool during interrupt") - - if record_direct_tool_call is not None: - should_record_direct_tool_call = record_direct_tool_call - else: - should_record_direct_tool_call = self._agent.record_direct_tool_call - - should_lock = should_record_direct_tool_call - - from ..agent import Agent # Locally imported to avoid circular reference - - acquired_lock = ( - should_lock - and isinstance(self._agent, Agent) - and self._agent._invocation_lock.acquire_lock(blocking=False) - ) - if should_lock and not acquired_lock: - raise ConcurrencyException( - "Direct tool call cannot be made while the agent is in the middle of an invocation. " - "Set record_direct_tool_call=False to allow direct tool calls during agent invocation." - ) + if record_direct_tool_call is not None: + should_record_direct_tool_call = record_direct_tool_call + else: + should_record_direct_tool_call = self._agent.record_direct_tool_call - try: - normalized_name = self._find_normalized_tool_name(name) + should_lock = should_record_direct_tool_call - # Create unique tool ID and set up the tool request - tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" - tool_use: ToolUse = { - "toolUseId": tool_id, - "name": normalized_name, - "input": kwargs.copy(), - } - tool_results: list[ToolResult] = [] - invocation_state = kwargs + from ..agent import Agent # Locally imported to avoid circular reference - async def acall() -> ToolResult: - async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): - if isinstance(event, ToolInterruptEvent): - self._agent._interrupt_state.deactivate() - raise RuntimeError("cannot raise interrupt in direct tool call") + acquired_lock = ( + should_lock and isinstance(self._agent, Agent) and self._agent._invocation_lock.acquire_lock(blocking=False) + ) + if should_lock and not acquired_lock: + raise ConcurrencyException( + "Direct tool call cannot be made while the agent is in the middle of an invocation. " + "Set record_direct_tool_call=False to allow direct tool calls during agent invocation." + ) - tool_result = tool_results[0] + try: + normalized_name = self._find_normalized_tool_name(self._tool_name) - if should_record_direct_tool_call: - # Create a record of this tool execution in the message history - await self._record_tool_execution(tool_use, tool_result, user_message_override) + # Create unique tool ID and set up the tool request + tool_id = f"tooluse_{self._tool_name}_{random.randint(100000000, 999999999)}" + tool_use: ToolUse = { + "toolUseId": tool_id, + "name": normalized_name, + "input": kwargs.copy(), + } + tool_results: list[ToolResult] = [] + invocation_state = kwargs - return tool_result + async def acall() -> ToolResult: + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): + if isinstance(event, ToolInterruptEvent): + self._agent._interrupt_state.deactivate() + raise RuntimeError("cannot raise interrupt in direct tool call") - tool_result = run_async(acall) + tool_result = tool_results[0] - # TODO: https://github.com/strands-agents/sdk-python/issues/1311 - if isinstance(self._agent, Agent): - self._agent.conversation_manager.apply_management(self._agent) + if should_record_direct_tool_call: + # Create a record of this tool execution in the message history + await self._record_tool_execution(tool_use, tool_result, user_message_override) return tool_result - finally: - if acquired_lock and isinstance(self._agent, Agent): - self._agent._invocation_lock.release() + tool_result = run_async(acall) + + # TODO: https://github.com/strands-agents/sdk-python/issues/1311 + if isinstance(self._agent, Agent): + self._agent.conversation_manager.apply_management(self._agent) + + return tool_result + + finally: + if acquired_lock and isinstance(self._agent, Agent): + self._agent._invocation_lock.release() + + def stream(self, **kwargs: Any) -> Iterator[TypedEvent]: + """Synchronous streaming via async-to-sync wrapper. + + This method provides synchronous streaming by wrapping stream_async() + with run_async(). Note that due to Python's async/sync boundary constraints, + events are buffered before yielding. For true streaming, use stream_async(). + + Args: + **kwargs: Tool parameters. + + Yields: + Tool execution events. + + Raises: + AttributeError: If tool doesn't exist. + RuntimeError: If called during interrupt. + """ + + async def async_generator() -> AsyncIterator[TypedEvent]: + async for event in self.stream_async(**kwargs): + yield event + + # Run async generator in sync context + async def collect_events() -> list[TypedEvent]: + events = [] + async for event in async_generator(): + events.append(event) + return events + + events = run_async(collect_events) + yield from events + + async def stream_async(self, **kwargs: Any) -> AsyncIterator[TypedEvent]: + """Asynchronous streaming from ToolExecutor._stream(). + + This method yields events directly from tool execution without recording + to message history. Designed for observability and real-time progress. + + Args: + **kwargs: Tool parameters. + + Yields: + Tool execution events from ToolExecutor._stream(). + + Raises: + AttributeError: If tool doesn't exist. + RuntimeError: If called during interrupt. + """ + if self._agent._interrupt_state.activated: + raise RuntimeError("cannot directly call tool during interrupt") - return caller + normalized_name = self._find_normalized_tool_name(self._tool_name) + + logger.debug("tool_name=<%s>, streaming= | executing tool stream", normalized_name) + + # Create unique tool ID and set up the tool request + tool_id = f"tooluse_{self._tool_name}_{random.randint(100000000, 999999999)}" + tool_use: ToolUse = { + "toolUseId": tool_id, + "name": normalized_name, + "input": kwargs.copy(), + } + tool_results: list[ToolResult] = [] + invocation_state = kwargs + + # Stream events directly without recording to message history + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): + if isinstance(event, ToolInterruptEvent): + self._agent._interrupt_state.deactivate() + raise RuntimeError("cannot raise interrupt in direct tool call") + yield event def _find_normalized_tool_name(self, name: str) -> str: """Lookup the tool represented by name, replacing characters with underscores as necessary.""" @@ -246,3 +312,39 @@ def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: di properties = tool_spec["inputSchema"]["json"]["properties"] return {k: v for k, v in input_params.items() if k in properties} + + +class _ToolCaller: + """Call tool as a function.""" + + def __init__(self, agent: "Agent | BidiAgent") -> None: + """Initialize instance. + + Args: + agent: Agent reference that will accept tool results. + """ + # WARNING: Do not add any other member variables or methods as this could result in a name conflict with + # agent tools and thus break their execution. + self._agent_ref = weakref.ref(agent) + + @property + def _agent(self) -> "Agent | BidiAgent": + """Return the agent, raising ReferenceError if it has been garbage collected.""" + agent = self._agent_ref() + if agent is None: + raise ReferenceError("Agent has been garbage collected") + return agent + + def __getattr__(self, name: str) -> _ToolExecutor: + """Return tool executor with streaming methods. + + This method enables the tool calling interface by returning a callable + object that provides both synchronous execution and streaming methods. + + Args: + name: Tool name. + + Returns: + Tool executor instance. + """ + return _ToolExecutor(self._agent, name) diff --git a/tests/strands/tools/test_tool_executor.py b/tests/strands/tools/test_tool_executor.py new file mode 100644 index 000000000..59bb04f66 --- /dev/null +++ b/tests/strands/tools/test_tool_executor.py @@ -0,0 +1,122 @@ +"""Unit tests for _ToolExecutor.""" + +import gc +import weakref + +import pytest + +from strands import Agent, tool + + +class TestToolExecutor: + """Test _ToolExecutor class.""" + + def test_executor_is_callable(self): + """Test tool executor is callable.""" + + @tool + def test_tool(x: int) -> int: + return x * 2 + + agent = Agent(tools=[test_tool]) + executor = agent.tool.test_tool + + # Should be callable + result = executor(x=5) + assert result["status"] == "success" + + def test_executor_has_streaming_methods(self): + """Test executor has stream and stream_async methods.""" + + @tool + def test_tool(x: int) -> int: + return x + + agent = Agent(tools=[test_tool]) + executor = agent.tool.test_tool + + assert hasattr(executor, "stream") + assert hasattr(executor, "stream_async") + assert callable(executor.stream) + assert callable(executor.stream_async) + + def test_weakref_prevents_circular_reference(self): + """Test weakref prevents agent from leaking.""" + + @tool + def test_tool(x: int) -> int: + return x + + gc.disable() + try: + agent = Agent(tools=[test_tool]) + _ = agent.tool.test_tool # Create executor to test weakref + ref = weakref.ref(agent) + + del agent + + if ref() is not None: + gc.collect() + + assert ref() is None + finally: + gc.enable() + + def test_executor_weakref_raises_on_deleted_agent(self): + """Test accessing _agent property raises ReferenceError when agent deleted.""" + + @tool + def test_tool(x: int) -> int: + return x + + agent = Agent(tools=[test_tool]) + executor = agent.tool.test_tool + + # Delete agent + del agent + gc.collect() + + # Accessing _agent should raise ReferenceError + with pytest.raises(ReferenceError, match="Agent has been garbage collected"): + _ = executor._agent + + @pytest.mark.asyncio + async def test_stream_async_with_interrupt_raises(self): + """Test stream_async raises when interrupt state activated.""" + + @tool + def test_tool(x: int) -> int: + return x + + agent = Agent(tools=[test_tool]) + agent._interrupt_state.activate() + + with pytest.raises(RuntimeError, match="cannot directly call tool during interrupt"): + async for _event in agent.tool.test_tool.stream_async(x=5): + pass + + def test_find_normalized_tool_name_with_underscores(self): + """Test tool name normalization replaces underscores with hyphens.""" + + @tool(name="my-tool") + def my_tool(x: int) -> int: + return x + + agent = Agent(tools=[my_tool]) + executor = agent.tool.my_tool + + # Should find tool with hyphen name via underscore access + result = executor(x=5) + assert result["status"] == "success" + + def test_find_normalized_tool_name_not_found(self): + """Test non-existent tool raises AttributeError.""" + + @tool + def test_tool(x: int) -> int: + return x + + agent = Agent(tools=[test_tool]) + + with pytest.raises(AttributeError, match="Tool 'nonexistent' not found"): + _ = agent.tool.nonexistent(x=5) diff --git a/tests_integ/test_tool_streaming.py b/tests_integ/test_tool_streaming.py new file mode 100644 index 000000000..a64fe4eef --- /dev/null +++ b/tests_integ/test_tool_streaming.py @@ -0,0 +1,79 @@ +"""Integration tests for direct tool call streaming (Issue #1436).""" + +import pytest + +from strands import Agent, tool + + +@tool +def simple_tool(value: int) -> int: + """Simple tool for testing.""" + return value * 2 + + +class TestToolStreaming: + """Test tool streaming methods.""" + + @pytest.mark.asyncio + async def test_async_streaming(self): + """Test async streaming captures events.""" + agent = Agent(tools=[simple_tool]) + events = [] + + async for event in agent.tool.simple_tool.stream_async(value=5): + events.append(event) + + assert len(events) > 0 + assert any(e.get("type") == "tool_result" for e in events) + + def test_sync_streaming(self): + """Test sync streaming captures events.""" + agent = Agent(tools=[simple_tool]) + events = [] + + for event in agent.tool.simple_tool.stream(value=5): + events.append(event) + + assert len(events) > 0 + assert any(e.get("type") == "tool_result" for e in events) + + def test_backward_compatibility(self): + """Test existing sync API unchanged.""" + agent = Agent(tools=[simple_tool]) + result = agent.tool.simple_tool(value=5) + + assert result["status"] == "success" + assert result["content"][0]["text"] == "10" + + @pytest.mark.asyncio + async def test_tool_not_found(self): + """Test non-existent tool raises AttributeError.""" + agent = Agent(tools=[]) + + with pytest.raises(AttributeError, match="Tool 'fake' not found"): + async for _event in agent.tool.fake.stream_async(): + pass + + @pytest.mark.asyncio + async def test_tool_error_captured_in_result(self): + """Test tool errors are captured in tool_result events.""" + + @tool + def error_tool() -> str: + raise ValueError("Test error") + + agent = Agent(tools=[error_tool]) + events = [] + + async for event in agent.tool.error_tool.stream_async(): + events.append(event) + + # Should have at least one event + assert len(events) > 0 + + # Final event should be tool_result with error status + final_event = events[-1] + assert final_event.get("type") == "tool_result" + tool_result = final_event.get("tool_result", {}) + assert tool_result.get("status") == "error" + assert "Test error" in tool_result.get("content", [{}])[0].get("text", "")