From 720a9fbebc17eb3a05c847000350410fc1c09352 Mon Sep 17 00:00:00 2001 From: bukejiyu <395822456@qq.com> Date: Tue, 31 Mar 2026 19:32:01 +0800 Subject: [PATCH 1/3] code style mtp --- fastdeploy/spec_decode/mtp.py | 201 ++++++++++++++++++++++++++-------- 1 file changed, 155 insertions(+), 46 deletions(-) diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 7ebe86b852f..2e13b2d7458 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -747,9 +747,9 @@ def _prepare_inputs(self, full_hidden_states): self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) - def _post_process(self, sampled_token_ids): + def _post_process_pd_reorder(self, sampled_token_ids): """ - PostProcess for generation + PostProcess for generation for pd reorder """ draft_model_update( sampled_token_ids, @@ -804,6 +804,98 @@ def _post_process(self, sampled_token_ids): self.model_inputs["step_idx"], ) + def _post_process(self, sampled_token_ids): + """ + PostProcess for generation + """ + draft_model_update( + sampled_token_ids, + self.model_inputs["draft_tokens"], + self.model_inputs["pre_ids"], + self.model_inputs["seq_lens_this_time"], + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["step_idx"], + # Note(ZKK): + # I strongly advise xpu student delete the fuck `output_cum_offsets` name in XPU backend + # like my pr https://github.com/PaddlePaddle/FastDeploy/pull/6358 + ( + self.model_inputs["cu_seqlens_q_output"] + if current_platform.is_cuda() + else self.model_inputs["output_cum_offsets"] + ), + self.model_inputs["stop_flags"], + self.model_inputs["not_need_stop"], + self.model_inputs["max_dec_len"], + self.model_inputs["eos_token_id"], + self.model_inputs["base_model_draft_tokens"], + self.max_model_len, + self.model_inputs["substep"], + ) + + if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0: + skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) + mtp_save_first_token( + self.model_inputs["base_model_draft_tokens"], + self.model_inputs["not_need_stop"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["prompt_lens"], + self.model_inputs["step_idx"], + self.local_rank, + self.parallel_config.use_ep, + skip_save, + ) + # Ensure only save first token once. + paddle.assign( + paddle.where( + self.model_inputs["stop_flags"], + paddle.zeros_like(self.model_inputs["step_idx"]), + self.model_inputs["step_idx"], + ), + self.model_inputs["step_idx"], + ) + + def _speculate_save_output_for_cuda(self, sampler_output): + real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] + if not self.model_inputs.enable_pd_reorder: + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + self.model_inputs["batch_token_num"][:real_bsz], + self.model_inputs["cu_batch_token_offset"][:real_bsz], + self.model_inputs["not_need_stop"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["prompt_lens"], + 4, # mtype + self.local_rank, + self.parallel_config.use_ep, + ) + else: + recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id) + recover_model_output_map = recover_batch_index_for_output( + self.model_inputs, + self.model_inputs.index_to_batch_id, + self.model_inputs.enable_pd_reorder[ + "batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens" + ], + ) + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + recover_model_output_map["batch_token_num"][:real_bsz], + recover_model_output_map["cu_batch_token_offset"][:real_bsz], + self.model_inputs["not_need_stop"], + recover_model_output_map["seq_lens_decoder"], + recover_model_output_map["prompt_lens"], + 4, # mtype + self.local_rank, + self.parallel_config.use_ep, + ) + def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False): """ Main process for MTP inference. @@ -955,29 +1047,7 @@ def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = F and substep == 0 and sampler_output.logprobs_tensors is not None ): - real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] - recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id) - recover_model_output_map = recover_batch_index_for_output( - self.model_inputs, - self.model_inputs.index_to_batch_id, - self.model_inputs.enable_pd_reorder[ - "batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens" - ], - ) - speculate_save_output_topk( - sampler_output.sampled_token_ids, - sampler_output.logprobs_tensors.logprob_token_ids, - sampler_output.logprobs_tensors.logprobs, - sampler_output.logprobs_tensors.selected_token_ranks, - recover_model_output_map["batch_token_num"][:real_bsz], - recover_model_output_map["cu_batch_token_offset"][:real_bsz], - self.model_inputs["not_need_stop"], - recover_model_output_map["seq_lens_decoder"], - recover_model_output_map["prompt_lens"], - 4, # mtype - self.local_rank, - self.parallel_config.use_ep, - ) + self._speculate_save_output_for_cuda(sampler_output) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( @@ -985,14 +1055,49 @@ def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = F self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) - - self._post_process(sampled_token_ids) + if not self.model_inputs.enable_pd_reorder: + self._post_process(sampled_token_ids) + else: + self._post_process_pd_reorder(sampled_token_ids) if substep != self.num_model_steps - 1: self._get_self_hidden_states(hidden_states) else: if hasattr(self.model, "empty_input_forward") and not is_dummy_run: self.model.empty_input_forward(forward_meta=self.forward_meta) + def _speculate_save_output_for_xpu(self, sampler_output): + real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] + if not self.model_inputs.enable_pd_reorder: + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + self.model_inputs["batch_token_num"][:real_bsz], + self.model_inputs["cu_batch_token_offset"][:real_bsz], + self.model_inputs["not_need_stop"], + 4, # mtype + self.local_rank, + ) + else: + recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id) + recover_model_output_map = recover_batch_index_for_output( + self.model_inputs, + self.model_inputs.index_to_batch_id, + self.model_inputs.enable_pd_reorder["batch_token_num", "cu_batch_token_offset"], + ) + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + recover_model_output_map["batch_token_num"][:real_bsz], + recover_model_output_map["cu_batch_token_offset"][:real_bsz], + self.model_inputs["not_need_stop"], + 4, # mtype + self.local_rank, + ) + def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False): """ Main process for MTP inference. @@ -1074,24 +1179,25 @@ def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = Fa ) if substep == 0 and sampler_output.logprobs_tensors is not None: - real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] - recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id) - recover_model_output_map = recover_batch_index_for_output( - self.model_inputs, - self.model_inputs.index_to_batch_id, - self.model_inputs.enable_pd_reorder["batch_token_num", "cu_batch_token_offset"], - ) - speculate_save_output_topk( - sampler_output.sampled_token_ids, - sampler_output.logprobs_tensors.logprob_token_ids, - sampler_output.logprobs_tensors.logprobs, - sampler_output.logprobs_tensors.selected_token_ranks, - recover_model_output_map["batch_token_num"][:real_bsz], - recover_model_output_map["cu_batch_token_offset"][:real_bsz], - self.model_inputs["not_need_stop"], - 4, # mtype - self.local_rank, - ) + self._speculate_save_output_for_xpu(sampler_output) + # real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] + # recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id) + # recover_model_output_map = recover_batch_index_for_output( + # self.model_inputs, + # self.model_inputs.index_to_batch_id, + # self.model_inputs.enable_pd_reorder["batch_token_num", "cu_batch_token_offset"], + # ) + # speculate_save_output_topk( + # sampler_output.sampled_token_ids, + # sampler_output.logprobs_tensors.logprob_token_ids, + # sampler_output.logprobs_tensors.logprobs, + # sampler_output.logprobs_tensors.selected_token_ranks, + # recover_model_output_map["batch_token_num"][:real_bsz], + # recover_model_output_map["cu_batch_token_offset"][:real_bsz], + # self.model_inputs["not_need_stop"], + # 4, # mtype + # self.local_rank, + # ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( @@ -1100,7 +1206,10 @@ def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = Fa group=self.parallel_config.tp_group, ) - self._post_process(sampled_token_ids) + if not self.model_inputs.enable_pd_reorder: + self._post_process(sampled_token_ids) + else: + self._post_process_pd_reorder(sampled_token_ids) if substep != self.num_model_steps - 1: self._get_self_hidden_states(hidden_states) else: From 0c2ec5c47a3c7b488e136c1de3d85d919ccb18fd Mon Sep 17 00:00:00 2001 From: bukejiyu <395822456@qq.com> Date: Tue, 31 Mar 2026 19:35:09 +0800 Subject: [PATCH 2/3] update --- fastdeploy/spec_decode/mtp.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 2e13b2d7458..b7264790157 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -1180,24 +1180,6 @@ def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = Fa if substep == 0 and sampler_output.logprobs_tensors is not None: self._speculate_save_output_for_xpu(sampler_output) - # real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] - # recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id) - # recover_model_output_map = recover_batch_index_for_output( - # self.model_inputs, - # self.model_inputs.index_to_batch_id, - # self.model_inputs.enable_pd_reorder["batch_token_num", "cu_batch_token_offset"], - # ) - # speculate_save_output_topk( - # sampler_output.sampled_token_ids, - # sampler_output.logprobs_tensors.logprob_token_ids, - # sampler_output.logprobs_tensors.logprobs, - # sampler_output.logprobs_tensors.selected_token_ranks, - # recover_model_output_map["batch_token_num"][:real_bsz], - # recover_model_output_map["cu_batch_token_offset"][:real_bsz], - # self.model_inputs["not_need_stop"], - # 4, # mtype - # self.local_rank, - # ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( From f1e1b5bbc142a5a506c205c93a0a76c2b4693c0d Mon Sep 17 00:00:00 2001 From: bukejiyu <395822456@qq.com> Date: Tue, 31 Mar 2026 19:47:25 +0800 Subject: [PATCH 3/3] code style --- .../model_executor/pre_and_post_process.py | 241 +++++++++++------- 1 file changed, 149 insertions(+), 92 deletions(-) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index c76fd4167f1..0f847a0d7f3 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -380,55 +380,82 @@ def save_output_normal( # In the future, we will abandon this approach. if envs.FD_USE_GET_SAVE_OUTPUT_V1: if save_each_rank or model_output.mp_rank == 0: - recover_share_inputs_map = recover_batch_index_for_output( - share_inputs, - model_output.index_to_batch_id, - model_output.enable_pd_reorder, - ["sampled_token_ids"], - ) - recover_batch_index_for_sampler_output( - sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder - ) - output = _build_stream_transfer_data( - recover_share_inputs_map["sampled_token_ids"], - logprobs=sampler_output.logprobs_tensors, - prompt_logprobs_list=model_output.prompt_logprobs_list, - ) + if not model_output.enable_pd_reorder: + output = _build_stream_transfer_data( + share_inputs["sampled_token_ids"], + logprobs=sampler_output.logprobs_tensors, + prompt_logprobs_list=model_output.prompt_logprobs_list, + ) + else: + recover_share_inputs_map = recover_batch_index_for_output( + share_inputs, + model_output.index_to_batch_id, + model_output.enable_pd_reorder, + ["sampled_token_ids"], + ) + recover_batch_index_for_sampler_output( + sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder + ) + output = _build_stream_transfer_data( + recover_share_inputs_map["sampled_token_ids"], + logprobs=sampler_output.logprobs_tensors, + prompt_logprobs_list=model_output.prompt_logprobs_list, + ) async_output_queue.put(output) else: if sampler_output.logprobs_tensors is None: - recover_share_inputs_map = recover_batch_index_for_output( - share_inputs, - model_output.index_to_batch_id, - model_output.enable_pd_reorder, - ["last_preempted_idx", "sampled_token_ids"], - ) - save_output( - recover_share_inputs_map["sampled_token_ids"], - model_output.not_need_stop, - recover_share_inputs_map["last_preempted_idx"], - model_output.mp_rank, - save_each_rank, - ) + if not model_output.enable_pd_reorder: + save_output( + share_inputs["sampled_token_ids"], + model_output.not_need_stop, + share_inputs["last_preempted_idx"], + model_output.mp_rank, + save_each_rank, + ) + else: + recover_share_inputs_map = recover_batch_index_for_output( + share_inputs, + model_output.index_to_batch_id, + model_output.enable_pd_reorder, + ["last_preempted_idx", "sampled_token_ids"], + ) + save_output( + recover_share_inputs_map["sampled_token_ids"], + model_output.not_need_stop, + recover_share_inputs_map["last_preempted_idx"], + model_output.mp_rank, + save_each_rank, + ) else: - recover_share_inputs_map = recover_batch_index_for_output( - share_inputs, - model_output.index_to_batch_id, - model_output.enable_pd_reorder, - ["last_preempted_idx"], - ) - recover_batch_index_for_sampler_output( - sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder - ) - save_output_topk( - share_inputs["sampled_token_ids"], - sampler_output.logprobs_tensors.logprob_token_ids, - sampler_output.logprobs_tensors.logprobs, - sampler_output.logprobs_tensors.selected_token_ranks, - model_output.not_need_stop, - recover_share_inputs_map["last_preempted_idx"], - model_output.mp_rank, - ) + if not model_output.enable_pd_reorder: + save_output_topk( + share_inputs["sampled_token_ids"], + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + model_output.not_need_stop, + share_inputs["last_preempted_idx"], + model_output.mp_rank, + ) + else: + recover_share_inputs_map = recover_batch_index_for_output( + share_inputs, + model_output.index_to_batch_id, + model_output.enable_pd_reorder, + ["last_preempted_idx"], + ) + recover_batch_index_for_sampler_output( + sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder + ) + save_output_topk( + share_inputs["sampled_token_ids"], + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + model_output.not_need_stop, + recover_share_inputs_map["last_preempted_idx"], + model_output.mp_rank, + ) share_inputs["last_preempted_idx"][:] = 0 @@ -529,54 +556,84 @@ def post_process_specualate( if not skip_save_output: if sampler_output.logprobs_tensors is None: - recover_model_output_map = recover_batch_index_for_output( - model_output, - model_output.index_to_batch_id, - model_output.enable_pd_reorder, - ["accept_tokens", "accept_num", "seq_lens_decoder", "prompt_lens"], - ) - recover_share_inputs = recover_batch_index_for_output( - share_inputs, model_output.index_to_batch_id, model_output.enable_pd_reorder, ["preempted_idx"] - ) - speculate_save_output( - recover_model_output_map["accept_tokens"], - recover_model_output_map["accept_num"], - model_output.not_need_stop, - recover_model_output_map["seq_lens_decoder"], - recover_model_output_map["prompt_lens"], - recover_share_inputs["preempted_idx"], - model_output.mp_rank, - save_each_rank, - bool(envs.ENABLE_V1_KVCACHE_SCHEDULER), - ) + if not model_output.enable_pd_reorder: + speculate_save_output( + model_output["accept_tokens"], + model_output["accept_num"], + model_output.not_need_stop, + model_output["seq_lens_decoder"], + model_output["prompt_lens"], + share_inputs["preempted_idx"], + model_output.mp_rank, + save_each_rank, + bool(envs.ENABLE_V1_KVCACHE_SCHEDULER), + ) + else: + recover_model_output_map = recover_batch_index_for_output( + model_output, + model_output.index_to_batch_id, + model_output.enable_pd_reorder, + ["accept_tokens", "accept_num", "seq_lens_decoder", "prompt_lens"], + ) + recover_share_inputs = recover_batch_index_for_output( + share_inputs, model_output.index_to_batch_id, model_output.enable_pd_reorder, ["preempted_idx"] + ) + speculate_save_output( + recover_model_output_map["accept_tokens"], + recover_model_output_map["accept_num"], + model_output.not_need_stop, + recover_model_output_map["seq_lens_decoder"], + recover_model_output_map["prompt_lens"], + recover_share_inputs["preempted_idx"], + model_output.mp_rank, + save_each_rank, + bool(envs.ENABLE_V1_KVCACHE_SCHEDULER), + ) else: - recover_batch_index_for_sampler_output( - sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder - ) - recover_model_output_map = recover_batch_index_for_output( - model_output, - model_output.index_to_batch_id, - model_output.enable_pd_reorder, - ["seq_lens_decoder", "prompt_lens"], - ) - recover_share_inputs = recover_batch_index_for_output( - share_inputs, model_output.index_to_batch_id, model_output.enable_pd_reorder, ["preempted_idx"] - ) - speculate_save_output_topk( - sampler_output.sampled_token_ids, - sampler_output.logprobs_tensors.logprob_token_ids, - sampler_output.logprobs_tensors.logprobs, - sampler_output.logprobs_tensors.selected_token_ranks, - sampler_output.token_num_per_batch, - sampler_output.cu_batch_token_offset, - model_output.not_need_stop, - recover_model_output_map["seq_lens_decoder"], - recover_model_output_map["prompt_lens"], - recover_share_inputs["preempted_idx"], - 3, # mtype - model_output.mp_rank, - save_each_rank, - ) + if not model_output.enable_pd_reorder: + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + sampler_output.token_num_per_batch, + sampler_output.cu_batch_token_offset, + model_output.not_need_stop, + model_output["seq_lens_decoder"], + model_output["prompt_lens"], + share_inputs["preempted_idx"], + 3, # mtype + model_output.mp_rank, + save_each_rank, + ) + else: + recover_batch_index_for_sampler_output( + sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder + ) + recover_model_output_map = recover_batch_index_for_output( + model_output, + model_output.index_to_batch_id, + model_output.enable_pd_reorder, + ["seq_lens_decoder", "prompt_lens"], + ) + recover_share_inputs = recover_batch_index_for_output( + share_inputs, model_output.index_to_batch_id, model_output.enable_pd_reorder, ["preempted_idx"] + ) + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + sampler_output.token_num_per_batch, + sampler_output.cu_batch_token_offset, + model_output.not_need_stop, + recover_model_output_map["seq_lens_decoder"], + recover_model_output_map["prompt_lens"], + recover_share_inputs["preempted_idx"], + 3, # mtype + model_output.mp_rank, + save_each_rank, + ) def post_process(