Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions src/google/adk/sessions/base_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,35 @@ async def append_event(self, session: Session, event: Event) -> Event:
"""Appends an event to a session object."""
if event.partial:
return event
# Apply temp-scoped state to the in-memory session BEFORE trimming the
# event delta, so that subsequent agents within the same invocation can
# read temp values (e.g. output_key='temp:my_key' in SequentialAgent).
self._apply_temp_state(session, event)
event = self._trim_temp_delta_state(event)
self._update_session_state(session, event)
session.events.append(event)
return event

def _apply_temp_state(self, session: Session, event: Event) -> None:
"""Applies temp-scoped state delta to the in-memory session state.

Temp state is ephemeral: it lives in the session's in-memory state for
the duration of the current invocation but is NOT persisted to storage
(the event delta is trimmed separately by _trim_temp_delta_state).
"""
if not event.actions or not event.actions.state_delta:
return
for key, value in event.actions.state_delta.items():
if key.startswith(State.TEMP_PREFIX):
session.state[key] = value

def _trim_temp_delta_state(self, event: Event) -> Event:
"""Removes temporary state delta keys from the event."""
"""Removes temporary state delta keys from the event.

This prevents temp-scoped state from being persisted, while the
in-memory session state (updated by _apply_temp_state) retains the
values for the duration of the current invocation.
"""
if not event.actions or not event.actions.state_delta:
return event

Expand All @@ -128,6 +150,4 @@ def _update_session_state(self, session: Session, event: Event) -> None:
if not event.actions or not event.actions.state_delta:
return
for key, value in event.actions.state_delta.items():
if key.startswith(State.TEMP_PREFIX):
continue
session.state.update({key: value})
3 changes: 3 additions & 0 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,9 @@ async def append_event(self, session: Session, event: Event) -> Event:
if event.partial:
return event

# Apply temp state to in-memory session before trimming, so that
# subsequent agents within the same invocation can read temp values.
self._apply_temp_state(session, event)
# Trim temp state before persisting
event = self._trim_temp_delta_state(event)

Expand Down
3 changes: 3 additions & 0 deletions src/google/adk/sessions/sqlite_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,9 @@ async def append_event(self, session: Session, event: Event) -> Event:
if event.partial:
return event

# Apply temp state to in-memory session before trimming, so that
# subsequent agents within the same invocation can read temp values.
self._apply_temp_state(session, event)
# Trim temp state before persisting
event = self._trim_temp_delta_state(event)
event_timestamp = event.timestamp
Expand Down
43 changes: 34 additions & 9 deletions tests/unittests/sessions/test_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,16 +418,41 @@ async def test_temp_state_is_not_persisted_in_state_or_events(session_service):
)
await session_service.append_event(session=session, event=event)

# Refetch session and check state and event
session_got = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id='s1'
)
# Check session state does not contain temp keys
assert session_got.state.get('sk') == 'v2'
assert 'temp:k1' not in session_got.state
# Temp state IS available in the in-memory session (same invocation)
assert session.state.get('temp:k1') == 'v1'
assert session.state.get('sk') == 'v2'

# Check event as stored in session does not contain temp keys in state_delta
assert 'temp:k1' not in session_got.events[0].actions.state_delta
assert session_got.events[0].actions.state_delta.get('sk') == 'v2'
assert 'temp:k1' not in event.actions.state_delta
assert event.actions.state_delta.get('sk') == 'v2'


@pytest.mark.asyncio
async def test_temp_state_visible_across_sequential_events(session_service):
"""Temp state set by one event should be readable before the next event.

This simulates a SequentialAgent where agent-1 writes output_key='temp:out'
and agent-2 needs to read it from session.state within the same invocation.
"""
app_name = 'my_app'
user_id = 'u1'
session = await session_service.create_session(
app_name=app_name, user_id=user_id, session_id='s_seq'
)

# Agent-1 writes temp state
event1 = Event(
invocation_id='inv1',
author='agent1',
actions=EventActions(state_delta={'temp:output': 'result_from_a1'}),
)
await session_service.append_event(session=session, event=event1)

# Agent-2 should be able to read temp state from the same session object
assert session.state.get('temp:output') == 'result_from_a1'

# But the event delta should NOT contain the temp key (not persisted)
assert 'temp:output' not in event1.actions.state_delta


@pytest.mark.asyncio
Expand Down