diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index f3751206a8..2710c3894c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index d857da9635..e31db15788 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index 9b7ef9e121..be1714321b 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -406,6 +406,39 @@ def _is_timestamp_compacted(ts: float) -> bool: return [event for _, _, event in processed_items] +def filter_rewound_events(events: list[Event]) -> list[Event]: + """Returns events with those annulled by a rewind removed. + + Iterates backward; when a rewind marker is found, skips all events + back to the rewind_before_invocation_id. + + Args: + events: The full event list from the session. + + Returns: + A new list with rewound events removed, in the original order. + """ + # Pre-compute the first occurrence index of each invocation_id for O(1) lookup. + first_occurrence: dict[str, int] = {} + for idx, event in enumerate(events): + if event.invocation_id not in first_occurrence: + first_occurrence[event.invocation_id] = idx + + filtered = [] + i = len(events) - 1 + while i >= 0: + event = events[i] + if event.actions and event.actions.rewind_before_invocation_id: + rewind_id = event.actions.rewind_before_invocation_id + if rewind_id in first_occurrence and first_occurrence[rewind_id] < i: + i = first_occurrence[rewind_id] + else: + filtered.append(event) + i -= 1 + filtered.reverse() + return filtered + + def _get_contents( current_branch: Optional[str], events: list[Event], @@ -430,23 +463,7 @@ def _get_contents( accumulated_output_transcription = '' # Filter out events that are annulled by a rewind. - # By iterating backward, when a rewind event is found, we skip all events - # from that point back to the `rewind_before_invocation_id`, thus removing - # them from the history used for the LLM request. - rewind_filtered_events = [] - i = len(events) - 1 - while i >= 0: - event = events[i] - if event.actions and event.actions.rewind_before_invocation_id: - rewind_invocation_id = event.actions.rewind_before_invocation_id - for j in range(0, i, 1): - if events[j].invocation_id == rewind_invocation_id: - i = j - break - else: - rewind_filtered_events.append(event) - i -= 1 - rewind_filtered_events.reverse() + rewind_filtered_events = filter_rewound_events(events) # Parse the events, leaving the contents and the function calls and # responses from the current agent. diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index d623075273..48d55565ac 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -1112,7 +1112,8 @@ def _find_agent_to_run( # the agent that returned the corresponding function call regardless the # type of the agent. e.g. a remote a2a agent may surface a credential # request as a special long-running function tool call. - event = find_matching_function_call(session.events) + filtered_events = contents.filter_rewound_events(session.events) + event = find_matching_function_call(filtered_events) if event and event.author: return root_agent.find_agent(event.author) @@ -1124,7 +1125,7 @@ def _event_filter(event: Event) -> bool: return False return True - for event in filter(_event_filter, reversed(session.events)): + for event in filter(_event_filter, reversed(filtered_events)): if event.author == root_agent.name: # Found root agent. return root_agent diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index cc3abc65e5..071c3b3ff2 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -31,6 +31,8 @@ from google.adk.cli.utils.agent_loader import AgentLoader from google.adk.errors.session_not_found_error import SessionNotFoundError from google.adk.events.event import Event +from google.adk.events.event import EventActions +from google.adk.flows.llm_flows.contents import filter_rewound_events from google.adk.plugins.base_plugin import BasePlugin from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService @@ -642,6 +644,96 @@ def test_is_transferable_across_agent_tree_with_non_llm_agent(self): assert result is False +def test_find_agent_to_run_ignores_rewound_sub_agent_event(): + """After a rewind, events from the rewound invocation are ignored.""" + root_agent = MockLlmAgent("root_agent") + sub_agent1 = MockLlmAgent("sub_agent1", parent_agent=root_agent) + root_agent.sub_agents = [sub_agent1] + + runner = Runner( + app_name="test_app", + agent=root_agent, + session_service=InMemorySessionService(), + artifact_service=InMemoryArtifactService(), + ) + + # sub_agent1 was the last active agent during inv1 + sub_agent_event = Event( + invocation_id="inv1", + author="sub_agent1", + content=types.Content( + role="model", parts=[types.Part(text="Sub agent response")] + ), + ) + # Rewind event that annuls inv1 and everything after it + rewind_event = Event( + invocation_id="inv2", + author="user", + actions=EventActions(rewind_before_invocation_id="inv1"), + ) + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[sub_agent_event, rewind_event], + ) + + result = runner._find_agent_to_run(session, root_agent) + assert result == root_agent + + +def test_find_agent_to_run_ignores_rewound_function_call(): + """After a rewind, a function call from the rewound invocation is not matched.""" + root_agent = MockLlmAgent("root_agent") + sub_agent2 = MockLlmAgent("sub_agent2", parent_agent=root_agent) + root_agent.sub_agents = [sub_agent2] + + runner = Runner( + app_name="test_app", + agent=root_agent, + session_service=InMemorySessionService(), + artifact_service=InMemoryArtifactService(), + ) + + function_call = types.FunctionCall(id="func_789", name="test_func", args={}) + function_response = types.FunctionResponse( + id="func_789", name="test_func", response={} + ) + + # sub_agent2 issued a function call in inv1 + call_event = Event( + invocation_id="inv1", + author="sub_agent2", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call)] + ), + ) + # User provides the function response, also in inv1 + response_event = Event( + invocation_id="inv1", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response)] + ), + ) + # Rewind event that annuls inv1 + rewind_event = Event( + invocation_id="inv2", + author="user", + actions=EventActions(rewind_before_invocation_id="inv1"), + ) + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[call_event, response_event, rewind_event], + ) + + # The rewound function call should not be matched; root_agent is returned + result = runner._find_agent_to_run(session, root_agent) + assert result == root_agent + + @pytest.mark.asyncio async def test_run_config_custom_metadata_propagates_to_events(): session_service = InMemorySessionService()