Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions posthog/ai/anthropic/anthropic_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TokenUsage,
ToolInProgress,
)
from posthog.ai.utils import serialize_raw_usage


def format_anthropic_response(response: Any) -> List[FormattedMessage]:
Expand Down Expand Up @@ -221,6 +222,12 @@ def extract_anthropic_usage_from_response(response: Any) -> TokenUsage:
if web_search_count > 0:
result["web_search_count"] = web_search_count

# Capture raw usage metadata for backend processing
# Serialize to dict here in the converter (not in utils)
serialized = serialize_raw_usage(response.usage)
if serialized:
result["raw_usage"] = serialized

return result


Expand All @@ -247,6 +254,11 @@ def extract_anthropic_usage_from_event(event: Any) -> TokenUsage:
usage["cache_read_input_tokens"] = getattr(
event.message.usage, "cache_read_input_tokens", 0
)
# Capture raw usage metadata for backend processing
# Serialize to dict here in the converter (not in utils)
serialized = serialize_raw_usage(event.message.usage)
if serialized:
usage["raw_usage"] = serialized

# Handle usage stats from message_delta event
if hasattr(event, "usage") and event.usage:
Expand All @@ -262,6 +274,12 @@ def extract_anthropic_usage_from_event(event: Any) -> TokenUsage:
if web_search_count > 0:
usage["web_search_count"] = web_search_count

# Capture raw usage metadata for backend processing
# Serialize to dict here in the converter (not in utils)
serialized = serialize_raw_usage(event.usage)
if serialized:
usage["raw_usage"] = serialized

return usage


Expand Down
7 changes: 7 additions & 0 deletions posthog/ai/gemini/gemini_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
FormattedMessage,
TokenUsage,
)
from posthog.ai.utils import serialize_raw_usage


class GeminiPart(TypedDict, total=False):
Expand Down Expand Up @@ -487,6 +488,12 @@ def _extract_usage_from_metadata(metadata: Any) -> TokenUsage:
if reasoning_tokens and reasoning_tokens > 0:
usage["reasoning_tokens"] = reasoning_tokens

# Capture raw usage metadata for backend processing
# Serialize to dict here in the converter (not in utils)
serialized = serialize_raw_usage(metadata)
if serialized:
usage["raw_usage"] = serialized

return usage


Expand Down
19 changes: 19 additions & 0 deletions posthog/ai/openai/openai_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
FormattedTextContent,
TokenUsage,
)
from posthog.ai.utils import serialize_raw_usage


def format_openai_response(response: Any) -> List[FormattedMessage]:
Expand Down Expand Up @@ -429,6 +430,12 @@ def extract_openai_usage_from_response(response: Any) -> TokenUsage:
if web_search_count > 0:
result["web_search_count"] = web_search_count

# Capture raw usage metadata for backend processing
# Serialize to dict here in the converter (not in utils)
serialized = serialize_raw_usage(response.usage)
if serialized:
result["raw_usage"] = serialized

return result


Expand Down Expand Up @@ -482,6 +489,12 @@ def extract_openai_usage_from_chunk(
chunk.usage.completion_tokens_details.reasoning_tokens
)

# Capture raw usage metadata for backend processing
# Serialize to dict here in the converter (not in utils)
serialized = serialize_raw_usage(chunk.usage)
if serialized:
usage["raw_usage"] = serialized

