diff --git a/src/together/types/common.py b/src/together/types/common.py index e7ffa29..9c95fc5 100644 --- a/src/together/types/common.py +++ b/src/together/types/common.py @@ -42,6 +42,8 @@ class LogprobsPart(BaseModel): tokens: List[str | None] | None = None # token logprob list token_logprobs: List[float | None] | None = None + # top-k logprobs per token: one dict per token, mapping token string to log-probability + top_logprobs: List[Dict[str, float]] | None = None class PromptPart(BaseModel): diff --git a/tests/unit/test_logprobs_type.py b/tests/unit/test_logprobs_type.py new file mode 100644 index 0000000..b97d9a8 --- /dev/null +++ b/tests/unit/test_logprobs_type.py @@ -0,0 +1,39 @@ +import warnings + +from together.types.common import LogprobsPart + + +def test_logprobs_part_top_logprobs_is_list(): + """LogprobsPart.top_logprobs must accept a list of per-token dicts (issue #443).""" + lp = LogprobsPart( + tokens=["Hello", "."], + token_logprobs=[-2.6e-06, -4.8e-05], + top_logprobs=[ + {"Hello": -2.6e-06, "hello": -13.5, " Hello": -13.875}, + {".": -4.8e-05, "!": -10.2}, + ], + ) + assert isinstance(lp.top_logprobs, list) + assert len(lp.top_logprobs) == 2 + assert isinstance(lp.top_logprobs[0], dict) + assert lp.top_logprobs[0]["Hello"] == -2.6e-06 + + +def test_logprobs_part_model_dump_no_warning(): + """model_dump() must not emit PydanticSerializationUnexpectedValue for top_logprobs.""" + lp = LogprobsPart( + tokens=["Hello"], + token_logprobs=[-2.6e-06], + top_logprobs=[{"Hello": -2.6e-06, "hello": -13.5}], + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + dumped = lp.model_dump() + assert isinstance(dumped["top_logprobs"], list) + assert dumped["top_logprobs"][0]["Hello"] == -2.6e-06 + + +def test_logprobs_part_top_logprobs_optional(): + """top_logprobs defaults to None when not supplied.""" + lp = LogprobsPart(tokens=["hi"], token_logprobs=[-0.5]) + assert lp.top_logprobs is None