Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,9 @@ async def _run_agent_impl(context: ToolContext, input_json: str) -> Any:
raise ModelBehaviorError("Agent tool called with invalid input")

resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS
resolved_run_config = run_config
if resolved_run_config is None and isinstance(context, ToolContext):
resolved_run_config = context.run_config
if isinstance(context, ToolContext):
# Use a fresh ToolContext to avoid sharing approval state with parent runs.
nested_context = ToolContext(
Expand All @@ -600,6 +603,7 @@ async def _run_agent_impl(context: ToolContext, input_json: str) -> Any:
tool_arguments=context.tool_arguments,
tool_call=context.tool_call,
agent=context.agent,
run_config=resolved_run_config,
)
if should_capture_tool_input:
nested_context.tool_input = params_data
Expand Down Expand Up @@ -697,7 +701,7 @@ def _apply_nested_approvals(
starting_agent=cast(Agent[Any], self),
input=resume_state or resolved_input,
context=None if resume_state is not None else cast(Any, nested_context),
run_config=run_config,
run_config=resolved_run_config,
max_turns=resolved_max_turns,
hooks=hooks,
previous_response_id=None
Expand Down Expand Up @@ -761,7 +765,7 @@ async def dispatch_stream_events() -> None:
starting_agent=cast(Agent[Any], self),
input=resume_state or resolved_input,
context=None if resume_state is not None else cast(Any, nested_context),
run_config=run_config,
run_config=resolved_run_config,
max_turns=resolved_max_turns,
hooks=hooks,
previous_response_id=None
Expand Down
1 change: 1 addition & 0 deletions src/agents/run_internal/tool_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,7 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo
tool_call.call_id,
tool_call=tool_call,
agent=agent,
run_config=config,
)
agent_hooks = agent.hooks
if config.trace_include_sensitive_data:
Expand Down
12 changes: 12 additions & 0 deletions src/agents/tool_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
if TYPE_CHECKING:
from .agent import AgentBase
from .items import TResponseInputItem
from .run_config import RunConfig
from .run_context import _ApprovalRecord


Expand Down Expand Up @@ -48,6 +49,9 @@ class ToolContext(RunContextWrapper[TContext]):
agent: AgentBase[Any] | None = None
"""The active agent for this tool call, when available."""

run_config: RunConfig | None = None
"""The active run config for this tool call, when available."""

