diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 6878c9a5c6..0f2f6cc31a 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -28,6 +28,8 @@ from websockets.exceptions import ConnectionClosed from websockets.exceptions import ConnectionClosedOK +from . import _output_schema_processor +from . import functions from ...agents.base_agent import BaseAgent from ...agents.callback_context import CallbackContext from ...agents.invocation_context import InvocationContext @@ -50,8 +52,6 @@ from ...tools.tool_context import ToolContext from ...utils import model_name_utils from ...utils.context_utils import Aclosing -from . import _output_schema_processor -from . import functions from .audio_cache_manager import AudioCacheManager from .functions import build_auth_request_event diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index a892a3ce0a..33d0a9d46a 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -217,6 +217,35 @@ def __build_full_text_response( live_session_id=self._gemini_session.session_id, ) + def _to_generate_content_usage_metadata( + self, usage_metadata: types.UsageMetadata + ) -> types.GenerateContentResponseUsageMetadata: + """Converts live API usage metadata to GenerateContentResponse usage metadata. + + The live API names output tokens `response_token_count`/ + `response_tokens_details`, whereas `GenerateContentResponseUsageMetadata` + names them `candidates_token_count`/`candidates_tokens_details`. + + Args: + usage_metadata: The live API usage metadata. + + Returns: + The converted usage metadata. + """ + return types.GenerateContentResponseUsageMetadata( + prompt_token_count=usage_metadata.prompt_token_count, + cached_content_token_count=usage_metadata.cached_content_token_count, + candidates_token_count=usage_metadata.response_token_count, + total_token_count=usage_metadata.total_token_count, + thoughts_token_count=usage_metadata.thoughts_token_count, + tool_use_prompt_token_count=usage_metadata.tool_use_prompt_token_count, + prompt_tokens_details=usage_metadata.prompt_tokens_details, + cache_tokens_details=usage_metadata.cache_tokens_details, + candidates_tokens_details=usage_metadata.response_tokens_details, + tool_use_prompt_tokens_details=usage_metadata.tool_use_prompt_tokens_details, + traffic_type=usage_metadata.traffic_type, + ) + async def receive(self) -> AsyncGenerator[LlmResponse, None]: """Receives the model response using the llm server connection. @@ -235,9 +264,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: logger.debug('Got LLM Live message: %s', message) live_session_id = self._gemini_session.session_id if message.usage_metadata: - # Tracks token usage data per model. + # Remap live token usage to GenerateContentResponse usage metadata. yield LlmResponse( - usage_metadata=message.usage_metadata, + usage_metadata=self._to_generate_content_usage_metadata( + message.usage_metadata + ), model_version=self._model_version, live_session_id=live_session_id, ) diff --git a/src/google/adk/telemetry/_instrumentation.py b/src/google/adk/telemetry/_instrumentation.py index 8ce2797628..ea5dac4bff 100644 --- a/src/google/adk/telemetry/_instrumentation.py +++ b/src/google/adk/telemetry/_instrumentation.py @@ -26,9 +26,9 @@ from opentelemetry import trace import opentelemetry.context as context_api -from ..events import event as event_lib from . import _metrics from . import tracing +from ..events import event as event_lib if TYPE_CHECKING: from ..agents.base_agent import BaseAgent diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index e800f2bcfd..928d9fe9bf 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -231,10 +231,12 @@ async def mock_receive_generator(): content_response = next((r for r in responses if r.content), None) assert content_response is not None + # The live API's `response_token_count`/`response_tokens_details` are remapped + # to `candidates_token_count`/`candidates_tokens_details`. expected_usage = types.GenerateContentResponseUsageMetadata( prompt_token_count=10, cached_content_token_count=5, - candidates_token_count=None, + candidates_token_count=20, total_token_count=35, thoughts_token_count=2, prompt_tokens_details=[ @@ -243,12 +245,74 @@ async def mock_receive_generator(): cache_tokens_details=[ types.ModalityTokenCount(modality='text', token_count=5) ], - candidates_tokens_details=None, + candidates_tokens_details=[ + types.ModalityTokenCount(modality='text', token_count=20) + ], ) assert usage_response.usage_metadata == expected_usage assert content_response.content == mock_content +async def test_receive_usage_metadata_remaps_output_tokens( + gemini_connection, mock_gemini_session +): + """Test that live API output tokens are remapped to candidates_token_count.""" + usage_metadata = types.UsageMetadata( + prompt_token_count=10, + cached_content_token_count=5, + response_token_count=20, + total_token_count=35, + thoughts_token_count=2, + tool_use_prompt_token_count=3, + prompt_tokens_details=[ + types.ModalityTokenCount(modality='text', token_count=10) + ], + cache_tokens_details=[ + types.ModalityTokenCount(modality='text', token_count=5) + ], + response_tokens_details=[ + types.ModalityTokenCount(modality='text', token_count=20) + ], + ) + + mock_message = mock.AsyncMock() + mock_message.usage_metadata = usage_metadata + mock_message.server_content = None + mock_message.tool_call = None + mock_message.session_resumption_update = None + mock_message.go_away = None + + async def mock_receive_generator(): + yield mock_message + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + usage_response = next((r for r in responses if r.usage_metadata), None) + assert usage_response is not None + result = usage_response.usage_metadata + assert isinstance(result, types.GenerateContentResponseUsageMetadata) + # Output tokens are remapped from response_* to candidates_*. + assert result.candidates_token_count == 20 + assert result.candidates_tokens_details == [ + types.ModalityTokenCount(modality='text', token_count=20) + ] + # Shared fields are carried over unchanged. + assert result.prompt_token_count == 10 + assert result.cached_content_token_count == 5 + assert result.total_token_count == 35 + assert result.thoughts_token_count == 2 + assert result.tool_use_prompt_token_count == 3 + assert result.prompt_tokens_details == [ + types.ModalityTokenCount(modality='text', token_count=10) + ] + assert result.cache_tokens_details == [ + types.ModalityTokenCount(modality='text', token_count=5) + ] + + async def test_receive_populates_live_session_id( gemini_connection, mock_gemini_session ):