From 6f20ea09e5eb88bf80b732919bcedb4e3d8356d8 Mon Sep 17 00:00:00 2001 From: Deleter-D <867909454@qq.com> Date: Wed, 1 Apr 2026 16:18:12 +0800 Subject: [PATCH 1/2] refactor mtp.py --- fastdeploy/spec_decode/mtp.py | 725 ++----------------------- fastdeploy/spec_decode/mtp_cuda.py | 444 +++++++++++++++ fastdeploy/spec_decode/mtp_xpu.py | 276 ++++++++++ fastdeploy/spec_decode/types.py | 4 +- tests/spec_decode/test_mtp_proposer.py | 177 +----- 5 files changed, 803 insertions(+), 823 deletions(-) create mode 100644 fastdeploy/spec_decode/mtp_cuda.py create mode 100644 fastdeploy/spec_decode/mtp_xpu.py diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 38e4ed13064..02cfadd6e62 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -16,13 +16,13 @@ import os import time +from abc import abstractmethod from typing import TYPE_CHECKING, List import numpy as np import paddle from paddleformers.utils.log import logger -from fastdeploy import envs from fastdeploy.engine.request import Request, RequestType from fastdeploy.inter_communicator import IPCSignal from fastdeploy.model_executor.forward_meta import ForwardMeta @@ -30,53 +30,19 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) -from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import MTPSampler from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.models import ModelForCasualLM from fastdeploy.platforms import current_platform if current_platform.is_xpu(): - from fastdeploy.model_executor.ops.xpu import ( - draft_model_postprocess, - draft_model_preprocess, - draft_model_update, - eagle_get_hidden_states, - eagle_get_self_hidden_states, - mtp_save_first_token, - mtp_step_paddle, - set_data_ipc, - share_external_data, - update_attn_mask_offsets, - ) - from fastdeploy.model_executor.xpu_pre_and_post_process import ( - xpu_pre_process, - xpu_process_output, - ) + from fastdeploy.model_executor.ops.xpu import set_data_ipc, share_external_data else: - from fastdeploy.model_executor.ops.gpu import ( - draft_model_postprocess, - draft_model_preprocess, - draft_model_update, - eagle_get_hidden_states, - eagle_get_self_hidden_states, - eagle_gather_hidden_states, - hybrid_mtp_ngram, - mtp_save_first_token, - mtp_step_paddle, - share_external_data, - speculate_get_logits, - speculate_save_output_topk, - update_attn_mask_offsets, - set_data_ipc, - unset_data_ipc, - ) - from fastdeploy.model_executor.pre_and_post_process import async_set_value, pre_process + from fastdeploy.model_executor.ops.gpu import set_data_ipc, share_external_data + from fastdeploy.model_executor.pre_and_post_process import async_set_value from fastdeploy.worker.input_batch import ( ProposerInputBatch, - recover_batch_index_for_output, - recover_batch_index_for_sampler_output, reorder_split_prefill_and_decode_form_index_to_batch_id, ) @@ -89,6 +55,9 @@ class MTPProposer(Proposer): """ Proposer for Multi-Token-Prediction(MTP) + + Base class containing common logic. Platform-specific behavior is + implemented in MTPProposerCUDA and MTPProposerXPU subclasses. """ def __init__( @@ -118,17 +87,6 @@ def __init__( self.role = self.scheduler_config.splitwise_role self.pd_disaggregation_mode = fd_config.parallel_config.pd_disaggregation_mode - if current_platform.is_xpu(): - self._prepare_inputs = self._prepare_inputs_xpu - self._propose = self._propose_xpu - elif current_platform.is_cuda() or current_platform.is_maca(): - self._prepare_inputs = self._prepare_inputs_cuda - self._propose = self._propose_cuda - else: - raise RuntimeError( - f"Unsupported platform for MTP: {current_platform}. " f"Supported platforms: CUDA, MACA, XPU" - ) - self.sampler = MTPSampler(fd_config) self.model_inputs = ProposerInputBatch(self.fd_config, self.target_model_inputs) self.model_inputs.init_share_inputs() @@ -145,6 +103,43 @@ def __init__( self.forward_meta = None self.exist_prefill_flag = False + # ======================== Abstract methods ======================== + # Subclasses (MTPProposerCUDA / MTPProposerXPU) must implement these. + + @abstractmethod + def _initialize_forward_meta( + self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0 + ) -> None: + """Initialize forward meta and attention metadata for a substep.""" + ... + + @abstractmethod + def _prepare_inputs(self, full_hidden_states: paddle.Tensor) -> None: + """Prepare MTP inputs from target model hidden states (whole-proposer preprocessing).""" + ... + + @abstractmethod + def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0) -> None: + """Execute the multi-step MTP inference loop (per-substep preprocessing / forward / sampling).""" + ... + + @abstractmethod + def _post_process(self, sampled_token_ids) -> None: + """Per-substep post-processing after sampling.""" + ... + + @abstractmethod + def _update_status(self) -> None: + """Whole-proposer post-processing: update main-model forward info and manage MTP block allocation.""" + ... + + # ======================== Overridable hooks ======================== + def _extend_draft_token_with_ngram_match(self): + """Extend draft tokens with ngram matching. CUDA-only feature; no-op by default.""" + pass + + # ======================== Common methods ======================== + def _update_mtp_config(self, main_model): """ Update config for MTP from global config @@ -419,23 +414,6 @@ def _initialize_attn_backend( ) self.attn_backends.append(attn_backend) - def clear_mtp_cache(self, profile=False): - """ - Clear allocated cacheKV - """ - create_cache_tensor = profile or not ( - self.fd_config.cache_config.num_cpu_blocks > 0 - or self.fd_config.cache_config.kvcache_storage_backend - or self.fd_config.scheduler_config.splitwise_role != "mixed" - ) - if not create_cache_tensor: - for name, tensor in self.cache_kvs_map.items(): - unset_data_ipc(tensor, name, True, False) - self.cache_kvs_map.clear() - del self.model_inputs["caches"] - if self.forward_meta is not None: - del self.forward_meta.caches - def update_mtp_block_num(self, num_gpu_blocks) -> None: """ Update MTP block num by theoretical calculation @@ -641,546 +619,12 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: ) self.model_inputs.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"] - def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0): - """ - Initialize forward meta and attention meta data - """ - # Initialize forward meta - self.forward_meta = ForwardMeta( - ids_remove_padding=self.model_inputs["ids_remove_padding"], - rotary_embs=self.model_inputs["rope_emb"], - attn_backend=self.attn_backends[0], - decoder_batch_ids=self.model_inputs["decoder_batch_ids"], - decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"], - decoder_num_blocks_cpu=self.model_inputs["decoder_num_blocks_cpu"], - decoder_num_blocks_device=self.model_inputs["decoder_num_blocks_device"], - decoder_chunk_size_device=self.model_inputs["decoder_chunk_size_device"], - max_len_tensor_cpu=self.model_inputs["max_len_tensor_cpu"], - seq_lens_encoder=self.model_inputs["seq_lens_encoder"], - seq_lens_decoder=self.model_inputs["seq_lens_decoder"], - seq_lens_this_time=self.model_inputs["seq_lens_this_time"], - batch_id_per_token=self.model_inputs["batch_id_per_token"], - cu_seqlens_q=self.model_inputs["cu_seqlens_q"], - cu_seqlens_k=self.model_inputs["cu_seqlens_k"], - block_tables=self.model_inputs["block_tables"], - caches=self.model_inputs["caches"], - encoder_batch_ids=self.model_inputs["encoder_batch_ids"], - encoder_tile_ids_per_batch=self.model_inputs["encoder_tile_ids_per_batch"], - encoder_num_blocks_x_cpu=self.model_inputs["encoder_num_blocks_x_cpu"], - kv_batch_ids=self.model_inputs["kv_batch_ids"], - kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"], - kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"], - attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.use_attn_mask_offset else None, - ) - - # Initialzie attention meta data - for attn_backend in self.attn_backends: - attn_backend.init_attention_metadata(self.forward_meta) - - # Notes(liuzichang): - # 1. CUDA Graph capture sizes must be recorded in descending order (large → small). - # 2. In multi-step execution, only the first step should be captured. - self.forward_meta.step_use_cudagraph = ( - step_use_cudagraph and self.draft_model_use_cudagraph and not (substep > 0 and is_dummy_run) - ) - - def _initialize_forward_meta_xpu(self): - - self.forward_meta.decoder_batch_ids = (self.model_inputs["decoder_batch_ids"],) - self.forward_meta.decoder_tile_ids_per_batch = (self.model_inputs["decoder_tile_ids_per_batch"],) - self.forward_meta.decoder_num_blocks_cpu = (self.model_inputs["decoder_num_blocks_cpu"],) - self.forward_meta.decoder_num_blocks_device = (self.model_inputs["decoder_num_blocks_device"],) - self.forward_meta.decoder_chunk_size_device = (self.model_inputs["decoder_chunk_size_device"],) - self.forward_meta.max_len_tensor_cpu = (self.model_inputs["max_len_tensor_cpu"],) - - self.forward_meta.encoder_batch_ids = (self.model_inputs["encoder_batch_ids"],) - self.forward_meta.encoder_tile_ids_per_batch = (self.model_inputs["encoder_tile_ids_per_batch"],) - self.forward_meta.encoder_num_blocks_x_cpu = (self.model_inputs["encoder_num_blocks_x_cpu"],) - self.forward_meta.kv_batch_ids = (self.model_inputs["kv_batch_ids"],) - self.forward_meta.kv_tile_ids_per_batch = (self.model_inputs["kv_tile_ids_per_batch"],) - self.forward_meta.kv_num_blocks_x_cpu = (self.model_inputs["kv_num_blocks_x_cpu"],) - self.forward_meta.attn_backend = self.attn_backends[0] - if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query": - self.forward_meta.kv_signal_sender = self.target_model_inputs["kv_signal_sender"] - - self.forward_meta.is_draft = True - - # Initialzie attention meta data - for attn_backend in self.attn_backends: - attn_backend.init_attention_metadata(self.forward_meta) - def exist_prefill(self): """ check whether prefill stage exist """ return self.exist_prefill_flag - def _prepare_inputs_cuda(self, full_hidden_states): - """ - Prepare MTP inputs - - MTP state (seq_lens_decoder, step_idx) is "shadow state": - - Initialized from target model state each round - - Used for MTP forward, but not committed until verify - - No rollback needed since it's always re-initialized - """ - - draft_model_preprocess( - self.model_inputs["draft_tokens"], - self.model_inputs["input_ids"], - self.model_inputs["stop_flags"], - self.model_inputs["seq_lens_this_time"], - self.model_inputs["seq_lens_encoder"], - self.model_inputs["seq_lens_decoder"], - self.model_inputs["step_idx"], - self.model_inputs["not_need_stop_device"], - self.model_inputs["pre_ids"], - self.target_model_inputs["accept_tokens"], - self.target_model_inputs["accept_num"], - self.target_model_inputs["seq_lens_encoder"], - self.target_model_inputs["seq_lens_decoder"], - self.target_model_inputs["step_idx"], - self.target_model_inputs["stop_flags"], - self.model_inputs["max_dec_len"], - self.target_model_inputs["draft_tokens"], - self.num_model_steps, - self.role == "prefill", # is_splitwise_prefill - ) - - target_hidden_states, _ = eagle_get_hidden_states( - full_hidden_states, - self.model_inputs["seq_lens_this_time"], - self.model_inputs["seq_lens_encoder"], - self.model_inputs["seq_lens_decoder"], - self.model_inputs["stop_flags"], - self.target_model_inputs["accept_num"], - self.target_model_inputs["seq_lens_this_time"], - self.target_model_inputs["seq_lens_encoder"], - self.num_model_steps, - ) - - self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) - - def _prepare_inputs_xpu(self, full_hidden_states): - use_v1_cache_scheduler = bool(envs.ENABLE_V1_KVCACHE_SCHEDULER) - draft_model_preprocess( - self.model_inputs["draft_tokens"], - self.model_inputs["input_ids"], - self.model_inputs["stop_flags"], - self.model_inputs["seq_lens_this_time"], - self.model_inputs["seq_lens_encoder"], - self.model_inputs["seq_lens_decoder"], - self.model_inputs["step_idx"], - self.model_inputs["not_need_stop"], - self.model_inputs["batch_drop"], - self.model_inputs["is_block_step"], - self.model_inputs["pre_ids"], - self.model_inputs["mask_rollback"], - self.model_inputs["recompute_token_num"], - self.target_model_inputs["accept_tokens"], - self.target_model_inputs["accept_num"], - self.target_model_inputs["seq_lens_this_time"], - self.target_model_inputs["seq_lens_encoder"], - self.target_model_inputs["seq_lens_decoder"], - self.target_model_inputs["step_idx"], - self.target_model_inputs["stop_flags"], - self.target_model_inputs["is_block_step"], - self.target_model_inputs["draft_tokens"], - self.num_model_steps, - True, - self.role == "prefill", - use_v1_cache_scheduler, - ) - - target_hidden_states = eagle_get_hidden_states( - full_hidden_states, - self.model_inputs["seq_lens_this_time"], - self.model_inputs["seq_lens_encoder"], - self.model_inputs["seq_lens_decoder"], - self.model_inputs["stop_flags"], - self.target_model_inputs["accept_num"], - self.target_model_inputs["seq_lens_this_time"], - self.target_model_inputs["seq_lens_encoder"], - self.num_model_steps, - ) - self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) - - def _post_process(self, sampled_token_ids): - """ - PostProcess for generation - """ - draft_model_update( - sampled_token_ids, - self.model_inputs["draft_tokens"], - self.model_inputs["pre_ids"], - self.model_inputs["seq_lens_this_time"], - self.model_inputs["seq_lens_encoder"], - self.model_inputs["seq_lens_decoder"], - self.model_inputs["step_idx"], - # Note(ZKK): - # I strongly advise xpu student delete the fuck `output_cum_offsets` name in XPU backend - # like my pr https://github.com/PaddlePaddle/FastDeploy/pull/6358 - ( - self.model_inputs["cu_seqlens_q_output"] - if current_platform.is_cuda() - else self.model_inputs["output_cum_offsets"] - ), - self.model_inputs["stop_flags"], - ( - self.model_inputs["not_need_stop_device"] - if current_platform.is_cuda() - else self.model_inputs["not_need_stop"] - ), - self.model_inputs["max_dec_len"], - self.model_inputs["eos_token_id"], - self.model_inputs["base_model_draft_tokens"], - self.max_model_len, - self.model_inputs["substep"], - ) - - if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0: - skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) - recover_model_output_map = recover_batch_index_for_output( - self.model_inputs, - self.model_inputs.index_to_batch_id, - self.model_inputs.enable_pd_reorder, - ["base_model_draft_tokens", "seq_lens_decoder", "prompt_lens", "step_idx"], - ) - mtp_save_first_token( - recover_model_output_map["base_model_draft_tokens"], - self.model_inputs["not_need_stop"], - recover_model_output_map["seq_lens_decoder"], - recover_model_output_map["prompt_lens"], - recover_model_output_map["step_idx"], - self.local_rank, - self.parallel_config.use_ep, - skip_save, - ) - # Ensure only save first token once. - paddle.assign( - paddle.where( - self.model_inputs["stop_flags"], - paddle.zeros_like(self.model_inputs["step_idx"]), - self.model_inputs["step_idx"], - ), - self.model_inputs["step_idx"], - ) - - def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0): - """ - Main process for MTP inference. - Args: - step_use_cudagraph: bool - Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP. - """ - is_blocking = ( - (not self.fd_config.scheduler_config.enable_overlap_schedule) - or is_dummy_run - or self.exist_prefill() - or real_bsz == 0 - ) - for substep in range(self.num_model_steps): - if is_blocking: - token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item() - else: - if substep == 0: - token_num_cpu = real_bsz * (self.max_draft_token_num + 1) - else: - token_num_cpu = real_bsz - if token_num_cpu > 0: - self.model_inputs["substep"] = substep - # Remove padding - ( - ids_remove_padding, - batch_id_per_token, - cu_seqlens_q, - cu_seqlens_k, - cu_seqlens_q_output, - batch_id_per_token_output, - real_output_token_num, - ) = pre_process( - token_num_cpu, - self.model_inputs["input_ids"], - self.model_inputs["seq_lens_this_time"], - True, - self.model_inputs["draft_tokens"], - self.model_inputs["seq_lens_encoder"], - self.model_inputs["seq_lens_decoder"], - ) - - if self.use_attn_mask_offset: - attn_mask_offsets = update_attn_mask_offsets( - ids_remove_padding, - getattr( - self.model_inputs, "seq_lens_this_time", self.model_inputs["seq_lens_this_time_buffer"] - ), - self.model_inputs["seq_lens_encoder"], - self.model_inputs["seq_lens_decoder"], - cu_seqlens_q, - self.model_inputs["attn_mask_offsets_full"], - self.model_inputs["is_block_step"], - self.model_inputs["decode_states"], - ) - self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False) - - # Initialize forward meta data - self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) - self.model_inputs["batch_id_per_token"][:] = -1 - self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) - self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) - - # For speculative decoding - self.model_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False) - self.model_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False) - - # Initialize forward meta data - self._initialize_forward_meta( - step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep - ) - self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False) - self.forward_meta.real_bsz = real_bsz - - # Padding inputs for cuda graph - self.padding_cudagraph_inputs() - - # Get sampling metadata - self.sampling_metadata = SamplingMetadata( - temperature=self.model_inputs["temperature"], - top_p=self.model_inputs["top_p"], - top_k=self.model_inputs["top_k"], - seed=self.model_inputs["infer_seed"], - step_idx=self.model_inputs["step_idx"], - token_ids_all=self.model_inputs["token_ids_all"], - pre_token_ids=self.model_inputs["pre_ids"], - prompt_lens=self.model_inputs["prompt_lens"], - fake_prompt_lens=self.model_inputs["fake_prompt_lens"], - frequency_penalties=self.model_inputs["frequency_score"], - presence_penalties=self.model_inputs["presence_score"], - repetition_penalties=self.model_inputs["penalty_score"], - min_dec_lens=self.model_inputs["min_dec_len"], - bad_words_token_ids=self.model_inputs["bad_tokens"], - bad_words_token_len=self.model_inputs["bad_tokens_len"], - eos_token_ids=self.model_inputs["eos_token_id"], - max_num_logprobs=20 if self.enable_logprob else None, - temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"], - top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"], - share_inputs=self.model_inputs, - ) - - real_num = self.model_inputs["ids_remove_padding"].shape[0] - target_hidden_states = self.model_inputs["target_hidden_states"][:real_num] - model_output = self.model( - ids_remove_padding=self.model_inputs["ids_remove_padding"], - previous_hidden_states=target_hidden_states, - forward_meta=self.forward_meta, - ) - if self.forward_meta.step_use_cudagraph: - model_output = model_output[: self.real_token_num] - - hidden_states, _ = eagle_gather_hidden_states( - model_output, - self.model_inputs["cu_seqlens_q"], - self.model_inputs["seq_lens_this_time"], - self.model_inputs["seq_lens_decoder"], - self.model_inputs["seq_lens_encoder"], - self.model_inputs["batch_id_per_token_output"], - self.model_inputs["cu_seqlens_q_output"], - real_output_token_num, - ) - - # 4. Compute logits, Sample - logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta) - if self.enable_logprob and self.enable_draft_logprob and substep == 0: - first_token_logits = self.model.compute_logits( - self.model_inputs["first_token_hidden_states"], forward_meta=self.forward_meta - ) - - speculate_get_logits( - self.model_inputs["draft_logits"], - self.model_inputs["next_token_num"], - self.model_inputs["batch_token_num"], - self.model_inputs["cu_next_token_offset"], - self.model_inputs["cu_batch_token_offset"], - logits, - first_token_logits, - self.model_inputs["seq_lens_this_time"], - self.model_inputs["seq_lens_encoder"], - ) - - sampled_token_ids, sampler_output = self.sampler( - logits, - self.sampling_metadata, - self.max_model_len, - self.model_inputs, - ) - - if ( - not is_dummy_run - and self.parallel_config.tensor_parallel_rank == 0 - and substep == 0 - and sampler_output.logprobs_tensors is not None - ): - real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] - recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id) - recover_model_output_map = recover_batch_index_for_output( - self.model_inputs, - self.model_inputs.index_to_batch_id, - self.model_inputs.enable_pd_reorder, - ["batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens"], - ) - speculate_save_output_topk( - sampler_output.sampled_token_ids, - sampler_output.logprobs_tensors.logprob_token_ids, - sampler_output.logprobs_tensors.logprobs, - sampler_output.logprobs_tensors.selected_token_ranks, - recover_model_output_map["batch_token_num"][:real_bsz], - recover_model_output_map["cu_batch_token_offset"][:real_bsz], - self.model_inputs["not_need_stop"], - recover_model_output_map["seq_lens_decoder"], - recover_model_output_map["prompt_lens"], - 4, # mtype - self.local_rank, - self.parallel_config.use_ep, - ) - - if self.parallel_config.tensor_parallel_size > 1: - paddle.distributed.broadcast( - sampled_token_ids, - self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, - group=self.parallel_config.tp_group, - ) - - self._post_process(sampled_token_ids) - self.model_inputs["target_hidden_states"].copy_(hidden_states, False) - else: - if hasattr(self.model, "empty_input_forward") and not is_dummy_run: - self.model.empty_input_forward(forward_meta=self.forward_meta) - self.exist_prefill_flag = False - - def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0): - """ - Main process for MTP inference. - Args: - step_use_cudagraph: bool - Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP. - """ - for substep in range(self.num_model_steps): - if self.model_inputs["not_need_stop"]: - self.model_inputs["substep"] = substep - # Remove padding - self.forward_meta = xpu_pre_process( - self.model_inputs["input_ids"], - self.model_inputs["seq_lens_this_time"], - self.model_inputs, - True, - self.cache_config.block_size, - self.model_inputs["draft_tokens"], - self.model_inputs["seq_lens_encoder"], - self.model_inputs["seq_lens_decoder"], - ) - - if self.enable_mm: - attn_mask_offsets = update_attn_mask_offsets( - self.model_inputs["ids_remove_padding"], - getattr( - self.model_inputs, "seq_lens_this_time", self.model_inputs["seq_lens_this_time_buffer"] - ), - self.model_inputs["seq_lens_encoder"], - self.model_inputs["seq_lens_decoder"], - self.model_inputs["cu_seqlens_q"], - self.model_inputs["attn_mask_offsets_full"], - self.model_inputs["attn_mask_offsets_decoder"], - self.model_inputs["is_block_step"], - self.model_inputs["decode_states"], - self.model_inputs["mask_rollback"], - ) - self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False) - - self._initialize_forward_meta_xpu() - # Get sampling metadata - self.sampling_metadata = SamplingMetadata( - temperature=self.model_inputs["temperature"], - top_p=self.model_inputs["top_p"], - top_k=self.model_inputs["top_k"], - seed=self.model_inputs["infer_seed"], - step_idx=self.model_inputs["step_idx"], - pre_token_ids=self.model_inputs["pre_ids"], - frequency_penalties=self.model_inputs["frequency_score"], - presence_penalties=self.model_inputs["presence_score"], - repetition_penalties=self.model_inputs["penalty_score"], - min_dec_lens=self.model_inputs["min_dec_len"], - bad_words_token_ids=self.model_inputs["bad_tokens"], - eos_token_ids=self.model_inputs["eos_token_id"], - max_num_logprobs=20 if self.enable_logprob else None, - temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"], - top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"], - share_inputs=self.model_inputs, - ) - - if self.num_model_steps > 1: - self.model_inputs.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"]) - - model_output = self.model( - ids_remove_padding=self.model_inputs["ids_remove_padding"], - previous_hidden_states=self.model_inputs["target_hidden_states"], - forward_meta=self.forward_meta, - ) - hidden_states = xpu_process_output( - model_output, self.model_inputs["cum_offsets"], self.forward_meta, self.model_inputs - ) - # 4. Compute logits, Sample - logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta) - sampled_token_ids, sampler_output = self.sampler( - logits, - self.sampling_metadata, - self.max_model_len, - self.model_inputs, - ) - - if substep == 0 and sampler_output.logprobs_tensors is not None: - real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] - recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id) - recover_model_output_map = recover_batch_index_for_output( - self.model_inputs, - self.model_inputs.index_to_batch_id, - self.model_inputs.enable_pd_reorder, - ["batch_token_num", "cu_batch_token_offset"], - ) - speculate_save_output_topk( - sampler_output.sampled_token_ids, - sampler_output.logprobs_tensors.logprob_token_ids, - sampler_output.logprobs_tensors.logprobs, - sampler_output.logprobs_tensors.selected_token_ranks, - recover_model_output_map["batch_token_num"][:real_bsz], - recover_model_output_map["cu_batch_token_offset"][:real_bsz], - self.model_inputs["not_need_stop"], - 4, # mtype - self.local_rank, - ) - - if self.parallel_config.tensor_parallel_size > 1: - paddle.distributed.broadcast( - sampled_token_ids, - self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, - group=self.parallel_config.tp_group, - ) - - self._post_process(sampled_token_ids) - if substep != self.num_model_steps - 1: - self._get_self_hidden_states_xpu(hidden_states) - else: - if hasattr(self.model, "empty_input_forward") and not is_dummy_run: - self.model.empty_input_forward(self.forward_meta) - - def _get_self_hidden_states_xpu(self, hidden_states): - target_hidden_states = eagle_get_self_hidden_states( - hidden_states, - self.model_inputs.last_seq_lens_this_time, - self.model_inputs["seq_lens_this_time"], - self.model_inputs["step_idx"], - ) - self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) - def update_task_chunk_prefill(self, task): """ Update single task's chunk_prefill info @@ -1210,58 +654,6 @@ def update_task_chunk_prefill(self, task): self.model_inputs["step_idx"][idx : idx + 1] = 0 self.model_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0) - def _update_status(self): - """ - Update main-model's forward info in next step. - Allocate/Free block of MPT. - """ - draft_model_postprocess( - self.target_model_inputs["draft_tokens"], - self.target_model_inputs["seq_lens_this_time"], - self.target_model_inputs["seq_lens_encoder"], - self.target_model_inputs["stop_flags"], - ) - if not envs.ENABLE_V1_KVCACHE_SCHEDULER: - mtp_step_paddle( - self.target_model_inputs["stop_flags"], - self.model_inputs["stop_flags"], - self.model_inputs["batch_drop"], - self.model_inputs["seq_lens_this_time"], - self.model_inputs["seq_lens_encoder"], - self.model_inputs["seq_lens_decoder"], - self.model_inputs["block_tables"], - self.model_inputs["encoder_block_lens"], - self.model_inputs["used_list_len"], - self.model_inputs["free_list"], - self.model_inputs["free_list_len"], - self.cache_config.block_size, - self.max_draft_token_num, - ) - - def _extend_draft_token_with_ngram_match(self): - # TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency - device = paddle.CUDAPinnedPlace() - - draft_tokens = self.target_model_inputs["draft_tokens"].cpu() - seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu() - seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu() - hybrid_mtp_ngram( - self.model_inputs["input_ids_cpu"], - self.model_inputs["input_ids_len"], - self.model_inputs["pre_ids"]._copy_to(device, True), - self.model_inputs["step_idx"].cpu(), - self.target_model_inputs["actual_draft_token_num"].cpu(), - draft_tokens, - seq_lens_this_time, - seq_lens_decoder, - self.model_inputs["max_dec_len"].cpu(), - self.max_ngram_size, - self.min_ngram_size, - self.max_draft_token_num, - ) - self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda() - self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() - def _run_impl( self, full_hidden_states: paddle.Tensor, @@ -1280,19 +672,6 @@ def is_chunk_prefill_enabled(self): """""" return True - def padding_cudagraph_inputs(self) -> None: - """ - Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. - In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch. - """ - # In init_attention_metadata, the decode buffer has already been cleared - - # To adapt to CUDA Graph, keep the forward pass at the maximum batch size. - if self.forward_meta.step_use_cudagraph: - self.forward_meta.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"] - self.real_token_num = self.forward_meta.ids_remove_padding.shape[0] - return - def _empty_cache(self): if current_platform.is_cuda(): paddle.device.cuda.empty_cache() @@ -1322,3 +701,19 @@ def _share_external_data(self, cache, cache_name, cache_shape): return share_external_data(cache, cache_name, cache_shape, False) else: return share_external_data(cache, cache_name, cache_shape) + + +def create_mtp_proposer(fd_config, main_model, local_rank, device_id, share_inputs): + """Factory function that returns the platform-specific MTPProposer subclass.""" + if current_platform.is_xpu(): + from fastdeploy.spec_decode.mtp_xpu import MTPProposerXPU + + return MTPProposerXPU(fd_config, main_model, local_rank, device_id, share_inputs) + elif current_platform.is_cuda() or current_platform.is_maca(): + from fastdeploy.spec_decode.mtp_cuda import MTPProposerCUDA + + return MTPProposerCUDA(fd_config, main_model, local_rank, device_id, share_inputs) + else: + raise RuntimeError( + f"Unsupported platform for MTP: {current_platform}. " f"Supported platforms: CUDA, MACA, XPU" + ) diff --git a/fastdeploy/spec_decode/mtp_cuda.py b/fastdeploy/spec_decode/mtp_cuda.py new file mode 100644 index 00000000000..bd76ac080c7 --- /dev/null +++ b/fastdeploy/spec_decode/mtp_cuda.py @@ -0,0 +1,444 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy import envs +from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata +from fastdeploy.model_executor.ops.gpu import ( + draft_model_postprocess, + draft_model_preprocess, + draft_model_update, + eagle_gather_hidden_states, + eagle_get_hidden_states, + hybrid_mtp_ngram, + mtp_save_first_token, + speculate_get_logits, + speculate_save_output_topk, + unset_data_ipc, + update_attn_mask_offsets, +) +from fastdeploy.model_executor.pre_and_post_process import pre_process +from fastdeploy.worker.input_batch import ( + recover_batch_index_for_output, + recover_batch_index_for_sampler_output, +) + +from .mtp import MTPProposer + + +class MTPProposerCUDA(MTPProposer): + """ + CUDA-specific MTPProposer implementation. + """ + + def _prepare_inputs(self, full_hidden_states): + """ + Prepare MTP inputs + + MTP state (seq_lens_decoder, step_idx) is "shadow state": + - Initialized from target model state each round + - Used for MTP forward, but not committed until verify + - No rollback needed since it's always re-initialized + """ + + draft_model_preprocess( + self.model_inputs["draft_tokens"], + self.model_inputs["input_ids"], + self.model_inputs["stop_flags"], + self.model_inputs["seq_lens_this_time"], + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["step_idx"], + self.model_inputs["not_need_stop_device"], + self.model_inputs["pre_ids"], + self.target_model_inputs["accept_tokens"], + self.target_model_inputs["accept_num"], + self.target_model_inputs["seq_lens_encoder"], + self.target_model_inputs["seq_lens_decoder"], + self.target_model_inputs["step_idx"], + self.target_model_inputs["stop_flags"], + self.model_inputs["max_dec_len"], + self.target_model_inputs["draft_tokens"], + self.num_model_steps, + self.role == "prefill", # is_splitwise_prefill + ) + + target_hidden_states, _ = eagle_get_hidden_states( + full_hidden_states, + self.model_inputs["seq_lens_this_time"], + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["stop_flags"], + self.target_model_inputs["accept_num"], + self.target_model_inputs["seq_lens_this_time"], + self.target_model_inputs["seq_lens_encoder"], + self.num_model_steps, + ) + + self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) + + def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0): + """ + Initialize forward meta and attention meta data + """ + # Initialize forward meta + self.forward_meta = ForwardMeta( + ids_remove_padding=self.model_inputs["ids_remove_padding"], + rotary_embs=self.model_inputs["rope_emb"], + attn_backend=self.attn_backends[0], + decoder_batch_ids=self.model_inputs["decoder_batch_ids"], + decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"], + decoder_num_blocks_cpu=self.model_inputs["decoder_num_blocks_cpu"], + decoder_num_blocks_device=self.model_inputs["decoder_num_blocks_device"], + decoder_chunk_size_device=self.model_inputs["decoder_chunk_size_device"], + max_len_tensor_cpu=self.model_inputs["max_len_tensor_cpu"], + seq_lens_encoder=self.model_inputs["seq_lens_encoder"], + seq_lens_decoder=self.model_inputs["seq_lens_decoder"], + seq_lens_this_time=self.model_inputs["seq_lens_this_time"], + batch_id_per_token=self.model_inputs["batch_id_per_token"], + cu_seqlens_q=self.model_inputs["cu_seqlens_q"], + cu_seqlens_k=self.model_inputs["cu_seqlens_k"], + block_tables=self.model_inputs["block_tables"], + caches=self.model_inputs["caches"], + encoder_batch_ids=self.model_inputs["encoder_batch_ids"], + encoder_tile_ids_per_batch=self.model_inputs["encoder_tile_ids_per_batch"], + encoder_num_blocks_x_cpu=self.model_inputs["encoder_num_blocks_x_cpu"], + kv_batch_ids=self.model_inputs["kv_batch_ids"], + kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"], + kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"], + attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.use_attn_mask_offset else None, + ) + + # Initialzie attention meta data + for attn_backend in self.attn_backends: + attn_backend.init_attention_metadata(self.forward_meta) + + # Notes(liuzichang): + # 1. CUDA Graph capture sizes must be recorded in descending order (large → small). + # 2. In multi-step execution, only the first step should be captured. + self.forward_meta.step_use_cudagraph = ( + step_use_cudagraph and self.draft_model_use_cudagraph and not (substep > 0 and is_dummy_run) + ) + + def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0): + """ + Main process for MTP inference. + Args: + step_use_cudagraph: bool + Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP. + """ + is_blocking = ( + (not self.fd_config.scheduler_config.enable_overlap_schedule) + or is_dummy_run + or self.exist_prefill() + or real_bsz == 0 + ) + for substep in range(self.num_model_steps): + if is_blocking: + token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item() + else: + if substep == 0: + token_num_cpu = real_bsz * (self.max_draft_token_num + 1) + else: + token_num_cpu = real_bsz + if token_num_cpu > 0: + self.model_inputs["substep"] = substep + # Remove padding + token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item() + ( + ids_remove_padding, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_q_output, + batch_id_per_token_output, + real_output_token_num, + ) = pre_process( + token_num_cpu, + self.model_inputs["input_ids"], + self.model_inputs["seq_lens_this_time"], + True, + self.model_inputs["draft_tokens"], + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + ) + + if self.use_attn_mask_offset: + attn_mask_offsets = update_attn_mask_offsets( + ids_remove_padding, + getattr( + self.model_inputs, "seq_lens_this_time", self.model_inputs["seq_lens_this_time_buffer"] + ), + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + cu_seqlens_q, + self.model_inputs["attn_mask_offsets_full"], + self.model_inputs["is_block_step"], + self.model_inputs["decode_states"], + ) + self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False) + + # Initialize forward meta data + self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) + self.model_inputs["batch_id_per_token"][:] = -1 + self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) + self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) + + # For speculative decoding + self.model_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False) + self.model_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False) + + # Initialize forward meta data + self._initialize_forward_meta( + step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep + ) + self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False) + self.forward_meta.real_bsz = real_bsz + + # Padding inputs for cuda graph + self.padding_cudagraph_inputs() + + # Get sampling metadata + self.sampling_metadata = SamplingMetadata( + temperature=self.model_inputs["temperature"], + top_p=self.model_inputs["top_p"], + top_k=self.model_inputs["top_k"], + seed=self.model_inputs["infer_seed"], + step_idx=self.model_inputs["step_idx"], + token_ids_all=self.model_inputs["token_ids_all"], + pre_token_ids=self.model_inputs["pre_ids"], + prompt_lens=self.model_inputs["prompt_lens"], + fake_prompt_lens=self.model_inputs["fake_prompt_lens"], + frequency_penalties=self.model_inputs["frequency_score"], + presence_penalties=self.model_inputs["presence_score"], + repetition_penalties=self.model_inputs["penalty_score"], + min_dec_lens=self.model_inputs["min_dec_len"], + bad_words_token_ids=self.model_inputs["bad_tokens"], + bad_words_token_len=self.model_inputs["bad_tokens_len"], + eos_token_ids=self.model_inputs["eos_token_id"], + max_num_logprobs=20 if self.enable_logprob else None, + temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"], + top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"], + share_inputs=self.model_inputs, + ) + + real_num = self.model_inputs["ids_remove_padding"].shape[0] + target_hidden_states = self.model_inputs["target_hidden_states"][:real_num] + model_output = self.model( + ids_remove_padding=self.model_inputs["ids_remove_padding"], + previous_hidden_states=target_hidden_states, + forward_meta=self.forward_meta, + ) + if self.forward_meta.step_use_cudagraph: + model_output = model_output[: self.real_token_num] + + hidden_states, _ = eagle_gather_hidden_states( + model_output, + self.model_inputs["cu_seqlens_q"], + self.model_inputs["seq_lens_this_time"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["seq_lens_encoder"], + self.model_inputs["batch_id_per_token_output"], + self.model_inputs["cu_seqlens_q_output"], + real_output_token_num, + ) + + # 4. Compute logits, Sample + logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta) + if self.enable_logprob and self.enable_draft_logprob and substep == 0: + first_token_logits = self.model.compute_logits( + self.model_inputs["first_token_hidden_states"], forward_meta=self.forward_meta + ) + + speculate_get_logits( + self.model_inputs["draft_logits"], + self.model_inputs["next_token_num"], + self.model_inputs["batch_token_num"], + self.model_inputs["cu_next_token_offset"], + self.model_inputs["cu_batch_token_offset"], + logits, + first_token_logits, + self.model_inputs["seq_lens_this_time"], + self.model_inputs["seq_lens_encoder"], + ) + + sampled_token_ids, sampler_output = self.sampler( + logits, + self.sampling_metadata, + self.max_model_len, + self.model_inputs, + ) + + if ( + not is_dummy_run + and self.parallel_config.tensor_parallel_rank == 0 + and substep == 0 + and sampler_output.logprobs_tensors is not None + ): + real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] + recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id) + recover_model_output_map = recover_batch_index_for_output( + self.model_inputs, + self.model_inputs.index_to_batch_id, + self.model_inputs.enable_pd_reorder, + ["batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens"], + ) + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + recover_model_output_map["batch_token_num"][:real_bsz], + recover_model_output_map["cu_batch_token_offset"][:real_bsz], + self.model_inputs["not_need_stop"], + recover_model_output_map["seq_lens_decoder"], + recover_model_output_map["prompt_lens"], + 4, # mtype + self.local_rank, + self.parallel_config.use_ep, + ) + + if self.parallel_config.tensor_parallel_size > 1: + paddle.distributed.broadcast( + sampled_token_ids, + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + + self._post_process(sampled_token_ids) + self.model_inputs["target_hidden_states"].copy_(hidden_states, False) + else: + if hasattr(self.model, "empty_input_forward") and not is_dummy_run: + self.model.empty_input_forward(forward_meta=self.forward_meta) + self.exist_prefill_flag = False + + def _post_process(self, sampled_token_ids): + """ + PostProcess for generation + """ + draft_model_update( + sampled_token_ids, + self.model_inputs["draft_tokens"], + self.model_inputs["pre_ids"], + self.model_inputs["seq_lens_this_time"], + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["step_idx"], + self.model_inputs["cu_seqlens_q_output"], + self.model_inputs["stop_flags"], + self.model_inputs["not_need_stop_device"], + self.model_inputs["max_dec_len"], + self.model_inputs["eos_token_id"], + self.model_inputs["base_model_draft_tokens"], + self.max_model_len, + self.model_inputs["substep"], + ) + + if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0: + skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) + recover_model_output_map = recover_batch_index_for_output( + self.model_inputs, + self.model_inputs.index_to_batch_id, + self.model_inputs.enable_pd_reorder, + ["base_model_draft_tokens", "seq_lens_decoder", "prompt_lens", "step_idx"], + ) + mtp_save_first_token( + recover_model_output_map["base_model_draft_tokens"], + self.model_inputs["not_need_stop"], + recover_model_output_map["seq_lens_decoder"], + recover_model_output_map["prompt_lens"], + recover_model_output_map["step_idx"], + self.local_rank, + self.parallel_config.use_ep, + skip_save, + ) + # Ensure only save first token once. + paddle.assign( + paddle.where( + self.model_inputs["stop_flags"], + paddle.zeros_like(self.model_inputs["step_idx"]), + self.model_inputs["step_idx"], + ), + self.model_inputs["step_idx"], + ) + + def _update_status(self): + """ + Update main-model's forward info in next step. + Allocate/Free block of MPT. + """ + draft_model_postprocess( + self.target_model_inputs["draft_tokens"], + self.target_model_inputs["seq_lens_this_time"], + self.target_model_inputs["seq_lens_encoder"], + self.target_model_inputs["stop_flags"], + ) + + def _extend_draft_token_with_ngram_match(self): + # TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency + device = paddle.CUDAPinnedPlace() + + draft_tokens = self.target_model_inputs["draft_tokens"].cpu() + seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu() + seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu() + hybrid_mtp_ngram( + self.model_inputs["input_ids_cpu"], + self.model_inputs["input_ids_len"], + self.model_inputs["pre_ids"]._copy_to(device, True), + self.model_inputs["step_idx"].cpu(), + self.target_model_inputs["actual_draft_token_num"].cpu(), + draft_tokens, + seq_lens_this_time, + seq_lens_decoder, + self.model_inputs["max_dec_len"].cpu(), + self.max_ngram_size, + self.min_ngram_size, + self.max_draft_token_num, + ) + self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda() + self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() + + def padding_cudagraph_inputs(self) -> None: + """ + Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. + In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch. + """ + # In init_attention_metadata, the decode buffer has already been cleared + + # To adapt to CUDA Graph, keep the forward pass at the maximum batch size. + if self.forward_meta.step_use_cudagraph: + self.forward_meta.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"] + self.real_token_num = self.forward_meta.ids_remove_padding.shape[0] + return + + def clear_mtp_cache(self, profile=False): + """ + Clear allocated cacheKV + """ + create_cache_tensor = profile or not ( + self.fd_config.cache_config.num_cpu_blocks > 0 + or self.fd_config.cache_config.kvcache_storage_backend + or self.fd_config.scheduler_config.splitwise_role != "mixed" + ) + if not create_cache_tensor: + for name, tensor in self.cache_kvs_map.items(): + unset_data_ipc(tensor, name, True, False) + self.cache_kvs_map.clear() + del self.model_inputs["caches"] + if self.forward_meta is not None: + del self.forward_meta.caches diff --git a/fastdeploy/spec_decode/mtp_xpu.py b/fastdeploy/spec_decode/mtp_xpu.py new file mode 100644 index 00000000000..bc5107799b9 --- /dev/null +++ b/fastdeploy/spec_decode/mtp_xpu.py @@ -0,0 +1,276 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy import envs +from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata +from fastdeploy.model_executor.ops.xpu import ( + draft_model_postprocess, + draft_model_preprocess, + draft_model_update, + eagle_get_hidden_states, + eagle_get_self_hidden_states, + mtp_save_first_token, + update_attn_mask_offsets, +) +from fastdeploy.model_executor.xpu_pre_and_post_process import ( + xpu_pre_process, + xpu_process_output, +) +from fastdeploy.worker.input_batch import recover_batch_index_for_output + +from .mtp import MTPProposer + + +class MTPProposerXPU(MTPProposer): + """ + XPU-specific MTPProposer implementation. + """ + + def _prepare_inputs(self, full_hidden_states): + use_v1_cache_scheduler = bool(envs.ENABLE_V1_KVCACHE_SCHEDULER) + draft_model_preprocess( + self.model_inputs["draft_tokens"], + self.model_inputs["input_ids"], + self.model_inputs["stop_flags"], + self.model_inputs["seq_lens_this_time"], + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["step_idx"], + self.model_inputs["not_need_stop"], + self.model_inputs["batch_drop"], + self.model_inputs["is_block_step"], + self.model_inputs["pre_ids"], + self.model_inputs["mask_rollback"], + self.model_inputs["recompute_token_num"], + self.target_model_inputs["accept_tokens"], + self.target_model_inputs["accept_num"], + self.target_model_inputs["seq_lens_this_time"], + self.target_model_inputs["seq_lens_encoder"], + self.target_model_inputs["seq_lens_decoder"], + self.target_model_inputs["step_idx"], + self.target_model_inputs["stop_flags"], + self.target_model_inputs["is_block_step"], + self.target_model_inputs["draft_tokens"], + self.num_model_steps, + True, + self.role == "prefill", + use_v1_cache_scheduler, + ) + + target_hidden_states = eagle_get_hidden_states( + full_hidden_states, + self.model_inputs["seq_lens_this_time"], + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["stop_flags"], + self.target_model_inputs["accept_num"], + self.target_model_inputs["seq_lens_this_time"], + self.target_model_inputs["seq_lens_encoder"], + self.num_model_steps, + ) + self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) + + def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0): + + self.forward_meta.decoder_batch_ids = (self.model_inputs["decoder_batch_ids"],) + self.forward_meta.decoder_tile_ids_per_batch = (self.model_inputs["decoder_tile_ids_per_batch"],) + self.forward_meta.decoder_num_blocks_cpu = (self.model_inputs["decoder_num_blocks_cpu"],) + self.forward_meta.decoder_num_blocks_device = (self.model_inputs["decoder_num_blocks_device"],) + self.forward_meta.decoder_chunk_size_device = (self.model_inputs["decoder_chunk_size_device"],) + self.forward_meta.max_len_tensor_cpu = (self.model_inputs["max_len_tensor_cpu"],) + + self.forward_meta.encoder_batch_ids = (self.model_inputs["encoder_batch_ids"],) + self.forward_meta.encoder_tile_ids_per_batch = (self.model_inputs["encoder_tile_ids_per_batch"],) + self.forward_meta.encoder_num_blocks_x_cpu = (self.model_inputs["encoder_num_blocks_x_cpu"],) + self.forward_meta.kv_batch_ids = (self.model_inputs["kv_batch_ids"],) + self.forward_meta.kv_tile_ids_per_batch = (self.model_inputs["kv_tile_ids_per_batch"],) + self.forward_meta.kv_num_blocks_x_cpu = (self.model_inputs["kv_num_blocks_x_cpu"],) + self.forward_meta.attn_backend = self.attn_backends[0] + if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query": + self.forward_meta.kv_signal_sender = self.target_model_inputs["kv_signal_sender"] + + self.forward_meta.is_draft = True + + # Initialzie attention meta data + for attn_backend in self.attn_backends: + attn_backend.init_attention_metadata(self.forward_meta) + + def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0): + """ + Main process for MTP inference. + Args: + step_use_cudagraph: bool + Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP. + """ + for substep in range(self.num_model_steps): + if self.model_inputs["not_need_stop"]: + self.model_inputs["substep"] = substep + # Remove padding + self.forward_meta = xpu_pre_process( + self.model_inputs["input_ids"], + self.model_inputs["seq_lens_this_time"], + self.model_inputs, + True, + self.cache_config.block_size, + self.model_inputs["draft_tokens"], + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + ) + + if self.enable_mm: + attn_mask_offsets = update_attn_mask_offsets( + self.model_inputs["ids_remove_padding"], + getattr( + self.model_inputs, "seq_lens_this_time", self.model_inputs["seq_lens_this_time_buffer"] + ), + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["cu_seqlens_q"], + self.model_inputs["attn_mask_offsets_full"], + self.model_inputs["attn_mask_offsets_decoder"], + self.model_inputs["is_block_step"], + self.model_inputs["decode_states"], + self.model_inputs["mask_rollback"], + ) + self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False) + + self._initialize_forward_meta_xpu() + # Get sampling metadata + self.sampling_metadata = SamplingMetadata( + temperature=self.model_inputs["temperature"], + top_p=self.model_inputs["top_p"], + top_k=self.model_inputs["top_k"], + seed=self.model_inputs["infer_seed"], + step_idx=self.model_inputs["step_idx"], + pre_token_ids=self.model_inputs["pre_ids"], + frequency_penalties=self.model_inputs["frequency_score"], + presence_penalties=self.model_inputs["presence_score"], + repetition_penalties=self.model_inputs["penalty_score"], + min_dec_lens=self.model_inputs["min_dec_len"], + bad_words_token_ids=self.model_inputs["bad_tokens"], + eos_token_ids=self.model_inputs["eos_token_id"], + max_num_logprobs=20 if self.enable_logprob else None, + temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"], + top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"], + share_inputs=self.model_inputs, + ) + + if self.num_model_steps > 1: + self.model_inputs.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"]) + + model_output = self.model( + ids_remove_padding=self.model_inputs["ids_remove_padding"], + previous_hidden_states=self.model_inputs["target_hidden_states"], + forward_meta=self.forward_meta, + ) + hidden_states = xpu_process_output( + model_output, self.model_inputs["cum_offsets"], self.forward_meta, self.model_inputs + ) + # 4. Compute logits, Sample + logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta) + sampled_token_ids, sampler_output = self.sampler( + logits, + self.sampling_metadata, + self.max_model_len, + self.model_inputs, + ) + + if self.parallel_config.tensor_parallel_size > 1: + paddle.distributed.broadcast( + sampled_token_ids, + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + + self._post_process(sampled_token_ids) + if substep != self.num_model_steps - 1: + self._get_self_hidden_states(hidden_states) + else: + if hasattr(self.model, "empty_input_forward") and not is_dummy_run: + self.model.empty_input_forward(self.forward_meta) + + def _get_self_hidden_states(self, hidden_states): + target_hidden_states = eagle_get_self_hidden_states( + hidden_states, + self.model_inputs.last_seq_lens_this_time, + self.model_inputs["seq_lens_this_time"], + self.model_inputs["step_idx"], + ) + self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) + + def _post_process(self, sampled_token_ids): + """ + PostProcess for generation + """ + draft_model_update( + sampled_token_ids, + self.model_inputs["draft_tokens"], + self.model_inputs["pre_ids"], + self.model_inputs["seq_lens_this_time"], + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["step_idx"], + self.model_inputs["output_cum_offsets"], + self.model_inputs["stop_flags"], + self.model_inputs["not_need_stop"], + self.model_inputs["max_dec_len"], + self.model_inputs["eos_token_id"], + self.model_inputs["base_model_draft_tokens"], + self.max_model_len, + self.model_inputs["substep"], + ) + + if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0: + skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) + recover_model_output_map = recover_batch_index_for_output( + self.model_inputs, + self.model_inputs.index_to_batch_id, + self.model_inputs.enable_pd_reorder, + ["base_model_draft_tokens", "seq_lens_decoder", "prompt_lens", "step_idx"], + ) + mtp_save_first_token( + recover_model_output_map["base_model_draft_tokens"], + self.model_inputs["not_need_stop"], + recover_model_output_map["seq_lens_decoder"], + recover_model_output_map["prompt_lens"], + recover_model_output_map["step_idx"], + self.local_rank, + self.parallel_config.use_ep, + skip_save, + ) + # Ensure only save first token once. + paddle.assign( + paddle.where( + self.model_inputs["stop_flags"], + paddle.zeros_like(self.model_inputs["step_idx"]), + self.model_inputs["step_idx"], + ), + self.model_inputs["step_idx"], + ) + + def _update_status(self): + """ + Update main-model's forward info in next step. + Allocate/Free block of MPT. + """ + draft_model_postprocess( + self.target_model_inputs["draft_tokens"], + self.target_model_inputs["seq_lens_this_time"], + self.target_model_inputs["seq_lens_encoder"], + self.target_model_inputs["stop_flags"], + ) diff --git a/fastdeploy/spec_decode/types.py b/fastdeploy/spec_decode/types.py index 3473d810bb9..47a66341306 100644 --- a/fastdeploy/spec_decode/types.py +++ b/fastdeploy/spec_decode/types.py @@ -91,9 +91,9 @@ def create_proposer(self, fd_config, **kwargs) -> Optional["Proposer"]: if self == SpecMethod.NAIVE: return None elif self == SpecMethod.MTP: - from fastdeploy.spec_decode.mtp import MTPProposer + from fastdeploy.spec_decode.mtp import create_mtp_proposer - return MTPProposer( + return create_mtp_proposer( fd_config, kwargs["main_model"], kwargs["local_rank"], diff --git a/tests/spec_decode/test_mtp_proposer.py b/tests/spec_decode/test_mtp_proposer.py index 160e4011df1..2865d9d605a 100644 --- a/tests/spec_decode/test_mtp_proposer.py +++ b/tests/spec_decode/test_mtp_proposer.py @@ -25,7 +25,7 @@ from fastdeploy.config import SpeculativeConfig from fastdeploy.engine.request import Request, RequestType -from fastdeploy.spec_decode.mtp import MTPProposer +from fastdeploy.spec_decode.mtp_cuda import MTPProposerCUDA as MTPProposer class TestMTPProposer(unittest.TestCase): @@ -121,6 +121,7 @@ def setUp(self): @patch("fastdeploy.spec_decode.mtp.get_attention_backend") @patch("fastdeploy.worker.input_batch.get_rope") def test_init_and_config_methods(self, mock_rope, mock_attn_backend, mock_model_loader): + # Note: get_model_loader/get_attention_backend are still in mtp.py base class """Test initialization and config update methods""" mock_model = Mock() mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) @@ -352,7 +353,7 @@ def test_insert_prefill_inputs(self, mock_rope, mock_attn_backend, mock_model_lo def test_forward_meta_and_exist_prefill( self, mock_rope, mock_attn_backend, mock_model_loader, mock_ipc_signal_cls ): - """Test _initialize_forward_meta, _initialize_forward_meta_xpu, and exist_prefill""" + """Test _initialize_forward_meta and exist_prefill""" mock_ipc_signal = Mock() mock_ipc_signal.value = [0] * self.fd_config.parallel_config.tensor_parallel_size mock_ipc_signal_cls.return_value = mock_ipc_signal @@ -370,14 +371,12 @@ def test_forward_meta_and_exist_prefill( proposer.initialize_kv_cache(main_model_num_blocks=10) proposer.model_inputs.seq_lens_this_time = proposer.model_inputs["seq_lens_this_time_buffer"] - # Test _initialize_forward_meta + # Test _initialize_forward_meta (CUDA version creates a new ForwardMeta) proposer._initialize_forward_meta(step_use_cudagraph=False) self.assertIsNotNone(proposer.forward_meta) - # Test _initialize_forward_meta_xpu - proposer._initialize_forward_meta_xpu() - if hasattr(proposer.forward_meta, "pos_emb_type") and proposer.forward_meta.pos_emb_type is not None: - self.assertEqual(proposer.forward_meta.pos_emb_type, "NORMAL") + # NOTE: _initialize_forward_meta_xpu is now in MTPProposerXPU (mtp_xpu.py). + # XPU-specific forward_meta initialization is tested separately in XPU environment. # Test exist_prefill proposer.exist_prefill_flag = True @@ -391,8 +390,8 @@ def test_forward_meta_and_exist_prefill( @patch("fastdeploy.spec_decode.mtp.get_model_loader") @patch("fastdeploy.spec_decode.mtp.get_attention_backend") @patch("fastdeploy.worker.input_batch.get_rope") - @patch("fastdeploy.spec_decode.mtp.draft_model_preprocess") - @patch("fastdeploy.spec_decode.mtp.eagle_get_hidden_states") + @patch("fastdeploy.spec_decode.mtp_cuda.draft_model_preprocess") + @patch("fastdeploy.spec_decode.mtp_cuda.eagle_get_hidden_states") def test_prepare_inputs_and_post_process( self, mock_eagle, mock_preprocess, mock_rope, mock_attn_backend, mock_model_loader ): @@ -423,44 +422,6 @@ def test_prepare_inputs_and_post_process( sampled_token_ids = paddle.ones([2, 1], dtype="int64") proposer._post_process(sampled_token_ids) - @patch("fastdeploy.spec_decode.mtp.current_platform") - @patch("fastdeploy.spec_decode.mtp.get_model_loader") - @patch("fastdeploy.spec_decode.mtp.get_attention_backend") - @patch("fastdeploy.worker.input_batch.get_rope") - @patch("fastdeploy.spec_decode.mtp.draft_model_preprocess") - @patch("fastdeploy.spec_decode.mtp.eagle_get_hidden_states") - def test_prepare_inputs_xpu_branch( - self, mock_eagle, mock_preprocess, mock_rope, mock_attn_backend, mock_model_loader, mock_platform - ): - """Test _prepare_inputs XPU branch (line 754)""" - mock_platform.is_cuda.return_value = False - mock_platform.is_maca.return_value = False - mock_platform.is_xpu.return_value = True - - mock_model = Mock() - mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) - mock_model_loader.return_value.load_model.return_value = mock_model - mock_attn = Mock() - mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) - mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn - mock_rope.return_value = paddle.zeros([1, 2048, 64]) - # XPU branch returns only target_hidden_states, not tuple - mock_eagle.return_value = paddle.zeros([2, 768], dtype="bfloat16") - mock_preprocess.return_value = None - - proposer = MTPProposer( - self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs - ) - full_hidden_states = paddle.zeros([2, 768], dtype="bfloat16") - proposer.model_inputs.seq_lens_this_time = proposer.model_inputs["seq_lens_this_time_buffer"] - - # Test _prepare_inputs with XPU platform (covers line 754) - proposer._prepare_inputs(full_hidden_states) - mock_preprocess.assert_called() - mock_eagle.assert_called() - # Verify eagle_get_hidden_states was called without returning output_token_num - self.assertEqual(mock_eagle.call_count, 1) - @patch("fastdeploy.spec_decode.mtp.get_model_loader") @patch("fastdeploy.spec_decode.mtp.get_attention_backend") @patch("fastdeploy.worker.input_batch.get_rope") @@ -497,88 +458,11 @@ def test_update_task_chunk_prefill(self, mock_rope, mock_attn_backend, mock_mode task.chunk_idx = 2 proposer.update_task_chunk_prefill(task) - # NOTE: Temporarily skipped - _get_self_hidden_states_cuda method does not exist - # @patch("fastdeploy.spec_decode.mtp.eagle_get_self_hidden_states") - # @patch("fastdeploy.spec_decode.mtp.get_model_loader") - # @patch("fastdeploy.spec_decode.mtp.get_attention_backend") - # @patch("fastdeploy.worker.input_batch.get_rope") - # def test_get_self_hidden_states_cuda( - # self, mock_rope, mock_attn_backend, mock_model_loader, mock_eagle_self_hidden - # ): - # """Test _get_self_hidden_states_cuda method (lines 1140-1148)""" - # mock_model = Mock() - # mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) - # mock_model_loader.return_value.load_model.return_value = mock_model - # mock_attn = Mock() - # mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) - # mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn - # mock_rope.return_value = paddle.zeros([1, 2048, 64]) - # mock_eagle_self_hidden.return_value = ( - # paddle.zeros([2, 768], dtype="bfloat16"), - # paddle.to_tensor([2], dtype="int32"), - # ) - # - # # Use num_speculative_tokens=2 to ensure num_model_steps > 1 - # self.fd_config.speculative_config.num_speculative_tokens = 2 - # proposer = MTPProposer( - # self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs - # ) - # proposer.model_inputs.seq_lens_this_time = proposer.model_inputs["seq_lens_this_time_buffer"] - # proposer.model_inputs.last_seq_lens_this_time = paddle.zeros([2, 1], dtype="int32") - # proposer.model_inputs["step_idx"] = paddle.ones([2, 1], dtype="int64") - # - # hidden_states = paddle.zeros([2, 768], dtype="bfloat16") - # - # # Test _get_self_hidden_states_cuda directly (covers lines 1140-1148) - # proposer._get_self_hidden_states_cuda(hidden_states) - # - # # Verify eagle_get_self_hidden_states was called - # mock_eagle_self_hidden.assert_called_once() - - @patch("fastdeploy.spec_decode.mtp.eagle_get_self_hidden_states") - @patch("fastdeploy.spec_decode.mtp.current_platform") - @patch("fastdeploy.spec_decode.mtp.get_model_loader") - @patch("fastdeploy.spec_decode.mtp.get_attention_backend") - @patch("fastdeploy.worker.input_batch.get_rope") - def test_get_self_hidden_states_xpu( - self, mock_rope, mock_attn_backend, mock_model_loader, mock_platform, mock_eagle_self_hidden - ): - """Test _get_self_hidden_states_xpu method (lines 1130-1137, 1125)""" - mock_platform.is_cuda.return_value = False - mock_platform.is_maca.return_value = False - mock_platform.is_xpu.return_value = True - - mock_model = Mock() - mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) - mock_model_loader.return_value.load_model.return_value = mock_model - mock_attn = Mock() - mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) - mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn - mock_rope.return_value = paddle.zeros([1, 2048, 64]) - mock_eagle_self_hidden.return_value = paddle.zeros([2, 768], dtype="bfloat16") - - # Use num_speculative_tokens=2 to ensure num_model_steps > 1 - self.fd_config.speculative_config.num_speculative_tokens = 2 - proposer = MTPProposer( - self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs - ) - proposer.model_inputs.seq_lens_this_time = proposer.model_inputs["seq_lens_this_time_buffer"] - proposer.model_inputs.last_seq_lens_this_time = paddle.zeros([2, 1], dtype="int32") - proposer.model_inputs["step_idx"] = paddle.ones([2, 1], dtype="int64") - - hidden_states = paddle.zeros([2, 768], dtype="bfloat16") - - # Test _get_self_hidden_states_xpu directly (covers lines 1130-1137) - proposer._get_self_hidden_states_xpu(hidden_states) - - # Verify eagle_get_self_hidden_states was called - mock_eagle_self_hidden.assert_called_once() - @patch("fastdeploy.spec_decode.mtp.get_model_loader") @patch("fastdeploy.spec_decode.mtp.get_attention_backend") @patch("fastdeploy.worker.input_batch.get_rope") - @patch("fastdeploy.spec_decode.mtp.draft_model_postprocess") - @patch("fastdeploy.spec_decode.mtp.mtp_step_paddle") + @patch("fastdeploy.spec_decode.mtp_cuda.draft_model_postprocess") + @patch("fastdeploy.spec_decode.mtp_cuda.mtp_step_paddle") def test_update_status(self, mock_mtp_step, mock_postprocess, mock_rope, mock_attn_backend, mock_model_loader): """Test _update_status""" mock_model = Mock() @@ -597,14 +481,14 @@ def test_update_status(self, mock_mtp_step, mock_postprocess, mock_rope, mock_at proposer.model_inputs.seq_lens_this_time = proposer.model_inputs["seq_lens_this_time_buffer"] # Test with ENABLE_V1_KVCACHE_SCHEDULER=False - with patch("fastdeploy.spec_decode.mtp.envs.ENABLE_V1_KVCACHE_SCHEDULER", False): + with patch("fastdeploy.spec_decode.mtp_cuda.envs.ENABLE_V1_KVCACHE_SCHEDULER", False): proposer._update_status() mock_mtp_step.assert_called() @patch("fastdeploy.spec_decode.mtp.get_model_loader") @patch("fastdeploy.spec_decode.mtp.get_attention_backend") @patch("fastdeploy.worker.input_batch.get_rope") - @patch("fastdeploy.spec_decode.mtp.hybrid_mtp_ngram") + @patch("fastdeploy.spec_decode.mtp_cuda.hybrid_mtp_ngram") def test_extend_draft_token_and_run_impl(self, mock_ngram, mock_rope, mock_attn_backend, mock_model_loader): """Test _extend_draft_token_with_ngram_match and _run_impl""" mock_model = Mock() @@ -675,9 +559,8 @@ def test_padding_cudagraph_inputs_and_empty_cache( @patch("fastdeploy.spec_decode.mtp.get_model_loader") @patch("fastdeploy.spec_decode.mtp.get_attention_backend") @patch("fastdeploy.worker.input_batch.get_rope") - @patch("fastdeploy.spec_decode.mtp.current_platform") - def test_cache_type_branches(self, mock_platform, mock_rope, mock_attn_backend, mock_model_loader): - """Cover _get_cache_type CUDA/XPU/unsupported branches""" + def test_cache_type_branches(self, mock_rope, mock_attn_backend, mock_model_loader): + """Cover _get_cache_type for MTPProposerCUDA (always returns 'uint8').""" mock_model = Mock() mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) mock_model_loader.return_value.load_model.return_value = mock_model @@ -686,28 +569,11 @@ def test_cache_type_branches(self, mock_platform, mock_rope, mock_attn_backend, mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn mock_rope.return_value = paddle.zeros([1, 2048, 64]) - # CUDA branch - mock_platform.is_cuda.return_value = True - mock_platform.is_xpu.return_value = False proposer = MTPProposer( self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs ) self.assertEqual(proposer._get_cache_type(), "uint8") - # XPU branch - mock_platform.is_cuda.return_value = False - mock_platform.is_xpu.return_value = True - proposer = MTPProposer( - self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs - ) - self.assertEqual(proposer._get_cache_type(), "int8") - - # Unsupported branch: reuse existing proposer to avoid RuntimeError in __init__ - mock_platform.is_cuda.return_value = False - mock_platform.is_xpu.return_value = False - with self.assertRaises(NotImplementedError): - proposer._get_cache_type() - @patch("fastdeploy.spec_decode.mtp.get_model_loader") @patch("fastdeploy.spec_decode.mtp.get_attention_backend") @patch("fastdeploy.worker.input_batch.get_rope") @@ -764,21 +630,20 @@ def test_insert_tasks_v1_preempted(self, mock_rope, mock_attn_backend, mock_mode self.assertTrue(proposer.model_inputs["stop_flags"][0].item()) self.assertEqual(proposer.model_inputs["seq_lens_this_time_buffer"][0].item(), 0) - @patch("fastdeploy.spec_decode.mtp.get_model_loader") - @patch("fastdeploy.spec_decode.mtp.get_attention_backend") - @patch("fastdeploy.worker.input_batch.get_rope") @patch("fastdeploy.spec_decode.mtp.current_platform") - def test_unsupported_platform_raises_runtime_error( - self, mock_platform, mock_rope, mock_attn_backend, mock_model_loader - ): - """Cover RuntimeError in __init__ when platform is unsupported (line 120).""" + def test_unsupported_platform_raises_runtime_error(self, mock_platform): + """Cover RuntimeError in create_mtp_proposer when platform is unsupported.""" mock_platform.is_xpu.return_value = False mock_platform.is_cuda.return_value = False mock_platform.is_maca.return_value = False mock_platform.__str__ = lambda self: "UnsupportedPlatform" + from fastdeploy.spec_decode.mtp import create_mtp_proposer + with self.assertRaises(RuntimeError) as ctx: - MTPProposer(self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs) + create_mtp_proposer( + self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs + ) self.assertIn("Unsupported platform for MTP", str(ctx.exception)) From ecd9f7c7cea3bb167498f009a68bd49b56d58dab Mon Sep 17 00:00:00 2001 From: Deleter-D <867909454@qq.com> Date: Thu, 2 Apr 2026 15:51:26 +0800 Subject: [PATCH 2/2] fix ut --- tests/spec_decode/test_mtp_proposer.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/spec_decode/test_mtp_proposer.py b/tests/spec_decode/test_mtp_proposer.py index 2865d9d605a..0812aecd010 100644 --- a/tests/spec_decode/test_mtp_proposer.py +++ b/tests/spec_decode/test_mtp_proposer.py @@ -462,8 +462,7 @@ def test_update_task_chunk_prefill(self, mock_rope, mock_attn_backend, mock_mode @patch("fastdeploy.spec_decode.mtp.get_attention_backend") @patch("fastdeploy.worker.input_batch.get_rope") @patch("fastdeploy.spec_decode.mtp_cuda.draft_model_postprocess") - @patch("fastdeploy.spec_decode.mtp_cuda.mtp_step_paddle") - def test_update_status(self, mock_mtp_step, mock_postprocess, mock_rope, mock_attn_backend, mock_model_loader): + def test_update_status(self, mock_postprocess, mock_rope, mock_attn_backend, mock_model_loader): """Test _update_status""" mock_model = Mock() mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) @@ -473,18 +472,12 @@ def test_update_status(self, mock_mtp_step, mock_postprocess, mock_rope, mock_at mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn mock_rope.return_value = paddle.zeros([1, 2048, 64]) mock_postprocess.return_value = None - mock_mtp_step.return_value = None proposer = MTPProposer( self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs ) proposer.model_inputs.seq_lens_this_time = proposer.model_inputs["seq_lens_this_time_buffer"] - # Test with ENABLE_V1_KVCACHE_SCHEDULER=False - with patch("fastdeploy.spec_decode.mtp_cuda.envs.ENABLE_V1_KVCACHE_SCHEDULER", False): - proposer._update_status() - mock_mtp_step.assert_called() - @patch("fastdeploy.spec_decode.mtp.get_model_loader") @patch("fastdeploy.spec_decode.mtp.get_attention_backend") @patch("fastdeploy.worker.input_batch.get_rope")