diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 1b2391bd4b0..a1101c971bd 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -965,8 +965,8 @@ void NgramMatch(const paddle::Tensor& input_ids, const int max_ngram_size, const int max_draft_tokens); -void HybridMtpNgram(const paddle::Tensor& input_ids, - const paddle::Tensor& input_ids_len, +void HybridMtpNgram(const paddle::Tensor& token_ids_all, + const paddle::Tensor& prompt_lens, const paddle::Tensor& pre_ids, const paddle::Tensor& step_idx, const paddle::Tensor& draft_token_num, diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu b/custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu index 778a6112367..b1dccf45a65 100644 --- a/custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu @@ -29,8 +29,8 @@ // Also copies tentative matched tokens to scratch buffers. // ============================================================ __global__ void ngram_match_mixed_search_kernel( - const int64_t *input_ids, - const int64_t *input_ids_len, + const int64_t *token_ids_all, + const int64_t *prompt_lens, const int64_t *pre_ids, const int64_t *step_idx, const int *draft_token_num, @@ -38,7 +38,7 @@ __global__ void ngram_match_mixed_search_kernel( const int64_t *max_dec_len, int64_t *draft_tokens_copy, int32_t *seq_lens_this_time_copy, - int64_t input_ids_stride, + int64_t max_model_len, int64_t pre_ids_stride, int64_t draft_tokens_stride, int64_t max_batch_size, @@ -69,8 +69,9 @@ __global__ void ngram_match_mixed_search_kernel( if (draft_budget <= 0 || remaining_dec <= 0) return; int max_draft_tokens = static_cast(min(draft_budget, remaining_dec)); - const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride; - const int64_t cur_input_ids_len = input_ids_len[batch_idx]; + const int64_t prompt_len = prompt_lens[batch_idx]; + const int64_t *cur_input_ids = token_ids_all + batch_idx * max_model_len; + const int64_t cur_input_ids_len = prompt_len; const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride; const int64_t cur_step_idx = step_idx[batch_idx]; @@ -228,8 +229,8 @@ static int sum_mixed_cpu(const int *value, int num) { return sum_value; } -static void find_candidate_pred_tokens_mixed(const int64_t *input_ids, - const int64_t *input_ids_len, +static void find_candidate_pred_tokens_mixed(const int64_t *token_ids_all, + const int64_t *prompt_lens, const int64_t *pre_ids, const int64_t *step_idx, const int *draft_token_num, @@ -237,7 +238,7 @@ static void find_candidate_pred_tokens_mixed(const int64_t *input_ids, int32_t *seq_lens_this_time, int32_t *seq_lens_decoder, int64_t *max_dec_len, - int64_t input_ids_stride, + int64_t max_model_len, int64_t pre_ids_stride, int64_t draft_tokens_stride, int64_t max_batch_size, @@ -268,11 +269,12 @@ static void find_candidate_pred_tokens_mixed(const int64_t *input_ids, int max_draft_tokens_query = static_cast(std::min(draft_budget, remaining_dec)); - const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride; + const int64_t prompt_len = prompt_lens[batch_idx]; + const int64_t *cur_input_ids = token_ids_all + batch_idx * max_model_len; int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride; const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride; const int64_t cur_step_idx = step_idx[batch_idx]; - const int64_t cur_input_ids_len = input_ids_len[batch_idx]; + const int64_t cur_input_ids_len = prompt_len; unprocessed_batch_size--; auto sum_token_num = sum_mixed_cpu(seq_lens_this_time, batch_idx); @@ -363,8 +365,8 @@ static void find_candidate_pred_tokens_mixed(const int64_t *input_ids, // threshold enforcement + final token copy. // ============================================================ -void HybridMtpNgram(const paddle::Tensor &input_ids, - const paddle::Tensor &input_ids_len, +void HybridMtpNgram(const paddle::Tensor &token_ids_all, + const paddle::Tensor &prompt_lens, const paddle::Tensor &pre_ids, const paddle::Tensor &step_idx, const paddle::Tensor &draft_token_num, @@ -375,8 +377,7 @@ void HybridMtpNgram(const paddle::Tensor &input_ids, const int max_ngram_size, const int min_ngram_size, const int max_draft_tokens) { - auto input_ids_shape = input_ids.shape(); - const int64_t input_ids_stride = input_ids_shape[1]; + const int64_t max_model_len = token_ids_all.shape()[1]; auto pre_ids_shape = pre_ids.shape(); const int64_t pre_ids_stride = pre_ids_shape[1]; @@ -392,8 +393,8 @@ void HybridMtpNgram(const paddle::Tensor &input_ids, threshold = std::stoi(env_var); } - if (input_ids.is_gpu()) { - auto stream = input_ids.stream(); + if (token_ids_all.is_gpu()) { + auto stream = token_ids_all.stream(); // NOTE: GPU path does not pass seq_lens_decoder to kernels — the mixed // variant uses ori_seq_len_this_time == 0 to skip inactive items. This @@ -408,16 +409,16 @@ void HybridMtpNgram(const paddle::Tensor &input_ids, auto draft_tokens_copy = paddle::empty({max_batch_size, draft_tokens_stride}, paddle::DataType::INT64, - input_ids.place()); + token_ids_all.place()); // Scratch copy of seq_lens_this_time (Phase 1 writes tentative counts) auto seq_lens_this_time_copy = paddle::empty( - {max_batch_size}, paddle::DataType::INT32, input_ids.place()); + {max_batch_size}, paddle::DataType::INT32, token_ids_all.place()); // Save a copy of original seq_lens_this_time for Phase 2 // (Phase 1 reads from the original, Phase 2 needs ori values) auto seq_lens_this_time_orig = paddle::empty( - {max_batch_size}, paddle::DataType::INT32, input_ids.place()); + {max_batch_size}, paddle::DataType::INT32, token_ids_all.place()); cudaMemcpyAsync(seq_lens_this_time_orig.data(), seq_lens_this_time.data(), max_batch_size * sizeof(int32_t), @@ -434,8 +435,8 @@ void HybridMtpNgram(const paddle::Tensor &input_ids, NGRAM_BLOCK_THREADS, 0, stream>>>( - input_ids.data(), - input_ids_len.data(), + token_ids_all.data(), + prompt_lens.data(), pre_ids.data(), step_idx.data(), draft_token_num.data(), @@ -443,7 +444,7 @@ void HybridMtpNgram(const paddle::Tensor &input_ids, max_dec_len.data(), draft_tokens_copy.data(), seq_lens_this_time_copy.data(), - input_ids_stride, + max_model_len, pre_ids_stride, draft_tokens_stride, max_batch_size, @@ -464,8 +465,8 @@ void HybridMtpNgram(const paddle::Tensor &input_ids, threshold); } else { find_candidate_pred_tokens_mixed( - input_ids.data(), - input_ids_len.data(), + token_ids_all.data(), + prompt_lens.data(), pre_ids.data(), step_idx.data(), draft_token_num.data(), @@ -473,7 +474,7 @@ void HybridMtpNgram(const paddle::Tensor &input_ids, const_cast(seq_lens_this_time.data()), const_cast(seq_lens_decoder.data()), const_cast(max_dec_len.data()), - input_ids_stride, + max_model_len, pre_ids_stride, draft_tokens_stride, max_batch_size, @@ -484,8 +485,8 @@ void HybridMtpNgram(const paddle::Tensor &input_ids, } PD_BUILD_STATIC_OP(hybrid_mtp_ngram) - .Inputs({"input_ids", - "input_ids_len", + .Inputs({"token_ids_all", + "prompt_lens", "pre_ids", "step_idx", "draft_token_num", diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index cf614ad399e..868f60553c1 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -470,16 +470,10 @@ def insert_tasks_v1( input_ids = request.prompt_token_ids + request.output_token_ids - self.model_inputs["input_ids_len"][idx] = length - 1 async_set_value(self.model_inputs["pre_ids"][idx : idx + 1], -1) self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][ idx : idx + 1, 1:length ] - # TODO: use token_all_ids replace with input_ids_cpu - if getattr(self, "hybrid_mode", False) and "input_ids_cpu" in self.model_inputs: - self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[ - "input_ids" - ][idx : idx + 1, 1:length].cpu() encoder_block_num = len(request.block_tables) async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num) async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1) @@ -567,7 +561,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: request = req_dicts[i] idx = request.idx length = len(request.prompt_token_ids) - self.model_inputs.input_ids_len[idx] = length - 1 if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode": length = len(request.prompt_token_ids) @@ -575,9 +568,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[ "input_ids" ][idx : idx + 1, 1:length] - self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array( - request.prompt_token_ids - )[1:] self.model_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1] prefill_token_num = self.max_draft_token_num + 1 self.model_inputs["draft_tokens"][idx : idx + 1, 0:1] = paddle.to_tensor( @@ -606,9 +596,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[ "input_ids" ][idx : idx + 1, 1:length] - self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array( - request.prompt_token_ids - )[1:] self.model_inputs["pre_ids"][idx : idx + 1] = -1 self.model_inputs["step_idx"][idx : idx + 1] = 0 if self.cache_config.enable_chunked_prefill: diff --git a/fastdeploy/spec_decode/mtp_cuda.py b/fastdeploy/spec_decode/mtp_cuda.py index eb74b371069..cbe7ad1aa19 100644 --- a/fastdeploy/spec_decode/mtp_cuda.py +++ b/fastdeploy/spec_decode/mtp_cuda.py @@ -393,10 +393,9 @@ def _update_status(self): ) def _extend_draft_token_with_ngram_match(self): - # TODO: replace with gpu tensor hybrid_mtp_ngram( - self.model_inputs["input_ids_cpu"].cuda(), - self.model_inputs["input_ids_len"].cuda(), + self.model_inputs["token_ids_all"], + self.model_inputs["prompt_lens"], self.model_inputs["pre_ids"], self.model_inputs["step_idx"], self.target_model_inputs["actual_draft_token_num"], diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 56dc4ae9aa8..f5994c6341c 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -758,12 +758,6 @@ def init_share_inputs(self): self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"]) self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"]) - self.input_ids_cpu = paddle.full( - shape=[self.scheduler_config.max_num_seqs, self.model_config.max_model_len], - fill_value=-1, - dtype="int64", - device="cpu", - ) self.seq_lens_this_time_buffer = paddle.clone(self.target_model_input_batch["seq_lens_this_time"]) self.seq_lens_encoder = paddle.clone(self.target_model_input_batch["seq_lens_encoder"]) @@ -776,7 +770,7 @@ def init_share_inputs(self): self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"]) self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"]) if "token_ids_all" in self.target_model_input_batch: - self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"]) + self.token_ids_all = self.target_model_input_batch["token_ids_all"] # TODO: delete pre_ids in mtp self.pre_ids = paddle.full( [self.scheduler_config.max_num_seqs, self.model_config.max_model_len], @@ -886,7 +880,6 @@ def init_share_inputs(self): self.last_seq_lens_this_time = paddle.full_like( self.target_model_input_batch["seq_lens_this_time"], fill_value=-1, dtype="int32" ) - self.input_ids_len = paddle.zeros(shape=[self.scheduler_config.max_num_seqs, 1], dtype="int64", device="cpu") self.temp_scaled_logprobs = self.target_model_input_batch["temp_scaled_logprobs"] self.top_p_normalized_logprobs = self.target_model_input_batch["top_p_normalized_logprobs"] self.accept_num = self.target_model_input_batch["accept_num"] @@ -936,14 +929,12 @@ def swap_data(tensor, idx1, idx2): self.index_to_batch_id[i1], self.index_to_batch_id[i2] = self.index_to_batch_id[i2], self.index_to_batch_id[i1] swap_data(self.block_tables, i1, i2) swap_data(self.input_ids, i1, i2) - swap_data(self.input_ids_cpu, i1, i2) swap_data(self.seq_lens_this_time_buffer, i1, i2) swap_data(self.seq_lens_encoder, i1, i2) swap_data(self.seq_lens_decoder, i1, i2) swap_data(self.step_idx, i1, i2) swap_data(self.pre_ids, i1, i2) swap_data(self.encoder_block_lens, i1, i2) - swap_data(self.input_ids_len, i1, i2) swap_data(self.mask_rollback, i1, i2) swap_data(self.recompute_token_num, i1, i2) if self.enable_mm: @@ -966,7 +957,6 @@ def reset_model_inputs(self) -> None: # Clone the target model inputs to restore initial values self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"]) self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"]) - fill_paddle_tensor(self, "input_ids_cpu", -1) # acceptance rate decline when reset seq_lens_this_time # self.seq_lens_this_time_buffer = paddle.clone(self.target_model_input_batch["seq_lens_this_time"]) @@ -980,7 +970,7 @@ def reset_model_inputs(self) -> None: self.index_to_batch_id = {} if current_platform.is_cuda(): if "token_ids_all" in self.target_model_input_batch: - self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"]) + self.token_ids_all = self.target_model_input_batch["token_ids_all"] # TODO: delete pre_ids in mtp self.pre_ids = paddle.full( [self.scheduler_config.max_num_seqs, self.model_config.max_model_len], @@ -1062,9 +1052,6 @@ def reset_model_inputs(self) -> None: if self.num_model_steps > 1: fill_paddle_tensor(self, "last_seq_lens_this_time", -1) - # Reset input IDs length - fill_paddle_tensor(self, "input_ids_len", 0) - # Reset various scores and flags self.temp_scaled_logprobs = self.target_model_input_batch["temp_scaled_logprobs"] self.top_p_normalized_logprobs = self.target_model_input_batch["top_p_normalized_logprobs"] diff --git a/tests/e2e/test_ernie_21b_mtp_ngram.py b/tests/e2e/test_ernie_21b_mtp_ngram.py new file mode 100644 index 00000000000..2ee3b694a16 --- /dev/null +++ b/tests/e2e/test_ernie_21b_mtp_ngram.py @@ -0,0 +1,345 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import signal +import subprocess +import sys +import time + +import pytest +from utils.serving_utils import ( + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + clean, + extract_logprobs, + get_stream_chunks, + is_port_open, + send_request, +) + + +def _build_speculate_metrics_baseline( + accepted_tokens, + rejected_tokens, + accept_ratio, + average_accept_length, + accepted_tokens_per_head, + accept_ratio_per_head, +): + """ + Build a tolerance-based baseline for speculate metrics. + + Integer counters remain strict, while floating-point fields and + per-head metric arrays use approximate comparison to reduce + environment-sensitive test flakiness. + """ + return { + "accepted_tokens": accepted_tokens, + "rejected_tokens": rejected_tokens, + "accept_ratio": pytest.approx(accept_ratio, abs=0.02), + "average_accept_length": pytest.approx(average_accept_length, abs=0.1), + "accepted_tokens_per_head": pytest.approx(accepted_tokens_per_head, abs=2), + "accept_ratio_per_head": pytest.approx(accept_ratio_per_head, abs=0.05), + } + + +BASELINE_SPECULATE_METRICS = _build_speculate_metrics_baseline( + accepted_tokens=100, + rejected_tokens=176, + accept_ratio=0.54, + average_accept_length=2.1739130434782608, + accepted_tokens_per_head=[46, 25, 15, 8, 6, 0], + accept_ratio_per_head=[0.5434782608695652, 0.6, 0.5333333333333333, 0.75, 0.0], +) +BASELINE_SPECULATE_METRICS_WITH_LOGPROBS = _build_speculate_metrics_baseline( + accepted_tokens=100, + rejected_tokens=182, + accept_ratio=0.53, + average_accept_length=2.127659574468085, + accepted_tokens_per_head=[47, 29, 16, 5, 3, 0], + accept_ratio_per_head=[0.6170212765957447, 0.5517241379310345, 0.3125, 0.6, 0.0], +) + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_run_server(): + """ + Pytest fixture that runs once per test session: + - Cleans ports before tests + - Starts the API server in hybrid MTP-Ngram mode as a subprocess + - Waits for server port to open (up to 5 minutes) + - Tears down server after all tests finish + """ + print("Pre-test port cleanup...") + clean() + + if os.path.exists("log") and os.path.isdir("log"): + shutil.rmtree("log") + + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "ernie-4_5-21b-a3b-bf16-paddle") + else: + model_path = "./ernie-4_5-21b-a3b-bf16-paddle" + mtp_model_path = os.path.join(model_path, "mtp") + + speculative_config = { + "method": "mtp", + "model": mtp_model_path, + "num_speculative_tokens": 5, + "num_model_steps": 3, + "mtp_strategy": "with_ngram", + "max_ngram_size": 3, + "min_ngram_size": 1, + } + + log_path = "server.log" + cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT), + "--tensor-parallel-size", + "1", + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT), + "--metrics-port", + str(FD_METRICS_PORT), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT), + "--max-model-len", + "4096", + "--max-num-seqs", + "8", + "--quantization", + "wint4", + "--enable-overlap-schedule", + "--enable-logprob", + "--speculative-config", + json.dumps(speculative_config), + "--graph-optimization-config", + '{"use_cudagraph":true, "use_unique_memory_pool":true, "draft_model_use_cudagraph":true}', + ] + + with open(log_path, "w") as logfile: + process = subprocess.Popen( + cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, + ) + + for _ in range(300): + if is_port_open("127.0.0.1", FD_API_PORT): + print(f"Server is up on port {FD_API_PORT}") + break + time.sleep(1) + else: + try: + os.killpg(process.pid, signal.SIGTERM) + clean() + except Exception as e: + print(f"Failed to kill process group: {e}") + raise RuntimeError(f"API server did not start on port {FD_API_PORT}") + + yield + + print("\n===== Post-test server cleanup... =====") + try: + os.killpg(process.pid, signal.SIGTERM) + clean() + print(f"server (pid={process.pid}) terminated") + except Exception as e: + print(f"Failed to terminate API server: {e}") + + +@pytest.fixture(scope="session") +def api_url(): + return f"http://0.0.0.0:{FD_API_PORT}/v1/chat/completions" + + +@pytest.fixture(scope="session") +def metrics_url(): + return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" + + +def test_mtp_ngram_stream(api_url): + """Hybrid MTP-Ngram streaming generation returns non-empty result with valid token counts.""" + payload = { + "model": "default", + "messages": [{"role": "user", "content": "牛顿的三大运动定律是什么?"}], + "max_tokens": 50, + "min_tokens": 10, + "temperature": 0, + "top_p": 0, + "seed": 42, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + } + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join(x["choices"][0]["delta"]["content"] for x in chunks[:-1]) + assert result != "", "Generation result is empty" + usage = chunks[-1]["usage"] + assert usage["completion_tokens"] <= payload["max_tokens"] + assert usage["completion_tokens"] >= payload["min_tokens"] + assert usage["total_tokens"] == usage["completion_tokens"] + usage["prompt_tokens"] + + +def test_mtp_ngram_non_stream(api_url): + """Hybrid MTP-Ngram non-streaming generation returns non-empty result with valid token counts.""" + payload = { + "model": "default", + "messages": [{"role": "user", "content": "牛顿的三大运动定律是什么?"}], + "max_tokens": 50, + "min_tokens": 10, + "temperature": 0, + "top_p": 0, + "seed": 42, + "stream": False, + } + response = send_request(url=api_url, payload=payload).json() + result = response["choices"][0]["message"]["content"] + assert result != "", "Generation result is empty" + usage = response["usage"] + assert usage["completion_tokens"] <= payload["max_tokens"] + assert usage["completion_tokens"] >= payload["min_tokens"] + assert usage["total_tokens"] == usage["completion_tokens"] + usage["prompt_tokens"] + + +def test_mtp_ngram_speculate_metrics(api_url): + """speculate_metrics contains the MTP + ngram acceptance stats for a repeated-prompt request.""" + # Prompt with repeated fragments to increase ngram match rate + content = ( + "国外项目风险管理研究起步较早,理论体系成熟。早期研究集中于保险与金融领域,后逐步扩展至工程项目、" + "公共管理等多领域。在理论层面,COSO《企业风险管理——整合框架》和ISO31000标准为风险管理提供了系统性" + "指导,强调风险识别、评估、应对与监控的全流程管理。风险识别方法包括故障树分析、事件树分析等;风险评估" + "则广泛应用VaR模型、蒙特卡洛模拟等量化工具。应对策略涵盖规避、转移、减轻和接受等,并衍生出风险共享、" + "升级等复杂策略。此外,组织文化、管理层支持等因素对风险管理有效性影响显著。近年来,随着科技发展," + "人工智能、大数据等技术被引入风险管理,推动其向智能化、自动化方向发展。请介绍一下国外关于项目风险管理" + "的文献研究综述,300字以内" + ) + payload = { + "model": "default", + "messages": [{"role": "user", "content": content}], + "max_tokens": 100, + "min_tokens": 20, + "temperature": 0, + "top_p": 0, + "seed": 42, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + } + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + # chunks[-1] is the usage chunk; chunks[-2] is the last content chunk containing speculate_metrics + speculate_metrics = chunks[-2]["choices"][0].get("speculate_metrics") + # print(f"\n[test_mtp_ngram_speculate_metrics] speculate_metrics: {json.dumps(speculate_metrics, indent=2)}") + assert speculate_metrics is not None, "speculate_metrics is missing from last chunk" + + # accepted_tokens_per_head length should equal num_speculative_tokens + 1 (head 0 is the verified token). + # num_speculative_tokens=5 → 6 entries. + accepted_per_head = speculate_metrics.get("accepted_tokens_per_head") + assert accepted_per_head is not None, "accepted_tokens_per_head is missing" + assert len(accepted_per_head) == 6, ( + f"Expected 6 entries in accepted_tokens_per_head (num_speculative_tokens=5 + 1), " + f"got {len(accepted_per_head)}: {accepted_per_head}" + ) + # Monotonically non-increasing (each subsequent draft head is harder to accept) + for i in range(len(accepted_per_head) - 1): + assert ( + accepted_per_head[i] >= accepted_per_head[i + 1] + ), f"accepted_tokens_per_head is not monotonically non-increasing at index {i}: {accepted_per_head}" + + # Cross-field consistency: total accepted == sum of per-head + assert speculate_metrics["accepted_tokens"] == sum(accepted_per_head), ( + f"accepted_tokens ({speculate_metrics['accepted_tokens']}) != " + f"sum(accepted_tokens_per_head) ({sum(accepted_per_head)})" + ) + + # Baseline comparison — exact match against the values captured in the reference environment. + if BASELINE_SPECULATE_METRICS is not None: + assert speculate_metrics == BASELINE_SPECULATE_METRICS, ( + f"speculate_metrics mismatch\n" + f"got: {json.dumps(speculate_metrics, indent=2)}\n" + f"baseline: {json.dumps(BASELINE_SPECULATE_METRICS, indent=2)}" + ) + + +def test_mtp_ngram_speculate_metrics_with_logprobs(api_url): + """speculate_metrics and logprobs coexist correctly when hybrid mode + logprobs are both enabled.""" + content = ( + "国外项目风险管理研究起步较早,理论体系成熟。早期研究集中于保险与金融领域,后逐步扩展至工程项目、" + "公共管理等多领域。在理论层面,COSO《企业风险管理——整合框架》和ISO31000标准为风险管理提供了系统性" + "指导,强调风险识别、评估、应对与监控的全流程管理。风险识别方法包括故障树分析、事件树分析等;风险评估" + "则广泛应用VaR模型、蒙特卡洛模拟等量化工具。应对策略涵盖规避、转移、减轻和接受等,并衍生出风险共享、" + "升级等复杂策略。此外,组织文化、管理层支持等因素对风险管理有效性影响显著。近年来,随着科技发展," + "人工智能、大数据等技术被引入风险管理,推动其向智能化、自动化方向发展。请介绍一下国外关于项目风险管理" + "的文献研究综述,300字以内" + ) + payload = { + "model": "default", + "messages": [{"role": "user", "content": content}], + "max_tokens": 100, + "min_tokens": 20, + "temperature": 0, + "top_p": 0, + "seed": 42, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "logprobs": True, + "top_logprobs": 5, + } + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + + # logprobs are present in each content chunk + logprobs_list = extract_logprobs(chunks) + assert len(logprobs_list) > 0, "No logprobs received" + for logprobs in logprobs_list: + assert "content" in logprobs + for item in logprobs["content"]: + assert "token" in item + assert "logprob" in item + assert "top_logprobs" in item + assert len(item["top_logprobs"]) <= 5 + + # speculate_metrics appears in the last content chunk and is consistent + speculate_metrics = chunks[-2]["choices"][0].get("speculate_metrics") + # print( + # f"\n[test_mtp_ngram_speculate_metrics_with_logprobs] " + # f"speculate_metrics: {json.dumps(speculate_metrics, indent=2)}" + # ) + assert speculate_metrics is not None, "speculate_metrics is missing from last chunk" + + accepted_per_head = speculate_metrics.get("accepted_tokens_per_head") + assert accepted_per_head is not None, "accepted_tokens_per_head is missing" + assert len(accepted_per_head) == 6 + assert speculate_metrics["accepted_tokens"] == sum(accepted_per_head) + + # Baseline comparison — exact match against the values captured in the reference environment. + if BASELINE_SPECULATE_METRICS_WITH_LOGPROBS is not None: + assert speculate_metrics == BASELINE_SPECULATE_METRICS_WITH_LOGPROBS, ( + f"speculate_metrics mismatch\n" + f"got: {json.dumps(speculate_metrics, indent=2)}\n" + f"baseline: {json.dumps(BASELINE_SPECULATE_METRICS_WITH_LOGPROBS, indent=2)}" + ) diff --git a/tests/operators/test_hybrid_mtp_ngram.py b/tests/operators/test_hybrid_mtp_ngram.py index 6c111f93763..72794557265 100644 --- a/tests/operators/test_hybrid_mtp_ngram.py +++ b/tests/operators/test_hybrid_mtp_ngram.py @@ -26,15 +26,19 @@ class TestNgramMatchMixed(unittest.TestCase): def setUp(self): self.max_bsz = 2 self.max_draft_tokens = 5 - self.max_len = 32 + self.max_model_len = 32 + self.prompt_len = 10 self.max_dec_len = 10 self.max_ngram_size = 5 self.min_ngram_size = 2 - # 初始化输入 tensor - self.input_ids = paddle.full(shape=[self.max_bsz, self.max_len], fill_value=-1, dtype="int64").cpu() - self.input_ids_len = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int64").cpu() - self.pre_ids = paddle.full(shape=[self.max_bsz, self.max_len], fill_value=-1, dtype="int64").cpu() + # token_ids_all layout: + # [0 .. prompt_len-1] <- prompt (Phase 1 search source) + # [prompt_len .. ] <- pad (-1) + # pre_ids carries the generated tokens used as Phase 2 search source. + self.token_ids_all = paddle.full(shape=[self.max_bsz, self.max_model_len], fill_value=-1, dtype="int64").cpu() + self.prompt_lens = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int64").cpu() + self.pre_ids = paddle.full(shape=[self.max_bsz, self.max_model_len], fill_value=-1, dtype="int64").cpu() self.step_idx = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int64").cpu() self.draft_token_num = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int32").cpu() self.draft_tokens = paddle.full( @@ -50,9 +54,10 @@ def setUp(self): dtype="int64", ).cpu() - # 设置具体数据 - self.input_ids[:, :10] = np.arange(0, 10) - self.input_ids_len[:] = 10 + # Fill prompt 0..9 into token_ids_all (Phase 1 search source) + self.token_ids_all[:, : self.prompt_len] = np.arange(0, self.prompt_len) + self.prompt_lens[:] = self.prompt_len + pre_ids_np = np.array([10, 9, 8, 7, 6, 10, 9, 8, 7], dtype="int32") self.pre_ids[:, : pre_ids_np.shape[0]] = pre_ids_np self.step_idx[:] = 8 @@ -63,14 +68,16 @@ def setUp(self): self.seq_lens_decoder[:] = 12 self.max_dec_len[:] = 512 - # 期望结果 + # Expected results (unchanged: kernel matching logic is identical; + # only the data source for prompt tokens moved from input_ids to + # token_ids_all[:, :prompt_len]). self.ref_seq_lens_this_time = np.array([[6], [6]], dtype="int32") self.ref_draft_tokens = np.array([[8, 7, 6, 10, 9, 8], [8, 7, 6, 10, 9, 8]], dtype="int64") def test_ngram_match_mixed(self): hybrid_mtp_ngram( - self.input_ids, - self.input_ids_len, + self.token_ids_all, + self.prompt_lens, self.pre_ids, self.step_idx, self.draft_token_num, diff --git a/tests/spec_decode/test_ngram_gpu_kernel.py b/tests/spec_decode/test_ngram_gpu_kernel.py index ad1f5ea845b..87868c67436 100644 --- a/tests/spec_decode/test_ngram_gpu_kernel.py +++ b/tests/spec_decode/test_ngram_gpu_kernel.py @@ -126,8 +126,8 @@ def _cpu_ngram_match( def _cpu_hybrid_mtp_ngram( - input_ids, - input_ids_len, + token_ids_all, + prompt_lens, pre_ids, step_idx, draft_token_num, @@ -145,7 +145,7 @@ def _cpu_hybrid_mtp_ngram( max_dec_len = max_dec_len.ravel() step_idx = step_idx.ravel() draft_token_num = draft_token_num.ravel() - input_ids_len = input_ids_len.ravel() + prompt_lens = prompt_lens.ravel() max_batch_size = seq_lens_this_time.shape[0] unprocessed = sum(1 for b in range(max_batch_size) if seq_lens_decoder[b] > 0) @@ -158,11 +158,11 @@ def _cpu_hybrid_mtp_ngram( if ori_slt == 0 or max_q <= 0: continue - cur_input_ids = input_ids[batch_idx] + cur_input_ids = token_ids_all[batch_idx] cur_draft = draft_tokens[batch_idx] cur_pre = pre_ids[batch_idx] cur_step = int(step_idx[batch_idx]) - cur_ids_len = int(input_ids_len[batch_idx]) + cur_ids_len = int(prompt_lens[batch_idx]) unprocessed -= 1 sum_tok = sum(int(seq_lens_this_time[i]) for i in range(batch_idx + 1)) @@ -263,9 +263,14 @@ def _make_mixed_test_data(batch_size=4, input_len=64, pre_ids_len=256, max_draft """Create realistic test tensors for hybrid_mtp_ngram op.""" rng = np.random.RandomState(seed) vocab_size = 1000 + # token_ids_all must be at least as large as the per-request budget; + # use pre_ids_len as a stand-in for max_model_len in tests. + max_model_len = max(pre_ids_len, input_len + 64) - input_ids = rng.randint(0, vocab_size, (batch_size, input_len)).astype(np.int64) - input_ids_len = np.full((batch_size, 1), input_len, dtype=np.int64) + prompt_tokens = rng.randint(0, vocab_size, (batch_size, input_len)).astype(np.int64) + token_ids_all = np.full((batch_size, max_model_len), -1, dtype=np.int64) + token_ids_all[:, :input_len] = prompt_tokens + prompt_lens = np.full((batch_size, 1), input_len, dtype=np.int64) pre_ids = np.zeros((batch_size, pre_ids_len), dtype=np.int64) step_idx = np.zeros((batch_size, 1), dtype=np.int64) @@ -280,13 +285,13 @@ def _make_mixed_test_data(batch_size=4, input_len=64, pre_ids_len=256, max_draft # Copy contiguous blocks from prompt to guarantee ngram matches gen_len = 20 src = rng.randint(0, max(1, input_len - gen_len)) - pre_ids[b, :gen_len] = input_ids[b, src : src + gen_len] + pre_ids[b, :gen_len] = prompt_tokens[b, src : src + gen_len] # step_idx = last valid position (0-based index) step_idx[b] = gen_len - 1 return { - "input_ids": input_ids, - "input_ids_len": input_ids_len, + "token_ids_all": token_ids_all, + "prompt_lens": prompt_lens, "pre_ids": pre_ids, "step_idx": step_idx, "draft_token_num": draft_token_num, @@ -835,8 +840,8 @@ def test_correctness_basic(self): cpu_draft = data["draft_tokens"].copy() cpu_slt = data["seq_lens_this_time"].copy() _cpu_hybrid_mtp_ngram( - data["input_ids"], - data["input_ids_len"], + data["token_ids_all"], + data["prompt_lens"], data["pre_ids"], data["step_idx"], data["draft_token_num"], @@ -851,8 +856,8 @@ def test_correctness_basic(self): gpu_data = _to_gpu(data) self.hybrid_mtp_ngram( - gpu_data["input_ids"], - gpu_data["input_ids_len"], + gpu_data["token_ids_all"], + gpu_data["prompt_lens"], gpu_data["pre_ids"], gpu_data["step_idx"], gpu_data["draft_token_num"], @@ -879,8 +884,8 @@ def test_correctness_varied_seeds(self): cpu_draft = data["draft_tokens"].copy() cpu_slt = data["seq_lens_this_time"].copy() _cpu_hybrid_mtp_ngram( - data["input_ids"], - data["input_ids_len"], + data["token_ids_all"], + data["prompt_lens"], data["pre_ids"], data["step_idx"], data["draft_token_num"], @@ -894,8 +899,8 @@ def test_correctness_varied_seeds(self): ) gpu_data = _to_gpu(data) self.hybrid_mtp_ngram( - gpu_data["input_ids"], - gpu_data["input_ids_len"], + gpu_data["token_ids_all"], + gpu_data["prompt_lens"], gpu_data["pre_ids"], gpu_data["step_idx"], gpu_data["draft_token_num"], @@ -922,8 +927,8 @@ def test_large_batch_long_seq(self): cpu_draft = data["draft_tokens"].copy() cpu_slt = data["seq_lens_this_time"].copy() _cpu_hybrid_mtp_ngram( - data["input_ids"], - data["input_ids_len"], + data["token_ids_all"], + data["prompt_lens"], data["pre_ids"], data["step_idx"], data["draft_token_num"], @@ -941,8 +946,8 @@ def test_large_batch_long_seq(self): os.environ["SPEC_TOKENUM_THRESHOLD"] = str(high_threshold) try: self.hybrid_mtp_ngram( - gpu_data["input_ids"], - gpu_data["input_ids_len"], + gpu_data["token_ids_all"], + gpu_data["prompt_lens"], gpu_data["pre_ids"], gpu_data["step_idx"], gpu_data["draft_token_num"], @@ -969,8 +974,8 @@ def test_single_batch_long_seq(self): cpu_draft = data["draft_tokens"].copy() cpu_slt = data["seq_lens_this_time"].copy() _cpu_hybrid_mtp_ngram( - data["input_ids"], - data["input_ids_len"], + data["token_ids_all"], + data["prompt_lens"], data["pre_ids"], data["step_idx"], data["draft_token_num"], @@ -984,8 +989,8 @@ def test_single_batch_long_seq(self): ) gpu_data = _to_gpu(data) self.hybrid_mtp_ngram( - gpu_data["input_ids"], - gpu_data["input_ids_len"], + gpu_data["token_ids_all"], + gpu_data["prompt_lens"], gpu_data["pre_ids"], gpu_data["step_idx"], gpu_data["draft_token_num"], @@ -1008,8 +1013,8 @@ def test_many_short_seqs(self): cpu_draft = data["draft_tokens"].copy() cpu_slt = data["seq_lens_this_time"].copy() _cpu_hybrid_mtp_ngram( - data["input_ids"], - data["input_ids_len"], + data["token_ids_all"], + data["prompt_lens"], data["pre_ids"], data["step_idx"], data["draft_token_num"], @@ -1027,8 +1032,8 @@ def test_many_short_seqs(self): os.environ["SPEC_TOKENUM_THRESHOLD"] = str(high_threshold) try: self.hybrid_mtp_ngram( - gpu_data["input_ids"], - gpu_data["input_ids_len"], + gpu_data["token_ids_all"], + gpu_data["prompt_lens"], gpu_data["pre_ids"], gpu_data["step_idx"], gpu_data["draft_token_num"],