Skip to content

[Bug] train_dreambooth_lora_flux2_klein.py: batch size mismatch with --with_prior_preservation #13292

@vishk23

Description

@vishk23

When using --with_prior_preservation with train_dreambooth_lora_flux2_klein.py,
the prompt embedding repeat logic doubles the batch incorrectly.

The line:
num_repeat_elements = len(prompts)

should be:
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)

Because prompts already contains both instance and class samples from collate_fn,
but the repeat creates 4 embeddings for a batch of 2 latents.

Reproducible with: train_batch_size=1, with_prior_preservation=True

Diffusers version: 0.38.0.dev0 (main branch)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions