From 2e3bb3376826ff1ee9c2f3f1a54c9b14db99a32c Mon Sep 17 00:00:00 2001 From: Antonio Aranda <102337110+arandito@users.noreply.github.com> Date: Wed, 11 Mar 2026 12:11:19 -0400 Subject: [PATCH 1/4] Add smithy-xml package --- .../smithy-core/src/smithy_core/traits.py | 40 ++ packages/smithy-xml/CHANGELOG.md | 1 + packages/smithy-xml/NOTICE | 1 + packages/smithy-xml/README.md | 4 + packages/smithy-xml/pyproject.toml | 50 ++ .../smithy-xml/src/smithy_xml/__init__.py | 55 ++ .../src/smithy_xml/_private/__init__.py | 2 + .../src/smithy_xml/_private/deserializers.py | 378 ++++++++++ .../src/smithy_xml/_private/readers.py | 55 ++ .../src/smithy_xml/_private/serializers.py | 372 ++++++++++ packages/smithy-xml/src/smithy_xml/py.typed | 1 + .../smithy-xml/src/smithy_xml/settings.py | 16 + packages/smithy-xml/tests/__init__.py | 2 + packages/smithy-xml/tests/unit/__init__.py | 662 ++++++++++++++++++ .../tests/unit/test_deserializers.py | 145 ++++ .../smithy-xml/tests/unit/test_serializers.py | 209 ++++++ pyproject.toml | 1 + uv.lock | 11 + 18 files changed, 2005 insertions(+) create mode 100644 packages/smithy-xml/CHANGELOG.md create mode 100644 packages/smithy-xml/NOTICE create mode 100644 packages/smithy-xml/README.md create mode 100644 packages/smithy-xml/pyproject.toml create mode 100644 packages/smithy-xml/src/smithy_xml/__init__.py create mode 100644 packages/smithy-xml/src/smithy_xml/_private/__init__.py create mode 100644 packages/smithy-xml/src/smithy_xml/_private/deserializers.py create mode 100644 packages/smithy-xml/src/smithy_xml/_private/readers.py create mode 100644 packages/smithy-xml/src/smithy_xml/_private/serializers.py create mode 100644 packages/smithy-xml/src/smithy_xml/py.typed create mode 100644 packages/smithy-xml/src/smithy_xml/settings.py create mode 100644 packages/smithy-xml/tests/__init__.py create mode 100644 packages/smithy-xml/tests/unit/__init__.py create mode 100644 packages/smithy-xml/tests/unit/test_deserializers.py create mode 100644 packages/smithy-xml/tests/unit/test_serializers.py diff --git a/packages/smithy-core/src/smithy_core/traits.py b/packages/smithy-core/src/smithy_core/traits.py index d7dfd22cf..5b64c82f8 100644 --- a/packages/smithy-core/src/smithy_core/traits.py +++ b/packages/smithy-core/src/smithy_core/traits.py @@ -350,3 +350,43 @@ def name(self) -> str: @property def scheme(self) -> str | None: return self.document_value.get("scheme") # type: ignore + + +@dataclass(init=False, frozen=True) +class XmlNameTrait(Trait, id=ShapeID("smithy.api#xmlName")): + document_value: str | None = None + + def __post_init__(self): + assert isinstance(self.document_value, str) + + @property + def value(self) -> str: + return self.document_value # type: ignore + + +@dataclass(init=False, frozen=True) +class XmlNamespaceTrait(Trait, id=ShapeID("smithy.api#xmlNamespace")): + def __post_init__(self): + assert isinstance(self.document_value, Mapping) + assert isinstance(self.document_value["uri"], str) + assert isinstance(self.document_value.get("prefix"), str | None) + + @property + def uri(self) -> str: + return self.document_value["uri"] # type: ignore + + @property + def prefix(self) -> str | None: + return self.document_value.get("prefix") # type: ignore + + +@dataclass(init=False, frozen=True) +class XmlFlattenedTrait(Trait, id=ShapeID("smithy.api#xmlFlattened")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class XmlAttributeTrait(Trait, id=ShapeID("smithy.api#xmlAttribute")): + def __post_init__(self): + assert self.document_value is None diff --git a/packages/smithy-xml/CHANGELOG.md b/packages/smithy-xml/CHANGELOG.md new file mode 100644 index 000000000..5ddad421e --- /dev/null +++ b/packages/smithy-xml/CHANGELOG.md @@ -0,0 +1 @@ +# Changelog \ No newline at end of file diff --git a/packages/smithy-xml/NOTICE b/packages/smithy-xml/NOTICE new file mode 100644 index 000000000..616fc5889 --- /dev/null +++ b/packages/smithy-xml/NOTICE @@ -0,0 +1 @@ +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/packages/smithy-xml/README.md b/packages/smithy-xml/README.md new file mode 100644 index 000000000..bcd7da956 --- /dev/null +++ b/packages/smithy-xml/README.md @@ -0,0 +1,4 @@ +# smithy-xml + +This package provides generic XML serialization and deserialization support +for Smithy clients and servers. diff --git a/packages/smithy-xml/pyproject.toml b/packages/smithy-xml/pyproject.toml new file mode 100644 index 000000000..eca3b80f9 --- /dev/null +++ b/packages/smithy-xml/pyproject.toml @@ -0,0 +1,50 @@ +[project] +name = "smithy-xml" +dynamic = ["version"] +requires-python = ">=3.12" +authors = [ + {name = "Amazon Web Services"}, +] +description = "XML serialization and deserialization support for Smithy tooling." +readme = "README.md" +license = {text = "Apache License 2.0"} +keywords = ["smithy", "sdk", "xml"] +classifiers = [ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "Intended Audience :: System Administrators", + "Natural Language :: English", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Software Development :: Libraries" +] +dependencies = [ + "smithy-core", +] + +[project.urls] +"Changelog" = "https://github.com/smithy-lang/smithy-python/blob/develop/packages/smithy-xml/CHANGELOG.md" +"Code" = "https://github.com/smithy-lang/smithy-python/blob/develop/packages/smithy-xml/" +"Issue tracker" = "https://github.com/smithy-lang/smithy-python/issues" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.version] +path = "src/smithy_xml/__init__.py" + +[tool.hatch.build] +exclude = [ + "tests", +] + +[tool.ruff] +src = ["src"] diff --git a/packages/smithy-xml/src/smithy_xml/__init__.py b/packages/smithy-xml/src/smithy_xml/__init__.py new file mode 100644 index 000000000..9616aa8ef --- /dev/null +++ b/packages/smithy-xml/src/smithy_xml/__init__.py @@ -0,0 +1,55 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from io import BytesIO +from xml.etree.ElementTree import iterparse + +from smithy_core.codecs import Codec +from smithy_core.deserializers import ShapeDeserializer +from smithy_core.interfaces import BytesReader, BytesWriter +from smithy_core.serializers import ShapeSerializer +from smithy_core.types import TimestampFormat + +from ._private.deserializers import XMLShapeDeserializer as _XMLShapeDeserializer +from ._private.readers import XMLEventReader as _XMLEventReader +from ._private.serializers import XMLShapeSerializer as _XMLShapeSerializer +from .settings import XMLSettings + +__version__ = "0.0.1" +__all__ = ("XMLCodec", "XMLSettings") + + +class XMLCodec(Codec): + """A codec for converting shapes to/from XML.""" + + def __init__( + self, + default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME, + default_namespace: str | None = None, + ) -> None: + self._settings = XMLSettings( + default_timestamp_format=default_timestamp_format, + default_namespace=default_namespace, + ) + + @property + def media_type(self) -> str: + return "application/xml" + + def create_serializer(self, sink: BytesWriter) -> ShapeSerializer: + return _XMLShapeSerializer(sink=sink, settings=self._settings) + + def create_deserializer( + self, + source: bytes | BytesReader, + *, + wrapper_elements: tuple[str, ...] = (), + ) -> ShapeDeserializer: + if isinstance(source, bytes): + source = BytesIO(source) + reader = _XMLEventReader( + iterparse(source, events=("start", "end")) # noqa: S314 + ) + return _XMLShapeDeserializer( + settings=self._settings, reader=reader, wrapper_elements=wrapper_elements + ) diff --git a/packages/smithy-xml/src/smithy_xml/_private/__init__.py b/packages/smithy-xml/src/smithy_xml/_private/__init__.py new file mode 100644 index 000000000..33cbe867a --- /dev/null +++ b/packages/smithy-xml/src/smithy_xml/_private/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/packages/smithy-xml/src/smithy_xml/_private/deserializers.py b/packages/smithy-xml/src/smithy_xml/_private/deserializers.py new file mode 100644 index 000000000..ae1ece78f --- /dev/null +++ b/packages/smithy-xml/src/smithy_xml/_private/deserializers.py @@ -0,0 +1,378 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import datetime +from base64 import b64decode +from collections.abc import Callable +from decimal import Decimal +from xml.etree.ElementTree import Element + +from smithy_core.deserializers import ShapeDeserializer, SpecificShapeDeserializer +from smithy_core.documents import Document +from smithy_core.exceptions import SmithyError +from smithy_core.schemas import Schema +from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.traits import ( + TimestampFormatTrait, + XmlAttributeTrait, + XmlFlattenedTrait, + XmlNameTrait, +) + +from ..settings import XMLSettings +from .readers import XMLEvent, XMLEventReader + + +def _local_name(tag: str) -> str: + """Strip namespace URI from an element tag: {uri}local -> local.""" + if tag.startswith("{"): + return tag.split("}", 1)[1] + return tag + + +def _expected_root_name(schema: Schema) -> str | None: + """Get the expected root element name for root validation.""" + if schema.shape_type not in (ShapeType.STRUCTURE, ShapeType.UNION): + return None + if xml_name := schema.get_trait(XmlNameTrait): + return xml_name.value + return schema.id.name + + +def _validate_element_name(expected: str, elem: Element) -> None: + """Raise XMLParseError if the element's local name doesn't match expected.""" + found = _local_name(elem.tag) + if found != expected: + raise XMLParseError(f"Expected element '{expected}', got '{found}'") + + +def _xml_member_name(member_schema: Schema) -> str: + """Get the XML element name for a member, respecting @xmlName.""" + if xml_name := member_schema.get_trait(XmlNameTrait): + return xml_name.value + return member_schema.expect_member_name() + + +def _parse_xml_float(text: str) -> float: + """Parse an XML float string, handling NaN and Infinity.""" + match text: + case "NaN": + return float("nan") + case "Infinity": + return float("inf") + case "-Infinity": + return float("-inf") + case _: + return float(text) + + +class XMLParseError(SmithyError): + def __init__(self, message: str) -> None: + super().__init__(f"Error parsing XML: {message}") + + +class XMLShapeDeserializer(ShapeDeserializer): + """Deserializer that reads XML from a streaming pull parser.""" + + def __init__( + self, + settings: XMLSettings, + reader: XMLEventReader, + wrapper_elements: tuple[str, ...] = (), + ) -> None: + self._settings = settings + self._reader = reader + self._is_root = not bool(wrapper_elements) + self._xml_names: dict[ShapeID, dict[str, Schema]] = {} + self._preconsumed_start: Element | None = None + + # Wrapper elements are protocol transport containers (e.g. awsQuery's + # ). The last wrapper's start element is kept + # so that the next read can reuse it. + for wrapper in wrapper_elements: + event = next(self._reader) + if event.type != "start": + raise XMLParseError(f"Expected start element, got '{event.type}'") + _validate_element_name(wrapper, event.elem) + self._preconsumed_start = event.elem + + def is_null(self) -> bool: + return False + + def read_null(self) -> None: + return None + + def read_boolean(self, schema: Schema) -> bool: + text = self._read_text() + match text: + case "true": + return True + case "false": + return False + case _: + raise XMLParseError(f"Expected 'true' or 'false', got '{text}'") + + def read_blob(self, schema: Schema) -> bytes: + return b64decode(self._read_text()) + + def read_integer(self, schema: Schema) -> int: + return int(self._read_text()) + + def read_float(self, schema: Schema) -> float: + return _parse_xml_float(self._read_text()) + + def read_big_decimal(self, schema: Schema) -> Decimal: + return Decimal(self._read_text()) + + def read_string(self, schema: Schema) -> str: + return self._read_text() + + def read_document(self, schema: Schema) -> Document: + raise NotImplementedError("XML does not support document types") + + def read_timestamp(self, schema: Schema) -> datetime.datetime: + fmt = self._settings.default_timestamp_format + if format_trait := schema.get_trait(TimestampFormatTrait): + fmt = format_trait.format + + text = self._read_text() + return fmt.deserialize(text) + + def read_struct( + self, + schema: Schema, + consumer: Callable[[Schema, "ShapeDeserializer"], None], + ) -> None: + xml_names = self._get_xml_names(schema) + start_from_wrapper = self._preconsumed_start is not None + start_elem = self._consume_start_event() + if self._is_root: + self._is_root = False + expected = _expected_root_name(schema) + if expected is not None: + _validate_element_name(expected, start_elem) + + # Wrapper elements are protocol transport containers, not modeled structs, + # so their attributes cannot be deserialized as struct members. + if not start_from_wrapper: + for member_schema in schema.members.values(): + if member_schema.get_trait(XmlAttributeTrait) is None: + continue + expected_attr_name = _xml_member_name(member_schema) + for attr_name, attr_value in start_elem.attrib.items(): + attr_local_name = _local_name(attr_name) + if attr_local_name == expected_attr_name: + consumer( + member_schema, + _AttributeDeserializer(attr_value, self._settings), + ) + break + + # Flattened members lack an enclosing element, so there is no way to + # know when all items have been parsed. Their events are collected + # during iteration and replayed through a bounded reader afterwards. + flattened_buffers: dict[str, list[XMLEvent]] = {} + flattened_names = { + xml_name: member_schema + for xml_name, member_schema in xml_names.items() + if member_schema.get_trait(XmlFlattenedTrait) is not None + } + + while self._reader.peek().type != "end": + tag = _local_name(self._reader.peek().elem.tag) + + if tag in flattened_names: + flattened_buffers.setdefault(tag, []).extend(self._buffer_element()) + elif tag in xml_names: + consumer(xml_names[tag], self) + else: + # Skip unknown tag + self._consume_start_event() + self._skip_to_end() + + next(self._reader) + + for tag, events in flattened_buffers.items(): + member_schema = flattened_names[tag] + buffered_de = XMLShapeDeserializer( + self._settings, + XMLEventReader(iter(events)), + ) + consumer(member_schema, buffered_de) + + def read_list( + self, + schema: Schema, + consumer: Callable[["ShapeDeserializer"], None], + ) -> None: + is_flattened = schema.get_trait(XmlFlattenedTrait) is not None + if not is_flattened: + self._consume_start_event() + while self._reader.peek().type != "end": + consumer(self) + else: + while self._reader.has_next(): + consumer(self) + + if not is_flattened: + next(self._reader) + + def read_map( + self, + schema: Schema, + consumer: Callable[[str, "ShapeDeserializer"], None], + ) -> None: + is_flattened = schema.get_trait(XmlFlattenedTrait) is not None + key_schema = schema.members["key"] + value_schema = schema.members["value"] + key_tag = _xml_member_name(key_schema) + value_tag = _xml_member_name(value_schema) + + if not is_flattened: + self._consume_start_event() + while self._reader.peek().type != "end": + self._read_map_entry(key_tag, value_tag, consumer) + else: + while self._reader.has_next(): + self._read_map_entry(key_tag, value_tag, consumer) + + if not is_flattened: + next(self._reader) + + def _read_text(self) -> str: + """Consume a complete element (start through end) and return its text.""" + elem = self._consume_start_event() + self._skip_to_end() + # elem.text is populated only after consuming the "end" event + return elem.text or "" + + def _consume_start_event(self) -> Element: + """Consume and return the next start element. + + If a start element was pre-consumed (e.g. from consuming wrapper elements), + it is returned first and cleared. + """ + if self._preconsumed_start is not None: + elem = self._preconsumed_start + self._preconsumed_start = None + return elem + event = next(self._reader) + if event.type != "start": + raise XMLParseError(f"Expected start element, got '{event.type}'") + return event.elem + + def _skip_to_end(self) -> None: + """Skip to the matching end event. Assumes start was already consumed.""" + depth = 1 + while depth > 0: + event = next(self._reader) + if event.type == "start": + depth += 1 + elif event.type == "end": + depth -= 1 + + def _buffer_element(self) -> list[XMLEvent]: + """Buffer a complete element's events (start through matching end).""" + events: list[XMLEvent] = [] + event = next(self._reader) + events.append(event) + depth = 1 + while depth > 0: + event = next(self._reader) + events.append(event) + if event.type == "start": + depth += 1 + elif event.type == "end": + depth -= 1 + return events + + def _get_xml_names(self, schema: Schema) -> dict[str, Schema]: + """Get or build the XML element name -> member schema mapping for a shape.""" + if schema.id in self._xml_names: + return self._xml_names[schema.id] + result: dict[str, Schema] = {} + for member_schema in schema.members.values(): + if member_schema.get_trait(XmlAttributeTrait) is not None: + continue + xml_name = _xml_member_name(member_schema) + result[xml_name] = member_schema + self._xml_names[schema.id] = result + return result + + def _read_map_entry( + self, + key_tag: str, + value_tag: str, + consumer: Callable[[str, "ShapeDeserializer"], None], + ) -> None: + """Read one map entry element and emit key/value pairs via consumer.""" + self._consume_start_event() + + key: str | None = None + while self._reader.peek().type != "end": + child_tag = _local_name(self._reader.peek().elem.tag) + if child_tag == key_tag: + key = self._read_text() + elif child_tag == value_tag: + if key is None: + raise XMLParseError( + "Map key element must appear before value element" + ) + consumer(key, self) + else: + # Skip unknown child tag + self._consume_start_event() + self._skip_to_end() + + next(self._reader) + + +class _AttributeDeserializer(SpecificShapeDeserializer): + """Deserializer for a value extracted from an XML attribute string.""" + + def __init__(self, value: str, settings: XMLSettings) -> None: + self._value = value + self._settings = settings + + def read_string(self, schema: Schema) -> str: + return self._value + + def read_boolean(self, schema: Schema) -> bool: + match self._value: + case "true": + return True + case "false": + return False + case _: + raise XMLParseError(f"Expected 'true' or 'false', got '{self._value}'") + + def read_byte(self, schema: Schema) -> int: + return self.read_integer(schema) + + def read_short(self, schema: Schema) -> int: + return self.read_integer(schema) + + def read_integer(self, schema: Schema) -> int: + return int(self._value) + + def read_long(self, schema: Schema) -> int: + return self.read_integer(schema) + + def read_big_integer(self, schema: Schema) -> int: + return self.read_integer(schema) + + def read_float(self, schema: Schema) -> float: + return _parse_xml_float(self._value) + + def read_double(self, schema: Schema) -> float: + return self.read_float(schema) + + def read_big_decimal(self, schema: Schema) -> Decimal: + return Decimal(self._value) + + def read_timestamp(self, schema: Schema) -> datetime.datetime: + fmt = self._settings.default_timestamp_format + if format_trait := schema.get_trait(TimestampFormatTrait): + fmt = format_trait.format + + return fmt.deserialize(self._value) diff --git a/packages/smithy-xml/src/smithy_xml/_private/readers.py b/packages/smithy-xml/src/smithy_xml/_private/readers.py new file mode 100644 index 000000000..98164f1e2 --- /dev/null +++ b/packages/smithy-xml/src/smithy_xml/_private/readers.py @@ -0,0 +1,55 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Iterator +from typing import NamedTuple +from xml.etree.ElementTree import Element + + +class XMLEvent(NamedTuple): + type: str + elem: Element + + +class XMLEventReader: + """Buffered iterator over XML pull parser events with peek support. + + Wraps an iterator of ``(event, element)`` tuples — either from + ``iterparse`` (streaming from a byte source) or from an in-memory list + (for flattened member replay). + """ + + def __init__(self, events: Iterator[tuple[str, Element] | XMLEvent]) -> None: + self._iter = events + self._pending: XMLEvent | None = None + + def __iter__(self): + return self + + def __next__(self) -> XMLEvent: + if self._pending is not None: + result = self._pending + self._pending = None + return result + return self._next() + + def _next(self) -> XMLEvent: + event = next(self._iter) + if isinstance(event, XMLEvent): + return event + event_type, elem = event + return XMLEvent(event_type, elem) + + def has_next(self) -> bool: + if self._pending is not None: + return True + try: + self._pending = self._next() + return True + except StopIteration: + return False + + def peek(self) -> XMLEvent: + if self._pending is None: + self._pending = self._next() + return self._pending diff --git a/packages/smithy-xml/src/smithy_xml/_private/serializers.py b/packages/smithy-xml/src/smithy_xml/_private/serializers.py new file mode 100644 index 000000000..c3f242070 --- /dev/null +++ b/packages/smithy-xml/src/smithy_xml/_private/serializers.py @@ -0,0 +1,372 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from base64 import b64encode +from collections.abc import Callable +from contextlib import AbstractContextManager +from datetime import datetime +from decimal import Decimal +from types import TracebackType +from typing import Self +from xml.etree.ElementTree import Element, SubElement, tostring + +from smithy_core.documents import Document +from smithy_core.interfaces import BytesWriter +from smithy_core.schemas import Schema +from smithy_core.serializers import ( + InterceptingSerializer, + MapSerializer, + ShapeSerializer, + SpecificShapeSerializer, +) +from smithy_core.shapes import ShapeType +from smithy_core.traits import ( + TimestampFormatTrait, + XmlAttributeTrait, + XmlFlattenedTrait, + XmlNamespaceTrait, + XmlNameTrait, +) + +from ..settings import XMLSettings + +_INF: float = float("inf") +_NEG_INF: float = float("-inf") + + +def _xml_member_name(member_schema: Schema) -> str: + """Get the XML element name for a member, respecting @xmlName.""" + if xml_name := member_schema.get_trait(XmlNameTrait): + return xml_name.value + return member_schema.expect_member_name() + + +def _xml_root_name(schema: Schema) -> str: + """Get the XML root element name, respecting @xmlName and member targets.""" + if xml_name := schema.get_trait(XmlNameTrait): + return xml_name.value + if schema.member_target is not None: + return schema.expect_member_target().id.name + return schema.id.name + + +def _set_xml_namespace( + element: Element, + schema: Schema, + settings: XMLSettings, + *, + is_root: bool = False, +) -> None: + """Apply @xmlNamespace to an element, or the default namespace if root.""" + if namespace_trait := schema.get_trait(XmlNamespaceTrait): + if namespace_trait.prefix: + element.set(f"xmlns:{namespace_trait.prefix}", namespace_trait.uri) + else: + element.set("xmlns", namespace_trait.uri) + return + + if is_root and settings.default_namespace: + element.set("xmlns", settings.default_namespace) + + +def _format_xml_float(value: float) -> str: + """Format a float for XML, handling NaN and Infinity.""" + if value != value: + return "NaN" + if value == _INF: + return "Infinity" + if value == _NEG_INF: + return "-Infinity" + return repr(value) + + +def _is_flattened_collection_schema(schema: Schema) -> bool: + """Check if a schema is a flattened list or map.""" + return schema.get_trait(XmlFlattenedTrait) is not None and schema.shape_type in ( + ShapeType.LIST, + ShapeType.MAP, + ) + + +class XMLShapeSerializer(ShapeSerializer): + """Serializes Smithy shapes into XML and writes the result to a BytesWriter. + + Builds an in-memory XML tree backed by an element stack. ``write_*`` + methods target the top element, and struct/list/map serializers push and + pop child elements to control nesting. ``flush`` writes the tree to the + sink. + """ + + def __init__(self, sink: BytesWriter, settings: XMLSettings) -> None: + self._sink = sink + self.settings = settings + self._root: Element | None = None + self.element_stack: list[Element] = [] + + @property + def current(self) -> Element: + return self.element_stack[-1] + + def ensure_root(self, schema: Schema) -> None: + if self._root is not None: + return + root = Element(_xml_root_name(schema)) + _set_xml_namespace(root, schema, self.settings, is_root=True) + self._root = root + self.element_stack.append(root) + + def begin_struct( + self, schema: "Schema" + ) -> AbstractContextManager["ShapeSerializer"]: + return XMLStructSerializer(self, schema) + + def begin_list( + self, schema: "Schema", size: int + ) -> AbstractContextManager["ShapeSerializer"]: + return XMLListSerializer(self, schema) + + def begin_map( + self, schema: "Schema", size: int + ) -> AbstractContextManager["MapSerializer"]: + return XMLMapSerializer(self, schema) + + def write_null(self, schema: "Schema") -> None: + self.ensure_root(schema) + + def write_boolean(self, schema: "Schema", value: bool) -> None: + self.ensure_root(schema) + self.current.text = "true" if value else "false" + + def write_integer(self, schema: "Schema", value: int) -> None: + self.ensure_root(schema) + self.current.text = str(value) + + def write_float(self, schema: "Schema", value: float) -> None: + self.ensure_root(schema) + self.current.text = _format_xml_float(value) + + def write_big_decimal(self, schema: "Schema", value: Decimal) -> None: + self.ensure_root(schema) + self.current.text = str(value.normalize()) + + def write_string(self, schema: "Schema", value: str) -> None: + self.ensure_root(schema) + self.current.text = value + + def write_blob(self, schema: "Schema", value: bytes) -> None: + self.ensure_root(schema) + self.current.text = b64encode(value).decode("utf-8") + + def write_timestamp(self, schema: "Schema", value: datetime) -> None: + self.ensure_root(schema) + fmt = self.settings.default_timestamp_format + if format_trait := schema.get_trait(TimestampFormatTrait): + fmt = format_trait.format + self.current.text = str(fmt.serialize(value)) + + def write_document(self, schema: "Schema", value: Document) -> None: + raise NotImplementedError("XML does not support document types.") + + def flush(self) -> None: + if self._root is None: + return + xml_bytes = tostring(self._root, encoding="utf-8", xml_declaration=False) + self._sink.write(xml_bytes) + + self._root = None + self.element_stack.clear() + + +class XMLStructSerializer(InterceptingSerializer): + """Serializes struct members as child XML elements. + + ``before`` pushes a child element for the member onto the parent's stack, + and ``after`` pops it. Attributes and flattened collections are special-cased + to skip the push/pop. + """ + + def __init__(self, parent: XMLShapeSerializer, schema: Schema) -> None: + self._parent = parent + self._schema = schema + + def __enter__(self) -> Self: + self._parent.ensure_root(self._schema) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + def before(self, schema: "Schema") -> ShapeSerializer: + member_name = _xml_member_name(schema) + + # Attributes are written on the current element, not as children. + if schema.get_trait(XmlAttributeTrait) is not None: + return XMLAttributeSerializer( + self._parent.current, member_name, self._parent.settings + ) + + # Flattened collections have no wrapper element. Items are added + # directly under the current element without changing the stack. + if _is_flattened_collection_schema(schema): + return self._parent + + # Non-flattened collections push a wrapper element onto the stack. + child = SubElement(self._parent.current, member_name) + _set_xml_namespace(child, schema, self._parent.settings) + self._parent.element_stack.append(child) + return self._parent + + def after(self, schema: "Schema") -> None: + # Attributes and flattened collections didn't push, so don't pop. + if schema.get_trait(XmlAttributeTrait) is not None: + return + if _is_flattened_collection_schema(schema): + return + self._parent.element_stack.pop() + + +class XMLListSerializer(InterceptingSerializer): + """Serializes list items as repeated child elements. + + ``before`` pushes a child element for each item, ``after`` pops it. + """ + + def __init__(self, parent: XMLShapeSerializer, schema: Schema) -> None: + self._parent = parent + self._schema = schema + is_flattened = schema.get_trait(XmlFlattenedTrait) is not None + + if is_flattened: + if schema.member_target is not None: + self._item_tag = _xml_member_name(schema) + else: + self._item_tag = _xml_root_name(schema) + else: + self._item_tag = _xml_member_name(schema.members["member"]) + + def __enter__(self) -> Self: + self._parent.ensure_root(self._schema) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + def before(self, schema: "Schema") -> ShapeSerializer: + child = SubElement(self._parent.current, self._item_tag) + _set_xml_namespace(child, schema, self._parent.settings) + self._parent.element_stack.append(child) + return self._parent + + def after(self, schema: "Schema") -> None: + self._parent.element_stack.pop() + + +class XMLMapSerializer(MapSerializer): + """Serializes map entries as ```` elements. + + Each ``entry`` call pushes the value element onto the stack for the + value writer callback, then pops it. + """ + + def __init__(self, parent: XMLShapeSerializer, schema: Schema) -> None: + self._parent = parent + self._schema = schema + self._is_flattened = schema.get_trait(XmlFlattenedTrait) is not None + + self._key_schema = schema.members["key"] + self._value_schema = schema.members["value"] + self._key_tag = _xml_member_name(self._key_schema) + self._value_tag = _xml_member_name(self._value_schema) + + if self._is_flattened: + if schema.member_target is not None: + self._entry_tag = _xml_member_name(schema) + else: + self._entry_tag = _xml_root_name(schema) + else: + self._entry_tag = "entry" + + def __enter__(self) -> Self: + self._parent.ensure_root(self._schema) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + def entry(self, key: str, value_writer: Callable[[ShapeSerializer], None]) -> None: + settings = self._parent.settings + entry_el = SubElement(self._parent.current, self._entry_tag) + if self._is_flattened: + _set_xml_namespace(entry_el, self._schema, settings) + + key_el = SubElement(entry_el, self._key_tag) + _set_xml_namespace(key_el, self._key_schema, settings) + key_el.text = key + + value_el = SubElement(entry_el, self._value_tag) + _set_xml_namespace(value_el, self._value_schema, settings) + self._parent.element_stack.append(value_el) + value_writer(self._parent) + self._parent.element_stack.pop() + + +class XMLAttributeSerializer(SpecificShapeSerializer): + """Serializer that writes values as XML attributes on the parent element.""" + + def __init__(self, element: Element, attr_name: str, settings: XMLSettings) -> None: + self._element = element + self._attr_name = attr_name + self._settings = settings + + def write_null(self, schema: "Schema") -> None: + pass + + def write_boolean(self, schema: "Schema", value: bool) -> None: + self._element.set(self._attr_name, "true" if value else "false") + + def write_byte(self, schema: "Schema", value: int) -> None: + self.write_integer(schema, value) + + def write_short(self, schema: "Schema", value: int) -> None: + self.write_integer(schema, value) + + def write_integer(self, schema: "Schema", value: int) -> None: + self._element.set(self._attr_name, str(value)) + + def write_long(self, schema: "Schema", value: int) -> None: + self.write_integer(schema, value) + + def write_big_integer(self, schema: "Schema", value: int) -> None: + self.write_integer(schema, value) + + def write_float(self, schema: "Schema", value: float) -> None: + self._element.set(self._attr_name, _format_xml_float(value)) + + def write_double(self, schema: "Schema", value: float) -> None: + self.write_float(schema, value) + + def write_big_decimal(self, schema: "Schema", value: Decimal) -> None: + self._element.set(self._attr_name, str(value.normalize())) + + def write_string(self, schema: "Schema", value: str) -> None: + self._element.set(self._attr_name, value) + + def write_timestamp(self, schema: "Schema", value: datetime) -> None: + fmt = self._settings.default_timestamp_format + if format_trait := schema.get_trait(TimestampFormatTrait): + fmt = format_trait.format + self._element.set(self._attr_name, str(fmt.serialize(value))) diff --git a/packages/smithy-xml/src/smithy_xml/py.typed b/packages/smithy-xml/src/smithy_xml/py.typed new file mode 100644 index 000000000..f5642f79f --- /dev/null +++ b/packages/smithy-xml/src/smithy_xml/py.typed @@ -0,0 +1 @@ +Marker diff --git a/packages/smithy-xml/src/smithy_xml/settings.py b/packages/smithy-xml/src/smithy_xml/settings.py new file mode 100644 index 000000000..ea25cd1b6 --- /dev/null +++ b/packages/smithy-xml/src/smithy_xml/settings.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +from smithy_core.types import TimestampFormat + + +@dataclass(frozen=True) +class XMLSettings: + """Configuration for XML serialization/deserialization.""" + + default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME + """Default timestamp format when a member does not define @timestampFormat.""" + + default_namespace: str | None = None + """Default XML namespace (``xmlns``) applied to the root element during serialization.""" diff --git a/packages/smithy-xml/tests/__init__.py b/packages/smithy-xml/tests/__init__.py new file mode 100644 index 000000000..04f8b7b76 --- /dev/null +++ b/packages/smithy-xml/tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/packages/smithy-xml/tests/unit/__init__.py b/packages/smithy-xml/tests/unit/__init__.py new file mode 100644 index 000000000..630200031 --- /dev/null +++ b/packages/smithy-xml/tests/unit/__init__.py @@ -0,0 +1,662 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from datetime import UTC, datetime +from decimal import Decimal +from typing import Any, Self + +from smithy_core.deserializers import ShapeDeserializer +from smithy_core.prelude import ( + BIG_DECIMAL, + BLOB, + BOOLEAN, + FLOAT, + INTEGER, + STRING, + TIMESTAMP, +) +from smithy_core.schemas import Schema +from smithy_core.serializers import ShapeSerializer +from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.traits import ( + TimestampFormatTrait, + XmlAttributeTrait, + XmlFlattenedTrait, + XmlNamespaceTrait, + XmlNameTrait, +) + +STRING_LIST_SCHEMA = Schema.collection( + id=ShapeID("smithy.example#StringList"), + shape_type=ShapeType.LIST, + members={ + "member": { + "target": STRING, + } + }, +) + +STRING_MAP_SCHEMA = Schema.collection( + id=ShapeID("smithy.example#StringMap"), + shape_type=ShapeType.MAP, + members={ + "key": { + "target": STRING, + }, + "value": { + "target": STRING, + }, + }, +) + +# List with @xmlName on the member element +RENAMED_LIST_SCHEMA = Schema.collection( + id=ShapeID("smithy.example#RenamedList"), + shape_type=ShapeType.LIST, + members={ + "member": { + "target": STRING, + "traits": [XmlNameTrait("item")], + } + }, +) + +# Map with @xmlName on key/value members +RENAMED_MAP_SCHEMA = Schema.collection( + id=ShapeID("smithy.example#RenamedMap"), + shape_type=ShapeType.MAP, + members={ + "key": { + "target": STRING, + "traits": [XmlNameTrait("Attribute")], + }, + "value": { + "target": STRING, + "traits": [XmlNameTrait("Setting")], + }, + }, +) + +# Map with @xmlName and @xmlNamespace on key/value members +RENAMED_NS_MAP_SCHEMA = Schema.collection( + id=ShapeID("smithy.example#RenamedNsMap"), + shape_type=ShapeType.MAP, + members={ + "key": { + "target": STRING, + "traits": [ + XmlNameTrait("K"), + XmlNamespaceTrait({"uri": "https://the-key.example.com"}), + ], + }, + "value": { + "target": STRING, + "traits": [ + XmlNameTrait("V"), + XmlNamespaceTrait({"uri": "https://the-value.example.com"}), + ], + }, + }, +) + +# List with @xmlNamespace on member +NAMESPACED_LIST_SCHEMA = Schema.collection( + id=ShapeID("smithy.example#NamespacedList"), + shape_type=ShapeType.LIST, + members={ + "member": { + "target": STRING, + "traits": [XmlNamespaceTrait({"uri": "http://bux.com"})], + } + }, +) + +# Struct with @xmlNamespace (default xmlns) +NAMESPACED_STRUCT_SCHEMA = Schema.collection( + id=ShapeID("smithy.example#NsStruct"), + traits=[XmlNamespaceTrait({"uri": "https://example.com"})], + members={ + "value": {"target": STRING}, + }, +) + +# Struct with @xmlNamespace (prefixed xmlns) +PREFIXED_NS_STRUCT_SCHEMA = Schema.collection( + id=ShapeID("smithy.example#PrefixedNsStruct"), + traits=[XmlNamespaceTrait({"uri": "https://example.com", "prefix": "baz"})], + members={ + "value": {"target": STRING}, + }, +) + +SCHEMA: Schema = Schema.collection( + id=ShapeID("smithy.example#SerdeShape"), + members={ + "booleanMember": { + "target": BOOLEAN, + }, + "integerMember": { + "target": INTEGER, + }, + "floatMember": { + "target": FLOAT, + }, + "bigDecimalMember": { + "target": BIG_DECIMAL, + }, + "stringMember": { + "target": STRING, + }, + "xmlNameMember": { + "target": STRING, + "traits": [XmlNameTrait("CustomName")], + }, + "blobMember": { + "target": BLOB, + }, + "timestampMember": { + "target": TIMESTAMP, + }, + "dateTimeMember": { + "target": TIMESTAMP, + "traits": [TimestampFormatTrait("date-time")], + }, + "httpDateMember": { + "target": TIMESTAMP, + "traits": [TimestampFormatTrait("http-date")], + }, + "epochSecondsMember": { + "target": TIMESTAMP, + "traits": [TimestampFormatTrait("epoch-seconds")], + }, + "listMember": { + "target": STRING_LIST_SCHEMA, + }, + "mapMember": { + "target": STRING_MAP_SCHEMA, + }, + "structMember": None, + "xmlAttributeMember": { + "target": STRING, + "traits": [XmlAttributeTrait()], + }, + "renamedListMember": { + "target": RENAMED_LIST_SCHEMA, + }, + "flattenedListMember": { + "target": STRING_LIST_SCHEMA, + "traits": [XmlFlattenedTrait()], + }, + "flattenedMapMember": { + "target": STRING_MAP_SCHEMA, + "traits": [XmlFlattenedTrait()], + }, + "flattenedRenamedListMember": { + "target": STRING_LIST_SCHEMA, + "traits": [XmlFlattenedTrait(), XmlNameTrait("customItem")], + }, + "flattenedRenamedMapMember": { + "target": RENAMED_MAP_SCHEMA, + "traits": [XmlFlattenedTrait(), XmlNameTrait("KVP")], + }, + "xmlAttributeNamedMember": { + "target": STRING, + "traits": [XmlAttributeTrait(), XmlNameTrait("test")], + }, + }, +) +SCHEMA.members["structMember"] = Schema.member( + id=SCHEMA.id.with_member("structMember"), + target=SCHEMA, + index=13, +) + + +@dataclass +class SerdeShape: + boolean_member: bool | None = None + integer_member: int | None = None + float_member: float | None = None + big_decimal_member: Decimal | None = None + string_member: str | None = None + xml_name_member: str | None = None + blob_member: bytes | None = None + timestamp_member: datetime | None = None + date_time_member: datetime | None = None + http_date_member: datetime | None = None + epoch_seconds_member: datetime | None = None + list_member: list[str] | None = None + map_member: dict[str, str] | None = None + struct_member: "SerdeShape | None" = None + xml_attribute_member: str | None = None + renamed_list_member: list[str] | None = None + flattened_list_member: list[str] | None = None + flattened_map_member: dict[str, str] | None = None + flattened_renamed_list_member: list[str] | None = None + flattened_renamed_map_member: dict[str, str] | None = None + xml_attribute_named_member: str | None = None + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(SCHEMA, self) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + if self.boolean_member is not None: + serializer.write_boolean( + SCHEMA.members["booleanMember"], self.boolean_member + ) + if self.integer_member is not None: + serializer.write_integer( + SCHEMA.members["integerMember"], self.integer_member + ) + if self.float_member is not None: + serializer.write_float(SCHEMA.members["floatMember"], self.float_member) + if self.big_decimal_member is not None: + serializer.write_big_decimal( + SCHEMA.members["bigDecimalMember"], self.big_decimal_member + ) + if self.string_member is not None: + serializer.write_string(SCHEMA.members["stringMember"], self.string_member) + if self.xml_name_member is not None: + serializer.write_string( + SCHEMA.members["xmlNameMember"], self.xml_name_member + ) + if self.blob_member is not None: + serializer.write_blob(SCHEMA.members["blobMember"], self.blob_member) + if self.timestamp_member is not None: + serializer.write_timestamp( + SCHEMA.members["timestampMember"], self.timestamp_member + ) + if self.date_time_member is not None: + serializer.write_timestamp( + SCHEMA.members["dateTimeMember"], self.date_time_member + ) + if self.http_date_member is not None: + serializer.write_timestamp( + SCHEMA.members["httpDateMember"], self.http_date_member + ) + if self.epoch_seconds_member is not None: + serializer.write_timestamp( + SCHEMA.members["epochSecondsMember"], self.epoch_seconds_member + ) + if self.list_member is not None: + schema = SCHEMA.members["listMember"] + target_schema = schema.expect_member_target().members["member"] + with serializer.begin_list(schema, len(self.list_member)) as ls: + for element in self.list_member: + ls.write_string(target_schema, element) + if self.map_member is not None: + schema = SCHEMA.members["mapMember"] + target_schema = schema.expect_member_target().members["value"] + with serializer.begin_map(schema, len(self.map_member)) as ms: + for key, value in self.map_member.items(): + ms.entry(key, lambda vs: vs.write_string(target_schema, value)) # type: ignore + if self.struct_member is not None: + serializer.write_struct(SCHEMA.members["structMember"], self.struct_member) + if self.xml_attribute_member is not None: + serializer.write_string( + SCHEMA.members["xmlAttributeMember"], self.xml_attribute_member + ) + if self.renamed_list_member is not None: + schema = SCHEMA.members["renamedListMember"] + target_schema = schema.expect_member_target().members["member"] + with serializer.begin_list(schema, len(self.renamed_list_member)) as ls: + for element in self.renamed_list_member: + ls.write_string(target_schema, element) + if self.flattened_list_member is not None: + schema = SCHEMA.members["flattenedListMember"] + target_schema = schema.expect_member_target().members["member"] + with serializer.begin_list(schema, len(self.flattened_list_member)) as ls: + for element in self.flattened_list_member: + ls.write_string(target_schema, element) + if self.flattened_map_member is not None: + schema = SCHEMA.members["flattenedMapMember"] + target_schema = schema.expect_member_target().members["value"] + with serializer.begin_map(schema, len(self.flattened_map_member)) as ms: + for key, value in self.flattened_map_member.items(): + ms.entry(key, lambda vs: vs.write_string(target_schema, value)) # type: ignore + if self.flattened_renamed_list_member is not None: + schema = SCHEMA.members["flattenedRenamedListMember"] + target_schema = schema.expect_member_target().members["member"] + with serializer.begin_list( + schema, len(self.flattened_renamed_list_member) + ) as ls: + for element in self.flattened_renamed_list_member: + ls.write_string(target_schema, element) + if self.flattened_renamed_map_member is not None: + schema = SCHEMA.members["flattenedRenamedMapMember"] + target_schema = schema.expect_member_target().members["value"] + with serializer.begin_map( + schema, len(self.flattened_renamed_map_member) + ) as ms: + for key, value in self.flattened_renamed_map_member.items(): + ms.entry(key, lambda vs: vs.write_string(target_schema, value)) # type: ignore + if self.xml_attribute_named_member is not None: + serializer.write_string( + SCHEMA.members["xmlAttributeNamedMember"], + self.xml_attribute_named_member, + ) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["boolean_member"] = de.read_boolean( + SCHEMA.members["booleanMember"] + ) + case 1: + kwargs["integer_member"] = de.read_integer( + SCHEMA.members["integerMember"] + ) + case 2: + kwargs["float_member"] = de.read_float( + SCHEMA.members["floatMember"] + ) + case 3: + kwargs["big_decimal_member"] = de.read_big_decimal( + SCHEMA.members["bigDecimalMember"] + ) + case 4: + kwargs["string_member"] = de.read_string( + SCHEMA.members["stringMember"] + ) + case 5: + kwargs["xml_name_member"] = de.read_string( + SCHEMA.members["xmlNameMember"] + ) + case 6: + kwargs["blob_member"] = de.read_blob(SCHEMA.members["blobMember"]) + case 7: + kwargs["timestamp_member"] = de.read_timestamp( + SCHEMA.members["timestampMember"] + ) + case 8: + kwargs["date_time_member"] = de.read_timestamp( + SCHEMA.members["dateTimeMember"] + ) + case 9: + kwargs["http_date_member"] = de.read_timestamp( + SCHEMA.members["httpDateMember"] + ) + case 10: + kwargs["epoch_seconds_member"] = de.read_timestamp( + SCHEMA.members["epochSecondsMember"] + ) + case 11: + list_value: list[str] = [] + de.read_list( + SCHEMA.members["listMember"], + lambda d: list_value.append(d.read_string(STRING)), + ) + kwargs["list_member"] = list_value + case 12: + map_value: dict[str, str] = {} + de.read_map( + SCHEMA.members["mapMember"], + lambda k, d: map_value.__setitem__(k, d.read_string(STRING)), + ) + kwargs["map_member"] = map_value + case 13: + kwargs["struct_member"] = SerdeShape.deserialize(de) + case 14: + kwargs["xml_attribute_member"] = de.read_string( + SCHEMA.members["xmlAttributeMember"] + ) + case 15: + renamed_list_value: list[str] = [] + de.read_list( + SCHEMA.members["renamedListMember"], + lambda d: renamed_list_value.append(d.read_string(STRING)), + ) + kwargs["renamed_list_member"] = renamed_list_value + case 16: + flat_list_value: list[str] = [] + de.read_list( + SCHEMA.members["flattenedListMember"], + lambda d: flat_list_value.append(d.read_string(STRING)), + ) + kwargs["flattened_list_member"] = flat_list_value + case 17: + flat_map_value: dict[str, str] = {} + de.read_map( + SCHEMA.members["flattenedMapMember"], + lambda k, d: flat_map_value.__setitem__( + k, d.read_string(STRING) + ), + ) + kwargs["flattened_map_member"] = flat_map_value + case 18: + flat_renamed_list: list[str] = [] + de.read_list( + SCHEMA.members["flattenedRenamedListMember"], + lambda d: flat_renamed_list.append(d.read_string(STRING)), + ) + kwargs["flattened_renamed_list_member"] = flat_renamed_list + case 19: + flat_renamed_map: dict[str, str] = {} + de.read_map( + SCHEMA.members["flattenedRenamedMapMember"], + lambda k, d: flat_renamed_map.__setitem__( + k, d.read_string(STRING) + ), + ) + kwargs["flattened_renamed_map_member"] = flat_renamed_map + case 20: + kwargs["xml_attribute_named_member"] = de.read_string( + SCHEMA.members["xmlAttributeNamedMember"] + ) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +# ---- Serde test cases ---- +# Each entry is (value, xml_bytes) for round-trip testing. +# Inspired by the awsQuery/restXml protocol compliance tests + +XML_SERDE_CASES: list[tuple[Any, bytes]] = [ + # Scalars + (True, b"true"), + (False, b"false"), + (1, b"1"), + (1.5, b"1.5"), + (float("inf"), b"Infinity"), + (float("-inf"), b"-Infinity"), + (Decimal("1.1"), b"1.1"), + (b"value", b"dmFsdWU="), + ("foo", b"foo"), + ( + datetime(2014, 4, 29, 18, 30, 38, tzinfo=UTC), + b"2014-04-29T18:30:38Z", + ), + # Wrapped list — elements + ( + ["foo", "bar"], + b"foobar", + ), + # Wrapped map — + ( + {"foo": "bar"}, + b"foobar", + ), + # Struct with single scalar + ( + SerdeShape(string_member="foo"), + b"foo", + ), + ( + SerdeShape(boolean_member=True), + b"true", + ), + ( + SerdeShape(integer_member=3), + b"3", + ), + ( + SerdeShape(float_member=5.5), + b"5.5", + ), + ( + SerdeShape(big_decimal_member=Decimal("1.1")), + b"1.1", + ), + ( + SerdeShape(blob_member=b"value"), + b"dmFsdWU=", + ), + # @xmlName — member serialized under custom element name + ( + SerdeShape(xml_name_member="bar"), + b"bar", + ), + # Timestamps with different formats + ( + SerdeShape(timestamp_member=datetime(2014, 4, 29, 18, 30, 38, tzinfo=UTC)), + b"2014-04-29T18:30:38Z", + ), + ( + SerdeShape(date_time_member=datetime(2014, 4, 29, 18, 30, 38, tzinfo=UTC)), + b"2014-04-29T18:30:38Z", + ), + ( + SerdeShape(http_date_member=datetime(2014, 4, 29, 18, 30, 38, tzinfo=UTC)), + b"Tue, 29 Apr 2014 18:30:38 GMT", + ), + ( + SerdeShape(epoch_seconds_member=datetime(2014, 4, 29, 18, 30, 38, tzinfo=UTC)), + b"1398796238", + ), + # List inside struct + ( + SerdeShape(list_member=["foo", "bar"]), + ( + b"" + b"foobar" + b"" + ), + ), + # Map inside struct + ( + SerdeShape(map_member={"foo": "bar"}), + ( + b"" + b"foobar" + b"" + ), + ), + # Nested struct + ( + SerdeShape(struct_member=SerdeShape(string_member="nested")), + ( + b"" + b"nested" + b"" + ), + ), + # @xmlAttribute — member as attribute on parent element + ( + SerdeShape(xml_attribute_member="attr_val"), + b'', + ), + # List with @xmlName("item") on member + ( + SerdeShape(renamed_list_member=["foo", "bar"]), + ( + b"" + b"foobar" + b"" + ), + ), + # @xmlFlattened list + ( + SerdeShape(flattened_list_member=["hi", "bye"]), + ( + b"" + b"hi" + b"bye" + b"" + ), + ), + # @xmlFlattened map + ( + SerdeShape(flattened_map_member={"foo": "Foo", "baz": "Baz"}), + ( + b"" + b"fooFoo" + b"bazBaz" + b"" + ), + ), + # @xmlFlattened + @xmlName on list member + ( + SerdeShape(flattened_renamed_list_member=["hi", "bye"]), + ( + b"" + b"hi" + b"bye" + b"" + ), + ), + # @xmlFlattened + @xmlName on map member with renamed key/value + ( + SerdeShape(flattened_renamed_map_member={"foo": "Foo"}), + ( + b"" + b"fooFoo" + b"" + ), + ), + # @xmlAttribute + @xmlName + ( + SerdeShape(xml_attribute_named_member="attr_val"), + b'', + ), + # Multiple members in one struct — realistic multi-member test + ( + SerdeShape( + boolean_member=True, + integer_member=42, + string_member="hello", + list_member=["a", "b"], + ), + ( + b"" + b"true" + b"42" + b"hello" + b"ab" + b"" + ), + ), + # Nested struct 3 levels deep + ( + SerdeShape( + struct_member=SerdeShape(struct_member=SerdeShape(string_member="deep")) + ), + ( + b"" + b"" + b"deep" + b"" + b"" + ), + ), + # XML escaping in text content + ( + SerdeShape(string_member=""), + b"<foo&bar>", + ), + # Empty collections — wrapper element with no children + ( + SerdeShape(list_member=[]), + b"", + ), + ( + SerdeShape(map_member={}), + b"", + ), +] diff --git a/packages/smithy-xml/tests/unit/test_deserializers.py b/packages/smithy-xml/tests/unit/test_deserializers.py new file mode 100644 index 000000000..f3ee9e703 --- /dev/null +++ b/packages/smithy-xml/tests/unit/test_deserializers.py @@ -0,0 +1,145 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import math +from datetime import datetime +from decimal import Decimal +from typing import Any + +import pytest +from smithy_core.prelude import ( + BIG_DECIMAL, + BLOB, + BOOLEAN, + DOCUMENT, + FLOAT, + INTEGER, + STRING, + TIMESTAMP, +) +from smithy_xml import XMLCodec + +from . import ( + STRING_LIST_SCHEMA, + STRING_MAP_SCHEMA, + XML_SERDE_CASES, + SerdeShape, +) + + +@pytest.mark.parametrize("expected, given", XML_SERDE_CASES) +def test_xml_deserializer(expected: Any, given: bytes) -> None: + codec = XMLCodec() + deserializer = codec.create_deserializer(given) + match expected: + case bool(): + actual = deserializer.read_boolean(BOOLEAN) + case int(): + actual = deserializer.read_integer(INTEGER) + case float(): + actual = deserializer.read_float(FLOAT) + case Decimal(): + actual = deserializer.read_big_decimal(BIG_DECIMAL) + case bytes(): + actual = deserializer.read_blob(BLOB) + case str(): + actual = deserializer.read_string(STRING) + case datetime(): + actual = deserializer.read_timestamp(TIMESTAMP) + case list(): + actual_list: list[str] = [] + deserializer.read_list( + STRING_LIST_SCHEMA, + lambda d: actual_list.append(d.read_string(STRING)), + ) + actual = actual_list + case dict(): + actual_map: dict[str, str] = {} + deserializer.read_map( + STRING_MAP_SCHEMA, + lambda k, d: actual_map.__setitem__(k, d.read_string(STRING)), + ) + actual = actual_map + case SerdeShape(): + actual = SerdeShape.deserialize(deserializer) + case _: + raise Exception(f"Unexpected type: {type(expected)}") + + assert actual == expected + + +def test_read_document_raises() -> None: + """XML does not support document types.""" + deserializer = XMLCodec().create_deserializer(b"foo") + with pytest.raises( + NotImplementedError, match="XML does not support document types" + ): + deserializer.read_document(DOCUMENT) + + +def test_deserialize_nan() -> None: + actual = XMLCodec().create_deserializer(b"NaN").read_float(FLOAT) + assert math.isnan(actual) + + +def test_deserialize_empty_string_self_closed() -> None: + assert XMLCodec().create_deserializer(b"").read_string(STRING) == "" + + +def test_deserialize_empty_string_open_close() -> None: + assert XMLCodec().create_deserializer(b"").read_string(STRING) == "" + + +def test_deserialize_empty_blob() -> None: + assert XMLCodec().create_deserializer(b"").read_blob(BLOB) == b"" + + +def test_deserialize_empty_blob_self_closed() -> None: + assert XMLCodec().create_deserializer(b"").read_blob(BLOB) == b"" + + +def test_wrapper_elements() -> None: + """Deserializer can unwrap awsQuery-style response wrappers.""" + xml = ( + b"" + b"hello" + b"" + ) + deserializer = XMLCodec().create_deserializer( + xml, wrapper_elements=("OpResponse", "OpResult") + ) + result = SerdeShape.deserialize(deserializer) + assert result.string_member == "hello" + + +def test_wrapper_elements_scalar_read() -> None: + xml = b"hello" + deserializer = XMLCodec().create_deserializer( + xml, wrapper_elements=("OpResponse", "OpResult") + ) + assert deserializer.read_string(STRING) == "hello" + + +def test_flattened_list_interleaved_with_other_members() -> None: + """Flattened list elements can be interleaved with other struct members.""" + xml = ( + b"" + b"first" + b"middle" + b"second" + b"" + ) + result = SerdeShape.deserialize(XMLCodec().create_deserializer(xml)) + assert result.flattened_list_member == ["first", "second"] + assert result.string_member == "middle" + + +def test_unknown_members_skipped() -> None: + xml = ( + b"" + b"keep" + b"ignore" + b"5" + b"" + ) + result = SerdeShape.deserialize(XMLCodec().create_deserializer(xml)) + assert result == SerdeShape(string_member="keep", integer_member=5) diff --git a/packages/smithy-xml/tests/unit/test_serializers.py b/packages/smithy-xml/tests/unit/test_serializers.py new file mode 100644 index 000000000..2356c0df9 --- /dev/null +++ b/packages/smithy-xml/tests/unit/test_serializers.py @@ -0,0 +1,209 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from datetime import datetime +from decimal import Decimal +from io import BytesIO +from typing import Any, cast +from xml.etree.ElementTree import canonicalize + +import pytest +from smithy_core.prelude import ( + BIG_DECIMAL, + BLOB, + BOOLEAN, + FLOAT, + INTEGER, + STRING, + TIMESTAMP, +) +from smithy_xml import XMLCodec + +from . import ( + NAMESPACED_LIST_SCHEMA, + NAMESPACED_STRUCT_SCHEMA, + PREFIXED_NS_STRUCT_SCHEMA, + RENAMED_NS_MAP_SCHEMA, + STRING_LIST_SCHEMA, + STRING_MAP_SCHEMA, + XML_SERDE_CASES, + SerdeShape, +) + + +def _canonicalize(xml_bytes: bytes) -> str: + """Canonicalize XML for comparison, stripping whitespace differences.""" + return canonicalize(xml_bytes, strip_text=True) + + +@pytest.mark.parametrize("given, expected", XML_SERDE_CASES) +def test_xml_serializer(given: Any, expected: bytes) -> None: + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + match given: + case bool(): + serializer.write_boolean(BOOLEAN, given) + case int(): + serializer.write_integer(INTEGER, given) + case float(): + serializer.write_float(FLOAT, given) + case Decimal(): + serializer.write_big_decimal(BIG_DECIMAL, given) + case bytes(): + serializer.write_blob(BLOB, given) + case str(): + serializer.write_string(STRING, given) + case datetime(): + serializer.write_timestamp(TIMESTAMP, given) + case list(): + given = cast(list[str], given) + with serializer.begin_list(STRING_LIST_SCHEMA, len(given)) as ls: + member_schema = STRING_LIST_SCHEMA.members["member"] + for e in given: + ls.write_string(member_schema, e) + case dict(): + given = cast(dict[str, str], given) + with serializer.begin_map(STRING_MAP_SCHEMA, len(given)) as ms: + member_schema = STRING_MAP_SCHEMA.members["value"] + for k, v in given.items(): + ms.entry(k, lambda vs: vs.write_string(member_schema, v)) # type: ignore + case SerdeShape(): + given.serialize(serializer) + case _: + raise Exception(f"Unexpected type: {type(given)}") + + serializer.flush() + sink.seek(0) + actual = sink.read() + assert _canonicalize(actual) == _canonicalize(expected) + + +def test_write_null() -> None: + """write_null creates an empty element (no text content).""" + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + serializer.write_null(STRING) + serializer.flush() + sink.seek(0) + actual = sink.read() + assert actual == b"" + + +def test_write_document_raises() -> None: + """XML does not support document types.""" + from smithy_core.documents import Document + from smithy_core.prelude import DOCUMENT + + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + with pytest.raises(NotImplementedError, match="XML does not support document"): + serializer.write_document(DOCUMENT, Document(None)) + + +def test_float_nan() -> None: + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + serializer.write_float(FLOAT, float("nan")) + serializer.flush() + sink.seek(0) + assert sink.read() == b"NaN" + + +def test_default_namespace() -> None: + """Default namespace is set on the root element.""" + sink = BytesIO() + serializer = XMLCodec(default_namespace="https://example.com/").create_serializer( + sink + ) + serializer.write_string(STRING, "hi") + serializer.flush() + sink.seek(0) + actual = sink.read() + assert actual == b'hi' + + +def test_xml_escaping_in_attribute() -> None: + """XML special characters are escaped in attribute values.""" + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + shape = SerdeShape(xml_attribute_named_member='<"test">&') + shape.serialize(serializer) + serializer.flush() + sink.seek(0) + actual = sink.read() + assert actual == b'' + + +def test_flush_with_no_writes() -> None: + """Flushing without any writes produces no output.""" + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + serializer.flush() + sink.seek(0) + assert sink.read() == b"" + + +def test_list_with_namespace_on_member() -> None: + """@xmlNamespace on list member adds xmlns to each item element.""" + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + items = ["Bar"] + with serializer.begin_list(NAMESPACED_LIST_SCHEMA, len(items)) as ls: + member_schema = NAMESPACED_LIST_SCHEMA.members["member"] + for e in items: + ls.write_string(member_schema, e) + serializer.flush() + sink.seek(0) + actual = sink.read() + assert ( + actual + == b'Bar' + ) + + +def test_map_with_xmlname_and_namespace() -> None: + """Map with @xmlName + @xmlNamespace on key and value members.""" + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + data = {"a": "A"} + with serializer.begin_map(RENAMED_NS_MAP_SCHEMA, len(data)) as ms: + member_schema = RENAMED_NS_MAP_SCHEMA.members["value"] + for k, v in data.items(): + ms.entry(k, lambda vs: vs.write_string(member_schema, v)) + serializer.flush() + sink.seek(0) + actual = sink.read() + assert actual == ( + b"" + b'a' + b'A' + b"" + ) + + +def test_struct_with_xml_namespace() -> None: + """@xmlNamespace on struct adds default xmlns to root element.""" + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + with serializer.begin_struct(NAMESPACED_STRUCT_SCHEMA) as ss: + ss.write_string(NAMESPACED_STRUCT_SCHEMA.members["value"], "hi") + serializer.flush() + sink.seek(0) + actual = sink.read() + assert ( + actual == b'hi' + ) + + +def test_struct_with_xml_namespace_prefix() -> None: + """@xmlNamespace with prefix adds prefixed xmlns to root element.""" + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + with serializer.begin_struct(PREFIXED_NS_STRUCT_SCHEMA) as ss: + ss.write_string(PREFIXED_NS_STRUCT_SCHEMA.members["value"], "hi") + serializer.flush() + sink.seek(0) + actual = sink.read() + assert ( + actual + == b'hi' + ) diff --git a/pyproject.toml b/pyproject.toml index d7a23d99e..fa33308f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ members = ["packages/*"] smithy_core = { workspace = true } smithy_http = { workspace = true } smithy_json = { workspace = true } +smithy_xml = { workspace = true } smithy_aws_core = { workspace = true } smithy_aws_event_stream = { workspace = true } aws_sdk_signers = {workspace = true } diff --git a/uv.lock b/uv.lock index ca20fac7e..53f054dc5 100644 --- a/uv.lock +++ b/uv.lock @@ -11,6 +11,7 @@ members = [ "smithy-http", "smithy-json", "smithy-python", + "smithy-xml", ] [[package]] @@ -777,6 +778,16 @@ test = [ ] typing = [{ name = "pyright", specifier = ">=1.1.400" }] +[[package]] +name = "smithy-xml" +source = { editable = "packages/smithy-xml" } +dependencies = [ + { name = "smithy-core" }, +] + +[package.metadata] +requires-dist = [{ name = "smithy-core", editable = "packages/smithy-core" }] + [[package]] name = "typing-extensions" version = "4.13.2" From 48b7047df703c5b9b61b35a9d7112e4f11f24b45 Mon Sep 17 00:00:00 2001 From: Antonio Aranda <102337110+arandito@users.noreply.github.com> Date: Wed, 18 Mar 2026 10:13:23 -0400 Subject: [PATCH 2/4] Remove serializers from initial smithy-xml package --- ...ture-f4925439be784276a4639947064c99f9.json | 4 + .../smithy-xml/src/smithy_xml/__init__.py | 16 +- .../src/smithy_xml/_private/deserializers.py | 18 +- .../src/smithy_xml/_private/serializers.py | 372 ------------------ .../smithy-xml/src/smithy_xml/settings.py | 12 +- .../smithy-xml/tests/unit/test_serializers.py | 204 +--------- 6 files changed, 38 insertions(+), 588 deletions(-) create mode 100644 packages/smithy-xml/.changes/next-release/smithy-xml-feature-f4925439be784276a4639947064c99f9.json delete mode 100644 packages/smithy-xml/src/smithy_xml/_private/serializers.py diff --git a/packages/smithy-xml/.changes/next-release/smithy-xml-feature-f4925439be784276a4639947064c99f9.json b/packages/smithy-xml/.changes/next-release/smithy-xml-feature-f4925439be784276a4639947064c99f9.json new file mode 100644 index 000000000..35b2cb2f7 --- /dev/null +++ b/packages/smithy-xml/.changes/next-release/smithy-xml-feature-f4925439be784276a4639947064c99f9.json @@ -0,0 +1,4 @@ +{ + "type": "feature", + "description": "Added initial support for XML deserialization in Smithy clients." +} \ No newline at end of file diff --git a/packages/smithy-xml/src/smithy_xml/__init__.py b/packages/smithy-xml/src/smithy_xml/__init__.py index 9616aa8ef..d50d808e4 100644 --- a/packages/smithy-xml/src/smithy_xml/__init__.py +++ b/packages/smithy-xml/src/smithy_xml/__init__.py @@ -12,10 +12,9 @@ from ._private.deserializers import XMLShapeDeserializer as _XMLShapeDeserializer from ._private.readers import XMLEventReader as _XMLEventReader -from ._private.serializers import XMLShapeSerializer as _XMLShapeSerializer from .settings import XMLSettings -__version__ = "0.0.1" +__version__ = "0.1.0" __all__ = ("XMLCodec", "XMLSettings") @@ -24,10 +23,21 @@ class XMLCodec(Codec): def __init__( self, + use_timestamp_format: bool = True, default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME, default_namespace: str | None = None, ) -> None: + """Initializes an XMLCodec. + + :param use_timestamp_format: Whether the codec should use the + `smithy.api#timestampFormat` trait, if present. + :param default_timestamp_format: The default timestamp format to use if the + `smithy.api#timestampFormat` trait is not enabled or not present. + :param default_namespace: Default XML namespace (`xmlns`) applied to the root + element during serialization. + """ self._settings = XMLSettings( + use_timestamp_format=use_timestamp_format, default_timestamp_format=default_timestamp_format, default_namespace=default_namespace, ) @@ -37,7 +47,7 @@ def media_type(self) -> str: return "application/xml" def create_serializer(self, sink: BytesWriter) -> ShapeSerializer: - return _XMLShapeSerializer(sink=sink, settings=self._settings) + raise NotImplementedError("XML serialization is not supported") def create_deserializer( self, diff --git a/packages/smithy-xml/src/smithy_xml/_private/deserializers.py b/packages/smithy-xml/src/smithy_xml/_private/deserializers.py index ae1ece78f..1e5f13b9d 100644 --- a/packages/smithy-xml/src/smithy_xml/_private/deserializers.py +++ b/packages/smithy-xml/src/smithy_xml/_private/deserializers.py @@ -131,12 +131,13 @@ def read_document(self, schema: Schema) -> Document: raise NotImplementedError("XML does not support document types") def read_timestamp(self, schema: Schema) -> datetime.datetime: - fmt = self._settings.default_timestamp_format - if format_trait := schema.get_trait(TimestampFormatTrait): - fmt = format_trait.format + format = self._settings.default_timestamp_format + if self._settings.use_timestamp_format: + if format_trait := schema.get_trait(TimestampFormatTrait): + format = format_trait.format text = self._read_text() - return fmt.deserialize(text) + return format.deserialize(text) def read_struct( self, @@ -371,8 +372,9 @@ def read_big_decimal(self, schema: Schema) -> Decimal: return Decimal(self._value) def read_timestamp(self, schema: Schema) -> datetime.datetime: - fmt = self._settings.default_timestamp_format - if format_trait := schema.get_trait(TimestampFormatTrait): - fmt = format_trait.format + format = self._settings.default_timestamp_format + if self._settings.use_timestamp_format: + if format_trait := schema.get_trait(TimestampFormatTrait): + format = format_trait.format - return fmt.deserialize(self._value) + return format.deserialize(self._value) diff --git a/packages/smithy-xml/src/smithy_xml/_private/serializers.py b/packages/smithy-xml/src/smithy_xml/_private/serializers.py deleted file mode 100644 index c3f242070..000000000 --- a/packages/smithy-xml/src/smithy_xml/_private/serializers.py +++ /dev/null @@ -1,372 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 -from base64 import b64encode -from collections.abc import Callable -from contextlib import AbstractContextManager -from datetime import datetime -from decimal import Decimal -from types import TracebackType -from typing import Self -from xml.etree.ElementTree import Element, SubElement, tostring - -from smithy_core.documents import Document -from smithy_core.interfaces import BytesWriter -from smithy_core.schemas import Schema -from smithy_core.serializers import ( - InterceptingSerializer, - MapSerializer, - ShapeSerializer, - SpecificShapeSerializer, -) -from smithy_core.shapes import ShapeType -from smithy_core.traits import ( - TimestampFormatTrait, - XmlAttributeTrait, - XmlFlattenedTrait, - XmlNamespaceTrait, - XmlNameTrait, -) - -from ..settings import XMLSettings - -_INF: float = float("inf") -_NEG_INF: float = float("-inf") - - -def _xml_member_name(member_schema: Schema) -> str: - """Get the XML element name for a member, respecting @xmlName.""" - if xml_name := member_schema.get_trait(XmlNameTrait): - return xml_name.value - return member_schema.expect_member_name() - - -def _xml_root_name(schema: Schema) -> str: - """Get the XML root element name, respecting @xmlName and member targets.""" - if xml_name := schema.get_trait(XmlNameTrait): - return xml_name.value - if schema.member_target is not None: - return schema.expect_member_target().id.name - return schema.id.name - - -def _set_xml_namespace( - element: Element, - schema: Schema, - settings: XMLSettings, - *, - is_root: bool = False, -) -> None: - """Apply @xmlNamespace to an element, or the default namespace if root.""" - if namespace_trait := schema.get_trait(XmlNamespaceTrait): - if namespace_trait.prefix: - element.set(f"xmlns:{namespace_trait.prefix}", namespace_trait.uri) - else: - element.set("xmlns", namespace_trait.uri) - return - - if is_root and settings.default_namespace: - element.set("xmlns", settings.default_namespace) - - -def _format_xml_float(value: float) -> str: - """Format a float for XML, handling NaN and Infinity.""" - if value != value: - return "NaN" - if value == _INF: - return "Infinity" - if value == _NEG_INF: - return "-Infinity" - return repr(value) - - -def _is_flattened_collection_schema(schema: Schema) -> bool: - """Check if a schema is a flattened list or map.""" - return schema.get_trait(XmlFlattenedTrait) is not None and schema.shape_type in ( - ShapeType.LIST, - ShapeType.MAP, - ) - - -class XMLShapeSerializer(ShapeSerializer): - """Serializes Smithy shapes into XML and writes the result to a BytesWriter. - - Builds an in-memory XML tree backed by an element stack. ``write_*`` - methods target the top element, and struct/list/map serializers push and - pop child elements to control nesting. ``flush`` writes the tree to the - sink. - """ - - def __init__(self, sink: BytesWriter, settings: XMLSettings) -> None: - self._sink = sink - self.settings = settings - self._root: Element | None = None - self.element_stack: list[Element] = [] - - @property - def current(self) -> Element: - return self.element_stack[-1] - - def ensure_root(self, schema: Schema) -> None: - if self._root is not None: - return - root = Element(_xml_root_name(schema)) - _set_xml_namespace(root, schema, self.settings, is_root=True) - self._root = root - self.element_stack.append(root) - - def begin_struct( - self, schema: "Schema" - ) -> AbstractContextManager["ShapeSerializer"]: - return XMLStructSerializer(self, schema) - - def begin_list( - self, schema: "Schema", size: int - ) -> AbstractContextManager["ShapeSerializer"]: - return XMLListSerializer(self, schema) - - def begin_map( - self, schema: "Schema", size: int - ) -> AbstractContextManager["MapSerializer"]: - return XMLMapSerializer(self, schema) - - def write_null(self, schema: "Schema") -> None: - self.ensure_root(schema) - - def write_boolean(self, schema: "Schema", value: bool) -> None: - self.ensure_root(schema) - self.current.text = "true" if value else "false" - - def write_integer(self, schema: "Schema", value: int) -> None: - self.ensure_root(schema) - self.current.text = str(value) - - def write_float(self, schema: "Schema", value: float) -> None: - self.ensure_root(schema) - self.current.text = _format_xml_float(value) - - def write_big_decimal(self, schema: "Schema", value: Decimal) -> None: - self.ensure_root(schema) - self.current.text = str(value.normalize()) - - def write_string(self, schema: "Schema", value: str) -> None: - self.ensure_root(schema) - self.current.text = value - - def write_blob(self, schema: "Schema", value: bytes) -> None: - self.ensure_root(schema) - self.current.text = b64encode(value).decode("utf-8") - - def write_timestamp(self, schema: "Schema", value: datetime) -> None: - self.ensure_root(schema) - fmt = self.settings.default_timestamp_format - if format_trait := schema.get_trait(TimestampFormatTrait): - fmt = format_trait.format - self.current.text = str(fmt.serialize(value)) - - def write_document(self, schema: "Schema", value: Document) -> None: - raise NotImplementedError("XML does not support document types.") - - def flush(self) -> None: - if self._root is None: - return - xml_bytes = tostring(self._root, encoding="utf-8", xml_declaration=False) - self._sink.write(xml_bytes) - - self._root = None - self.element_stack.clear() - - -class XMLStructSerializer(InterceptingSerializer): - """Serializes struct members as child XML elements. - - ``before`` pushes a child element for the member onto the parent's stack, - and ``after`` pops it. Attributes and flattened collections are special-cased - to skip the push/pop. - """ - - def __init__(self, parent: XMLShapeSerializer, schema: Schema) -> None: - self._parent = parent - self._schema = schema - - def __enter__(self) -> Self: - self._parent.ensure_root(self._schema) - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> None: - pass - - def before(self, schema: "Schema") -> ShapeSerializer: - member_name = _xml_member_name(schema) - - # Attributes are written on the current element, not as children. - if schema.get_trait(XmlAttributeTrait) is not None: - return XMLAttributeSerializer( - self._parent.current, member_name, self._parent.settings - ) - - # Flattened collections have no wrapper element. Items are added - # directly under the current element without changing the stack. - if _is_flattened_collection_schema(schema): - return self._parent - - # Non-flattened collections push a wrapper element onto the stack. - child = SubElement(self._parent.current, member_name) - _set_xml_namespace(child, schema, self._parent.settings) - self._parent.element_stack.append(child) - return self._parent - - def after(self, schema: "Schema") -> None: - # Attributes and flattened collections didn't push, so don't pop. - if schema.get_trait(XmlAttributeTrait) is not None: - return - if _is_flattened_collection_schema(schema): - return - self._parent.element_stack.pop() - - -class XMLListSerializer(InterceptingSerializer): - """Serializes list items as repeated child elements. - - ``before`` pushes a child element for each item, ``after`` pops it. - """ - - def __init__(self, parent: XMLShapeSerializer, schema: Schema) -> None: - self._parent = parent - self._schema = schema - is_flattened = schema.get_trait(XmlFlattenedTrait) is not None - - if is_flattened: - if schema.member_target is not None: - self._item_tag = _xml_member_name(schema) - else: - self._item_tag = _xml_root_name(schema) - else: - self._item_tag = _xml_member_name(schema.members["member"]) - - def __enter__(self) -> Self: - self._parent.ensure_root(self._schema) - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> None: - pass - - def before(self, schema: "Schema") -> ShapeSerializer: - child = SubElement(self._parent.current, self._item_tag) - _set_xml_namespace(child, schema, self._parent.settings) - self._parent.element_stack.append(child) - return self._parent - - def after(self, schema: "Schema") -> None: - self._parent.element_stack.pop() - - -class XMLMapSerializer(MapSerializer): - """Serializes map entries as ```` elements. - - Each ``entry`` call pushes the value element onto the stack for the - value writer callback, then pops it. - """ - - def __init__(self, parent: XMLShapeSerializer, schema: Schema) -> None: - self._parent = parent - self._schema = schema - self._is_flattened = schema.get_trait(XmlFlattenedTrait) is not None - - self._key_schema = schema.members["key"] - self._value_schema = schema.members["value"] - self._key_tag = _xml_member_name(self._key_schema) - self._value_tag = _xml_member_name(self._value_schema) - - if self._is_flattened: - if schema.member_target is not None: - self._entry_tag = _xml_member_name(schema) - else: - self._entry_tag = _xml_root_name(schema) - else: - self._entry_tag = "entry" - - def __enter__(self) -> Self: - self._parent.ensure_root(self._schema) - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> None: - pass - - def entry(self, key: str, value_writer: Callable[[ShapeSerializer], None]) -> None: - settings = self._parent.settings - entry_el = SubElement(self._parent.current, self._entry_tag) - if self._is_flattened: - _set_xml_namespace(entry_el, self._schema, settings) - - key_el = SubElement(entry_el, self._key_tag) - _set_xml_namespace(key_el, self._key_schema, settings) - key_el.text = key - - value_el = SubElement(entry_el, self._value_tag) - _set_xml_namespace(value_el, self._value_schema, settings) - self._parent.element_stack.append(value_el) - value_writer(self._parent) - self._parent.element_stack.pop() - - -class XMLAttributeSerializer(SpecificShapeSerializer): - """Serializer that writes values as XML attributes on the parent element.""" - - def __init__(self, element: Element, attr_name: str, settings: XMLSettings) -> None: - self._element = element - self._attr_name = attr_name - self._settings = settings - - def write_null(self, schema: "Schema") -> None: - pass - - def write_boolean(self, schema: "Schema", value: bool) -> None: - self._element.set(self._attr_name, "true" if value else "false") - - def write_byte(self, schema: "Schema", value: int) -> None: - self.write_integer(schema, value) - - def write_short(self, schema: "Schema", value: int) -> None: - self.write_integer(schema, value) - - def write_integer(self, schema: "Schema", value: int) -> None: - self._element.set(self._attr_name, str(value)) - - def write_long(self, schema: "Schema", value: int) -> None: - self.write_integer(schema, value) - - def write_big_integer(self, schema: "Schema", value: int) -> None: - self.write_integer(schema, value) - - def write_float(self, schema: "Schema", value: float) -> None: - self._element.set(self._attr_name, _format_xml_float(value)) - - def write_double(self, schema: "Schema", value: float) -> None: - self.write_float(schema, value) - - def write_big_decimal(self, schema: "Schema", value: Decimal) -> None: - self._element.set(self._attr_name, str(value.normalize())) - - def write_string(self, schema: "Schema", value: str) -> None: - self._element.set(self._attr_name, value) - - def write_timestamp(self, schema: "Schema", value: datetime) -> None: - fmt = self._settings.default_timestamp_format - if format_trait := schema.get_trait(TimestampFormatTrait): - fmt = format_trait.format - self._element.set(self._attr_name, str(fmt.serialize(value))) diff --git a/packages/smithy-xml/src/smithy_xml/settings.py b/packages/smithy-xml/src/smithy_xml/settings.py index ea25cd1b6..86549d3e1 100644 --- a/packages/smithy-xml/src/smithy_xml/settings.py +++ b/packages/smithy-xml/src/smithy_xml/settings.py @@ -5,12 +5,16 @@ from smithy_core.types import TimestampFormat -@dataclass(frozen=True) +@dataclass(slots=True) class XMLSettings: - """Configuration for XML serialization/deserialization.""" + """Settings for the XML codec.""" + + use_timestamp_format: bool = True + """Whether the codec should use the `smithy.api#timestampFormat` trait, if present.""" default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME - """Default timestamp format when a member does not define @timestampFormat.""" + """The default timestamp format to use if the `smithy.api#timestampFormat` trait is + not enabled or not present.""" default_namespace: str | None = None - """Default XML namespace (``xmlns``) applied to the root element during serialization.""" + """Default XML namespace (`xmlns`) applied to the root element during serialization.""" diff --git a/packages/smithy-xml/tests/unit/test_serializers.py b/packages/smithy-xml/tests/unit/test_serializers.py index 2356c0df9..deaef6659 100644 --- a/packages/smithy-xml/tests/unit/test_serializers.py +++ b/packages/smithy-xml/tests/unit/test_serializers.py @@ -1,209 +1,11 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from datetime import datetime -from decimal import Decimal from io import BytesIO -from typing import Any, cast -from xml.etree.ElementTree import canonicalize import pytest -from smithy_core.prelude import ( - BIG_DECIMAL, - BLOB, - BOOLEAN, - FLOAT, - INTEGER, - STRING, - TIMESTAMP, -) from smithy_xml import XMLCodec -from . import ( - NAMESPACED_LIST_SCHEMA, - NAMESPACED_STRUCT_SCHEMA, - PREFIXED_NS_STRUCT_SCHEMA, - RENAMED_NS_MAP_SCHEMA, - STRING_LIST_SCHEMA, - STRING_MAP_SCHEMA, - XML_SERDE_CASES, - SerdeShape, -) - -def _canonicalize(xml_bytes: bytes) -> str: - """Canonicalize XML for comparison, stripping whitespace differences.""" - return canonicalize(xml_bytes, strip_text=True) - - -@pytest.mark.parametrize("given, expected", XML_SERDE_CASES) -def test_xml_serializer(given: Any, expected: bytes) -> None: - sink = BytesIO() - serializer = XMLCodec().create_serializer(sink) - match given: - case bool(): - serializer.write_boolean(BOOLEAN, given) - case int(): - serializer.write_integer(INTEGER, given) - case float(): - serializer.write_float(FLOAT, given) - case Decimal(): - serializer.write_big_decimal(BIG_DECIMAL, given) - case bytes(): - serializer.write_blob(BLOB, given) - case str(): - serializer.write_string(STRING, given) - case datetime(): - serializer.write_timestamp(TIMESTAMP, given) - case list(): - given = cast(list[str], given) - with serializer.begin_list(STRING_LIST_SCHEMA, len(given)) as ls: - member_schema = STRING_LIST_SCHEMA.members["member"] - for e in given: - ls.write_string(member_schema, e) - case dict(): - given = cast(dict[str, str], given) - with serializer.begin_map(STRING_MAP_SCHEMA, len(given)) as ms: - member_schema = STRING_MAP_SCHEMA.members["value"] - for k, v in given.items(): - ms.entry(k, lambda vs: vs.write_string(member_schema, v)) # type: ignore - case SerdeShape(): - given.serialize(serializer) - case _: - raise Exception(f"Unexpected type: {type(given)}") - - serializer.flush() - sink.seek(0) - actual = sink.read() - assert _canonicalize(actual) == _canonicalize(expected) - - -def test_write_null() -> None: - """write_null creates an empty element (no text content).""" - sink = BytesIO() - serializer = XMLCodec().create_serializer(sink) - serializer.write_null(STRING) - serializer.flush() - sink.seek(0) - actual = sink.read() - assert actual == b"" - - -def test_write_document_raises() -> None: - """XML does not support document types.""" - from smithy_core.documents import Document - from smithy_core.prelude import DOCUMENT - - sink = BytesIO() - serializer = XMLCodec().create_serializer(sink) - with pytest.raises(NotImplementedError, match="XML does not support document"): - serializer.write_document(DOCUMENT, Document(None)) - - -def test_float_nan() -> None: - sink = BytesIO() - serializer = XMLCodec().create_serializer(sink) - serializer.write_float(FLOAT, float("nan")) - serializer.flush() - sink.seek(0) - assert sink.read() == b"NaN" - - -def test_default_namespace() -> None: - """Default namespace is set on the root element.""" - sink = BytesIO() - serializer = XMLCodec(default_namespace="https://example.com/").create_serializer( - sink - ) - serializer.write_string(STRING, "hi") - serializer.flush() - sink.seek(0) - actual = sink.read() - assert actual == b'hi' - - -def test_xml_escaping_in_attribute() -> None: - """XML special characters are escaped in attribute values.""" - sink = BytesIO() - serializer = XMLCodec().create_serializer(sink) - shape = SerdeShape(xml_attribute_named_member='<"test">&') - shape.serialize(serializer) - serializer.flush() - sink.seek(0) - actual = sink.read() - assert actual == b'' - - -def test_flush_with_no_writes() -> None: - """Flushing without any writes produces no output.""" - sink = BytesIO() - serializer = XMLCodec().create_serializer(sink) - serializer.flush() - sink.seek(0) - assert sink.read() == b"" - - -def test_list_with_namespace_on_member() -> None: - """@xmlNamespace on list member adds xmlns to each item element.""" - sink = BytesIO() - serializer = XMLCodec().create_serializer(sink) - items = ["Bar"] - with serializer.begin_list(NAMESPACED_LIST_SCHEMA, len(items)) as ls: - member_schema = NAMESPACED_LIST_SCHEMA.members["member"] - for e in items: - ls.write_string(member_schema, e) - serializer.flush() - sink.seek(0) - actual = sink.read() - assert ( - actual - == b'Bar' - ) - - -def test_map_with_xmlname_and_namespace() -> None: - """Map with @xmlName + @xmlNamespace on key and value members.""" - sink = BytesIO() - serializer = XMLCodec().create_serializer(sink) - data = {"a": "A"} - with serializer.begin_map(RENAMED_NS_MAP_SCHEMA, len(data)) as ms: - member_schema = RENAMED_NS_MAP_SCHEMA.members["value"] - for k, v in data.items(): - ms.entry(k, lambda vs: vs.write_string(member_schema, v)) - serializer.flush() - sink.seek(0) - actual = sink.read() - assert actual == ( - b"" - b'a' - b'A' - b"" - ) - - -def test_struct_with_xml_namespace() -> None: - """@xmlNamespace on struct adds default xmlns to root element.""" - sink = BytesIO() - serializer = XMLCodec().create_serializer(sink) - with serializer.begin_struct(NAMESPACED_STRUCT_SCHEMA) as ss: - ss.write_string(NAMESPACED_STRUCT_SCHEMA.members["value"], "hi") - serializer.flush() - sink.seek(0) - actual = sink.read() - assert ( - actual == b'hi' - ) - - -def test_struct_with_xml_namespace_prefix() -> None: - """@xmlNamespace with prefix adds prefixed xmlns to root element.""" - sink = BytesIO() - serializer = XMLCodec().create_serializer(sink) - with serializer.begin_struct(PREFIXED_NS_STRUCT_SCHEMA) as ss: - ss.write_string(PREFIXED_NS_STRUCT_SCHEMA.members["value"], "hi") - serializer.flush() - sink.seek(0) - actual = sink.read() - assert ( - actual - == b'hi' - ) +def test_create_serializer_raises() -> None: + with pytest.raises(NotImplementedError, match="XML serialization is not supported"): + XMLCodec().create_serializer(BytesIO()) From 1c7aa1799494347223231f29de7db25e51ab1fc9 Mon Sep 17 00:00:00 2001 From: Antonio Aranda <102337110+arandito@users.noreply.github.com> Date: Thu, 19 Mar 2026 11:38:01 -0400 Subject: [PATCH 3/4] smithy-aws-core: add support for awsQuery protocol --- ...ture-ec24e8dbe26b4ee58029f89da5b6526e.json | 4 + packages/smithy-aws-core/pyproject.toml | 3 + .../src/smithy_aws_core/_private/__init__.py | 2 + .../_private/query/__init__.py | 10 + .../smithy_aws_core/_private/query/errors.py | 106 +++++++ .../_private/query/serializers.py | 243 +++++++++++++++ .../src/smithy_aws_core/aio/protocols.py | 180 ++++++++++- .../src/smithy_aws_core/traits.py | 22 ++ .../tests/unit/aio/test_protocols.py | 175 ++++++++++- .../smithy-aws-core/tests/unit/test_query.py | 282 ++++++++++++++++++ .../smithy-aws-core/tests/unit/test_traits.py | 14 +- .../smithy-core/src/smithy_core/schemas.py | 3 + uv.lock | 6 +- 13 files changed, 1033 insertions(+), 17 deletions(-) create mode 100644 packages/smithy-aws-core/.changes/next-release/smithy-aws-core-feature-ec24e8dbe26b4ee58029f89da5b6526e.json create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/__init__.py create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/query/__init__.py create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/query/errors.py create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/query/serializers.py create mode 100644 packages/smithy-aws-core/tests/unit/test_query.py diff --git a/packages/smithy-aws-core/.changes/next-release/smithy-aws-core-feature-ec24e8dbe26b4ee58029f89da5b6526e.json b/packages/smithy-aws-core/.changes/next-release/smithy-aws-core-feature-ec24e8dbe26b4ee58029f89da5b6526e.json new file mode 100644 index 000000000..377903034 --- /dev/null +++ b/packages/smithy-aws-core/.changes/next-release/smithy-aws-core-feature-ec24e8dbe26b4ee58029f89da5b6526e.json @@ -0,0 +1,4 @@ +{ + "type": "feature", + "description": "Add `awsQuery` protocol support for Smithy clients." +} \ No newline at end of file diff --git a/packages/smithy-aws-core/pyproject.toml b/packages/smithy-aws-core/pyproject.toml index 7dbd5ffbc..f9c0cec4c 100644 --- a/packages/smithy-aws-core/pyproject.toml +++ b/packages/smithy-aws-core/pyproject.toml @@ -50,6 +50,9 @@ eventstream = [ json = [ "smithy-json~=0.2.0" ] +xml = [ + "smithy-xml~=0.1.0" +] [tool.hatch.build] exclude = [ diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/__init__.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/__init__.py new file mode 100644 index 000000000..33cbe867a --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/query/__init__.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/query/__init__.py new file mode 100644 index 000000000..694350d1f --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/query/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from .errors import create_aws_query_error +from .serializers import QueryShapeSerializer + +__all__ = ( + "QueryShapeSerializer", + "create_aws_query_error", +) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/query/errors.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/query/errors.py new file mode 100644 index 000000000..ed0dd97d4 --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/query/errors.py @@ -0,0 +1,106 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Any +from xml.etree.ElementTree import Element, ParseError, fromstring + +from smithy_core.documents import TypeRegistry +from smithy_core.exceptions import CallError, ExpectationNotMetError, ModeledError +from smithy_core.schemas import APIOperation +from smithy_core.shapes import ShapeID +from smithy_xml import XMLCodec + +from ...traits import AwsQueryErrorTrait + + +def _local_name(tag: str) -> str: + """Strip namespace URI from an element tag: {uri}local -> local.""" + if tag.startswith("{"): + return tag.split("}", 1)[1] + return tag + + +def _find_child(element: Element, name: str) -> Element | None: + """Return the first child element whose local name matches ``name``.""" + for child in element: + if _local_name(child.tag) == name: + return child + return None + + +def _parse_aws_query_error_code( + body: bytes, wrapper_elements: tuple[str, ...] +) -> str | None: + """Parse the ``Code`` field from a wrapped awsQuery error response.""" + try: + element = fromstring(body) # noqa: S314 + except ParseError: + return None + + if wrapper_elements: + if _local_name(element.tag) != wrapper_elements[0]: + return None + for wrapper in wrapper_elements[1:]: + next_element = _find_child(element, wrapper) + if next_element is None: + return None + element = next_element + + code_element = _find_child(element, "Code") + return code_element.text if code_element is not None else None + + +def _resolve_aws_query_error_shape_id( + *, + code: str, + operation: APIOperation[Any, Any], + error_registry: TypeRegistry, + default_namespace: str, +) -> ShapeID | None: + """Resolve an awsQuery error code to a modeled error shape ID.""" + for error_schema in operation.error_schemas: + trait = error_schema.get_trait(AwsQueryErrorTrait) + if trait is not None and trait.code == code: + if error_schema.id in error_registry: + return error_schema.id + break + + fallback_id = ShapeID.from_parts(namespace=default_namespace, name=code) + return fallback_id if fallback_id in error_registry else None + + +def create_aws_query_error( + *, + body: bytes, + operation: APIOperation[Any, Any], + error_registry: TypeRegistry, + default_namespace: str, + wrapper_elements: tuple[str, ...], + status: int, +) -> CallError: + """Create a modeled or generic CallError from an awsQuery error response.""" + code = _parse_aws_query_error_code(body, wrapper_elements) + if code is not None: + shape_id = _resolve_aws_query_error_shape_id( + code=code, + operation=operation, + error_registry=error_registry, + default_namespace=default_namespace, + ) + if shape_id is not None: + error_shape = error_registry.get(shape_id) + if not issubclass(error_shape, ModeledError): + raise ExpectationNotMetError( + "Modeled errors must be derived from 'ModeledError', " + f"but got {error_shape}" + ) + + deserializer = XMLCodec().create_deserializer( + body, wrapper_elements=wrapper_elements + ) + return error_shape.deserialize(deserializer) + + message = f"Unknown error for operation {operation.schema.id} - status: {status}" + if code is not None: + message += f", code: {code}" + fault = "client" if 400 <= status < 500 else "server" + return CallError(message=message, fault=fault) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/query/serializers.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/query/serializers.py new file mode 100644 index 000000000..03b585321 --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/query/serializers.py @@ -0,0 +1,243 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from base64 import b64encode +from collections.abc import Callable +from contextlib import AbstractContextManager +from datetime import datetime +from decimal import Decimal +from types import TracebackType +from typing import Self +from urllib.parse import quote + +from smithy_core.documents import Document +from smithy_core.exceptions import SerializationError +from smithy_core.interfaces import BytesWriter +from smithy_core.schemas import Schema +from smithy_core.serializers import ( + InterceptingSerializer, + MapSerializer, + ShapeSerializer, +) +from smithy_core.traits import TimestampFormatTrait, XmlFlattenedTrait, XmlNameTrait +from smithy_core.types import TimestampFormat +from smithy_core.utils import serialize_float + + +def _percent_encode_query(value: str) -> str: + """Encode a query key or value using RFC 3986 percent-encoding.""" + return quote(value, safe="-_.~") + + +def _resolve_name(schema: Schema, default: str) -> str: + """Return ``@xmlName`` when present, otherwise ``default``.""" + if (xml_name := schema.get_trait(XmlNameTrait)) is not None: + return xml_name.value + return default + + +def _is_flattened(schema: Schema) -> bool: + """Return whether a collection is ``@xmlFlattened``.""" + return schema.get_trait(XmlFlattenedTrait) is not None + + +class QueryShapeSerializer(ShapeSerializer): + """Serializes Smithy shapes into AWS Query form parameters. + + Tracks a dotted key path and accumulates ``(key, value)`` pairs in a + shared buffer. Struct/list/map serializers create children that extend the + path, and primitives append terminal values at the current path. ``flush`` + emits the buffered pairs as the query payload. + """ + + def __init__( + self, + *, + sink: BytesWriter, + action: str | None = None, + version: str | None = None, + path: tuple[str, ...] = (), + params: list[tuple[str, str]] | None = None, + default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME, + ) -> None: + self._sink = sink + self._action = action + self._version = version + self._path = path + self._params = [] if params is None else params + self._default_timestamp_format = default_timestamp_format + + def child(self, *segments: str) -> "QueryShapeSerializer": + return QueryShapeSerializer( + sink=self._sink, + path=(*self._path, *segments), + params=self._params, + default_timestamp_format=self._default_timestamp_format, + ) + + def append(self, value: str) -> None: + if not self._path: + raise SerializationError( + "Unable to serialize AWS Query value without a key path." + ) + self._params.append((".".join(self._path), value)) + + def begin_struct(self, schema: Schema) -> AbstractContextManager[ShapeSerializer]: + return QueryStructSerializer(self) + + def begin_list( + self, schema: Schema, size: int + ) -> AbstractContextManager[ShapeSerializer]: + if size == 0: + self.append("") + return QueryListSerializer(self, schema) + + def begin_map( + self, schema: Schema, size: int + ) -> AbstractContextManager[MapSerializer]: + return QueryMapSerializer(self, schema) + + def write_null(self, schema: Schema) -> None: + return None + + def write_boolean(self, schema: Schema, value: bool) -> None: + self.append("true" if value else "false") + + def write_integer(self, schema: Schema, value: int) -> None: + self.append(str(value)) + + def write_float(self, schema: Schema, value: float) -> None: + self.append(serialize_float(value)) + + def write_big_decimal(self, schema: Schema, value: Decimal) -> None: + self.append(serialize_float(value)) + + def write_string(self, schema: Schema, value: str) -> None: + self.append(value) + + def write_blob(self, schema: Schema, value: bytes) -> None: + self.append(b64encode(value).decode("utf-8")) + + def write_timestamp(self, schema: Schema, value: datetime) -> None: + format = self._default_timestamp_format + if (trait := schema.get_trait(TimestampFormatTrait)) is not None: + format = trait.format + self.append(str(format.serialize(value))) + + def write_document(self, schema: Schema, value: Document) -> None: + raise SerializationError("Query protocols do not support document types.") + + def flush(self) -> None: + serialized: list[tuple[str, str]] = [] + if self._action is not None and self._version is not None: + serialized.extend( + [ + ("Action", self._action), + ("Version", self._version), + ] + ) + serialized.extend(self._params) + body = "&".join( + f"{_percent_encode_query(key)}={_percent_encode_query(value)}" + for key, value in serialized + ).encode("utf-8") + self._sink.write(body) + + +class QueryStructSerializer(InterceptingSerializer): + """Serializes struct members as child query paths. + + ``before`` creates a child serializer rooted at the member name, honoring + ``@xmlName``. + """ + + def __init__(self, parent: QueryShapeSerializer) -> None: + self._parent = parent + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + def before(self, schema: Schema) -> ShapeSerializer: + return self._parent.child(_resolve_name(schema, schema.expect_member_name())) + + def after(self, schema: Schema) -> None: + pass + + +class QueryListSerializer(InterceptingSerializer): + """Serializes list entries as indexed child query paths. + + ``before`` increments a 1-based index and creates the item path as either + ``.`` or ```` when the list is flattened. + """ + + def __init__(self, parent: QueryShapeSerializer, schema: Schema) -> None: + self._parent = parent + self._is_flattened = _is_flattened(schema) + self._item_name = _resolve_name(schema.members["member"], "member") + self._index = 0 + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + def before(self, schema: Schema) -> ShapeSerializer: + self._index += 1 + if self._is_flattened: + return self._parent.child(str(self._index)) + return self._parent.child(self._item_name, str(self._index)) + + def after(self, schema: Schema) -> None: + pass + + +class QueryMapSerializer(MapSerializer): + """Serializes map entries as indexed key and value query paths. + + Each entry increments a 1-based index, uses ``entry.`` (or + ```` when flattened), writes the key at ``...``, and + serializes the value at ``...``. + """ + + def __init__(self, parent: QueryShapeSerializer, schema: Schema) -> None: + self._parent = parent + self._is_flattened = _is_flattened(schema) + self._key_name = _resolve_name(schema.members["key"], "key") + self._value_name = _resolve_name(schema.members["value"], "value") + self._index = 0 + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + def entry(self, key: str, value_writer: Callable[[ShapeSerializer], None]) -> None: + self._index += 1 + if self._is_flattened: + entry_path = (str(self._index),) + else: + entry_path = ("entry", str(self._index)) + + self._parent.child(*entry_path, self._key_name).append(key) + value_writer(self._parent.child(*entry_path, self._value_name)) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py b/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py index 709651a4a..ab7078b03 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py @@ -2,31 +2,55 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Callable from inspect import iscoroutinefunction +from io import BytesIO from typing import TYPE_CHECKING, Any, Final +from smithy_core import URI as _URI from smithy_core.aio.interfaces import AsyncWriter from smithy_core.aio.interfaces.auth import AuthScheme from smithy_core.aio.interfaces.eventstream import EventPublisher, EventReceiver from smithy_core.aio.types import AsyncBytesReader from smithy_core.codecs import Codec from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer +from smithy_core.documents import TypeRegistry from smithy_core.exceptions import ( + CallError, DiscriminatorError, MissingDependencyError, UnsupportedStreamError, ) -from smithy_core.interfaces import TypedProperties +from smithy_core.interfaces import TypedProperties, URI from smithy_core.schemas import APIOperation, Schema from smithy_core.serializers import SerializeableShape from smithy_core.shapes import ShapeID, ShapeType from smithy_core.types import TimestampFormat +from smithy_http import tuples_to_fields +from smithy_http.aio import HTTPRequest as _HTTPRequest from smithy_http.aio.interfaces import HTTPErrorIdentifier, HTTPRequest, HTTPResponse -from smithy_http.aio.protocols import HttpBindingClientProtocol -from smithy_json import JSONCodec, JSONDocument +from smithy_http.aio.protocols import HttpBindingClientProtocol, HttpClientProtocol +from smithy_http.deserializers import HTTPResponseDeserializer -from ..traits import RestJson1Trait +from .._private.query.errors import ( + create_aws_query_error, +) +from .._private.query.serializers import QueryShapeSerializer +from ..traits import AwsQueryTrait, RestJson1Trait from ..utils import parse_document_discriminator, parse_error_code +try: + from smithy_json import JSONCodec, JSONDocument + + _HAS_JSON = True +except ImportError: + _HAS_JSON = False # type: ignore + +try: + from smithy_xml import XMLCodec + + _HAS_XML = True +except ImportError: + _HAS_XML = False # type: ignore + try: from smithy_aws_event_stream.aio import ( AWSEventPublisher, @@ -44,10 +68,26 @@ AWSEventReceiver, SigningConfig, ) + from smithy_json import JSONCodec, JSONDocument + from smithy_xml import XMLCodec from typing_extensions import TypeForm -def _assert_event_stream_capable() -> None: +def _assert_json() -> None: + if not _HAS_JSON: + raise MissingDependencyError( + "Attempted to use JSON protocol support, but smithy-json is not installed." + ) + + +def _assert_xml() -> None: + if not _HAS_XML: + raise MissingDependencyError( + "Attempted to use XML protocol support, but smithy-xml is not installed." + ) + + +def _assert_event_stream() -> None: if not _HAS_EVENT_STREAM: raise MissingDependencyError( "Attempted to use event streams, but smithy-aws-event-stream " @@ -99,6 +139,7 @@ def __init__(self, service_schema: Schema) -> None: :param service: The schema for the service to interact with. """ + _assert_json() self._codec: Final = JSONCodec( document_class=AWSJSONDocument, default_namespace=service_schema.id.namespace, @@ -134,7 +175,7 @@ def create_event_publisher[ context: TypedProperties, auth_scheme: AuthScheme[Any, Any, Any, Any] | None = None, ) -> EventPublisher[Event]: - _assert_event_stream_capable() + _assert_event_stream() signing_config: SigningConfig | None = None if auth_scheme is not None: event_signer = auth_scheme.event_signer(request=request) @@ -177,9 +218,134 @@ def create_event_receiver[ event_deserializer: Callable[[ShapeDeserializer], Event], context: TypedProperties, ) -> EventReceiver[Event]: - _assert_event_stream_capable() + _assert_event_stream() return AWSEventReceiver( payload_codec=self.payload_codec, source=AsyncBytesReader(response.body), deserializer=event_deserializer, ) + + +class AwsQueryClientProtocol(HttpClientProtocol): + """An implementation of the aws.protocols#awsQuery protocol.""" + + _id: Final = AwsQueryTrait.id + _content_type: Final = "application/x-www-form-urlencoded" + + def __init__(self, service_schema: Schema, version: str) -> None: + _assert_xml() + self._default_namespace: Final = service_schema.id.namespace + self._version: Final = version + self._codec: Final = XMLCodec(default_namespace=self._default_namespace) + + @property + def id(self) -> ShapeID: + return self._id + + def serialize_request[ + OperationInput: SerializeableShape, + OperationOutput: DeserializeableShape, + ]( + self, + *, + operation: APIOperation[OperationInput, OperationOutput], + input: OperationInput, + endpoint: URI, + context: TypedProperties, + ) -> HTTPRequest: + sink = BytesIO() + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=sink, + action=self._action_name(operation), + version=self._version, + params=params, + ) + input.serialize(serializer) + serializer.flush() + content_length = sink.tell() + sink.seek(0) + body = AsyncBytesReader(sink) + return _HTTPRequest( + method="POST", + destination=_URI(host="", path="/"), + fields=tuples_to_fields( + [ + ("content-type", self._content_type), + ("content-length", str(content_length)), + ] + ), + body=body, + ) + + async def deserialize_response[ + OperationInput: SerializeableShape, + OperationOutput: DeserializeableShape, + ]( + self, + *, + operation: APIOperation[OperationInput, OperationOutput], + request: HTTPRequest, + response: HTTPResponse, + error_registry: TypeRegistry, + context: TypedProperties, + ) -> OperationOutput: + body = await response.consume_body_async() + + if response.status >= 300: + raise self._create_error( + operation=operation, + response=response, + response_body=body, + error_registry=error_registry, + ) + + if len(body) == 0: + return operation.output.deserialize( + HTTPResponseDeserializer( + payload_codec=self._codec, + response=response, + body=body, + ) + ) + + wrapper_elements = self._response_wrapper_elements(operation) + deserializer = self._codec.create_deserializer( + body, wrapper_elements=wrapper_elements + ) + return operation.output.deserialize(deserializer) + + def _create_error( + self, + *, + operation: APIOperation[Any, Any], + response: HTTPResponse, + response_body: bytes, + error_registry: TypeRegistry, + ) -> CallError: + return create_aws_query_error( + body=response_body, + operation=operation, + error_registry=error_registry, + default_namespace=self._default_namespace, + wrapper_elements=self._error_wrapper_elements(), + status=response.status, + ) + + def _action_name( + self, + operation: APIOperation[SerializeableShape, DeserializeableShape], + ) -> str: + return operation.schema.id.name + + def _response_wrapper_elements( + self, + operation: APIOperation[SerializeableShape, DeserializeableShape], + ) -> tuple[str, str]: + return ( + f"{operation.schema.id.name}Response", + f"{operation.schema.id.name}Result", + ) + + def _error_wrapper_elements(self) -> tuple[str, ...]: + return ("ErrorResponse", "Error") diff --git a/packages/smithy-aws-core/src/smithy_aws_core/traits.py b/packages/smithy-aws-core/src/smithy_aws_core/traits.py index 3902a55ff..2a789b96e 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/traits.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/traits.py @@ -46,6 +46,28 @@ def __init__(self, value: DocumentValue | DynamicTrait = None): ) +@dataclass(frozen=True) +class AwsQueryTrait(Trait, id=ShapeID("aws.protocols#awsQuery")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class AwsQueryErrorTrait(Trait, id=ShapeID("aws.protocols#awsQueryError")): + def __post_init__(self): + assert isinstance(self.document_value, Mapping) + assert isinstance(self.document_value.get("code"), str) + assert isinstance(self.document_value.get("httpResponseCode"), int | None) + + @property + def code(self) -> str: + return self.document_value["code"] # type: ignore + + @property + def http_response_code(self) -> int | None: + return self.document_value.get("httpResponseCode") # type: ignore + + @dataclass(init=False, frozen=True) class SigV4Trait(Trait, id=ShapeID("aws.auth#sigv4")): def __post_init__(self): diff --git a/packages/smithy-aws-core/tests/unit/aio/test_protocols.py b/packages/smithy-aws-core/tests/unit/aio/test_protocols.py index 7b767a080..d02481948 100644 --- a/packages/smithy-aws-core/tests/unit/aio/test_protocols.py +++ b/packages/smithy-aws-core/tests/unit/aio/test_protocols.py @@ -1,13 +1,27 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from typing import Any, cast from unittest.mock import Mock import pytest -from smithy_aws_core.aio.protocols import AWSErrorIdentifier, AWSJSONDocument -from smithy_core.exceptions import DiscriminatorError +from smithy_aws_core.aio.protocols import ( + AWSErrorIdentifier, + AWSJSONDocument, + AwsQueryClientProtocol, +) +from smithy_aws_core.traits import AwsQueryTrait +from smithy_core.deserializers import ShapeDeserializer +from smithy_core.documents import TypeRegistry +from smithy_core.exceptions import CallError, DiscriminatorError, ModeledError +from smithy_core.interfaces import URI +from smithy_core.prelude import STRING from smithy_core.schemas import APIOperation, Schema +from smithy_core.serializers import ShapeSerializer from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.traits import Trait +from smithy_core.types import TypedProperties from smithy_http import Fields, tuples_to_fields from smithy_http.aio import HTTPResponse from smithy_json import JSONSettings @@ -36,13 +50,11 @@ def test_aws_error_identifier(header: str | None, expected: ShapeID | None) -> N fields = tuples_to_fields([("x-amzn-errortype", header)]) http_response = HTTPResponse(status=500, fields=fields) - operation = Mock(spec=APIOperation) - operation.schema = Schema( - id=ShapeID("com.test#TestOperation"), shape_type=ShapeType.OPERATION - ) - error_identifier = AWSErrorIdentifier() - actual = error_identifier.identify(operation=operation, response=http_response) + actual = error_identifier.identify( + operation=_mock_operation(_operation_schema("TestOperation")), + response=http_response, + ) assert actual == expected @@ -97,3 +109,150 @@ def test_aws_json_document_discriminator( else: discriminator = AWSJSONDocument(document, settings=settings).discriminator assert discriminator == expected + + +_INPUT_SCHEMA = Schema.collection( + id=ShapeID("com.test#TestInput"), + members={"name": {"target": STRING}}, +) +_SERVICE_SCHEMA = Schema.collection( + id=ShapeID("com.test#QueryService"), + shape_type=ShapeType.SERVICE, + traits=[AwsQueryTrait(None)], +) +_INVALID_ACTION_ERROR_SCHEMA = Schema.collection( + id=ShapeID("com.test#InvalidActionError"), + traits=[ + Trait.new(id=ShapeID("smithy.api#error"), value="client"), + Trait.new( + id=ShapeID("aws.protocols#awsQueryError"), + value={"code": "InvalidAction"}, + ), + ], + members={"message": {"target": STRING}}, +) + + +@dataclass +class _TestInput: + name: str | None = None + + def serialize(self, serializer: ShapeSerializer) -> None: + serializer.write_struct(_INPUT_SCHEMA, self) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + if self.name is not None: + serializer.write_string(_INPUT_SCHEMA.members["name"], self.name) + + +class _ModeledQueryError(ModeledError): + message: str + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> "_ModeledQueryError": + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + if schema.expect_member_name() == "message": + kwargs["message"] = de.read_string(schema) + + deserializer.read_struct(_INVALID_ACTION_ERROR_SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +def _operation_schema(name: str) -> Schema: + return Schema( + id=ShapeID(f"com.test#{name}"), + shape_type=ShapeType.OPERATION, + ) + + +def _mock_operation( + schema: Schema, + *, + error_schemas: list[Schema] | None = None, +) -> APIOperation[Any, Any]: + operation = Mock(spec=APIOperation) + operation.schema = schema + operation.error_schemas = error_schemas or [] + return cast("APIOperation[Any, Any]", operation) + + +@pytest.mark.asyncio +async def test_aws_query_serializes_base_request_shape() -> None: + protocol = AwsQueryClientProtocol(_SERVICE_SCHEMA, "2020-01-08") + request = protocol.serialize_request( + operation=_mock_operation(_operation_schema("TestOperation")), + input=_TestInput(name="example"), + endpoint=cast(URI, Mock()), + context=TypedProperties(), + ) + + assert request.method == "POST" + assert request.destination.path == "/" + assert ( + request.fields["content-type"].as_string() + == "application/x-www-form-urlencoded" + ) + body = await request.consume_body_async() + assert request.fields["content-length"].as_string() == str(len(body)) + assert body == b"Action=TestOperation&Version=2020-01-08&name=example" + + +def test_aws_query_resolves_modeled_error_from_query_error_trait() -> None: + protocol = AwsQueryClientProtocol(_SERVICE_SCHEMA, "2020-01-08") + error = getattr(protocol, "_create_error")( + operation=_mock_operation( + _operation_schema("FailingOperation"), + error_schemas=[_INVALID_ACTION_ERROR_SCHEMA], + ), + response=HTTPResponse(status=400, fields=tuples_to_fields([])), + response_body=( + b"InvalidAction" + b"bad request" + ), + error_registry=TypeRegistry( + {ShapeID("com.test#InvalidActionError"): _ModeledQueryError} + ), + ) + + assert isinstance(error, _ModeledQueryError) + assert error.message == "bad request" + + +def test_aws_query_resolves_modeled_error_from_default_namespace_fallback() -> None: + protocol = AwsQueryClientProtocol(_SERVICE_SCHEMA, "2020-01-08") + error = getattr(protocol, "_create_error")( + operation=_mock_operation(_operation_schema("FailingOperation")), + response=HTTPResponse(status=503, fields=tuples_to_fields([])), + response_body=( + b"ServiceUnavailable" + b"try again" + ), + error_registry=TypeRegistry( + {ShapeID("com.test#ServiceUnavailable"): _ModeledQueryError} + ), + ) + + assert isinstance(error, _ModeledQueryError) + assert error.message == "try again" + + +def test_aws_query_returns_generic_error_for_unknown_code() -> None: + protocol = AwsQueryClientProtocol(_SERVICE_SCHEMA, "2020-01-08") + error = getattr(protocol, "_create_error")( + operation=_mock_operation(_operation_schema("FailingOperation")), + response=HTTPResponse(status=500, fields=tuples_to_fields([])), + response_body=( + b"UnknownThing" + b"bad request" + ), + error_registry=TypeRegistry({}), + ) + + assert isinstance(error, CallError) + assert not isinstance(error, ModeledError) + assert error.message == ( + "Unknown error for operation com.test#FailingOperation" + " - status: 500, code: UnknownThing" + ) diff --git a/packages/smithy-aws-core/tests/unit/test_query.py b/packages/smithy-aws-core/tests/unit/test_query.py new file mode 100644 index 000000000..ab85bf664 --- /dev/null +++ b/packages/smithy-aws-core/tests/unit/test_query.py @@ -0,0 +1,282 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from io import BytesIO + +from smithy_aws_core._private.query.serializers import QueryShapeSerializer +from smithy_core.prelude import STRING +from smithy_core.schemas import Schema +from smithy_core.serializers import ShapeSerializer +from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.traits import XmlFlattenedTrait, XmlNameTrait + + +def test_query_list_serialization() -> None: + list_schema = Schema.collection( + id=ShapeID("com.test#StringList"), + shape_type=ShapeType.LIST, + members={"member": {"target": STRING}}, + ) + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), + action="TestOperation", + version="2020-01-08", + path=("Items",), + params=params, + ) + with serializer.begin_list(list_schema, 2) as list_serializer: + member_schema = list_schema.members["member"] + list_serializer.write_string(member_schema, "a") + list_serializer.write_string(member_schema, "b") + + assert params == [ + ("Items.member.1", "a"), + ("Items.member.2", "b"), + ] + + +def test_query_flattened_list_serialization() -> None: + list_schema = Schema.collection( + id=ShapeID("com.test#StringList"), + shape_type=ShapeType.LIST, + traits=[XmlFlattenedTrait()], + members={"member": {"target": STRING}}, + ) + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), + action="TestOperation", + version="2020-01-08", + path=("Items",), + params=params, + ) + with serializer.begin_list(list_schema, 2) as list_serializer: + member_schema = list_schema.members["member"] + list_serializer.write_string(member_schema, "a") + list_serializer.write_string(member_schema, "b") + + assert params == [("Items.1", "a"), ("Items.2", "b")] + + +def test_query_empty_list_serialization() -> None: + list_schema = Schema.collection( + id=ShapeID("com.test#StringList"), + shape_type=ShapeType.LIST, + members={"member": {"target": STRING}}, + ) + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), + action="TestOperation", + version="2020-01-08", + path=("Items",), + params=params, + ) + with serializer.begin_list(list_schema, 0): + pass + + assert params == [("Items", "")] + + +def test_query_flattened_list_uses_member_xml_name() -> None: + list_schema = Schema.collection( + id=ShapeID("com.test#StringList"), + shape_type=ShapeType.LIST, + members={"member": {"target": STRING, "traits": [XmlNameTrait("item")]}}, + ) + input_schema = Schema.collection( + id=ShapeID("com.test#Input"), + members={ + "values": { + "target": list_schema, + "traits": [XmlFlattenedTrait(), XmlNameTrait("Hi")], + } + }, + ) + + @dataclass + class Input: + values: list[str] + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(input_schema) as struct_serializer: + self.serialize_members(struct_serializer) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + schema = input_schema.members["values"] + member_schema = schema.expect_member_target().members["member"] + with serializer.begin_list(schema, len(self.values)) as list_serializer: + for value in self.values: + list_serializer.write_string(member_schema, value) + + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), action="TestOperation", version="2020-01-08", params=params + ) + Input(values=["a", "b"]).serialize(serializer) + + assert params == [("Hi.1", "a"), ("Hi.2", "b")] + + +def test_query_map_serialization_uses_xml_name_traits() -> None: + map_schema = Schema.collection( + id=ShapeID("com.test#StringMap"), + shape_type=ShapeType.MAP, + members={ + "key": {"target": STRING, "traits": [XmlNameTrait("K")]}, + "value": {"target": STRING, "traits": [XmlNameTrait("V")]}, + }, + ) + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), + action="TestOperation", + version="2020-01-08", + path=("Attributes",), + params=params, + ) + with serializer.begin_map(map_schema, 1) as map_serializer: + map_serializer.entry( + "one", lambda value_serializer: value_serializer.write_string(STRING, "1") + ) + + assert params == [ + ("Attributes.entry.1.K", "one"), + ("Attributes.entry.1.V", "1"), + ] + + +def test_query_flattened_map_serialization() -> None: + map_schema = Schema.collection( + id=ShapeID("com.test#StringMap"), + shape_type=ShapeType.MAP, + traits=[XmlFlattenedTrait()], + members={ + "key": {"target": STRING}, + "value": {"target": STRING}, + }, + ) + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), + action="TestOperation", + version="2020-01-08", + path=("Attributes",), + params=params, + ) + with serializer.begin_map(map_schema, 2) as map_serializer: + map_serializer.entry( + "one", lambda value_serializer: value_serializer.write_string(STRING, "1") + ) + map_serializer.entry( + "two", lambda value_serializer: value_serializer.write_string(STRING, "2") + ) + + assert params == [ + ("Attributes.1.key", "one"), + ("Attributes.1.value", "1"), + ("Attributes.2.key", "two"), + ("Attributes.2.value", "2"), + ] + + +def test_query_empty_map_is_omitted() -> None: + map_schema = Schema.collection( + id=ShapeID("com.test#StringMap"), + shape_type=ShapeType.MAP, + members={ + "key": {"target": STRING}, + "value": {"target": STRING}, + }, + ) + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), + action="TestOperation", + version="2020-01-08", + path=("Attributes",), + params=params, + ) + with serializer.begin_map(map_schema, 0): + pass + + assert params == [] + + +def test_query_null_member_is_omitted() -> None: + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), + action="TestOperation", + version="2020-01-08", + path=("Nullable",), + params=params, + ) + + serializer.write_null(STRING) + + assert params == [] + + +def test_query_serializer_flush_writes_body_to_sink() -> None: + sink = BytesIO() + serializer = QueryShapeSerializer( + sink=sink, + action="TestOperation", + version="2020-01-08", + path=("Member Name",), + ) + serializer.write_string(STRING, "hello world") + serializer.flush() + + expected = b"Action=TestOperation&Version=2020-01-08&Member%20Name=hello%20world" + assert sink.getvalue() == expected + + +def test_query_serializer_flush_omits_action_and_version_when_unset() -> None: + sink = BytesIO() + serializer = QueryShapeSerializer(sink=sink, path=("MemberName",)) + serializer.write_string(STRING, "hello world") + serializer.flush() + + assert sink.getvalue() == b"MemberName=hello%20world" + + +def test_query_nested_struct_serialization() -> None: + inner_schema = Schema.collection( + id=ShapeID("com.test#Inner"), + members={"value": {"target": STRING}}, + ) + outer_schema = Schema.collection( + id=ShapeID("com.test#Outer"), + members={"inner": {"target": inner_schema}}, + ) + + @dataclass + class Inner: + value: str + + def serialize(self, serializer: ShapeSerializer) -> None: + serializer.write_struct(inner_schema, self) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_string(inner_schema.members["value"], self.value) + + @dataclass + class Outer: + inner: Inner + + def serialize(self, serializer: ShapeSerializer) -> None: + serializer.write_struct(outer_schema, self) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_struct(outer_schema.members["inner"], self.inner) + + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), action="TestOperation", version="2020-01-08", params=params + ) + Outer(inner=Inner("x")).serialize(serializer) + + assert params == [("inner.value", "x")] diff --git a/packages/smithy-aws-core/tests/unit/test_traits.py b/packages/smithy-aws-core/tests/unit/test_traits.py index d4f04ebf1..bf4d7fe6c 100644 --- a/packages/smithy-aws-core/tests/unit/test_traits.py +++ b/packages/smithy-aws-core/tests/unit/test_traits.py @@ -1,9 +1,21 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from smithy_aws_core.traits import RestJson1Trait +from smithy_aws_core.traits import AwsQueryErrorTrait, AwsQueryTrait, RestJson1Trait def test_allows_empty_restjson1_value() -> None: trait = RestJson1Trait(None) assert trait.http == ("http/1.1",) + assert trait.event_stream_http == ("http/1.1",) + + +def test_allows_empty_aws_query_trait_value() -> None: + trait = AwsQueryTrait(None) + assert trait.document_value is None + + +def test_parses_aws_query_error_trait() -> None: + trait = AwsQueryErrorTrait({"code": "InvalidAction", "httpResponseCode": 400}) + assert trait.code == "InvalidAction" + assert trait.http_response_code == 400 diff --git a/packages/smithy-core/src/smithy_core/schemas.py b/packages/smithy-core/src/smithy_core/schemas.py index 925073b42..8c593fdfc 100644 --- a/packages/smithy-core/src/smithy_core/schemas.py +++ b/packages/smithy-core/src/smithy_core/schemas.py @@ -303,6 +303,9 @@ class APIOperation[I: "SerializeableShape", O: "DeserializeableShape"]: effective_auth_schemes: Sequence[ShapeID] """A list of effective auth schemes for the operation.""" + error_schemas: Sequence[Schema] = field(repr=False) + """A list of modeled error schemas for the operation.""" + @property def idempotency_token_member(self) -> Schema | None: """The input schema member that serves as the idempotency token.""" diff --git a/uv.lock b/uv.lock index 53f054dc5..2a844cde3 100644 --- a/uv.lock +++ b/uv.lock @@ -669,6 +669,9 @@ eventstream = [ json = [ { name = "smithy-json" }, ] +xml = [ + { name = "smithy-xml" }, +] [package.metadata] requires-dist = [ @@ -677,8 +680,9 @@ requires-dist = [ { name = "smithy-core", editable = "packages/smithy-core" }, { name = "smithy-http", editable = "packages/smithy-http" }, { name = "smithy-json", marker = "extra == 'json'", editable = "packages/smithy-json" }, + { name = "smithy-xml", marker = "extra == 'xml'", editable = "packages/smithy-xml" }, ] -provides-extras = ["eventstream", "json"] +provides-extras = ["eventstream", "json", "xml"] [[package]] name = "smithy-aws-event-stream" From b4369c15f099333fabe9e75a9af991fd17c5a483 Mon Sep 17 00:00:00 2001 From: Antonio Aranda <102337110+arandito@users.noreply.github.com> Date: Thu, 19 Mar 2026 11:49:08 -0400 Subject: [PATCH 4/4] codegen: generate AwsQueryClientProtocol for awsQuery services and generate protocol tests --- Makefile | 10 +-- codegen/aws/core/build.gradle.kts | 1 + .../aws/codegen/AwsProtocolsIntegration.java | 2 +- .../codegen/AwsQueryProtocolGenerator.java | 69 +++++++++++++++++++ .../codegen/HttpProtocolTestGenerator.java | 36 +++++++++- .../generators/OperationGenerator.java | 13 +++- codegen/protocol-test/build.gradle.kts | 1 + codegen/protocol-test/smithy-build.json | 22 ++++++ 8 files changed, 145 insertions(+), 9 deletions(-) create mode 100644 codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsQueryProtocolGenerator.java diff --git a/Makefile b/Makefile index 5e931e1c5..ec2452598 100644 --- a/Makefile +++ b/Makefile @@ -14,10 +14,12 @@ build-java: ## Builds the Java code generation packages. cd codegen && ./gradlew clean build -test-protocols: ## Generates and runs the restJson1 protocol tests. - cd codegen && ./gradlew :protocol-test:build - uv pip install codegen/protocol-test/build/smithyprojections/protocol-test/rest-json-1/python-client-codegen - uv run pytest codegen/protocol-test/build/smithyprojections/protocol-test/rest-json-1/python-client-codegen +test-protocols: ## Generates and runs protocol tests for all supported protocols. + cd codegen && ./gradlew :protocol-test:clean :protocol-test:build + @set -e; for projection_dir in codegen/protocol-test/build/smithyprojections/protocol-test/*/python-client-codegen; do \ + uv pip install "$$projection_dir"; \ + uv run pytest "$$projection_dir"; \ + done lint-py: ## Runs linters and formatters on the python packages. diff --git a/codegen/aws/core/build.gradle.kts b/codegen/aws/core/build.gradle.kts index 3a81c5190..49d506d8d 100644 --- a/codegen/aws/core/build.gradle.kts +++ b/codegen/aws/core/build.gradle.kts @@ -12,4 +12,5 @@ extra["moduleName"] = "software.amazon.smithy.python.aws.codegen" dependencies { implementation(project(":core")) implementation(libs.smithy.aws.traits) + implementation(libs.smithy.protocol.test.traits) } diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsProtocolsIntegration.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsProtocolsIntegration.java index 63601dd5d..d7d24a4af 100644 --- a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsProtocolsIntegration.java +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsProtocolsIntegration.java @@ -16,6 +16,6 @@ public class AwsProtocolsIntegration implements PythonIntegration { @Override public List getProtocolGenerators() { - return List.of(); + return List.of(new AwsQueryProtocolGenerator()); } } diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsQueryProtocolGenerator.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsQueryProtocolGenerator.java new file mode 100644 index 000000000..f02c5e815 --- /dev/null +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsQueryProtocolGenerator.java @@ -0,0 +1,69 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.python.aws.codegen; + +import java.util.Set; +import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.python.codegen.ApplicationProtocol; +import software.amazon.smithy.python.codegen.GenerationContext; +import software.amazon.smithy.python.codegen.HttpProtocolTestGenerator; +import software.amazon.smithy.python.codegen.SymbolProperties; +import software.amazon.smithy.python.codegen.generators.ProtocolGenerator; +import software.amazon.smithy.python.codegen.writer.PythonWriter; +import software.amazon.smithy.utils.SmithyInternalApi; + +@SmithyInternalApi +public final class AwsQueryProtocolGenerator implements ProtocolGenerator { + private static final Set TESTS_TO_SKIP = Set.of( + // TODO: support the request compression trait + // https://smithy.io/2.0/spec/behavior-traits.html#smithy-api-requestcompression-trait + "SDKAppliedContentEncoding_awsQuery", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsQuery", + + // TODO: support idempotency token autofill + "QueryProtocolIdempotencyTokenAutoFill", + + // This test asserts nan == nan, which is never true. + // We should update the generator to make specific assertions for these. + "AwsQuerySupportsNaNFloatOutputs", + + // TODO: support of the endpoint trait + "AwsQueryEndpointTraitWithHostLabel", + "AwsQueryEndpointTrait"); + + @Override + public ShapeId getProtocol() { + return AwsQueryTrait.ID; + } + + @Override + public ApplicationProtocol getApplicationProtocol(GenerationContext context) { + return ApplicationProtocol.createDefaultHttpApplicationProtocol(); + } + + @Override + public void initializeProtocol(GenerationContext context, PythonWriter writer) { + writer.addDependency(AwsPythonDependency.SMITHY_AWS_CORE.withOptionalDependencies("xml")); + writer.addImport("smithy_aws_core.aio.protocols", "AwsQueryClientProtocol"); + var service = context.settings().service(context.model()); + var serviceSymbol = context.symbolProvider().toSymbol(service); + var serviceSchema = serviceSymbol.expectProperty(SymbolProperties.SCHEMA); + var version = service.getVersion(); + writer.write("AwsQueryClientProtocol($T, $S)", serviceSchema, version); + } + + @Override + public void generateProtocolTests(GenerationContext context) { + context.writerDelegator() + .useFileWriter("./tests/test_awsquery_protocol.py", "tests.test_awsquery_protocol", writer -> { + new HttpProtocolTestGenerator( + context, + getProtocol(), + writer, + (shape, testCase) -> TESTS_TO_SKIP.contains(testCase.getId())).run(); + }); + } +} diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java index 0a06f15bf..33f5191e3 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java @@ -16,6 +16,7 @@ import java.util.logging.Logger; import java.util.stream.Collectors; import java.util.stream.Stream; +import software.amazon.smithy.aws.traits.auth.SigV4Trait; import software.amazon.smithy.codegen.core.CodegenException; import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.model.Model; @@ -188,12 +189,14 @@ private void generateRequestTest(OperationShape operation, HttpRequestTestCase t endpoint_uri="https://$L/$L", transport = $T(), retry_strategy=SimpleRetryStrategy(max_attempts=1), + ${C|} ) """, CodegenUtils.getConfigSymbol(context.settings()), host, path, - REQUEST_TEST_ASYNC_HTTP_CLIENT_SYMBOL); + REQUEST_TEST_ASYNC_HTTP_CLIENT_SYMBOL, + (Runnable) this::writeSigV4TestConfig); })); // Generate the input using the expected shape and params @@ -418,6 +421,16 @@ private void compareMediaBlob(HttpMessageTestCase testCase, PythonWriter writer) """); return; } + if (contentType.equals("application/x-www-form-urlencoded")) { + writer.addStdlibImport("urllib.parse", "parse_qsl"); + writer.write(""" + actual_params = sorted(parse_qsl(actual_body_content.decode())) + expected_params = sorted(parse_qsl(expected_body_content.decode())) + assert actual_params == expected_params + + """); + return; + } writer.write("assert actual_body_content == expected_body_content\n"); } @@ -437,13 +450,15 @@ private void generateResponseTest(OperationShape operation, HttpResponseTestCase headers=$J, body=b$S, ), + ${C|} ) """, CodegenUtils.getConfigSymbol(context.settings()), RESPONSE_TEST_ASYNC_HTTP_CLIENT_SYMBOL, testCase.getCode(), CodegenUtils.toTuples(testCase.getHeaders()), - testCase.getBody().filter(body -> !body.isEmpty()).orElse("")); + testCase.getBody().filter(body -> !body.isEmpty()).orElse(""), + (Runnable) this::writeSigV4TestConfig); })); // Create an empty input object to pass var inputShape = model.expectShape(operation.getInputShape(), StructureShape.class); @@ -490,13 +505,15 @@ private void generateErrorResponseTest( headers=$J, body=b$S, ), + ${C|} ) """, CodegenUtils.getConfigSymbol(context.settings()), RESPONSE_TEST_ASYNC_HTTP_CLIENT_SYMBOL, testCase.getCode(), CodegenUtils.toTuples(testCase.getHeaders()), - testCase.getBody().orElse("")); + testCase.getBody().orElse(""), + (Runnable) this::writeSigV4TestConfig); })); // Create an empty input object to pass var inputShape = model.expectShape(operation.getInputShape(), StructureShape.class); @@ -607,6 +624,19 @@ private void writeClientBlock( }); } + private void writeSigV4TestConfig() { + if (!service.hasTrait(SigV4Trait.class)) { + return; + } + writer.addImport("smithy_aws_core.identity", "StaticCredentialsResolver"); + writer.write(""" + region="us-east-1", + aws_access_key_id="test-access-key-id", + aws_secret_access_key="test-secret-access-key", + aws_credentials_identity_resolver=StaticCredentialsResolver(), + """); + } + private void writeUtilStubs(Symbol serviceSymbol) { LOGGER.fine(String.format("Writing utility stubs for %s : %s", serviceSymbol.getName(), protocol.getName())); writer.addDependency(SmithyPythonDependency.SMITHY_CORE); diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java index 9ada3907a..9557a93ae 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java @@ -58,6 +58,9 @@ public void run() { }), effective_auth_schemes = [ $8C + ], + error_schemas = [ + $9C ] ) """, @@ -68,7 +71,8 @@ public void run() { inSymbol.expectProperty(SymbolProperties.SCHEMA), outSymbol.expectProperty(SymbolProperties.SCHEMA), writer.consumer(this::writeErrorTypeRegistry), - writer.consumer(this::writeAuthSchemes)); + writer.consumer(this::writeAuthSchemes), + writer.consumer(this::writeErrorSchemas)); } private void writeErrorTypeRegistry(PythonWriter writer) { @@ -82,6 +86,13 @@ private void writeErrorTypeRegistry(PythonWriter writer) { } } + private void writeErrorSchemas(PythonWriter writer) { + for (var error : shape.getErrors()) { + var errSymbol = symbolProvider.toSymbol(model.expectShape(error)); + writer.write("$T,", errSymbol.expectProperty(SymbolProperties.SCHEMA)); + } + } + private void writeAuthSchemes(PythonWriter writer) { var authSchemes = ServiceIndex.of(model) .getEffectiveAuthSchemes(context.settings().service(), diff --git a/codegen/protocol-test/build.gradle.kts b/codegen/protocol-test/build.gradle.kts index 5c470b9c4..cddc35e75 100644 --- a/codegen/protocol-test/build.gradle.kts +++ b/codegen/protocol-test/build.gradle.kts @@ -30,5 +30,6 @@ repositories { dependencies { implementation(project(":core")) + implementation(project(":aws:core")) implementation(libs.smithy.aws.protocol.tests) } diff --git a/codegen/protocol-test/smithy-build.json b/codegen/protocol-test/smithy-build.json index cbaccad98..f5ec825f2 100644 --- a/codegen/protocol-test/smithy-build.json +++ b/codegen/protocol-test/smithy-build.json @@ -22,6 +22,28 @@ "moduleVersion": "0.0.1" } } + }, + "aws-query": { + "transforms": [ + { + "name": "includeServices", + "args": { + "services": [ + "aws.protocoltests.query#AwsQuery" + ] + } + }, + { + "name": "removeUnusedShapes" + } + ], + "plugins": { + "python-client-codegen": { + "service": "aws.protocoltests.query#AwsQuery", + "module": "awsquery", + "moduleVersion": "0.0.1" + } + } } } }