Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 85 additions & 20 deletions unit_tests/server/core/objs/test_sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
RegularConstraint,
AllowedTokenIds,
ExponentialDecayLengthPenalty,
DecodeNode,
SamplingParams,
GuidedGrammar,
GuidedJsonSchema,
Expand All @@ -14,6 +13,7 @@
ALLOWED_TOKEN_IDS_MAX_LENGTH,
JSON_SCHEMA_MAX_LENGTH,
GRAMMAR_CONSTRAINT_MAX_LENGTH,
MAX_BEST_OF,
)

grammar_str = r"""root ::= (expr "=" term)+
Expand Down Expand Up @@ -117,24 +117,6 @@ def test_exponential_decay_length_penalty_initialization():
penalty.initialize((5, 0.5))


def test_decode_node_initialization():
node = DecodeNode()
data = {
"node_id": 12345678901234567890, # 示例 UUID
"ip": "192.168.1.1",
"rpyc_port": 8080,
"max_new_tokens": 10,
}
node.initialize(data)
assert node.exists is True
assert node.node_id.node_id_high == (12345678901234567890 >> 64) & 0xFFFFFFFFFFFFFFFF
assert node.node_id.node_id_low == 12345678901234567890 & 0xFFFFFFFFFFFFFFFF
assert node.ip[0] == 192
assert node.ip[1] == 168
assert node.ip[2] == 1
assert node.ip[3] == 1


def test_sampling_params_initialization():
params = SamplingParams()
data = {
Expand All @@ -161,7 +143,6 @@ def test_sampling_params_initialization():
"allowed_token_ids": [1, 2, 3],
"stop_sequences": [[2, 1], [3, 4]],
"exponential_decay_length_penalty": (1, 1.0),
"move_kv_to_decode_node": None,
}
params.init(None, **data)

Expand All @@ -173,6 +154,90 @@ def test_sampling_params_initialization():
assert params.stop_sequences.size == 2


def _make_params(**overrides):
"""Build a SamplingParams whose fields are valid by default, applying overrides.

``do_sample=True`` is used so that the sampling-related fields (temperature, top_p,
top_k) are kept as provided; with greedy decoding ``init`` overrides them to defaults.
"""
data = {
"best_of": 1,
"n": 1,
"do_sample": True,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"max_new_tokens": 16,
"min_new_tokens": 1,
}
data.update(overrides)
params = SamplingParams()
params.init(None, **data)
return params


def test_verify_accepts_valid_defaults():
# A minimally specified, valid configuration must pass verification.
_make_params().verify()


def test_verify_accepts_n_equal_best_of_greater_than_one():
params = _make_params(best_of=2, n=2)
params.verify()
assert params.n == params.best_of == 2


def test_verify_rejects_n_not_equal_best_of():
# The engine currently only supports n == best_of; a mismatch must be rejected.
with pytest.raises(ValueError):
_make_params(best_of=2, n=1).verify()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since _make_params internally calls params.init(), which automatically runs self.verify(), any invalid parameter configuration will raise a ValueError during initialization. Consequently, the trailing .verify() call is never reached and is dead code. Removing .verify() makes the test cleaner and accurately reflects where the exception is raised.

Suggested change
_make_params(best_of=2, n=1).verify()
_make_params(best_of=2, n=1)



@pytest.mark.parametrize("best_of", [0, -1, MAX_BEST_OF + 1])
def test_verify_rejects_best_of_out_of_range(best_of):
with pytest.raises(ValueError):
_make_params(best_of=best_of, n=best_of).verify()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since _make_params internally calls params.init(), which automatically runs self.verify(), any invalid parameter configuration will raise a ValueError during initialization. Consequently, the trailing .verify() call is never reached and is dead code. Removing .verify() makes the test cleaner and accurately reflects where the exception is raised.

Suggested change
_make_params(best_of=best_of, n=best_of).verify()
_make_params(best_of=best_of, n=best_of)



@pytest.mark.parametrize(
"field, value",
[
("presence_penalty", -0.1),
("frequency_penalty", -0.1),
("repetition_penalty", 0.5),
("temperature", -1.0),
("top_p", 0.0),
("top_p", 1.5),
("top_k", 0),
("top_k", -2),
("max_new_tokens", 0),
("min_new_tokens", 0),
],
)
def test_verify_rejects_invalid_sampling_fields(field, value):
with pytest.raises(ValueError):
_make_params(**{field: value}).verify()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since _make_params internally calls params.init(), which automatically runs self.verify(), any invalid parameter configuration will raise a ValueError during initialization. Consequently, the trailing .verify() call is never reached and is dead code. Removing .verify() makes the test cleaner and accurately reflects where the exception is raised.

Suggested change
_make_params(**{field: value}).verify()
_make_params(**{field: value})



def test_verify_rejects_min_new_tokens_greater_than_max():
with pytest.raises(ValueError):
_make_params(min_new_tokens=8, max_new_tokens=4).verify()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since _make_params internally calls params.init(), which automatically runs self.verify(), any invalid parameter configuration will raise a ValueError during initialization. Consequently, the trailing .verify() call is never reached and is dead code. Removing .verify() makes the test cleaner and accurately reflects where the exception is raised.

Suggested change
_make_params(min_new_tokens=8, max_new_tokens=4).verify()
_make_params(min_new_tokens=8, max_new_tokens=4)



@pytest.mark.parametrize("top_k", [-1, 1, 50])
def test_verify_accepts_valid_top_k(top_k):
_make_params(top_k=top_k).verify()


def test_verify_rejects_regular_constraint_with_allowed_token_ids():
# regular_constraint and allowed_token_ids are mutually exclusive.
with pytest.raises(ValueError):
_make_params(regular_constraint="[a-z]+", allowed_token_ids=[1, 2, 3]).verify()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since _make_params internally calls params.init(), which automatically runs self.verify(), any invalid parameter configuration will raise a ValueError during initialization. Consequently, the trailing .verify() call is never reached and is dead code. Removing .verify() makes the test cleaner and accurately reflects where the exception is raised.

Suggested change
_make_params(regular_constraint="[a-z]+", allowed_token_ids=[1, 2, 3]).verify()
_make_params(regular_constraint="[a-z]+", allowed_token_ids=[1, 2, 3])



# Mock tokenizer for testing
class MockTokenizer:
def encode(self, text, add_special_tokens=False):
Expand Down
Loading