diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index dddc2c83e0..eb22a83bb9 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -106,13 +106,35 @@ async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session object.""" if event.partial: return event + # Apply temp-scoped state to the in-memory session BEFORE trimming the + # event delta, so that subsequent agents within the same invocation can + # read temp values (e.g. output_key='temp:my_key' in SequentialAgent). + self._apply_temp_state(session, event) event = self._trim_temp_delta_state(event) self._update_session_state(session, event) session.events.append(event) return event + def _apply_temp_state(self, session: Session, event: Event) -> None: + """Applies temp-scoped state delta to the in-memory session state. + + Temp state is ephemeral: it lives in the session's in-memory state for + the duration of the current invocation but is NOT persisted to storage + (the event delta is trimmed separately by _trim_temp_delta_state). + """ + if not event.actions or not event.actions.state_delta: + return + for key, value in event.actions.state_delta.items(): + if key.startswith(State.TEMP_PREFIX): + session.state[key] = value + def _trim_temp_delta_state(self, event: Event) -> Event: - """Removes temporary state delta keys from the event.""" + """Removes temporary state delta keys from the event. + + This prevents temp-scoped state from being persisted, while the + in-memory session state (updated by _apply_temp_state) retains the + values for the duration of the current invocation. + """ if not event.actions or not event.actions.state_delta: return event @@ -128,6 +150,4 @@ def _update_session_state(self, session: Session, event: Event) -> None: if not event.actions or not event.actions.state_delta: return for key, value in event.actions.state_delta.items(): - if key.startswith(State.TEMP_PREFIX): - continue session.state.update({key: value}) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 24f525bae0..306c8d19d4 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -522,6 +522,9 @@ async def append_event(self, session: Session, event: Event) -> Event: if event.partial: return event + # Apply temp state to in-memory session before trimming, so that + # subsequent agents within the same invocation can read temp values. + self._apply_temp_state(session, event) # Trim temp state before persisting event = self._trim_temp_delta_state(event) diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index d23c8278cf..600f89c4b9 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -361,6 +361,9 @@ async def append_event(self, session: Session, event: Event) -> Event: if event.partial: return event + # Apply temp state to in-memory session before trimming, so that + # subsequent agents within the same invocation can read temp values. + self._apply_temp_state(session, event) # Trim temp state before persisting event = self._trim_temp_delta_state(event) event_timestamp = event.timestamp diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 25530bed89..21637f9ff5 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -418,16 +418,41 @@ async def test_temp_state_is_not_persisted_in_state_or_events(session_service): ) await session_service.append_event(session=session, event=event) - # Refetch session and check state and event - session_got = await session_service.get_session( - app_name=app_name, user_id=user_id, session_id='s1' - ) - # Check session state does not contain temp keys - assert session_got.state.get('sk') == 'v2' - assert 'temp:k1' not in session_got.state + # Temp state IS available in the in-memory session (same invocation) + assert session.state.get('temp:k1') == 'v1' + assert session.state.get('sk') == 'v2' + # Check event as stored in session does not contain temp keys in state_delta - assert 'temp:k1' not in session_got.events[0].actions.state_delta - assert session_got.events[0].actions.state_delta.get('sk') == 'v2' + assert 'temp:k1' not in event.actions.state_delta + assert event.actions.state_delta.get('sk') == 'v2' + + +@pytest.mark.asyncio +async def test_temp_state_visible_across_sequential_events(session_service): + """Temp state set by one event should be readable before the next event. + + This simulates a SequentialAgent where agent-1 writes output_key='temp:out' + and agent-2 needs to read it from session.state within the same invocation. + """ + app_name = 'my_app' + user_id = 'u1' + session = await session_service.create_session( + app_name=app_name, user_id=user_id, session_id='s_seq' + ) + + # Agent-1 writes temp state + event1 = Event( + invocation_id='inv1', + author='agent1', + actions=EventActions(state_delta={'temp:output': 'result_from_a1'}), + ) + await session_service.append_event(session=session, event=event1) + + # Agent-2 should be able to read temp state from the same session object + assert session.state.get('temp:output') == 'result_from_a1' + + # But the event delta should NOT contain the temp key (not persisted) + assert 'temp:output' not in event1.actions.state_delta @pytest.mark.asyncio