From e46d1ad64e1e353731281f91ee9bc4d68b5c1ab2 Mon Sep 17 00:00:00 2001 From: Dev-X25874 <283057883+Dev-X25874@users.noreply.github.com> Date: Sat, 23 May 2026 14:49:58 +0530 Subject: [PATCH 1/2] examples/dreambooth: apply SNR min-weighting to prior loss when snr_gamma is set --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index ac8dd9243df6..e3efbd2a3163 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1853,6 +1853,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() + if args.with_prior_preservation: + # Apply the same SNR weighting to the prior loss for consistency. + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") + prior_loss = prior_loss.mean(dim=list(range(1, len(prior_loss.shape)))) * mse_loss_weights + prior_loss = prior_loss.mean() + if args.with_prior_preservation: # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss From fd952c819f1123aacf0a4080cdb07d834dabe354 Mon Sep 17 00:00:00 2001 From: Dev-X25874 <283057883+Dev-X25874@users.noreply.github.com> Date: Sat, 23 May 2026 14:51:13 +0530 Subject: [PATCH 2/2] examples/dreambooth: add test for dreambooth lora sdxl with snr_gamma and prior preservation --- examples/dreambooth/test_dreambooth_lora.py | 26 +++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/examples/dreambooth/test_dreambooth_lora.py b/examples/dreambooth/test_dreambooth_lora.py index e950807d372d..b646d47f4171 100644 --- a/examples/dreambooth/test_dreambooth_lora.py +++ b/examples/dreambooth/test_dreambooth_lora.py @@ -377,3 +377,29 @@ def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit # checkpoint-2 should have been deleted {"checkpoint-4", "checkpoint-6"}, ) + + def test_dreambooth_lora_sdxl_snr_gamma_with_prior_preservation(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 2 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --snr_gamma 5.0 + --with_prior_preservation + --class_data_dir docs/source/en/imgs + --class_prompt photo + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # Verify training completed and produced a valid LoRA weights file. + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))