From 6f0a6fcffc3343104095b2c6cd022b15ea5cdb49 Mon Sep 17 00:00:00 2001 From: songyuxing Date: Mon, 18 May 2026 10:59:39 +0800 Subject: [PATCH] [BugFix] Fix attention mask for multimodal models --- fastdeploy/spec_decode/mtp.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 78cf8029504..f28c8352c9e 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -523,10 +523,16 @@ def insert_tasks_v1( ) if self.use_attn_mask_offset: inputs = request.multimodal_inputs - self.model_inputs["attn_mask_offsets_full"][idx][0 : prefill_end_index - prefill_start_index] = ( - paddle.to_tensor( - inputs["attention_mask_offset"][prefill_start_index:prefill_end_index], dtype="int32" + attn_offset_len = prefill_end_index - prefill_start_index + if inputs.get("attention_mask_offset", None) is None: + attention_mask_offset_slice = np.arange(prefill_start_index, prefill_end_index, dtype=np.int32) + else: + attention_mask_offset_slice = np.asarray( + inputs["attention_mask_offset"][prefill_start_index:prefill_end_index], + dtype=np.int32, ) + self.model_inputs["attn_mask_offsets_full"][idx][0:attn_offset_len] = paddle.to_tensor( + attention_mask_offset_slice, dtype="int32" ) # GPU don't need it anymore # NOTE: XPU backend needs decoder attention mask offset; GPU backend does not use it