diff --git a/drift/core/communication/communicator.py b/drift/core/communication/communicator.py index 5d2e9b9..6ae4361 100644 --- a/drift/core/communication/communicator.py +++ b/drift/core/communication/communicator.py @@ -612,31 +612,7 @@ def _extract_response_data(self, struct: Any) -> dict[str, Any]: return {} try: - - def value_to_python(value): - """Convert protobuf Value to Python type.""" - if hasattr(value, "null_value"): - return None - elif hasattr(value, "number_value"): - return value.number_value - elif hasattr(value, "string_value"): - return value.string_value - elif hasattr(value, "bool_value"): - return value.bool_value - elif hasattr(value, "struct_value") and value.struct_value: - return struct_to_dict(value.struct_value) - elif hasattr(value, "list_value") and value.list_value: - return [value_to_python(v) for v in value.list_value.values] - return None - - def struct_to_dict(s): - """Convert protobuf Struct to Python dict.""" - if not hasattr(s, "fields"): - return {} - result = {} - for key, value in s.fields.items(): - result[key] = value_to_python(value) - return result + from ..protobuf_utils import struct_to_dict data = struct_to_dict(struct) diff --git a/drift/core/communication/types.py b/drift/core/communication/types.py index 179c528..0ecd7aa 100644 --- a/drift/core/communication/types.py +++ b/drift/core/communication/types.py @@ -379,6 +379,8 @@ def extract_response_data(struct: Any) -> dict[str, Any]: The CLI returns response data wrapped in a Struct with a "response" field. """ + from ..protobuf_utils import value_to_python + try: # Handle betterproto dict-like struct if hasattr(struct, "items"): @@ -391,8 +393,8 @@ def extract_response_data(struct: Any) -> dict[str, Any]: if hasattr(struct, "fields"): fields = struct.fields if "response" in fields: - return _value_to_python(fields["response"]) - return {k: _value_to_python(v) for k, v in fields.items()} + return value_to_python(fields["response"]) + return {k: value_to_python(v) for k, v in fields.items()} # Direct dict access if isinstance(struct, dict): @@ -403,20 +405,3 @@ def extract_response_data(struct: Any) -> dict[str, Any]: return {} except Exception: return {} - - -def _value_to_python(value: Any) -> Any: - """Convert a protobuf Value to Python native type.""" - if hasattr(value, "null_value"): - return None - if hasattr(value, "number_value"): - return value.number_value - if hasattr(value, "string_value"): - return value.string_value - if hasattr(value, "bool_value"): - return value.bool_value - if hasattr(value, "struct_value"): - return {k: _value_to_python(v) for k, v in value.struct_value.fields.items()} - if hasattr(value, "list_value"): - return [_value_to_python(v) for v in value.list_value.values] - return value diff --git a/drift/core/protobuf_utils.py b/drift/core/protobuf_utils.py new file mode 100644 index 0000000..3912481 --- /dev/null +++ b/drift/core/protobuf_utils.py @@ -0,0 +1,102 @@ +"""Protobuf conversion utilities. + +This module provides utilities for converting protobuf Value and Struct objects +to Python native types. Handles both Google protobuf and betterproto variants. +""" + +from __future__ import annotations + +from typing import Any + + +def value_to_python(value: Any) -> Any: + """Convert protobuf Value or Python native type to Python native type. + + Handles: + 1. Python native types (from betterproto's Struct which stores values directly) + 2. Google protobuf Value objects (from google.protobuf.struct_pb2) + 3. betterproto Value objects (from betterproto.lib.google.protobuf) + + Args: + value: A protobuf Value, native Python type, or nested structure + + Returns: + Python native type (None, bool, int, float, str, list, or dict) + """ + # Handle Python native types (from betterproto or already converted) + if value is None: + return None + if isinstance(value, (bool, int, float, str)): + # Note: bool must be checked before int since bool is a subclass of int + return value + if isinstance(value, list): + return [value_to_python(v) for v in value] + if isinstance(value, dict): + return {k: value_to_python(v) for k, v in value.items()} + + # Handle Google protobuf Value objects using WhichOneof + if hasattr(value, "WhichOneof"): + kind = value.WhichOneof("kind") + if kind == "null_value": + return None + elif kind == "number_value": + return value.number_value + elif kind == "string_value": + return value.string_value + elif kind == "bool_value": + return value.bool_value + elif kind == "struct_value": + return struct_to_dict(value.struct_value) + elif kind == "list_value": + return [value_to_python(v) for v in value.list_value.values] + + # Handle betterproto Value objects using is_set method + if hasattr(value, "is_set"): + try: + if value.is_set("null_value"): + return None + elif value.is_set("number_value"): + return value.number_value + elif value.is_set("string_value"): + return value.string_value + elif value.is_set("bool_value"): + return value.bool_value + elif value.is_set("struct_value"): + sv = value.struct_value + if isinstance(sv, dict): + return {k: value_to_python(v) for k, v in sv.get("fields", sv).items()} + return struct_to_dict(sv) + elif value.is_set("list_value"): + lv = value.list_value + # Handle dict-style list_value (betterproto can store dicts) + if isinstance(lv, dict): + return [value_to_python(v) for v in lv.get("values", [])] + return [value_to_python(v) for v in lv.values] + except (AttributeError, TypeError): + pass # Not a betterproto Value, fall through + + # Fallback: return the value as-is + return value + + +def struct_to_dict(struct: Any) -> dict[str, Any]: + """Convert protobuf Struct to Python dict. + + Args: + struct: A protobuf Struct object (Google or betterproto) + + Returns: + Python dict with converted values + """ + if not struct: + return {} + + # Handle dict-like struct (betterproto sometimes returns dicts directly) + if isinstance(struct, dict): + return {k: value_to_python(v) for k, v in struct.items()} + + # Handle struct with fields attribute + if hasattr(struct, "fields"): + return {k: value_to_python(v) for k, v in struct.fields.items()} + + return {} diff --git a/drift/instrumentation/django/csrf_utils.py b/drift/instrumentation/django/csrf_utils.py deleted file mode 100644 index da1721d..0000000 --- a/drift/instrumentation/django/csrf_utils.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Django CSRF token utilities for consistent record/replay testing. - -This module provides utilities to normalize CSRF tokens so that recorded -and replayed responses produce identical output for comparison. -""" - -from __future__ import annotations - -import logging -import re - -logger = logging.getLogger(__name__) - -CSRF_PLACEHOLDER = "__DRIFT_CSRF__" - - -def normalize_csrf_in_body(body: bytes | None) -> bytes | None: - """Normalize CSRF tokens in response body for consistent record/replay comparison. - - Replaces Django CSRF tokens with a fixed placeholder so that recorded - responses match replayed responses during comparison. - - This should be called after the response is sent to the browser, - but before storing in the span. The actual response to the browser - is unchanged. - - Args: - body: Response body bytes (typically HTML) - - Returns: - Body with CSRF tokens normalized, or original body if not applicable - """ - if not body: - return body - - try: - body_str = body.decode("utf-8") - - # Pattern 1: Hidden input fields with csrfmiddlewaretoken - # - # Handles both single and double quotes, various attribute orders - csrf_input_pattern = ( - r'(]*name=["\']csrfmiddlewaretoken["\'][^>]*value=["\'])' - r'[^"\']+(["\'])' - ) - body_str = re.sub( - csrf_input_pattern, - rf"\g<1>{CSRF_PLACEHOLDER}\2", - body_str, - flags=re.IGNORECASE, - ) - - # Pattern 2: Also handle value before name (different attribute order) - # - csrf_input_pattern_alt = r'(]*value=["\'])[^"\']+(["\'][^>]*name=["\']csrfmiddlewaretoken["\'])' - body_str = re.sub( - csrf_input_pattern_alt, - rf"\g<1>{CSRF_PLACEHOLDER}\2", - body_str, - flags=re.IGNORECASE, - ) - - return body_str.encode("utf-8") - - except Exception as e: - logger.debug(f"Error normalizing CSRF tokens: {e}") - return body diff --git a/drift/instrumentation/django/html_utils.py b/drift/instrumentation/django/html_utils.py new file mode 100644 index 0000000..6d51346 --- /dev/null +++ b/drift/instrumentation/django/html_utils.py @@ -0,0 +1,188 @@ +"""Django HTML utilities for consistent record/replay testing. + +This module provides utilities to normalize HTML content so that recorded +and replayed responses produce identical output for comparison. Includes +CSRF token normalization and HTML class ordering normalization. +""" + +from __future__ import annotations + +import logging +import re +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from django.http import HttpResponse + +logger = logging.getLogger(__name__) + +# Placeholder used to replace CSRF tokens for consistent comparison +CSRF_PLACEHOLDER = "__DRIFT_CSRF__" + + +def normalize_csrf_in_body(body: bytes | None) -> bytes | None: + """Normalize CSRF tokens in response body for consistent record/replay comparison. + + Replaces Django CSRF tokens with a fixed placeholder so that recorded + responses match replayed responses during comparison. + + Args: + body: Response body bytes (typically HTML) + + Returns: + Body with CSRF tokens normalized, or original body if not applicable + """ + if not body: + return body + + try: + body_str = body.decode("utf-8") + + # Pattern 1: Hidden input fields with csrfmiddlewaretoken + # + # Handles both single and double quotes, various attribute orders + csrf_input_pattern = ( + r'(]*name=["\']csrfmiddlewaretoken["\'][^>]*value=["\'])' + r'[^"\']+(["\'])' + ) + body_str = re.sub( + csrf_input_pattern, + rf"\g<1>{CSRF_PLACEHOLDER}\2", + body_str, + flags=re.IGNORECASE, + ) + + # Pattern 2: Also handle value before name (different attribute order) + # + csrf_input_pattern_alt = r'(]*value=["\'])[^"\']+(["\'][^>]*name=["\']csrfmiddlewaretoken["\'])' + body_str = re.sub( + csrf_input_pattern_alt, + rf"\g<1>{CSRF_PLACEHOLDER}\2", + body_str, + flags=re.IGNORECASE, + ) + + return body_str.encode("utf-8") + + except Exception as e: + logger.debug(f"Error normalizing CSRF tokens: {e}") + return body + + +def _sort_classes(match: re.Match) -> str: + """Sort CSS classes within a class attribute match. + + Args: + match: Regex match object containing the class attribute + + Returns: + The class attribute with sorted class names + """ + prefix = match.group(1) # 'class="' or "class='" + classes = match.group(2) # The actual class names + quote = match.group(3) # Closing quote + + # Split classes by whitespace, sort them, and rejoin + class_list = classes.split() + sorted_classes = " ".join(sorted(class_list)) + + return f"{prefix}{sorted_classes}{quote}" + + +def normalize_html_class_ordering(body: bytes | None) -> bytes | None: + """Normalize HTML class attribute ordering for consistent record/replay comparison. + + Sorts CSS class names alphabetically within each class attribute so that + non-deterministic class ordering (e.g., from django-unfold) doesn't cause + false deviations during replay comparison. + + Args: + body: Response body bytes (typically HTML) + + Returns: + Body with class attributes normalized, or original body if not applicable + """ + if not body: + return body + + try: + body_str = body.decode("utf-8") + + # Pattern to match class attributes with double or single quotes + # class="class1 class2 class3" or class='class1 class2 class3' + # Captures: (1) prefix including quote, (2) class names, (3) closing quote + class_pattern = r'(class=["\'])([^"\']+)(["\'])' + + body_str = re.sub( + class_pattern, + _sort_classes, + body_str, + flags=re.IGNORECASE, + ) + + return body_str.encode("utf-8") + + except Exception as e: + logger.debug(f"Error normalizing HTML class ordering: {e}") + return body + + +def normalize_html_body(body: bytes | None, content_type: str, content_encoding: str = "") -> bytes | None: + """Normalize HTML body for consistent record/replay comparison. + + Applies CSRF token normalization and HTML class ordering normalization. + Only processes uncompressed text/html responses. + + Args: + body: Response body bytes + content_type: Content-Type header value + content_encoding: Content-Encoding header value (optional) + + Returns: + Normalized body bytes, or original body if not applicable + """ + if not body: + return body + + if "text/html" not in content_type.lower(): + return body + + # Skip normalization for compressed responses - decoding gzip/deflate as UTF-8 would corrupt the body + encoding = content_encoding.lower() if content_encoding else "" + if encoding and encoding != "identity": + return body + + normalized = normalize_csrf_in_body(body) + normalized = normalize_html_class_ordering(normalized) + + return normalized + + +def normalize_html_response(response: HttpResponse) -> HttpResponse: + """Normalize CSRF tokens and HTML class ordering in Django response. + + Modifies the response body in-place. Only affects uncompressed HTML responses. + Used in REPLAY mode to ensure the actual response matches the recorded + (normalized) response. + + Args: + response: Django HttpResponse object + + Returns: + Modified response with normalized body + """ + content_type = response.get("Content-Type", "") + content_encoding = response.get("Content-Encoding", "") + + if not hasattr(response, "content") or not response.content: + return response + + normalized_body = normalize_html_body(response.content, content_type, content_encoding) + + if normalized_body is not None and normalized_body != response.content: + response.content = normalized_body + # Update Content-Length header if present + if "Content-Length" in response: + response["Content-Length"] = len(normalized_body) + + return response diff --git a/drift/instrumentation/django/middleware.py b/drift/instrumentation/django/middleware.py index 23b014e..d052429 100644 --- a/drift/instrumentation/django/middleware.py +++ b/drift/instrumentation/django/middleware.py @@ -147,8 +147,8 @@ def _handle_replay_request(self, request: HttpRequest, sdk) -> HttpResponse: with SpanUtils.with_span(span_info): response = self.get_response(request) # REPLAY mode: don't capture the span (it's already recorded) - # But do normalize CSRF tokens in the response so comparison succeeds - response = self._normalize_csrf_in_response(response) + # But do normalize the response so comparison succeeds + response = self._normalize_html_response(response) return response finally: # Reset context @@ -264,42 +264,22 @@ def process_view( if route: request._drift_route_template = route # type: ignore - def _normalize_csrf_in_response(self, response: HttpResponse) -> HttpResponse: - """Normalize CSRF tokens in the actual response body for REPLAY mode. + def _normalize_html_response(self, response: HttpResponse) -> HttpResponse: + """Normalize HTML response body for REPLAY mode comparison. In REPLAY mode, we need the actual HTTP response to match the recorded - response (which had CSRF tokens normalized during recording). This modifies - the response body to replace real CSRF tokens with the normalized placeholder. - - This only affects HTML responses. + response (which had CSRF tokens and class ordering normalized during recording). + This modifies the response body in-place. Args: response: Django HttpResponse object Returns: - Modified response with normalized CSRF tokens + Modified response with normalized body """ - content_type = response.get("Content-Type", "") - if "text/html" not in content_type.lower(): - return response - - # Skip normalization for compressed responses - decoding gzip/deflate as UTF-8 would corrupt the body - content_encoding = response.get("Content-Encoding", "").lower() - if content_encoding and content_encoding != "identity": - return response - - # Get response body and normalize CSRF tokens - if hasattr(response, "content") and response.content: - from .csrf_utils import normalize_csrf_in_body + from .html_utils import normalize_html_response - normalized_body = normalize_csrf_in_body(response.content) - if normalized_body is not None and normalized_body != response.content: - response.content = normalized_body - # Update Content-Length header if present - if "Content-Length" in response: - response["Content-Length"] = len(normalized_body) - - return response + return normalize_html_response(response) def _capture_span(self, request: HttpRequest, response: HttpResponse, span_info: SpanInfo) -> None: """Create and collect a span from request/response data. @@ -340,16 +320,14 @@ def _capture_span(self, request: HttpRequest, response: HttpResponse, span_info: if isinstance(content, bytes) and len(content) > 0: response_body = content - # Normalize CSRF tokens in HTML responses for consistent record/replay comparison + # Normalize HTML responses for consistent record/replay comparison # This only affects what is stored in the span, not what the browser receives if response_body: - content_type = response_headers.get("Content-Type", "") - content_encoding = response_headers.get("Content-Encoding", "").lower() - # Skip normalization for compressed responses - decoding gzip/deflate as UTF-8 would corrupt the body - if "text/html" in content_type.lower() and (not content_encoding or content_encoding == "identity"): - from .csrf_utils import normalize_csrf_in_body + from .html_utils import normalize_html_body - response_body = normalize_csrf_in_body(response_body) + content_type = response_headers.get("Content-Type", "") + content_encoding = response_headers.get("Content-Encoding", "") + response_body = normalize_html_body(response_body, content_type, content_encoding) output_value = build_output_value( status_code, diff --git a/tests/unit/test_html_utils.py b/tests/unit/test_html_utils.py new file mode 100644 index 0000000..c6cea6a --- /dev/null +++ b/tests/unit/test_html_utils.py @@ -0,0 +1,290 @@ +"""Unit tests for Django HTML utilities. + +Tests for HTML normalization functions including CSRF token and class ordering. +""" + +import sys +from pathlib import Path +from unittest.mock import MagicMock + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from drift.instrumentation.django.html_utils import ( + CSRF_PLACEHOLDER, + normalize_csrf_in_body, + normalize_html_body, + normalize_html_class_ordering, + normalize_html_response, +) + + +class TestNormalizeHtmlClassOrdering: + """Tests for normalize_html_class_ordering function.""" + + def test_sorts_classes_alphabetically(self): + """Classes within a class attribute should be sorted alphabetically.""" + html = b'
content
' + result = normalize_html_class_ordering(html) + assert result == b'
content
' + + def test_handles_multiple_class_attributes(self): + """Multiple class attributes in the same HTML should all be normalized.""" + html = b'
text
' + result = normalize_html_class_ordering(html) + assert result == b'
text
' + + def test_handles_single_quotes(self): + """Class attributes with single quotes should be handled.""" + html = b"
content
" + result = normalize_html_class_ordering(html) + assert result == b"
content
" + + def test_preserves_single_class(self): + """Single class should remain unchanged.""" + html = b'
content
' + result = normalize_html_class_ordering(html) + assert result == b'
content
' + + def test_handles_empty_class_attribute(self): + """Empty class attribute should remain unchanged.""" + html = b'
content
' + result = normalize_html_class_ordering(html) + assert result == b'
content
' + + def test_preserves_other_attributes(self): + """Other attributes should not be affected.""" + html = b'
content
' + result = normalize_html_class_ordering(html) + assert result == b'
content
' + + def test_handles_none_input(self): + """None input should return None.""" + assert normalize_html_class_ordering(None) is None + + def test_handles_empty_bytes(self): + """Empty bytes should return empty bytes.""" + assert normalize_html_class_ordering(b"") == b"" + + def test_handles_html_without_class_attributes(self): + """HTML without class attributes should remain unchanged.""" + html = b"

Hello

" + result = normalize_html_class_ordering(html) + assert result == html + + def test_case_insensitive_class_attribute(self): + """CLASS (uppercase) should also be handled.""" + html = b'
content
' + result = normalize_html_class_ordering(html) + assert result == b'
content
' + + def test_handles_extra_whitespace_between_classes(self): + """Extra whitespace between classes should be normalized to single space.""" + html = b'
content
' + result = normalize_html_class_ordering(html) + assert result == b'
content
' + + def test_handles_tailwind_style_classes(self): + """Tailwind-style classes with special characters should be sorted correctly.""" + html = b'
content
' + result = normalize_html_class_ordering(html) + assert result == b'
content
' + + def test_handles_non_utf8_gracefully(self): + """Non-UTF8 content should be returned unchanged.""" + # Invalid UTF-8 sequence + html = b"\xff\xfe invalid utf8" + result = normalize_html_class_ordering(html) + assert result == html + + +class TestNormalizeCsrfInBody: + """Tests for normalize_csrf_in_body function.""" + + def test_normalizes_csrf_token_name_before_value(self): + """CSRF token with name before value should be normalized.""" + html = b'' + result = normalize_csrf_in_body(html) + assert result is not None + assert CSRF_PLACEHOLDER.encode() in result + assert b"abc123xyz" not in result + + def test_normalizes_csrf_token_value_before_name(self): + """CSRF token with value before name should be normalized.""" + html = b'' + result = normalize_csrf_in_body(html) + assert result is not None + assert CSRF_PLACEHOLDER.encode() in result + assert b"abc123xyz" not in result + + def test_handles_single_quotes(self): + """CSRF token with single quotes should be normalized.""" + html = b"" + result = normalize_csrf_in_body(html) + assert result is not None + assert CSRF_PLACEHOLDER.encode() in result + assert b"abc123xyz" not in result + + def test_preserves_surrounding_html(self): + """Other HTML content should be preserved.""" + html = b'
' + result = normalize_csrf_in_body(html) + assert result is not None + assert b"
" in result + assert b"
" in result + assert b'name="username"' in result + + def test_handles_multiple_csrf_tokens(self): + """Multiple CSRF tokens in the same HTML should all be normalized.""" + html = b""" +
+
+ """ + result = normalize_csrf_in_body(html) + assert result is not None + assert result.count(CSRF_PLACEHOLDER.encode()) == 2 + assert b"token1" not in result + assert b"token2" not in result + + def test_handles_none_input(self): + """None input should return None.""" + assert normalize_csrf_in_body(None) is None + + def test_handles_empty_bytes(self): + """Empty bytes should return empty bytes.""" + assert normalize_csrf_in_body(b"") == b"" + + def test_handles_html_without_csrf(self): + """HTML without CSRF tokens should remain unchanged.""" + html = b"
" + result = normalize_csrf_in_body(html) + assert result == html + + def test_case_insensitive_csrf_name(self): + """CSRF token name matching should be case-insensitive.""" + html = b'' + result = normalize_csrf_in_body(html) + assert result is not None + assert CSRF_PLACEHOLDER.encode() in result + + def test_handles_non_utf8_gracefully(self): + """Non-UTF8 content should be returned unchanged.""" + html = b"\xff\xfe invalid utf8" + result = normalize_csrf_in_body(html) + assert result == html + + +class TestNormalizeHtmlBody: + """Tests for normalize_html_body function.""" + + def test_normalizes_html_content(self): + """HTML content should be normalized.""" + html = b'
content
' + result = normalize_html_body(html, "text/html") + assert result == b'
content
' + + def test_skips_non_html_content(self): + """Non-HTML content types should not be modified.""" + json_content = b'{"class": "b a"}' + result = normalize_html_body(json_content, "application/json") + assert result == json_content + + def test_skips_compressed_content(self): + """Compressed content should not be modified.""" + html = b'
content
' + result = normalize_html_body(html, "text/html", "gzip") + assert result == html + + def test_allows_identity_encoding(self): + """Identity encoding should be processed normally.""" + html = b'
content
' + result = normalize_html_body(html, "text/html", "identity") + assert result == b'
content
' + + def test_handles_none_body(self): + """None body should return None.""" + assert normalize_html_body(None, "text/html") is None + + def test_handles_empty_body(self): + """Empty body should return empty body.""" + assert normalize_html_body(b"", "text/html") == b"" + + def test_handles_text_html_with_charset(self): + """text/html with charset should be recognized as HTML.""" + html = b'
content
' + result = normalize_html_body(html, "text/html; charset=utf-8") + assert result == b'
content
' + + def test_case_insensitive_content_type(self): + """Content-Type matching should be case-insensitive.""" + html = b'
content
' + result = normalize_html_body(html, "TEXT/HTML") + assert result == b'
content
' + + def test_skips_deflate_encoding(self): + """Deflate encoding should be skipped.""" + html = b'
content
' + result = normalize_html_body(html, "text/html", "deflate") + assert result == html + + def test_skips_br_encoding(self): + """Brotli encoding should be skipped.""" + html = b'
content
' + result = normalize_html_body(html, "text/html", "br") + assert result == html + + +class TestNormalizeHtmlResponse: + """Tests for normalize_html_response function.""" + + def test_normalizes_html_response(self): + """HTML response content should be normalized.""" + response = MagicMock() + response.get.side_effect = lambda key, default="": { + "Content-Type": "text/html", + "Content-Encoding": "", + }.get(key, default) + response.content = b'
content
' + response.__contains__ = lambda self, key: key == "Content-Length" + response.__setitem__ = MagicMock() + + result = normalize_html_response(response) + + assert result.content == b'
content
' + + def test_updates_content_length(self): + """Content-Length header should be updated if present.""" + response = MagicMock() + response.get.side_effect = lambda key, default="": { + "Content-Type": "text/html", + "Content-Encoding": "", + }.get(key, default) + response.content = b'
content
' + response.__contains__ = lambda self, key: key == "Content-Length" + + normalize_html_response(response) + + response.__setitem__.assert_called_with("Content-Length", len(b'
content
')) + + def test_skips_non_html_response(self): + """Non-HTML responses should not be modified.""" + response = MagicMock() + response.get.side_effect = lambda key, default="": { + "Content-Type": "application/json", + "Content-Encoding": "", + }.get(key, default) + original_content = b'{"class": "b a"}' + response.content = original_content + + result = normalize_html_response(response) + + assert result.content == original_content + + def test_handles_response_without_content(self): + """Response without content attribute should be returned unchanged.""" + # spec=["get"] ensures mock only has 'get' attribute, no 'content' + response = MagicMock(spec=["get"]) + response.get.return_value = "text/html" + + result = normalize_html_response(response) + + assert result is response diff --git a/tests/unit/test_protobuf_utils.py b/tests/unit/test_protobuf_utils.py new file mode 100644 index 0000000..bb46de4 --- /dev/null +++ b/tests/unit/test_protobuf_utils.py @@ -0,0 +1,248 @@ +"""Unit tests for protobuf conversion utilities.""" + +import sys +from pathlib import Path +from unittest.mock import MagicMock + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from drift.core.protobuf_utils import struct_to_dict, value_to_python + + +class TestValueToPythonNativeTypes: + """Tests for value_to_python with Python native types.""" + + def test_returns_none_for_none(self): + """None should return None.""" + assert value_to_python(None) is None + + def test_returns_bool_unchanged(self): + """Boolean values should be returned unchanged.""" + assert value_to_python(True) is True + assert value_to_python(False) is False + + def test_returns_int_unchanged(self): + """Integer values should be returned unchanged.""" + assert value_to_python(42) == 42 + assert value_to_python(0) == 0 + assert value_to_python(-1) == -1 + + def test_returns_float_unchanged(self): + """Float values should be returned unchanged.""" + assert value_to_python(3.14) == 3.14 + assert value_to_python(0.0) == 0.0 + + def test_returns_str_unchanged(self): + """String values should be returned unchanged.""" + assert value_to_python("hello") == "hello" + assert value_to_python("") == "" + + def test_processes_list_recursively(self): + """List values should be processed recursively.""" + assert value_to_python([1, 2, 3]) == [1, 2, 3] + assert value_to_python(["a", "b"]) == ["a", "b"] + assert value_to_python([]) == [] + + def test_processes_dict_recursively(self): + """Dict values should be processed recursively.""" + assert value_to_python({"a": 1, "b": 2}) == {"a": 1, "b": 2} + assert value_to_python({}) == {} + + def test_processes_nested_structures(self): + """Nested structures should be processed recursively.""" + nested = { + "users": [ + {"name": "Alice", "age": 30}, + {"name": "Bob", "age": 25}, + ], + "count": 2, + } + assert value_to_python(nested) == nested + + def test_bool_not_confused_with_int(self): + """Bool should be returned as bool, not converted to int.""" + # This is important because bool is a subclass of int in Python + result = value_to_python(True) + assert result is True + assert type(result) is bool + + +class TestValueToPythonGoogleProtobuf: + """Tests for value_to_python with Google protobuf Value objects.""" + + def test_handles_null_value(self): + """Google protobuf null_value should return None.""" + mock_value = MagicMock() + mock_value.WhichOneof.return_value = "null_value" + assert value_to_python(mock_value) is None + + def test_handles_number_value(self): + """Google protobuf number_value should return the number.""" + mock_value = MagicMock() + mock_value.WhichOneof.return_value = "number_value" + mock_value.number_value = 42.0 + assert value_to_python(mock_value) == 42.0 + + def test_handles_string_value(self): + """Google protobuf string_value should return the string.""" + mock_value = MagicMock() + mock_value.WhichOneof.return_value = "string_value" + mock_value.string_value = "hello" + assert value_to_python(mock_value) == "hello" + + def test_handles_bool_value(self): + """Google protobuf bool_value should return the bool.""" + mock_value = MagicMock() + mock_value.WhichOneof.return_value = "bool_value" + mock_value.bool_value = True + assert value_to_python(mock_value) is True + + def test_handles_list_value(self): + """Google protobuf list_value should return a list.""" + # Create mock list items + item1 = MagicMock() + item1.WhichOneof.return_value = "number_value" + item1.number_value = 1.0 + + item2 = MagicMock() + item2.WhichOneof.return_value = "number_value" + item2.number_value = 2.0 + + mock_value = MagicMock() + mock_value.WhichOneof.return_value = "list_value" + mock_value.list_value.values = [item1, item2] + + assert value_to_python(mock_value) == [1.0, 2.0] + + +class TestValueToPythonBetterproto: + """Tests for value_to_python with betterproto Value objects.""" + + def _make_betterproto_value(self, field_name, field_value): + """Helper to create a mock betterproto Value.""" + mock_value = MagicMock() + # No WhichOneof (not Google protobuf) + del mock_value.WhichOneof + + def is_set(name): + return name == field_name + + mock_value.is_set = is_set + setattr(mock_value, field_name, field_value) + return mock_value + + def test_handles_null_value(self): + """Betterproto null_value should return None.""" + mock_value = self._make_betterproto_value("null_value", 0) + assert value_to_python(mock_value) is None + + def test_handles_number_value(self): + """Betterproto number_value should return the number.""" + mock_value = self._make_betterproto_value("number_value", 42.0) + assert value_to_python(mock_value) == 42.0 + + def test_handles_string_value(self): + """Betterproto string_value should return the string.""" + mock_value = self._make_betterproto_value("string_value", "hello") + assert value_to_python(mock_value) == "hello" + + def test_handles_bool_value(self): + """Betterproto bool_value should return the bool.""" + mock_value = self._make_betterproto_value("bool_value", True) + assert value_to_python(mock_value) is True + + def test_handles_dict_style_struct_value(self): + """Betterproto struct_value as dict should be handled.""" + mock_value = MagicMock() + del mock_value.WhichOneof + + def is_set(name): + return name == "struct_value" + + mock_value.is_set = is_set + mock_value.struct_value = {"key": "value"} + + result = value_to_python(mock_value) + assert result == {"key": "value"} + + def test_handles_dict_style_list_value(self): + """Betterproto list_value as dict should be handled.""" + mock_value = MagicMock() + del mock_value.WhichOneof + + def is_set(name): + return name == "list_value" + + mock_value.is_set = is_set + mock_value.list_value = {"values": [1, 2, 3]} + + result = value_to_python(mock_value) + assert result == [1, 2, 3] + + +class TestValueToPythonFallback: + """Tests for value_to_python fallback behavior.""" + + def test_returns_unknown_type_as_is(self): + """Unknown types should be returned as-is (not None).""" + + class CustomType: + pass + + obj = CustomType() + result = value_to_python(obj) + assert result is obj + + def test_handles_is_set_exception_gracefully(self): + """If is_set raises an exception, should fall through gracefully.""" + mock_value = MagicMock() + del mock_value.WhichOneof + + def is_set(name): + raise TypeError("Unexpected error") + + mock_value.is_set = is_set + + # Should not raise, should return value as-is + result = value_to_python(mock_value) + assert result is mock_value + + +class TestStructToDict: + """Tests for struct_to_dict function.""" + + def test_returns_empty_dict_for_none(self): + """None should return empty dict.""" + assert struct_to_dict(None) == {} + + def test_returns_empty_dict_for_empty_struct(self): + """Empty struct should return empty dict.""" + mock_struct = MagicMock() + mock_struct.fields = {} + assert struct_to_dict(mock_struct) == {} + + def test_handles_dict_input(self): + """Dict input should be processed and returned.""" + result = struct_to_dict({"a": 1, "b": "hello"}) + assert result == {"a": 1, "b": "hello"} + + def test_handles_struct_with_fields(self): + """Struct with fields attribute should be converted.""" + mock_struct = MagicMock() + mock_struct.fields = {"name": "Alice", "age": 30} + # Remove items() to ensure we go through the fields path + del mock_struct.items + + result = struct_to_dict(mock_struct) + assert result == {"name": "Alice", "age": 30} + + def test_processes_nested_values(self): + """Nested values in struct should be processed recursively.""" + mock_struct = MagicMock() + mock_struct.fields = { + "user": {"name": "Bob", "scores": [1, 2, 3]}, + } + del mock_struct.items + + result = struct_to_dict(mock_struct) + assert result == {"user": {"name": "Bob", "scores": [1, 2, 3]}}