diff --git a/drift/core/communication/communicator.py b/drift/core/communication/communicator.py index 5d2e9b9..6ae4361 100644 --- a/drift/core/communication/communicator.py +++ b/drift/core/communication/communicator.py @@ -612,31 +612,7 @@ def _extract_response_data(self, struct: Any) -> dict[str, Any]: return {} try: - - def value_to_python(value): - """Convert protobuf Value to Python type.""" - if hasattr(value, "null_value"): - return None - elif hasattr(value, "number_value"): - return value.number_value - elif hasattr(value, "string_value"): - return value.string_value - elif hasattr(value, "bool_value"): - return value.bool_value - elif hasattr(value, "struct_value") and value.struct_value: - return struct_to_dict(value.struct_value) - elif hasattr(value, "list_value") and value.list_value: - return [value_to_python(v) for v in value.list_value.values] - return None - - def struct_to_dict(s): - """Convert protobuf Struct to Python dict.""" - if not hasattr(s, "fields"): - return {} - result = {} - for key, value in s.fields.items(): - result[key] = value_to_python(value) - return result + from ..protobuf_utils import struct_to_dict data = struct_to_dict(struct) diff --git a/drift/core/communication/types.py b/drift/core/communication/types.py index 179c528..0ecd7aa 100644 --- a/drift/core/communication/types.py +++ b/drift/core/communication/types.py @@ -379,6 +379,8 @@ def extract_response_data(struct: Any) -> dict[str, Any]: The CLI returns response data wrapped in a Struct with a "response" field. """ + from ..protobuf_utils import value_to_python + try: # Handle betterproto dict-like struct if hasattr(struct, "items"): @@ -391,8 +393,8 @@ def extract_response_data(struct: Any) -> dict[str, Any]: if hasattr(struct, "fields"): fields = struct.fields if "response" in fields: - return _value_to_python(fields["response"]) - return {k: _value_to_python(v) for k, v in fields.items()} + return value_to_python(fields["response"]) + return {k: value_to_python(v) for k, v in fields.items()} # Direct dict access if isinstance(struct, dict): @@ -403,20 +405,3 @@ def extract_response_data(struct: Any) -> dict[str, Any]: return {} except Exception: return {} - - -def _value_to_python(value: Any) -> Any: - """Convert a protobuf Value to Python native type.""" - if hasattr(value, "null_value"): - return None - if hasattr(value, "number_value"): - return value.number_value - if hasattr(value, "string_value"): - return value.string_value - if hasattr(value, "bool_value"): - return value.bool_value - if hasattr(value, "struct_value"): - return {k: _value_to_python(v) for k, v in value.struct_value.fields.items()} - if hasattr(value, "list_value"): - return [_value_to_python(v) for v in value.list_value.values] - return value diff --git a/drift/core/protobuf_utils.py b/drift/core/protobuf_utils.py new file mode 100644 index 0000000..3912481 --- /dev/null +++ b/drift/core/protobuf_utils.py @@ -0,0 +1,102 @@ +"""Protobuf conversion utilities. + +This module provides utilities for converting protobuf Value and Struct objects +to Python native types. Handles both Google protobuf and betterproto variants. +""" + +from __future__ import annotations + +from typing import Any + + +def value_to_python(value: Any) -> Any: + """Convert protobuf Value or Python native type to Python native type. + + Handles: + 1. Python native types (from betterproto's Struct which stores values directly) + 2. Google protobuf Value objects (from google.protobuf.struct_pb2) + 3. betterproto Value objects (from betterproto.lib.google.protobuf) + + Args: + value: A protobuf Value, native Python type, or nested structure + + Returns: + Python native type (None, bool, int, float, str, list, or dict) + """ + # Handle Python native types (from betterproto or already converted) + if value is None: + return None + if isinstance(value, (bool, int, float, str)): + # Note: bool must be checked before int since bool is a subclass of int + return value + if isinstance(value, list): + return [value_to_python(v) for v in value] + if isinstance(value, dict): + return {k: value_to_python(v) for k, v in value.items()} + + # Handle Google protobuf Value objects using WhichOneof + if hasattr(value, "WhichOneof"): + kind = value.WhichOneof("kind") + if kind == "null_value": + return None + elif kind == "number_value": + return value.number_value + elif kind == "string_value": + return value.string_value + elif kind == "bool_value": + return value.bool_value + elif kind == "struct_value": + return struct_to_dict(value.struct_value) + elif kind == "list_value": + return [value_to_python(v) for v in value.list_value.values] + + # Handle betterproto Value objects using is_set method + if hasattr(value, "is_set"): + try: + if value.is_set("null_value"): + return None + elif value.is_set("number_value"): + return value.number_value + elif value.is_set("string_value"): + return value.string_value + elif value.is_set("bool_value"): + return value.bool_value + elif value.is_set("struct_value"): + sv = value.struct_value + if isinstance(sv, dict): + return {k: value_to_python(v) for k, v in sv.get("fields", sv).items()} + return struct_to_dict(sv) + elif value.is_set("list_value"): + lv = value.list_value + # Handle dict-style list_value (betterproto can store dicts) + if isinstance(lv, dict): + return [value_to_python(v) for v in lv.get("values", [])] + return [value_to_python(v) for v in lv.values] + except (AttributeError, TypeError): + pass # Not a betterproto Value, fall through + + # Fallback: return the value as-is + return value + + +def struct_to_dict(struct: Any) -> dict[str, Any]: + """Convert protobuf Struct to Python dict. + + Args: + struct: A protobuf Struct object (Google or betterproto) + + Returns: + Python dict with converted values + """ + if not struct: + return {} + + # Handle dict-like struct (betterproto sometimes returns dicts directly) + if isinstance(struct, dict): + return {k: value_to_python(v) for k, v in struct.items()} + + # Handle struct with fields attribute + if hasattr(struct, "fields"): + return {k: value_to_python(v) for k, v in struct.fields.items()} + + return {} diff --git a/drift/instrumentation/django/csrf_utils.py b/drift/instrumentation/django/csrf_utils.py deleted file mode 100644 index da1721d..0000000 --- a/drift/instrumentation/django/csrf_utils.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Django CSRF token utilities for consistent record/replay testing. - -This module provides utilities to normalize CSRF tokens so that recorded -and replayed responses produce identical output for comparison. -""" - -from __future__ import annotations - -import logging -import re - -logger = logging.getLogger(__name__) - -CSRF_PLACEHOLDER = "__DRIFT_CSRF__" - - -def normalize_csrf_in_body(body: bytes | None) -> bytes | None: - """Normalize CSRF tokens in response body for consistent record/replay comparison. - - Replaces Django CSRF tokens with a fixed placeholder so that recorded - responses match replayed responses during comparison. - - This should be called after the response is sent to the browser, - but before storing in the span. The actual response to the browser - is unchanged. - - Args: - body: Response body bytes (typically HTML) - - Returns: - Body with CSRF tokens normalized, or original body if not applicable - """ - if not body: - return body - - try: - body_str = body.decode("utf-8") - - # Pattern 1: Hidden input fields with csrfmiddlewaretoken - # - # Handles both single and double quotes, various attribute orders - csrf_input_pattern = ( - r'(]*name=["\']csrfmiddlewaretoken["\'][^>]*value=["\'])' - r'[^"\']+(["\'])' - ) - body_str = re.sub( - csrf_input_pattern, - rf"\g<1>{CSRF_PLACEHOLDER}\2", - body_str, - flags=re.IGNORECASE, - ) - - # Pattern 2: Also handle value before name (different attribute order) - # - csrf_input_pattern_alt = r'(]*value=["\'])[^"\']+(["\'][^>]*name=["\']csrfmiddlewaretoken["\'])' - body_str = re.sub( - csrf_input_pattern_alt, - rf"\g<1>{CSRF_PLACEHOLDER}\2", - body_str, - flags=re.IGNORECASE, - ) - - return body_str.encode("utf-8") - - except Exception as e: - logger.debug(f"Error normalizing CSRF tokens: {e}") - return body diff --git a/drift/instrumentation/django/html_utils.py b/drift/instrumentation/django/html_utils.py new file mode 100644 index 0000000..6d51346 --- /dev/null +++ b/drift/instrumentation/django/html_utils.py @@ -0,0 +1,188 @@ +"""Django HTML utilities for consistent record/replay testing. + +This module provides utilities to normalize HTML content so that recorded +and replayed responses produce identical output for comparison. Includes +CSRF token normalization and HTML class ordering normalization. +""" + +from __future__ import annotations + +import logging +import re +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from django.http import HttpResponse + +logger = logging.getLogger(__name__) + +# Placeholder used to replace CSRF tokens for consistent comparison +CSRF_PLACEHOLDER = "__DRIFT_CSRF__" + + +def normalize_csrf_in_body(body: bytes | None) -> bytes | None: + """Normalize CSRF tokens in response body for consistent record/replay comparison. + + Replaces Django CSRF tokens with a fixed placeholder so that recorded + responses match replayed responses during comparison. + + Args: + body: Response body bytes (typically HTML) + + Returns: + Body with CSRF tokens normalized, or original body if not applicable + """ + if not body: + return body + + try: + body_str = body.decode("utf-8") + + # Pattern 1: Hidden input fields with csrfmiddlewaretoken + # + # Handles both single and double quotes, various attribute orders + csrf_input_pattern = ( + r'(]*name=["\']csrfmiddlewaretoken["\'][^>]*value=["\'])' + r'[^"\']+(["\'])' + ) + body_str = re.sub( + csrf_input_pattern, + rf"\g<1>{CSRF_PLACEHOLDER}\2", + body_str, + flags=re.IGNORECASE, + ) + + # Pattern 2: Also handle value before name (different attribute order) + # + csrf_input_pattern_alt = r'(]*value=["\'])[^"\']+(["\'][^>]*name=["\']csrfmiddlewaretoken["\'])' + body_str = re.sub( + csrf_input_pattern_alt, + rf"\g<1>{CSRF_PLACEHOLDER}\2", + body_str, + flags=re.IGNORECASE, + ) + + return body_str.encode("utf-8") + + except Exception as e: + logger.debug(f"Error normalizing CSRF tokens: {e}") + return body + + +def _sort_classes(match: re.Match) -> str: + """Sort CSS classes within a class attribute match. + + Args: + match: Regex match object containing the class attribute + + Returns: + The class attribute with sorted class names + """ + prefix = match.group(1) # 'class="' or "class='" + classes = match.group(2) # The actual class names + quote = match.group(3) # Closing quote + + # Split classes by whitespace, sort them, and rejoin + class_list = classes.split() + sorted_classes = " ".join(sorted(class_list)) + + return f"{prefix}{sorted_classes}{quote}" + + +def normalize_html_class_ordering(body: bytes | None) -> bytes | None: + """Normalize HTML class attribute ordering for consistent record/replay comparison. + + Sorts CSS class names alphabetically within each class attribute so that + non-deterministic class ordering (e.g., from django-unfold) doesn't cause + false deviations during replay comparison. + + Args: + body: Response body bytes (typically HTML) + + Returns: + Body with class attributes normalized, or original body if not applicable + """ + if not body: + return body + + try: + body_str = body.decode("utf-8") + + # Pattern to match class attributes with double or single quotes + # class="class1 class2 class3" or class='class1 class2 class3' + # Captures: (1) prefix including quote, (2) class names, (3) closing quote + class_pattern = r'(class=["\'])([^"\']+)(["\'])' + + body_str = re.sub( + class_pattern, + _sort_classes, + body_str, + flags=re.IGNORECASE, + ) + + return body_str.encode("utf-8") + + except Exception as e: + logger.debug(f"Error normalizing HTML class ordering: {e}") + return body + + +def normalize_html_body(body: bytes | None, content_type: str, content_encoding: str = "") -> bytes | None: + """Normalize HTML body for consistent record/replay comparison. + + Applies CSRF token normalization and HTML class ordering normalization. + Only processes uncompressed text/html responses. + + Args: + body: Response body bytes + content_type: Content-Type header value + content_encoding: Content-Encoding header value (optional) + + Returns: + Normalized body bytes, or original body if not applicable + """ + if not body: + return body + + if "text/html" not in content_type.lower(): + return body + + # Skip normalization for compressed responses - decoding gzip/deflate as UTF-8 would corrupt the body + encoding = content_encoding.lower() if content_encoding else "" + if encoding and encoding != "identity": + return body + + normalized = normalize_csrf_in_body(body) + normalized = normalize_html_class_ordering(normalized) + + return normalized + + +def normalize_html_response(response: HttpResponse) -> HttpResponse: + """Normalize CSRF tokens and HTML class ordering in Django response. + + Modifies the response body in-place. Only affects uncompressed HTML responses. + Used in REPLAY mode to ensure the actual response matches the recorded + (normalized) response. + + Args: + response: Django HttpResponse object + + Returns: + Modified response with normalized body + """ + content_type = response.get("Content-Type", "") + content_encoding = response.get("Content-Encoding", "") + + if not hasattr(response, "content") or not response.content: + return response + + normalized_body = normalize_html_body(response.content, content_type, content_encoding) + + if normalized_body is not None and normalized_body != response.content: + response.content = normalized_body + # Update Content-Length header if present + if "Content-Length" in response: + response["Content-Length"] = len(normalized_body) + + return response diff --git a/drift/instrumentation/django/middleware.py b/drift/instrumentation/django/middleware.py index 23b014e..d052429 100644 --- a/drift/instrumentation/django/middleware.py +++ b/drift/instrumentation/django/middleware.py @@ -147,8 +147,8 @@ def _handle_replay_request(self, request: HttpRequest, sdk) -> HttpResponse: with SpanUtils.with_span(span_info): response = self.get_response(request) # REPLAY mode: don't capture the span (it's already recorded) - # But do normalize CSRF tokens in the response so comparison succeeds - response = self._normalize_csrf_in_response(response) + # But do normalize the response so comparison succeeds + response = self._normalize_html_response(response) return response finally: # Reset context @@ -264,42 +264,22 @@ def process_view( if route: request._drift_route_template = route # type: ignore - def _normalize_csrf_in_response(self, response: HttpResponse) -> HttpResponse: - """Normalize CSRF tokens in the actual response body for REPLAY mode. + def _normalize_html_response(self, response: HttpResponse) -> HttpResponse: + """Normalize HTML response body for REPLAY mode comparison. In REPLAY mode, we need the actual HTTP response to match the recorded - response (which had CSRF tokens normalized during recording). This modifies - the response body to replace real CSRF tokens with the normalized placeholder. - - This only affects HTML responses. + response (which had CSRF tokens and class ordering normalized during recording). + This modifies the response body in-place. Args: response: Django HttpResponse object Returns: - Modified response with normalized CSRF tokens + Modified response with normalized body """ - content_type = response.get("Content-Type", "") - if "text/html" not in content_type.lower(): - return response - - # Skip normalization for compressed responses - decoding gzip/deflate as UTF-8 would corrupt the body - content_encoding = response.get("Content-Encoding", "").lower() - if content_encoding and content_encoding != "identity": - return response - - # Get response body and normalize CSRF tokens - if hasattr(response, "content") and response.content: - from .csrf_utils import normalize_csrf_in_body + from .html_utils import normalize_html_response - normalized_body = normalize_csrf_in_body(response.content) - if normalized_body is not None and normalized_body != response.content: - response.content = normalized_body - # Update Content-Length header if present - if "Content-Length" in response: - response["Content-Length"] = len(normalized_body) - - return response + return normalize_html_response(response) def _capture_span(self, request: HttpRequest, response: HttpResponse, span_info: SpanInfo) -> None: """Create and collect a span from request/response data. @@ -340,16 +320,14 @@ def _capture_span(self, request: HttpRequest, response: HttpResponse, span_info: if isinstance(content, bytes) and len(content) > 0: response_body = content - # Normalize CSRF tokens in HTML responses for consistent record/replay comparison + # Normalize HTML responses for consistent record/replay comparison # This only affects what is stored in the span, not what the browser receives if response_body: - content_type = response_headers.get("Content-Type", "") - content_encoding = response_headers.get("Content-Encoding", "").lower() - # Skip normalization for compressed responses - decoding gzip/deflate as UTF-8 would corrupt the body - if "text/html" in content_type.lower() and (not content_encoding or content_encoding == "identity"): - from .csrf_utils import normalize_csrf_in_body + from .html_utils import normalize_html_body - response_body = normalize_csrf_in_body(response_body) + content_type = response_headers.get("Content-Type", "") + content_encoding = response_headers.get("Content-Encoding", "") + response_body = normalize_html_body(response_body, content_type, content_encoding) output_value = build_output_value( status_code, diff --git a/tests/unit/test_html_utils.py b/tests/unit/test_html_utils.py new file mode 100644 index 0000000..c6cea6a --- /dev/null +++ b/tests/unit/test_html_utils.py @@ -0,0 +1,290 @@ +"""Unit tests for Django HTML utilities. + +Tests for HTML normalization functions including CSRF token and class ordering. +""" + +import sys +from pathlib import Path +from unittest.mock import MagicMock + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from drift.instrumentation.django.html_utils import ( + CSRF_PLACEHOLDER, + normalize_csrf_in_body, + normalize_html_body, + normalize_html_class_ordering, + normalize_html_response, +) + + +class TestNormalizeHtmlClassOrdering: + """Tests for normalize_html_class_ordering function.""" + + def test_sorts_classes_alphabetically(self): + """Classes within a class attribute should be sorted alphabetically.""" + html = b'
Hello