Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 48 additions & 153 deletions drift/instrumentation/django/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

from __future__ import annotations

import json
import logging
import time
from collections.abc import Callable
from typing import TYPE_CHECKING

from opentelemetry.trace import SpanKind as OTelSpanKind
from opentelemetry.trace import Status
from opentelemetry.trace import StatusCode as OTelStatusCode

logger = logging.getLogger(__name__)

Expand All @@ -17,13 +20,8 @@
from ...core.tracing import TdSpanAttributes
from ...core.tracing.span_utils import CreateSpanOptions, SpanInfo, SpanUtils
from ...core.types import (
CleanSpanData,
Duration,
PackageType,
SpanKind,
SpanStatus,
StatusCode,
Timestamp,
TuskDriftMode,
replay_trace_id_context,
span_kind_context,
Expand Down Expand Up @@ -282,7 +280,12 @@ def _normalize_html_response(self, response: HttpResponse) -> HttpResponse:
return normalize_html_response(response)

def _capture_span(self, request: HttpRequest, response: HttpResponse, span_info: SpanInfo) -> None:
"""Create and collect a span from request/response data.
"""Finalize span with request/response data by setting OTel attributes.

Sets INPUT_VALUE, OUTPUT_VALUE, schema merges, and status on the OTel
span. When span.end() is called, TdSpanProcessor.on_end() converts
these attributes to CleanSpanData and exports it - the same single-write
pattern used by Flask (WSGI handler) and FastAPI.

Args:
request: Django HttpRequest object
Expand All @@ -294,12 +297,7 @@ def _capture_span(self, request: HttpRequest, response: HttpResponse, span_info:
if not start_time_ns or not span_info.span.is_recording():
return

# Use trace_id and span_id from span_info
trace_id = span_info.trace_id
span_id = span_info.span_id

end_time_ns = time.time_ns()
duration_ns = end_time_ns - start_time_ns

# Build input_value using WSGI utilities
request_body = getattr(request, "_drift_request_body", None)
Expand Down Expand Up @@ -357,7 +355,8 @@ def _capture_span(self, request: HttpRequest, response: HttpResponse, span_info:
f"Blocking trace {trace_id} - binary response: {content_type} "
f"(decoded as {decoded_type.name if decoded_type else 'unknown'})"
)
return # Skip span creation
span_info.span.set_status(Status(OTelStatusCode.ERROR, "Binary content blocked"))
return

# Apply transforms if present
transform_metadata = None
Expand All @@ -372,93 +371,41 @@ def _capture_span(self, request: HttpRequest, response: HttpResponse, span_info:
input_value = span_data.input_value or input_value
output_value = span_data.output_value or output_value

# Build schema merges and generate schemas
# Note: Django uses direct CleanSpanData creation instead of OTel spans,
# so we need to generate schemas here instead of in the converter
from ...core.json_schema_helper import JsonSchemaHelper

input_schema_merges_dict = build_input_schema_merges(input_value)
output_schema_merges_dict = build_output_schema_merges(output_value)

# Convert dict back to SchemaMerge objects for JsonSchemaHelper
from ...core.json_schema_helper import DecodedType, EncodingType, SchemaMerge

def dict_to_schema_merges(merges_dict):
result = {}
for key, merge_data in merges_dict.items():
encoding = EncodingType(merge_data["encoding"]) if "encoding" in merge_data else None
decoded_type = DecodedType(merge_data["decoded_type"]) if "decoded_type" in merge_data else None
match_importance = merge_data.get("match_importance")
result[key] = SchemaMerge(
encoding=encoding, decoded_type=decoded_type, match_importance=match_importance
)
return result

input_schema_merges = dict_to_schema_merges(input_schema_merges_dict)
output_schema_merges = dict_to_schema_merges(output_schema_merges_dict)

input_schema_info = JsonSchemaHelper.generate_schema_and_hash(input_value, input_schema_merges)
output_schema_info = JsonSchemaHelper.generate_schema_and_hash(output_value, output_schema_merges)

from ...core.drift_sdk import TuskDrift

sdk = TuskDrift.get_instance()
# Derive timestamp from start_time_ns
timestamp_seconds = start_time_ns // 1_000_000_000
timestamp_nanos = start_time_ns % 1_000_000_000
duration_seconds = duration_ns // 1_000_000_000
duration_nanos = duration_ns % 1_000_000_000

# Match Node SDK: >= 400 is considered an error
if status_code >= 400:
status = SpanStatus(code=StatusCode.ERROR, message=f"HTTP {status_code}")
else:
status = SpanStatus(code=StatusCode.OK, message="")

# Django-specific: use route template for span name to avoid cardinality explosion
method = request.method or ""
route_template = getattr(request, "_drift_route_template", None)
if route_template:
# Use route template (e.g., "users/<int:id>/")
span_name = f"{method} {route_template}"
else:
# Fallback to literal path (e.g., for 404s)
span_name = f"{method} {request.path}"
span_info.span.set_attribute(TdSpanAttributes.NAME, span_name)

# Only create and collect span in RECORD mode
# In REPLAY mode, we only set up context for child spans but don't record the root span
if sdk.mode == TuskDriftMode.RECORD:
clean_span = CleanSpanData(
trace_id=trace_id,
span_id=span_id,
parent_span_id="",
name=span_name,
package_name="django",
instrumentation_name="DjangoInstrumentation",
submodule_name=method,
package_type=PackageType.HTTP,
kind=SpanKind.SERVER,
input_value=input_value,
output_value=output_value,
input_schema=input_schema_info.schema,
output_schema=output_schema_info.schema,
input_value_hash=input_schema_info.decoded_value_hash,
output_value_hash=output_schema_info.decoded_value_hash,
input_schema_hash=input_schema_info.decoded_schema_hash,
output_schema_hash=output_schema_info.decoded_schema_hash,
status=status,
is_pre_app_start=span_info.is_pre_app_start,
is_root_span=True,
timestamp=Timestamp(seconds=timestamp_seconds, nanos=timestamp_nanos),
duration=Duration(seconds=duration_seconds, nanos=duration_nanos),
transform_metadata=transform_metadata,
metadata=None,
)
# Set data attributes - TdSpanProcessor.on_end() reads these to build CleanSpanData
span_info.span.set_attribute(TdSpanAttributes.INPUT_VALUE, json.dumps(input_value))
span_info.span.set_attribute(TdSpanAttributes.OUTPUT_VALUE, json.dumps(output_value))

# Set schema merge hints (schemas are generated at export time by the converter)
input_schema_merges = build_input_schema_merges(input_value)
output_schema_merges = build_output_schema_merges(output_value)
span_info.span.set_attribute(TdSpanAttributes.INPUT_SCHEMA_MERGES, json.dumps(input_schema_merges))
span_info.span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA_MERGES, json.dumps(output_schema_merges))

if transform_metadata:
span_info.span.set_attribute(TdSpanAttributes.TRANSFORM_METADATA, json.dumps(transform_metadata))

sdk.collect_span(clean_span)
# Set status based on HTTP status code
if status_code >= 400:
span_info.span.set_status(Status(OTelStatusCode.ERROR, f"HTTP {status_code}"))
else:
span_info.span.set_status(Status(OTelStatusCode.OK))

def _capture_error_span(self, request: HttpRequest, exception: Exception, span_info: SpanInfo) -> None:
"""Create and collect an error span.
"""Finalize span with error data by setting OTel attributes.

Sets INPUT_VALUE, OUTPUT_VALUE (with error info), schema merges, and
ERROR status on the OTel span. When span.end() is called,
TdSpanProcessor.on_end() converts and exports - same pattern as
Flask/FastAPI.

Args:
request: Django HttpRequest object
Expand All @@ -470,13 +417,6 @@ def _capture_error_span(self, request: HttpRequest, exception: Exception, span_i
if not start_time_ns or not span_info.span.is_recording():
return

# Use trace_id and span_id from span_info
trace_id = span_info.trace_id
span_id = span_info.span_id

end_time_ns = time.time_ns()
duration_ns = end_time_ns - start_time_ns

# Build input_value
request_body = getattr(request, "_drift_request_body", None)
input_value = build_input_value(request.META, request_body)
Expand All @@ -490,66 +430,21 @@ def _capture_error_span(self, request: HttpRequest, exception: Exception, span_i
str(exception),
)

# Build schema merges and generate schemas
from ...core.json_schema_helper import DecodedType, EncodingType, JsonSchemaHelper, SchemaMerge

input_schema_merges_dict = build_input_schema_merges(input_value)
output_schema_merges_dict = build_output_schema_merges(output_value)

def dict_to_schema_merges(merges_dict):
result = {}
for key, merge_data in merges_dict.items():
encoding = EncodingType(merge_data["encoding"]) if "encoding" in merge_data else None
decoded_type = DecodedType(merge_data["decoded_type"]) if "decoded_type" in merge_data else None
match_importance = merge_data.get("match_importance")
result[key] = SchemaMerge(
encoding=encoding, decoded_type=decoded_type, match_importance=match_importance
)
return result

input_schema_merges = dict_to_schema_merges(input_schema_merges_dict)
output_schema_merges = dict_to_schema_merges(output_schema_merges_dict)

input_schema_info = JsonSchemaHelper.generate_schema_and_hash(input_value, input_schema_merges)
output_schema_info = JsonSchemaHelper.generate_schema_and_hash(output_value, output_schema_merges)

from ...core.drift_sdk import TuskDrift

sdk = TuskDrift.get_instance()
timestamp_seconds = start_time_ns // 1_000_000_000
timestamp_nanos = start_time_ns % 1_000_000_000
duration_seconds = duration_ns // 1_000_000_000
duration_nanos = duration_ns % 1_000_000_000

# Update span name with route template
method = request.method or ""
route_template = getattr(request, "_drift_route_template", None)
span_name = f"{method} {route_template}" if route_template else f"{method} {request.path}"
span_info.span.set_attribute(TdSpanAttributes.NAME, span_name)

clean_span = CleanSpanData(
trace_id=trace_id,
span_id=span_id,
parent_span_id="",
name=span_name,
package_name="django",
instrumentation_name="DjangoInstrumentation",
submodule_name=method,
package_type=PackageType.HTTP,
kind=SpanKind.SERVER,
input_value=input_value,
output_value=output_value,
input_schema=input_schema_info.schema,
output_schema=output_schema_info.schema,
input_value_hash=input_schema_info.decoded_value_hash,
output_value_hash=output_schema_info.decoded_value_hash,
input_schema_hash=input_schema_info.decoded_schema_hash,
output_schema_hash=output_schema_info.decoded_schema_hash,
status=SpanStatus(code=StatusCode.ERROR, message=f"Exception: {type(exception).__name__}"),
is_pre_app_start=span_info.is_pre_app_start,
is_root_span=True,
timestamp=Timestamp(seconds=timestamp_seconds, nanos=timestamp_nanos),
duration=Duration(seconds=duration_seconds, nanos=duration_nanos),
transform_metadata=None,
metadata=None,
)
# Set data attributes - TdSpanProcessor.on_end() reads these to build CleanSpanData
span_info.span.set_attribute(TdSpanAttributes.INPUT_VALUE, json.dumps(input_value))
span_info.span.set_attribute(TdSpanAttributes.OUTPUT_VALUE, json.dumps(output_value))

# Set schema merge hints (schemas are generated at export time by the converter)
input_schema_merges = build_input_schema_merges(input_value)
output_schema_merges = build_output_schema_merges(output_value)
span_info.span.set_attribute(TdSpanAttributes.INPUT_SCHEMA_MERGES, json.dumps(input_schema_merges))
span_info.span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA_MERGES, json.dumps(output_schema_merges))

sdk.collect_span(clean_span)
# Set error status
span_info.span.set_status(Status(OTelStatusCode.ERROR, f"Exception: {type(exception).__name__}"))
6 changes: 3 additions & 3 deletions drift/instrumentation/psycopg/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ async def patched_async_connect(conninfo="", **kwargs):
# Replace the classmethod with our patched version
AsyncConnection.connect = classmethod(
lambda cls, conninfo="", **kwargs: patched_async_connect(conninfo, **kwargs)
) # type: ignore[method-assign]
)
logger.info("psycopg.AsyncConnection.connect instrumented")

# Also patch AsyncConnectionPool to inject cursor_factory
Expand Down Expand Up @@ -281,7 +281,7 @@ def patched_init(pool_self, conninfo="", **kwargs):

return original_init(pool_self, conninfo, **kwargs)

AsyncConnectionPool.__init__ = patched_init # type: ignore[method-assign]
AsyncConnectionPool.__init__ = patched_init
logger.info("psycopg_pool.AsyncConnectionPool.__init__ instrumented")

def _create_async_cursor_factory(self, sdk: TuskDrift, base_factory=None):
Expand All @@ -296,7 +296,7 @@ def _create_async_cursor_factory(self, sdk: TuskDrift, base_factory=None):
from psycopg import AsyncCursor as BaseAsyncCursor
except ImportError:
logger.warning("[ASYNC_CURSOR_FACTORY] Could not import psycopg.AsyncCursor")
BaseAsyncCursor = object # type: ignore
BaseAsyncCursor = object

base = base_factory or BaseAsyncCursor

Expand Down