Skip to content
68 changes: 66 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,68 @@ def __getitem__(self, index):
return example


# These helpers only matter for prior preservation, where instance and class prompt
# embedding batches are concatenated and may not share the same mask/sequence length.
def _materialize_prompt_embedding_mask(
prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None
) -> torch.Tensor:
"""Return a dense mask tensor for a prompt embedding batch."""
batch_size, seq_len = prompt_embeds.shape[:2]

if prompt_embeds_mask is None:
return torch.ones((batch_size, seq_len), dtype=torch.long, device=prompt_embeds.device)

if prompt_embeds_mask.shape != (batch_size, seq_len):
raise ValueError(
f"`prompt_embeds_mask` shape {prompt_embeds_mask.shape} must match prompt embeddings shape "
f"({batch_size}, {seq_len})."
)

return prompt_embeds_mask.to(device=prompt_embeds.device)


def _pad_prompt_embedding_pair(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would maybe also add a comment on how this is only relevant during prior preservation.

Copy link
Copy Markdown
Author

@chenyangzhu1 chenyangzhu1 Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

I added the following comment

# These helpers only matter for prior preservation, where instance and class prompt
# embedding batches are concatenated and may not share the same mask/sequence length.

prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None, target_seq_len: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""Pad one prompt embedding batch and its mask to a shared sequence length."""
prompt_embeds_mask = _materialize_prompt_embedding_mask(prompt_embeds, prompt_embeds_mask)
pad_width = target_seq_len - prompt_embeds.shape[1]

if pad_width <= 0:
return prompt_embeds, prompt_embeds_mask

prompt_embeds = torch.cat(
[prompt_embeds, prompt_embeds.new_zeros(prompt_embeds.shape[0], pad_width, prompt_embeds.shape[2])], dim=1
)
prompt_embeds_mask = torch.cat(
[prompt_embeds_mask, prompt_embeds_mask.new_zeros(prompt_embeds_mask.shape[0], pad_width)], dim=1
)

return prompt_embeds, prompt_embeds_mask


def concat_prompt_embedding_batches(
*prompt_embedding_pairs: tuple[torch.Tensor, torch.Tensor | None],
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Concatenate prompt embedding batches while handling missing masks and length mismatches."""
if not prompt_embedding_pairs:
raise ValueError("At least one prompt embedding pair must be provided.")

target_seq_len = max(prompt_embeds.shape[1] for prompt_embeds, _ in prompt_embedding_pairs)
padded_pairs = [
_pad_prompt_embedding_pair(prompt_embeds, prompt_embeds_mask, target_seq_len)
for prompt_embeds, prompt_embeds_mask in prompt_embedding_pairs
]

merged_prompt_embeds = torch.cat([prompt_embeds for prompt_embeds, _ in padded_pairs], dim=0)
merged_mask = torch.cat([prompt_embeds_mask for _, prompt_embeds_mask in padded_pairs], dim=0)

if merged_mask.all():
return merged_prompt_embeds, None

return merged_prompt_embeds, merged_mask


def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
Expand Down Expand Up @@ -1320,8 +1382,10 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
prompt_embeds = instance_prompt_embeds
prompt_embeds_mask = instance_prompt_embeds_mask
if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_embeds], dim=0)
prompt_embeds_mask = torch.cat([prompt_embeds_mask, class_prompt_embeds_mask], dim=0)
prompt_embeds, prompt_embeds_mask = concat_prompt_embedding_batches(
(instance_prompt_embeds, instance_prompt_embeds_mask),
(class_prompt_embeds, class_prompt_embeds_mask),
)

# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
Expand Down
Loading