elif provider_type == "responses":
# For Responses API, usage is only in chunk.response.usage for completed events
if hasattr(chunk, "type") and chunk.type == "response.completed":
Expand Down Expand Up @@ -516,6 +529,12 @@ def extract_openai_usage_from_chunk(
if web_search_count > 0:
usage["web_search_count"] = web_search_count

# Capture raw usage metadata for backend processing
# Serialize to dict here in the converter (not in utils)
serialized = serialize_raw_usage(response_usage)
if serialized:
usage["raw_usage"] = serialized

return usage


Expand Down
1 change: 1 addition & 0 deletions posthog/ai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class TokenUsage(TypedDict, total=False):
cache_creation_input_tokens: Optional[int]
reasoning_tokens: Optional[int]
web_search_count: Optional[int]
raw_usage: Optional[Any] # Raw provider usage metadata for backend processing


class ProviderResponse(TypedDict, total=False):
Expand Down
78 changes: 78 additions & 0 deletions posthog/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,54 @@
from posthog.client import Client as PostHogClient


def serialize_raw_usage(raw_usage: Any) -> Optional[Dict[str, Any]]:
"""
Convert raw provider usage objects to JSON-serializable dicts.

Handles Pydantic models (OpenAI/Anthropic) and protobuf-like objects (Gemini)
with a fallback chain to ensure we never pass unserializable objects to PostHog.

Args:
raw_usage: Raw usage object from provider SDK

Returns:
Plain dict or None if conversion fails
"""
if raw_usage is None:
return None

# Already a dict
if isinstance(raw_usage, dict):
return raw_usage

# Try Pydantic model_dump() (OpenAI/Anthropic)
if hasattr(raw_usage, "model_dump") and callable(raw_usage.model_dump):
try:
return raw_usage.model_dump()
except Exception:
pass

# Try to_dict() (some protobuf objects)
if hasattr(raw_usage, "to_dict") and callable(raw_usage.to_dict):
try:
return raw_usage.to_dict()
except Exception:
pass

# Try __dict__ / vars() for simple objects
try:
return vars(raw_usage)
except Exception:
pass

# Last resort: convert to string representation
# This ensures we always return something rather than failing
try:
return {"_raw": str(raw_usage)}
except Exception:
return None


def merge_usage_stats(
target: TokenUsage, source: TokenUsage, mode: str = "incremental"
) -> None:
Expand Down Expand Up @@ -60,6 +108,17 @@ def merge_usage_stats(
current = target.get("web_search_count") or 0
target["web_search_count"] = max(current, source_web_search)

# Merge raw_usage to avoid losing data from earlier events
# For Anthropic streaming: message_start has input tokens, message_delta has output
# Note: raw_usage is already serialized by converters, so it's a dict
source_raw_usage = source.get("raw_usage")
if source_raw_usage is not None and isinstance(source_raw_usage, dict):
current_raw_value = target.get("raw_usage")
current_raw: Dict[str, Any] = (
current_raw_value if isinstance(current_raw_value, dict) else {}
)
target["raw_usage"] = {**current_raw, **source_raw_usage}

elif mode == "cumulative":
# Replace with latest values (already cumulative)
if source.get("input_tokens") is not None:
Expand All @@ -76,6 +135,9 @@ def merge_usage_stats(
target["reasoning_tokens"] = source["reasoning_tokens"]
if source.get("web_search_count") is not None:
target["web_search_count"] = source["web_search_count"]
# Note: raw_usage is already serialized by converters, so it's a dict
if source.get("raw_usage") is not None:
target["raw_usage"] = source["raw_usage"]

else:
raise ValueError(f"Invalid mode: {mode}. Must be 'incremental' or 'cumulative'")
Expand Down Expand Up @@ -332,6 +394,11 @@ def call_llm_and_track_usage(
if web_search_count is not None and web_search_count > 0:
tag("$ai_web_search_count", web_search_count)

raw_usage = usage.get("raw_usage")
if raw_usage is not None:
# Already serialized by converters
tag("$ai_usage", raw_usage)

if posthog_distinct_id is None:
tag("$process_person_profile", False)

Expand Down Expand Up @@ -457,6 +524,11 @@ async def call_llm_and_track_usage_async(
if web_search_count is not None and web_search_count > 0:
tag("$ai_web_search_count", web_search_count)

raw_usage = usage.get("raw_usage")
if raw_usage is not None:
# Already serialized by converters
tag("$ai_usage", raw_usage)

if posthog_distinct_id is None:
tag("$process_person_profile", False)

Expand Down Expand Up @@ -594,6 +666,12 @@ def capture_streaming_event(
):
event_properties["$ai_web_search_count"] = web_search_count

# Add raw usage metadata if present (all providers)
raw_usage = event_data["usage_stats"].get("raw_usage")
if raw_usage is not None:
# Already serialized by converters
event_properties["$ai_usage"] = raw_usage

# Handle provider-specific fields
if (
event_data["provider"] == "openai"
Expand Down
20 changes: 20 additions & 0 deletions posthog/test/ai/anthropic/test_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -306,6 +307,15 @@ def test_basic_completion(mock_client, mock_anthropic_response):
assert props["$ai_http_status"] == 200
assert props["foo"] == "bar"
assert isinstance(props["$ai_latency"], float)
# Verify raw usage metadata is passed for backend processing
assert "$ai_usage" in props
assert props["$ai_usage"] is not None
# Verify it's JSON-serializable
json.dumps(props["$ai_usage"])
# Verify it has expected structure
assert isinstance(props["$ai_usage"], dict)
assert "input_tokens" in props["$ai_usage"]
assert "output_tokens" in props["$ai_usage"]


def test_groups(mock_client, mock_anthropic_response):
Expand Down Expand Up @@ -918,6 +928,16 @@ def test_streaming_with_tool_calls(mock_client, mock_anthropic_stream_with_tools
assert props["$ai_cache_read_input_tokens"] == 5
assert props["$ai_cache_creation_input_tokens"] == 0

# Verify raw usage is captured in streaming mode (merged from events)
assert "$ai_usage" in props
assert props["$ai_usage"] is not None
# Verify it's JSON-serializable
json.dumps(props["$ai_usage"])
# Verify it has expected structure (merged from message_start and message_delta)
assert isinstance(props["$ai_usage"], dict)
assert "input_tokens" in props["$ai_usage"]
assert "output_tokens" in props["$ai_usage"]


def test_async_streaming_with_tool_calls(mock_client, mock_anthropic_stream_with_tools):
"""Test that tool calls are properly captured in async streaming mode."""
Expand Down
55 changes: 55 additions & 0 deletions posthog/test/ai/gemini/test_gemini.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -34,6 +35,13 @@ def mock_gemini_response():
# Ensure cache and reasoning tokens are not present (not MagicMock)
mock_usage.cached_content_token_count = 0
mock_usage.thoughts_token_count = 0
# Make model_dump() return a proper dict for serialization
mock_usage.model_dump.return_value = {
"prompt_token_count": 20,
"candidates_token_count": 10,
"cached_content_token_count": 0,
"thoughts_token_count": 0,
}
mock_response.usage_metadata = mock_usage

mock_candidate = MagicMock()
Expand Down Expand Up @@ -69,6 +77,13 @@ def mock_gemini_response_with_function_calls():
mock_usage.candidates_token_count = 15
mock_usage.cached_content_token_count = 0
mock_usage.thoughts_token_count = 0
# Make model_dump() return a proper dict for serialization
mock_usage.model_dump.return_value = {
"prompt_token_count": 25,
"candidates_token_count": 15,
"cached_content_token_count": 0,
"thoughts_token_count": 0,
}
mock_response.usage_metadata = mock_usage

# Mock function call
Expand Down Expand Up @@ -117,6 +132,13 @@ def mock_gemini_response_function_calls_only():
mock_usage.candidates_token_count = 12
mock_usage.cached_content_token_count = 0
mock_usage.thoughts_token_count = 0
# Make model_dump() return a proper dict for serialization
mock_usage.model_dump.return_value = {
"prompt_token_count": 30,
"candidates_token_count": 12,
"cached_content_token_count": 0,
"thoughts_token_count": 0,
}
mock_response.usage_metadata = mock_usage

# Mock function call
Expand Down Expand Up @@ -174,6 +196,15 @@ def test_new_client_basic_generation(
assert props["foo"] == "bar"
assert "$ai_trace_id" in props
assert props["$ai_latency"] > 0
# Verify raw usage metadata is passed for backend processing
assert "$ai_usage" in props
assert props["$ai_usage"] is not None
# Verify it's JSON-serializable
json.dumps(props["$ai_usage"])
# Verify it has expected structure
assert isinstance(props["$ai_usage"], dict)
assert "prompt_token_count" in props["$ai_usage"]
assert "candidates_token_count" in props["$ai_usage"]


def test_new_client_streaming_with_generate_content_stream(
Expand Down Expand Up @@ -810,6 +841,13 @@ def test_streaming_cache_and_reasoning_tokens(mock_client, mock_google_genai_cli
chunk1_usage.candidates_token_count = 5
chunk1_usage.cached_content_token_count = 30 # Cache tokens
chunk1_usage.thoughts_token_count = 0
# Make model_dump() return a proper dict for serialization
chunk1_usage.model_dump.return_value = {
"prompt_token_count": 100,
"candidates_token_count": 5,
"cached_content_token_count": 30,
"thoughts_token_count": 0,
}
chunk1.usage_metadata = chunk1_usage

chunk2 = MagicMock()
Expand All @@ -819,6 +857,13 @@ def test_streaming_cache_and_reasoning_tokens(mock_client, mock_google_genai_cli
chunk2_usage.candidates_token_count = 10
chunk2_usage.cached_content_token_count = 30 # Same cache tokens
chunk2_usage.thoughts_token_count = 5 # Reasoning tokens
# Make model_dump() return a proper dict for serialization
chunk2_usage.model_dump.return_value = {
"prompt_token_count": 100,
"candidates_token_count": 10,
"cached_content_token_count": 30,
"thoughts_token_count": 5,
}
chunk2.usage_metadata = chunk2_usage

mock_stream = iter([chunk1, chunk2])
Expand Down Expand Up @@ -848,6 +893,16 @@ def test_streaming_cache_and_reasoning_tokens(mock_client, mock_google_genai_cli
assert props["$ai_cache_read_input_tokens"] == 30
assert props["$ai_reasoning_tokens"] == 5

# Verify raw usage is captured in streaming mode (merged from chunks)
assert "$ai_usage" in props
assert props["$ai_usage"] is not None
# Verify it's JSON-serializable
json.dumps(props["$ai_usage"])
# Verify it has expected structure
assert isinstance(props["$ai_usage"], dict)
assert "prompt_token_count" in props["$ai_usage"]
assert "candidates_token_count" in props["$ai_usage"]


def test_web_search_grounding(mock_client, mock_google_genai_client):
"""Test web search detection via grounding_metadata."""
Expand Down
Loading
Loading