Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from websockets.exceptions import ConnectionClosed
from websockets.exceptions import ConnectionClosedOK

from . import _output_schema_processor
from . import functions
from ...agents.base_agent import BaseAgent
from ...agents.callback_context import CallbackContext
from ...agents.invocation_context import InvocationContext
Expand All @@ -50,8 +52,6 @@
from ...tools.tool_context import ToolContext
from ...utils import model_name_utils
from ...utils.context_utils import Aclosing
from . import _output_schema_processor
from . import functions
from .audio_cache_manager import AudioCacheManager
from .functions import build_auth_request_event

Expand Down
35 changes: 33 additions & 2 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,35 @@ def __build_full_text_response(
live_session_id=self._gemini_session.session_id,
)

def _to_generate_content_usage_metadata(
self, usage_metadata: types.UsageMetadata
) -> types.GenerateContentResponseUsageMetadata:
"""Converts live API usage metadata to GenerateContentResponse usage metadata.

The live API names output tokens `response_token_count`/
`response_tokens_details`, whereas `GenerateContentResponseUsageMetadata`
names them `candidates_token_count`/`candidates_tokens_details`.

Args:
usage_metadata: The live API usage metadata.

Returns:
The converted usage metadata.
"""
return types.GenerateContentResponseUsageMetadata(
prompt_token_count=usage_metadata.prompt_token_count,
cached_content_token_count=usage_metadata.cached_content_token_count,
candidates_token_count=usage_metadata.response_token_count,
total_token_count=usage_metadata.total_token_count,
thoughts_token_count=usage_metadata.thoughts_token_count,
tool_use_prompt_token_count=usage_metadata.tool_use_prompt_token_count,
prompt_tokens_details=usage_metadata.prompt_tokens_details,
cache_tokens_details=usage_metadata.cache_tokens_details,
candidates_tokens_details=usage_metadata.response_tokens_details,
tool_use_prompt_tokens_details=usage_metadata.tool_use_prompt_tokens_details,
traffic_type=usage_metadata.traffic_type,
)

async def receive(self) -> AsyncGenerator[LlmResponse, None]:
"""Receives the model response using the llm server connection.

Expand All @@ -235,9 +264,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
logger.debug('Got LLM Live message: %s', message)
live_session_id = self._gemini_session.session_id
if message.usage_metadata:
# Tracks token usage data per model.
# Remap live token usage to GenerateContentResponse usage metadata.
yield LlmResponse(
usage_metadata=message.usage_metadata,
usage_metadata=self._to_generate_content_usage_metadata(
message.usage_metadata
),
model_version=self._model_version,
live_session_id=live_session_id,
)
Expand Down
2 changes: 1 addition & 1 deletion src/google/adk/telemetry/_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from opentelemetry import trace
import opentelemetry.context as context_api

from ..events import event as event_lib
from . import _metrics
from . import tracing
from ..events import event as event_lib

if TYPE_CHECKING:
from ..agents.base_agent import BaseAgent
Expand Down
68 changes: 66 additions & 2 deletions tests/unittests/models/test_gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,12 @@ async def mock_receive_generator():
content_response = next((r for r in responses if r.content), None)
assert content_response is not None

# The live API's `response_token_count`/`response_tokens_details` are remapped
# to `candidates_token_count`/`candidates_tokens_details`.
expected_usage = types.GenerateContentResponseUsageMetadata(
prompt_token_count=10,
cached_content_token_count=5,
candidates_token_count=None,
candidates_token_count=20,
total_token_count=35,
thoughts_token_count=2,
prompt_tokens_details=[
Expand All @@ -243,12 +245,74 @@ async def mock_receive_generator():
cache_tokens_details=[
types.ModalityTokenCount(modality='text', token_count=5)
],
candidates_tokens_details=None,
candidates_tokens_details=[
types.ModalityTokenCount(modality='text', token_count=20)
],
)
assert usage_response.usage_metadata == expected_usage
assert content_response.content == mock_content


async def test_receive_usage_metadata_remaps_output_tokens(
gemini_connection, mock_gemini_session
):
"""Test that live API output tokens are remapped to candidates_token_count."""
usage_metadata = types.UsageMetadata(
prompt_token_count=10,
cached_content_token_count=5,
response_token_count=20,
total_token_count=35,
thoughts_token_count=2,
tool_use_prompt_token_count=3,
prompt_tokens_details=[
types.ModalityTokenCount(modality='text', token_count=10)
],
cache_tokens_details=[
types.ModalityTokenCount(modality='text', token_count=5)
],
response_tokens_details=[
types.ModalityTokenCount(modality='text', token_count=20)
],
)

mock_message = mock.AsyncMock()
mock_message.usage_metadata = usage_metadata
mock_message.server_content = None
mock_message.tool_call = None
mock_message.session_resumption_update = None
mock_message.go_away = None

async def mock_receive_generator():
yield mock_message

receive_mock = mock.Mock(return_value=mock_receive_generator())
mock_gemini_session.receive = receive_mock

responses = [resp async for resp in gemini_connection.receive()]

usage_response = next((r for r in responses if r.usage_metadata), None)
assert usage_response is not None
result = usage_response.usage_metadata
assert isinstance(result, types.GenerateContentResponseUsageMetadata)
# Output tokens are remapped from response_* to candidates_*.
assert result.candidates_token_count == 20
assert result.candidates_tokens_details == [
types.ModalityTokenCount(modality='text', token_count=20)
]
# Shared fields are carried over unchanged.
assert result.prompt_token_count == 10
assert result.cached_content_token_count == 5
assert result.total_token_count == 35
assert result.thoughts_token_count == 2
assert result.tool_use_prompt_token_count == 3
assert result.prompt_tokens_details == [
types.ModalityTokenCount(modality='text', token_count=10)
]
assert result.cache_tokens_details == [
types.ModalityTokenCount(modality='text', token_count=5)
]


async def test_receive_populates_live_session_id(
gemini_connection, mock_gemini_session
):
Expand Down
Loading