Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/together/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class LogprobsPart(BaseModel):
tokens: List[str | None] | None = None
# token logprob list
token_logprobs: List[float | None] | None = None
# top-k logprobs per token: one dict per token, mapping token string to log-probability
top_logprobs: List[Dict[str, float]] | None = None


class PromptPart(BaseModel):
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/test_logprobs_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import warnings

from together.types.common import LogprobsPart


def test_logprobs_part_top_logprobs_is_list():
"""LogprobsPart.top_logprobs must accept a list of per-token dicts (issue #443)."""
lp = LogprobsPart(
tokens=["Hello", "."],
token_logprobs=[-2.6e-06, -4.8e-05],
top_logprobs=[
{"Hello": -2.6e-06, "hello": -13.5, " Hello": -13.875},
{".": -4.8e-05, "!": -10.2},
],
)
assert isinstance(lp.top_logprobs, list)
assert len(lp.top_logprobs) == 2
assert isinstance(lp.top_logprobs[0], dict)
assert lp.top_logprobs[0]["Hello"] == -2.6e-06


def test_logprobs_part_model_dump_no_warning():
"""model_dump() must not emit PydanticSerializationUnexpectedValue for top_logprobs."""
lp = LogprobsPart(
tokens=["Hello"],
token_logprobs=[-2.6e-06],
top_logprobs=[{"Hello": -2.6e-06, "hello": -13.5}],
)
with warnings.catch_warnings():
warnings.simplefilter("error")
dumped = lp.model_dump()
assert isinstance(dumped["top_logprobs"], list)
assert dumped["top_logprobs"][0]["Hello"] == -2.6e-06


def test_logprobs_part_top_logprobs_optional():
"""top_logprobs defaults to None when not supplied."""
lp = LogprobsPart(tokens=["hi"], token_logprobs=[-0.5])
assert lp.top_logprobs is None