diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 23ae05e2ab96..291e96eb6d6d 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2543,7 +2543,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2776,7 +2778,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 2b415cfed2fe..47d8d0a48523 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -250,15 +250,18 @@ def forward( hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, **joint_attention_kwargs, ) @@ -312,6 +315,7 @@ def forward( encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: temb_img, temb_txt = temb[:, :6], temb[:, 6:] @@ -321,11 +325,13 @@ def forward( encoder_hidden_states, emb=temb_txt ) joint_attention_kwargs = joint_attention_kwargs or {} + # Attention. attention_outputs = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, **joint_attention_kwargs, ) @@ -570,6 +576,7 @@ def forward( timestep: torch.LongTensor = None, img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, + attention_mask: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_block_samples=None, controlnet_single_block_samples=None, @@ -659,11 +666,7 @@ def forward( ) if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask ) else: @@ -672,6 +675,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, joint_attention_kwargs=joint_attention_kwargs, ) @@ -704,6 +708,7 @@ def forward( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, joint_attention_kwargs=joint_attention_kwargs, ) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index c111458d3320..c898d6758379 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -235,6 +235,7 @@ def _get_t5_prompt_embeds( dtype = self.text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + attention_mask = attention_mask.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape @@ -242,7 +243,10 @@ def _get_t5_prompt_embeds( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - return prompt_embeds + attention_mask = attention_mask.repeat(1, num_images_per_prompt) + attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, attention_mask def encode_prompt( self, @@ -252,6 +256,8 @@ def encode_prompt( num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, do_classifier_free_guidance: bool = True, max_sequence_length: int = 512, lora_scale: Optional[float] = None, @@ -293,7 +299,7 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds = self._get_t5_prompt_embeds( + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( prompt=prompt, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, @@ -323,12 +329,13 @@ def encode_prompt( " the batch size of `prompt`." ) - negative_prompt_embeds = self._get_t5_prompt_embeds( + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) + negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) if self.text_encoder is not None: @@ -336,7 +343,14 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) - return prompt_embeds, text_ids, negative_prompt_embeds, negative_text_ids + return ( + prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_embeds, + negative_text_ids, + negative_prompt_attention_mask, + ) # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt): @@ -534,6 +548,38 @@ def prepare_latents( return latents, latent_image_ids + def _prepare_attention_mask( + self, batch_size, sequence_length, dtype, prompt_attention_mask=None, negative_prompt_attention_mask=None + ): + attention_mask = None + if prompt_attention_mask is not None: + # Extend the prompt attention mask to account for image tokens in the final sequence + attention_mask = torch.cat( + [prompt_attention_mask, torch.ones(batch_size, sequence_length, device=prompt_attention_mask.device)], + dim=1, + ) + attention_mask = attention_mask.to(dtype) + attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None] + + negative_attention_mask = None + if negative_prompt_attention_mask is not None: + negative_attention_mask = torch.cat( + [ + negative_prompt_attention_mask, + torch.ones(batch_size, sequence_length, device=negative_prompt_attention_mask.device), + ], + dim=1, + ) + negative_attention_mask = negative_attention_mask.to(dtype) + negative_attention_mask = ( + negative_attention_mask[:, None, None, :] * negative_attention_mask[:, None, :, None] + ) + print(attention_mask.shape) + attention_mask = torch.cat([negative_attention_mask, attention_mask], dim=0) + print(attention_mask.shape) + + return attention_mask + @property def guidance_scale(self): return self._guidance_scale @@ -578,6 +624,8 @@ def __call__( negative_ip_adapter_image: Optional[PipelineImageInput] = None, negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -704,13 +752,17 @@ def __call__( ( prompt_embeds, text_ids, + prompt_attention_mask, negative_prompt_embeds, negative_text_ids, + negative_prompt_attention_mask, ) = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, do_classifier_free_guidance=self.do_classifier_free_guidance, device=device, num_images_per_prompt=num_images_per_prompt, @@ -718,6 +770,9 @@ def __call__( lora_scale=lora_scale, ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, latent_image_ids = self.prepare_latents( @@ -730,6 +785,7 @@ def __call__( generator, latents, ) + # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = latents.shape[1] @@ -740,6 +796,15 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) + + attention_mask = self._prepare_attention_mask( + batch_size=latents.shape[0], + sequence_length=image_seq_len, + dtype=latents.dtype, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, @@ -792,15 +857,18 @@ def __call__( if image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) noise_pred = self.transformer( - hidden_states=latents, + hidden_states=latent_model_input, timestep=timestep / 1000, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, + attention_mask=attention_mask, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] @@ -808,16 +876,8 @@ def __call__( if self.do_classifier_free_guidance: if negative_image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - neg_noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype