diff --git a/contributing/samples/adk_concurrent_agent_tool_call/README.md b/contributing/samples/adk_concurrent_agent_tool_call/README.md new file mode 100644 index 0000000000..061e519695 --- /dev/null +++ b/contributing/samples/adk_concurrent_agent_tool_call/README.md @@ -0,0 +1,88 @@ +# Concurrent Agent Tool Call Tests + +This sample directory contains tests for concurrency issues that can occur when multiple agents or runners share toolsets and execute tools concurrently. The tests verify that closing one runner or completing one AgentTool call does not interrupt tools being executed by other runners or AgentTool calls that share the same toolset. + +## Structure + +- **`mock_tools.py`**: Common mock tools and toolsets used by all tests + - `MockTool`: A mock tool that waits for a `done_event` before completing + - `MockMcpToolset`: A mock MCP toolset with a closed event for testing concurrency + +- **`runner_shared_toolset/`**: Tests concurrent runner behavior with shared toolsets + - Tests the scenario where two `InMemoryRunner` instances share the same agent and toolset + - Verifies that closing one runner doesn't interrupt tools being executed by the other runner + +- **`agent_tool_parallel/`**: Tests parallel AgentTool call behavior + - Tests the scenario where a root agent calls a sub-agent via `AgentTool` multiple times in parallel + - Verifies that `AgentToolManager` properly handles parallel execution of `AgentTool` calls that share the same agent + +## Problem Statement + +Both test scenarios address similar concurrency issues: + +1. **Runner Shared Toolset**: When multiple `Runner` instances share the same agent (and thus the same toolset), closing one runner should not interrupt tools being executed by other runners. + +2. **AgentTool Parallel Calls**: When a root agent calls a sub-agent via `AgentTool` multiple times in parallel, each `AgentTool` call creates a `Runner` that uses the same sub-agent. When one `AgentTool` call completes and its runner closes, other parallel calls should not be interrupted. + +## Running the Tests + +### Runner Shared Toolset Test + +```bash +# Run the test script directly +python -m contributing.samples.adk_concurrent_agent_tool_call.runner_shared_toolset.main + +# Or use the ADK CLI +adk run contributing/samples/adk_concurrent_agent_tool_call/runner_shared_toolset +``` + +### AgentTool Parallel Call Test + +```bash +# Run the test script directly +python -m contributing.samples.adk_concurrent_agent_tool_call.agent_tool_parallel.main + +# Or use the ADK CLI +adk run contributing/samples/adk_concurrent_agent_tool_call/agent_tool_parallel +``` + +## Common Components + +### MockTool + +A mock tool that waits for a `done_event` before completing. It checks if the toolset's `closed_event` is set during execution and raises an error if interrupted. + +### MockMcpToolset + +A mock MCP toolset that simulates a stateful protocol. It creates a new `MockTool` instance on each `get_tools()` call (not cached), which is important for testing the concurrency scenario. + +## Expected Behavior + +Both tests should verify: + +- Tools should start executing concurrently +- When one runner/AgentTool call completes, other parallel executions should continue +- All parallel executions should complete successfully without being interrupted +- No "interrupted" errors should appear in the events + +## Key Testing Points + +1. **Concurrent Tool Execution**: Verifies that multiple runners/AgentTool calls can execute tools from the same toolset simultaneously +2. **Toolset Closure Handling**: Ensures that closing one runner doesn't affect tools being executed by other runners +3. **State Management**: Tests that shared toolset state is properly managed across multiple runners/AgentTool calls +4. **Error Detection**: Checks for interruption errors in parallel executions + +## Implementation Details + +Both tests use monkey patching to track when tools are called: + +- Patches `functions.__call_tool_async` to track running tools +- Uses `asyncio.Event` to synchronize tool execution +- Monitors events to detect any interruption errors + +## Related Components + +- **AgentTool**: The tool that wraps an agent and allows it to be called as a tool +- **AgentToolManager**: Manages runner registration and toolset cleanup for `AgentTool` +- **Runner**: The execution engine that orchestrates agent execution +- **BaseToolset**: Base class for toolsets that can be shared across multiple runners diff --git a/contributing/samples/adk_concurrent_agent_tool_call/__init__.py b/contributing/samples/adk_concurrent_agent_tool_call/__init__.py new file mode 100644 index 0000000000..0a2669d7a2 --- /dev/null +++ b/contributing/samples/adk_concurrent_agent_tool_call/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contributing/samples/adk_concurrent_agent_tool_call/agent_tool_parallel/__init__.py b/contributing/samples/adk_concurrent_agent_tool_call/agent_tool_parallel/__init__.py new file mode 100644 index 0000000000..8630a7f719 --- /dev/null +++ b/contributing/samples/adk_concurrent_agent_tool_call/agent_tool_parallel/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent + +__all__ = ["agent"] diff --git a/contributing/samples/adk_concurrent_agent_tool_call/agent_tool_parallel/agent.py b/contributing/samples/adk_concurrent_agent_tool_call/agent_tool_parallel/agent.py new file mode 100644 index 0000000000..3ac2679950 --- /dev/null +++ b/contributing/samples/adk_concurrent_agent_tool_call/agent_tool_parallel/agent.py @@ -0,0 +1,66 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=g-importing-member + +import os +import sys + +SAMPLES_DIR = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..") +) +if SAMPLES_DIR not in sys.path: + sys.path.append(SAMPLES_DIR) + +from adk_concurrent_agent_tool_call.mock_tools import MockMcpToolset +from google.adk import Agent +from google.adk.tools.agent_tool import AgentTool + +# Create a MCP toolset for the sub-agent +sub_agent_mcp_toolset = MockMcpToolset() + +sub_agent_system_prompt = """ +You are a helpful sub-agent that can use tools to help users. +When asked to use the mcp_tool, you should call it. +""" + +# Create a sub-agent that uses the MCP toolset +sub_agent = Agent( + model="gemini-2.5-flash", + name="sub_agent", + description=( + "A sub-agent that uses a MCP toolset for testing parallel AgentTool" + " calls." + ), + instruction=sub_agent_system_prompt, + tools=[sub_agent_mcp_toolset], +) + +# Create the root agent that uses AgentTool to call the sub-agent +root_agent_system_prompt = """ +You are a helpful assistant that can call sub-agents as tools. +When asked to use the sub_agent tool, you should call it. +You can call multiple sub_agent tools in parallel if needed. +""" + +root_agent = Agent( + model="gemini-2.5-flash", + name="root_agent", + description=( + "A root agent that calls sub-agents via AgentTool for testing parallel" + " execution." + ), + instruction=root_agent_system_prompt, + tools=[AgentTool(agent=sub_agent)], +) diff --git a/contributing/samples/adk_concurrent_agent_tool_call/agent_tool_parallel/main.py b/contributing/samples/adk_concurrent_agent_tool_call/agent_tool_parallel/main.py new file mode 100644 index 0000000000..238993fed6 --- /dev/null +++ b/contributing/samples/adk_concurrent_agent_tool_call/agent_tool_parallel/main.py @@ -0,0 +1,234 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=g-importing-member + +import os +import sys + +SAMPLES_DIR = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..") +) +if SAMPLES_DIR not in sys.path: + sys.path.append(SAMPLES_DIR) + +import asyncio +import time +from typing import Any + +from adk_concurrent_agent_tool_call.agent_tool_parallel import agent +from adk_concurrent_agent_tool_call.mock_tools import MockMcpTool +from google.adk.agents.run_config import RunConfig +from google.adk.events.event import Event +from google.adk.runners import InMemoryRunner +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types + +# Track running tools using monkey patch +running_tools: dict[str, MockMcpTool] = {} + +# Track running tasks using monkey patch +running_tasks: dict[str, asyncio.Task[Any]] = {} + + +async def main(): + """Tests parallel AgentTool call behavior with shared agents. + + This test verifies the scenario where: + 1. Root agent calls sub-agent via AgentTool multiple times in parallel + 2. Each AgentTool call creates a runner that uses the same sub-agent + 3. Each sub-agent runner calls tools from the shared toolset concurrently + 4. When one AgentTool call completes and its runner closes, other parallel + calls should not be interrupted + + This demonstrates that AgentToolManager properly handles parallel execution + of AgentTool calls that share the same agent. + """ + app_name = "adk_agent_tool_parallel_app" + user_id = "adk_agent_tool_parallel_user" + + trigger_count = 0 + + # Event to wait for both tool call requests to be made + tool_call_request_event = asyncio.Event() + + def trigger_tool_call_request(): + """Trigger the tool call request event.""" + nonlocal trigger_count + trigger_count += 1 + if trigger_count >= 2: + tool_call_request_event.set() + + # Create runner with root agent + runner = InMemoryRunner( + agent=agent.root_agent, + app_name=app_name, + ) + + session = await runner.session_service.create_session( + app_name=app_name, user_id=user_id + ) + + # Monkey patch __call_tool_async to track running tools + from google.adk.flows.llm_flows import functions + + original_call_tool_async = functions.__call_tool_async + + async def patched_call_tool_async( + tool: BaseTool, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + """Patched version that tracks running tools.""" + if isinstance(tool, MockMcpTool): + running_tools[tool_context.state["task_id"]] = tool + print(f"Tool {tool.name} started for session {tool_context.session.id}") + trigger_tool_call_request() + return await original_call_tool_async(tool, args, tool_context) + + functions.__call_tool_async = patched_call_tool_async + + # Monkey patch AgentTool.run_async to track running task + from google.adk.tools.agent_tool import AgentTool + + original_run_async = AgentTool.run_async + + async def patched_run_async( + self: AgentTool, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + """Patched version that tracks running task.""" + task = asyncio.create_task( + original_run_async(self, args=args, tool_context=tool_context) + ) + + task_id = task.__hash__() + tool_context.state["task_id"] = task_id + running_tasks[task_id] = task + return await task + + AgentTool.run_async = patched_run_async + + events: list[Event] = [] + try: + + async def run_agent(): + nonlocal events + # Run agent with a prompt that triggers parallel AgentTool calls + print("Starting agent with parallel AgentTool calls...") + content = types.Content( + role="user", + parts=[ + types.Part.from_text( + text=( + "Please call the sub_agent tool twice in parallel to" + " help me." + ) + ) + ], + ) + + async for event in runner.run_async( + user_id=session.user_id, + session_id=session.id, + new_message=content, + run_config=RunConfig(save_input_blobs_as_artifacts=False), + ): + events.append(event) + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f"Session {session.id}: {event.author}: {part.text}") + if part.function_call: + print( + f"Session {session.id}: {event.author}: function_call" + f" {part.function_call.name}" + ) + if part.function_response: + print( + f"Session {session.id}: {event.author}: function_response" + f" {part.function_response.name}" + ) + return events + + runner_task = asyncio.create_task(run_agent()) + + # Wait for both tools to start if they haven't already + assert not runner_task.done(), "Runner should not be done" + await tool_call_request_event.wait() + + print(f"Running tools: {list(running_tools.keys())}") + + # Get the running tools + tool1_tuple = list(running_tools.items())[0] + tool2_tuple = list(running_tools.items())[1] + + tool1_task = running_tasks[tool1_tuple[0]] + tool2_task = running_tasks[tool2_tuple[0]] + + # Complete tool1 + print("Waiting for agent tool 1 to complete...") + tool1_tuple[1].done_event.set() + await tool1_task + print("Tool1 completed ✓") + + await asyncio.sleep(0.1) + + print("Waiting for agent tool 2 to complete...") + tool2_tuple[1].done_event.set() + await tool2_task + print("Tool2 completed ✓") + + await runner_task + print(f"Agent completed with {len(events)} events ✓") + + # Check if any tools were interrupted + has_error = any( + event.content + and event.content.parts + and any( + "interrupted" in str(part.function_response) + or "interrupted" in str(part.text) + for part in event.content.parts + ) + for event in events + ) + + if has_error: + print("⚠️ Some tools were interrupted during parallel execution") + else: + print("✅ All parallel AgentTool calls completed successfully") + + finally: + # Restore original function + functions.__call_tool_async = original_call_tool_async + AgentTool.run_async = original_run_async + print("Monkey patch restored ✓") + + +if __name__ == "__main__": + start_time = time.time() + print( + "Script start time:", + time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(start_time)), + ) + print("=" * 50) + print("Testing parallel AgentTool calls with shared agents") + print("=" * 50) + asyncio.run(main()) + end_time = time.time() + print("=" * 50) + print( + "Script end time:", + time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(end_time)), + ) + print("Total script execution time:", f"{end_time - start_time:.2f} seconds") diff --git a/contributing/samples/adk_concurrent_agent_tool_call/mock_tools.py b/contributing/samples/adk_concurrent_agent_tool_call/mock_tools.py new file mode 100644 index 0000000000..98ae4346bf --- /dev/null +++ b/contributing/samples/adk_concurrent_agent_tool_call/mock_tools.py @@ -0,0 +1,94 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=g-importing-member + +import asyncio +from typing import Any + +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.base_toolset import BaseToolset +from google.adk.tools.tool_context import ToolContext +from google.genai.types import FunctionDeclaration + + +class MockMcpTool(BaseTool): + """A mock tool that waits for a done event before completing.""" + + def __init__( + self, + name: str, + closed_event: asyncio.Event, + ): + super().__init__(name=name, description=f"Mock tool {name}") + self.closed_event = closed_event + self.done_event = asyncio.Event() + + def _get_declaration(self) -> FunctionDeclaration: + return FunctionDeclaration( + name=self.name, + description=self.description, + ) + + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> dict[str, str]: + """Runs the tool, checking if toolset is closed during execution.""" + # Check if toolset is closed before starting + if self.closed_event.is_set(): + raise RuntimeError(f"Tool {self.name} cannot run: toolset is closed") + + closed_event_task = asyncio.create_task(self.closed_event.wait()) + done_event_task = asyncio.create_task(self.done_event.wait()) + + await asyncio.wait( + [closed_event_task, done_event_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + # Check if toolset was closed during execution + if self.closed_event.is_set(): + raise RuntimeError( + f"Tool {self.name} interrupted: toolset was closed during execution" + ) + + # Tool completed successfully + return {"result": f"Tool {self.name} completed successfully"} + + +class MockMcpToolset(BaseToolset): + """A mock MCP toolset with a closed event. + This toolset is used to test concurrency scenarios with shared toolsets. + """ + + def __init__(self): + super().__init__() + self.closed_event = asyncio.Event() + + async def get_tools(self, readonly_context=None) -> list[BaseTool]: + """Returns a single mock tool.""" + # Note that if you cache the tool, there is no such issue since the tool is reused. + # e.g. `return [self._tool]` + # However, MCP is a stateful protocol, so the tool should not be reused. + return [ + MockMcpTool( + name="mcp_tool", + closed_event=self.closed_event, + ) + ] + + async def close(self) -> None: + """Closes the toolset by setting the closed event.""" + print(f"Closing toolset {self.__hash__()}") + self.closed_event.set() diff --git a/contributing/samples/adk_concurrent_agent_tool_call/runner_shared_toolset/__init__.py b/contributing/samples/adk_concurrent_agent_tool_call/runner_shared_toolset/__init__.py new file mode 100644 index 0000000000..8630a7f719 --- /dev/null +++ b/contributing/samples/adk_concurrent_agent_tool_call/runner_shared_toolset/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent + +__all__ = ["agent"] diff --git a/contributing/samples/adk_concurrent_agent_tool_call/runner_shared_toolset/agent.py b/contributing/samples/adk_concurrent_agent_tool_call/runner_shared_toolset/agent.py new file mode 100644 index 0000000000..54213b3673 --- /dev/null +++ b/contributing/samples/adk_concurrent_agent_tool_call/runner_shared_toolset/agent.py @@ -0,0 +1,45 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=g-importing-member + +import os +import sys + +SAMPLES_DIR = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..") +) +if SAMPLES_DIR not in sys.path: + sys.path.append(SAMPLES_DIR) + +from adk_concurrent_agent_tool_call.mock_tools import MockMcpToolset +from google.adk import Agent + +# Create a MCP toolset +mcp_toolset = MockMcpToolset() + +system_prompt = """ +You are a helpful assistant that can use tools to help users. +When asked to use the mcp_tool, you should call it. +""" + +root_agent = Agent( + model="gemini-2.5-flash", + name="parallel_agent", + description=( + "An agent that uses a MCP toolset for testing runner close behavior." + ), + instruction=system_prompt, + tools=[mcp_toolset], +) diff --git a/contributing/samples/adk_concurrent_agent_tool_call/runner_shared_toolset/main.py b/contributing/samples/adk_concurrent_agent_tool_call/runner_shared_toolset/main.py new file mode 100644 index 0000000000..39a4d5076c --- /dev/null +++ b/contributing/samples/adk_concurrent_agent_tool_call/runner_shared_toolset/main.py @@ -0,0 +1,240 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +SAMPLES_DIR = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..") +) +if SAMPLES_DIR not in sys.path: + sys.path.append(SAMPLES_DIR) + +import asyncio +import time +from typing import Any + +from adk_concurrent_agent_tool_call.mock_tools import MockMcpTool +from adk_concurrent_agent_tool_call.runner_shared_toolset import agent +from google.adk.agents.run_config import RunConfig +from google.adk.runners import InMemoryRunner +from google.adk.sessions.session import Session +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types + +# Track running tools using monkey patch +running_tools: dict[str, MockMcpTool] = {} + + +async def main(): + """Tests runner close behavior with shared toolsets. + + This test verifies the scenario where: + 1. Runner1 and Runner2 both use the same agent with a shared toolset + 2. Both runners call tools concurrently + 3. Runner1's tool completes and runner1 closes (which closes the shared toolset) + 4. Runner2's tool should not be interrupted when toolset is closed + + This demonstrates the issue: when a toolset is closed, all tools using that + toolset are affected, even if they're being used by different runners. + """ + app_name = "adk_parallel_agent_app" + user_id_1 = "adk_parallel_user_1" + user_id_2 = "adk_parallel_user_2" + + trigger_count = 0 + + # Event to wait for both tool call requests to be made + tool_call_request_event = asyncio.Event() + + def trigger_tool_call_request(): + """Trigger the tool call request event.""" + nonlocal trigger_count + trigger_count += 1 + if trigger_count >= 2: + tool_call_request_event.set() + + # Create two runners with the same agent (sharing the same toolset) + runner1 = InMemoryRunner( + agent=agent.root_agent, + app_name=app_name, + ) + runner2 = InMemoryRunner( + agent=agent.root_agent, + app_name=app_name, + ) + + session_1 = await runner1.session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + session_2 = await runner2.session_service.create_session( + app_name=app_name, user_id=user_id_2 + ) + + # Monkey patch __call_tool_async to track running tools + from google.adk.flows.llm_flows import functions + + original_call_tool_async = functions.__call_tool_async + + async def patched_call_tool_async( + tool: BaseTool, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + """Patched version that tracks running tools.""" + if isinstance(tool, MockMcpTool): + running_tools[tool_context.session.id] = tool + print(f"Tool {tool.name} started for session {tool_context.session.id}") + trigger_tool_call_request() + return await original_call_tool_async(tool, args, tool_context) + + functions.__call_tool_async = patched_call_tool_async + + try: + + async def run_agent_prompt( + runner: InMemoryRunner, session: Session, prompt_text: str + ): + """Run agent with a prompt and collect events.""" + content = types.Content( + role="user", parts=[types.Part.from_text(text=prompt_text)] + ) + events = [] + async for event in runner.run_async( + user_id=session.user_id, + session_id=session.id, + new_message=content, + run_config=RunConfig(save_input_blobs_as_artifacts=False), + ): + events.append(event) + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print( + f"Runner {runner.__hash__()} Session {session.id}:" + f" {event.author}: {part.text}" + ) + if part.function_call: + print( + f"Runner {runner.__hash__()} Session {session.id}:" + f" {event.author}: function_call {part.function_call.name}" + ) + if part.function_response: + print( + f"Runner {runner.__hash__()} Session {session.id}:" + f" {event.author}: function_response" + f" {part.function_response.name}" + ) + return events + + # Start tool execution in runner1 + print("Starting runner tool execution...") + runner1_task = asyncio.create_task( + run_agent_prompt( + runner1, session_1, "Please use the mcp_tool to help me." + ) + ) + + # Start tool execution in runner2 + runner2_task = asyncio.create_task( + run_agent_prompt( + runner2, session_2, "Please use the mcp_tool to help me." + ) + ) + + # Verify both runners are running + assert not runner1_task.done(), "Runner1 should not be done" + assert not runner2_task.done(), "Runner2 should not be done" + + # Wait to both tools are running + await tool_call_request_event.wait() + + print(f"Running tools: {list(running_tools.keys())}") + + # Get the running tools + runner1_tool = running_tools.get(session_1.id) + runner2_tool = running_tools.get(session_2.id) + + if runner1_tool: + print(f"Completing runner1's tool (session {session_1.id})...") + runner1_tool.done_event.set() + + # Verify runner1 completed + print("Waiting for runner1 to complete...") + runner1_events = await runner1_task + print(f"Runner1 completed with {len(runner1_events)} events ✓") + + # We are closing the runner1 here, this will close the toolset and interrupt the runner2's tool. + # This may happen when you call 2 concurrent AgentTools of which origins are the same agent. + await runner1.close() + + # Verify toolset was closed + # assert agent.mcp_toolset.closed_event.is_set() + # print("Toolset closed event is set ✓") + + # Complete runner2's tool if it's still running + if runner2_tool: + print(f"Completing runner2's tool (session {session_2.id})...") + runner2_tool.done_event.set() + + # Wait for runner2's task to complete + print("Waiting for runner2 to complete...") + runner2_events = await runner2_task + print(f"Runner2 completed with {len(runner2_events)} events") + + # Check if runner2's tool was interrupted + has_error = any( + event.content + and event.content.parts + and any( + "interrupted" in str(part.function_response) + or "interrupted" in str(part.text) + for part in event.content.parts + ) + for event in runner2_events + ) + + if has_error: + print("Runner2's tool was interrupted by toolset close") + else: + print( + "Runner2's tool completed normally (may have finished before close) ✓" + ) + + # Clean up runner2 + await runner2.close() + print("All runners closed ✓") + + finally: + # Restore original function + functions.__call_tool_async = original_call_tool_async + print("Monkey patch restored ✓") + + +if __name__ == "__main__": + start_time = time.time() + print( + "Script start time:", + time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(start_time)), + ) + print("=" * 50) + print("Testing runner close with shared toolsets") + print("=" * 50) + asyncio.run(main()) + end_time = time.time() + print("=" * 50) + print( + "Script end time:", + time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(end_time)), + ) + print("Total script execution time:", f"{end_time - start_time:.2f} seconds") diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 1773729719..2997d02435 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -1456,11 +1456,17 @@ async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]): except Exception as e: logger.error('Error closing toolset %s: %s', type(toolset).__name__, e) - async def close(self): - """Closes the runner.""" + async def close(self, cleanup_toolsets: bool = True): + """Closes the runner. + + Args: + cleanup_toolsets: Whether to cleanup toolsets. + Default is True. + """ logger.info('Closing runner...') - # Close Toolsets - await self._cleanup_toolsets(self._collect_toolset(self.agent)) + if cleanup_toolsets: + # Close Toolsets + await self._cleanup_toolsets(self._collect_toolset(self.agent)) # Close Plugins if self.plugin_manager: diff --git a/src/google/adk/tools/_agent_tool_manager.py b/src/google/adk/tools/_agent_tool_manager.py new file mode 100644 index 0000000000..20603b127e --- /dev/null +++ b/src/google/adk/tools/_agent_tool_manager.py @@ -0,0 +1,130 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..agents.base_agent import BaseAgent + from ..runners import Runner + + +class AgentToolManager: + """Manages the relationship between runners and agents used by AgentTool. + + This class prevents premature cleanup of agent toolsets when multiple + AgentTools using the same agent are running concurrently. It tracks + active runners per agent and ensures that agent toolsets are only cleaned + up when no runners are using that agent. + + The manager uses a lock to ensure thread-safe registration and + unregistration of runners. When unregistering a runner, the lock is held + until the returned async generator is fully consumed, ensuring that cleanup + operations can complete safely without race conditions. + """ + + def __init__(self): + """Initializes the AgentToolManager.""" + # Maps agent to a set of active runners using that agent + self._runners_by_agent: dict[int, set[Runner]] = {} + # Lock to ensure thread-safe access to _runners_by_agent + self._lock = asyncio.Lock() + + async def register_runner(self, agent: BaseAgent, runner: Runner) -> None: + """Registers a runner for the given agent. + + This method should be called at the start of AgentTool.run_async() + when a runner is created. The runner is tracked to prevent premature + cleanup of the agent's toolsets. + + Args: + agent: The agent instance used by the runner. + runner: The runner instance to register. + """ + async with self._lock: + # TODO: can we use the name of the agent as the key? + if id(agent) not in self._runners_by_agent: + self._runners_by_agent[id(agent)] = set() + self._runners_by_agent[id(agent)].add(runner) + + @asynccontextmanager + async def unregister_runner(self, agent: BaseAgent, runner: Runner): + """Unregisters a runner for the given agent. + + This method should be called before cleaning up a runner at the end + of AgentTool.run_async(). It returns an async context manager that yields + whether the runner should be cleaned up (i.e., if it's the last runner + using the agent). The lock is held until the context manager is fully consumed, + ensuring that cleanup operations can complete safely. + + Usage: + async with manager.unregister_runner(agent, runner) as should_cleanup: + if should_cleanup: + await runner.close() + + Args: + agent: The agent instance used by the runner. + runner: The runner instance to unregister. + + Yields: + True if this was the last runner using the agent and cleanup should + proceed, False if other runners are still using the agent and cleanup + should be skipped. + """ + async with self._lock: + yield self._unregister(agent, runner) + + def _unregister(self, agent: BaseAgent, runner: Runner) -> bool: + """Unregisters a runner and determines if cleanup should proceed. + + Args: + agent: The agent instance used by the runner. + runner: The runner instance to unregister. + + Returns: + True if cleanup should proceed (no other runners using the agent), + False if cleanup should be skipped (other runners still using the agent). + """ + if id(agent) not in self._runners_by_agent: + # Agent not registered, safe to cleanup + return True + + runners = self._runners_by_agent[id(agent)] + if runner not in runners: + # Runner not registered, safe to cleanup + return True + + runners.remove(runner) + + # If no runners left for this agent, cleanup is safe + if not runners: + del self._runners_by_agent[id(agent)] + return True + + # Other runners still using this agent, skip cleanup + return False + + +_agent_tool_manager_instance: AgentToolManager | None = None + + +def get_agent_tool_manager() -> AgentToolManager: + """Gets the singleton AgentToolManager instance, initializing it if needed.""" + global _agent_tool_manager_instance + if _agent_tool_manager_instance is None: + _agent_tool_manager_instance = AgentToolManager() + return _agent_tool_manager_instance diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 46d8616619..2e1c053447 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -25,6 +25,7 @@ from ..agents.common_configs import AgentRefConfig from ..memory.in_memory_memory_service import InMemoryMemoryService from ..utils.context_utils import Aclosing +from ._agent_tool_manager import get_agent_tool_manager from ._forwarding_artifact_service import ForwardingArtifactService from .base_tool import BaseTool from .tool_configs import BaseToolConfig @@ -61,6 +62,7 @@ def __init__( self.agent = agent self.skip_summarization: bool = skip_summarization self.include_plugins = include_plugins + self._agent_tool_manager = get_agent_tool_manager() super().__init__(name=agent.name, description=agent.description) @@ -158,6 +160,8 @@ async def run_async( plugins=plugins, ) + await self._agent_tool_manager.register_runner(self.agent, runner) + state_dict = { k: v for k, v in tool_context.state.to_dict().items() @@ -184,7 +188,10 @@ async def run_async( # Clean up runner resources (especially MCP sessions) # to avoid "Attempted to exit cancel scope in a different task" errors - await runner.close() + async with self._agent_tool_manager.unregister_runner( + self.agent, runner + ) as should_cleanup_toolsets: + await runner.close(cleanup_toolsets=should_cleanup_toolsets) if not last_content: return '' diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index a9723b4347..04c1d037ed 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -131,7 +131,7 @@ def run_async( ) return _empty_async_generator() - async def close(self): + async def close(self, cleanup_toolsets: bool = False): """Mock close method.""" pass diff --git a/tests/unittests/tools/test_agent_tool_manager.py b/tests/unittests/tools/test_agent_tool_manager.py new file mode 100644 index 0000000000..52315609a2 --- /dev/null +++ b/tests/unittests/tools/test_agent_tool_manager.py @@ -0,0 +1,333 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from unittest import mock + +from google.adk.agents.llm_agent import Agent +from google.adk.runners import Runner +from google.adk.tools._agent_tool_manager import AgentToolManager +from google.adk.tools._agent_tool_manager import get_agent_tool_manager +import pytest + +from .. import testing_utils + + +@pytest.fixture +def manager(): + """Creates a fresh AgentToolManager instance for each test.""" + return AgentToolManager() + + +@pytest.fixture +def agent(): + """Creates a test agent.""" + return Agent( + name='test_agent', + model=testing_utils.MockModel.create(responses=['test']), + ) + + +@pytest.fixture +def runner(agent): + """Creates a test runner.""" + return testing_utils.InMemoryRunner(agent) + + +@pytest.mark.asyncio +async def test_register_runner(manager, agent, runner): + """Test basic runner registration.""" + await manager.register_runner(agent, runner) + + # Verify runner is registered + async with manager._lock: + assert id(agent) in manager._runners_by_agent + assert runner in manager._runners_by_agent[id(agent)] + + +@pytest.mark.asyncio +async def test_unregister_runner_single_runner(manager, agent, runner): + """Test unregistering the only runner for an agent.""" + await manager.register_runner(agent, runner) + + async with manager.unregister_runner(agent, runner) as should_cleanup: + assert should_cleanup is True + + # Verify runner is removed + async with manager._lock: + assert id(agent) not in manager._runners_by_agent + + +@pytest.mark.asyncio +async def test_unregister_runner_multiple_runners(manager, agent): + """Test unregistering one runner when multiple runners exist.""" + runner1 = testing_utils.InMemoryRunner(agent) + runner2 = testing_utils.InMemoryRunner(agent) + + await manager.register_runner(agent, runner1) + await manager.register_runner(agent, runner2) + + # Unregister first runner - should not cleanup + async with manager.unregister_runner(agent, runner1) as should_cleanup: + assert should_cleanup is False + + # Verify runner1 is removed but runner2 remains + async with manager._lock: + assert id(agent) in manager._runners_by_agent + assert runner1 not in manager._runners_by_agent[id(agent)] + assert runner2 in manager._runners_by_agent[id(agent)] + + # Unregister second runner - should cleanup + async with manager.unregister_runner(agent, runner2) as should_cleanup: + assert should_cleanup is True + + # Verify agent is completely removed + async with manager._lock: + assert id(agent) not in manager._runners_by_agent + + +@pytest.mark.asyncio +async def test_unregister_unregistered_runner(manager, agent, runner): + """Test unregistering a runner that was never registered.""" + async with manager.unregister_runner(agent, runner) as should_cleanup: + # Should allow cleanup for unregistered runner + assert should_cleanup is True + + +@pytest.mark.asyncio +async def test_unregister_unregistered_agent(manager, agent, runner): + """Test unregistering from an agent that was never registered.""" + # Register runner for a different agent + other_agent = Agent( + name='other_agent', + model=testing_utils.MockModel.create(responses=['test']), + ) + await manager.register_runner(other_agent, runner) + + # Try to unregister from unregistered agent + async with manager.unregister_runner(agent, runner) as should_cleanup: + assert should_cleanup is True + + # Verify other agent is still registered + async with manager._lock: + assert id(other_agent) in manager._runners_by_agent + assert runner in manager._runners_by_agent[id(other_agent)] + + +@pytest.mark.asyncio +async def test_multiple_agents(manager): + """Test managing runners for multiple different agents.""" + agent1 = Agent( + name='agent1', model=testing_utils.MockModel.create(responses=['test']) + ) + agent2 = Agent( + name='agent2', model=testing_utils.MockModel.create(responses=['test']) + ) + runner1 = testing_utils.InMemoryRunner(agent1) + runner2 = testing_utils.InMemoryRunner(agent2) + + await manager.register_runner(agent1, runner1) + await manager.register_runner(agent2, runner2) + + # Verify both agents are tracked separately + async with manager._lock: + assert id(agent1) in manager._runners_by_agent + assert id(agent2) in manager._runners_by_agent + assert runner1 in manager._runners_by_agent[id(agent1)] + assert runner2 in manager._runners_by_agent[id(agent2)] + + # Unregister one agent + async with manager.unregister_runner(agent1, runner1) as should_cleanup: + assert should_cleanup is True + + # Verify only agent2 remains + async with manager._lock: + assert id(agent1) not in manager._runners_by_agent + assert id(agent2) in manager._runners_by_agent + + +@pytest.mark.asyncio +async def test_concurrent_registration(manager, agent): + """Test concurrent registration of multiple runners.""" + num_runners = 10 + runners = [testing_utils.InMemoryRunner(agent) for _ in range(num_runners)] + + # Register all runners concurrently + await asyncio.gather( + *[manager.register_runner(agent, runner) for runner in runners] + ) + + # Verify all runners are registered + async with manager._lock: + assert id(agent) in manager._runners_by_agent + assert len(manager._runners_by_agent[id(agent)]) == num_runners + for runner in runners: + assert runner in manager._runners_by_agent[id(agent)] + + +@pytest.mark.asyncio +async def test_concurrent_unregistration(manager, agent): + """Test concurrent unregistration of multiple runners.""" + num_runners = 10 + runners = [testing_utils.InMemoryRunner(agent) for _ in range(num_runners)] + + # Register all runners + await asyncio.gather( + *[manager.register_runner(agent, runner) for runner in runners] + ) + + # Unregister all runners concurrently + async def unregister_runner(runner): + async with manager.unregister_runner(agent, runner) as should_cleanup: + return should_cleanup + + cleanup_results = await asyncio.gather( + *[unregister_runner(runner) for runner in runners] + ) + + # Only the last runner should trigger cleanup + cleanup_count = sum(1 for result in cleanup_results if result is True) + assert cleanup_count == 1 + + # Verify agent is removed + async with manager._lock: + assert id(agent) not in manager._runners_by_agent + + +@pytest.mark.asyncio +async def test_concurrent_register_and_unregister(manager, agent): + """Test concurrent registration and unregistration.""" + num_operations = 20 + runners = [testing_utils.InMemoryRunner(agent) for _ in range(num_operations)] + + async def register_and_unregister(runner): + await manager.register_runner(agent, runner) + async with manager.unregister_runner(agent, runner) as should_cleanup: + return should_cleanup + + # Run register/unregister operations concurrently + results = await asyncio.gather( + *[register_and_unregister(runner) for runner in runners] + ) + + # All operations should complete without errors + # The cleanup results depend on timing, but at least one should be True + assert any(results) + + # Verify final state - agent should be removed + async with manager._lock: + assert id(agent) not in manager._runners_by_agent + + +@pytest.mark.asyncio +async def test_lock_prevents_race_condition(manager, agent): + """Test that the lock prevents race conditions during unregistration.""" + runner1 = testing_utils.InMemoryRunner(agent) + runner2 = testing_utils.InMemoryRunner(agent) + + await manager.register_runner(agent, runner1) + await manager.register_runner(agent, runner2) + + # Create a barrier to synchronize unregistration attempts + barrier = asyncio.Barrier(2) + + async def unregister_with_barrier(runner): + await barrier.wait() # Wait for both to reach this point + async with manager.unregister_runner(agent, runner) as should_cleanup: + return should_cleanup + + # Unregister both runners concurrently + results = await asyncio.gather( + unregister_with_barrier(runner1), unregister_with_barrier(runner2) + ) + + # Exactly one should return True (the last one to complete) + cleanup_count = sum(1 for result in results if result is True) + assert cleanup_count == 1 + + # Verify agent is removed + async with manager._lock: + assert id(agent) not in manager._runners_by_agent + + +@pytest.mark.asyncio +async def test_unregister_runner_context_manager_holds_lock(manager, agent): + """Test that unregister_runner context manager holds lock until exit.""" + runner = testing_utils.InMemoryRunner(agent) + await manager.register_runner(agent, runner) + + lock_acquired_during_context = False + + async def try_acquire_lock(): + nonlocal lock_acquired_during_context + try: + # Try to acquire lock with timeout + await asyncio.wait_for(manager._lock.acquire(), timeout=0.1) + lock_acquired_during_context = True + manager._lock.release() + except asyncio.TimeoutError: + # Lock is held, which is expected + pass + + async with manager.unregister_runner(agent, runner) as should_cleanup: + # Try to acquire lock from another task + await try_acquire_lock() + assert should_cleanup is True + + # After context exits, lock should be released + async with manager._lock: + assert id(agent) not in manager._runners_by_agent + + +@pytest.mark.asyncio +async def test_get_agent_tool_manager_singleton(): + """Test that get_agent_tool_manager returns a singleton.""" + manager1 = get_agent_tool_manager() + manager2 = get_agent_tool_manager() + + assert manager1 is manager2 + assert isinstance(manager1, AgentToolManager) + + +@pytest.mark.asyncio +async def test_register_same_runner_twice(manager, agent, runner): + """Test registering the same runner twice for the same agent.""" + await manager.register_runner(agent, runner) + await manager.register_runner(agent, runner) + + # Runner should only appear once in the set + async with manager._lock: + assert id(agent) in manager._runners_by_agent + assert runner in manager._runners_by_agent[id(agent)] + assert len(manager._runners_by_agent[id(agent)]) == 1 + + +@pytest.mark.asyncio +async def test_unregister_same_runner_twice(manager, agent, runner): + """Test unregistering the same runner twice.""" + await manager.register_runner(agent, runner) + + # First unregistration should return True + async with manager.unregister_runner(agent, runner) as should_cleanup: + assert should_cleanup is True + + # Second unregistration should also return True (runner not found) + async with manager.unregister_runner(agent, runner) as should_cleanup: + assert should_cleanup is True + + # Verify agent is removed + async with manager._lock: + assert id(agent) not in manager._runners_by_agent