Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 186 additions & 84 deletions src/strands/tools/_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
"""

import json
import logging
import random
import weakref
from collections.abc import Callable
from collections.abc import AsyncIterator, Iterator
from typing import TYPE_CHECKING, Any

from .._async import run_async
from ..tools.executors._executor import ToolExecutor
from ..types._events import ToolInterruptEvent
from ..types._events import ToolInterruptEvent, TypedEvent
from ..types.content import ContentBlock, Message
from ..types.exceptions import ConcurrencyException
from ..types.tools import ToolResult, ToolUse
Expand All @@ -24,19 +25,27 @@
from ..agent import Agent
from ..experimental.bidi.agent import BidiAgent

logger = logging.getLogger(__name__)

class _ToolCaller:
"""Call tool as a function."""

def __init__(self, agent: "Agent | BidiAgent") -> None:
"""Initialize instance.
class _ToolExecutor:
"""Callable wrapper for tools that provides streaming methods.

This class enables three execution modes for tools:
1. Synchronous: result = executor(x=5)
2. Sync streaming: for event in executor.stream(x=5)
3. Async streaming: async for event in executor.stream_async(x=5)
"""

def __init__(self, agent: "Agent | BidiAgent", tool_name: str) -> None:
"""Initialize tool executor.

Args:
agent: Agent reference that will accept tool results.
agent: Agent reference that owns the tools.
tool_name: Name of the tool to execute.
"""
# WARNING: Do not add any other member variables or methods as this could result in a name conflict with
# agent tools and thus break their execution.
self._agent_ref = weakref.ref(agent)
self._tool_name = tool_name

@property
def _agent(self) -> "Agent | BidiAgent":
Expand All @@ -46,104 +55,161 @@ def _agent(self) -> "Agent | BidiAgent":
raise ReferenceError("Agent has been garbage collected")
return agent

def __getattr__(self, name: str) -> Callable[..., Any]:
"""Call tool as a function.
def __call__(
self,
user_message_override: str | None = None,
record_direct_tool_call: bool | None = None,
**kwargs: Any,
) -> ToolResult:
"""Synchronous tool execution (existing behavior - backward compatible).

