-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Fix QwenImage txt_seq_lens handling #12702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b547fcf
72a80c6
88cee8b
ac5ac24
0477526
18efdde
6a549d4
2d424e0
30b5f98
588dc04
f1c2d99
ec52417
beeb020
5c6f8e3
5d434f6
71ba603
babf490
afad335
2d5ab16
c78a1e9
d6d4b1d
e999b76
8115f0b
3b1510c
3676d8e
9ed0ffd
abec461
e26e7b3
59e3882
0cb2138
60bd454
a5abbb8
8415c57
afff5b7
8dc6c3f
22cb03d
125a3a4
b5b6342
61f5265
2ef38e2
270c63f
35efa06
50c4815
c88bc06
1433783
8de799c
fc93747
8bb47d8
b7c288a
4700b7f
4fe7659
77902bc
5b570c7
4d4e5f4
8ee4d95
11e0e45
9d76074
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,7 @@ | |
|
|
||
| from ...configuration_utils import ConfigMixin, register_to_config | ||
| from ...loaders import FromOriginalModelMixin, PeftAdapterMixin | ||
| from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers | ||
| from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers | ||
| from ..attention import AttentionMixin | ||
| from ..cache_utils import CacheMixin | ||
| from ..controlnets.controlnet import zero_module | ||
|
|
@@ -31,6 +31,7 @@ | |
| QwenImageTransformerBlock, | ||
| QwenTimestepProjEmbeddings, | ||
| RMSNorm, | ||
| compute_text_seq_len_from_mask, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -136,7 +137,7 @@ def forward( | |
| return_dict: bool = True, | ||
| ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: | ||
| """ | ||
| The [`FluxTransformer2DModel`] forward method. | ||
| The [`QwenImageControlNetModel`] forward method. | ||
|
|
||
| Args: | ||
| hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): | ||
|
|
@@ -147,24 +148,39 @@ def forward( | |
| The scale factor for ControlNet outputs. | ||
| encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): | ||
| Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. | ||
| pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected | ||
| from the embeddings of input conditions. | ||
| encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): | ||
| Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens. | ||
| Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern | ||
| (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention. | ||
| timestep ( `torch.LongTensor`): | ||
| Used to indicate denoising step. | ||
| block_controlnet_hidden_states: (`list` of `torch.Tensor`): | ||
| A list of tensors that if specified are added to the residuals of transformer blocks. | ||
| img_shapes (`List[Tuple[int, int, int]]`, *optional*): | ||
| Image shapes for RoPE computation. | ||
| txt_seq_lens (`List[int]`, *optional*): | ||
| **Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence | ||
| length. | ||
| joint_attention_kwargs (`dict`, *optional*): | ||
| A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | ||
| `self.processor` in | ||
| [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | ||
| return_dict (`bool`, *optional*, defaults to `True`): | ||
| Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain | ||
| tuple. | ||
| Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. | ||
|
|
||
| Returns: | ||
| If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a | ||
| `tuple` where the first element is the sample tensor. | ||
| If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where | ||
| the first element is the controlnet block samples. | ||
| """ | ||
| # Handle deprecated txt_seq_lens parameter | ||
| if txt_seq_lens is not None: | ||
| deprecate( | ||
| "txt_seq_lens", | ||
| "0.39.0", | ||
| "Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in " | ||
| "version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` " | ||
| "and `encoder_hidden_states_mask`.", | ||
| standard_warn=False, | ||
| ) | ||
|
|
||
| if joint_attention_kwargs is not None: | ||
| joint_attention_kwargs = joint_attention_kwargs.copy() | ||
| lora_scale = joint_attention_kwargs.pop("scale", 1.0) | ||
|
|
@@ -186,32 +202,47 @@ def forward( | |
|
|
||
| temb = self.time_text_embed(timestep, hidden_states) | ||
|
|
||
| image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) | ||
| # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask | ||
| text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( | ||
| encoder_hidden_states, encoder_hidden_states_mask | ||
| ) | ||
|
|
||
| image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) | ||
|
|
||
| timestep = timestep.to(hidden_states.dtype) | ||
| encoder_hidden_states = self.txt_norm(encoder_hidden_states) | ||
| encoder_hidden_states = self.txt_in(encoder_hidden_states) | ||
|
|
||
| # Construct joint attention mask once to avoid reconstructing in every block | ||
| block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {} | ||
| if encoder_hidden_states_mask is not None: | ||
| # Build joint mask: [text_mask, all_ones_for_image] | ||
| batch_size, image_seq_len = hidden_states.shape[:2] | ||
| image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) | ||
| joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) | ||
| block_attention_kwargs["attention_mask"] = joint_attention_mask | ||
|
|
||
| block_samples = () | ||
| for index_block, block in enumerate(self.transformer_blocks): | ||
| for block in self.transformer_blocks: | ||
| if torch.is_grad_enabled() and self.gradient_checkpointing: | ||
| encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( | ||
| block, | ||
| hidden_states, | ||
| encoder_hidden_states, | ||
| encoder_hidden_states_mask, | ||
| None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) | ||
| temb, | ||
| image_rotary_emb, | ||
| block_attention_kwargs, | ||
| ) | ||
|
|
||
| else: | ||
| encoder_hidden_states, hidden_states = block( | ||
| hidden_states=hidden_states, | ||
| encoder_hidden_states=encoder_hidden_states, | ||
| encoder_hidden_states_mask=encoder_hidden_states_mask, | ||
| encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) | ||
| temb=temb, | ||
| image_rotary_emb=image_rotary_emb, | ||
| joint_attention_kwargs=joint_attention_kwargs, | ||
| joint_attention_kwargs=block_attention_kwargs, | ||
| ) | ||
| block_samples = block_samples + (hidden_states,) | ||
|
|
||
|
|
@@ -267,6 +298,15 @@ def forward( | |
| joint_attention_kwargs: Optional[Dict[str, Any]] = None, | ||
| return_dict: bool = True, | ||
| ) -> Union[QwenImageControlNetOutput, Tuple]: | ||
| if txt_seq_lens is not None: | ||
| deprecate( | ||
| "txt_seq_lens", | ||
| "0.39.0", | ||
| "Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be " | ||
| "removed in version 0.39.0. The text sequence length is now automatically inferred from " | ||
| "`encoder_hidden_states` and `encoder_hidden_states_mask`.", | ||
| standard_warn=False, | ||
| ) | ||
| # ControlNet-Union with multiple conditions | ||
| # only load one ControlNet for saving memories | ||
| if len(self.nets) == 1: | ||
|
|
@@ -281,7 +321,6 @@ def forward( | |
| encoder_hidden_states_mask=encoder_hidden_states_mask, | ||
| timestep=timestep, | ||
| img_shapes=img_shapes, | ||
| txt_seq_lens=txt_seq_lens, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we also build joint mask for controlnet?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch i think we should add it |
||
| joint_attention_kwargs=joint_attention_kwargs, | ||
| return_dict=return_dict, | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of specifying performance numbers on
torch.compileand other attention backends, maybe we could highlight this point and include with and withouttorch.compilenumbers? @cdutr WDYT?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! I've simplified the Performance section to focus on torch.compile with the before/after numbers,
removed the attention backend tables since the differences between backends are minimal compared to the torch.compile gains