From 82f8d5ea71ea8bd14fce804255fe83ae43231509 Mon Sep 17 00:00:00 2001 From: g97iulio1609 Date: Sat, 28 Feb 2026 12:07:14 +0100 Subject: [PATCH] fix: stabilize function call IDs across streaming events MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When models don't provide function call IDs, ADK generates client-side IDs via populate_client_function_call_id(). In streaming mode, partial and final events for the same logical function call each get a fresh uuid4, causing an ID mismatch that breaks HITL (human-in-the-loop) workflows and SSE consumers that correlate function calls across chunks. Root cause: _finalize_model_response_event creates a new Event object for each llm_response chunk, and populate_client_function_call_id generates a brand-new ID every time without knowledge of prior IDs. Fix: Add an optional function_call_id_cache dict that maps (name, index) keys to previously generated IDs. The streaming loop in _run_async creates the cache before iteration and threads it through _postprocess_async → _finalize_model_response_event → populate_client_function_call_id, ensuring the same logical function call gets a stable ID across all streaming events. The cache is keyed by (name:index) to correctly handle multiple calls to the same function within a single response. Fixes #4609 --- .../adk/flows/llm_flows/base_llm_flow.py | 23 +- src/google/adk/flows/llm_flows/functions.py | 19 +- .../test_streaming_function_call_ids.py | 196 ++++++++++++++++++ 3 files changed, 232 insertions(+), 6 deletions(-) create mode 100644 tests/unittests/flows/llm_flows/test_streaming_function_call_ids.py diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 5368ca93cc..d3304e7712 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -81,6 +81,7 @@ def _finalize_model_response_event( llm_request: LlmRequest, llm_response: LlmResponse, model_response_event: Event, + function_call_id_cache: Optional[dict[str, str]] = None, ) -> Event: """Finalize and build the model response event from LLM response. @@ -91,6 +92,9 @@ def _finalize_model_response_event( llm_request: The original LLM request. llm_response: The LLM response from the model. model_response_event: The base event to populate. + function_call_id_cache: Optional dict mapping function call names to + previously generated IDs. Used to keep IDs stable across partial + and final streaming events. Returns: The finalized Event with LLM response data merged in. @@ -103,7 +107,9 @@ def _finalize_model_response_event( if finalized_event.content: function_calls = finalized_event.get_function_calls() if function_calls: - functions.populate_client_function_call_id(finalized_event) + functions.populate_client_function_call_id( + finalized_event, function_call_id_cache + ) finalized_event.long_running_tool_ids = ( functions.get_long_running_function_calls( function_calls, llm_request.tools_dict @@ -827,6 +833,9 @@ async def _run_one_step_async( author=invocation_context.agent.name, branch=invocation_context.branch, ) + # Cache maps function call names to generated IDs so that partial and + # final streaming events for the same call share a stable ID. + function_call_id_cache: dict[str, str] = {} async with Aclosing( self._call_llm_async( invocation_context, llm_request, model_response_event @@ -840,6 +849,7 @@ async def _run_one_step_async( llm_request, llm_response, model_response_event, + function_call_id_cache, ) ) as agen: async for event in agen: @@ -886,6 +896,7 @@ async def _postprocess_async( llm_request: LlmRequest, llm_response: LlmResponse, model_response_event: Event, + function_call_id_cache: Optional[dict[str, str]] = None, ) -> AsyncGenerator[Event, None]: """Postprocess after calling the LLM. @@ -894,6 +905,9 @@ async def _postprocess_async( llm_request: The original LLM request. llm_response: The LLM response from the LLM call. model_response_event: A mutable event for the LLM response. + function_call_id_cache: Optional dict mapping function call names to + previously generated IDs. Keeps IDs stable across partial and final + streaming events. Yields: A generator of events. @@ -917,7 +931,8 @@ async def _postprocess_async( # Builds the event. model_response_event = self._finalize_model_response_event( - llm_request, llm_response, model_response_event + llm_request, llm_response, model_response_event, + function_call_id_cache, ) yield model_response_event @@ -1197,9 +1212,11 @@ def _finalize_model_response_event( llm_request: LlmRequest, llm_response: LlmResponse, model_response_event: Event, + function_call_id_cache: Optional[dict[str, str]] = None, ) -> Event: return _finalize_model_response_event( - llm_request, llm_response, model_response_event + llm_request, llm_response, model_response_event, + function_call_id_cache, ) async def _resolve_toolset_auth( diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 6082e1a745..f9e86e6f79 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -181,12 +181,25 @@ def generate_client_function_call_id() -> str: return f'{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}' -def populate_client_function_call_id(model_response_event: Event) -> None: +def populate_client_function_call_id( + model_response_event: Event, + function_call_id_cache: Optional[dict[str, str]] = None, +) -> None: if not model_response_event.get_function_calls(): return - for function_call in model_response_event.get_function_calls(): + for idx, function_call in enumerate( + model_response_event.get_function_calls() + ): if not function_call.id: - function_call.id = generate_client_function_call_id() + # Use (name, index) as cache key so that two calls to the same + # function in a single response keep separate stable IDs. + cache_key = f'{function_call.name}:{idx}' + if function_call_id_cache is not None and cache_key in function_call_id_cache: + function_call.id = function_call_id_cache[cache_key] + else: + function_call.id = generate_client_function_call_id() + if function_call_id_cache is not None: + function_call_id_cache[cache_key] = function_call.id def remove_client_function_call_id(content: Optional[types.Content]) -> None: diff --git a/tests/unittests/flows/llm_flows/test_streaming_function_call_ids.py b/tests/unittests/flows/llm_flows/test_streaming_function_call_ids.py new file mode 100644 index 0000000000..a56921478f --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_streaming_function_call_ids.py @@ -0,0 +1,196 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests that function call IDs stay stable across streaming events.""" + +from google.adk.events.event import Event +from google.adk.flows.llm_flows.base_llm_flow import _finalize_model_response_event +from google.adk.flows.llm_flows.functions import populate_client_function_call_id +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.genai import types +import pytest + + +def _make_fc_response(name: str, args: dict | None = None, partial: bool = False) -> LlmResponse: + """Create an LlmResponse containing a single function call.""" + fc = types.FunctionCall(name=name, args=args or {}) + return LlmResponse( + content=types.Content(role='model', parts=[types.Part(function_call=fc)]), + partial=partial, + ) + + +def _make_multi_fc_response(calls: list[tuple[str, dict]], partial: bool = False) -> LlmResponse: + """Create an LlmResponse containing multiple function calls.""" + parts = [ + types.Part(function_call=types.FunctionCall(name=name, args=args)) + for name, args in calls + ] + return LlmResponse( + content=types.Content(role='model', parts=parts), + partial=partial, + ) + + +class TestPopulateClientFunctionCallIdWithCache: + """Tests for populate_client_function_call_id with ID caching.""" + + def test_generates_id_and_stores_in_cache(self): + event = Event(author='agent') + event.content = types.Content( + role='model', + parts=[types.Part(function_call=types.FunctionCall(name='get_weather', args={}))], + ) + cache: dict[str, str] = {} + populate_client_function_call_id(event, cache) + fc = event.get_function_calls()[0] + assert fc.id.startswith('adk-') + assert 'get_weather:0' in cache + assert cache['get_weather:0'] == fc.id + + def test_reuses_cached_id(self): + cache: dict[str, str] = {'get_weather:0': 'adk-cached-id-123'} + + event = Event(author='agent') + event.content = types.Content( + role='model', + parts=[types.Part(function_call=types.FunctionCall(name='get_weather', args={}))], + ) + populate_client_function_call_id(event, cache) + assert event.get_function_calls()[0].id == 'adk-cached-id-123' + + def test_no_cache_generates_new_id_each_time(self): + event1 = Event(author='agent') + event1.content = types.Content( + role='model', + parts=[types.Part(function_call=types.FunctionCall(name='get_weather', args={}))], + ) + event2 = Event(author='agent') + event2.content = types.Content( + role='model', + parts=[types.Part(function_call=types.FunctionCall(name='get_weather', args={}))], + ) + populate_client_function_call_id(event1) + populate_client_function_call_id(event2) + assert event1.get_function_calls()[0].id != event2.get_function_calls()[0].id + + def test_multiple_calls_same_name_get_separate_ids(self): + event = Event(author='agent') + event.content = types.Content( + role='model', + parts=[ + types.Part(function_call=types.FunctionCall(name='search', args={'q': 'a'})), + types.Part(function_call=types.FunctionCall(name='search', args={'q': 'b'})), + ], + ) + cache: dict[str, str] = {} + populate_client_function_call_id(event, cache) + fcs = event.get_function_calls() + assert fcs[0].id != fcs[1].id + assert cache['search:0'] == fcs[0].id + assert cache['search:1'] == fcs[1].id + + def test_skips_function_calls_that_already_have_ids(self): + event = Event(author='agent') + event.content = types.Content( + role='model', + parts=[types.Part(function_call=types.FunctionCall( + name='get_weather', args={}, id='server-provided-id'))], + ) + cache: dict[str, str] = {} + populate_client_function_call_id(event, cache) + assert event.get_function_calls()[0].id == 'server-provided-id' + assert len(cache) == 0 + + +class TestFinalizeModelResponseEventWithCache: + """Tests that _finalize_model_response_event preserves IDs via cache.""" + + def test_partial_and_final_share_same_function_call_id(self): + model_response_event = Event( + author='agent', + invocation_id='inv-1', + ) + llm_request = LlmRequest(model='mock', contents=[]) + cache: dict[str, str] = {} + + # Partial event + partial_response = _make_fc_response('get_weather', partial=True) + partial_event = _finalize_model_response_event( + llm_request, partial_response, model_response_event, cache, + ) + partial_id = partial_event.get_function_calls()[0].id + assert partial_id.startswith('adk-') + + # Final event — same function call must get the same ID + final_response = _make_fc_response('get_weather', partial=False) + final_event = _finalize_model_response_event( + llm_request, final_response, model_response_event, cache, + ) + final_id = final_event.get_function_calls()[0].id + assert final_id == partial_id + + def test_without_cache_ids_differ(self): + model_response_event = Event( + author='agent', + invocation_id='inv-1', + ) + llm_request = LlmRequest(model='mock', contents=[]) + + partial_response = _make_fc_response('get_weather', partial=True) + partial_event = _finalize_model_response_event( + llm_request, partial_response, model_response_event, + ) + partial_id = partial_event.get_function_calls()[0].id + + final_response = _make_fc_response('get_weather', partial=False) + final_event = _finalize_model_response_event( + llm_request, final_response, model_response_event, + ) + final_id = final_event.get_function_calls()[0].id + + # Without cache, IDs are different (this is the bug scenario) + assert final_id != partial_id + + def test_multi_function_call_streaming_preserves_all_ids(self): + model_response_event = Event( + author='agent', + invocation_id='inv-1', + ) + llm_request = LlmRequest(model='mock', contents=[]) + cache: dict[str, str] = {} + + # Partial with two function calls + partial_response = _make_multi_fc_response( + [('search', {'q': 'weather'}), ('lookup', {'id': '42'})], + partial=True, + ) + partial_event = _finalize_model_response_event( + llm_request, partial_response, model_response_event, cache, + ) + partial_ids = [fc.id for fc in partial_event.get_function_calls()] + + # Final with same two function calls + final_response = _make_multi_fc_response( + [('search', {'q': 'weather'}), ('lookup', {'id': '42'})], + partial=False, + ) + final_event = _finalize_model_response_event( + llm_request, final_response, model_response_event, cache, + ) + final_ids = [fc.id for fc in final_event.get_function_calls()] + + assert partial_ids == final_ids + assert partial_ids[0] != partial_ids[1] # different calls have different IDs