diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index c76fd4167f1..0f847a0d7f3 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -380,55 +380,82 @@ def save_output_normal( # In the future, we will abandon this approach. if envs.FD_USE_GET_SAVE_OUTPUT_V1: if save_each_rank or model_output.mp_rank == 0: - recover_share_inputs_map = recover_batch_index_for_output( - share_inputs, - model_output.index_to_batch_id, - model_output.enable_pd_reorder, - ["sampled_token_ids"], - ) - recover_batch_index_for_sampler_output( - sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder - ) - output = _build_stream_transfer_data( - recover_share_inputs_map["sampled_token_ids"], - logprobs=sampler_output.logprobs_tensors, - prompt_logprobs_list=model_output.prompt_logprobs_list, - ) + if not model_output.enable_pd_reorder: + output = _build_stream_transfer_data( + share_inputs["sampled_token_ids"], + logprobs=sampler_output.logprobs_tensors, + prompt_logprobs_list=model_output.prompt_logprobs_list, + ) + else: + recover_share_inputs_map = recover_batch_index_for_output( + share_inputs, + model_output.index_to_batch_id, + model_output.enable_pd_reorder, + ["sampled_token_ids"], + ) + recover_batch_index_for_sampler_output( + sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder + ) + output = _build_stream_transfer_data( + recover_share_inputs_map["sampled_token_ids"], + logprobs=sampler_output.logprobs_tensors, + prompt_logprobs_list=model_output.prompt_logprobs_list, + ) async_output_queue.put(output) else: if sampler_output.logprobs_tensors is None: - recover_share_inputs_map = recover_batch_index_for_output( - share_inputs, - model_output.index_to_batch_id, - model_output.enable_pd_reorder, - ["last_preempted_idx", "sampled_token_ids"], - ) - save_output( - recover_share_inputs_map["sampled_token_ids"], - model_output.not_need_stop, - recover_share_inputs_map["last_preempted_idx"], - model_output.mp_rank, - save_each_rank, - ) + if not model_output.enable_pd_reorder: + save_output( + share_inputs["sampled_token_ids"], + model_output.not_need_stop, + share_inputs["last_preempted_idx"], + model_output.mp_rank, + save_each_rank, + ) + else: + recover_share_inputs_map = recover_batch_index_for_output( + share_inputs, + model_output.index_to_batch_id, + model_output.enable_pd_reorder, + ["last_preempted_idx", "sampled_token_ids"], + ) + save_output( + recover_share_inputs_map["sampled_token_ids"], + model_output.not_need_stop, + recover_share_inputs_map["last_preempted_idx"], + model_output.mp_rank, + save_each_rank, + ) else: - recover_share_inputs_map = recover_batch_index_for_output( - share_inputs, - model_output.index_to_batch_id, - model_output.enable_pd_reorder, - ["last_preempted_idx"], - ) - recover_batch_index_for_sampler_output( - sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder - ) - save_output_topk( - share_inputs["sampled_token_ids"], - sampler_output.logprobs_tensors.logprob_token_ids, - sampler_output.logprobs_tensors.logprobs, - sampler_output.logprobs_tensors.selected_token_ranks, - model_output.not_need_stop, - recover_share_inputs_map["last_preempted_idx"], - model_output.mp_rank, - ) + if not model_output.enable_pd_reorder: + save_output_topk( + share_inputs["sampled_token_ids"], + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + model_output.not_need_stop, + share_inputs["last_preempted_idx"], + model_output.mp_rank, + ) + else: + recover_share_inputs_map = recover_batch_index_for_output( + share_inputs, + model_output.index_to_batch_id, + model_output.enable_pd_reorder, + ["last_preempted_idx"], + ) + recover_batch_index_for_sampler_output( + sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder + ) + save_output_topk( + share_inputs["sampled_token_ids"], + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + model_output.not_need_stop, + recover_share_inputs_map["last_preempted_idx"], + model_output.mp_rank, + ) share_inputs["last_preempted_idx"][:] = 0 @@ -529,54 +556,84 @@ def post_process_specualate( if not skip_save_output: if sampler_output.logprobs_tensors is None: - recover_model_output_map = recover_batch_index_for_output( - model_output, - model_output.index_to_batch_id, - model_output.enable_pd_reorder, - ["accept_tokens", "accept_num", "seq_lens_decoder", "prompt_lens"], - ) - recover_share_inputs = recover_batch_index_for_output( - share_inputs, model_output.index_to_batch_id, model_output.enable_pd_reorder, ["preempted_idx"] - ) - speculate_save_output( - recover_model_output_map["accept_tokens"], - recover_model_output_map["accept_num"], - model_output.not_need_stop, - recover_model_output_map["seq_lens_decoder"], - recover_model_output_map["prompt_lens"], - recover_share_inputs["preempted_idx"], - model_output.mp_rank, - save_each_rank, - bool(envs.ENABLE_V1_KVCACHE_SCHEDULER), - ) + if not model_output.enable_pd_reorder: + speculate_save_output( + model_output["accept_tokens"], + model_output["accept_num"], + model_output.not_need_stop, + model_output["seq_lens_decoder"], + model_output["prompt_lens"], + share_inputs["preempted_idx"], + model_output.mp_rank, + save_each_rank, + bool(envs.ENABLE_V1_KVCACHE_SCHEDULER), + ) + else: + recover_model_output_map = recover_batch_index_for_output( + model_output, + model_output.index_to_batch_id, + model_output.enable_pd_reorder, + ["accept_tokens", "accept_num", "seq_lens_decoder", "prompt_lens"], + ) + recover_share_inputs = recover_batch_index_for_output( + share_inputs, model_output.index_to_batch_id, model_output.enable_pd_reorder, ["preempted_idx"] + ) + speculate_save_output( + recover_model_output_map["accept_tokens"], + recover_model_output_map["accept_num"], + model_output.not_need_stop, + recover_model_output_map["seq_lens_decoder"], + recover_model_output_map["prompt_lens"], + recover_share_inputs["preempted_idx"], + model_output.mp_rank, + save_each_rank, + bool(envs.ENABLE_V1_KVCACHE_SCHEDULER), + ) else: - recover_batch_index_for_sampler_output( - sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder - ) - recover_model_output_map = recover_batch_index_for_output( - model_output, - model_output.index_to_batch_id, - model_output.enable_pd_reorder, - ["seq_lens_decoder", "prompt_lens"], - ) - recover_share_inputs = recover_batch_index_for_output( - share_inputs, model_output.index_to_batch_id, model_output.enable_pd_reorder, ["preempted_idx"] - ) - speculate_save_output_topk( - sampler_output.sampled_token_ids, - sampler_output.logprobs_tensors.logprob_token_ids, - sampler_output.logprobs_tensors.logprobs, - sampler_output.logprobs_tensors.selected_token_ranks, - sampler_output.token_num_per_batch, - sampler_output.cu_batch_token_offset, - model_output.not_need_stop, - recover_model_output_map["seq_lens_decoder"], - recover_model_output_map["prompt_lens"], - recover_share_inputs["preempted_idx"], - 3, # mtype - model_output.mp_rank, - save_each_rank, - ) + if not model_output.enable_pd_reorder: + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + sampler_output.token_num_per_batch, + sampler_output.cu_batch_token_offset, + model_output.not_need_stop, + model_output["seq_lens_decoder"], + model_output["prompt_lens"], + share_inputs["preempted_idx"], + 3, # mtype + model_output.mp_rank, + save_each_rank, + ) + else: + recover_batch_index_for_sampler_output( + sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder + ) + recover_model_output_map = recover_batch_index_for_output( + model_output, + model_output.index_to_batch_id, + model_output.enable_pd_reorder, + ["seq_lens_decoder", "prompt_lens"], + ) + recover_share_inputs = recover_batch_index_for_output( + share_inputs, model_output.index_to_batch_id, model_output.enable_pd_reorder, ["preempted_idx"] + ) + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + sampler_output.token_num_per_batch, + sampler_output.cu_batch_token_offset, + model_output.not_need_stop, + recover_model_output_map["seq_lens_decoder"], + recover_model_output_map["prompt_lens"], + recover_share_inputs["preempted_idx"], + 3, # mtype + model_output.mp_rank, + save_each_rank, + ) def post_process( diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 7ebe86b852f..b7264790157 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -747,9 +747,9 @@ def _prepare_inputs(self, full_hidden_states): self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) - def _post_process(self, sampled_token_ids): + def _post_process_pd_reorder(self, sampled_token_ids): """ - PostProcess for generation + PostProcess for generation for pd reorder """ draft_model_update( sampled_token_ids, @@ -804,6 +804,98 @@ def _post_process(self, sampled_token_ids): self.model_inputs["step_idx"], ) + def _post_process(self, sampled_token_ids): + """ + PostProcess for generation + """ + draft_model_update( + sampled_token_ids, + self.model_inputs["draft_tokens"], + self.model_inputs["pre_ids"], + self.model_inputs["seq_lens_this_time"], + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["step_idx"], + # Note(ZKK): + # I strongly advise xpu student delete the fuck `output_cum_offsets` name in XPU backend + # like my pr https://github.com/PaddlePaddle/FastDeploy/pull/6358 + ( + self.model_inputs["cu_seqlens_q_output"] + if current_platform.is_cuda() + else self.model_inputs["output_cum_offsets"] + ), + self.model_inputs["stop_flags"], + self.model_inputs["not_need_stop"], + self.model_inputs["max_dec_len"], + self.model_inputs["eos_token_id"], + self.model_inputs["base_model_draft_tokens"], + self.max_model_len, + self.model_inputs["substep"], + ) + + if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0: + skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) + mtp_save_first_token( + self.model_inputs["base_model_draft_tokens"], + self.model_inputs["not_need_stop"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["prompt_lens"], + self.model_inputs["step_idx"], + self.local_rank, + self.parallel_config.use_ep, + skip_save, + ) + # Ensure only save first token once. + paddle.assign( + paddle.where( + self.model_inputs["stop_flags"], + paddle.zeros_like(self.model_inputs["step_idx"]), + self.model_inputs["step_idx"], + ), + self.model_inputs["step_idx"], + ) + + def _speculate_save_output_for_cuda(self, sampler_output): + real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] + if not self.model_inputs.enable_pd_reorder: + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + self.model_inputs["batch_token_num"][:real_bsz], + self.model_inputs["cu_batch_token_offset"][:real_bsz], + self.model_inputs["not_need_stop"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["prompt_lens"], + 4, # mtype + self.local_rank, + self.parallel_config.use_ep, + ) + else: + recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id) + recover_model_output_map = recover_batch_index_for_output( + self.model_inputs, + self.model_inputs.index_to_batch_id, + self.model_inputs.enable_pd_reorder[ + "batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens" + ], + ) + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + recover_model_output_map["batch_token_num"][:real_bsz], + recover_model_output_map["cu_batch_token_offset"][:real_bsz], + self.model_inputs["not_need_stop"], + recover_model_output_map["seq_lens_decoder"], + recover_model_output_map["prompt_lens"], + 4, # mtype + self.local_rank, + self.parallel_config.use_ep, + ) + def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False): """ Main process for MTP inference. @@ -955,29 +1047,7 @@ def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = F and substep == 0 and sampler_output.logprobs_tensors is not None ): - real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] - recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id) - recover_model_output_map = recover_batch_index_for_output( - self.model_inputs, - self.model_inputs.index_to_batch_id, - self.model_inputs.enable_pd_reorder[ - "batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens" - ], - ) - speculate_save_output_topk( - sampler_output.sampled_token_ids, - sampler_output.logprobs_tensors.logprob_token_ids, - sampler_output.logprobs_tensors.logprobs, - sampler_output.logprobs_tensors.selected_token_ranks, - recover_model_output_map["batch_token_num"][:real_bsz], - recover_model_output_map["cu_batch_token_offset"][:real_bsz], - self.model_inputs["not_need_stop"], - recover_model_output_map["seq_lens_decoder"], - recover_model_output_map["prompt_lens"], - 4, # mtype - self.local_rank, - self.parallel_config.use_ep, - ) + self._speculate_save_output_for_cuda(sampler_output) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( @@ -985,14 +1055,49 @@ def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = F self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) - - self._post_process(sampled_token_ids) + if not self.model_inputs.enable_pd_reorder: + self._post_process(sampled_token_ids) + else: + self._post_process_pd_reorder(sampled_token_ids) if substep != self.num_model_steps - 1: self._get_self_hidden_states(hidden_states) else: if hasattr(self.model, "empty_input_forward") and not is_dummy_run: self.model.empty_input_forward(forward_meta=self.forward_meta) + def _speculate_save_output_for_xpu(self, sampler_output): + real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] + if not self.model_inputs.enable_pd_reorder: + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + self.model_inputs["batch_token_num"][:real_bsz], + self.model_inputs["cu_batch_token_offset"][:real_bsz], + self.model_inputs["not_need_stop"], + 4, # mtype + self.local_rank, + ) + else: + recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id) + recover_model_output_map = recover_batch_index_for_output( + self.model_inputs, + self.model_inputs.index_to_batch_id, + self.model_inputs.enable_pd_reorder["batch_token_num", "cu_batch_token_offset"], + ) + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + recover_model_output_map["batch_token_num"][:real_bsz], + recover_model_output_map["cu_batch_token_offset"][:real_bsz], + self.model_inputs["not_need_stop"], + 4, # mtype + self.local_rank, + ) + def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False): """ Main process for MTP inference. @@ -1074,24 +1179,7 @@ def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = Fa ) if substep == 0 and sampler_output.logprobs_tensors is not None: - real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] - recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id) - recover_model_output_map = recover_batch_index_for_output( - self.model_inputs, - self.model_inputs.index_to_batch_id, - self.model_inputs.enable_pd_reorder["batch_token_num", "cu_batch_token_offset"], - ) - speculate_save_output_topk( - sampler_output.sampled_token_ids, - sampler_output.logprobs_tensors.logprob_token_ids, - sampler_output.logprobs_tensors.logprobs, - sampler_output.logprobs_tensors.selected_token_ranks, - recover_model_output_map["batch_token_num"][:real_bsz], - recover_model_output_map["cu_batch_token_offset"][:real_bsz], - self.model_inputs["not_need_stop"], - 4, # mtype - self.local_rank, - ) + self._speculate_save_output_for_xpu(sampler_output) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( @@ -1100,7 +1188,10 @@ def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = Fa group=self.parallel_config.tp_group, ) - self._post_process(sampled_token_ids) + if not self.model_inputs.enable_pd_reorder: + self._post_process(sampled_token_ids) + else: + self._post_process_pd_reorder(sampled_token_ids) if substep != self.num_model_steps - 1: self._get_self_hidden_states(hidden_states) else: