Skip to content
Open
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
changeKind: feature
packages:
- "@typespec/http-client-python"
---

Support `datetime.timedelta` for `duration` types encoded as `seconds` or `milliseconds`. SDK users can now pass a `datetime.timedelta` (instead of a raw `int`/`float`) and responses are deserialized back into `datetime.timedelta`.
7 changes: 7 additions & 0 deletions packages/http-client-python/emitter/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,13 @@ function emitBuiltInType(
encode: type.encode,
});
}
if (type.encode === "seconds" || type.encode === "milliseconds") {
return getSimpleTypeResult(context, {
type: type.kind,
encode: type.encode,
wireType: getType(context, type.wireType),
});
}
}
if (type.kind === "utcDateTime" || type.kind === "offsetDateTime") {
if (type.encode === "unixTimestamp") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,22 @@ def serialize_sample_value(value: Any) -> str:


class DurationType(PrimitiveType):
def __init__(self, yaml_data: dict[str, Any], code_model: "CodeModel") -> None:
super().__init__(yaml_data=yaml_data, code_model=code_model)
# ``seconds`` and ``milliseconds`` encodings serialize a timedelta to a numeric
# wire value. ``encode`` is set to a combined format token (e.g.
# ``duration-seconds-int``) so that serialization/deserialization can convert
# between ``datetime.timedelta`` and the numeric wire type. ISO8601 (the default)
# leaves ``encode`` unset and keeps the legacy ISO 8601 string behavior.
self.encode: Optional[str] = None
encode = yaml_data.get("encode")
if encode in ("seconds", "milliseconds"):
wire_type = yaml_data.get("wireType") or {}
wire = "int" if wire_type.get("type") == "integer" else "float"
self.encode = f"duration-{encode}-{wire}"

def serialization_type(self, **kwargs: Any) -> str:
return "duration"
return self.encode or "duration"

def docstring_type(self, **kwargs: Any) -> str:
return "~" + self.type_annotation()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,29 @@ def _serialize_bytes(o, format: typing.Optional[str] = None) -> str:
return encoded


def _serialize_duration(td: timedelta, format: typing.Optional[str] = None):
"""Serialize a timedelta to its wire representation.

For the ``seconds``/``milliseconds`` encodings the value is converted to a
numeric value, otherwise it falls back to an ISO 8601 duration string.

:param timedelta td: The timedelta to serialize.
:param str format: The duration encoding format.
:rtype: int or float or str
:return: serialized duration
"""
seconds = td.total_seconds()
if format == "duration-seconds-int":
return int(seconds)
if format == "duration-seconds-float":
return seconds
if format == "duration-milliseconds-int":
return int(seconds * 1000)
if format == "duration-milliseconds-float":
return seconds * 1000
return _timedelta_as_isostr(td)


def _serialize_datetime(o, format: typing.Optional[str] = None):
if hasattr(o, "year") and hasattr(o, "hour"):
if format == "rfc7231":
Expand Down Expand Up @@ -307,6 +330,12 @@ def _deserialize_duration(attr):
return isodate.parse_duration(attr)


def _deserialize_duration_numeric(attr, unit):
if isinstance(attr, timedelta):
return attr
return timedelta(**{unit: float(attr)})

Comment thread
msyyc marked this conversation as resolved.

def _deserialize_decimal(attr):
if isinstance(attr, decimal.Decimal):
return attr
Expand Down Expand Up @@ -336,6 +365,10 @@ _DESERIALIZE_MAPPING_WITHFORMAT = {
"unix-timestamp": _deserialize_datetime_unix_timestamp,
"base64": _deserialize_bytes,
"base64url": _deserialize_bytes_base64,
"duration-seconds-int": functools.partial(_deserialize_duration_numeric, unit="seconds"),
"duration-seconds-float": functools.partial(_deserialize_duration_numeric, unit="seconds"),
"duration-milliseconds-int": functools.partial(_deserialize_duration_numeric, unit="milliseconds"),
"duration-milliseconds-float": functools.partial(_deserialize_duration_numeric, unit="milliseconds"),
}


Expand Down Expand Up @@ -576,7 +609,7 @@ def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-m
pass
# Last, try datetime.timedelta
try:
return _timedelta_as_isostr(o)
return _serialize_duration(o, format)
except AttributeError:
# This will be raised when it hits value.total_seconds in the method above
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,10 @@ class Serializer: # pylint: disable=too-many-public-methods
"rfc-1123": Serializer.serialize_rfc,
"unix-time": Serializer.serialize_unix,
"duration": Serializer.serialize_duration,
"duration-seconds-int": Serializer.serialize_duration_seconds_int,
"duration-seconds-float": Serializer.serialize_duration_seconds_float,
"duration-milliseconds-int": Serializer.serialize_duration_milliseconds_int,
"duration-milliseconds-float": Serializer.serialize_duration_milliseconds_float,
"date": Serializer.serialize_date,
"time": Serializer.serialize_time,
"decimal": Serializer.serialize_decimal,
Expand Down Expand Up @@ -1105,6 +1109,61 @@ class Serializer: # pylint: disable=too-many-public-methods
attr = isodate.parse_duration(attr)
return isodate.duration_isoformat(attr)

@staticmethod
def _serialize_duration_numeric(attr, scale, as_int):
"""Serialize a TimeDelta into a numeric value scaled to the wire unit.

:param TimeDelta attr: Object to be serialized.
:param int scale: Multiplier applied to total seconds (1 for seconds, 1000 for milliseconds).
:param bool as_int: Whether to truncate the result to an int.
:rtype: int or float
:return: serialized duration
"""
if isinstance(attr, str):
attr = isodate.parse_duration(attr)
value = attr.total_seconds() * scale if isinstance(attr, datetime.timedelta) else attr
return int(value) if as_int else float(value)

@staticmethod
def serialize_duration_seconds_int(attr, **kwargs): # pylint: disable=unused-argument
"""Serialize TimeDelta object into an integer number of seconds.

:param TimeDelta attr: Object to be serialized.
:rtype: int
:return: serialized duration
"""
return Serializer._serialize_duration_numeric(attr, 1, True)

@staticmethod
def serialize_duration_seconds_float(attr, **kwargs): # pylint: disable=unused-argument
"""Serialize TimeDelta object into a floating point number of seconds.

:param TimeDelta attr: Object to be serialized.
:rtype: float
:return: serialized duration
"""
return Serializer._serialize_duration_numeric(attr, 1, False)

@staticmethod
def serialize_duration_milliseconds_int(attr, **kwargs): # pylint: disable=unused-argument
"""Serialize TimeDelta object into an integer number of milliseconds.

:param TimeDelta attr: Object to be serialized.
:rtype: int
:return: serialized duration
"""
return Serializer._serialize_duration_numeric(attr, 1000, True)

@staticmethod
def serialize_duration_milliseconds_float(attr, **kwargs): # pylint: disable=unused-argument
"""Serialize TimeDelta object into a floating point number of milliseconds.

:param TimeDelta attr: Object to be serialized.
:rtype: float
:return: serialized duration
"""
return Serializer._serialize_duration_numeric(attr, 1000, False)

@staticmethod
def serialize_rfc(attr, **kwargs): # pylint: disable=unused-argument
"""Serialize Datetime object into RFC-1123 formatted string.
Expand Down Expand Up @@ -1377,6 +1436,10 @@ class Deserializer:
"rfc-1123": Deserializer.deserialize_rfc,
"unix-time": Deserializer.deserialize_unix,
"duration": Deserializer.deserialize_duration,
"duration-seconds-int": Deserializer.deserialize_duration_seconds,
"duration-seconds-float": Deserializer.deserialize_duration_seconds,
"duration-milliseconds-int": Deserializer.deserialize_duration_milliseconds,
"duration-milliseconds-float": Deserializer.deserialize_duration_milliseconds,
"date": Deserializer.deserialize_date,
"time": Deserializer.deserialize_time,
"decimal": Deserializer.deserialize_decimal,
Expand All @@ -1389,6 +1452,10 @@ class Deserializer:
}
self.deserialize_expected_types = {
"duration": (isodate.Duration, datetime.timedelta),
"duration-seconds-int": (isodate.Duration, datetime.timedelta),
"duration-seconds-float": (isodate.Duration, datetime.timedelta),
"duration-milliseconds-int": (isodate.Duration, datetime.timedelta),
"duration-milliseconds-float": (isodate.Duration, datetime.timedelta),
"iso-8601": (datetime.datetime),
}
self.dependencies: dict[str, type] = dict(classes) if classes else {}
Expand Down Expand Up @@ -1950,6 +2017,48 @@ class Deserializer:
raise DeserializationError(msg) from err
return duration

@staticmethod
def _deserialize_duration_numeric(attr, unit):
"""Deserialize a numeric duration value into a TimeDelta object.

:param float attr: response value to be deserialized.
:param str unit: The wire unit, used as the ``timedelta`` keyword
(``"seconds"`` or ``"milliseconds"``).
:return: Deserialized duration
:rtype: TimeDelta
:raises DeserializationError: if value is invalid.
"""
if isinstance(attr, ET.Element):
attr = attr.text
try:
duration = datetime.timedelta(**{unit: float(attr)}) # type: ignore
except (ValueError, OverflowError, TypeError) as err:
msg = "Cannot deserialize duration object."
raise DeserializationError(msg) from err
return duration

@staticmethod
def deserialize_duration_seconds(attr):
"""Deserialize a numeric number of seconds into a TimeDelta object.

:param float attr: response value to be deserialized.
:return: Deserialized duration
:rtype: TimeDelta
:raises DeserializationError: if value is invalid.
"""
return Deserializer._deserialize_duration_numeric(attr, "seconds")

@staticmethod
def deserialize_duration_milliseconds(attr):
"""Deserialize a numeric number of milliseconds into a TimeDelta object.

:param float attr: response value to be deserialized.
:return: Deserialized duration
:rtype: TimeDelta
:raises DeserializationError: if value is invalid.
"""
return Deserializer._deserialize_duration_numeric(attr, "milliseconds")

@staticmethod
def deserialize_date(attr):
"""Deserialize ISO-8601 formatted string into Date object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,20 @@
import pytest_asyncio
from encode.duration.aio import DurationClient
from encode.duration.models import (
Int32SecondsDurationProperty,
DefaultDurationProperty,
ISO8601DurationProperty,
Int32SecondsDurationProperty,
FloatSecondsDurationProperty,
DefaultDurationProperty,
Float64SecondsDurationProperty,
Int32MillisecondsDurationProperty,
FloatMillisecondsDurationProperty,
Float64MillisecondsDurationProperty,
FloatSecondsDurationArrayProperty,
FloatMillisecondsDurationArrayProperty,
Int32SecondsLargerUnitDurationProperty,
FloatSecondsLargerUnitDurationProperty,
Int32MillisecondsLargerUnitDurationProperty,
FloatMillisecondsLargerUnitDurationProperty,
)


Expand All @@ -27,10 +36,20 @@ async def client():
async def test_query(client: DurationClient):
await client.query.default(input=datetime.timedelta(days=40))
await client.query.iso8601(input=datetime.timedelta(days=40))
await client.query.int32_seconds(input=36)
await client.query.int32_seconds_array(input=[36, 47])
await client.query.float_seconds(input=35.625)
await client.query.float64_seconds(input=35.625)
await client.query.int32_seconds(input=datetime.timedelta(seconds=36))
await client.query.int32_seconds_larger_unit(input=datetime.timedelta(seconds=120))
await client.query.int32_seconds_array(input=[datetime.timedelta(seconds=36), datetime.timedelta(seconds=47)])
await client.query.float_seconds(input=datetime.timedelta(seconds=35.625))
await client.query.float_seconds_larger_unit(input=datetime.timedelta(seconds=150))
await client.query.float64_seconds(input=datetime.timedelta(seconds=35.625))
await client.query.int32_milliseconds(input=datetime.timedelta(milliseconds=36000))
await client.query.int32_milliseconds_larger_unit(input=datetime.timedelta(milliseconds=180000))
await client.query.int32_milliseconds_array(
input=[datetime.timedelta(milliseconds=36000), datetime.timedelta(milliseconds=47000)]
)
await client.query.float_milliseconds(input=datetime.timedelta(milliseconds=35625))
await client.query.float_milliseconds_larger_unit(input=datetime.timedelta(milliseconds=210000))
await client.query.float64_milliseconds(input=datetime.timedelta(milliseconds=35625))


@pytest.mark.asyncio
Expand All @@ -43,22 +62,69 @@ async def test_property(client: DurationClient):
assert result.value == datetime.timedelta(days=40)
result = await client.property.iso8601(ISO8601DurationProperty(value="P40D"))
assert result.value == datetime.timedelta(days=40)
result = await client.property.int32_seconds(Int32SecondsDurationProperty(value=36))
assert result.value == 36
result = await client.property.float_seconds(FloatSecondsDurationProperty(value=35.625))
assert abs(result.value - 35.625) < 0.0001
result = await client.property.float64_seconds(FloatSecondsDurationProperty(value=35.625))
assert abs(result.value - 35.625) < 0.0001
result = await client.property.float_seconds_array(FloatSecondsDurationArrayProperty(value=[35.625, 46.75]))
assert abs(result.value[0] - 35.625) < 0.0001
assert abs(result.value[1] - 46.75) < 0.0001
result = await client.property.int32_seconds(Int32SecondsDurationProperty(value=datetime.timedelta(seconds=36)))
assert result.value == datetime.timedelta(seconds=36)
result = await client.property.float_seconds(FloatSecondsDurationProperty(value=datetime.timedelta(seconds=35.625)))
assert result.value == datetime.timedelta(seconds=35.625)
result = await client.property.float64_seconds(
Float64SecondsDurationProperty(value=datetime.timedelta(seconds=35.625))
)
assert result.value == datetime.timedelta(seconds=35.625)
result = await client.property.int32_milliseconds(
Int32MillisecondsDurationProperty(value=datetime.timedelta(milliseconds=36000))
)
assert result.value == datetime.timedelta(milliseconds=36000)
result = await client.property.float_milliseconds(
FloatMillisecondsDurationProperty(value=datetime.timedelta(milliseconds=35625))
)
assert result.value == datetime.timedelta(milliseconds=35625)
result = await client.property.float64_milliseconds(
Float64MillisecondsDurationProperty(value=datetime.timedelta(milliseconds=35625))
)
assert result.value == datetime.timedelta(milliseconds=35625)
result = await client.property.float_seconds_array(
FloatSecondsDurationArrayProperty(value=[datetime.timedelta(seconds=35.625), datetime.timedelta(seconds=46.75)])
)
assert result.value == [datetime.timedelta(seconds=35.625), datetime.timedelta(seconds=46.75)]
result = await client.property.float_milliseconds_array(
FloatMillisecondsDurationArrayProperty(
value=[datetime.timedelta(milliseconds=35625), datetime.timedelta(milliseconds=46750)]
)
)
assert result.value == [datetime.timedelta(milliseconds=35625), datetime.timedelta(milliseconds=46750)]
result = await client.property.int32_seconds_larger_unit(
Int32SecondsLargerUnitDurationProperty(value=datetime.timedelta(seconds=120))
)
assert result.value == datetime.timedelta(seconds=120)
result = await client.property.float_seconds_larger_unit(
FloatSecondsLargerUnitDurationProperty(value=datetime.timedelta(seconds=150))
)
assert result.value == datetime.timedelta(seconds=150)
result = await client.property.int32_milliseconds_larger_unit(
Int32MillisecondsLargerUnitDurationProperty(value=datetime.timedelta(milliseconds=180000))
)
assert result.value == datetime.timedelta(milliseconds=180000)
result = await client.property.float_milliseconds_larger_unit(
FloatMillisecondsLargerUnitDurationProperty(value=datetime.timedelta(milliseconds=210000))
)
assert result.value == datetime.timedelta(milliseconds=210000)


@pytest.mark.asyncio
async def test_header(client: DurationClient):
await client.header.default(duration=datetime.timedelta(days=40))
await client.header.iso8601(duration=datetime.timedelta(days=40))
await client.header.iso8601_array(duration=[datetime.timedelta(days=40), datetime.timedelta(days=50)])
await client.header.int32_seconds(duration=36)
await client.header.float_seconds(duration=35.625)
await client.header.float64_seconds(duration=35.625)
await client.header.int32_seconds(duration=datetime.timedelta(seconds=36))
await client.header.int32_seconds_larger_unit(duration=datetime.timedelta(seconds=120))
await client.header.float_seconds(duration=datetime.timedelta(seconds=35.625))
await client.header.float_seconds_larger_unit(duration=datetime.timedelta(seconds=150))
await client.header.float64_seconds(duration=datetime.timedelta(seconds=35.625))
await client.header.int32_milliseconds(duration=datetime.timedelta(milliseconds=36000))
await client.header.int32_milliseconds_larger_unit(duration=datetime.timedelta(milliseconds=180000))
await client.header.int32_milliseconds_array(
duration=[datetime.timedelta(milliseconds=36000), datetime.timedelta(milliseconds=47000)]
)
await client.header.float_milliseconds(duration=datetime.timedelta(milliseconds=35625))
await client.header.float_milliseconds_larger_unit(duration=datetime.timedelta(milliseconds=210000))
await client.header.float64_milliseconds(duration=datetime.timedelta(milliseconds=35625))
Loading
Loading