This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`).
It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing').

Args:
name: The name of the attribute (tool) being accessed.
user_message_override: Optional custom message to record.
record_direct_tool_call: Whether to record in message history.
**kwargs: Tool parameters.

Returns:
A function that when called will execute the named tool.
ToolResult from execution.

Raises:
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
AttributeError: If tool doesn't exist.
RuntimeError: If called during interrupt.
ConcurrencyException: If invocation lock cannot be acquired.
"""
if self._agent._interrupt_state.activated:
raise RuntimeError("cannot directly call tool during interrupt")

def caller(
user_message_override: str | None = None,
record_direct_tool_call: bool | None = None,
**kwargs: Any,
) -> Any:
"""Call a tool directly by name.

Args:
user_message_override: Optional custom message to record instead of default
record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class
attribute if provided.
**kwargs: Keyword arguments to pass to the tool.

Returns:
The result returned by the tool.

Raises:
AttributeError: If the tool doesn't exist.
"""
if self._agent._interrupt_state.activated:
raise RuntimeError("cannot directly call tool during interrupt")

if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
else:
should_record_direct_tool_call = self._agent.record_direct_tool_call

should_lock = should_record_direct_tool_call

from ..agent import Agent # Locally imported to avoid circular reference

acquired_lock = (
should_lock
and isinstance(self._agent, Agent)
and self._agent._invocation_lock.acquire_lock(blocking=False)
)
if should_lock and not acquired_lock:
raise ConcurrencyException(
"Direct tool call cannot be made while the agent is in the middle of an invocation. "
"Set record_direct_tool_call=False to allow direct tool calls during agent invocation."
)
if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
else:
should_record_direct_tool_call = self._agent.record_direct_tool_call

try:
normalized_name = self._find_normalized_tool_name(name)
should_lock = should_record_direct_tool_call

# Create unique tool ID and set up the tool request
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
tool_use: ToolUse = {
"toolUseId": tool_id,
"name": normalized_name,
"input": kwargs.copy(),
}
tool_results: list[ToolResult] = []
invocation_state = kwargs
from ..agent import Agent # Locally imported to avoid circular reference

async def acall() -> ToolResult:
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
if isinstance(event, ToolInterruptEvent):
self._agent._interrupt_state.deactivate()
raise RuntimeError("cannot raise interrupt in direct tool call")
acquired_lock = (
should_lock and isinstance(self._agent, Agent) and self._agent._invocation_lock.acquire_lock(blocking=False)
)
if should_lock and not acquired_lock:
raise ConcurrencyException(
"Direct tool call cannot be made while the agent is in the middle of an invocation. "
"Set record_direct_tool_call=False to allow direct tool calls during agent invocation."
)

tool_result = tool_results[0]
try:
normalized_name = self._find_normalized_tool_name(self._tool_name)

if should_record_direct_tool_call:
# Create a record of this tool execution in the message history
await self._record_tool_execution(tool_use, tool_result, user_message_override)
# Create unique tool ID and set up the tool request
tool_id = f"tooluse_{self._tool_name}_{random.randint(100000000, 999999999)}"
tool_use: ToolUse = {
"toolUseId": tool_id,
"name": normalized_name,
"input": kwargs.copy(),
}
tool_results: list[ToolResult] = []
invocation_state = kwargs

return tool_result
async def acall() -> ToolResult:
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
if isinstance(event, ToolInterruptEvent):
self._agent._interrupt_state.deactivate()
raise RuntimeError("cannot raise interrupt in direct tool call")

tool_result = run_async(acall)
tool_result = tool_results[0]

# TODO: https://github.com/strands-agents/sdk-python/issues/1311
if isinstance(self._agent, Agent):
self._agent.conversation_manager.apply_management(self._agent)
if should_record_direct_tool_call:
# Create a record of this tool execution in the message history
await self._record_tool_execution(tool_use, tool_result, user_message_override)

return tool_result

finally:
if acquired_lock and isinstance(self._agent, Agent):
self._agent._invocation_lock.release()
tool_result = run_async(acall)

# TODO: https://github.com/strands-agents/sdk-python/issues/1311
if isinstance(self._agent, Agent):
self._agent.conversation_manager.apply_management(self._agent)

return tool_result

finally:
if acquired_lock and isinstance(self._agent, Agent):
self._agent._invocation_lock.release()

def stream(self, **kwargs: Any) -> Iterator[TypedEvent]:
"""Synchronous streaming via async-to-sync wrapper.

This method provides synchronous streaming by wrapping stream_async()
with run_async(). Note that due to Python's async/sync boundary constraints,
events are buffered before yielding. For true streaming, use stream_async().

Args:
**kwargs: Tool parameters.

Yields:
Tool execution events.

Raises:
AttributeError: If tool doesn't exist.
RuntimeError: If called during interrupt.
"""

async def async_generator() -> AsyncIterator[TypedEvent]:
async for event in self.stream_async(**kwargs):
yield event

# Run async generator in sync context
async def collect_events() -> list[TypedEvent]:
events = []
async for event in async_generator():
events.append(event)
return events

events = run_async(collect_events)
yield from events

async def stream_async(self, **kwargs: Any) -> AsyncIterator[TypedEvent]:
"""Asynchronous streaming from ToolExecutor._stream().

This method yields events directly from tool execution without recording
to message history. Designed for observability and real-time progress.

Args:
**kwargs: Tool parameters.

Yields:
Tool execution events from ToolExecutor._stream().

Raises:
AttributeError: If tool doesn't exist.
RuntimeError: If called during interrupt.
"""
if self._agent._interrupt_state.activated:
raise RuntimeError("cannot directly call tool during interrupt")

return caller
normalized_name = self._find_normalized_tool_name(self._tool_name)

logger.debug("tool_name=<%s>, streaming=<True> | executing tool stream", normalized_name)

# Create unique tool ID and set up the tool request
tool_id = f"tooluse_{self._tool_name}_{random.randint(100000000, 999999999)}"
tool_use: ToolUse = {
"toolUseId": tool_id,
"name": normalized_name,
"input": kwargs.copy(),
}
tool_results: list[ToolResult] = []
invocation_state = kwargs

# Stream events directly without recording to message history
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
if isinstance(event, ToolInterruptEvent):
self._agent._interrupt_state.deactivate()
raise RuntimeError("cannot raise interrupt in direct tool call")
yield event

def _find_normalized_tool_name(self, name: str) -> str:
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
Expand Down Expand Up @@ -246,3 +312,39 @@ def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: di

properties = tool_spec["inputSchema"]["json"]["properties"]
return {k: v for k, v in input_params.items() if k in properties}


class _ToolCaller:
"""Call tool as a function."""

def __init__(self, agent: "Agent | BidiAgent") -> None:
"""Initialize instance.

Args:
agent: Agent reference that will accept tool results.
"""
# WARNING: Do not add any other member variables or methods as this could result in a name conflict with
# agent tools and thus break their execution.
self._agent_ref = weakref.ref(agent)

@property
def _agent(self) -> "Agent | BidiAgent":
"""Return the agent, raising ReferenceError if it has been garbage collected."""
agent = self._agent_ref()
if agent is None:
raise ReferenceError("Agent has been garbage collected")
return agent

def __getattr__(self, name: str) -> _ToolExecutor:
"""Return tool executor with streaming methods.

This method enables the tool calling interface by returning a callable
object that provides both synchronous execution and streaming methods.

Args:
name: Tool name.

Returns:
Tool executor instance.
"""
return _ToolExecutor(self._agent, name)
Loading