From 9019e928994892f5d9076c8e53208771ebab9a15 Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 16 Jun 2025 22:23:46 +0530 Subject: [PATCH 01/11] update --- .../models/transformers/transformer_chroma.py | 13 ++++-- .../pipelines/chroma/pipeline_chroma.py | 46 +++++++++++++++++-- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 2b415cfed2fe..20438309cb39 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -250,6 +250,7 @@ def forward( hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: residual = hidden_states @@ -259,6 +260,7 @@ def forward( attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, **joint_attention_kwargs, ) @@ -312,6 +314,7 @@ def forward( encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: temb_img, temb_txt = temb[:, :6], temb[:, 6:] @@ -326,6 +329,7 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, **joint_attention_kwargs, ) @@ -570,6 +574,7 @@ def forward( timestep: torch.LongTensor = None, img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, + attention_mask: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_block_samples=None, controlnet_single_block_samples=None, @@ -659,11 +664,7 @@ def forward( ) if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask=attention_mask ) else: @@ -672,6 +673,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, joint_attention_kwargs=joint_attention_kwargs, ) @@ -704,6 +706,7 @@ def forward( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, joint_attention_kwargs=joint_attention_kwargs, ) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index c111458d3320..5df182562828 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -242,7 +242,7 @@ def _get_t5_prompt_embeds( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - return prompt_embeds + return prompt_embeds, attention_mask def encode_prompt( self, @@ -292,8 +292,9 @@ def encode_prompt( else: batch_size = prompt_embeds.shape[0] + prompt_attention_mask = None if prompt_embeds is None: - prompt_embeds = self._get_t5_prompt_embeds( + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( prompt=prompt, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, @@ -303,6 +304,7 @@ def encode_prompt( dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) negative_text_ids = None + negative_prompt_attention_mask = None if do_classifier_free_guidance: if negative_prompt_embeds is None: @@ -323,7 +325,7 @@ def encode_prompt( " the batch size of `prompt`." ) - negative_prompt_embeds = self._get_t5_prompt_embeds( + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, @@ -336,7 +338,14 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) - return prompt_embeds, text_ids, negative_prompt_embeds, negative_text_ids + return ( + prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_embeds, + negative_text_ids, + negative_prompt_attention_mask, + ) # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt): @@ -534,6 +543,23 @@ def prepare_latents( return latents, latent_image_ids + def _prepare_attention_mask( + self, batch_size, sequence_length, prompt_attention_mask, negative_prompt_attention_mask=None + ): + device = prompt_attention_mask.device + attention_mask = torch.cat([prompt_attention_mask, torch.ones(batch_size, sequence_length, device=device)]) + + negative_attention_mask = None + if negative_prompt_attention_mask is not None: + negative_attention_mask = torch.cat( + [ + negative_prompt_attention_mask, + torch.ones(batch_size, sequence_length, device=device), + ] + ) + + return attention_mask, negative_attention_mask + @property def guidance_scale(self): return self._guidance_scale @@ -704,8 +730,10 @@ def __call__( ( prompt_embeds, text_ids, + prompt_attention_mask, negative_prompt_embeds, negative_text_ids, + negative_prompt_attention_mask, ) = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, @@ -730,6 +758,14 @@ def __call__( generator, latents, ) + + prompt_attention_mask, negative_prompt_attention_mask = self._prepare_attention_mask( + latents.shape[0], + latents.shape[1], + prompt_attention_mask, + negative_prompt_attention_mask, + ) + # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = latents.shape[1] @@ -801,6 +837,7 @@ def __call__( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, + attention_mask=prompt_attention_mask, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] @@ -814,6 +851,7 @@ def __call__( encoder_hidden_states=negative_prompt_embeds, txt_ids=negative_text_ids, img_ids=latent_image_ids, + attention_mask=negative_prompt_attention_mask, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] From 188b0d2a2f3a84ab95df3d045d6e1bdb928efa0d Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 16 Jun 2025 19:32:19 +0200 Subject: [PATCH 02/11] update --- src/diffusers/pipelines/chroma/pipeline_chroma.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 5df182562828..9048efec8d24 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -235,6 +235,7 @@ def _get_t5_prompt_embeds( dtype = self.text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + attention_mask = attention_mask.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape @@ -331,6 +332,7 @@ def encode_prompt( max_sequence_length=max_sequence_length, device=device, ) + negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) if self.text_encoder is not None: @@ -547,7 +549,9 @@ def _prepare_attention_mask( self, batch_size, sequence_length, prompt_attention_mask, negative_prompt_attention_mask=None ): device = prompt_attention_mask.device - attention_mask = torch.cat([prompt_attention_mask, torch.ones(batch_size, sequence_length, device=device)]) + attention_mask = torch.cat( + [prompt_attention_mask, torch.ones(batch_size, sequence_length, device=device)], dim=1 + ) negative_attention_mask = None if negative_prompt_attention_mask is not None: @@ -555,7 +559,8 @@ def _prepare_attention_mask( [ negative_prompt_attention_mask, torch.ones(batch_size, sequence_length, device=device), - ] + ], + dim=1, ) return attention_mask, negative_attention_mask @@ -759,7 +764,7 @@ def __call__( latents, ) - prompt_attention_mask, negative_prompt_attention_mask = self._prepare_attention_mask( + attention_mask, negative_attention_mask = self._prepare_attention_mask( latents.shape[0], latents.shape[1], prompt_attention_mask, @@ -837,7 +842,7 @@ def __call__( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, - attention_mask=prompt_attention_mask, + attention_mask=attention_mask.to(latents.dtype), joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] @@ -851,7 +856,7 @@ def __call__( encoder_hidden_states=negative_prompt_embeds, txt_ids=negative_text_ids, img_ids=latent_image_ids, - attention_mask=negative_prompt_attention_mask, + attention_mask=negative_attention_mask.to(latents.dtype), joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] From 602af7411e1a095e699f4fa7ec794f1a91094348 Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 16 Jun 2025 23:38:17 +0530 Subject: [PATCH 03/11] update --- .../pipelines/chroma/pipeline_chroma.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 9048efec8d24..c1b543656b21 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -549,6 +549,8 @@ def _prepare_attention_mask( self, batch_size, sequence_length, prompt_attention_mask, negative_prompt_attention_mask=None ): device = prompt_attention_mask.device + + # Extend the prompt attention mask to account for image tokens in the final sequence attention_mask = torch.cat( [prompt_attention_mask, torch.ones(batch_size, sequence_length, device=device)], dim=1 ) @@ -764,13 +766,6 @@ def __call__( latents, ) - attention_mask, negative_attention_mask = self._prepare_attention_mask( - latents.shape[0], - latents.shape[1], - prompt_attention_mask, - negative_prompt_attention_mask, - ) - # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = latents.shape[1] @@ -781,6 +776,14 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) + + attention_mask, negative_attention_mask = self._prepare_attention_mask( + batch_size=latents.shape[0], + sequence_length=image_seq_len, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, From ad13450cfef414a31404f926a2e1f50b655367e6 Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 16 Jun 2025 23:59:40 +0530 Subject: [PATCH 04/11] update --- src/diffusers/pipelines/chroma/pipeline_chroma.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index c1b543656b21..8308769b7d41 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -243,6 +243,9 @@ def _get_t5_prompt_embeds( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + attention_mask = attention_mask.repeat(1, num_images_per_prompt) + attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len) + return prompt_embeds, attention_mask def encode_prompt( From d74985c160c989bddf0b2482a8634247e012a5a4 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 17 Jun 2025 05:01:10 +0200 Subject: [PATCH 05/11] update --- .../models/transformers/transformer_chroma.py | 7 +++++ .../pipelines/chroma/pipeline_chroma.py | 29 ++++++++++++------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 20438309cb39..45ae6a8781d7 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -257,6 +257,10 @@ def forward( norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) joint_attention_kwargs = joint_attention_kwargs or {} + + if attention_mask is not None: + attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None] + attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, @@ -324,6 +328,9 @@ def forward( encoder_hidden_states, emb=temb_txt ) joint_attention_kwargs = joint_attention_kwargs or {} + if attention_mask is not None: + attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None] + # Attention. attention_outputs = self.attn( hidden_states=norm_hidden_states, diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 8308769b7d41..02b718411417 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -256,6 +256,8 @@ def encode_prompt( num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, do_classifier_free_guidance: bool = True, max_sequence_length: int = 512, lora_scale: Optional[float] = None, @@ -296,7 +298,6 @@ def encode_prompt( else: batch_size = prompt_embeds.shape[0] - prompt_attention_mask = None if prompt_embeds is None: prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( prompt=prompt, @@ -308,7 +309,6 @@ def encode_prompt( dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) negative_text_ids = None - negative_prompt_attention_mask = None if do_classifier_free_guidance: if negative_prompt_embeds is None: @@ -551,19 +551,20 @@ def prepare_latents( def _prepare_attention_mask( self, batch_size, sequence_length, prompt_attention_mask, negative_prompt_attention_mask=None ): - device = prompt_attention_mask.device - - # Extend the prompt attention mask to account for image tokens in the final sequence - attention_mask = torch.cat( - [prompt_attention_mask, torch.ones(batch_size, sequence_length, device=device)], dim=1 - ) + attention_mask = None + if prompt_attention_mask is not None: + # Extend the prompt attention mask to account for image tokens in the final sequence + attention_mask = torch.cat( + [prompt_attention_mask, torch.ones(batch_size, sequence_length, device=prompt_attention_mask.device)], + dim=1, + ) negative_attention_mask = None if negative_prompt_attention_mask is not None: negative_attention_mask = torch.cat( [ negative_prompt_attention_mask, - torch.ones(batch_size, sequence_length, device=device), + torch.ones(batch_size, sequence_length, device=negative_prompt_attention_mask.device), ], dim=1, ) @@ -614,6 +615,8 @@ def __call__( negative_ip_adapter_image: Optional[PipelineImageInput] = None, negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -749,6 +752,8 @@ def __call__( negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, do_classifier_free_guidance=self.do_classifier_free_guidance, device=device, num_images_per_prompt=num_images_per_prompt, @@ -848,7 +853,7 @@ def __call__( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, - attention_mask=attention_mask.to(latents.dtype), + attention_mask=attention_mask.to(latents.dtype) if attention_mask is not None else None, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] @@ -862,7 +867,9 @@ def __call__( encoder_hidden_states=negative_prompt_embeds, txt_ids=negative_text_ids, img_ids=latent_image_ids, - attention_mask=negative_attention_mask.to(latents.dtype), + attention_mask=negative_attention_mask.to(latents.dtype) + if negative_attention_mask is not None + else None, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] From 172b2ef73bb876fa640a0ed229d6956f67b39ad7 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 17 Jun 2025 05:07:02 +0200 Subject: [PATCH 06/11] update --- src/diffusers/models/attention_processor.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 23ae05e2ab96..291e96eb6d6d 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2543,7 +2543,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2776,7 +2778,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) From 7cdd7d2df00682eb89c32d7f27194e5cd37b9b8c Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 17 Jun 2025 05:19:30 +0200 Subject: [PATCH 07/11] update --- src/diffusers/models/transformers/transformer_chroma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 45ae6a8781d7..d11f6c2a5e25 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -671,7 +671,7 @@ def forward( ) if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask=attention_mask + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask ) else: From 544dad4c2567cbd9d2d908f3e14b2df6f591b309 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 17 Jun 2025 05:54:38 +0200 Subject: [PATCH 08/11] update --- src/diffusers/pipelines/chroma/pipeline_chroma.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 02b718411417..a86138f5694d 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -549,7 +549,7 @@ def prepare_latents( return latents, latent_image_ids def _prepare_attention_mask( - self, batch_size, sequence_length, prompt_attention_mask, negative_prompt_attention_mask=None + self, batch_size, sequence_length, dtype, prompt_attention_mask=None, negative_prompt_attention_mask=None ): attention_mask = None if prompt_attention_mask is not None: @@ -558,6 +558,7 @@ def _prepare_attention_mask( [prompt_attention_mask, torch.ones(batch_size, sequence_length, device=prompt_attention_mask.device)], dim=1, ) + attention_mask = attention_mask.to(dtype) negative_attention_mask = None if negative_prompt_attention_mask is not None: @@ -568,6 +569,7 @@ def _prepare_attention_mask( ], dim=1, ) + negative_attention_mask = negative_attention_mask.to(dtype) return attention_mask, negative_attention_mask @@ -788,6 +790,7 @@ def __call__( attention_mask, negative_attention_mask = self._prepare_attention_mask( batch_size=latents.shape[0], sequence_length=image_seq_len, + dtype=latents.dtype, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, ) @@ -853,7 +856,7 @@ def __call__( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, - attention_mask=attention_mask.to(latents.dtype) if attention_mask is not None else None, + attention_mask=attention_mask, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] @@ -867,9 +870,7 @@ def __call__( encoder_hidden_states=negative_prompt_embeds, txt_ids=negative_text_ids, img_ids=latent_image_ids, - attention_mask=negative_attention_mask.to(latents.dtype) - if negative_attention_mask is not None - else None, + attention_mask=negative_attention_mask, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] From 77030474314c58772b25ba40058f5a92f314ea0f Mon Sep 17 00:00:00 2001 From: BuildTools Date: Tue, 17 Jun 2025 00:27:04 -0600 Subject: [PATCH 09/11] add cond + uncond batch --- .../pipelines/chroma/pipeline_chroma.py | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index a86138f5694d..b9dca0c2a487 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -570,8 +570,9 @@ def _prepare_attention_mask( dim=1, ) negative_attention_mask = negative_attention_mask.to(dtype) + attention_mask = torch.cat([attention_mask, negative_attention_mask], dim=0) - return attention_mask, negative_attention_mask + return attention_mask @property def guidance_scale(self): @@ -763,6 +764,9 @@ def __call__( lora_scale=lora_scale, ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, latent_image_ids = self.prepare_latents( @@ -787,7 +791,7 @@ def __call__( self.scheduler.config.get("max_shift", 1.15), ) - attention_mask, negative_attention_mask = self._prepare_attention_mask( + attention_mask = self._prepare_attention_mask( batch_size=latents.shape[0], sequence_length=image_seq_len, dtype=latents.dtype, @@ -847,11 +851,13 @@ def __call__( if image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) noise_pred = self.transformer( - hidden_states=latents, + hidden_states=latent_model_input, timestep=timestep / 1000, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, @@ -864,17 +870,8 @@ def __call__( if self.do_classifier_free_guidance: if negative_image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - neg_noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - attention_mask=negative_attention_mask, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype From 08483296d10686b1d61219172a1b52b68a27d3e6 Mon Sep 17 00:00:00 2001 From: BuildTools Date: Tue, 17 Jun 2025 00:42:52 -0600 Subject: [PATCH 10/11] fix batch --- src/diffusers/models/transformers/transformer_chroma.py | 4 ---- src/diffusers/pipelines/chroma/pipeline_chroma.py | 6 +++++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index d11f6c2a5e25..c77d0d2147cb 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -258,8 +258,6 @@ def forward( mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) joint_attention_kwargs = joint_attention_kwargs or {} - if attention_mask is not None: - attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None] attn_output = self.attn( hidden_states=norm_hidden_states, @@ -328,8 +326,6 @@ def forward( encoder_hidden_states, emb=temb_txt ) joint_attention_kwargs = joint_attention_kwargs or {} - if attention_mask is not None: - attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None] # Attention. attention_outputs = self.attn( diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index b9dca0c2a487..ec840fbacde6 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -559,6 +559,7 @@ def _prepare_attention_mask( dim=1, ) attention_mask = attention_mask.to(dtype) + attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None] negative_attention_mask = None if negative_prompt_attention_mask is not None: @@ -570,7 +571,10 @@ def _prepare_attention_mask( dim=1, ) negative_attention_mask = negative_attention_mask.to(dtype) - attention_mask = torch.cat([attention_mask, negative_attention_mask], dim=0) + negative_attention_mask = negative_attention_mask[:, None, None, :] * negative_attention_mask[:, None, :, None] + print(attention_mask.shape) + attention_mask = torch.cat([negative_attention_mask, attention_mask], dim=0) + print(attention_mask.shape) return attention_mask From deea9dd1d82de46d38a5189b1f95c469e5c61c01 Mon Sep 17 00:00:00 2001 From: BuildTools Date: Tue, 17 Jun 2025 00:47:36 -0600 Subject: [PATCH 11/11] make style, quality --- src/diffusers/models/transformers/transformer_chroma.py | 1 - src/diffusers/pipelines/chroma/pipeline_chroma.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index c77d0d2147cb..47d8d0a48523 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -258,7 +258,6 @@ def forward( mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) joint_attention_kwargs = joint_attention_kwargs or {} - attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index ec840fbacde6..c898d6758379 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -571,7 +571,9 @@ def _prepare_attention_mask( dim=1, ) negative_attention_mask = negative_attention_mask.to(dtype) - negative_attention_mask = negative_attention_mask[:, None, None, :] * negative_attention_mask[:, None, :, None] + negative_attention_mask = ( + negative_attention_mask[:, None, None, :] * negative_attention_mask[:, None, :, None] + ) print(attention_mask.shape) attention_mask = torch.cat([negative_attention_mask, attention_mask], dim=0) print(attention_mask.shape)