Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Callable,
Generic,
Literal,
Optional,
Protocol,
TypeVar,
Union,
Expand Down Expand Up @@ -765,6 +766,7 @@ def default_tool_error_function(ctx: RunContextWrapper[Any], error: Exception) -


ToolErrorFunction = Callable[[RunContextWrapper[Any], Exception], MaybeAwaitable[str]]
_UNSET_FAILURE_ERROR_FUNCTION = object()


@overload
Expand Down Expand Up @@ -813,7 +815,7 @@ def function_tool(
description_override: str | None = None,
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
failure_error_function: ToolErrorFunction | None | object = _UNSET_FAILURE_ERROR_FUNCTION,
strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
needs_approval: bool
Expand Down Expand Up @@ -923,10 +925,18 @@ async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any:
try:
return await _on_invoke_tool_impl(ctx, input)
except Exception as e:
if failure_error_function is None:
resolved_failure_error_function: ToolErrorFunction | None
if failure_error_function is _UNSET_FAILURE_ERROR_FUNCTION:
resolved_failure_error_function = default_tool_error_function
else:
resolved_failure_error_function = cast(
Optional[ToolErrorFunction], failure_error_function
)

if resolved_failure_error_function is None:
raise

result = failure_error_function(ctx, e)
result = resolved_failure_error_function(ctx, e)
if inspect.isawaitable(result):
return await result

Expand Down
20 changes: 20 additions & 0 deletions tests/test_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import BaseModel
from typing_extensions import TypedDict

import agents.tool as tool_module
from agents import (
Agent,
AgentBase,
Expand Down Expand Up @@ -448,6 +449,25 @@ def boom() -> None:
assert result.startswith("handled:")


@pytest.mark.asyncio
async def test_default_failure_error_function_is_resolved_at_invoke_time(
monkeypatch: pytest.MonkeyPatch,
) -> None:
def boom(a: int) -> None:
raise ValueError(f"boom:{a}")

tool = function_tool(boom)

def patched_default(_ctx: RunContextWrapper[Any], error: Exception) -> str:
return f"patched:{error}"

monkeypatch.setattr(tool_module, "default_tool_error_function", patched_default)

ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 7}')
result = await tool.on_invoke_tool(ctx, '{"a": 7}')
assert result == "patched:boom:7"


def test_function_tool_accepts_guardrail_arguments():
tool = function_tool(
simple_function,
Expand Down