diff --git a/docs/tools.md b/docs/tools.md index 3dc860efd5..b8245cd5de 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -309,6 +309,25 @@ for tool in agent.tools: 3. Functions can optionally take the `context` (must be the first argument). You can also set overrides, like the name of the tool, description, which docstring style to use, etc. 4. You can pass the decorated functions to the list of tools. +You can also decorate instance methods. Access the tool from an instance before passing it to +`Agent.tools`; the implicit `self` parameter is bound to that instance and omitted from the tool +schema. + +```python +class CustomerTools: + def __init__(self, tenant_id: str) -> None: + self.tenant_id = tenant_id + + @function_tool + def lookup_customer(self, customer_id: str) -> str: + """Look up a customer by ID.""" + return f"{self.tenant_id}:{customer_id}" + + +customer_tools = CustomerTools("tenant_123") +agent = Agent(name="Assistant", tools=[customer_tools.lookup_customer]) +``` + ??? note "Expand to see output" ``` diff --git a/src/agents/function_schema.py b/src/agents/function_schema.py index 8fe52df320..8dcc148f12 100644 --- a/src/agents/function_schema.py +++ b/src/agents/function_schema.py @@ -40,6 +40,8 @@ class FuncSchema: strict_json_schema: bool = True """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, as it increases the likelihood of correct JSON input.""" + omitted_parameter_names: tuple[str, ...] = () + """Parameter names that are supplied by the SDK instead of model-generated JSON.""" def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]: """ @@ -52,6 +54,8 @@ def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]: # Use enumerate() so we can skip the first parameter if it's context. for idx, (name, param) in enumerate(self.signature.parameters.items()): + if name in self.omitted_parameter_names: + continue # If the function takes a RunContextWrapper and this is the first parameter, skip it. if self.takes_context and idx == 0: continue @@ -228,6 +232,7 @@ def function_schema( description_override: str | None = None, use_docstring_info: bool = True, strict_json_schema: bool = True, + skip_first_parameter: bool = False, ) -> FuncSchema: """ Given a Python function, extracts a `FuncSchema` from it, capturing the name, description, @@ -246,6 +251,8 @@ def function_schema( the schema adheres to the "strict" standard the OpenAI API expects. We **strongly** recommend setting this to True, as it increases the likelihood of the LLM producing correct JSON input. + skip_first_parameter: If True, omit the first signature parameter from the tool schema and + call arguments. This is used for instance methods decorated with `@function_tool`. Returns: A `FuncSchema` object containing the function's name, description, parameter descriptions, @@ -288,22 +295,29 @@ def function_schema( params = list(sig.parameters.items()) takes_context = False filtered_params = [] + omitted_parameter_names: list[str] = [] + + params_to_check = params + if skip_first_parameter and params: + omitted_parameter_names.append(params[0][0]) + params_to_check = params[1:] - if params: - first_name, first_param = params[0] + if params_to_check: + first_name, first_param = params_to_check[0] # Prefer the evaluated type hint if available ann = type_hints.get(first_name, first_param.annotation) if ann != inspect._empty: origin = get_origin(ann) or ann if origin is RunContextWrapper or origin is ToolContext: takes_context = True # Mark that the function takes context + omitted_parameter_names.append(first_name) else: filtered_params.append((first_name, first_param)) else: filtered_params.append((first_name, first_param)) # For parameters other than the first, raise error if any use RunContextWrapper or ToolContext. - for name, param in params[1:]: + for name, param in params_to_check[1:]: ann = type_hints.get(name, param.annotation) if ann != inspect._empty: origin = get_origin(ann) or ann @@ -421,4 +435,5 @@ def function_schema( signature=sig, takes_context=takes_context, strict_json_schema=strict_json_schema, + omitted_parameter_names=tuple(omitted_parameter_names), ) diff --git a/src/agents/tool.py b/src/agents/tool.py index 42c41397cb..abf778ba2d 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -389,6 +389,60 @@ class FunctionTool: _emit_tool_origin: bool = field(default=True, kw_only=True, repr=False) """Whether runtime item generation should emit tool origin metadata for this tool.""" + _method_tool_factory: Callable[[Any], FunctionTool] | None = field( + default=None, + kw_only=True, + repr=False, + ) + """Internal descriptor hook used for instance methods decorated with `@function_tool`.""" + + _staticmethod_tool_factory: Callable[[], FunctionTool] | None = field( + default=None, + kw_only=True, + repr=False, + ) + """Internal fallback for class-scoped tools wrapped in `staticmethod`.""" + + _method_tool_bound_to_class: bool = field(default=False, kw_only=True, repr=False) + """Whether Python installed this tool directly on a class via `__set_name__`.""" + + def __set_name__(self, owner: type[Any], name: str) -> None: + if self._staticmethod_tool_factory is not None: + self._method_tool_bound_to_class = True + + def __getattribute__(self, name: str) -> Any: + if not name.startswith("_") and name not in {"__class__", "__dict__"}: + object.__getattribute__(self, "_maybe_apply_staticmethod_tool")() + return object.__getattribute__(self, name) + + def __get__(self, instance: Any, owner: type[Any] | None = None) -> FunctionTool: + if instance is None or self._method_tool_factory is None: + return self + return self._method_tool_factory(instance) + + def _maybe_apply_staticmethod_tool(self) -> None: + try: + staticmethod_tool_factory = object.__getattribute__(self, "_staticmethod_tool_factory") + method_tool_bound_to_class = object.__getattribute__( + self, "_method_tool_bound_to_class" + ) + except AttributeError: + return + + if staticmethod_tool_factory is None or method_tool_bound_to_class: + return + + # `staticmethod` does not forward `__set_name__` to the wrapped FunctionTool. + # Rebuild as a normal tool before exposing schema or invocation state. + object.__setattr__(self, "_staticmethod_tool_factory", None) + staticmethod_tool = staticmethod_tool_factory() + for tool_field in dataclasses.fields(FunctionTool): + object.__setattr__(self, tool_field.name, getattr(staticmethod_tool, tool_field.name)) + + bind_to_function_tool = getattr(self.on_invoke_tool, "__agents_bind_function_tool__", None) + if callable(bind_to_function_tool): + self.on_invoke_tool = bind_to_function_tool(self) + @property def qualified_name(self) -> str: """Return the public qualified name used to identify this function tool.""" @@ -1836,8 +1890,27 @@ def function_tool( explicitly loads it. """ - def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: + def _is_instance_method_tool(the_func: ToolFunction[...]) -> bool: + parameters = tuple(inspect.signature(the_func).parameters.values()) + if not parameters: + return False + + parent_name = the_func.__qualname__.rsplit(".", 1)[0] + return "." in the_func.__qualname__ and not parent_name.endswith("") + + def _create_function_tool( + the_func: ToolFunction[...], + *, + method_tool_instance: Any | None = None, + treat_as_instance_method: bool | None = None, + enable_method_binding: bool = True, + ) -> FunctionTool: is_sync_function_tool = not inspect.iscoroutinefunction(the_func) + is_instance_method_tool = ( + _is_instance_method_tool(the_func) + if treat_as_instance_method is None + else treat_as_instance_method + ) schema = function_schema( func=the_func, name_override=name_override, @@ -1845,9 +1918,15 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: docstring_style=docstring_style, use_docstring_info=use_docstring_info, strict_json_schema=strict_mode, + skip_first_parameter=is_instance_method_tool, ) async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any: + if is_instance_method_tool and method_tool_instance is None: + raise UserError( + f"Instance method tool {schema.name} must be accessed from an instance" + ) + tool_name = ctx.tool_name json_data = _parse_function_tool_json_input(tool_name=tool_name, input_json=input) _log_function_tool_invocation(tool_name=tool_name, input_json=input) @@ -1866,16 +1945,16 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any: if not _debug.DONT_LOG_TOOL_DATA: logger.debug(f"Tool call args: {args}, kwargs: {kwargs_dict}") + leading_args: list[Any] = [] + if is_instance_method_tool: + leading_args.append(method_tool_instance) + if schema.takes_context: + leading_args.append(ctx) + if not is_sync_function_tool: - if schema.takes_context: - result = await the_func(ctx, *args, **kwargs_dict) - else: - result = await the_func(*args, **kwargs_dict) + result = await the_func(*leading_args, *args, **kwargs_dict) else: - if schema.takes_context: - result = await asyncio.to_thread(the_func, ctx, *args, **kwargs_dict) - else: - result = await asyncio.to_thread(the_func, *args, **kwargs_dict) + result = await asyncio.to_thread(the_func, *leading_args, *args, **kwargs_dict) if _debug.DONT_LOG_TOOL_DATA: logger.debug(f"Tool {tool_name} completed.") @@ -1906,6 +1985,18 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any: defer_loading=defer_loading, sync_invoker=is_sync_function_tool, ) + if enable_method_binding and is_instance_method_tool and method_tool_instance is None: + function_tool._method_tool_factory = lambda instance: _create_function_tool( + the_func, + method_tool_instance=instance, + treat_as_instance_method=True, + enable_method_binding=False, + ) + function_tool._staticmethod_tool_factory = lambda: _create_function_tool( + the_func, + treat_as_instance_method=False, + enable_method_binding=False, + ) return function_tool # If func is actually a callable, we were used as @function_tool with no parentheses diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index 60ae2558cc..83b3cda2ff 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -155,6 +155,145 @@ async def test_simple_function(): ) +@pytest.mark.asyncio +async def test_instance_method_function_tool_binds_self(): + class AccountTools: + def __init__(self, prefix: str) -> None: + self.prefix = prefix + + @function_tool + def lookup(self, account_id: str) -> str: + """Look up an account.""" + return f"{self.prefix}:{account_id}" + + tools = AccountTools("acct") + tool = tools.lookup + + assert isinstance(AccountTools.lookup, FunctionTool) + assert tool.name == "lookup" + assert "self" not in tool.params_json_schema["properties"] + assert "account_id" in tool.params_json_schema["properties"] + + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), + '{"account_id": "123"}', + ) + + assert result == "acct:123" + + +@pytest.mark.asyncio +async def test_instance_method_function_tool_binds_non_self_receiver_name(): + class AccountTools: + def __init__(self, prefix: str) -> None: + self.prefix = prefix + + @function_tool + def lookup(this, account_id: str) -> str: + """Look up an account.""" + return f"{this.prefix}:{account_id}" + + tools = AccountTools("acct") + tool = tools.lookup + + assert "this" not in tool.params_json_schema["properties"] + assert "account_id" in tool.params_json_schema["properties"] + + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), + '{"account_id": "123"}', + ) + + assert result == "acct:123" + + +@pytest.mark.asyncio +async def test_function_tool_does_not_treat_self_named_argument_as_method(): + def lookup(self: str, account_id: str) -> str: + """Look up an account.""" + return f"{self}:{account_id}" + + tool = function_tool(lookup) + + assert "self" in tool.params_json_schema["properties"] + assert "account_id" in tool.params_json_schema["properties"] + + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), + '{"self": "acct", "account_id": "123"}', + ) + + assert result == "acct:123" + + +@pytest.mark.asyncio +async def test_staticmethod_function_tool_keeps_first_parameter(): + class AccountTools: + @staticmethod + @function_tool + def lookup(account_id: str) -> str: + """Look up an account.""" + return f"acct:{account_id}" + + tool = AccountTools.lookup + + assert isinstance(tool, FunctionTool) + assert "account_id" in tool.params_json_schema["properties"] + + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), + '{"account_id": "123"}', + ) + + assert result == "acct:123" + + +@pytest.mark.asyncio +async def test_staticmethod_function_tool_allows_self_named_parameter(): + class AccountTools: + @staticmethod + @function_tool + def lookup(self: str, account_id: str) -> str: + """Look up an account.""" + return f"{self}:{account_id}" + + tool = AccountTools.lookup + + assert isinstance(tool, FunctionTool) + assert "self" in tool.params_json_schema["properties"] + assert "account_id" in tool.params_json_schema["properties"] + + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), + '{"self": "acct", "account_id": "123"}', + ) + + assert result == "acct:123" + + +@pytest.mark.asyncio +async def test_instance_method_function_tool_supports_context_after_self(): + class AccountTools: + @function_tool + def lookup(self, ctx: ToolContext[str], account_id: str) -> str: + """Look up an account with context.""" + return f"{ctx.context}:{account_id}" + + tools = AccountTools() + tool = tools.lookup + + assert "self" not in tool.params_json_schema["properties"] + assert "ctx" not in tool.params_json_schema["properties"] + assert "account_id" in tool.params_json_schema["properties"] + + result = await tool.on_invoke_tool( + ToolContext("tenant", tool_name=tool.name, tool_call_id="1", tool_arguments=""), + '{"account_id": "123"}', + ) + + assert result == "tenant:123" + + @pytest.mark.asyncio async def test_sync_function_runs_via_to_thread(monkeypatch: pytest.MonkeyPatch) -> None: calls = {"to_thread": 0, "func": 0}