From b547fcf16c2963fa196e9dd5a083c8ae035d1e3d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 23 Nov 2025 18:02:23 +0000 Subject: [PATCH 01/38] Fix QwenImage txt_seq_lens handling --- .../controlnets/controlnet_qwenimage.py | 15 ++++++- .../transformers/transformer_qwenimage.py | 45 ++++++++++++++++--- .../qwenimage/before_denoise.py | 41 ----------------- .../modular_pipelines/qwenimage/denoise.py | 11 +---- .../pipelines/qwenimage/pipeline_qwenimage.py | 7 --- .../pipeline_qwenimage_controlnet.py | 3 -- .../pipeline_qwenimage_controlnet_inpaint.py | 3 -- .../qwenimage/pipeline_qwenimage_edit.py | 7 --- .../pipeline_qwenimage_edit_inpaint.py | 7 --- .../qwenimage/pipeline_qwenimage_edit_plus.py | 7 --- .../qwenimage/pipeline_qwenimage_img2img.py | 7 --- .../qwenimage/pipeline_qwenimage_inpaint.py | 7 --- .../test_models_transformer_qwenimage.py | 29 +++++++++++- 13 files changed, 82 insertions(+), 107 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index 7c4955eb5828..b04a91f012c7 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -211,6 +211,9 @@ def forward( Used to indicate denoising step. block_controlnet_hidden_states: (`list` of `torch.Tensor`): A list of tensors that if specified are added to the residuals of transformer blocks. + txt_seq_lens (`List[int]`, *optional*): + Optional text sequence lengths. If omitted, or shorter than the encoder hidden states length, the model + derives the length from the encoder hidden states (or their mask). joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -244,7 +247,17 @@ def forward( temb = self.time_text_embed(timestep, hidden_states) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + batch_size, text_seq_len = encoder_hidden_states.shape[:2] + if txt_seq_lens is not None: + if len(txt_seq_lens) != batch_size: + raise ValueError( + f"`txt_seq_lens` must have length {batch_size}, but got {len(txt_seq_lens)} instead." + ) + text_seq_len = max(text_seq_len, max(txt_seq_lens)) + elif encoder_hidden_states_mask is not None: + text_seq_len = max(text_seq_len, int(encoder_hidden_states_mask.sum(dim=1).max().item())) + + image_rotary_emb = self.pos_embed(img_shapes, text_seq_len, device=hidden_states.device) timestep = timestep.to(hidden_states.dtype) encoder_hidden_states = self.txt_norm(encoder_hidden_states) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index c0fa031b9faf..9aa57f994551 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -197,15 +197,15 @@ def rope_params(self, index, dim, theta=10000): def forward( self, video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], - txt_seq_lens: List[int], + txt_seq_len: int, device: torch.device, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): A list of 3 integers [frame, height, width] representing the shape of the video. - txt_seq_lens (`List[int]`): - A list of integers of length batch_size representing the length of each text prompt. + txt_seq_len (`int`): + The length of the text sequence. This should match the encoder hidden states length. device: (`torch.device`): The device on which to perform the RoPE computation. """ @@ -232,8 +232,7 @@ def forward( else: max_vid_index = max(height, width, max_vid_index) - max_len = max(txt_seq_lens) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + txt_seq_len, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @@ -330,6 +329,27 @@ def __call__( joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) + # If an encoder_hidden_states_mask is provided, turn it into a broadcastable attention mask. + if encoder_hidden_states_mask is not None and attention_mask is None: + batch_size, image_seq_len = hidden_states.shape[:2] + text_seq_len = encoder_hidden_states.shape[1] + + if encoder_hidden_states_mask.shape[0] != batch_size: + raise ValueError( + f"`encoder_hidden_states_mask` batch size ({encoder_hidden_states_mask.shape[0]}) " + f"must match hidden_states batch size ({batch_size})." + ) + if encoder_hidden_states_mask.shape[1] != text_seq_len: + raise ValueError( + f"`encoder_hidden_states_mask` sequence length ({encoder_hidden_states_mask.shape[1]}) " + f"must match encoder_hidden_states sequence length ({text_seq_len})." + ) + + text_attention_mask = encoder_hidden_states_mask.to(dtype=torch.bool, device=hidden_states.device) + image_attention_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) + joint_attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) + attention_mask = joint_attention_mask[:, None, None, :] + # Compute joint attention joint_hidden_states = dispatch_attention_fn( joint_query, @@ -588,6 +608,9 @@ def forward( Mask of the input conditions. timestep ( `torch.LongTensor`): Used to indicate denoising step. + txt_seq_lens (`List[int]`, *optional*): + Optional text sequence lengths. If not provided, or if any provided values are shorter than the + encoder hidden states length, the model falls back to the encoder hidden states length. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -621,6 +644,16 @@ def forward( encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) + batch_size, text_seq_len = encoder_hidden_states.shape[:2] + if txt_seq_lens is not None: + if len(txt_seq_lens) != batch_size: + raise ValueError( + f"`txt_seq_lens` must have length {batch_size}, but got {len(txt_seq_lens)} instead." + ) + text_seq_len = max(text_seq_len, max(txt_seq_lens)) + elif encoder_hidden_states_mask is not None: + text_seq_len = max(text_seq_len, int(encoder_hidden_states_mask.sum(dim=1).max().item())) + if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 @@ -630,7 +663,7 @@ def forward( else self.time_text_embed(timestep, guidance, hidden_states) ) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + image_rotary_emb = self.pos_embed(img_shapes, text_seq_len, device=hidden_states.device) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index 0e470332c6f4..05cb19afb674 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -525,18 +525,6 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=List[List[Tuple[int, int, int]]], description="The shapes of the images latents, used for RoPE calculation", ), - OutputParam( - name="txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the prompt embeds, used for RoPE calculation", - ), - OutputParam( - name="negative_txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", - ), ] def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -551,14 +539,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - ) ] ] * block_state.batch_size - block_state.txt_seq_lens = ( - block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None - ) - block_state.negative_txt_seq_lens = ( - block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() - if block_state.negative_prompt_embeds_mask is not None - else None - ) self.set_block_state(state, block_state) @@ -592,18 +572,6 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=List[List[Tuple[int, int, int]]], description="The shapes of the images latents, used for RoPE calculation", ), - OutputParam( - name="txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the prompt embeds, used for RoPE calculation", - ), - OutputParam( - name="negative_txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", - ), ] def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -626,15 +594,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - ] ] * block_state.batch_size - block_state.txt_seq_lens = ( - block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None - ) - block_state.negative_txt_seq_lens = ( - block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() - if block_state.negative_prompt_embeds_mask is not None - else None - ) - self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py index 49acd2dc0295..2faa34ada329 100644 --- a/src/diffusers/modular_pipelines/qwenimage/denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py @@ -149,7 +149,7 @@ def inputs(self) -> List[InputParam]: kwargs_type="denoiser_input_fields", description=( "All conditional model inputs for the denoiser. " - "It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens." + "It should contain prompt_embeds/negative_prompt_embeds." ), ), ] @@ -176,7 +176,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState img_shapes=block_state.img_shapes, encoder_hidden_states=block_state.prompt_embeds, encoder_hidden_states_mask=block_state.prompt_embeds_mask, - txt_seq_lens=block_state.txt_seq_lens, return_dict=False, ) @@ -247,10 +246,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState getattr(block_state, "prompt_embeds_mask", None), getattr(block_state, "negative_prompt_embeds_mask", None), ), - "txt_seq_lens": ( - getattr(block_state, "txt_seq_lens", None), - getattr(block_state, "negative_txt_seq_lens", None), - ), } components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) @@ -345,10 +340,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState getattr(block_state, "prompt_embeds_mask", None), getattr(block_state, "negative_prompt_embeds_mask", None), ), - "txt_seq_lens": ( - getattr(block_state, "txt_seq_lens", None), - getattr(block_state, "negative_txt_seq_lens", None), - ), } components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 33dc2039b986..bc3ce84e1019 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -672,11 +672,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -695,7 +690,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -709,7 +703,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index 5111096d93c1..ce6fc974a56e 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -909,7 +909,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, ) @@ -920,7 +919,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -935,7 +933,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 102a813ab582..77d78a5ca7a1 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -852,7 +852,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, ) @@ -863,7 +862,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -878,7 +876,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index ed37b238c8c9..dd723460a59e 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -793,11 +793,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -821,7 +816,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -836,7 +830,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index d54d1881fa4e..cf467203a9d2 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -1008,11 +1008,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1035,7 +1030,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -1050,7 +1044,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index ec203edf166c..942ee348508c 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -777,11 +777,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -805,7 +800,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -820,7 +814,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index cb4c5d8016bb..e0b41b8b8799 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -775,11 +775,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -797,7 +792,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -811,7 +805,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 1915c27eb2bb..83f02539b1ba 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -944,11 +944,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -966,7 +961,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -980,7 +974,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index b24fa90503ef..cd4792f889b6 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -68,7 +68,6 @@ def prepare_dummy_input(self, height=4, width=4): "encoder_hidden_states_mask": encoder_hidden_states_mask, "timestep": timestep, "img_shapes": img_shapes, - "txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(), } def prepare_init_args_and_inputs_for_common(self): @@ -91,6 +90,34 @@ def test_gradient_checkpointing_is_applied(self): expected_set = {"QwenImageTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + def test_accepts_short_txt_seq_lens(self): + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Provide a deliberately short txt_seq_lens to ensure the model falls back to the embedding length. + inputs["txt_seq_lens"] = [2] * inputs["encoder_hidden_states"].shape[0] + + with torch.no_grad(): + output = model(**inputs) + + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + def test_builds_attention_mask_from_encoder_mask(self): + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Create a mask with padding on the last two tokens. + encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() + encoder_hidden_states_mask[:, -2:] = 0 + + inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask + inputs.pop("txt_seq_lens", None) + + with torch.no_grad(): + output = model(**inputs) + + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = QwenImageTransformer2DModel From 72a80c66256f7a4a5b03c55134f5f69395fcd3df Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 23 Nov 2025 18:24:41 +0000 Subject: [PATCH 02/38] formatting --- .../models/transformers/transformer_qwenimage.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 9aa57f994551..a39fbf0073de 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -346,7 +346,9 @@ def __call__( ) text_attention_mask = encoder_hidden_states_mask.to(dtype=torch.bool, device=hidden_states.device) - image_attention_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) + image_attention_mask = torch.ones( + (batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device + ) joint_attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) attention_mask = joint_attention_mask[:, None, None, :] @@ -609,8 +611,8 @@ def forward( timestep ( `torch.LongTensor`): Used to indicate denoising step. txt_seq_lens (`List[int]`, *optional*): - Optional text sequence lengths. If not provided, or if any provided values are shorter than the - encoder hidden states length, the model falls back to the encoder hidden states length. + Optional text sequence lengths. If not provided, or if any provided values are shorter than the encoder + hidden states length, the model falls back to the encoder hidden states length. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -647,9 +649,7 @@ def forward( batch_size, text_seq_len = encoder_hidden_states.shape[:2] if txt_seq_lens is not None: if len(txt_seq_lens) != batch_size: - raise ValueError( - f"`txt_seq_lens` must have length {batch_size}, but got {len(txt_seq_lens)} instead." - ) + raise ValueError(f"`txt_seq_lens` must have length {batch_size}, but got {len(txt_seq_lens)} instead.") text_seq_len = max(text_seq_len, max(txt_seq_lens)) elif encoder_hidden_states_mask is not None: text_seq_len = max(text_seq_len, int(encoder_hidden_states_mask.sum(dim=1).max().item())) From 88cee8b5a8a3e4de65f5f590ac5f5642c87244fa Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 23 Nov 2025 18:38:11 +0000 Subject: [PATCH 03/38] formatting --- src/diffusers/models/controlnets/controlnet_qwenimage.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index b04a91f012c7..e09a40f5fb5b 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -250,9 +250,7 @@ def forward( batch_size, text_seq_len = encoder_hidden_states.shape[:2] if txt_seq_lens is not None: if len(txt_seq_lens) != batch_size: - raise ValueError( - f"`txt_seq_lens` must have length {batch_size}, but got {len(txt_seq_lens)} instead." - ) + raise ValueError(f"`txt_seq_lens` must have length {batch_size}, but got {len(txt_seq_lens)} instead.") text_seq_len = max(text_seq_len, max(txt_seq_lens)) elif encoder_hidden_states_mask is not None: text_seq_len = max(text_seq_len, int(encoder_hidden_states_mask.sum(dim=1).max().item())) From ac5ac24d9895478c85a5d4ed943f779e705bc702 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 29 Nov 2025 15:37:30 +0000 Subject: [PATCH 04/38] remove txt_seq_lens and use bool mask --- .../train_dreambooth_lora_qwen_image.py | 2 -- .../controlnets/controlnet_qwenimage.py | 35 +++++++------------ .../transformers/transformer_qwenimage.py | 32 +++++++++-------- .../test_models_transformer_qwenimage.py | 32 ++++++++++++++--- 4 files changed, 59 insertions(+), 42 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index 56de160d6f29..ecdbe8302ab8 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -1513,14 +1513,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): height=model_input.shape[3], width=model_input.shape[4], ) - print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}") model_pred = transformer( hidden_states=packed_noisy_model_input, encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, timestep=timesteps / 1000, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, )[0] model_pred = QwenImagePipeline._unpack_latents( diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index e09a40f5fb5b..a5747b16cf0c 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -189,12 +189,11 @@ def forward( encoder_hidden_states_mask: torch.Tensor = None, timestep: torch.LongTensor = None, img_shapes: Optional[List[Tuple[int, int, int]]] = None, - txt_seq_lens: Optional[List[int]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ - The [`FluxTransformer2DModel`] forward method. + The [`QwenImageControlNetModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): @@ -205,26 +204,24 @@ def forward( The scale factor for ControlNet outputs. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): + Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens. + Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern + (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention. timestep ( `torch.LongTensor`): Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. - txt_seq_lens (`List[int]`, *optional*): - Optional text sequence lengths. If omitted, or shorter than the encoder hidden states length, the model - derives the length from the encoder hidden states (or their mask). + img_shapes (`List[Tuple[int, int, int]]`, *optional*): + Image shapes for RoPE computation. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. + If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where + the first element is the controlnet block samples. """ if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() @@ -247,13 +244,9 @@ def forward( temb = self.time_text_embed(timestep, hidden_states) - batch_size, text_seq_len = encoder_hidden_states.shape[:2] - if txt_seq_lens is not None: - if len(txt_seq_lens) != batch_size: - raise ValueError(f"`txt_seq_lens` must have length {batch_size}, but got {len(txt_seq_lens)} instead.") - text_seq_len = max(text_seq_len, max(txt_seq_lens)) - elif encoder_hidden_states_mask is not None: - text_seq_len = max(text_seq_len, int(encoder_hidden_states_mask.sum(dim=1).max().item())) + # Use the encoder_hidden_states sequence length for RoPE computation + # The mask is used for attention masking in the attention processor + _, text_seq_len = encoder_hidden_states.shape[:2] image_rotary_emb = self.pos_embed(img_shapes, text_seq_len, device=hidden_states.device) @@ -332,7 +325,6 @@ def forward( encoder_hidden_states_mask: torch.Tensor = None, timestep: torch.LongTensor = None, img_shapes: Optional[List[Tuple[int, int, int]]] = None, - txt_seq_lens: Optional[List[int]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[QwenImageControlNetOutput, Tuple]: @@ -350,7 +342,6 @@ def forward( encoder_hidden_states_mask=encoder_hidden_states_mask, timestep=timestep, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, joint_attention_kwargs=joint_attention_kwargs, return_dict=return_dict, ) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index a39fbf0073de..932f05977166 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -330,6 +330,8 @@ def __call__( joint_value = torch.cat([txt_value, img_value], dim=1) # If an encoder_hidden_states_mask is provided, turn it into a broadcastable attention mask. + # The encoder_hidden_states_mask is expected to have 1.0 for valid tokens and 0.0 for padding. + # We convert it to a boolean mask where True means "attend" and False means "mask out" (don't attend). if encoder_hidden_states_mask is not None and attention_mask is None: batch_size, image_seq_len = hidden_states.shape[:2] text_seq_len = encoder_hidden_states.shape[1] @@ -345,7 +347,9 @@ def __call__( f"must match encoder_hidden_states sequence length ({text_seq_len})." ) - text_attention_mask = encoder_hidden_states_mask.to(dtype=torch.bool, device=hidden_states.device) + # Convert mask to boolean: 1/1.0 -> True (attend), 0/0.0 -> False (don't attend) + # This is the correct semantics for PyTorch's scaled_dot_product_attention with boolean masks. + text_attention_mask = encoder_hidden_states_mask.bool() image_attention_mask = torch.ones( (batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device ) @@ -592,7 +596,6 @@ def forward( encoder_hidden_states_mask: torch.Tensor = None, timestep: torch.LongTensor = None, img_shapes: Optional[List[Tuple[int, int, int]]] = None, - txt_seq_lens: Optional[List[int]] = None, guidance: torch.Tensor = None, # TODO: this should probably be removed attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_block_samples=None, @@ -606,17 +609,22 @@ def forward( Input `hidden_states`. encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): - Mask of the input conditions. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): + Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens. + Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern + (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention. timestep ( `torch.LongTensor`): Used to indicate denoising step. - txt_seq_lens (`List[int]`, *optional*): - Optional text sequence lengths. If not provided, or if any provided values are shorter than the encoder - hidden states length, the model falls back to the encoder hidden states length. + img_shapes (`List[Tuple[int, int, int]]`, *optional*): + Image shapes for RoPE computation. + guidance (`torch.Tensor`, *optional*): + Guidance tensor for conditional generation. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_block_samples (*optional*): + ControlNet block samples to add to the transformer blocks. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. @@ -646,13 +654,9 @@ def forward( encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) - batch_size, text_seq_len = encoder_hidden_states.shape[:2] - if txt_seq_lens is not None: - if len(txt_seq_lens) != batch_size: - raise ValueError(f"`txt_seq_lens` must have length {batch_size}, but got {len(txt_seq_lens)} instead.") - text_seq_len = max(text_seq_len, max(txt_seq_lens)) - elif encoder_hidden_states_mask is not None: - text_seq_len = max(text_seq_len, int(encoder_hidden_states_mask.sum(dim=1).max().item())) + # Use the encoder_hidden_states sequence length for RoPE computation + # The mask is used for attention masking in the attention processor + _, text_seq_len = encoder_hidden_states.shape[:2] if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index cd4792f889b6..0999e2b470df 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -90,16 +90,20 @@ def test_gradient_checkpointing_is_applied(self): expected_set = {"QwenImageTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - def test_accepts_short_txt_seq_lens(self): + def test_infers_text_seq_len_from_mask(self): init_dict, inputs = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) - # Provide a deliberately short txt_seq_lens to ensure the model falls back to the embedding length. - inputs["txt_seq_lens"] = [2] * inputs["encoder_hidden_states"].shape[0] + # Create a mask with only 2 valid tokens (rest are padding) + encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() + encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid + + inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask with torch.no_grad(): output = model(**inputs) + # The model should infer text_seq_len=2 from the mask for RoPE computation self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) def test_builds_attention_mask_from_encoder_mask(self): @@ -111,13 +115,33 @@ def test_builds_attention_mask_from_encoder_mask(self): encoder_hidden_states_mask[:, -2:] = 0 inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask - inputs.pop("txt_seq_lens", None) with torch.no_grad(): output = model(**inputs) self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + def test_non_contiguous_attention_mask(self): + """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Create a non-contiguous mask pattern: valid, padding, valid, padding, etc. + encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() + # Pattern: [True, False, True, False, True, False, False] + encoder_hidden_states_mask[:, 1] = 0 + encoder_hidden_states_mask[:, 3] = 0 + encoder_hidden_states_mask[:, 5:] = 0 + + inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask + + with torch.no_grad(): + output = model(**inputs) + + # The model should handle non-contiguous masks correctly + # RoPE uses the full sequence length, attention masking handles the pattern + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = QwenImageTransformer2DModel From 18efddeda3e81698edb2ab07d284cec0f3d1b20a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 30 Nov 2025 14:11:51 +0000 Subject: [PATCH 05/38] use compute_text_seq_len_from_mask --- .../controlnets/controlnet_qwenimage.py | 8 +++-- .../transformers/transformer_qwenimage.py | 35 +++++++++++++++++-- .../test_models_transformer_qwenimage.py | 11 ++++-- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index a5747b16cf0c..513a04a99ec7 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -31,6 +31,7 @@ QwenImageTransformerBlock, QwenTimestepProjEmbeddings, RMSNorm, + compute_text_seq_len_from_mask, ) @@ -244,9 +245,10 @@ def forward( temb = self.time_text_embed(timestep, hidden_states) - # Use the encoder_hidden_states sequence length for RoPE computation - # The mask is used for attention masking in the attention processor - _, text_seq_len = encoder_hidden_states.shape[:2] + # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask + text_seq_len, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + encoder_hidden_states, encoder_hidden_states_mask + ) image_rotary_emb = self.pos_embed(img_shapes, text_seq_len, device=hidden_states.device) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 932f05977166..5dc69edbb739 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -141,6 +141,34 @@ def apply_rotary_emb_qwen( return x_out.type_as(x) +def compute_text_seq_len_from_mask( + encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: Optional[torch.Tensor] +) -> Tuple[int, Optional[torch.Tensor]]: + """ + Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask. + """ + batch_size, text_seq_len = encoder_hidden_states.shape[:2] + if encoder_hidden_states_mask is None: + return text_seq_len, None + + if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len): + raise ValueError( + f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match " + f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})." + ) + + if encoder_hidden_states_mask.dtype != torch.bool: + encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool) + + position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) + active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) + has_active = encoder_hidden_states_mask.any(dim=1) + per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) + rope_text_seq_len = max(text_seq_len, int(per_sample_len.max().item())) + + return rope_text_seq_len, encoder_hidden_states_mask + + class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim): super().__init__() @@ -654,9 +682,10 @@ def forward( encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) - # Use the encoder_hidden_states sequence length for RoPE computation - # The mask is used for attention masking in the attention processor - _, text_seq_len = encoder_hidden_states.shape[:2] + # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask + text_seq_len, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + encoder_hidden_states, encoder_hidden_states_mask + ) if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 0999e2b470df..54cdbf4b6a35 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -19,6 +19,7 @@ import torch from diffusers import QwenImageTransformer2DModel +from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin @@ -133,13 +134,17 @@ def test_non_contiguous_attention_mask(self): encoder_hidden_states_mask[:, 3] = 0 encoder_hidden_states_mask[:, 5:] = 0 - inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask + inferred_rope_len, normalized_mask = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask + ) + self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1]) + self.assertTrue(normalized_mask.dtype == torch.bool) + + inputs["encoder_hidden_states_mask"] = normalized_mask with torch.no_grad(): output = model(**inputs) - # The model should handle non-contiguous masks correctly - # RoPE uses the full sequence length, attention masking handles the pattern self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) From 6a549d45ef2ec23bca8a5746a5820913f705a804 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 30 Nov 2025 16:14:26 +0000 Subject: [PATCH 06/38] add seq_lens to dispatch_attention_fn --- src/diffusers/models/attention_dispatch.py | 70 ++++++++++++++----- .../controlnets/controlnet_qwenimage.py | 6 +- .../transformers/transformer_qwenimage.py | 12 +++- .../test_models_transformer_qwenimage.py | 3 +- 4 files changed, 68 insertions(+), 23 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 0c247b76d039..fb2d6287ea73 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -305,6 +305,7 @@ def dispatch_attention_fn( *, backend: Optional[AttentionBackendName] = None, parallel_config: Optional["ParallelConfig"] = None, + seq_lens: Optional[torch.Tensor] = None, ) -> torch.Tensor: attention_kwargs = attention_kwargs or {} @@ -327,6 +328,8 @@ def dispatch_attention_fn( **attention_kwargs, "_parallel_config": parallel_config, } + if seq_lens is not None: + kwargs["seq_lens"] = seq_lens if is_torch_version(">=", "2.5.0"): kwargs["enable_gqa"] = enable_gqa @@ -1400,18 +1403,29 @@ def _flash_varlen_attention( is_causal: bool = False, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, + seq_lens: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape - if attn_mask is not None: - attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + if seq_lens is not None: + seq_lens = seq_lens.to(query.device) + # use the same lengths for Q and KV + seqlens_k = seq_lens + cu_seqlens_q = torch.cat([seq_lens.new_zeros(1), seq_lens.cumsum(0)], dim=0).to(torch.int32) + cu_seqlens_k = cu_seqlens_q + max_seqlen_q = int(seq_lens.max().item()) + max_seqlen_k = max_seqlen_q + attn_mask = None # varlen uses lengths + else: + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) ) - ) key_valid, value_valid = [], [] for b in range(batch_size): @@ -1521,18 +1535,28 @@ def _flash_varlen_attention_3( is_causal: bool = False, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, + seq_lens: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape - if attn_mask is not None: - attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + if seq_lens is not None: + seq_lens = seq_lens.to(query.device) + seqlens_k = seq_lens + cu_seqlens_q = torch.cat([seq_lens.new_zeros(1), seq_lens.cumsum(0)], dim=0).to(torch.int32) + cu_seqlens_k = cu_seqlens_q + max_seqlen_q = int(seq_lens.max().item()) + max_seqlen_k = max_seqlen_q + attn_mask = None # varlen uses lengths + else: + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) ) - ) key_valid, value_valid = [], [] for b in range(batch_size): @@ -2023,6 +2047,7 @@ def _sage_varlen_attention( scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, + seq_lens: Optional[torch.Tensor] = None, ) -> torch.Tensor: if return_lse: raise ValueError("Sage varlen backend does not support setting `return_lse=True`.") @@ -2030,14 +2055,23 @@ def _sage_varlen_attention( batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape - if attn_mask is not None: - attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + if seq_lens is not None: + seq_lens = seq_lens.to(query.device) + seqlens_k = seq_lens + cu_seqlens_q = torch.cat([seq_lens.new_zeros(1), seq_lens.cumsum(0)], dim=0).to(torch.int32) + cu_seqlens_k = cu_seqlens_q + max_seqlen_q = int(seq_lens.max().item()) + max_seqlen_k = max_seqlen_q + attn_mask = None # varlen uses lengths + else: + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) ) - ) key_valid, value_valid = [], [] for b in range(batch_size): diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index 513a04a99ec7..7e67168248c8 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -228,6 +228,7 @@ def forward( joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: + joint_attention_kwargs = {} lora_scale = 1.0 if USE_PEFT_BACKEND: @@ -246,10 +247,13 @@ def forward( temb = self.time_text_embed(timestep, hidden_states) # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask - text_seq_len, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + text_seq_len, text_seq_lens_per_sample, encoder_hidden_states_mask = compute_text_seq_len_from_mask( encoder_hidden_states, encoder_hidden_states_mask ) + if text_seq_lens_per_sample is not None: + joint_attention_kwargs.setdefault("text_seq_lens", text_seq_lens_per_sample) + image_rotary_emb = self.pos_embed(img_shapes, text_seq_len, device=hidden_states.device) timestep = timestep.to(hidden_states.dtype) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 5dc69edbb739..da719645975b 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -143,7 +143,7 @@ def apply_rotary_emb_qwen( def compute_text_seq_len_from_mask( encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: Optional[torch.Tensor] -) -> Tuple[int, Optional[torch.Tensor]]: +) -> Tuple[int, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask. """ @@ -166,7 +166,7 @@ def compute_text_seq_len_from_mask( per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) rope_text_seq_len = max(text_seq_len, int(per_sample_len.max().item())) - return rope_text_seq_len, encoder_hidden_states_mask + return rope_text_seq_len, per_sample_len, encoder_hidden_states_mask class QwenTimestepProjEmbeddings(nn.Module): @@ -308,6 +308,7 @@ def __call__( encoder_hidden_states_mask: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + text_seq_lens: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: if encoder_hidden_states is None: raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") @@ -394,6 +395,7 @@ def __call__( is_causal=False, backend=self._attention_backend, parallel_config=self._parallel_config, + seq_lens=text_seq_lens, ) # Reshape back @@ -665,6 +667,7 @@ def forward( attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) else: + attention_kwargs = {} lora_scale = 1.0 if USE_PEFT_BACKEND: @@ -683,10 +686,13 @@ def forward( encoder_hidden_states = self.txt_in(encoder_hidden_states) # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask - text_seq_len, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + text_seq_len, text_seq_lens_per_sample, encoder_hidden_states_mask = compute_text_seq_len_from_mask( encoder_hidden_states, encoder_hidden_states_mask ) + if text_seq_lens_per_sample is not None: + attention_kwargs.setdefault("text_seq_lens", text_seq_lens_per_sample) + if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 54cdbf4b6a35..e56f7ab47deb 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -134,9 +134,10 @@ def test_non_contiguous_attention_mask(self): encoder_hidden_states_mask[:, 3] = 0 encoder_hidden_states_mask[:, 5:] = 0 - inferred_rope_len, normalized_mask = compute_text_seq_len_from_mask( + inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( inputs["encoder_hidden_states"], encoder_hidden_states_mask ) + self.assertEqual(int(per_sample_len.max().item()), 5) self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1]) self.assertTrue(normalized_mask.dtype == torch.bool) From 2d424e037124b3e970f1c35f7435c578632d4f29 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 30 Nov 2025 16:45:50 +0000 Subject: [PATCH 07/38] use joint_seq_lens --- .../models/transformers/transformer_qwenimage.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index da719645975b..c905563b56e9 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -361,6 +361,7 @@ def __call__( # If an encoder_hidden_states_mask is provided, turn it into a broadcastable attention mask. # The encoder_hidden_states_mask is expected to have 1.0 for valid tokens and 0.0 for padding. # We convert it to a boolean mask where True means "attend" and False means "mask out" (don't attend). + joint_seq_lens = None if encoder_hidden_states_mask is not None and attention_mask is None: batch_size, image_seq_len = hidden_states.shape[:2] text_seq_len = encoder_hidden_states.shape[1] @@ -385,6 +386,12 @@ def __call__( joint_attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) attention_mask = joint_attention_mask[:, None, None, :] + # For varlen flash attention, we need the JOINT sequence lengths (text + image), not just text + if text_seq_lens is not None: + # text_seq_lens contains per-sample text lengths + # Add the image sequence length to get total joint sequence length + joint_seq_lens = text_seq_lens + image_seq_len + # Compute joint attention joint_hidden_states = dispatch_attention_fn( joint_query, @@ -395,7 +402,7 @@ def __call__( is_causal=False, backend=self._attention_backend, parallel_config=self._parallel_config, - seq_lens=text_seq_lens, + seq_lens=joint_seq_lens, ) # Reshape back From 30b5f98842f427ba9eb4666a3955f6a4f858a004 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 30 Nov 2025 16:48:52 +0000 Subject: [PATCH 08/38] remove unused index_block --- src/diffusers/models/controlnets/controlnet_qwenimage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index 7e67168248c8..c970d0270445 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -261,7 +261,7 @@ def forward( encoder_hidden_states = self.txt_in(encoder_hidden_states) block_samples = () - for index_block, block in enumerate(self.transformer_blocks): + for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, From f1c2d99688f0859c9c30337dbf12836e7b2466c1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 6 Dec 2025 18:15:20 +0000 Subject: [PATCH 09/38] WIP: Remove seq_lens parameter and use mask-based approach - Remove seq_lens parameter from dispatch_attention_fn - Update varlen backends to extract seqlens from masks - Update QwenImage to pass 2D joint_attention_mask - Fix native backend to handle 2D boolean masks - Fix sage_varlen seqlens_q to match seqlens_k for self-attention Note: sage_varlen still producing black images, needs further investigation --- src/diffusers/models/attention_dispatch.py | 222 +++++++++--------- .../transformers/transformer_qwenimage.py | 20 +- 2 files changed, 112 insertions(+), 130 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index fb2d6287ea73..e7c4040e45fa 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -305,7 +305,6 @@ def dispatch_attention_fn( *, backend: Optional[AttentionBackendName] = None, parallel_config: Optional["ParallelConfig"] = None, - seq_lens: Optional[torch.Tensor] = None, ) -> torch.Tensor: attention_kwargs = attention_kwargs or {} @@ -328,8 +327,6 @@ def dispatch_attention_fn( **attention_kwargs, "_parallel_config": parallel_config, } - if seq_lens is not None: - kwargs["seq_lens"] = seq_lens if is_torch_version(">=", "2.5.0"): kwargs["enable_gqa"] = enable_gqa @@ -502,8 +499,10 @@ def _prepare_for_flash_attn_or_sage_varlen_with_mask( attn_mask: torch.Tensor, device: Optional[torch.device] = None, ): - seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) + # For self-attention (Q, K, V from same sequence), seqlens_q should equal seqlens_k + # Both are computed from the mask which indicates valid tokens seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) + seqlens_q = seqlens_k # In self-attention, query and key lengths are the same cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) @@ -1403,29 +1402,18 @@ def _flash_varlen_attention( is_causal: bool = False, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, - seq_lens: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape - if seq_lens is not None: - seq_lens = seq_lens.to(query.device) - # use the same lengths for Q and KV - seqlens_k = seq_lens - cu_seqlens_q = torch.cat([seq_lens.new_zeros(1), seq_lens.cumsum(0)], dim=0).to(torch.int32) - cu_seqlens_k = cu_seqlens_q - max_seqlen_q = int(seq_lens.max().item()) - max_seqlen_k = max_seqlen_q - attn_mask = None # varlen uses lengths - else: - if attn_mask is not None: - attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) + ) key_valid, value_valid = [], [] for b in range(batch_size): @@ -1450,9 +1438,10 @@ def _flash_varlen_attention( causal=is_causal, return_attn_probs=return_lse, ) + out = out.unflatten(0, (batch_size, -1)) - return out + return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( @@ -1535,28 +1524,18 @@ def _flash_varlen_attention_3( is_causal: bool = False, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, - seq_lens: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape - if seq_lens is not None: - seq_lens = seq_lens.to(query.device) - seqlens_k = seq_lens - cu_seqlens_q = torch.cat([seq_lens.new_zeros(1), seq_lens.cumsum(0)], dim=0).to(torch.int32) - cu_seqlens_k = cu_seqlens_q - max_seqlen_q = int(seq_lens.max().item()) - max_seqlen_k = max_seqlen_q - attn_mask = None # varlen uses lengths - else: - if attn_mask is not None: - attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) + ) key_valid, value_valid = [], [] for b in range(batch_size): @@ -1707,6 +1686,20 @@ def _native_attention( ) -> torch.Tensor: if return_lse: raise ValueError("Native attention backend does not support setting `return_lse=True`.") + + # Convert 2D boolean mask to 4D additive mask for SDPA + if attn_mask is not None and attn_mask.ndim == 2: + # attn_mask is [batch_size, seq_len_k] boolean: True means attend, False means mask out + # SDPA expects [batch_size, 1, 1, seq_len_k] additive mask: 0.0 for attend, -inf for mask out + batch_size, seq_len_k = attn_mask.shape + # Ensure it's boolean for torch.where + if attn_mask.dtype != torch.bool: + attn_mask = attn_mask.bool() + # Convert boolean to additive: True -> 0.0, False -> -inf + attn_mask = torch.where(attn_mask, 0.0, float("-inf")) + # Convert to query dtype and reshape to [batch_size, 1, 1, seq_len_k] for broadcasting + attn_mask = attn_mask.to(dtype=query.dtype).view(batch_size, 1, 1, seq_len_k) + if _parallel_config is None: query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) out = torch.nn.functional.scaled_dot_product_attention( @@ -2018,68 +2011,57 @@ def _sage_attention_hub( ) -> torch.Tensor: lse = None func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn - if _parallel_config is None: - out = func( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - sm_scale=scale, - return_lse=return_lse, - ) - if return_lse: - out, lse, *_ = out + out = func( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + sm_scale=scale, + return_lse=return_lse, + ) + if return_lse: + out, lse, *_ = out return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( AttentionBackendName.SAGE_VARLEN, - constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], + constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_qkv_dtype_bf16_or_fp16, _check_shape], ) def _sage_varlen_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, - is_causal: bool = False, scale: Optional[float] = None, + is_causal: bool = False, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, - seq_lens: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if return_lse: - raise ValueError("Sage varlen backend does not support setting `return_lse=True`.") - batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape - if seq_lens is not None: - seq_lens = seq_lens.to(query.device) - seqlens_k = seq_lens - cu_seqlens_q = torch.cat([seq_lens.new_zeros(1), seq_lens.cumsum(0)], dim=0).to(torch.int32) - cu_seqlens_k = cu_seqlens_q - max_seqlen_q = int(seq_lens.max().item()) - max_seqlen_k = max_seqlen_q - attn_mask = None # varlen uses lengths - else: - if attn_mask is not None: - attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) + (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) + ) - key_valid, value_valid = [], [] + # The SageAttention kernel needs the query, key, and values to be of shape + # `[num_tokens, num_heads, head_dim]`. The number of tokens is the total number of tokens in the + # batch. + query_valid, key_valid, value_valid = [], [], [] for b in range(batch_size): - valid_len = seqlens_k[b] - key_valid.append(key[b, :valid_len]) - value_valid.append(value[b, :valid_len]) + query_valid.append(query[b, : seqlens_q[b]]) + key_valid.append(key[b, : seqlens_k[b]]) + value_valid.append(value[b, : seqlens_k[b]]) - query_packed = query.flatten(0, 1) + query_packed = torch.cat(query_valid, dim=0) key_packed = torch.cat(key_valid, dim=0) value_packed = torch.cat(value_valid, dim=0) @@ -2091,17 +2073,33 @@ def _sage_varlen_attention( cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, - is_causal=is_causal, sm_scale=scale, + is_causal=is_causal, + return_lse=return_lse, ) - out = out.unflatten(0, (batch_size, -1)) + lse = None + if return_lse: + out, lse, *_ = out - return out + # The output of the SageAttention kernel is of shape `[num_tokens, num_heads, head_dim]`. + # We need to reshape it to `[batch_size, seq_len_q, num_heads, head_dim]`. + out_padded = torch.zeros( + (batch_size, seq_len_q, *out.shape[1:]), + dtype=out.dtype, + device=out.device, + ) + for b in range(batch_size): + start, end = cu_seqlens_q[b], cu_seqlens_q[b + 1] + len_q = seqlens_q[b] + out_padded[b, :len_q] = out[start:end] + out = out_padded + + return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, - constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], + constraints=[_check_device_cuda_atleast_smXY(8, 9), _check_shape], ) def _sage_qk_int8_pv_fp8_cuda_attention( query: torch.Tensor, @@ -2112,15 +2110,16 @@ def _sage_qk_int8_pv_fp8_cuda_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - return sageattn_qk_int8_pv_fp8_cuda( + out = sageattn_qk_int8_pv_fp8_cuda( q=query, k=key, v=value, tensor_layout="NHD", is_causal=is_causal, sm_scale=scale, - return_lse=return_lse, ) + # The LSE is not returned from this kernel so we cannot support cases where it is needed + return out @_AttentionBackendRegistry.register( @@ -2136,15 +2135,16 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - return sageattn_qk_int8_pv_fp8_cuda_sm90( + out = sageattn_qk_int8_pv_fp8_cuda_sm90( q=query, k=key, v=value, tensor_layout="NHD", is_causal=is_causal, sm_scale=scale, - return_lse=return_lse, ) + # The LSE is not returned from this kernel so we cannot support cases where it is needed + return out @_AttentionBackendRegistry.register( @@ -2160,15 +2160,16 @@ def _sage_qk_int8_pv_fp16_cuda_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - return sageattn_qk_int8_pv_fp16_cuda( + out = sageattn_qk_int8_pv_fp16_cuda( q=query, k=key, v=value, tensor_layout="NHD", is_causal=is_causal, sm_scale=scale, - return_lse=return_lse, ) + # The LSE is not returned from this kernel so we cannot support cases where it is needed + return out @_AttentionBackendRegistry.register( @@ -2184,59 +2185,52 @@ def _sage_qk_int8_pv_fp16_triton_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - return sageattn_qk_int8_pv_fp16_triton( + out = sageattn_qk_int8_pv_fp16_triton( q=query, k=key, v=value, tensor_layout="NHD", is_causal=is_causal, sm_scale=scale, - return_lse=return_lse, ) + # The LSE is not returned from this kernel so we cannot support cases where it is needed + return out @_AttentionBackendRegistry.register( AttentionBackendName.XFORMERS, - constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], ) def _xformers_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, - enable_gqa: bool = False, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: if return_lse: - raise ValueError("xformers attention backend does not support setting `return_lse=True`.") - - batch_size, seq_len_q, num_heads_q, _ = query.shape - _, seq_len_kv, num_heads_kv, _ = key.shape + raise ValueError("Xformers attention backend does not support setting `return_lse=True`.") + op = xops.MemoryEfficientAttentionCkOp if is_causal: - attn_mask = xops.LowerTriangularMask() - elif attn_mask is not None: - if attn_mask.ndim == 2: - attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) - elif attn_mask.ndim != 4: - raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") - attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) - - if enable_gqa: - if num_heads_q % num_heads_kv != 0: - raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") - num_heads_per_group = num_heads_q // num_heads_kv - query = query.unflatten(2, (num_heads_kv, -1)) - key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) - value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + op = op.WITH_AUTOMATIC_CAUSAL_MASK + # Removed the check for attn_mask: Optional[torch.Tensor] = None + # since it's removed from the function signature and is not supported. - out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) + out = xops.memory_efficient_attention( + q=query, + k=key, + v=value, + p=dropout_p, + scale=scale, + op=op, + ) + return out - if enable_gqa: - out = out.flatten(2, 3) - return out +# ===== Default backend ===== +_check_attention_backend_requirements(_AttentionBackendRegistry._active_backend) +_maybe_download_kernel_for_backend(_AttentionBackendRegistry._active_backend) \ No newline at end of file diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index c905563b56e9..4d3554d35a44 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -308,7 +308,6 @@ def __call__( encoder_hidden_states_mask: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - text_seq_lens: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: if encoder_hidden_states is None: raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") @@ -358,10 +357,9 @@ def __call__( joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) - # If an encoder_hidden_states_mask is provided, turn it into a broadcastable attention mask. + # If an encoder_hidden_states_mask is provided, create a joint attention mask. # The encoder_hidden_states_mask is expected to have 1.0 for valid tokens and 0.0 for padding. # We convert it to a boolean mask where True means "attend" and False means "mask out" (don't attend). - joint_seq_lens = None if encoder_hidden_states_mask is not None and attention_mask is None: batch_size, image_seq_len = hidden_states.shape[:2] text_seq_len = encoder_hidden_states.shape[1] @@ -378,19 +376,13 @@ def __call__( ) # Convert mask to boolean: 1/1.0 -> True (attend), 0/0.0 -> False (don't attend) - # This is the correct semantics for PyTorch's scaled_dot_product_attention with boolean masks. text_attention_mask = encoder_hidden_states_mask.bool() image_attention_mask = torch.ones( (batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device ) - joint_attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) - attention_mask = joint_attention_mask[:, None, None, :] - - # For varlen flash attention, we need the JOINT sequence lengths (text + image), not just text - if text_seq_lens is not None: - # text_seq_lens contains per-sample text lengths - # Add the image sequence length to get total joint sequence length - joint_seq_lens = text_seq_lens + image_seq_len + # Create 2D joint mask [batch_size, text_seq_len + image_seq_len] + # The attention dispatch will normalize this and extract sequence lengths + attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) # Compute joint attention joint_hidden_states = dispatch_attention_fn( @@ -402,7 +394,6 @@ def __call__( is_causal=False, backend=self._attention_backend, parallel_config=self._parallel_config, - seq_lens=joint_seq_lens, ) # Reshape back @@ -697,9 +688,6 @@ def forward( encoder_hidden_states, encoder_hidden_states_mask ) - if text_seq_lens_per_sample is not None: - attention_kwargs.setdefault("text_seq_lens", text_seq_lens_per_sample) - if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 From beeb0206b7bb870efc3bfce67c3f2dd884cd9961 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 7 Dec 2025 11:41:38 +0000 Subject: [PATCH 10/38] fix formatting --- src/diffusers/models/attention_dispatch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index df328464e7e2..4e30c12a296e 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1497,6 +1497,7 @@ def _flash_varlen_attention( key_packed = torch.cat(key_valid, dim=0) value_packed = torch.cat(value_valid, dim=0) + lse = None out = flash_attn_varlen_func( q=query_packed, k=key_packed, @@ -1510,6 +1511,8 @@ def _flash_varlen_attention( causal=is_causal, return_attn_probs=return_lse, ) + if return_lse: + out, _, lse = out out = out.unflatten(0, (batch_size, -1)) @@ -2359,4 +2362,4 @@ def _xformers_attention( # ===== Default backend ===== _check_attention_backend_requirements(_AttentionBackendRegistry._active_backend) -_maybe_download_kernel_for_backend(_AttentionBackendRegistry._active_backend) \ No newline at end of file +_maybe_download_kernel_for_backend(_AttentionBackendRegistry._active_backend) From 5c6f8e396945f45962c209ee681170d499e12eaa Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 7 Dec 2025 12:18:13 +0000 Subject: [PATCH 11/38] undo sage changes --- src/diffusers/models/attention_dispatch.py | 143 +++++++++------------ 1 file changed, 63 insertions(+), 80 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 4e30c12a296e..dccefbe24b1c 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -514,10 +514,8 @@ def _prepare_for_flash_attn_or_sage_varlen_with_mask( attn_mask: torch.Tensor, device: Optional[torch.device] = None, ): - # For self-attention (Q, K, V from same sequence), seqlens_q should equal seqlens_k - # Both are computed from the mask which indicates valid tokens + seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) - seqlens_q = seqlens_k # In self-attention, query and key lengths are the same cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) @@ -1497,7 +1495,6 @@ def _flash_varlen_attention( key_packed = torch.cat(key_valid, dim=0) value_packed = torch.cat(value_valid, dim=0) - lse = None out = flash_attn_varlen_func( q=query_packed, k=key_packed, @@ -1511,12 +1508,9 @@ def _flash_varlen_attention( causal=is_causal, return_attn_probs=return_lse, ) - if return_lse: - out, _, lse = out - out = out.unflatten(0, (batch_size, -1)) - return (out, lse) if return_lse else out + return out @_AttentionBackendRegistry.register( @@ -2140,57 +2134,58 @@ def _sage_attention_hub( ) -> torch.Tensor: lse = None func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn - out = func( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - sm_scale=scale, - return_lse=return_lse, - ) - if return_lse: - out, lse, *_ = out + if _parallel_config is None: + out = func( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + sm_scale=scale, + return_lse=return_lse, + ) + if return_lse: + out, lse, *_ = out return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( AttentionBackendName.SAGE_VARLEN, - constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_qkv_dtype_bf16_or_fp16, _check_shape], + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], ) def _sage_varlen_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, - scale: Optional[float] = None, is_causal: bool = False, + scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if return_lse: + raise ValueError("Sage varlen backend does not support setting `return_lse=True`.") + batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( _prepare_for_flash_attn_or_sage_varlen( batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) ) - # The SageAttention kernel needs the query, key, and values to be of shape - # `[num_tokens, num_heads, head_dim]`. The number of tokens is the total number of tokens in the - # batch. - query_valid, key_valid, value_valid = [], [], [] + key_valid, value_valid = [], [] for b in range(batch_size): - query_valid.append(query[b, : seqlens_q[b]]) - key_valid.append(key[b, : seqlens_k[b]]) - value_valid.append(value[b, : seqlens_k[b]]) + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) - query_packed = torch.cat(query_valid, dim=0) + query_packed = query.flatten(0, 1) key_packed = torch.cat(key_valid, dim=0) value_packed = torch.cat(value_valid, dim=0) @@ -2202,33 +2197,17 @@ def _sage_varlen_attention( cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, - sm_scale=scale, is_causal=is_causal, - return_lse=return_lse, - ) - lse = None - if return_lse: - out, lse, *_ = out - - # The output of the SageAttention kernel is of shape `[num_tokens, num_heads, head_dim]`. - # We need to reshape it to `[batch_size, seq_len_q, num_heads, head_dim]`. - out_padded = torch.zeros( - (batch_size, seq_len_q, *out.shape[1:]), - dtype=out.dtype, - device=out.device, + sm_scale=scale, ) - for b in range(batch_size): - start, end = cu_seqlens_q[b], cu_seqlens_q[b + 1] - len_q = seqlens_q[b] - out_padded[b, :len_q] = out[start:end] - out = out_padded + out = out.unflatten(0, (batch_size, -1)) - return (out, lse) if return_lse else out + return out @_AttentionBackendRegistry.register( AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, - constraints=[_check_device_cuda_atleast_smXY(8, 9), _check_shape], + constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], ) def _sage_qk_int8_pv_fp8_cuda_attention( query: torch.Tensor, @@ -2239,16 +2218,15 @@ def _sage_qk_int8_pv_fp8_cuda_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - out = sageattn_qk_int8_pv_fp8_cuda( + return sageattn_qk_int8_pv_fp8_cuda( q=query, k=key, v=value, tensor_layout="NHD", is_causal=is_causal, sm_scale=scale, + return_lse=return_lse, ) - # The LSE is not returned from this kernel so we cannot support cases where it is needed - return out @_AttentionBackendRegistry.register( @@ -2264,16 +2242,15 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - out = sageattn_qk_int8_pv_fp8_cuda_sm90( + return sageattn_qk_int8_pv_fp8_cuda_sm90( q=query, k=key, v=value, tensor_layout="NHD", is_causal=is_causal, sm_scale=scale, + return_lse=return_lse, ) - # The LSE is not returned from this kernel so we cannot support cases where it is needed - return out @_AttentionBackendRegistry.register( @@ -2289,16 +2266,15 @@ def _sage_qk_int8_pv_fp16_cuda_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - out = sageattn_qk_int8_pv_fp16_cuda( + return sageattn_qk_int8_pv_fp16_cuda( q=query, k=key, v=value, tensor_layout="NHD", is_causal=is_causal, sm_scale=scale, + return_lse=return_lse, ) - # The LSE is not returned from this kernel so we cannot support cases where it is needed - return out @_AttentionBackendRegistry.register( @@ -2314,52 +2290,59 @@ def _sage_qk_int8_pv_fp16_triton_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - out = sageattn_qk_int8_pv_fp16_triton( + return sageattn_qk_int8_pv_fp16_triton( q=query, k=key, v=value, tensor_layout="NHD", is_causal=is_causal, sm_scale=scale, + return_lse=return_lse, ) - # The LSE is not returned from this kernel so we cannot support cases where it is needed - return out @_AttentionBackendRegistry.register( AttentionBackendName.XFORMERS, - constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], + constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], ) def _xformers_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, + enable_gqa: bool = False, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: if return_lse: - raise ValueError("Xformers attention backend does not support setting `return_lse=True`.") + raise ValueError("xformers attention backend does not support setting `return_lse=True`.") + + batch_size, seq_len_q, num_heads_q, _ = query.shape + _, seq_len_kv, num_heads_kv, _ = key.shape - op = xops.MemoryEfficientAttentionCkOp if is_causal: - op = op.WITH_AUTOMATIC_CAUSAL_MASK - # Removed the check for attn_mask: Optional[torch.Tensor] = None - # since it's removed from the function signature and is not supported. + attn_mask = xops.LowerTriangularMask() + elif attn_mask is not None: + if attn_mask.ndim == 2: + attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + elif attn_mask.ndim != 4: + raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") + attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) - out = xops.memory_efficient_attention( - q=query, - k=key, - v=value, - p=dropout_p, - scale=scale, - op=op, - ) - return out + if enable_gqa: + if num_heads_q % num_heads_kv != 0: + raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") + num_heads_per_group = num_heads_q // num_heads_kv + query = query.unflatten(2, (num_heads_kv, -1)) + key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) + + if enable_gqa: + out = out.flatten(2, 3) -# ===== Default backend ===== -_check_attention_backend_requirements(_AttentionBackendRegistry._active_backend) -_maybe_download_kernel_for_backend(_AttentionBackendRegistry._active_backend) + return out \ No newline at end of file From 5d434f63249ea354d4db0085eb9bec0a87938dc9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 7 Dec 2025 12:59:02 +0000 Subject: [PATCH 12/38] xformers support --- src/diffusers/models/attention_dispatch.py | 32 ++++++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index dccefbe24b1c..d1f469be94b2 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2327,10 +2327,36 @@ def _xformers_attention( attn_mask = xops.LowerTriangularMask() elif attn_mask is not None: if attn_mask.ndim == 2: - attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + # Convert 2D boolean mask to 4D for xformers + # attn_mask is [batch_size, seq_len_k] boolean: True means attend, False means mask + # xformers requires 4D masks [batch, heads, seq_q, seq_k] + # xformers expects additive bias: 0.0 for attend, -inf for mask + # Need memory alignment - create larger tensor and slice for alignment + original_seq_len = attn_mask.size(1) + aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8 + + # Create aligned 4D tensor and slice to ensure proper memory layout + aligned_mask = torch.zeros( + (batch_size, num_heads_q, seq_len_q, aligned_seq_len), + dtype=query.dtype, + device=query.device, + ) + # Fill in the actual mask values (converting boolean to additive) + # Expand 2D [batch, seq_k] -> 4D [batch, heads, seq_q, seq_k] + # Convert: True -> 0.0, False -> -inf + mask_2d = attn_mask # [batch, seq_len_k] + mask_additive = torch.where(mask_2d, 0.0, float("-inf")).type_as(query) + # Broadcast to [batch, heads, seq_q, seq_len_k] + aligned_mask[:, :, :, :original_seq_len] = mask_additive.view(batch_size, 1, 1, original_seq_len) + # Mask out the padding (already -inf from zeros -> where with default) + aligned_mask[:, :, :, original_seq_len:] = float("-inf") + + # Slice to actual size with proper alignment + attn_mask = aligned_mask[:, :, :, :seq_len_kv] elif attn_mask.ndim != 4: raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") - attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) + elif attn_mask.ndim == 4: + attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) if enable_gqa: if num_heads_q % num_heads_kv != 0: @@ -2345,4 +2371,4 @@ def _xformers_attention( if enable_gqa: out = out.flatten(2, 3) - return out \ No newline at end of file + return out From 71ba60368cc6b86d9fef2e7f1f0f6a9e0f946fcd Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 8 Dec 2025 10:56:18 +0000 Subject: [PATCH 13/38] hub fix --- src/diffusers/models/attention_dispatch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index d1f469be94b2..1295b0b96e8f 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2372,3 +2372,7 @@ def _xformers_attention( out = out.flatten(2, 3) return out + + +# Initialize: download kernel for the initial backend set via DIFFUSERS_ATTN_BACKEND +_maybe_download_kernel_for_backend(_AttentionBackendRegistry._active_backend) From afad3357526a5d089d4e4ca3a47f55d118b52f86 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 8 Dec 2025 17:06:00 +0000 Subject: [PATCH 14/38] fix torch compile issues --- .../models/transformers/transformer_qwenimage.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 4d3554d35a44..3a8fefc0b87c 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -149,7 +149,7 @@ def compute_text_seq_len_from_mask( """ batch_size, text_seq_len = encoder_hidden_states.shape[:2] if encoder_hidden_states_mask is None: - return text_seq_len, None + return text_seq_len, None, None if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len): raise ValueError( @@ -164,7 +164,11 @@ def compute_text_seq_len_from_mask( active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) has_active = encoder_hidden_states_mask.any(dim=1) per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) - rope_text_seq_len = max(text_seq_len, int(per_sample_len.max().item())) + + # Keep as tensor to avoid graph breaks in torch.compile + # torch.maximum works with mixed tensor/scalar and keeps result as tensor + text_seq_len_tensor = torch.tensor(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) + rope_text_seq_len = torch.maximum(text_seq_len_tensor, per_sample_len.max()) return rope_text_seq_len, per_sample_len, encoder_hidden_states_mask @@ -237,9 +241,9 @@ def forward( device: (`torch.device`): The device on which to perform the RoPE computation. """ - if self.pos_freqs.device != device: - self.pos_freqs = self.pos_freqs.to(device) - self.neg_freqs = self.neg_freqs.to(device) + # Move to device unconditionally to avoid graph breaks in torch.compile + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) if isinstance(video_fhw, list): video_fhw = video_fhw[0] From c78a1e96426b5772eb34a3ffb4f66e9ae56a1985 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 9 Dec 2025 08:22:04 +0000 Subject: [PATCH 15/38] fix tests --- .../test_models_transformer_qwenimage.py | 56 +++++++++++++------ 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index e56f7ab47deb..28a0359741bd 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -15,7 +15,6 @@ import unittest -import pytest import torch from diffusers import QwenImageTransformer2DModel @@ -92,36 +91,60 @@ def test_gradient_checkpointing_is_applied(self): super().test_gradient_checkpointing_is_applied(expected_set=expected_set) def test_infers_text_seq_len_from_mask(self): + """Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors.""" init_dict, inputs = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) - # Create a mask with only 2 valid tokens (rest are padding) + # Test 1: Contiguous mask with padding at the end (only first 2 tokens valid) encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid - inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask + rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask + ) - with torch.no_grad(): - output = model(**inputs) + # Verify rope_text_seq_len is returned as a tensor (for torch.compile compatibility) + self.assertIsInstance(rope_text_seq_len, torch.Tensor) + self.assertEqual(rope_text_seq_len.ndim, 0) # Should be scalar tensor - # The model should infer text_seq_len=2 from the mask for RoPE computation - self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + # Verify per_sample_len is computed correctly (max valid position + 1 = 2) + self.assertIsInstance(per_sample_len, torch.Tensor) + self.assertEqual(int(per_sample_len.max().item()), 2) - def test_builds_attention_mask_from_encoder_mask(self): - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict).to(torch_device) - - # Create a mask with padding on the last two tokens. - encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() - encoder_hidden_states_mask[:, -2:] = 0 + # Verify mask is normalized to bool dtype + self.assertTrue(normalized_mask.dtype == torch.bool) + self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values - inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask + # Verify rope_text_seq_len is at least the sequence length + self.assertGreaterEqual(int(rope_text_seq_len.item()), inputs["encoder_hidden_states"].shape[1]) + # Test 2: Verify model runs successfully with inferred values + inputs["encoder_hidden_states_mask"] = normalized_mask with torch.no_grad(): output = model(**inputs) - self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + # Test 3: Different mask pattern (padding at beginning) + encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone() + encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding + encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid + + rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask2 + ) + + # Max valid position is 6 (last token), so per_sample_len should be 7 + self.assertEqual(int(per_sample_len2.max().item()), 7) + self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values + + # Test 4: No mask provided (None case) + rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], None + ) + self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1]) + self.assertIsNone(per_sample_len_none) + self.assertIsNone(normalized_mask_none) + def test_non_contiguous_attention_mask(self): """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])""" init_dict, inputs = self.prepare_init_args_and_inputs_for_common() @@ -158,6 +181,5 @@ def prepare_init_args_and_inputs_for_common(self): def prepare_dummy_input(self, height, width): return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width) - @pytest.mark.xfail(condition=True, reason="RoPE needs to be revisited.", strict=True) def test_torch_compile_recompilation_and_graph_break(self): super().test_torch_compile_recompilation_and_graph_break() From d6d4b1d1bd6cee73d2b44856a94a9ab7065584b4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 9 Dec 2025 09:07:38 +0000 Subject: [PATCH 16/38] use _prepare_attn_mask_native --- src/diffusers/models/attention_dispatch.py | 52 ++++++++++++++++------ 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 1295b0b96e8f..c27ce5aed5bc 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1790,6 +1790,39 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): return out +def _prepare_attn_mask_native( + attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True +) -> torch.Tensor: + """ + Convert a 2D boolean attention mask to an additive mask, optionally reshaping to 4D for SDPA. + + Args: + attn_mask: 2D boolean tensor [batch_size, seq_len_k] where True means attend, False means mask out + target_dtype: The dtype to convert the mask to (usually query.dtype) + reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting + + Returns: + Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if + reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True. + """ + # Ensure it's boolean + if attn_mask.dtype != torch.bool: + attn_mask = attn_mask.bool() + + # Convert boolean to additive: True -> 0.0, False -> -inf + attn_mask = torch.where(attn_mask, 0.0, float("-inf")) + + # Convert to target dtype + attn_mask = attn_mask.to(dtype=target_dtype) + + # Optionally reshape to 4D for broadcasting in attention mechanisms + if reshape_4d: + batch_size, seq_len_k = attn_mask.shape + attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k) + + return attn_mask + + @_AttentionBackendRegistry.register( AttentionBackendName.NATIVE, constraints=[_check_device, _check_shape], @@ -1814,14 +1847,8 @@ def _native_attention( if attn_mask is not None and attn_mask.ndim == 2: # attn_mask is [batch_size, seq_len_k] boolean: True means attend, False means mask out # SDPA expects [batch_size, 1, 1, seq_len_k] additive mask: 0.0 for attend, -inf for mask out - batch_size, seq_len_k = attn_mask.shape - # Ensure it's boolean for torch.where - if attn_mask.dtype != torch.bool: - attn_mask = attn_mask.bool() - # Convert boolean to additive: True -> 0.0, False -> -inf - attn_mask = torch.where(attn_mask, 0.0, float("-inf")) - # Convert to query dtype and reshape to [batch_size, 1, 1, seq_len_k] for broadcasting - attn_mask = attn_mask.to(dtype=query.dtype).view(batch_size, 1, 1, seq_len_k) + # Use helper to convert boolean to additive mask and reshape to 4D + attn_mask = _prepare_attn_mask_native(attn_mask, target_dtype=query.dtype) if _parallel_config is None: query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) @@ -2342,12 +2369,10 @@ def _xformers_attention( device=query.device, ) # Fill in the actual mask values (converting boolean to additive) - # Expand 2D [batch, seq_k] -> 4D [batch, heads, seq_q, seq_k] - # Convert: True -> 0.0, False -> -inf - mask_2d = attn_mask # [batch, seq_len_k] - mask_additive = torch.where(mask_2d, 0.0, float("-inf")).type_as(query) + # Use helper to convert 2D boolean -> 4D additive mask + mask_additive = _prepare_attn_mask_native(attn_mask, target_dtype=query.dtype) # [batch, 1, 1, seq_len_k] # Broadcast to [batch, heads, seq_q, seq_len_k] - aligned_mask[:, :, :, :original_seq_len] = mask_additive.view(batch_size, 1, 1, original_seq_len) + aligned_mask[:, :, :, :original_seq_len] = mask_additive # Mask out the padding (already -inf from zeros -> where with default) aligned_mask[:, :, :, original_seq_len:] = float("-inf") @@ -2374,5 +2399,4 @@ def _xformers_attention( return out -# Initialize: download kernel for the initial backend set via DIFFUSERS_ATTN_BACKEND _maybe_download_kernel_for_backend(_AttentionBackendRegistry._active_backend) From e999b769dd53adc79c6e2b5ce95b4254cd0d28ea Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 9 Dec 2025 11:03:31 +0000 Subject: [PATCH 17/38] proper deprecation notice --- .../controlnets/controlnet_qwenimage.py | 22 ++++++++++--- .../transformers/transformer_qwenimage.py | 33 +++++++++++++++---- 2 files changed, 44 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index 567a6928f011..3fa5ea576a31 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from ..attention import AttentionMixin from ..cache_utils import CacheMixin from ..controlnets.controlnet import zero_module @@ -132,6 +132,7 @@ def forward( encoder_hidden_states_mask: torch.Tensor = None, timestep: torch.LongTensor = None, img_shapes: Optional[List[Tuple[int, int, int]]] = None, + txt_seq_lens: Optional[List[int]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: @@ -155,6 +156,9 @@ def forward( Used to indicate denoising step. img_shapes (`List[Tuple[int, int, int]]`, *optional*): Image shapes for RoPE computation. + txt_seq_lens (`List[int]`, *optional*): + **Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence + length. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -166,6 +170,17 @@ def forward( If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where the first element is the controlnet block samples. """ + # Handle deprecated txt_seq_lens parameter + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.37.0", + "Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in " + "version 0.37.0. The text sequence length is now automatically inferred from `encoder_hidden_states` " + "and `encoder_hidden_states_mask`.", + standard_warn=False, + ) + if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) @@ -193,10 +208,7 @@ def forward( encoder_hidden_states, encoder_hidden_states_mask ) - if text_seq_lens_per_sample is not None: - joint_attention_kwargs.setdefault("text_seq_lens", text_seq_lens_per_sample) - - image_rotary_emb = self.pos_embed(img_shapes, text_seq_len, device=hidden_states.device) + image_rotary_emb = self.pos_embed(img_shapes, text_seq_len=text_seq_len, device=hidden_states.device) timestep = timestep.to(hidden_states.dtype) encoder_hidden_states = self.txt_norm(encoder_hidden_states) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 3a8fefc0b87c..d44aa436f727 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, FeedForward @@ -229,18 +229,39 @@ def rope_params(self, index, dim, theta=10000): def forward( self, video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], - txt_seq_len: int, - device: torch.device, + txt_seq_len: Optional[Union[int, torch.Tensor]] = None, + device: torch.device = None, + txt_seq_lens: Optional[List[int]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): A list of 3 integers [frame, height, width] representing the shape of the video. - txt_seq_len (`int`): - The length of the text sequence. This should match the encoder hidden states length. - device: (`torch.device`): + txt_seq_len (`int` or `torch.Tensor`, *optional*): + The length of the text sequence. This should match the encoder hidden states length. Can be either an + int or a scalar tensor (for torch.compile compatibility). + device: (`torch.device`, *optional*): The device on which to perform the RoPE computation. + txt_seq_lens (`List[int]`, *optional*, **Deprecated**): + Deprecated parameter. Use `txt_seq_len` instead. If provided, the maximum value will be used. """ + # Handle deprecated txt_seq_lens parameter + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.37.0", + "Passing `txt_seq_lens` is deprecated and will be removed in version 0.37.0. " + "Please use `txt_seq_len` instead (singular, not plural). " + "The new parameter accepts a single int or tensor value instead of a list.", + standard_warn=False, + ) + if txt_seq_len is None: + # Use max of txt_seq_lens for backward compatibility + txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens + + if txt_seq_len is None: + raise ValueError("Either `txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.") + # Move to device unconditionally to avoid graph breaks in torch.compile self.pos_freqs = self.pos_freqs.to(device) self.neg_freqs = self.neg_freqs.to(device) From 8115f0b2c3756f0d82ec327eed5e926f37c9e7d4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 9 Dec 2025 15:54:16 +0000 Subject: [PATCH 18/38] add deprecate to txt_seq_lens --- .../models/controlnets/controlnet_qwenimage.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index 3fa5ea576a31..9f2408fbcbc1 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -285,9 +285,19 @@ def forward( encoder_hidden_states_mask: torch.Tensor = None, timestep: torch.LongTensor = None, img_shapes: Optional[List[Tuple[int, int, int]]] = None, + txt_seq_lens: Optional[List[int]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[QwenImageControlNetOutput, Tuple]: + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.37.0", + "Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be " + "removed in version 0.37.0. The text sequence length is now automatically inferred from " + "`encoder_hidden_states` and `encoder_hidden_states_mask`.", + standard_warn=False, + ) # ControlNet-Union with multiple conditions # only load one ControlNet for saving memories if len(self.nets) == 1: From 3b1510c3e208daeb408d5856c4afcb21f7a1fcef Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 10 Dec 2025 10:36:21 +0100 Subject: [PATCH 19/38] Update src/diffusers/models/transformers/transformer_qwenimage.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_qwenimage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index d44aa436f727..b718cd7bd525 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -229,9 +229,9 @@ def rope_params(self, index, dim, theta=10000): def forward( self, video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], - txt_seq_len: Optional[Union[int, torch.Tensor]] = None, - device: torch.device = None, txt_seq_lens: Optional[List[int]] = None, + device: torch.device = None, + txt_seq_len: Optional[Union[int, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: From 3676d8e7df38f2352b11b4886c8bc2b28ac26bd1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 10 Dec 2025 10:36:32 +0100 Subject: [PATCH 20/38] Update src/diffusers/models/transformers/transformer_qwenimage.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_qwenimage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index b718cd7bd525..b2a924a11b7c 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -722,7 +722,7 @@ def forward( else self.time_text_embed(timestep, guidance, hidden_states) ) - image_rotary_emb = self.pos_embed(img_shapes, text_seq_len, device=hidden_states.device) + image_rotary_emb = self.pos_embed(img_shapes, text_seq_len=text_seq_len, device=hidden_states.device) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: From 9ed0ffd5adae25f48ab8854a90b513d12a18ccdf Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 10 Dec 2025 09:53:44 +0000 Subject: [PATCH 21/38] Only create the mask if there's actual padding --- .../controlnets/controlnet_qwenimage.py | 4 +-- .../transformers/transformer_qwenimage.py | 34 ++++++++++++++----- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index 9f2408fbcbc1..4e9eacc417f9 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -204,11 +204,11 @@ def forward( temb = self.time_text_embed(timestep, hidden_states) # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask - text_seq_len, text_seq_lens_per_sample, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( encoder_hidden_states, encoder_hidden_states_mask ) - image_rotary_emb = self.pos_embed(img_shapes, text_seq_len=text_seq_len, device=hidden_states.device) + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_len=text_seq_len, device=hidden_states.device) timestep = timestep.to(hidden_states.dtype) encoder_hidden_states = self.txt_norm(encoder_hidden_states) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index b2a924a11b7c..b85413076e33 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -385,6 +385,7 @@ def __call__( # If an encoder_hidden_states_mask is provided, create a joint attention mask. # The encoder_hidden_states_mask is expected to have 1.0 for valid tokens and 0.0 for padding. # We convert it to a boolean mask where True means "attend" and False means "mask out" (don't attend). + # Only create the mask if there's actual padding, otherwise keep attention_mask=None for better SDPA performance. if encoder_hidden_states_mask is not None and attention_mask is None: batch_size, image_seq_len = hidden_states.shape[:2] text_seq_len = encoder_hidden_states.shape[1] @@ -400,14 +401,16 @@ def __call__( f"must match encoder_hidden_states sequence length ({text_seq_len})." ) - # Convert mask to boolean: 1/1.0 -> True (attend), 0/0.0 -> False (don't attend) + # Only create mask if there's actual padding (i.e., some False/0 values) + # When all values are True/1.0, passing attention_mask=None is more efficient for SDPA text_attention_mask = encoder_hidden_states_mask.bool() - image_attention_mask = torch.ones( - (batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device - ) - # Create 2D joint mask [batch_size, text_seq_len + image_seq_len] - # The attention dispatch will normalize this and extract sequence lengths - attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) + if not text_attention_mask.all(): + image_attention_mask = torch.ones( + (batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device + ) + # Create 2D joint mask [batch_size, text_seq_len + image_seq_len] + # The attention dispatch will normalize this and extract sequence lengths + attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) # Compute joint attention joint_hidden_states = dispatch_attention_fn( @@ -649,6 +652,7 @@ def forward( encoder_hidden_states_mask: torch.Tensor = None, timestep: torch.LongTensor = None, img_shapes: Optional[List[Tuple[int, int, int]]] = None, + txt_seq_lens: Optional[List[int]] = None, guidance: torch.Tensor = None, # TODO: this should probably be removed attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_block_samples=None, @@ -670,6 +674,9 @@ def forward( Used to indicate denoising step. img_shapes (`List[Tuple[int, int, int]]`, *optional*): Image shapes for RoPE computation. + txt_seq_lens (`List[int]`, *optional*, **Deprecated**): + Deprecated parameter. Use `encoder_hidden_states_mask` instead. If provided, the maximum value will be + used to compute RoPE sequence length. guidance (`torch.Tensor`, *optional*): Guidance tensor for conditional generation. attention_kwargs (`dict`, *optional*): @@ -686,6 +693,15 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.37.0", + "Passing `txt_seq_lens` is deprecated and will be removed in version 0.37.0. " + "Please use `txt_seq_len` instead (singular, not plural). " + "The new parameter accepts a single int or tensor value instead of a list.", + standard_warn=False, + ) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -709,7 +725,7 @@ def forward( encoder_hidden_states = self.txt_in(encoder_hidden_states) # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask - text_seq_len, text_seq_lens_per_sample, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( encoder_hidden_states, encoder_hidden_states_mask ) @@ -722,7 +738,7 @@ def forward( else self.time_text_embed(timestep, guidance, hidden_states) ) - image_rotary_emb = self.pos_embed(img_shapes, text_seq_len=text_seq_len, device=hidden_states.device) + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_len=text_seq_len, device=hidden_states.device) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: From e26e7b36460667454bf55a61d3c907b73aebac5b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 10 Dec 2025 13:20:30 +0000 Subject: [PATCH 22/38] fix order of docstrings --- .../models/transformers/transformer_qwenimage.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index b85413076e33..201503341594 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -237,13 +237,13 @@ def forward( Args: video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): A list of 3 integers [frame, height, width] representing the shape of the video. + txt_seq_lens (`List[int]`, *optional*, **Deprecated**): + Deprecated parameter. Use `txt_seq_len` instead. If provided, the maximum value will be used. + device: (`torch.device`, *optional*): + The device on which to perform the RoPE computation. txt_seq_len (`int` or `torch.Tensor`, *optional*): The length of the text sequence. This should match the encoder hidden states length. Can be either an int or a scalar tensor (for torch.compile compatibility). - device: (`torch.device`, *optional*): - The device on which to perform the RoPE computation. - txt_seq_lens (`List[int]`, *optional*, **Deprecated**): - Deprecated parameter. Use `txt_seq_len` instead. If provided, the maximum value will be used. """ # Handle deprecated txt_seq_lens parameter if txt_seq_lens is not None: From 59e388296bee3834f275ecccc6d8126d72ecd54d Mon Sep 17 00:00:00 2001 From: cdutr Date: Thu, 11 Dec 2025 19:14:44 -0300 Subject: [PATCH 23/38] Adds performance benchmarks and optimization details for QwenImage Enhances documentation with comprehensive performance insights for QwenImage pipeline: --- docs/source/en/api/pipelines/qwenimage.md | 60 ++++++++++++++++++++++- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/pipelines/qwenimage.md b/docs/source/en/api/pipelines/qwenimage.md index b3dd3dd93618..8a2fa3f47f06 100644 --- a/docs/source/en/api/pipelines/qwenimage.md +++ b/docs/source/en/api/pipelines/qwenimage.md @@ -108,12 +108,68 @@ pipe = QwenImageEditPlusPipeline.from_pretrained( image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg") image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png") image = pipe( - image=[image_1, image_2], - prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''', + image=[image_1, image_2], + prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''', num_inference_steps=50 ).images[0] ``` +## Performance + +### Attention Backends + +QwenImage supports multiple attention backends. Benchmarks on A100 80GB: + +**Single Image (30 steps, 512x512):** + +| Backend | Time (s) | +|---------|----------| +| flash_hub | 2.34 | +| native | 2.38 | +| xformers | 2.58 | +| flash_varlen | 2.78 | + +**Batch (2 images, 25 steps, 512x512):** + +| Backend | Time (s) | +|---------|----------| +| flash_hub | 2.85 | +| native | 3.16 | +| flash_varlen | 3.29 | +| xformers | 3.52 | + +### torch.compile + +Using `torch.compile` provides significant speedups with a one-time compilation overhead: + +```python +import torch +from diffusers import QwenImagePipeline + +pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda") +pipe.transformer = torch.compile(pipe.transformer) + +# First call triggers compilation (~7s overhead on A100) +# Subsequent calls see ~2.4x speedup +image = pipe("a cat", num_inference_steps=50).images[0] +``` + +### Batched Inference with Variable-Length Prompts + +When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output. + +```python +# CFG with different prompt lengths works correctly +image = pipe( + prompt="A cat", + negative_prompt="blurry, low quality, distorted", + true_cfg_scale=3.5, + num_inference_steps=50, +).images[0] +``` + +For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f). + ## QwenImagePipeline [[autodoc]] QwenImagePipeline From 60bd4543dd5dbd460af46b171dae6a29459ca726 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 12 Dec 2025 16:24:40 +0000 Subject: [PATCH 24/38] rope_text_seq_len = text_seq_len --- .../transformers/transformer_qwenimage.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 201503341594..effc508ad9a8 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -165,10 +165,9 @@ def compute_text_seq_len_from_mask( has_active = encoder_hidden_states_mask.any(dim=1) per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) + # For RoPE, we use the full text_seq_len (since per_sample_len.max() <= text_seq_len always) # Keep as tensor to avoid graph breaks in torch.compile - # torch.maximum works with mixed tensor/scalar and keeps result as tensor - text_seq_len_tensor = torch.tensor(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) - rope_text_seq_len = torch.maximum(text_seq_len_tensor, per_sample_len.max()) + rope_text_seq_len = torch.tensor(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) return rope_text_seq_len, per_sample_len, encoder_hidden_states_mask @@ -266,6 +265,18 @@ def forward( self.pos_freqs = self.pos_freqs.to(device) self.neg_freqs = self.neg_freqs.to(device) + # Validate batch inference with variable-sized images + if isinstance(video_fhw, list) and len(video_fhw) > 1: + # Check if all instances have the same size + first_fhw = video_fhw[0] + if not all(fhw == first_fhw for fhw in video_fhw): + logger.warning( + "Batch inference with variable-sized images is not currently supported in QwenEmbedRope. " + "All images in the batch should have the same dimensions (frame, height, width). " + f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} " + "for RoPE computation, which may lead to incorrect results for other images in the batch." + ) + if isinstance(video_fhw, list): video_fhw = video_fhw[0] if not isinstance(video_fhw, list): From a5abbb8cf6bd0fce8d9dabbbad182a67df78d44c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 12 Dec 2025 16:45:21 +0000 Subject: [PATCH 25/38] rename to max_txt_seq_len --- .../controlnets/controlnet_qwenimage.py | 2 +- .../transformers/transformer_qwenimage.py | 26 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index 4e9eacc417f9..f759d6797cf9 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -208,7 +208,7 @@ def forward( encoder_hidden_states, encoder_hidden_states_mask ) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_len=text_seq_len, device=hidden_states.device) + image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) timestep = timestep.to(hidden_states.dtype) encoder_hidden_states = self.txt_norm(encoder_hidden_states) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index effc508ad9a8..46f0c4a35d8f 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -230,19 +230,19 @@ def forward( video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], txt_seq_lens: Optional[List[int]] = None, device: torch.device = None, - txt_seq_len: Optional[Union[int, torch.Tensor]] = None, + max_txt_seq_len: Optional[Union[int, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): A list of 3 integers [frame, height, width] representing the shape of the video. txt_seq_lens (`List[int]`, *optional*, **Deprecated**): - Deprecated parameter. Use `txt_seq_len` instead. If provided, the maximum value will be used. + Deprecated parameter. Use `max_txt_seq_len` instead. If provided, the maximum value will be used. device: (`torch.device`, *optional*): The device on which to perform the RoPE computation. - txt_seq_len (`int` or `torch.Tensor`, *optional*): - The length of the text sequence. This should match the encoder hidden states length. Can be either an - int or a scalar tensor (for torch.compile compatibility). + max_txt_seq_len (`int` or `torch.Tensor`, *optional*): + The maximum text sequence length for RoPE computation. This should match the encoder hidden states + sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). """ # Handle deprecated txt_seq_lens parameter if txt_seq_lens is not None: @@ -250,16 +250,16 @@ def forward( "txt_seq_lens", "0.37.0", "Passing `txt_seq_lens` is deprecated and will be removed in version 0.37.0. " - "Please use `txt_seq_len` instead (singular, not plural). " - "The new parameter accepts a single int or tensor value instead of a list.", + "Please use `max_txt_seq_len` instead. " + "The new parameter accepts a single int or tensor value representing the maximum text sequence length.", standard_warn=False, ) - if txt_seq_len is None: + if max_txt_seq_len is None: # Use max of txt_seq_lens for backward compatibility - txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens + max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens - if txt_seq_len is None: - raise ValueError("Either `txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.") + if max_txt_seq_len is None: + raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.") # Move to device unconditionally to avoid graph breaks in torch.compile self.pos_freqs = self.pos_freqs.to(device) @@ -296,7 +296,7 @@ def forward( else: max_vid_index = max(height, width, max_vid_index) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + txt_seq_len, ...] + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_txt_seq_len, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @@ -749,7 +749,7 @@ def forward( else self.time_text_embed(timestep, guidance, hidden_states) ) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_len=text_seq_len, device=hidden_states.device) + image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: From 22cb03dba3b18289dac8765f04f99271976ce64c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 17 Dec 2025 13:00:20 +0000 Subject: [PATCH 26/38] removed deprecated args --- .../transformers/transformer_qwenimage.py | 40 +++++++++++++++---- .../qwenimage/pipeline_qwenimage_layered.py | 6 --- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 3857dd0e600c..9570696e76c5 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -367,14 +367,39 @@ def rope_params(self, index, dim, theta=10000): freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs - def forward(self, video_fhw, txt_seq_lens, device): + def forward( + self, + video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], + max_txt_seq_len: Union[int, torch.Tensor], + device: torch.device = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: - txt_length: [bs] a list of 1 integers representing the length of the text + Args: + video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): + A list of 3 integers [frame, height, width] representing the shape of the video, or a list of layer + structures. + max_txt_seq_len (`int` or `torch.Tensor`): + The maximum text sequence length for RoPE computation. This should match the encoder hidden states + sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). + device: (`torch.device`, *optional*): + The device on which to perform the RoPE computation. """ - if self.pos_freqs.device != device: - self.pos_freqs = self.pos_freqs.to(device) - self.neg_freqs = self.neg_freqs.to(device) + # Move to device unconditionally to avoid graph breaks in torch.compile + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + # Validate batch inference with variable-sized images + # In Layer3DRope, the outer list represents batch, inner list/tuple represents layers + if isinstance(video_fhw, list) and len(video_fhw) > 1: + # Check if this is batch inference (list of layer lists/tuples) + first_entry = video_fhw[0] + if not all(entry == first_entry for entry in video_fhw): + logger.warning( + "Batch inference with variable-sized images is not currently supported in QwenEmbedLayer3DRope. " + "All images in the batch should have the same layer structure. " + f"Detected sizes: {video_fhw}. Using the first image's layer structure {first_entry} " + "for RoPE computation, which may lead to incorrect results for other images in the batch." + ) if isinstance(video_fhw, list): video_fhw = video_fhw[0] @@ -400,8 +425,7 @@ def forward(self, video_fhw, txt_seq_lens, device): max_vid_index = max(height, width, max_vid_index) max_vid_index = max(max_vid_index, layer_num) - max_len = max(txt_seq_lens) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_txt_seq_len, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index 7bb12c26baa4..6ad972f8774d 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -781,10 +781,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) is_rgb = torch.tensor([0] * batch_size).to(device=device, dtype=torch.long) # 6. Denoising loop self.scheduler.set_begin_index(0) @@ -809,7 +805,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, additional_t_cond=is_rgb, return_dict=False, @@ -825,7 +820,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, additional_t_cond=is_rgb, return_dict=False, From 125a3a41a0a9957941ff68a2c3998402ff19d2ee Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 17 Dec 2025 15:01:13 +0000 Subject: [PATCH 27/38] undo unrelated change --- src/diffusers/models/controlnets/controlnet_qwenimage.py | 1 - src/diffusers/models/transformers/transformer_qwenimage.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index f759d6797cf9..d317c1246e6d 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -185,7 +185,6 @@ def forward( joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: - joint_attention_kwargs = {} lora_scale = 1.0 if USE_PEFT_BACKEND: diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 9570696e76c5..c40d19ac1beb 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -919,7 +919,6 @@ def forward( attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) else: - attention_kwargs = {} lora_scale = 1.0 if USE_PEFT_BACKEND: From b5b63421b85d00a9a197bd4491bc3a33e0af5f73 Mon Sep 17 00:00:00 2001 From: cdutr Date: Wed, 17 Dec 2025 13:35:26 -0300 Subject: [PATCH 28/38] Updates QwenImage performance documentation Removes detailed attention backend benchmarks and simplifies torch.compile performance description Focuses on key performance improvement with torch.compile, highlighting the specific speedup from 4.70s to 1.93s on an A100 GPU Streamlines the documentation to provide more concise and actionable performance insights --- docs/source/en/api/pipelines/qwenimage.md | 28 +++-------------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/docs/source/en/api/pipelines/qwenimage.md b/docs/source/en/api/pipelines/qwenimage.md index 8a2fa3f47f06..7ea6a5ddfb60 100644 --- a/docs/source/en/api/pipelines/qwenimage.md +++ b/docs/source/en/api/pipelines/qwenimage.md @@ -116,31 +116,9 @@ image = pipe( ## Performance -### Attention Backends - -QwenImage supports multiple attention backends. Benchmarks on A100 80GB: - -**Single Image (30 steps, 512x512):** - -| Backend | Time (s) | -|---------|----------| -| flash_hub | 2.34 | -| native | 2.38 | -| xformers | 2.58 | -| flash_varlen | 2.78 | - -**Batch (2 images, 25 steps, 512x512):** - -| Backend | Time (s) | -|---------|----------| -| flash_hub | 2.85 | -| native | 3.16 | -| flash_varlen | 3.29 | -| xformers | 3.52 | - ### torch.compile -Using `torch.compile` provides significant speedups with a one-time compilation overhead: +Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s): ```python import torch @@ -149,8 +127,8 @@ from diffusers import QwenImagePipeline pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda") pipe.transformer = torch.compile(pipe.transformer) -# First call triggers compilation (~7s overhead on A100) -# Subsequent calls see ~2.4x speedup +# First call triggers compilation (~7s overhead) +# Subsequent calls run at ~2.4x faster image = pipe("a cat", num_inference_steps=50).images[0] ``` From 61f526581e7fa018b1f51c6355c5fcb55ff97429 Mon Sep 17 00:00:00 2001 From: cdutr Date: Wed, 17 Dec 2025 17:23:26 -0300 Subject: [PATCH 29/38] Updates deprecation warnings for txt_seq_lens parameter Extends deprecation timeline for txt_seq_lens from version 0.37.0 to 0.39.0 across multiple Qwen image-related models Adds a new unit test to verify the deprecation warning behavior for the txt_seq_lens parameter --- .../models/controlnets/controlnet_qwenimage.py | 8 ++++---- .../transformers/transformer_qwenimage.py | 8 ++++---- .../test_models_transformer_qwenimage.py | 17 +++++++++++++++++ 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index d317c1246e6d..4c85b6ec3852 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -174,9 +174,9 @@ def forward( if txt_seq_lens is not None: deprecate( "txt_seq_lens", - "0.37.0", + "0.39.0", "Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in " - "version 0.37.0. The text sequence length is now automatically inferred from `encoder_hidden_states` " + "version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` " "and `encoder_hidden_states_mask`.", standard_warn=False, ) @@ -291,9 +291,9 @@ def forward( if txt_seq_lens is not None: deprecate( "txt_seq_lens", - "0.37.0", + "0.39.0", "Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be " - "removed in version 0.37.0. The text sequence length is now automatically inferred from " + "removed in version 0.39.0. The text sequence length is now automatically inferred from " "`encoder_hidden_states` and `encoder_hidden_states_mask`.", standard_warn=False, ) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index c40d19ac1beb..fa31e7d37080 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -258,8 +258,8 @@ def forward( if txt_seq_lens is not None: deprecate( "txt_seq_lens", - "0.37.0", - "Passing `txt_seq_lens` is deprecated and will be removed in version 0.37.0. " + "0.39.0", + "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " "Please use `max_txt_seq_len` instead. " "The new parameter accepts a single int or tensor value representing the maximum text sequence length.", standard_warn=False, @@ -909,8 +909,8 @@ def forward( if txt_seq_lens is not None: deprecate( "txt_seq_lens", - "0.37.0", - "Passing `txt_seq_lens` is deprecated and will be removed in version 0.37.0. " + "0.39.0", + "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " "Please use `txt_seq_len` instead (singular, not plural). " "The new parameter accepts a single int or tensor value instead of a list.", standard_warn=False, diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 28a0359741bd..62febf97acc1 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -145,6 +145,23 @@ def test_infers_text_seq_len_from_mask(self): self.assertIsNone(per_sample_len_none) self.assertIsNone(normalized_mask_none) + def test_deprecated_txt_seq_lens_warning(self): + """Test that passing the deprecated txt_seq_lens parameter raises a FutureWarning.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + # Add the deprecated txt_seq_lens parameter + inputs["txt_seq_lens"] = [inputs["encoder_hidden_states"].shape[1]] + + with self.assertWarns(FutureWarning) as warning: + with torch.no_grad(): + _ = model(**inputs) + + # Verify the warning message mentions the deprecated parameter + self.assertIn("txt_seq_lens", str(warning.warning)) + self.assertIn("deprecated", str(warning.warning).lower()) + def test_non_contiguous_attention_mask(self): """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])""" init_dict, inputs = self.prepare_init_args_and_inputs_for_common() From 2ef38e2c3457c2bc4dcd7f5d87d60c747985fb25 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 17 Dec 2025 21:42:19 +0000 Subject: [PATCH 30/38] fix compile --- .../transformers/transformer_qwenimage.py | 62 ++++++------ .../test_models_transformer_qwenimage.py | 95 ++++++++++++++++++- 2 files changed, 121 insertions(+), 36 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index c40d19ac1beb..f1c5fdc89021 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -165,12 +165,7 @@ def compute_text_seq_len_from_mask( active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) has_active = encoder_hidden_states_mask.any(dim=1) per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) - - # For RoPE, we use the full text_seq_len (since per_sample_len.max() <= text_seq_len always) - # Keep as tensor to avoid graph breaks in torch.compile - rope_text_seq_len = torch.tensor(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) - - return rope_text_seq_len, per_sample_len, encoder_hidden_states_mask + return text_seq_len, per_sample_len, encoder_hidden_states_mask class QwenTimestepProjEmbeddings(nn.Module): @@ -271,10 +266,6 @@ def forward( if max_txt_seq_len is None: raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.") - # Move to device unconditionally to avoid graph breaks in torch.compile - self.pos_freqs = self.pos_freqs.to(device) - self.neg_freqs = self.neg_freqs.to(device) - # Validate batch inference with variable-sized images if isinstance(video_fhw, list) and len(video_fhw) > 1: # Check if all instances have the same size @@ -297,8 +288,7 @@ def forward( for idx, fhw in enumerate(video_fhw): frame, height, width = fhw # RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs - video_freq = self._compute_video_freqs(frame, height, width, idx) - video_freq = video_freq.to(device) + video_freq = self._compute_video_freqs(frame, height, width, idx, device) vid_freqs.append(video_freq) if self.scale_rope: @@ -306,16 +296,21 @@ def forward( else: max_vid_index = max(height, width, max_vid_index) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_txt_seq_len, ...] + max_txt_seq_len_int = int(max_txt_seq_len) + # Create device-specific copy for text freqs without modifying self.pos_freqs + txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=128) - def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor: + def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None) -> torch.Tensor: seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -384,10 +379,6 @@ def forward( device: (`torch.device`, *optional*): The device on which to perform the RoPE computation. """ - # Move to device unconditionally to avoid graph breaks in torch.compile - self.pos_freqs = self.pos_freqs.to(device) - self.neg_freqs = self.neg_freqs.to(device) - # Validate batch inference with variable-sized images # In Layer3DRope, the outer list represents batch, inner list/tuple represents layers if isinstance(video_fhw, list) and len(video_fhw) > 1: @@ -412,11 +403,10 @@ def forward( for idx, fhw in enumerate(video_fhw): frame, height, width = fhw if idx != layer_num: - video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = self._compute_video_freqs(frame, height, width, idx, device) else: ### For the condition image, we set the layer index to -1 - video_freq = self._compute_condition_freqs(frame, height, width) - video_freq = video_freq.to(device) + video_freq = self._compute_condition_freqs(frame, height, width, device) vid_freqs.append(video_freq) if self.scale_rope: @@ -425,16 +415,21 @@ def forward( max_vid_index = max(height, width, max_vid_index) max_vid_index = max(max_vid_index, layer_num) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_txt_seq_len, ...] + max_txt_seq_len_int = int(max_txt_seq_len) + # Create device-specific copy for text freqs without modifying self.pos_freqs + txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=None) - def _compute_video_freqs(self, frame, height, width, idx=0): + def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None): seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -450,10 +445,13 @@ def _compute_video_freqs(self, frame, height, width, idx=0): return freqs.clone().contiguous() @functools.lru_cache(maxsize=None) - def _compute_condition_freqs(self, frame, height, width): + def _compute_condition_freqs(self, frame, height, width, device: torch.device = None): seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -911,8 +909,8 @@ def forward( "txt_seq_lens", "0.37.0", "Passing `txt_seq_lens` is deprecated and will be removed in version 0.37.0. " - "Please use `txt_seq_len` instead (singular, not plural). " - "The new parameter accepts a single int or tensor value instead of a list.", + "Please use `encoder_hidden_states_mask` instead. " + "The mask-based approach is more flexible and supports variable-length sequences.", standard_warn=False, ) if attention_kwargs is not None: diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 28a0359741bd..5ee63d6f273e 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -103,9 +103,8 @@ def test_infers_text_seq_len_from_mask(self): inputs["encoder_hidden_states"], encoder_hidden_states_mask ) - # Verify rope_text_seq_len is returned as a tensor (for torch.compile compatibility) - self.assertIsInstance(rope_text_seq_len, torch.Tensor) - self.assertEqual(rope_text_seq_len.ndim, 0) # Should be scalar tensor + # Verify rope_text_seq_len is returned as an int (for torch.compile compatibility) + self.assertIsInstance(rope_text_seq_len, int) # Verify per_sample_len is computed correctly (max valid position + 1 = 2) self.assertIsInstance(per_sample_len, torch.Tensor) @@ -116,7 +115,7 @@ def test_infers_text_seq_len_from_mask(self): self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values # Verify rope_text_seq_len is at least the sequence length - self.assertGreaterEqual(int(rope_text_seq_len.item()), inputs["encoder_hidden_states"].shape[1]) + self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1]) # Test 2: Verify model runs successfully with inferred values inputs["encoder_hidden_states_mask"] = normalized_mask @@ -142,6 +141,7 @@ def test_infers_text_seq_len_from_mask(self): inputs["encoder_hidden_states"], None ) self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1]) + self.assertIsInstance(rope_text_seq_len_none, int) self.assertIsNone(per_sample_len_none) self.assertIsNone(normalized_mask_none) @@ -162,6 +162,7 @@ def test_non_contiguous_attention_mask(self): ) self.assertEqual(int(per_sample_len.max().item()), 5) self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1]) + self.assertIsInstance(inferred_rope_len, int) self.assertTrue(normalized_mask.dtype == torch.bool) inputs["encoder_hidden_states_mask"] = normalized_mask @@ -171,6 +172,92 @@ def test_non_contiguous_attention_mask(self): self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + def test_txt_seq_lens_deprecation(self): + """Test that passing txt_seq_lens raises a deprecation warning.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Prepare inputs with txt_seq_lens (deprecated parameter) + txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]] + + # Remove encoder_hidden_states_mask to use the deprecated path + inputs_with_deprecated = inputs.copy() + inputs_with_deprecated.pop("encoder_hidden_states_mask") + inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens + + # Test that deprecation warning is raised + with self.assertWarns(FutureWarning) as warning_context: + with torch.no_grad(): + output = model(**inputs_with_deprecated) + + # Verify the warning message mentions the deprecation + warning_message = str(warning_context.warning) + self.assertIn("txt_seq_lens", warning_message) + self.assertIn("deprecated", warning_message) + self.assertIn("encoder_hidden_states_mask", warning_message) + + # Verify the model still works correctly despite the deprecation + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + def test_layered_model_with_mask(self): + """Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model).""" + # Create layered model config + init_dict = { + "patch_size": 2, + "in_channels": 16, + "out_channels": 16, + "num_layers": 2, + "attention_head_dim": 128, + "num_attention_heads": 4, + "joint_attention_dim": 16, + "use_layer3d_rope": True, # Enable layered RoPE + } + + model = self.model_class(**init_dict).to(torch_device) + + # Verify the model uses QwenEmbedLayer3DRope + from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope + + self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope) + + # Test single generation with layered structure + batch_size = 1 + text_seq_len = 7 + img_h, img_w = 4, 4 + layers = 4 + + # For layered model: (layers + 1) because we have N layers + 1 combined image + hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device) + encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device) + + # Create mask with some padding + encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device) + encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens + + timestep = torch.tensor([1.0]).to(torch_device) + + # Layer structure: 4 layers + 1 condition image + img_shapes = [ + [ + (1, img_h, img_w), # layer 0 + (1, img_h, img_w), # layer 1 + (1, img_h, img_w), # layer 2 + (1, img_h, img_w), # layer 3 + (1, img_h, img_w), # condition image (last one gets special treatment) + ] + ] + + with torch.no_grad(): + output = model( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + timestep=timestep, + img_shapes=img_shapes, + ) + + self.assertEqual(output.sample.shape[1], hidden_states.shape[1]) + class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = QwenImageTransformer2DModel From 35efa0638d256726171ea2c669df75a05dc9c194 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 17 Dec 2025 21:44:04 +0000 Subject: [PATCH 31/38] formatting --- src/diffusers/models/transformers/transformer_qwenimage.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 28eb1e2652da..372e1cff34e1 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -304,7 +304,9 @@ def forward( return vid_freqs, txt_freqs @functools.lru_cache(maxsize=128) - def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None) -> torch.Tensor: + def _compute_video_freqs( + self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None + ) -> torch.Tensor: seq_lens = frame * height * width pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs From 50c48152e056243a22629c911ce5e05e7acd6cad Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 17 Dec 2025 21:56:02 +0000 Subject: [PATCH 32/38] fix compile tests --- .../transformers/transformer_qwenimage.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 372e1cff34e1..217c64bf4b51 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -560,16 +560,15 @@ def __call__( f"must match encoder_hidden_states sequence length ({text_seq_len})." ) - # Only create mask if there's actual padding (i.e., some False/0 values) - # When all values are True/1.0, passing attention_mask=None is more efficient for SDPA + # Create joint attention mask + # torch.compile compatible: always create mask when encoder_hidden_states_mask is provided text_attention_mask = encoder_hidden_states_mask.bool() - if not text_attention_mask.all(): - image_attention_mask = torch.ones( - (batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device - ) - # Create 2D joint mask [batch_size, text_seq_len + image_seq_len] - # The attention dispatch will normalize this and extract sequence lengths - attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) + image_attention_mask = torch.ones( + (batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device + ) + # Create 2D joint mask [batch_size, text_seq_len + image_seq_len] + # The attention dispatch will normalize this and extract sequence lengths + attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) # Compute joint attention joint_hidden_states = dispatch_attention_fn( From 1433783b1e2ce8bcdb91fdbc95f0e0b67b5114a3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 17 Dec 2025 22:04:16 +0000 Subject: [PATCH 33/38] rename helper --- src/diffusers/models/attention_dispatch.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 3aeaf460ecf9..0366799f3056 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1881,12 +1881,15 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): return out -def _prepare_attn_mask_native( +def _prepare_additive_attn_mask( attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True ) -> torch.Tensor: """ Convert a 2D boolean attention mask to an additive mask, optionally reshaping to 4D for SDPA. + This helper is used by both native SDPA and xformers backends to convert boolean masks to the additive format they + require. + Args: attn_mask: 2D boolean tensor [batch_size, seq_len_k] where True means attend, False means mask out target_dtype: The dtype to convert the mask to (usually query.dtype) @@ -1939,7 +1942,7 @@ def _native_attention( # attn_mask is [batch_size, seq_len_k] boolean: True means attend, False means mask out # SDPA expects [batch_size, 1, 1, seq_len_k] additive mask: 0.0 for attend, -inf for mask out # Use helper to convert boolean to additive mask and reshape to 4D - attn_mask = _prepare_attn_mask_native(attn_mask, target_dtype=query.dtype) + attn_mask = _prepare_additive_attn_mask(attn_mask, target_dtype=query.dtype) if _parallel_config is None: query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) @@ -2480,7 +2483,9 @@ def _xformers_attention( ) # Fill in the actual mask values (converting boolean to additive) # Use helper to convert 2D boolean -> 4D additive mask - mask_additive = _prepare_attn_mask_native(attn_mask, target_dtype=query.dtype) # [batch, 1, 1, seq_len_k] + mask_additive = _prepare_additive_attn_mask( + attn_mask, target_dtype=query.dtype + ) # [batch, 1, 1, seq_len_k] # Broadcast to [batch, heads, seq_q, seq_len_k] aligned_mask[:, :, :, :original_seq_len] = mask_additive # Mask out the padding (already -inf from zeros -> where with default) From 8de799cbbbebb71734f6c3591e0b283e434f4ab7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 17 Dec 2025 22:07:49 +0000 Subject: [PATCH 34/38] remove duplicate --- .../test_models_transformer_qwenimage.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 955cad0807b3..5ee63d6f273e 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -145,23 +145,6 @@ def test_infers_text_seq_len_from_mask(self): self.assertIsNone(per_sample_len_none) self.assertIsNone(normalized_mask_none) - def test_deprecated_txt_seq_lens_warning(self): - """Test that passing the deprecated txt_seq_lens parameter raises a FutureWarning.""" - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict).to(torch_device) - model.eval() - - # Add the deprecated txt_seq_lens parameter - inputs["txt_seq_lens"] = [inputs["encoder_hidden_states"].shape[1]] - - with self.assertWarns(FutureWarning) as warning: - with torch.no_grad(): - _ = model(**inputs) - - # Verify the warning message mentions the deprecated parameter - self.assertIn("txt_seq_lens", str(warning.warning)) - self.assertIn("deprecated", str(warning.warning).lower()) - def test_non_contiguous_attention_mask(self): """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])""" init_dict, inputs = self.prepare_init_args_and_inputs_for_common() From fc9374718b9f1a78b91bee0613dbee53445d6a59 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 18 Dec 2025 09:58:29 +0000 Subject: [PATCH 35/38] smaller values --- .../test_models_transformer_qwenimage.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 5ee63d6f273e..384954dfbad7 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -205,12 +205,14 @@ def test_layered_model_with_mask(self): init_dict = { "patch_size": 2, "in_channels": 16, - "out_channels": 16, + "out_channels": 4, "num_layers": 2, - "attention_head_dim": 128, - "num_attention_heads": 4, + "attention_head_dim": 16, + "num_attention_heads": 3, "joint_attention_dim": 16, + "axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16) "use_layer3d_rope": True, # Enable layered RoPE + "use_additional_t_cond": True, # Enable additional time conditioning } model = self.model_class(**init_dict).to(torch_device) @@ -236,6 +238,9 @@ def test_layered_model_with_mask(self): timestep = torch.tensor([1.0]).to(torch_device) + # additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding) + addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device) + # Layer structure: 4 layers + 1 condition image img_shapes = [ [ @@ -254,6 +259,7 @@ def test_layered_model_with_mask(self): encoder_hidden_states_mask=encoder_hidden_states_mask, timestep=timestep, img_shapes=img_shapes, + additional_t_cond=addition_t_cond, ) self.assertEqual(output.sample.shape[1], hidden_states.shape[1]) From b7c288a36bc67c324d7018e0eb25109f9556dcd0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 20 Dec 2025 10:32:53 +0000 Subject: [PATCH 36/38] removed --- src/diffusers/models/attention_dispatch.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 0366799f3056..a94afb1f2fbc 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2512,6 +2512,3 @@ def _xformers_attention( out = out.flatten(2, 3) return out - - -_maybe_download_kernel_for_backend(_AttentionBackendRegistry._active_backend) From 2f868793b4e4e706e4b56207c1aa60817a5e4a56 Mon Sep 17 00:00:00 2001 From: dxqb Date: Sun, 21 Dec 2025 12:22:21 +0100 Subject: [PATCH 37/38] split attention --- src/diffusers/models/attention_dispatch.py | 42 +++++++++- .../transformers/transformer_qwenimage.py | 30 ++++--- .../pipelines/qwenimage/pipeline_qwenimage.py | 84 +++++++++++++------ 3 files changed, 119 insertions(+), 37 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index a94afb1f2fbc..eafa731e2f55 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -186,6 +186,7 @@ class AttentionBackendName(str, Enum): _NATIVE_MATH = "_native_math" _NATIVE_NPU = "_native_npu" _NATIVE_XLA = "_native_xla" + SPLIT = "split" # `sageattention` SAGE = "sage" @@ -503,7 +504,7 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask( cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) - max_seqlen_q = seqlens_q.max().item() + max_seqlen_q = seqlens_q.max().item() #TODO item() is inefficient and breaks torch.compile graphs. Use 'seq_len' parameter instead (see split attention backend) max_seqlen_k = seqlens_k.max().item() return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) @@ -1975,6 +1976,45 @@ def _native_attention( return out +@_AttentionBackendRegistry.register( + AttentionBackendName.SPLIT, + constraints=[_check_device, _check_shape], + supports_context_parallel=True, +) +def _split_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + seq_len: Optional[torch.Tensor] = None, #attn_mask is ignored if seq_len is passed + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> torch.Tensor: + if seq_len is None: + return _native_attention(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, _parallel_config) + + batch_size, batch_seq_len = query.shape[:2] + if any(sample_seq_len > batch_seq_len for sample_seq_len in seq_len): + raise ValueError("Attention sequence lengths cannot be longer than maximum sequence length") + if len(seq_len) != batch_size: + raise ValueError("Attention sequence lengths must match the batch size") + + result = [] + for index, sample_seq_len in enumerate(seq_len): + sliced_query = query[index, :sample_seq_len, :, :].unsqueeze(0) + sliced_key = key [index, :sample_seq_len, :, :].unsqueeze(0) + sliced_value = value[index, :sample_seq_len, :, :].unsqueeze(0) + sliced_result = _native_attention(sliced_query, sliced_key, sliced_value, None, dropout_p, is_causal, scale, enable_gqa, return_lse, _parallel_config) + + padding = torch.zeros((1, batch_seq_len - sample_seq_len) + sliced_result.shape[2:], device=sliced_result.device, dtype=sliced_result.dtype) + padded_result = torch.cat([sliced_result, padding], dim=1) + result.append(padded_result) + return torch.cat(result, dim=0) + @_AttentionBackendRegistry.register( AttentionBackendName._NATIVE_CUDNN, diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 217c64bf4b51..146af402a2fe 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -150,7 +150,7 @@ def compute_text_seq_len_from_mask( """ batch_size, text_seq_len = encoder_hidden_states.shape[:2] if encoder_hidden_states_mask is None: - return text_seq_len, None, None + return text_seq_len, [text_seq_len] * batch_size, None if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len): raise ValueError( @@ -165,7 +165,7 @@ def compute_text_seq_len_from_mask( active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) has_active = encoder_hidden_states_mask.any(dim=1) per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) - return text_seq_len, per_sample_len, encoder_hidden_states_mask + return text_seq_len, per_sample_len.tolist(), encoder_hidden_states_mask class QwenTimestepProjEmbeddings(nn.Module): @@ -492,6 +492,7 @@ def __call__( encoder_hidden_states_mask: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + encoder_hidden_states_len: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: if encoder_hidden_states is None: raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") @@ -537,16 +538,17 @@ def __call__( # Concatenate for joint attention # Order: [text, image] - joint_query = torch.cat([txt_query, img_query], dim=1) - joint_key = torch.cat([txt_key, img_key], dim=1) - joint_value = torch.cat([txt_value, img_value], dim=1) + joint_query = torch.cat([img_query, txt_query], dim=1) + joint_key = torch.cat([img_key, txt_key], dim=1) + joint_value = torch.cat([img_value, txt_value], dim=1) # If an encoder_hidden_states_mask is provided, create a joint attention mask. # The encoder_hidden_states_mask is expected to have 1.0 for valid tokens and 0.0 for padding. # We convert it to a boolean mask where True means "attend" and False means "mask out" (don't attend). # Only create the mask if there's actual padding, otherwise keep attention_mask=None for better SDPA performance. + batch_size, image_seq_len = hidden_states.shape[:2] + attention_kwargs = {} if encoder_hidden_states_mask is not None and attention_mask is None: - batch_size, image_seq_len = hidden_states.shape[:2] text_seq_len = encoder_hidden_states.shape[1] if encoder_hidden_states_mask.shape[0] != batch_size: @@ -568,7 +570,8 @@ def __call__( ) # Create 2D joint mask [batch_size, text_seq_len + image_seq_len] # The attention dispatch will normalize this and extract sequence lengths - attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) + attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1) + attention_kwargs['seq_len'] = [text_sample_len + image_seq_len for text_sample_len in encoder_hidden_states_len] # Compute joint attention joint_hidden_states = dispatch_attention_fn( @@ -580,6 +583,7 @@ def __call__( is_causal=False, backend=self._attention_backend, parallel_config=self._parallel_config, + attention_kwargs=attention_kwargs, ) # Reshape back @@ -587,8 +591,8 @@ def __call__( joint_hidden_states = joint_hidden_states.to(joint_query.dtype) # Split attention outputs back - txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part - img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part + img_attn_output = joint_hidden_states[:, :image_seq_len, :] # Image part + txt_attn_output = joint_hidden_states[:, image_seq_len:, :] # Text part # Apply output projections img_attn_output = attn.to_out[0](img_attn_output) @@ -694,6 +698,7 @@ def forward( encoder_hidden_states_mask: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + encoder_hidden_states_len: Optional[torch.Tensor] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, modulate_index: Optional[List[int]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -728,6 +733,7 @@ def forward( encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context") encoder_hidden_states_mask=encoder_hidden_states_mask, image_rotary_emb=image_rotary_emb, + encoder_hidden_states_len=encoder_hidden_states_len, **joint_attention_kwargs, ) @@ -947,7 +953,9 @@ def forward( encoder_hidden_states = self.txt_in(encoder_hidden_states) # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask - text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + if torch.all(encoder_hidden_states_mask): + encoder_hidden_states_mask = None + text_seq_len, text_seq_len_per_sample, encoder_hidden_states_mask = compute_text_seq_len_from_mask( encoder_hidden_states, encoder_hidden_states_mask ) @@ -971,6 +979,7 @@ def forward( encoder_hidden_states_mask, temb, image_rotary_emb, + text_seq_len_per_sample, attention_kwargs, modulate_index, ) @@ -982,6 +991,7 @@ def forward( encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, + encoder_hidden_states_len=text_seq_len_per_sample, joint_attention_kwargs=attention_kwargs, modulate_index=modulate_index, ) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index bc3ce84e1019..ec5ad60d9571 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -473,6 +473,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + batch_negative: bool = False, #TODO remove, only for testing ): r""" Function invoked when calling the pipeline for generation. @@ -603,23 +604,35 @@ def __call__( ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt - prompt_embeds, prompt_embeds_mask = self.encode_prompt( - prompt=prompt, - prompt_embeds=prompt_embeds, - prompt_embeds_mask=prompt_embeds_mask, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - ) - if do_true_cfg: - negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( - prompt=negative_prompt, - prompt_embeds=negative_prompt_embeds, - prompt_embeds_mask=negative_prompt_embeds_mask, + if do_true_cfg and batch_negative: + combined_prompt_embeds, combined_prompt_embeds_mask = self.encode_prompt( + prompt=[prompt, negative_prompt], +# prompt_embeds=prompt_embeds, +# prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + dtype = combined_prompt_embeds.dtype + else: + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) + dtype = prompt_embeds.dtype + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 @@ -628,7 +641,7 @@ def __call__( num_channels_latents, height, width, - prompt_embeds.dtype, + dtype, device, generator, latents, @@ -682,31 +695,50 @@ def __call__( self._current_timestep = t # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - with self.transformer.cache_context("cond"): + if do_true_cfg and batch_negative: noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, + hidden_states=torch.cat([latents] * 2, dim=0), + timestep=torch.cat([timestep] * 2, dim=0) / 1000, + guidance=torch.cat([guidance] * 2, dim=0) if guidance is not None else None, + encoder_hidden_states_mask=combined_prompt_embeds_mask, + encoder_hidden_states=combined_prompt_embeds, img_shapes=img_shapes, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] + noise_pred, neg_noise_pred = torch.chunk(noise_pred, 2, dim=0) + + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - if do_true_cfg: - with self.transformer.cache_context("uncond"): - neg_noise_pred = self.transformer( + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) From 87bbde486e0156fbadce95f208de2c42ed58adf2 Mon Sep 17 00:00:00 2001 From: dxqb Date: Sun, 21 Dec 2025 20:56:32 +0100 Subject: [PATCH 38/38] fix type hints --- src/diffusers/models/transformers/transformer_qwenimage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 146af402a2fe..0584dfc290c7 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -492,7 +492,7 @@ def __call__( encoder_hidden_states_mask: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - encoder_hidden_states_len: Optional[torch.Tensor] = None, + encoder_hidden_states_len: Optional[List[int]] = None, ) -> torch.FloatTensor: if encoder_hidden_states is None: raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") @@ -698,7 +698,7 @@ def forward( encoder_hidden_states_mask: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - encoder_hidden_states_len: Optional[torch.Tensor] = None, + encoder_hidden_states_len: Optional[List[int]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, modulate_index: Optional[List[int]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: