diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 70a17f4001..fabe6f278c 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -27,6 +27,7 @@ import json import logging import mimetypes +import os import random import time from types import MappingProxyType @@ -38,6 +39,13 @@ import uuid import weakref +from google.adk.agents.callback_context import CallbackContext +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.tool_context import ToolContext +from google.adk.version import __version__ from google.api_core import client_options from google.api_core.exceptions import InternalServerError from google.api_core.exceptions import ServiceUnavailable @@ -54,16 +62,8 @@ from opentelemetry import trace import pyarrow as pa -from ..agents.callback_context import CallbackContext -from ..models.llm_request import LlmRequest -from ..models.llm_response import LlmResponse -from ..tools.base_tool import BaseTool -from ..tools.tool_context import ToolContext -from ..version import __version__ -from .base_plugin import BasePlugin - if TYPE_CHECKING: - from ..agents.invocation_context import InvocationContext + from google.adk.agents.invocation_context import InvocationContext logger: logging.Logger = logging.getLogger("google_adk." + __name__) tracer = trace.get_tracer( @@ -498,16 +498,35 @@ class BigQueryLoggerConfig: # dropped or altered). Safe to leave enabled; a version label on the # table ensures the diff runs at most once per schema version. auto_schema_upgrade: bool = True + # Automatically create per-event-type BigQuery views that unnest + # JSON columns into typed, queryable columns. + create_views: bool = True # ============================================================================== # HELPER: TRACE MANAGER (Async-Safe with ContextVars) # ============================================================================== +# NOTE: These contextvars are module-global, not plugin-instance-scoped. +# This is safe in practice for two reasons: +# 1. PluginManager enforces name-uniqueness, preventing two BQ plugin +# instances on the same Runner. +# 2. Concurrent asyncio tasks (e.g. two Runners in asyncio.gather) each +# get an isolated contextvar copy, so they don't interfere. +# The only problematic case would be two plugin instances interleaved +# within the *same* asyncio task without task boundaries — which the +# framework's PluginManager already prevents. _root_agent_name_ctx = contextvars.ContextVar( "_bq_analytics_root_agent_name", default=None ) +# Tracks the invocation_id that owns the current span stack so that +# ensure_invocation_span() can distinguish "same invocation re-entry" +# (idempotent) from "stale records from a previous invocation" (clear). +_active_invocation_id_ctx: contextvars.ContextVar[Optional[str]] = ( + contextvars.ContextVar("_bq_analytics_active_invocation_id", default=None) +) + @dataclass class _SpanRecord: @@ -553,12 +572,13 @@ def _get_records() -> list[_SpanRecord]: @staticmethod def init_trace(callback_context: CallbackContext) -> None: - if _root_agent_name_ctx.get() is None: - try: - root_agent = callback_context._invocation_context.agent.root_agent - _root_agent_name_ctx.set(root_agent.name) - except (AttributeError, ValueError): - pass + # Always refresh root_agent_name — it can change between + # invocations (e.g. different root agents in the same task). + try: + root_agent = callback_context._invocation_context.agent.root_agent + _root_agent_name_ctx.set(root_agent.name) + except (AttributeError, ValueError): + pass # Ensure records stack is initialized TraceManager._get_records() @@ -600,7 +620,16 @@ def push_span( # Create the span without attaching it to the ambient context. # This avoids re-parenting framework spans like ``call_llm`` # or ``execute_tool``. See #4561. - span = tracer.start_span(span_name) + # + # If the internal stack already has a span, create the new span + # as a child so it shares the same trace_id. Without this, each + # ``start_span`` would be an independent root with its own + # trace_id — causing trace_id fracture (see #4645). + records = TraceManager._get_records() + parent_ctx = None + if records and records[-1].span.get_span_context().is_valid: + parent_ctx = trace.set_span_in_context(records[-1].span) + span = tracer.start_span(span_name, context=parent_ctx) if span.get_span_context().is_valid: span_id_str = format(span.get_span_context().span_id, "016x") @@ -614,7 +643,6 @@ def push_span( start_time_ns=time.time_ns(), ) - records = TraceManager._get_records() new_records = list(records) + [record] _span_records_ctx.set(new_records) @@ -651,6 +679,49 @@ def attach_current_span( return span_id_str + @staticmethod + def ensure_invocation_span( + callback_context: CallbackContext, + ) -> None: + """Ensures a root span exists on the plugin stack for this invocation. + + Must be called before any events are logged so that every event in + the invocation shares the same trace_id. + + * If the stack has entries for the *current* invocation → no-op + (idempotent within the same invocation). + * If the stack has entries from a *different* invocation → clear + stale records and re-initialise (safety net for abnormal exit). + * If the ambient OTel span is valid → ``attach_current_span`` + (reuse the runner's span without owning it). + * Otherwise → ``push_span("invocation")`` (create a new root + span that will be popped in ``after_run_callback``). + """ + current_inv = callback_context.invocation_id + active_inv = _active_invocation_id_ctx.get() + + records = _span_records_ctx.get() + if records: + if active_inv == current_inv: + return # Already initialised for this invocation. + # Stale records from a previous invocation that wasn't cleaned + # up (e.g. exception skipped after_run_callback). Clear and + # re-init. + logger.debug( + "Clearing %d stale span records from previous invocation.", + len(records), + ) + TraceManager.clear_stack() + + _active_invocation_id_ctx.set(current_inv) + + # Check for a valid ambient span (e.g. the Runner's invocation span). + ambient = trace.get_current_span() + if ambient.get_span_context().is_valid: + TraceManager.attach_current_span(callback_context) + else: + TraceManager.push_span(callback_context, "invocation") + @staticmethod def pop_span() -> tuple[Optional[str], Optional[int]]: """Ends the current span and pops it from the stack. @@ -679,6 +750,17 @@ def pop_span() -> tuple[Optional[str], Optional[int]]: return record.span_id, duration_ms + @staticmethod + def clear_stack() -> None: + """Clears all span records. Safety net for cross-invocation cleanup.""" + records = _span_records_ctx.get() + if records: + # End any owned spans to avoid OTel resource leaks. + for record in reversed(records): + if record.owns_span: + record.span.end() + _span_records_ctx.set([]) + @staticmethod def get_current_span_and_parent() -> tuple[Optional[str], Optional[str]]: """Gets current span_id and parent span_id.""" @@ -1581,6 +1663,115 @@ def _get_events_schema() -> list[bigquery.SchemaField]: ] +# ============================================================================== +# ANALYTICS VIEW DEFINITIONS +# ============================================================================== + +# Columns included in every per-event-type view. +_VIEW_COMMON_COLUMNS = ( + "timestamp", + "event_type", + "agent", + "session_id", + "invocation_id", + "user_id", + "trace_id", + "span_id", + "parent_span_id", + "status", + "error_message", + "is_truncated", +) + +# Per-event-type column extractions. Each value is a list of +# ``"SQL_EXPR AS alias"`` strings that will be appended after the +# common columns in the view SELECT. +_EVENT_VIEW_DEFS: dict[str, list[str]] = { + "USER_MESSAGE_RECEIVED": [], + "LLM_REQUEST": [ + "JSON_VALUE(attributes, '$.model') AS model", + "content AS request_content", + "JSON_QUERY(attributes, '$.llm_config') AS llm_config", + "JSON_QUERY(attributes, '$.tools') AS tools", + ], + "LLM_RESPONSE": [ + "JSON_QUERY(content, '$.response') AS response", + ( + "CAST(JSON_VALUE(content, '$.usage.prompt')" + " AS INT64) AS usage_prompt_tokens" + ), + ( + "CAST(JSON_VALUE(content, '$.usage.completion')" + " AS INT64) AS usage_completion_tokens" + ), + ( + "CAST(JSON_VALUE(content, '$.usage.total')" + " AS INT64) AS usage_total_tokens" + ), + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ( + "CAST(JSON_VALUE(latency_ms," + " '$.time_to_first_token_ms') AS INT64) AS ttft_ms" + ), + "JSON_VALUE(attributes, '$.model_version') AS model_version", + "JSON_QUERY(attributes, '$.usage_metadata') AS usage_metadata", + ], + "LLM_ERROR": [ + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ], + "TOOL_STARTING": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + "JSON_VALUE(content, '$.tool_origin') AS tool_origin", + ], + "TOOL_COMPLETED": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.result') AS tool_result", + "JSON_VALUE(content, '$.tool_origin') AS tool_origin", + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ], + "TOOL_ERROR": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + "JSON_VALUE(content, '$.tool_origin') AS tool_origin", + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ], + "AGENT_STARTING": [ + "JSON_VALUE(content, '$.text_summary') AS agent_instruction", + ], + "AGENT_COMPLETED": [ + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ], + "INVOCATION_STARTING": [], + "INVOCATION_COMPLETED": [], + "STATE_DELTA": [ + "JSON_QUERY(attributes, '$.state_delta') AS state_delta", + ], + "HITL_CREDENTIAL_REQUEST": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + ], + "HITL_CONFIRMATION_REQUEST": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + ], + "HITL_INPUT_REQUEST": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + ], +} + +_VIEW_SQL_TEMPLATE = """\ +CREATE OR REPLACE VIEW `{project}.{dataset}.{view_name}` AS +SELECT + {columns} +FROM + `{project}.{dataset}.{table}` +WHERE + event_type = '{event_type}' +""" + + # ============================================================================== # MAIN PLUGIN # ============================================================================== @@ -1592,7 +1783,7 @@ class _LoopState: batch_processor: BatchProcessor -@dataclass +@dataclass(kw_only=True) class EventData: """Typed container for structured fields passed to _log_event.""" @@ -1606,6 +1797,7 @@ class EventData: status: str = "OK" error_message: Optional[str] = None extra_attributes: dict[str, Any] = field(default_factory=dict) + trace_id_override: Optional[str] = None class BigQueryAgentAnalyticsPlugin(BasePlugin): @@ -1650,6 +1842,7 @@ def __init__( self.location = location self._started = False + self._startup_error: Optional[Exception] = None self._is_shutting_down = False self._setup_lock = None self.client = None @@ -1660,6 +1853,7 @@ def __init__( self.parser: Optional[HybridContentParser] = None self._schema = None self.arrow_schema = None + self._init_pid = os.getpid() def _cleanup_stale_loop_states(self) -> None: """Removes entries for event loops that have been closed.""" @@ -1912,6 +2106,8 @@ def _ensure_schema_exists(self) -> None: existing_table = self.client.get_table(self.full_table_id) if self.config.auto_schema_upgrade: self._maybe_upgrade_schema(existing_table) + if self.config.create_views: + self._create_analytics_views() except cloud_exceptions.NotFound: logger.info("Table %s not found, creating table.", self.full_table_id) tbl = bigquery.Table(self.full_table_id, schema=self._schema) @@ -1921,10 +2117,13 @@ def _ensure_schema_exists(self) -> None: ) tbl.clustering_fields = self.config.clustering_fields tbl.labels = {_SCHEMA_VERSION_LABEL_KEY: _SCHEMA_VERSION} + table_ready = False try: self.client.create_table(tbl) + table_ready = True except cloud_exceptions.Conflict: - pass + # Another process created it concurrently — still usable. + table_ready = True except Exception as e: logger.error( "Could not create table %s: %s", @@ -1932,6 +2131,8 @@ def _ensure_schema_exists(self) -> None: e, exc_info=True, ) + if table_ready and self.config.create_views: + self._create_analytics_views() except Exception as e: logger.error( "Error checking for table %s: %s", @@ -1980,6 +2181,50 @@ def _maybe_upgrade_schema(self, existing_table: bigquery.Table) -> None: exc_info=True, ) + def _create_analytics_views(self) -> None: + """Creates per-event-type BigQuery views (idempotent). + + Each view filters the events table by ``event_type`` and + extracts JSON columns into typed, queryable columns. Uses + ``CREATE OR REPLACE VIEW`` so it is safe to call repeatedly. + Errors are logged but never raised. + """ + for event_type, extra_cols in _EVENT_VIEW_DEFS.items(): + view_name = "v_" + event_type.lower() + columns = ",\n ".join(list(_VIEW_COMMON_COLUMNS) + extra_cols) + sql = _VIEW_SQL_TEMPLATE.format( + project=self.project_id, + dataset=self.dataset_id, + view_name=view_name, + columns=columns, + table=self.table_id, + event_type=event_type, + ) + try: + self.client.query(sql).result() + except Exception as e: + logger.error( + "Failed to create view %s: %s", + view_name, + e, + exc_info=True, + ) + + async def create_analytics_views(self) -> None: + """Public async helper to (re-)create all analytics views. + + Useful when views need to be refreshed explicitly, for example + after a schema upgrade. Ensures the plugin is initialized + before attempting view creation. + """ + await self._ensure_started() + if not self._started: + raise RuntimeError( + "Plugin initialization failed; cannot create analytics views." + ) from self._startup_error + loop = asyncio.get_running_loop() + await loop.run_in_executor(self._executor, self._create_analytics_views) + async def shutdown(self, timeout: float | None = None) -> None: """Shuts down the plugin and releases resources. @@ -2031,13 +2276,39 @@ def __getstate__(self): state["offloader"] = None state["parser"] = None state["_started"] = False + state["_startup_error"] = None state["_is_shutting_down"] = False + state["_init_pid"] = 0 return state def __setstate__(self, state): """Custom unpickling to restore state.""" + # Backfill keys that may be absent in pickled state from older + # code versions so _ensure_started does not raise AttributeError. + state.setdefault("_init_pid", 0) self.__dict__.update(state) + def _reset_runtime_state(self) -> None: + """Resets all runtime state after a fork. + + gRPC channels and asyncio locks are not safe to use after + ``os.fork()``. This method clears them so the next call to + ``_ensure_started()`` re-initializes everything in the child + process. Pure-data fields like ``_schema`` and + ``arrow_schema`` are kept because they are safe across fork. + """ + self._setup_lock = None + self.client = None + self._loop_state_by_loop = {} + self._write_stream_name = None + self._executor = None + self.offloader = None + self.parser = None + self._started = False + self._startup_error = None + self._is_shutting_down = False + self._init_pid = os.getpid() + async def __aenter__(self) -> BigQueryAgentAnalyticsPlugin: await self._ensure_started() return self @@ -2047,6 +2318,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: async def _ensure_started(self, **kwargs) -> None: """Ensures that the plugin is started and initialized.""" + if os.getpid() != self._init_pid: + self._reset_runtime_state() if not self._started: # Kept original lock name as it was not explicitly changed. if self._setup_lock is None: @@ -2056,31 +2329,59 @@ async def _ensure_started(self, **kwargs) -> None: try: await self._lazy_setup(**kwargs) self._started = True + self._startup_error = None except Exception as e: + self._startup_error = e logger.error("Failed to initialize BigQuery Plugin: %s", e) @staticmethod - def _resolve_span_ids( + def _resolve_ids( event_data: EventData, - ) -> tuple[str, str]: - """Reads span/parent overrides from EventData, falling back to TraceManager. + callback_context: CallbackContext, + ) -> tuple[Optional[str], Optional[str], Optional[str]]: + """Resolves trace_id, span_id, and parent_span_id for a log row. + + Priority order (highest first): + 1. Explicit ``EventData`` overrides (needed for post-pop callbacks). + 2. Ambient OTel span (the framework's ``start_as_current_span``). + When present this aligns BQ rows with Cloud Trace / o11y. + 3. Plugin's internal span stack (``TraceManager``). + 4. ``invocation_id`` fallback for trace_id. Returns: - (span_id, parent_span_id) + (trace_id, span_id, parent_span_id) """ - current_span_id, current_parent_span_id = ( + # --- Layer 3: plugin stack baseline --- + trace_id = TraceManager.get_trace_id(callback_context) + plugin_span_id, plugin_parent_span_id = ( TraceManager.get_current_span_and_parent() ) - - span_id = current_span_id + span_id = plugin_span_id + parent_span_id = plugin_parent_span_id + + # --- Layer 2: ambient OTel span --- + ambient = trace.get_current_span() + ambient_ctx = ambient.get_span_context() + if ambient_ctx.is_valid: + trace_id = format(ambient_ctx.trace_id, "032x") + span_id = format(ambient_ctx.span_id, "016x") + # Reset parent — stale plugin-stack parent must not leak through + # when the ambient span is a root (no parent). + parent_span_id = None + # SDK spans expose .parent; non-recording spans do not. + parent_ctx = getattr(ambient, "parent", None) + if parent_ctx is not None and parent_ctx.span_id: + parent_span_id = format(parent_ctx.span_id, "016x") + + # --- Layer 1: explicit EventData overrides --- + if event_data.trace_id_override is not None: + trace_id = event_data.trace_id_override if event_data.span_id_override is not None: span_id = event_data.span_id_override - - parent_span_id = current_parent_span_id if event_data.parent_span_id_override is not None: parent_span_id = event_data.parent_span_id_override - return span_id, parent_span_id + return trace_id, span_id, parent_span_id @staticmethod def _extract_latency( @@ -2193,8 +2494,9 @@ async def _log_event( except Exception as e: logger.warning("Content formatter failed: %s", e) - trace_id = TraceManager.get_trace_id(callback_context) - span_id, parent_span_id = self._resolve_span_ids(event_data) + trace_id, span_id, parent_span_id = self._resolve_ids( + event_data, callback_context + ) if not self.parser: logger.warning("Parser not initialized; skipping event %s.", event_type) @@ -2261,6 +2563,7 @@ async def on_user_message_callback( user_message: The message content received from the user. """ callback_ctx = CallbackContext(invocation_context) + TraceManager.ensure_invocation_span(callback_ctx) await self._log_event( "USER_MESSAGE_RECEIVED", callback_ctx, @@ -2395,9 +2698,11 @@ async def before_run_callback( invocation_context: The context of the current invocation. """ await self._ensure_started() + callback_ctx = CallbackContext(invocation_context) + TraceManager.ensure_invocation_span(callback_ctx) await self._log_event( "INVOCATION_STARTING", - CallbackContext(invocation_context), + callback_ctx, ) @_safe_callback @@ -2409,12 +2714,40 @@ async def after_run_callback( Args: invocation_context: The context of the current invocation. """ - await self._log_event( - "INVOCATION_COMPLETED", - CallbackContext(invocation_context), - ) - # Ensure all logs are flushed before the agent returns - await self.flush() + try: + # Capture trace_id BEFORE popping the invocation-root span so + # that INVOCATION_COMPLETED shares the same trace_id as all + # earlier events in this invocation (fixes #4645). + callback_ctx = CallbackContext(invocation_context) + trace_id = TraceManager.get_trace_id(callback_ctx) + + # Pop the invocation-root span pushed by ensure_invocation_span. + span_id, duration = TraceManager.pop_span() + parent_span_id = TraceManager.get_current_span_id() + + # Only override span IDs when no ambient OTel span exists. + # When ambient exists, _resolve_ids Layer 2 uses the framework's + # span IDs, keeping STARTING/COMPLETED pairs consistent. + has_ambient = trace.get_current_span().get_span_context().is_valid + + await self._log_event( + "INVOCATION_COMPLETED", + callback_ctx, + event_data=EventData( + trace_id_override=trace_id, + latency_ms=duration, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=(None if has_ambient else parent_span_id), + ), + ) + finally: + # Cleanup must run even if _log_event raises, otherwise + # stale invocation metadata leaks into the next invocation. + TraceManager.clear_stack() + _active_invocation_id_ctx.set(None) + _root_agent_name_ctx.set(None) + # Ensure all logs are flushed before the agent returns. + await self.flush() @_safe_callback async def before_agent_callback( @@ -2445,18 +2778,20 @@ async def after_agent_callback( callback_context: The callback context. """ span_id, duration = TraceManager.pop_span() - # When popping, the current stack now points to parent. - # The event we are logging ("AGENT_COMPLETED") belongs to the span we just popped. - # So we must override span_id to be the popped span, and parent to be current top of stack. parent_span_id, _ = TraceManager.get_current_span_and_parent() + # Only override span IDs when no ambient OTel span exists. + # When ambient exists, _resolve_ids Layer 2 uses the framework's + # span IDs, keeping STARTING/COMPLETED pairs consistent. + has_ambient = trace.get_current_span().get_span_context().is_valid + await self._log_event( "AGENT_COMPLETED", callback_context, event_data=EventData( latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=(None if has_ambient else parent_span_id), ), ) @@ -2606,6 +2941,12 @@ async def after_model_callback( # Otherwise log_event will fetch current stack (which is parent). span_id = popped_span_id or span_id + # Only override span IDs when no ambient OTel span exists. + # When ambient exists, _resolve_ids Layer 2 uses the framework's + # span IDs, keeping LLM_REQUEST/LLM_RESPONSE pairs consistent. + has_ambient = trace.get_current_span().get_span_context().is_valid + use_override = is_popped and not has_ambient + await self._log_event( "LLM_RESPONSE", callback_context, @@ -2616,8 +2957,8 @@ async def after_model_callback( time_to_first_token_ms=tfft, model_version=llm_response.model_version, usage_metadata=llm_response.usage_metadata, - span_id_override=span_id if is_popped else None, - parent_span_id_override=(parent_span_id if is_popped else None), + span_id_override=span_id if use_override else None, + parent_span_id_override=(parent_span_id if use_override else None), ), ) @@ -2638,14 +2979,18 @@ async def on_model_error_callback( """ span_id, duration = TraceManager.pop_span() parent_span_id, _ = TraceManager.get_current_span_and_parent() + + # Only override span IDs when no ambient OTel span exists. + has_ambient = trace.get_current_span().get_span_context().is_valid + await self._log_event( "LLM_ERROR", callback_context, event_data=EventData( error_message=str(error), latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=(None if has_ambient else parent_span_id), ), ) @@ -2710,10 +3055,13 @@ async def after_tool_callback( span_id, duration = TraceManager.pop_span() parent_span_id, _ = TraceManager.get_current_span_and_parent() + # Only override span IDs when no ambient OTel span exists. + has_ambient = trace.get_current_span().get_span_context().is_valid + event_data = EventData( latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=(None if has_ambient else parent_span_id), ) await self._log_event( "TOOL_COMPLETED", @@ -2749,7 +3097,12 @@ async def on_tool_error_callback( "args": args_truncated, "tool_origin": tool_origin, } - _, duration = TraceManager.pop_span() + span_id, duration = TraceManager.pop_span() + parent_span_id, _ = TraceManager.get_current_span_and_parent() + + # Only override span IDs when no ambient OTel span exists. + has_ambient = trace.get_current_span().get_span_context().is_valid + await self._log_event( "TOOL_ERROR", tool_context, @@ -2758,5 +3111,7 @@ async def on_tool_error_callback( event_data=EventData( error_message=str(error), latency_ms=duration, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=(None if has_ambient else parent_span_id), ), ) diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 549263fbae..5d87a17cd9 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -20,8 +20,8 @@ from unittest import mock from google.adk.agents import base_agent -from google.adk.agents import callback_context as callback_context_lib -from google.adk.agents import invocation_context as invocation_context_lib +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.invocation_context import InvocationContext from google.adk.events import event as event_lib from google.adk.events import event_actions as event_actions_lib from google.adk.models import llm_request as llm_request_lib @@ -83,7 +83,7 @@ def invocation_context(mock_agent, mock_session): mock_plugin_manager = mock.create_autospec( plugin_manager_lib.PluginManager, instance=True, spec_set=True ) - return invocation_context_lib.InvocationContext( + return InvocationContext( agent=mock_agent, session=mock_session, invocation_id="inv-789", @@ -94,9 +94,7 @@ def invocation_context(mock_agent, mock_session): @pytest.fixture def callback_context(invocation_context): - return callback_context_lib.CallbackContext( - invocation_context=invocation_context - ) + return CallbackContext(invocation_context=invocation_context) @pytest.fixture @@ -2152,7 +2150,7 @@ async def test_otel_integration( span_id = bigquery_agent_analytics_plugin.TraceManager.push_span( callback_context, "test_span" ) - mock_tracer.start_span.assert_called_with("test_span") + mock_tracer.start_span.assert_called_with("test_span", context=None) assert span_id == format(span_id_int, "016x") # Test get_trace_id # We need to mock trace.get_current_span() to return our mock span @@ -3018,81 +3016,221 @@ async def test_no_config_no_labels( assert "labels" not in attributes -class TestResolveSpanIds: - """Tests for the _resolve_span_ids static helper.""" +class TestResolveIds: + """Tests for the _resolve_ids static helper.""" + + def _resolve(self, ed, callback_context): + return bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_ids( + ed, callback_context + ) - def test_uses_trace_manager_defaults(self): - """Should use TraceManager values when no overrides provided.""" + def test_uses_trace_manager_defaults(self, callback_context): + """Should use TraceManager values when no overrides and no ambient.""" ed = bigquery_agent_analytics_plugin.EventData( extra_attributes={"some_key": "value"} ) - with mock.patch.object( - bigquery_agent_analytics_plugin.TraceManager, - "get_current_span_and_parent", - return_value=("span-1", "parent-1"), + with ( + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ), + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_trace_id", + return_value="trace-1", + ), ): - span_id, parent_id = ( - bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( - ed - ) - ) + trace_id, span_id, parent_id = self._resolve(ed, callback_context) + assert trace_id == "trace-1" assert span_id == "span-1" assert parent_id == "parent-1" - def test_span_id_override(self): + def test_span_id_override(self, callback_context): """Should use span_id_override from EventData.""" ed = bigquery_agent_analytics_plugin.EventData( span_id_override="custom-span" ) - with mock.patch.object( - bigquery_agent_analytics_plugin.TraceManager, - "get_current_span_and_parent", - return_value=("span-1", "parent-1"), + with ( + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ), + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_trace_id", + return_value="trace-1", + ), ): - span_id, parent_id = ( - bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( - ed - ) - ) + trace_id, span_id, parent_id = self._resolve(ed, callback_context) assert span_id == "custom-span" assert parent_id == "parent-1" - def test_parent_span_id_override(self): + def test_parent_span_id_override(self, callback_context): """Should use parent_span_id_override from EventData.""" ed = bigquery_agent_analytics_plugin.EventData( parent_span_id_override="custom-parent" ) - with mock.patch.object( - bigquery_agent_analytics_plugin.TraceManager, - "get_current_span_and_parent", - return_value=("span-1", "parent-1"), + with ( + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ), + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_trace_id", + return_value="trace-1", + ), ): - span_id, parent_id = ( - bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( - ed - ) - ) + trace_id, span_id, parent_id = self._resolve(ed, callback_context) assert span_id == "span-1" assert parent_id == "custom-parent" - def test_none_override_keeps_default(self): + def test_none_override_keeps_default(self, callback_context): """None overrides should keep the TraceManager defaults.""" ed = bigquery_agent_analytics_plugin.EventData( span_id_override=None, parent_span_id_override=None ) - with mock.patch.object( - bigquery_agent_analytics_plugin.TraceManager, - "get_current_span_and_parent", - return_value=("span-1", "parent-1"), + with ( + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ), + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_trace_id", + return_value="trace-1", + ), ): - span_id, parent_id = ( - bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( - ed - ) - ) + trace_id, span_id, parent_id = self._resolve(ed, callback_context) assert span_id == "span-1" assert parent_id == "parent-1" + def test_ambient_otel_span_takes_priority(self, callback_context): + """When an ambient OTel span is valid, its IDs take priority.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + ed = bigquery_agent_analytics_plugin.EventData() + + with real_tracer.start_as_current_span("invocation") as parent_span: + with real_tracer.start_as_current_span("agent") as agent_span: + ambient_ctx = agent_span.get_span_context() + expected_trace = format(ambient_ctx.trace_id, "032x") + expected_span = format(ambient_ctx.span_id, "016x") + expected_parent = format(parent_span.get_span_context().span_id, "016x") + + trace_id, span_id, parent_id = self._resolve(ed, callback_context) + + assert trace_id == expected_trace + assert span_id == expected_span + assert parent_id == expected_parent + provider.shutdown() + + def test_override_beats_ambient(self, callback_context): + """EventData overrides take priority over ambient OTel span.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + ed = bigquery_agent_analytics_plugin.EventData( + trace_id_override="forced-trace", + span_id_override="forced-span", + parent_span_id_override="forced-parent", + ) + + with real_tracer.start_as_current_span("invocation"): + trace_id, span_id, parent_id = self._resolve(ed, callback_context) + + assert trace_id == "forced-trace" + assert span_id == "forced-span" + assert parent_id == "forced-parent" + provider.shutdown() + + def test_ambient_root_span_no_self_parent(self, callback_context): + """Ambient root span (no parent) must not produce self-parent.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + # Seed the plugin stack with a span so there's a stale parent. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "plugin-child" + ) + + ed = bigquery_agent_analytics_plugin.EventData() + + # Single root ambient span — no parent. + with real_tracer.start_as_current_span("root_invocation") as root: + trace_id, span_id, parent_id = self._resolve(ed, callback_context) + root_span_id = format(root.get_span_context().span_id, "016x") + + # span_id should be the ambient root's span_id + assert span_id == root_span_id + # parent must be None — not the stale plugin parent, not self + assert parent_id is None + assert span_id != parent_id + + # Cleanup + bigquery_agent_analytics_plugin.TraceManager.pop_span() + provider.shutdown() + + def test_ambient_span_used_for_completed_event(self, callback_context): + """Completed event with overrides should use ambient when present. + + When an ambient OTel span is valid, passing None overrides lets + _resolve_ids Layer 2 pick the ambient span — matching the + STARTING event's span_id. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with real_tracer.start_as_current_span("invoke_agent") as agent_span: + expected_span = format(agent_span.get_span_context().span_id, "016x") + + # Simulate STARTING: no overrides → ambient Layer 2 wins. + ed_starting = bigquery_agent_analytics_plugin.EventData() + _, span_starting, _ = self._resolve(ed_starting, callback_context) + + # Simulate COMPLETED: None overrides (ambient check passed). + ed_completed = bigquery_agent_analytics_plugin.EventData( + span_id_override=None, + parent_span_id_override=None, + latency_ms=42, + ) + _, span_completed, _ = self._resolve(ed_completed, callback_context) + + assert span_starting == expected_span + assert span_completed == expected_span + assert span_starting == span_completed + + provider.shutdown() + class TestExtractLatency: """Tests for the _extract_latency static helper.""" @@ -3282,7 +3420,7 @@ def _make_invocation_context(agent_name, session, invocation_id="inv-001"): instance=True, spec_set=True, ) - return invocation_context_lib.InvocationContext( + return InvocationContext( agent=mock_a, session=session, invocation_id=invocation_id, @@ -3488,7 +3626,7 @@ async def test_full_subagent_callback_sequence( """ session = self._make_session() inv_ctx = self._make_invocation_context("schema_explorer", session) - cb_ctx = callback_context_lib.CallbackContext(invocation_context=inv_ctx) + cb_ctx = CallbackContext(invocation_context=inv_ctx) tool_ctx = tool_context_lib.ToolContext(invocation_context=inv_ctx) mock_agent = inv_ctx.agent tool = self._make_tool("get_table_info") @@ -3766,9 +3904,7 @@ async def test_multi_turn_multi_subagent_full_sequence( inv_ctx_t1_orch = self._make_invocation_context( "orchestrator", session, invocation_id="inv-t1" ) - cb_ctx_t1_orch = callback_context_lib.CallbackContext( - invocation_context=inv_ctx_t1_orch - ) + cb_ctx_t1_orch = CallbackContext(invocation_context=inv_ctx_t1_orch) # Orchestrator agent_starting await plugin.before_agent_callback( @@ -3781,9 +3917,7 @@ async def test_multi_turn_multi_subagent_full_sequence( inv_ctx_t1_sub = self._make_invocation_context( "schema_explorer", session, invocation_id="inv-t1" ) - cb_ctx_t1_sub = callback_context_lib.CallbackContext( - invocation_context=inv_ctx_t1_sub - ) + cb_ctx_t1_sub = CallbackContext(invocation_context=inv_ctx_t1_sub) tool_ctx_t1 = tool_context_lib.ToolContext( invocation_context=inv_ctx_t1_sub ) @@ -3831,9 +3965,7 @@ async def test_multi_turn_multi_subagent_full_sequence( inv_ctx_t2_orch = self._make_invocation_context( "orchestrator", session, invocation_id="inv-t2" ) - cb_ctx_t2_orch = callback_context_lib.CallbackContext( - invocation_context=inv_ctx_t2_orch - ) + cb_ctx_t2_orch = CallbackContext(invocation_context=inv_ctx_t2_orch) await plugin.before_agent_callback( agent=inv_ctx_t2_orch.agent, @@ -3845,9 +3977,7 @@ async def test_multi_turn_multi_subagent_full_sequence( inv_ctx_t2_sub = self._make_invocation_context( "image_describer", session, invocation_id="inv-t2" ) - cb_ctx_t2_sub = callback_context_lib.CallbackContext( - invocation_context=inv_ctx_t2_sub - ) + cb_ctx_t2_sub = CallbackContext(invocation_context=inv_ctx_t2_sub) tool_ctx_t2 = tool_context_lib.ToolContext( invocation_context=inv_ctx_t2_sub ) @@ -4665,3 +4795,1265 @@ def regular_tool() -> str: ), f"Expected no HITL events for regular tool, got {hitl_events}" await bq_plugin.shutdown() + + +# ============================================================================== +# Fork-Safety Tests +# ============================================================================== +class TestForkSafety: + """Tests for fork-safety via PID tracking.""" + + def _make_plugin(self): + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig() + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + return plugin + + @pytest.mark.asyncio + async def test_pid_change_triggers_reinit( + self, mock_auth_default, mock_bq_client, mock_write_client + ): + """Simulating a fork by changing _init_pid forces re-init.""" + plugin = self._make_plugin() + await plugin._ensure_started() + assert plugin._started is True + + # Simulate a fork: set _init_pid to a stale value + plugin._init_pid = -1 + assert plugin._started is True # still True before check + + # _ensure_started should detect PID mismatch and reset + await plugin._ensure_started() + # After reset + re-init, _init_pid should match current + import os + + assert plugin._init_pid == os.getpid() + assert plugin._started is True + await plugin.shutdown() + + @pytest.mark.asyncio + async def test_pid_unchanged_skips_reset( + self, mock_auth_default, mock_bq_client, mock_write_client + ): + """Same PID should not trigger a reset.""" + plugin = self._make_plugin() + await plugin._ensure_started() + + # Save references to verify they are not recreated + original_client = plugin.client + original_parser = plugin.parser + + await plugin._ensure_started() + assert plugin.client is original_client + assert plugin.parser is original_parser + await plugin.shutdown() + + def test_reset_runtime_state_clears_fields(self): + """_reset_runtime_state clears all runtime fields.""" + plugin = self._make_plugin() + # Fake some runtime state + plugin._started = True + plugin._is_shutting_down = True + plugin.client = mock.MagicMock() + plugin._loop_state_by_loop = {"fake": "state"} + plugin._write_stream_name = "some/stream" + plugin._executor = mock.MagicMock() + plugin.offloader = mock.MagicMock() + plugin.parser = mock.MagicMock() + plugin._setup_lock = mock.MagicMock() + # Keep pure-data fields + plugin._schema = ["kept"] + plugin.arrow_schema = "kept_arrow" + + plugin._reset_runtime_state() + + assert plugin._started is False + assert plugin._is_shutting_down is False + assert plugin.client is None + assert plugin._loop_state_by_loop == {} + assert plugin._write_stream_name is None + assert plugin._executor is None + assert plugin.offloader is None + assert plugin.parser is None + assert plugin._setup_lock is None + # Pure-data fields are preserved + assert plugin._schema == ["kept"] + assert plugin.arrow_schema == "kept_arrow" + + import os + + assert plugin._init_pid == os.getpid() + + def test_getstate_resets_pid(self): + """Pickle state should have _init_pid = 0 to force re-init.""" + plugin = self._make_plugin() + state = plugin.__getstate__() + assert state["_init_pid"] == 0 + assert state["_started"] is False + + @pytest.mark.asyncio + async def test_unpickle_legacy_state_missing_init_pid( + self, mock_auth_default, mock_bq_client, mock_write_client + ): + """Unpickling state from older code without _init_pid should not crash.""" + plugin = self._make_plugin() + state = plugin.__getstate__() + # Simulate legacy pickle state that lacks _init_pid entirely + del state["_init_pid"] + + new_plugin = ( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin.__new__( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin + ) + ) + new_plugin.__setstate__(state) + + # _init_pid should be backfilled to 0, triggering re-init + assert new_plugin._init_pid == 0 + # _ensure_started should not raise AttributeError + await new_plugin._ensure_started() + assert new_plugin._started is True + await new_plugin.shutdown() + + +# ============================================================================== +# Analytics Views Tests +# ============================================================================== +class TestAnalyticsViews: + """Tests for auto-created per-event-type BigQuery views.""" + + def _make_plugin(self, create_views=True): + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig( + create_views=create_views, + ) + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + plugin.client = mock.MagicMock() + plugin.full_table_id = f"{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}" + plugin._schema = bigquery_agent_analytics_plugin._get_events_schema() + return plugin + + def test_views_created_on_new_table(self): + """NotFound path creates all views.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + mock_query_job = mock.MagicMock() + plugin.client.query.return_value = mock_query_job + + plugin._ensure_schema_exists() + + expected_count = len(bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS) + assert plugin.client.query.call_count == expected_count + + def test_views_created_for_existing_table(self): + """Existing table path also creates views.""" + plugin = self._make_plugin(create_views=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = plugin._schema + existing.labels = { + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY: ( + bigquery_agent_analytics_plugin._SCHEMA_VERSION + ), + } + plugin.client.get_table.return_value = existing + mock_query_job = mock.MagicMock() + plugin.client.query.return_value = mock_query_job + + plugin._ensure_schema_exists() + + expected_count = len(bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS) + assert plugin.client.query.call_count == expected_count + + def test_views_not_created_when_disabled(self): + """create_views=False skips view creation.""" + plugin = self._make_plugin(create_views=False) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + + plugin._ensure_schema_exists() + + plugin.client.query.assert_not_called() + + def test_view_creation_error_logged_not_raised(self): + """Errors during view creation don't crash the plugin.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + plugin.client.query.side_effect = Exception("BQ error") + + # Should not raise + plugin._ensure_schema_exists() + + # Verify it tried to create views (and failed gracefully) + assert plugin.client.query.call_count > 0 + + def test_view_sql_contains_correct_event_filter(self): + """Each SQL has correct WHERE clause and view name.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + mock_query_job = mock.MagicMock() + plugin.client.query.return_value = mock_query_job + + plugin._ensure_schema_exists() + + calls = plugin.client.query.call_args_list + for call in calls: + sql = call[0][0] + # Each SQL should have CREATE OR REPLACE VIEW + assert "CREATE OR REPLACE VIEW" in sql + # Each SQL should filter by event_type + assert "WHERE" in sql + assert "event_type = " in sql + # View name should start with v_ + assert ".v_" in sql + + # Verify specific views exist + all_sql = " ".join(c[0][0] for c in calls) + for event_type in bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS: + view_name = "v_" + event_type.lower() + assert view_name in all_sql, f"View {view_name} not found in SQL" + + def test_config_create_views_default_true(self): + """Config create_views defaults to True.""" + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig() + assert config.create_views is True + + @pytest.mark.asyncio + async def test_create_analytics_views_ensures_started( + self, mock_auth_default, mock_bq_client, mock_write_client + ): + """Public create_analytics_views() initializes plugin first.""" + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + assert plugin._started is False + + await plugin.create_analytics_views() + + # Plugin should be started after the call + assert plugin._started is True + # Views should have been created (query called) + expected_count = len(bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS) + # _ensure_schema_exists also creates views, so total calls + # = schema-creation views + explicit views + assert mock_bq_client.query.call_count >= expected_count + await plugin.shutdown() + + def test_views_not_created_after_table_creation_failure(self): + """View creation is skipped when create_table raises a non-Conflict error.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + plugin.client.create_table.side_effect = RuntimeError("BQ down") + + plugin._ensure_schema_exists() + + # Views should NOT be attempted since table creation failed + plugin.client.query.assert_not_called() + + @pytest.mark.asyncio + async def test_create_analytics_views_raises_on_startup_failure( + self, mock_auth_default, mock_write_client + ): + """create_analytics_views() raises if plugin init fails.""" + # Make the BQ Client constructor raise so _lazy_setup fails + # before _started is set to True. + with mock.patch.object( + bigquery, "Client", side_effect=Exception("client boom") + ): + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + with pytest.raises( + RuntimeError, match="Plugin initialization failed" + ) as exc_info: + await plugin.create_analytics_views() + # Root cause should be chained for debuggability + assert exc_info.value.__cause__ is not None + assert "client boom" in str(exc_info.value.__cause__) + + +# ============================================================================== +# Trace-ID Continuity Tests (Issue #4645) +# ============================================================================== +class TestTraceIdContinuity: + """Tests for trace_id continuity across all events in an invocation. + + Regression tests for https://github.com/google/adk-python/issues/4645. + + When there is no ambient OTel span (e.g. Agent Engine, custom runners), + early events (USER_MESSAGE_RECEIVED, INVOCATION_STARTING) used to fall + back to ``invocation_id`` while AGENT_STARTING got a new OTel hex + trace_id from ``push_span()``. The ``ensure_invocation_span()`` fix + guarantees a root span is always on the stack before any events fire. + """ + + @pytest.mark.asyncio + async def test_trace_id_continuity_no_ambient_span(self, callback_context): + """All events share one trace_id when no ambient OTel span exists. + + Simulates the #4645 scenario: OTel IS configured (real TracerProvider) + but the Runner's ambient span is NOT present (e.g. Agent Engine, + custom runners). + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + # Create a real TracerProvider and patch the plugin's module-level + # tracer so push_span creates valid spans with proper trace_ids. + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test-plugin") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Reset the span records contextvar for a clean invocation. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient OTel span — we do NOT start_as_current_span. + ambient = trace.get_current_span() + assert not ambient.get_span_context().is_valid + + # ensure_invocation_span should push a new span. + TM.ensure_invocation_span(callback_context) + trace_id_early = TM.get_trace_id(callback_context) + assert trace_id_early is not None + # Should NOT fall back to invocation_id — it should be + # a 32-char hex OTel trace_id. + assert trace_id_early != callback_context.invocation_id + assert len(trace_id_early) == 32 + + # Simulate agent callback: push_span("agent") + TM.push_span(callback_context, "agent") + trace_id_agent = TM.get_trace_id(callback_context) + + # Both trace_ids must be identical. + assert trace_id_early == trace_id_agent + + # Cleanup + TM.pop_span() # agent + TM.pop_span() # invocation + + provider.shutdown() + + @pytest.mark.asyncio + async def test_invocation_completed_trace_continuity_no_ambient( + self, callback_context + ): + """INVOCATION_COMPLETED must share trace_id with earlier events. + + Reproduces the completion-event fracture: after_run_callback pops + the invocation span, then _log_event would resolve trace_id via + the fallback to invocation_id. The trace_id_override ensures the + completion event keeps the same trace_id. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test-plugin") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Reset for a clean invocation; no ambient span. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + assert not trace.get_current_span().get_span_context().is_valid + + # --- Simulate the full callback lifecycle --- + # 1. before_run / on_user_message: ensure invocation span + TM.ensure_invocation_span(callback_context) + trace_id_start = TM.get_trace_id(callback_context) + + # 2. before_agent: push agent span + TM.push_span(callback_context, "agent") + assert TM.get_trace_id(callback_context) == trace_id_start + + # 3. after_agent: pop agent span + TM.pop_span() + + # 4. after_run: capture trace_id THEN pop invocation span + trace_id_before_pop = TM.get_trace_id(callback_context) + assert trace_id_before_pop == trace_id_start + + TM.pop_span() + + # After popping, get_trace_id falls back to invocation_id + trace_id_after_pop = TM.get_trace_id(callback_context) + assert trace_id_after_pop == callback_context.invocation_id + + # The trace_id_override preserves continuity + assert trace_id_before_pop == trace_id_start + assert trace_id_before_pop != trace_id_after_pop + + provider.shutdown() + + @pytest.mark.asyncio + async def test_callbacks_emit_same_trace_id_no_ambient( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """Full callback path: all emitted rows share one trace_id. + + Exercises the real before_run → before_agent → after_agent → + after_run callback chain via the plugin instance, then checks + every emitted BQ row has the same trace_id. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test-plugin") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Reset span records for a clean invocation. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient span — simulates Agent Engine / custom runner. + assert not trace.get_current_span().get_span_context().is_valid + + # Run the full callback lifecycle. + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + await asyncio.sleep(0.01) + + # Collect all emitted rows. + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + event_types = [r["event_type"] for r in rows] + assert "INVOCATION_STARTING" in event_types + assert "INVOCATION_COMPLETED" in event_types + + # Every row must share the same trace_id. + trace_ids = {r["trace_id"] for r in rows} + assert len(trace_ids) == 1, ( + "Expected 1 unique trace_id across all events, got" + f" {len(trace_ids)}: {trace_ids}" + ) + # Should be a 32-char hex OTel trace, not the invocation_id. + sole_trace_id = trace_ids.pop() + assert sole_trace_id != invocation_context.invocation_id + assert len(sole_trace_id) == 32 + + provider.shutdown() + + @pytest.mark.asyncio + async def test_trace_id_continuity_with_ambient_span(self, callback_context): + """All events share one trace_id when an ambient OTel span exists.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + # Set up a real OTel tracer. + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Reset the span records contextvar. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + with real_tracer.start_as_current_span("runner_invocation"): + ambient = trace.get_current_span() + assert ambient.get_span_context().is_valid + ambient_trace_id = format(ambient.get_span_context().trace_id, "032x") + + # ensure_invocation_span should attach the ambient span. + TM.ensure_invocation_span(callback_context) + trace_id_early = TM.get_trace_id(callback_context) + assert trace_id_early == ambient_trace_id + + # Simulate agent callback: push_span("agent") + TM.push_span(callback_context, "agent") + trace_id_agent = TM.get_trace_id(callback_context) + assert trace_id_agent == ambient_trace_id + + # Cleanup + TM.pop_span() # agent + TM.pop_span() # invocation (attached, not owned) + + provider.shutdown() + + @pytest.mark.asyncio + async def test_invocation_root_span_isolated_across_turns( + self, callback_context + ): + """Each invocation gets its own root span; turns don't leak.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # --- Turn 1 --- + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + TM.ensure_invocation_span(callback_context) + trace_id_turn1 = TM.get_trace_id(callback_context) + + TM.push_span(callback_context, "agent") + assert TM.get_trace_id(callback_context) == trace_id_turn1 + TM.pop_span() # agent + TM.pop_span() # invocation + + # After popping, the stack should be empty. + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert not records + + # --- Turn 2 --- + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + TM.ensure_invocation_span(callback_context) + trace_id_turn2 = TM.get_trace_id(callback_context) + + TM.push_span(callback_context, "agent") + assert TM.get_trace_id(callback_context) == trace_id_turn2 + TM.pop_span() # agent + TM.pop_span() # invocation + + # The two turns must have DIFFERENT trace_ids (different + # root spans). + assert trace_id_turn1 != trace_id_turn2 + + provider.shutdown() + + +class TestSpanIdConsistency: + """Tests that STARTING/COMPLETED event pairs share span IDs. + + Span-ID resolution contract: + - When OTel is active: BQ rows use the same trace/span/parent IDs as + Cloud Trace (ambient framework spans). STARTING and COMPLETED events + in the same lifecycle share the same span_id. + - When OTel is not active: BQ rows use the plugin's internal span + stack. STARTING gets the current top-of-stack; COMPLETED gets the + popped span. + """ + + @pytest.mark.asyncio + async def test_starting_completed_same_span_with_ambient( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """With ambient OTel, STARTING and COMPLETED get the same span_id.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # Simulate the framework's ambient spans. + with real_tracer.start_as_current_span("invocation"): + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + with real_tracer.start_as_current_span("invoke_agent"): + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + agent_starting = [r for r in rows if r["event_type"] == "AGENT_STARTING"] + agent_completed = [ + r for r in rows if r["event_type"] == "AGENT_COMPLETED" + ] + + assert len(agent_starting) == 1 + assert len(agent_completed) == 1 + + # Both events must share the same span_id (the ambient + # invoke_agent span) — no plugin-synthetic override. + assert agent_starting[0]["span_id"] == agent_completed[0]["span_id"] + assert ( + agent_starting[0]["parent_span_id"] + == agent_completed[0]["parent_span_id"] + ) + + provider.shutdown() + + @pytest.mark.asyncio + async def test_starting_completed_use_plugin_span_without_ambient( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """Without ambient OTel, COMPLETED gets the popped plugin span.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient OTel span. + assert not trace.get_current_span().get_span_context().is_valid + + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + agent_starting = [r for r in rows if r["event_type"] == "AGENT_STARTING"] + agent_completed = [ + r for r in rows if r["event_type"] == "AGENT_COMPLETED" + ] + + assert len(agent_starting) == 1 + assert len(agent_completed) == 1 + + # AGENT_STARTING gets the top-of-stack span; AGENT_COMPLETED + # gets the popped span via override — they should match. + assert agent_starting[0]["span_id"] == agent_completed[0]["span_id"] + + provider.shutdown() + + @pytest.mark.asyncio + async def test_tool_error_captures_span_id( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + dummy_arrow_schema, + ): + """on_tool_error_callback uses the popped span_id (bonus fix).""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + mock_tool = mock.create_autospec(base_tool_lib.BaseTool, instance=True) + type(mock_tool).name = mock.PropertyMock(return_value="my_tool") + tool_ctx = tool_context_lib.ToolContext( + invocation_context=invocation_context + ) + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient OTel — plugin span stack provides IDs. + assert not trace.get_current_span().get_span_context().is_valid + + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + # Push tool span via before_tool_callback + await bq_plugin_inst.before_tool_callback( + tool=mock_tool, + tool_args={"a": 1}, + tool_context=tool_ctx, + ) + # Error callback should pop the tool span and use its ID + await bq_plugin_inst.on_tool_error_callback( + tool=mock_tool, + tool_args={"a": 1}, + tool_context=tool_ctx, + error=RuntimeError("boom"), + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + tool_starting = [r for r in rows if r["event_type"] == "TOOL_STARTING"] + tool_error = [r for r in rows if r["event_type"] == "TOOL_ERROR"] + + assert len(tool_starting) == 1 + assert len(tool_error) == 1 + + # The TOOL_ERROR event must have the same span_id as + # TOOL_STARTING (both correspond to the same tool span). + assert tool_starting[0]["span_id"] == tool_error[0]["span_id"] + assert tool_error[0]["span_id"] is not None + + provider.shutdown() + + +class TestStackLeakSafety: + """Tests for stack leak safety (P2). + + Ensures the plugin's internal span stack doesn't leak records + across invocations when after_run_callback is skipped. + """ + + def test_ensure_invocation_span_clears_stale_records(self, callback_context): + """Pre-populated stack from a different invocation is cleared.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Simulate stale records from incomplete previous invocation. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + # Mark the stale records as belonging to a different invocation. + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set( + "old-inv-stale" + ) + TM.push_span(callback_context, "stale-invocation") + TM.push_span(callback_context, "stale-agent") + + stale_records = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert len(stale_records) == 2 + + # ensure_invocation_span with the *current* invocation_id should + # detect the mismatch, clear stale records, and re-init. + TM.ensure_invocation_span(callback_context) + + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + # Should have exactly 1 fresh entry (the new invocation span). + assert len(records) == 1 + # The fresh span should NOT be one of the stale ones. + assert records[0].span_id != stale_records[0].span_id + assert records[0].span_id != stale_records[1].span_id + + provider.shutdown() + + def test_clear_stack_ends_owned_spans(self, callback_context): + """clear_stack() ends all owned spans.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + exporter = InMemorySpanExporter() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + TM.push_span(callback_context, "span-a") + TM.push_span(callback_context, "span-b") + + records = list(bigquery_agent_analytics_plugin._span_records_ctx.get()) + assert all(r.owns_span for r in records) + + TM.clear_stack() + + # Stack must be empty after clear. + result = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert result == [] + + # Both owned spans should have been ended (exported). + exported = exporter.get_finished_spans() + assert len(exported) == 2 + + provider.shutdown() + + @pytest.mark.asyncio + async def test_after_run_callback_clears_remaining_stack( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """after_run_callback clears any leftover stack entries.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient span. + assert not trace.get_current_span().get_span_context().is_valid + + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + # Push an agent span but DON'T pop it (simulate missing + # after_agent_callback due to exception). + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + # Stack now has [invocation, agent]. + + # after_run_callback should pop invocation + clear remaining. + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + # Stack must be empty. + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert records == [] + + provider.shutdown() + + @pytest.mark.asyncio + async def test_next_invocation_clean_after_incomplete_previous( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + mock_session, + ): + """Next invocation starts clean even if previous was incomplete.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + + # --- Incomplete invocation 1: no after_run_callback --- + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + # Skip after_agent and after_run — simulates exception. + + stale = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert len(stale) >= 2 # invocation + agent + + # --- Invocation 2 with a different invocation_id --- + mock_write_client.append_rows.reset_mock() + inv_ctx_2 = InvocationContext( + agent=mock_agent, + session=mock_session, + invocation_id="inv-NEW-002", + session_service=invocation_context.session_service, + plugin_manager=invocation_context.plugin_manager, + ) + await bq_plugin_inst.before_run_callback(invocation_context=inv_ctx_2) + + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + # Should have exactly 1 fresh invocation span. + assert len(records) == 1 + + # Cleanup + await bq_plugin_inst.after_run_callback(invocation_context=inv_ctx_2) + + provider.shutdown() + + def test_ensure_invocation_span_idempotent_same_invocation( + self, callback_context + ): + """Calling ensure_invocation_span twice in the same invocation is a no-op.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + + # First call: creates invocation span. + TM.ensure_invocation_span(callback_context) + records_after_first = list( + bigquery_agent_analytics_plugin._span_records_ctx.get() + ) + assert len(records_after_first) == 1 + first_span_id = records_after_first[0].span_id + + # Second call (same invocation): must be a no-op. + TM.ensure_invocation_span(callback_context) + records_after_second = ( + bigquery_agent_analytics_plugin._span_records_ctx.get() + ) + assert len(records_after_second) == 1 + assert records_after_second[0].span_id == first_span_id + + # Cleanup + TM.pop_span() + + provider.shutdown() + + @pytest.mark.asyncio + async def test_user_message_then_before_run_same_trace_no_ambient( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """Regression: on_user_message → before_run must share one trace_id. + + Without the invocation-ID guard, the second ensure_invocation_span() + call would clear the stack and create a new root span with a + different trace_id, fracturing USER_MESSAGE_RECEIVED from + INVOCATION_STARTING. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + + # No ambient span. + assert not trace.get_current_span().get_span_context().is_valid + + user_msg = types.Content(parts=[types.Part(text="hello")], role="user") + await bq_plugin_inst.on_user_message_callback( + invocation_context=invocation_context, + user_message=user_msg, + ) + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + event_types = [r["event_type"] for r in rows] + assert "USER_MESSAGE_RECEIVED" in event_types + assert "INVOCATION_STARTING" in event_types + + # Every row must share the same trace_id. + trace_ids = {r["trace_id"] for r in rows} + assert len(trace_ids) == 1, ( + "Expected 1 unique trace_id across all events, got" + f" {len(trace_ids)}: {trace_ids}" + ) + + provider.shutdown() + + +class TestRootAgentNameAcrossInvocations: + """Regression: root_agent_name must refresh across invocations.""" + + @pytest.mark.asyncio + async def test_root_agent_name_updates_between_invocations( + self, + bq_plugin_inst, + mock_write_client, + mock_session, + dummy_arrow_schema, + ): + """Two invocations with different root agents must log correct names. + + Previously init_trace() only set _root_agent_name_ctx when it was + None, so the second invocation would inherit the first's root agent. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + mock_session_service = mock.create_autospec( + base_session_service_lib.BaseSessionService, + instance=True, + spec_set=True, + ) + mock_plugin_manager = mock.create_autospec( + plugin_manager_lib.PluginManager, + instance=True, + spec_set=True, + ) + + def _make_inv_ctx(agent_name, inv_id): + agent = mock.create_autospec( + base_agent.BaseAgent, instance=True, spec_set=True + ) + type(agent).name = mock.PropertyMock(return_value=agent_name) + type(agent).instruction = mock.PropertyMock(return_value="") + # root_agent returns itself (no parent). + agent.root_agent = agent + return InvocationContext( + agent=agent, + session=mock_session, + invocation_id=inv_id, + session_service=mock_session_service, + plugin_manager=mock_plugin_manager, + ) + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # --- Invocation 1: root agent = "RootA" --- + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + bigquery_agent_analytics_plugin._root_agent_name_ctx.set(None) + + inv1 = _make_inv_ctx("RootA", "inv-001") + cb1 = CallbackContext(inv1) + await bq_plugin_inst.before_run_callback(invocation_context=inv1) + await bq_plugin_inst.before_agent_callback( + agent=inv1.agent, callback_context=cb1 + ) + await bq_plugin_inst.after_agent_callback( + agent=inv1.agent, callback_context=cb1 + ) + await bq_plugin_inst.after_run_callback(invocation_context=inv1) + await asyncio.sleep(0.01) + + rows_inv1 = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + + # --- Invocation 2: root agent = "RootB" --- + mock_write_client.append_rows.reset_mock() + + inv2 = _make_inv_ctx("RootB", "inv-002") + cb2 = CallbackContext(inv2) + await bq_plugin_inst.before_run_callback(invocation_context=inv2) + await bq_plugin_inst.before_agent_callback( + agent=inv2.agent, callback_context=cb2 + ) + await bq_plugin_inst.after_agent_callback( + agent=inv2.agent, callback_context=cb2 + ) + await bq_plugin_inst.after_run_callback(invocation_context=inv2) + await asyncio.sleep(0.01) + + rows_inv2 = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + + # Parse root_agent_name from the attributes JSON column. + def _get_root_names(rows): + names = set() + for r in rows: + attrs = r.get("attributes") + if attrs: + parsed = json.loads(attrs) if isinstance(attrs, str) else attrs + if "root_agent_name" in parsed: + names.add(parsed["root_agent_name"]) + return names + + names_inv1 = _get_root_names(rows_inv1) + names_inv2 = _get_root_names(rows_inv2) + + # Invocation 1 should only have "RootA". + assert names_inv1 == {"RootA"}, f"Expected {{'RootA'}}, got {names_inv1}" + # Invocation 2 must have "RootB", NOT stale "RootA". + assert names_inv2 == {"RootB"}, f"Expected {{'RootB'}}, got {names_inv2}" + + provider.shutdown() + + +class TestAfterRunCleanupExceptionSafety: + """after_run_callback cleanup must execute even if _log_event fails.""" + + @pytest.mark.asyncio + async def test_cleanup_runs_when_log_event_raises( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + ): + """Stale state is cleared even when _log_event raises.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + bigquery_agent_analytics_plugin._root_agent_name_ctx.set(None) + + # Run a normal before_run to initialise state. + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + + # Verify state is populated. + assert bigquery_agent_analytics_plugin._span_records_ctx.get() + assert ( + bigquery_agent_analytics_plugin._active_invocation_id_ctx.get() + is not None + ) + + # Make _log_event raise inside after_run_callback. + with mock.patch.object( + bq_plugin_inst, + "_log_event", + side_effect=RuntimeError("boom"), + ): + # _safe_callback swallows the exception, but cleanup in + # the finally block must still execute. + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + # All invocation state must be cleaned up despite the error. + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert records == [] or records is None + assert ( + bigquery_agent_analytics_plugin._active_invocation_id_ctx.get() + is None + ) + assert bigquery_agent_analytics_plugin._root_agent_name_ctx.get() is None + + provider.shutdown()