From cf0e86b636e3008ea8d7a939de2a16c7b10da91d Mon Sep 17 00:00:00 2001 From: supermario_leo Date: Sun, 14 Jun 2026 03:29:57 +0800 Subject: [PATCH] test(sampling_params): repair broken import and add verify() coverage The test module failed to collect because it imported and tested DecodeNode, which no longer exists after the nccl pd-mode removal. Drop the dead import and its test, remove the stale move_kv_to_decode_node kwarg, and add coverage for SamplingParams.verify(), which previously had none: the n == best_of constraint, best_of/n range limits, penalty and temperature/top_p/top_k bounds, min/max_new_tokens, and constraint mutual-exclusion. --- .../server/core/objs/test_sampling_params.py | 105 ++++++++++++++---- 1 file changed, 85 insertions(+), 20 deletions(-) diff --git a/unit_tests/server/core/objs/test_sampling_params.py b/unit_tests/server/core/objs/test_sampling_params.py index ef3f08d2fc..281741e666 100644 --- a/unit_tests/server/core/objs/test_sampling_params.py +++ b/unit_tests/server/core/objs/test_sampling_params.py @@ -5,7 +5,6 @@ RegularConstraint, AllowedTokenIds, ExponentialDecayLengthPenalty, - DecodeNode, SamplingParams, GuidedGrammar, GuidedJsonSchema, @@ -14,6 +13,7 @@ ALLOWED_TOKEN_IDS_MAX_LENGTH, JSON_SCHEMA_MAX_LENGTH, GRAMMAR_CONSTRAINT_MAX_LENGTH, + MAX_BEST_OF, ) grammar_str = r"""root ::= (expr "=" term)+ @@ -117,24 +117,6 @@ def test_exponential_decay_length_penalty_initialization(): penalty.initialize((5, 0.5)) -def test_decode_node_initialization(): - node = DecodeNode() - data = { - "node_id": 12345678901234567890, # 示例 UUID - "ip": "192.168.1.1", - "rpyc_port": 8080, - "max_new_tokens": 10, - } - node.initialize(data) - assert node.exists is True - assert node.node_id.node_id_high == (12345678901234567890 >> 64) & 0xFFFFFFFFFFFFFFFF - assert node.node_id.node_id_low == 12345678901234567890 & 0xFFFFFFFFFFFFFFFF - assert node.ip[0] == 192 - assert node.ip[1] == 168 - assert node.ip[2] == 1 - assert node.ip[3] == 1 - - def test_sampling_params_initialization(): params = SamplingParams() data = { @@ -161,7 +143,6 @@ def test_sampling_params_initialization(): "allowed_token_ids": [1, 2, 3], "stop_sequences": [[2, 1], [3, 4]], "exponential_decay_length_penalty": (1, 1.0), - "move_kv_to_decode_node": None, } params.init(None, **data) @@ -173,6 +154,90 @@ def test_sampling_params_initialization(): assert params.stop_sequences.size == 2 +def _make_params(**overrides): + """Build a SamplingParams whose fields are valid by default, applying overrides. + + ``do_sample=True`` is used so that the sampling-related fields (temperature, top_p, + top_k) are kept as provided; with greedy decoding ``init`` overrides them to defaults. + """ + data = { + "best_of": 1, + "n": 1, + "do_sample": True, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "max_new_tokens": 16, + "min_new_tokens": 1, + } + data.update(overrides) + params = SamplingParams() + params.init(None, **data) + return params + + +def test_verify_accepts_valid_defaults(): + # A minimally specified, valid configuration must pass verification. + _make_params().verify() + + +def test_verify_accepts_n_equal_best_of_greater_than_one(): + params = _make_params(best_of=2, n=2) + params.verify() + assert params.n == params.best_of == 2 + + +def test_verify_rejects_n_not_equal_best_of(): + # The engine currently only supports n == best_of; a mismatch must be rejected. + with pytest.raises(ValueError): + _make_params(best_of=2, n=1).verify() + + +@pytest.mark.parametrize("best_of", [0, -1, MAX_BEST_OF + 1]) +def test_verify_rejects_best_of_out_of_range(best_of): + with pytest.raises(ValueError): + _make_params(best_of=best_of, n=best_of).verify() + + +@pytest.mark.parametrize( + "field, value", + [ + ("presence_penalty", -0.1), + ("frequency_penalty", -0.1), + ("repetition_penalty", 0.5), + ("temperature", -1.0), + ("top_p", 0.0), + ("top_p", 1.5), + ("top_k", 0), + ("top_k", -2), + ("max_new_tokens", 0), + ("min_new_tokens", 0), + ], +) +def test_verify_rejects_invalid_sampling_fields(field, value): + with pytest.raises(ValueError): + _make_params(**{field: value}).verify() + + +def test_verify_rejects_min_new_tokens_greater_than_max(): + with pytest.raises(ValueError): + _make_params(min_new_tokens=8, max_new_tokens=4).verify() + + +@pytest.mark.parametrize("top_k", [-1, 1, 50]) +def test_verify_accepts_valid_top_k(top_k): + _make_params(top_k=top_k).verify() + + +def test_verify_rejects_regular_constraint_with_allowed_token_ids(): + # regular_constraint and allowed_token_ids are mutually exclusive. + with pytest.raises(ValueError): + _make_params(regular_constraint="[a-z]+", allowed_token_ids=[1, 2, 3]).verify() + + # Mock tokenizer for testing class MockTokenizer: def encode(self, text, add_special_tokens=False):