Skip to content

Fix incorrect batch handling in _prepare_image_ids usage in train_dreambooth_lora_flux2_img2img.py #13811

@Zhu1116

Description

@Zhu1116

Describe the bug

Description

In train_dreambooth_lora_flux2_img2img.py, there is a bug in how _prepare_image_ids is used for conditional image inputs (cond_model_input).

Problem

_prepare_image_ids in Flux2Pipeline is designed for multiple reference images within a single sample, where each image is assigned a different temporal embedding (e.g., T=10, T=20, T=30...) to distinguish multiple reference images of the same instance.

However, in the training script, cond_model_input has shape:

(B, C, H, W)

i.e., each batch element corresponds to an independent training sample, and each sample contains only one conditional image.

Buggy Behavior

In the current implementation (around line ~1703), the code:

  1. Splits the batch into a list of single-image tensors
  2. Calls _prepare_image_ids on the list
  3. Produces different temporal IDs per batch index

As a result:

  • sample 0 → T=10
  • sample 1 → T=20
  • sample 2 → T=30
  • ...

This is incorrect, because batch elements are independent samples and should not have inter-sample temporal relationships.

Expected Behavior

Each sample in the batch should use the same temporal id (e.g., T=10), since each sample only has one conditional image.

There should be no cross-sample temporal offset.

Reproduction

Incorrect Implementation

The current implementation is:

cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])]
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to(
    device=cond_model_input.device
)
cond_model_input_ids = cond_model_input_ids.view(
    cond_model_input.shape[0], -1, model_input_ids.shape[-1]
)

Suggested Fix

Instead of computing image IDs per batch element, generate IDs for a single sample and then expand across the batch dimension.

Fix:

cond_model_input_ids = Flux2Pipeline._prepare_image_ids(
    [cond_model_input[0:1]]
).to(device=cond_model_input.device)

# Expand across batch dimension
cond_model_input_ids = cond_model_input_ids.expand(
    cond_model_input.shape[0], -1, -1
)

Logs

System Info

  • 🤗 Diffusers version: 0.38.0
  • Platform: Linux-5.15.0-72-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.11.13
  • PyTorch version (GPU?): 2.5.1+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 1.15.0
  • Transformers version: 5.9.0
  • Accelerate version: 1.13.0
  • PEFT version: 0.19.1
  • Bitsandbytes version: not installed
  • Safetensors version: 0.8.0-rc.0
  • xFormers version: not installed
  • Accelerator: NVIDIA A800-SXM4-80GB, 81920 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions