diff --git a/src/agents/tool.py b/src/agents/tool.py index fd5b8e011..db1df63d6 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -12,6 +12,7 @@ Callable, Generic, Literal, + Optional, Protocol, TypeVar, Union, @@ -765,6 +766,7 @@ def default_tool_error_function(ctx: RunContextWrapper[Any], error: Exception) - ToolErrorFunction = Callable[[RunContextWrapper[Any], Exception], MaybeAwaitable[str]] +_UNSET_FAILURE_ERROR_FUNCTION = object() @overload @@ -813,7 +815,7 @@ def function_tool( description_override: str | None = None, docstring_style: DocstringStyle | None = None, use_docstring_info: bool = True, - failure_error_function: ToolErrorFunction | None = default_tool_error_function, + failure_error_function: ToolErrorFunction | None | object = _UNSET_FAILURE_ERROR_FUNCTION, strict_mode: bool = True, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, needs_approval: bool @@ -923,10 +925,18 @@ async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any: try: return await _on_invoke_tool_impl(ctx, input) except Exception as e: - if failure_error_function is None: + resolved_failure_error_function: ToolErrorFunction | None + if failure_error_function is _UNSET_FAILURE_ERROR_FUNCTION: + resolved_failure_error_function = default_tool_error_function + else: + resolved_failure_error_function = cast( + Optional[ToolErrorFunction], failure_error_function + ) + + if resolved_failure_error_function is None: raise - result = failure_error_function(ctx, e) + result = resolved_failure_error_function(ctx, e) if inspect.isawaitable(result): return await result diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index 01d1e3c6b..734e1f616 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -8,6 +8,7 @@ from pydantic import BaseModel from typing_extensions import TypedDict +import agents.tool as tool_module from agents import ( Agent, AgentBase, @@ -448,6 +449,25 @@ def boom() -> None: assert result.startswith("handled:") +@pytest.mark.asyncio +async def test_default_failure_error_function_is_resolved_at_invoke_time( + monkeypatch: pytest.MonkeyPatch, +) -> None: + def boom(a: int) -> None: + raise ValueError(f"boom:{a}") + + tool = function_tool(boom) + + def patched_default(_ctx: RunContextWrapper[Any], error: Exception) -> str: + return f"patched:{error}" + + monkeypatch.setattr(tool_module, "default_tool_error_function", patched_default) + + ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 7}') + result = await tool.on_invoke_tool(ctx, '{"a": 7}') + assert result == "patched:boom:7" + + def test_function_tool_accepts_guardrail_arguments(): tool = function_tool( simple_function,