From d140e9567e9b62fc36d9d94d2e7b5c5c03e980eb Mon Sep 17 00:00:00 2001 From: xodn348 Date: Thu, 21 May 2026 07:20:08 +0000 Subject: [PATCH] fix(types): add top_logprobs field to LogprobsPart The Together API returns `top_logprobs` as a List[Dict[str, float]] (one dict per token mapping token string to log-probability), but LogprobsPart had no field for it. Since BaseModel uses extra="allow", the data was stored as an extra field with no declared type, causing: - Missing type information for static analysis / IDE autocompletion - PydanticSerializationUnexpectedValue warnings on model_dump() Add `top_logprobs: List[Dict[str, float]] | None = None` to LogprobsPart and add three unit tests covering the list round-trip, warning-free serialization, and the default-None path. Fixes #443 Signed-off-by: xodn348 --- src/together/types/common.py | 2 ++ tests/unit/test_logprobs_type.py | 39 ++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 tests/unit/test_logprobs_type.py diff --git a/src/together/types/common.py b/src/together/types/common.py index e7ffa29..9c95fc5 100644 --- a/src/together/types/common.py +++ b/src/together/types/common.py @@ -42,6 +42,8 @@ class LogprobsPart(BaseModel): tokens: List[str | None] | None = None # token logprob list token_logprobs: List[float | None] | None = None + # top-k logprobs per token: one dict per token, mapping token string to log-probability + top_logprobs: List[Dict[str, float]] | None = None class PromptPart(BaseModel): diff --git a/tests/unit/test_logprobs_type.py b/tests/unit/test_logprobs_type.py new file mode 100644 index 0000000..b97d9a8 --- /dev/null +++ b/tests/unit/test_logprobs_type.py @@ -0,0 +1,39 @@ +import warnings + +from together.types.common import LogprobsPart + + +def test_logprobs_part_top_logprobs_is_list(): + """LogprobsPart.top_logprobs must accept a list of per-token dicts (issue #443).""" + lp = LogprobsPart( + tokens=["Hello", "."], + token_logprobs=[-2.6e-06, -4.8e-05], + top_logprobs=[ + {"Hello": -2.6e-06, "hello": -13.5, " Hello": -13.875}, + {".": -4.8e-05, "!": -10.2}, + ], + ) + assert isinstance(lp.top_logprobs, list) + assert len(lp.top_logprobs) == 2 + assert isinstance(lp.top_logprobs[0], dict) + assert lp.top_logprobs[0]["Hello"] == -2.6e-06 + + +def test_logprobs_part_model_dump_no_warning(): + """model_dump() must not emit PydanticSerializationUnexpectedValue for top_logprobs.""" + lp = LogprobsPart( + tokens=["Hello"], + token_logprobs=[-2.6e-06], + top_logprobs=[{"Hello": -2.6e-06, "hello": -13.5}], + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + dumped = lp.model_dump() + assert isinstance(dumped["top_logprobs"], list) + assert dumped["top_logprobs"][0]["Hello"] == -2.6e-06 + + +def test_logprobs_part_top_logprobs_optional(): + """top_logprobs defaults to None when not supplied.""" + lp = LogprobsPart(tokens=["hi"], token_logprobs=[-0.5]) + assert lp.top_logprobs is None