Skip to content
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
51 changes: 34 additions & 17 deletions src/google/adk/flows/llm_flows/contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,39 @@ def _is_timestamp_compacted(ts: float) -> bool:
return [event for _, _, event in processed_items]


def filter_rewound_events(events: list[Event]) -> list[Event]:
"""Returns events with those annulled by a rewind removed.

Iterates backward; when a rewind marker is found, skips all events
back to the rewind_before_invocation_id.

Args:
events: The full event list from the session.

Returns:
A new list with rewound events removed, in the original order.
"""
# Pre-compute the first occurrence index of each invocation_id for O(1) lookup.
first_occurrence: dict[str, int] = {}
for idx, event in enumerate(events):
if event.invocation_id not in first_occurrence:
first_occurrence[event.invocation_id] = idx

filtered = []
i = len(events) - 1
while i >= 0:
event = events[i]
if event.actions and event.actions.rewind_before_invocation_id:
rewind_id = event.actions.rewind_before_invocation_id
if rewind_id in first_occurrence and first_occurrence[rewind_id] < i:
i = first_occurrence[rewind_id]
else:
filtered.append(event)
i -= 1
filtered.reverse()
return filtered


def _get_contents(
current_branch: Optional[str],
events: list[Event],
Expand All @@ -430,23 +463,7 @@ def _get_contents(
accumulated_output_transcription = ''

# Filter out events that are annulled by a rewind.
# By iterating backward, when a rewind event is found, we skip all events
# from that point back to the `rewind_before_invocation_id`, thus removing
# them from the history used for the LLM request.
rewind_filtered_events = []
i = len(events) - 1
while i >= 0:
event = events[i]
if event.actions and event.actions.rewind_before_invocation_id:
rewind_invocation_id = event.actions.rewind_before_invocation_id
for j in range(0, i, 1):
if events[j].invocation_id == rewind_invocation_id:
i = j
break
else:
rewind_filtered_events.append(event)
i -= 1
rewind_filtered_events.reverse()
rewind_filtered_events = filter_rewound_events(events)

# Parse the events, leaving the contents and the function calls and
# responses from the current agent.
Expand Down
5 changes: 3 additions & 2 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,8 @@ def _find_agent_to_run(
# the agent that returned the corresponding function call regardless the
# type of the agent. e.g. a remote a2a agent may surface a credential
# request as a special long-running function tool call.
event = find_matching_function_call(session.events)
filtered_events = contents.filter_rewound_events(session.events)
event = find_matching_function_call(filtered_events)
if event and event.author:
return root_agent.find_agent(event.author)

Expand All @@ -1124,7 +1125,7 @@ def _event_filter(event: Event) -> bool:
return False
return True

for event in filter(_event_filter, reversed(session.events)):
for event in filter(_event_filter, reversed(filtered_events)):
if event.author == root_agent.name:
# Found root agent.
return root_agent
Expand Down
92 changes: 92 additions & 0 deletions tests/unittests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from google.adk.cli.utils.agent_loader import AgentLoader
from google.adk.errors.session_not_found_error import SessionNotFoundError
from google.adk.events.event import Event
from google.adk.events.event import EventActions
from google.adk.flows.llm_flows.contents import filter_rewound_events
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
Expand Down Expand Up @@ -642,6 +644,96 @@ def test_is_transferable_across_agent_tree_with_non_llm_agent(self):
assert result is False


def test_find_agent_to_run_ignores_rewound_sub_agent_event():
"""After a rewind, events from the rewound invocation are ignored."""
root_agent = MockLlmAgent("root_agent")
sub_agent1 = MockLlmAgent("sub_agent1", parent_agent=root_agent)
root_agent.sub_agents = [sub_agent1]

runner = Runner(
app_name="test_app",
agent=root_agent,
session_service=InMemorySessionService(),
artifact_service=InMemoryArtifactService(),
)

# sub_agent1 was the last active agent during inv1
sub_agent_event = Event(
invocation_id="inv1",
author="sub_agent1",
content=types.Content(
role="model", parts=[types.Part(text="Sub agent response")]
),
)
# Rewind event that annuls inv1 and everything after it
rewind_event = Event(
invocation_id="inv2",
author="user",
actions=EventActions(rewind_before_invocation_id="inv1"),
)
session = Session(
id="test_session",
user_id="test_user",
app_name="test_app",
events=[sub_agent_event, rewind_event],
)

result = runner._find_agent_to_run(session, root_agent)
assert result == root_agent


def test_find_agent_to_run_ignores_rewound_function_call():
"""After a rewind, a function call from the rewound invocation is not matched."""
root_agent = MockLlmAgent("root_agent")
sub_agent2 = MockLlmAgent("sub_agent2", parent_agent=root_agent)
root_agent.sub_agents = [sub_agent2]

runner = Runner(
app_name="test_app",
agent=root_agent,
session_service=InMemorySessionService(),
artifact_service=InMemoryArtifactService(),
)

function_call = types.FunctionCall(id="func_789", name="test_func", args={})
function_response = types.FunctionResponse(
id="func_789", name="test_func", response={}
)

# sub_agent2 issued a function call in inv1
call_event = Event(
invocation_id="inv1",
author="sub_agent2",
content=types.Content(
role="model", parts=[types.Part(function_call=function_call)]
),
)
# User provides the function response, also in inv1
response_event = Event(
invocation_id="inv1",
author="user",
content=types.Content(
role="user", parts=[types.Part(function_response=function_response)]
),
)
# Rewind event that annuls inv1
rewind_event = Event(
invocation_id="inv2",
author="user",
actions=EventActions(rewind_before_invocation_id="inv1"),
)
session = Session(
id="test_session",
user_id="test_user",
app_name="test_app",
events=[call_event, response_event, rewind_event],
)

# The rewound function call should not be matched; root_agent is returned
result = runner._find_agent_to_run(session, root_agent)
assert result == root_agent


@pytest.mark.asyncio
async def test_run_config_custom_metadata_propagates_to_events():
session_service = InMemorySessionService()
Expand Down