def __init__(
self,
context: TContext,
Expand All @@ -58,6 +62,7 @@ def __init__(
tool_call: ResponseFunctionToolCall | None = None,
*,
agent: AgentBase[Any] | None = None,
run_config: RunConfig | None = None,
turn_input: list[TResponseInputItem] | None = None,
_approvals: dict[str, _ApprovalRecord] | None = None,
tool_input: Any | None = None,
Expand Down Expand Up @@ -86,6 +91,7 @@ def __init__(
)
self.tool_call = tool_call
self.agent = agent
self.run_config = run_config

@classmethod
def from_agent_context(
Expand All @@ -94,6 +100,8 @@ def from_agent_context(
tool_call_id: str,
tool_call: ResponseFunctionToolCall | None = None,
agent: AgentBase[Any] | None = None,
*,
run_config: RunConfig | None = None,
) -> ToolContext:
"""
Create a ToolContext from a RunContextWrapper.
Expand All @@ -109,13 +117,17 @@ def from_agent_context(
tool_agent = agent
if tool_agent is None and isinstance(context, ToolContext):
tool_agent = context.agent
tool_run_config = run_config
if tool_run_config is None and isinstance(context, ToolContext):
tool_run_config = context.run_config

tool_context = cls(
tool_name=tool_name,
tool_call_id=tool_call_id,
tool_arguments=tool_args,
tool_call=tool_call,
agent=tool_agent,
run_config=tool_run_config,
**base_values,
)
return tool_context
161 changes: 161 additions & 0 deletions tests/test_agent_as_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,167 @@ async def extractor(result) -> str:
assert output == "custom output"


@pytest.mark.asyncio
async def test_agent_as_tool_inherits_parent_run_config_when_not_set(
monkeypatch: pytest.MonkeyPatch,
) -> None:
agent = Agent(name="inherits_config_agent")
parent_run_config = RunConfig(model="gpt-4.1-mini")

class DummyResult:
def __init__(self) -> None:
self.final_output = "ok"

async def fake_run(
cls,
starting_agent,
input,
*,
context,
max_turns,
hooks,
run_config,
previous_response_id,
conversation_id,
session,
):
assert starting_agent is agent
assert input == "hello"
assert isinstance(context, ToolContext)
assert run_config is parent_run_config
assert context.run_config is parent_run_config
return DummyResult()

monkeypatch.setattr(Runner, "run", classmethod(fake_run))

tool = cast(
FunctionTool,
agent.as_tool(
tool_name="inherits_config_tool",
tool_description="inherit config",
),
)
tool_context = ToolContext(
context=None,
tool_name="inherits_config_tool",
tool_call_id="call_inherit",
tool_arguments='{"input":"hello"}',
run_config=parent_run_config,
)

output = await tool.on_invoke_tool(tool_context, '{"input":"hello"}')

assert output == "ok"


@pytest.mark.asyncio
async def test_agent_as_tool_explicit_run_config_overrides_parent_context(
monkeypatch: pytest.MonkeyPatch,
) -> None:
agent = Agent(name="override_config_agent")
parent_run_config = RunConfig(model="gpt-4.1-mini")
explicit_run_config = RunConfig(model="gpt-4.1")

class DummyResult:
def __init__(self) -> None:
self.final_output = "ok"

async def fake_run(
cls,
starting_agent,
input,
*,
context,
max_turns,
hooks,
run_config,
previous_response_id,
conversation_id,
session,
):
assert starting_agent is agent
assert input == "hello"
assert isinstance(context, ToolContext)
assert run_config is explicit_run_config
assert context.run_config is explicit_run_config
return DummyResult()

monkeypatch.setattr(Runner, "run", classmethod(fake_run))

tool = cast(
FunctionTool,
agent.as_tool(
tool_name="override_config_tool",
tool_description="override config",
run_config=explicit_run_config,
),
)
tool_context = ToolContext(
context=None,
tool_name="override_config_tool",
tool_call_id="call_override",
tool_arguments='{"input":"hello"}',
run_config=parent_run_config,
)

output = await tool.on_invoke_tool(tool_context, '{"input":"hello"}')

assert output == "ok"


@pytest.mark.asyncio
async def test_agent_as_tool_inherits_trace_include_sensitive_data_setting(
monkeypatch: pytest.MonkeyPatch,
) -> None:
agent = Agent(name="trace_config_agent")
parent_run_config = RunConfig(trace_include_sensitive_data=False)

class DummyResult:
def __init__(self) -> None:
self.final_output = "ok"

async def fake_run(
cls,
starting_agent,
input,
*,
context,
max_turns,
hooks,
run_config,
previous_response_id,
conversation_id,
session,
):
assert starting_agent is agent
assert input == "hello"
assert isinstance(context, ToolContext)
assert run_config is parent_run_config
assert run_config.trace_include_sensitive_data is False
return DummyResult()

monkeypatch.setattr(Runner, "run", classmethod(fake_run))

tool = cast(
FunctionTool,
agent.as_tool(
tool_name="trace_config_tool",
tool_description="inherits trace config",
),
)
tool_context = ToolContext(
context=None,
tool_name="trace_config_tool",
tool_call_id="call_trace",
tool_arguments='{"input":"hello"}',
run_config=parent_run_config,
)

output = await tool.on_invoke_tool(tool_context, '{"input":"hello"}')

assert output == "ok"


@pytest.mark.asyncio
async def test_agent_as_tool_structured_input_sets_tool_input(
monkeypatch: pytest.MonkeyPatch,
Expand Down
26 changes: 26 additions & 0 deletions tests/test_run_step_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,32 @@ async def _fake_tool(context: ToolContext[str], value: str) -> str:
assert isinstance(result.next_step, NextStepRunAgain)


@pytest.mark.asyncio
async def test_function_tool_context_includes_run_config() -> None:
async def _tool_with_run_config(context: ToolContext[str]) -> str:
assert context.run_config is not None
return str(context.run_config.model)

tool = function_tool(
_tool_with_run_config,
name_override="tool_with_run_config",
failure_error_function=None,
)
agent = Agent(name="test", tools=[tool])
response = ModelResponse(
output=[get_function_tool_call("tool_with_run_config", "{}", call_id="call-1")],
usage=Usage(),
response_id=None,
)
run_config = RunConfig(model="gpt-4.1-mini")

result = await get_execute_result(agent, response, run_config=run_config)

assert len(result.generated_items) == 2
assert_item_is_function_tool_call_output(result.generated_items[1], "gpt-4.1-mini")
assert isinstance(result.next_step, NextStepRunAgain)


@pytest.mark.asyncio
async def test_handoff_output_leads_to_handoff_next_step():
agent_1 = Agent(name="test_1")
Expand Down
53 changes: 53 additions & 0 deletions tests/test_tool_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from openai.types.responses import ResponseFunctionToolCall

from agents import Agent
from agents.run_config import RunConfig
from agents.run_context import RunContextWrapper
from agents.tool_context import ToolContext
from tests.utils.hitl import make_context_wrapper
Expand Down Expand Up @@ -103,3 +104,55 @@ def test_tool_context_from_tool_context_inherits_agent() -> None:
)

assert derived_context.agent is agent


def test_tool_context_from_tool_context_inherits_run_config() -> None:
original_call = ResponseFunctionToolCall(
type="function_call",
name="test_tool",
call_id="call-3",
arguments="{}",
)
derived_call = ResponseFunctionToolCall(
type="function_call",
name="test_tool",
call_id="call-4",
arguments="{}",
)
parent_run_config = RunConfig(model="gpt-4.1-mini")
parent_context: ToolContext[dict[str, object]] = ToolContext(
context={},
tool_name="test_tool",
tool_call_id="call-3",
tool_arguments="{}",
tool_call=original_call,
run_config=parent_run_config,
)

derived_context = ToolContext.from_agent_context(
parent_context,
tool_call_id="call-4",
tool_call=derived_call,
)

assert derived_context.run_config is parent_run_config


def test_tool_context_from_agent_context_prefers_explicit_run_config() -> None:
tool_call = ResponseFunctionToolCall(
type="function_call",
name="test_tool",
call_id="call-1",
arguments="{}",
)
ctx = make_context_wrapper()
explicit_run_config = RunConfig(model="gpt-4.1")

tool_ctx = ToolContext.from_agent_context(
ctx,
tool_call_id="call-1",
tool_call=tool_call,
run_config=explicit_run_config,
)

assert tool_ctx.run_config is explicit_run_config