From c766e27c77bee8efb3ec81da66cf7bc75870fc5b Mon Sep 17 00:00:00 2001 From: js1234567 Date: Thu, 18 Dec 2025 19:55:36 +0800 Subject: [PATCH 1/9] Add FSDP option for Flux2 --- .../dreambooth/train_dreambooth_lora_flux2.py | 68 ++++++++++++++-- .../train_dreambooth_lora_flux2_img2img.py | 67 ++++++++++++++-- src/diffusers/training_utils.py | 79 ++++++++++++++++++- 3 files changed, 197 insertions(+), 17 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 81306940af8f..1bd95a326494 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -47,6 +47,7 @@ import numpy as np import torch +import torch.distributed as dist import transformers from accelerate import Accelerator from accelerate.logging import get_logger @@ -80,8 +81,10 @@ compute_loss_weighting_for_sd3, find_nearest_bucket, free_memory, + get_fsdp_kwargs_from_accelerator, offload_models, parse_buckets_string, + wrap_with_fsdp, ) from diffusers.utils import ( check_min_version, @@ -722,6 +725,7 @@ def parse_args(input_args=None): ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") + parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder") if input_args is not None: args = parser.parse_args(input_args) @@ -1219,7 +1223,11 @@ def main(args): if args.bnb_quantization_config_path is not None else {"device": accelerator.device, "dtype": weight_dtype} ) - transformer.to(**transformer_to_kwargs) + + is_fsdp = accelerator.state.fsdp_plugin is not None + if not is_fsdp: + transformer.to(**transformer_to_kwargs) + if args.do_fp8_training: convert_to_float8_training( transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) @@ -1507,6 +1515,21 @@ def _encode_single(prompt: str): args.validation_prompt, text_encoding_pipeline ) + # Init FSDP for text encoder + if args.fsdp_text_encoder: + fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator) + text_encoder_fsdp = wrap_with_fsdp( + model=text_encoding_pipeline.text_encoder, + device=accelerator.device, + offload=args.offload, + limit_all_gathers=True, + use_orig_params=True, + fsdp_kwargs=fsdp_kwargs, + ) + + text_encoding_pipeline.text_encoder = text_encoder_fsdp + dist.barrier() + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. @@ -1536,6 +1559,8 @@ def _encode_single(prompt: str): if train_dataset.custom_instance_prompts: if args.remote_text_encoder: prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"]) + elif args.fsdp_text_encoder: + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) else: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) @@ -1836,15 +1861,42 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Save the lora layers accelerator.wait_for_everyone() + is_fsdp = accelerator.state.fsdp_plugin is not None + + if is_fsdp: + transformer = unwrap_model(transformer) + state_dict = accelerator.get_state_dict(transformer) if accelerator.is_main_process: modules_to_save = {} - transformer = unwrap_model(transformer) - if args.bnb_quantization_config_path is None: - if args.upcast_before_saving: - transformer.to(torch.float32) - else: - transformer = transformer.to(weight_dtype) - transformer_lora_layers = get_peft_model_state_dict(transformer) + if is_fsdp: + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + state_dict = { + k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + else: + state_dict = { + k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + + transformer_lora_layers = get_peft_model_state_dict( + transformer, + state_dict=state_dict, + ) + transformer_lora_layers = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers.items() + } + + else: + transformer = unwrap_model(transformer) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer Flux2Pipeline.save_lora_weights( diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 0b9b9f993094..d5372e01a395 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -46,6 +46,7 @@ import numpy as np import torch +import torch.distributed as dist import transformers from accelerate import Accelerator from accelerate.logging import get_logger @@ -79,8 +80,10 @@ compute_loss_weighting_for_sd3, find_nearest_bucket, free_memory, + get_fsdp_kwargs_from_accelerator, offload_models, parse_buckets_string, + wrap_with_fsdp, ) from diffusers.utils import ( check_min_version, @@ -691,6 +694,7 @@ def parse_args(input_args=None): parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") + parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder") if input_args is not None: args = parser.parse_args(input_args) @@ -1156,7 +1160,11 @@ def main(args): if args.bnb_quantization_config_path is not None else {"device": accelerator.device, "dtype": weight_dtype} ) - transformer.to(**transformer_to_kwargs) + + is_fsdp = accelerator.state.fsdp_plugin is not None + if not is_fsdp: + transformer.to(**transformer_to_kwargs) + if args.do_fp8_training: convert_to_float8_training( transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) @@ -1430,6 +1438,21 @@ def _encode_single(prompt: str): args.validation_prompt, text_encoding_pipeline ) + # Init FSDP for text encoder + if args.fsdp_text_encoder: + fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator) + text_encoder_fsdp = wrap_with_fsdp( + model=text_encoding_pipeline.text_encoder, + device=accelerator.device, + offload=args.offload, + limit_all_gathers=True, + use_orig_params=True, + fsdp_kwargs=fsdp_kwargs, + ) + + text_encoding_pipeline.text_encoder = text_encoder_fsdp + dist.barrier() + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. @@ -1461,6 +1484,8 @@ def _encode_single(prompt: str): if train_dataset.custom_instance_prompts: if args.remote_text_encoder: prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"]) + elif args.fsdp_text_encoder: + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) else: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) @@ -1759,15 +1784,41 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Save the lora layers accelerator.wait_for_everyone() + + if is_fsdp: + transformer = unwrap_model(transformer) + state_dict = accelerator.get_state_dict(transformer) if accelerator.is_main_process: modules_to_save = {} - transformer = unwrap_model(transformer) - if args.bnb_quantization_config_path is None: - if args.upcast_before_saving: - transformer.to(torch.float32) - else: - transformer = transformer.to(weight_dtype) - transformer_lora_layers = get_peft_model_state_dict(transformer) + if is_fsdp: + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + state_dict = { + k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + else: + state_dict = { + k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + + transformer_lora_layers = get_peft_model_state_dict( + transformer, + state_dict=state_dict, + ) + transformer_lora_layers = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers.items() + } + + else: + transformer = unwrap_model(transformer) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer Flux2Pipeline.save_lora_weights( diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 7a98fa3da14a..7edd21be24d2 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -6,10 +6,15 @@ import re import warnings from contextlib import contextmanager -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Set, Type import numpy as np import torch +import torch.distributed as dist +from torch.distributed.fsdp import CPUOffload, ShardingStrategy +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from functools import partial from .models import UNet2DConditionModel from .pipelines import DiffusionPipeline @@ -394,6 +399,78 @@ def find_nearest_bucket(h, w, bucket_options): return best_bucket_idx +def get_fsdp_kwargs_from_accelerator(accelerator) -> dict: + """ + Extract and convert FSDP config from Accelerator into PyTorch FSDP kwargs. + """ + + kwargs = {} + fsdp_plugin = accelerator.state.fsdp_plugin + + if fsdp_plugin is None: + # FSDP not enabled in Accelerator + kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD + else: + # FSDP is enabled → use plugin's strategy, or default if None + kwargs["sharding_strategy"] = ( + fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD + ) + + return kwargs + + +def wrap_with_fsdp( + model: torch.nn.Module, + device: Union[str, torch.device], + offload: bool = True, + use_orig_params: bool = True, + limit_all_gathers: bool = True, + fsdp_kwargs: Optional[Dict[str, Any]] = None, + transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None, +) -> FSDP: + """ + Wrap a model with FSDP using common defaults and optional transformer auto-wrapping. + + Args: + model: Model to wrap + device: Target device (e.g., accelerator.device) + offload: Whether to enable CPU parameter offloading + use_orig_params: Whether to use original parameters + limit_all_gathers: Whether to limit all gathers + fsdp_kwargs: FSDP arguments (sharding_strategy, etc.) — usually from Accelerate config + transformer_layer_cls: Classes for auto-wrapping (if not using policy from fsdp_kwargs) + + Returns: + FSDP-wrapped model + """ + + if transformer_layer_cls is None: + # Set the default layers if transformer_layer_cls is not provided + transformer_layer_cls = type(model.model.language_model.layers[0]) + + # Add auto-wrap policy if transformer layers specified + auto_wrap_policy = partial( + transformer_auto_wrap_policy, + transformer_layer_cls={transformer_layer_cls}, + ) + + config = { + "device_id": device, + "cpu_offload": CPUOffload(offload_params=offload) if offload else None, + "use_orig_params": use_orig_params, + "limit_all_gathers": limit_all_gathers, + "auto_wrap_policy": auto_wrap_policy + } + + if fsdp_kwargs: + config.update(fsdp_kwargs) + + fsdp_model = FSDP(model, **config) + if dist.is_initialized(): + dist.barrier() + return fsdp_model + + # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ From 647c66aaf3bdba17b4601d9ff971da2e8ce92e50 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 22 Dec 2025 05:42:17 +0000 Subject: [PATCH 2/9] Apply style fixes --- src/diffusers/training_utils.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 7edd21be24d2..9407909cf049 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -6,7 +6,8 @@ import re import warnings from contextlib import contextmanager -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Set, Type +from functools import partial +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union import numpy as np import torch @@ -14,7 +15,6 @@ from torch.distributed.fsdp import CPUOffload, ShardingStrategy from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from functools import partial from .models import UNet2DConditionModel from .pipelines import DiffusionPipeline @@ -412,21 +412,19 @@ def get_fsdp_kwargs_from_accelerator(accelerator) -> dict: kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD else: # FSDP is enabled → use plugin's strategy, or default if None - kwargs["sharding_strategy"] = ( - fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD - ) + kwargs["sharding_strategy"] = fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD return kwargs def wrap_with_fsdp( - model: torch.nn.Module, - device: Union[str, torch.device], - offload: bool = True, - use_orig_params: bool = True, - limit_all_gathers: bool = True, - fsdp_kwargs: Optional[Dict[str, Any]] = None, - transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None, + model: torch.nn.Module, + device: Union[str, torch.device], + offload: bool = True, + use_orig_params: bool = True, + limit_all_gathers: bool = True, + fsdp_kwargs: Optional[Dict[str, Any]] = None, + transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None, ) -> FSDP: """ Wrap a model with FSDP using common defaults and optional transformer auto-wrapping. @@ -459,7 +457,7 @@ def wrap_with_fsdp( "cpu_offload": CPUOffload(offload_params=offload) if offload else None, "use_orig_params": use_orig_params, "limit_all_gathers": limit_all_gathers, - "auto_wrap_policy": auto_wrap_policy + "auto_wrap_policy": auto_wrap_policy, } if fsdp_kwargs: From f931ec31a577cf0283be54a166c37e4dcd1ee800 Mon Sep 17 00:00:00 2001 From: js1234567 Date: Tue, 23 Dec 2025 15:56:13 +0800 Subject: [PATCH 3/9] Add FSDP option for Flux2 --- .../dreambooth/train_dreambooth_lora_flux2.py | 64 +++++++++++++------ .../train_dreambooth_lora_flux2_img2img.py | 62 +++++++++++++----- src/diffusers/training_utils.py | 19 ++++-- 3 files changed, 104 insertions(+), 41 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 1bd95a326494..e25c7f166913 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -1271,19 +1271,42 @@ def unwrap_model(model): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): - if accelerator.is_main_process: - transformer_lora_layers_to_save = None - modules_to_save = {} + transformer_lora_layers_to_save = None + modules_to_save = {} + + if is_fsdp: for model in models: - if isinstance(model, type(unwrap_model(transformer))): - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - modules_to_save["transformer"] = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + state_dict = accelerator.get_state_dict(models) - # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if accelerator.is_main_process: + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(model), state_dict=state_dict, + ) + transformer_lora_layers_to_save = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers_to_save.items() + } + modules_to_save["transformer"] = model + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + else: + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + modules_to_save = {} + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + if accelerator.is_main_process: Flux2Pipeline.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, @@ -1293,13 +1316,19 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): transformer_ = None - while len(models) > 0: - model = models.pop() + if not is_fsdp: + while len(models) > 0: + model = models.pop() - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", + ) + transformer_.add_adapter(transformer_lora_config) lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir) @@ -1802,7 +1831,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or is_fsdp: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: @@ -1861,7 +1890,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Save the lora layers accelerator.wait_for_everyone() - is_fsdp = accelerator.state.fsdp_plugin is not None if is_fsdp: transformer = unwrap_model(transformer) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index d5372e01a395..2062994a0dc1 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1208,19 +1208,41 @@ def unwrap_model(model): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): - if accelerator.is_main_process: - transformer_lora_layers_to_save = None - modules_to_save = {} + transformer_lora_layers_to_save = None + modules_to_save = {} + if is_fsdp: for model in models: - if isinstance(model, type(unwrap_model(transformer))): - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - modules_to_save["transformer"] = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + state_dict = accelerator.get_state_dict(models) - # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if accelerator.is_main_process: + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(model), state_dict=state_dict, + ) + transformer_lora_layers_to_save = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers_to_save.items() + } + modules_to_save["transformer"] = model + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + else: + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + modules_to_save = {} + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + if accelerator.is_main_process: Flux2Pipeline.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, @@ -1230,13 +1252,19 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): transformer_ = None - while len(models) > 0: - model = models.pop() + if not is_fsdp: + while len(models) > 0: + model = models.pop() - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", + ) + transformer_.add_adapter(transformer_lora_config) lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir) @@ -1725,7 +1753,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or is_fsdp: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 9407909cf049..56e5fe4e5a89 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -5,16 +5,17 @@ import random import re import warnings +from accelerate.logging import get_logger from contextlib import contextmanager from functools import partial from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union import numpy as np import torch -import torch.distributed as dist -from torch.distributed.fsdp import CPUOffload, ShardingStrategy -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +if getattr(torch, "distributed", None) is not None: + from torch.distributed.fsdp import CPUOffload, ShardingStrategy + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from .models import UNet2DConditionModel from .pipelines import DiffusionPipeline @@ -405,6 +406,11 @@ def get_fsdp_kwargs_from_accelerator(accelerator) -> dict: """ kwargs = {} + fsdp_state = getattr(accelerator.state, "fsdp_plugin", None) + + if fsdp_state is None: + raise ValueError("Accelerate isn't configured to handle FSDP. Please update your installation.") + fsdp_plugin = accelerator.state.fsdp_plugin if fsdp_plugin is None: @@ -442,9 +448,12 @@ def wrap_with_fsdp( FSDP-wrapped model """ + logger = get_logger(__name__) + if transformer_layer_cls is None: # Set the default layers if transformer_layer_cls is not provided transformer_layer_cls = type(model.model.language_model.layers[0]) + logger.info(f"transformer_layer_cls is not provided, auto-inferred as {transformer_layer_cls.__name__}") # Add auto-wrap policy if transformer layers specified auto_wrap_policy = partial( @@ -464,8 +473,6 @@ def wrap_with_fsdp( config.update(fsdp_kwargs) fsdp_model = FSDP(model, **config) - if dist.is_initialized(): - dist.barrier() return fsdp_model From 8bce38c0862ea69cb08d5cdc5be031b0d4711344 Mon Sep 17 00:00:00 2001 From: js1234567 Date: Tue, 23 Dec 2025 16:23:48 +0800 Subject: [PATCH 4/9] Add FSDP option for Flux2 --- examples/dreambooth/train_dreambooth_lora_flux2.py | 6 ++++-- examples/dreambooth/train_dreambooth_lora_flux2_img2img.py | 6 ++++-- src/diffusers/training_utils.py | 4 +++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index e25c7f166913..ff502d93099a 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -1281,7 +1281,8 @@ def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: transformer_lora_layers_to_save = get_peft_model_state_dict( - unwrap_model(model), state_dict=state_dict, + unwrap_model(model), + state_dict=state_dict, ) transformer_lora_layers_to_save = { k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v @@ -1326,7 +1327,8 @@ def load_model_hook(models, input_dir): raise ValueError(f"unexpected save model: {model.__class__}") else: transformer_ = Flux2Transformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="transformer", + args.pretrained_model_name_or_path, + subfolder="transformer", ) transformer_.add_adapter(transformer_lora_config) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 2062994a0dc1..cd2d493b1837 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1217,7 +1217,8 @@ def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: transformer_lora_layers_to_save = get_peft_model_state_dict( - unwrap_model(model), state_dict=state_dict, + unwrap_model(model), + state_dict=state_dict, ) transformer_lora_layers_to_save = { k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v @@ -1262,7 +1263,8 @@ def load_model_hook(models, input_dir): raise ValueError(f"unexpected save model: {model.__class__}") else: transformer_ = Flux2Transformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="transformer", + args.pretrained_model_name_or_path, + subfolder="transformer", ) transformer_.add_adapter(transformer_lora_config) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 56e5fe4e5a89..9b09d2c814d2 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -5,13 +5,15 @@ import random import re import warnings -from accelerate.logging import get_logger from contextlib import contextmanager from functools import partial from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union import numpy as np import torch +from accelerate.logging import get_logger + + if getattr(torch, "distributed", None) is not None: from torch.distributed.fsdp import CPUOffload, ShardingStrategy from torch.distributed.fsdp import FullyShardedDataParallel as FSDP From 6cfac4642f9f1e8fca83c490a42cea2686d65ada Mon Sep 17 00:00:00 2001 From: js1234567 Date: Wed, 24 Dec 2025 15:43:21 +0800 Subject: [PATCH 5/9] Add FSDP option for Flux2 --- examples/dreambooth/README_flux2.md | 3 + .../dreambooth/train_dreambooth_lora_flux2.py | 71 +++++++++--------- .../train_dreambooth_lora_flux2_img2img.py | 72 ++++++++++--------- src/diffusers/training_utils.py | 7 ++ 4 files changed, 87 insertions(+), 66 deletions(-) diff --git a/examples/dreambooth/README_flux2.md b/examples/dreambooth/README_flux2.md index 1d1777811387..876cdf270519 100644 --- a/examples/dreambooth/README_flux2.md +++ b/examples/dreambooth/README_flux2.md @@ -98,6 +98,9 @@ Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take This way, the text encoder model is not loaded into memory during training. > [!NOTE] > to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`. +### FSDP Text Encoder +Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings. +This way, the memory cost can be distributed in multiple nodes. ### CPU Offloading To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed. ### Latent Caching diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index ff502d93099a..71ef89a3594b 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -47,7 +47,6 @@ import numpy as np import torch -import torch.distributed as dist import transformers from accelerate import Accelerator from accelerate.logging import get_logger @@ -64,6 +63,7 @@ from torchvision.transforms import functional as TF from tqdm.auto import tqdm from transformers import Mistral3ForConditionalGeneration, PixtralProcessor +from typing import Any import diffusers from diffusers import ( @@ -76,6 +76,7 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import ( _collate_lora_metadata, + _to_cpu_contiguous, cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, @@ -96,6 +97,9 @@ from diffusers.utils.torch_utils import is_compiled_module +if getattr(torch, "distributed", None) is not None: + import torch.distributed as dist + if is_wandb_available(): import wandb @@ -1271,43 +1275,44 @@ def unwrap_model(model): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): - transformer_lora_layers_to_save = None - modules_to_save = {} + transformer_cls = type(unwrap_model(transformer)) - if is_fsdp: - for model in models: - if isinstance(unwrap_model(model), type(unwrap_model(transformer))): - state_dict = accelerator.get_state_dict(models) + # 1) Validate and pick the transformer model + modules_to_save: dict[str, Any] = {} + transformer_model = None - if accelerator.is_main_process: - transformer_lora_layers_to_save = get_peft_model_state_dict( - unwrap_model(model), - state_dict=state_dict, - ) - transformer_lora_layers_to_save = { - k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v - for k, v in transformer_lora_layers_to_save.items() - } - modules_to_save["transformer"] = model - - # make sure to pop weight so that corresponding model is not saved again - if weights: - weights.pop() - else: - if accelerator.is_main_process: - transformer_lora_layers_to_save = None - modules_to_save = {} - for model in models: - if isinstance(model, type(unwrap_model(transformer))): - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - modules_to_save["transformer"] = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + for model in models: + if isinstance(unwrap_model(model), transformer_cls): + transformer_model = model + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer_model is None: + raise ValueError("No transformer model found in 'models'") - # make sure to pop weight so that corresponding model is not saved again - weights.pop() + # 2) Optionally gather FSDP state dict once + state_dict = accelerator.get_state_dict(models) if is_fsdp else None + # 3) Only main process materializes the LoRA state dict + transformer_lora_layers_to_save = None if accelerator.is_main_process: + peft_kwargs = {} + if is_fsdp: + peft_kwargs["state_dict"] = state_dict + + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(transformer_model) if is_fsdp else transformer_model + ** peft_kwargs, + ) + + if is_fsdp: + transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save) + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + Flux2Pipeline.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index cd2d493b1837..48d4000cf812 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -46,7 +46,6 @@ import numpy as np import torch -import torch.distributed as dist import transformers from accelerate import Accelerator from accelerate.logging import get_logger @@ -62,6 +61,7 @@ from torchvision.transforms import functional as TF from tqdm.auto import tqdm from transformers import Mistral3ForConditionalGeneration, PixtralProcessor +from typing import Any import diffusers from diffusers import ( @@ -75,6 +75,7 @@ from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor from diffusers.training_utils import ( _collate_lora_metadata, + _to_cpu_contiguous cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, @@ -96,6 +97,9 @@ from diffusers.utils.torch_utils import is_compiled_module +if getattr(torch, "distributed", None) is not None: + import torch.distributed as dist + if is_wandb_available(): import wandb @@ -1208,42 +1212,44 @@ def unwrap_model(model): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): - transformer_lora_layers_to_save = None - modules_to_save = {} - if is_fsdp: - for model in models: - if isinstance(unwrap_model(model), type(unwrap_model(transformer))): - state_dict = accelerator.get_state_dict(models) + transformer_cls = type(unwrap_model(transformer)) - if accelerator.is_main_process: - transformer_lora_layers_to_save = get_peft_model_state_dict( - unwrap_model(model), - state_dict=state_dict, - ) - transformer_lora_layers_to_save = { - k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v - for k, v in transformer_lora_layers_to_save.items() - } - modules_to_save["transformer"] = model - - # make sure to pop weight so that corresponding model is not saved again - if weights: - weights.pop() - else: - if accelerator.is_main_process: - transformer_lora_layers_to_save = None - modules_to_save = {} - for model in models: - if isinstance(model, type(unwrap_model(transformer))): - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - modules_to_save["transformer"] = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + # 1) Validate and pick the transformer model + modules_to_save: dict[str, Any] = {} + transformer_model = None - # make sure to pop weight so that corresponding model is not saved again - weights.pop() + for model in models: + if isinstance(unwrap_model(model), transformer_cls): + transformer_model = model + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + if transformer_model is None: + raise ValueError("No transformer model found in 'models'") + + # 2) Optionally gather FSDP state dict once + state_dict = accelerator.get_state_dict(models) if is_fsdp else None + + # 3) Only main process materializes the LoRA state dict + transformer_lora_layers_to_save = None if accelerator.is_main_process: + peft_kwargs = {} + if is_fsdp: + peft_kwargs["state_dict"] = state_dict + + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(transformer_model) if is_fsdp else transformer_model + **peft_kwargs, + ) + + if is_fsdp: + transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save) + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + Flux2Pipeline.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 9b09d2c814d2..90523c4c3c1f 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -402,6 +402,13 @@ def find_nearest_bucket(h, w, bucket_options): return best_bucket_idx +def _to_cpu_contiguous(state_dicts) -> dict: + return { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in state_dicts.items() + } + + def get_fsdp_kwargs_from_accelerator(accelerator) -> dict: """ Extract and convert FSDP config from Accelerator into PyTorch FSDP kwargs. From af339debf46431694980476d67f97701a142c77d Mon Sep 17 00:00:00 2001 From: js1234567 Date: Wed, 24 Dec 2025 17:11:05 +0800 Subject: [PATCH 6/9] Add FSDP option for Flux2 --- examples/dreambooth/README_flux2.md | 2 +- examples/dreambooth/train_dreambooth_lora_flux2.py | 8 ++++---- .../dreambooth/train_dreambooth_lora_flux2_img2img.py | 8 ++++---- src/diffusers/training_utils.py | 5 +---- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/examples/dreambooth/README_flux2.md b/examples/dreambooth/README_flux2.md index 876cdf270519..41a77c3bbcc8 100644 --- a/examples/dreambooth/README_flux2.md +++ b/examples/dreambooth/README_flux2.md @@ -100,7 +100,7 @@ This way, the text encoder model is not loaded into memory during training. > to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`. ### FSDP Text Encoder Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings. -This way, the memory cost can be distributed in multiple nodes. +This way, it distributes the memory cost across multiple nodes. ### CPU Offloading To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed. ### Latent Caching diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 71ef89a3594b..6bba0b94b1b2 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -44,6 +44,7 @@ import warnings from contextlib import nullcontext from pathlib import Path +from typing import Any import numpy as np import torch @@ -63,7 +64,6 @@ from torchvision.transforms import functional as TF from tqdm.auto import tqdm from transformers import Mistral3ForConditionalGeneration, PixtralProcessor -from typing import Any import diffusers from diffusers import ( @@ -1292,7 +1292,7 @@ def save_model_hook(models, weights, output_dir): raise ValueError("No transformer model found in 'models'") # 2) Optionally gather FSDP state dict once - state_dict = accelerator.get_state_dict(models) if is_fsdp else None + state_dict = accelerator.get_state_dict(model) if is_fsdp else None # 3) Only main process materializes the LoRA state dict transformer_lora_layers_to_save = None @@ -1302,8 +1302,8 @@ def save_model_hook(models, weights, output_dir): peft_kwargs["state_dict"] = state_dict transformer_lora_layers_to_save = get_peft_model_state_dict( - unwrap_model(transformer_model) if is_fsdp else transformer_model - ** peft_kwargs, + unwrap_model(transformer_model) if is_fsdp else transformer_model, + **peft_kwargs, ) if is_fsdp: diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 48d4000cf812..c22c48ecaeb6 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -43,6 +43,7 @@ import shutil from contextlib import nullcontext from pathlib import Path +from typing import Any import numpy as np import torch @@ -61,7 +62,6 @@ from torchvision.transforms import functional as TF from tqdm.auto import tqdm from transformers import Mistral3ForConditionalGeneration, PixtralProcessor -from typing import Any import diffusers from diffusers import ( @@ -75,7 +75,7 @@ from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor from diffusers.training_utils import ( _collate_lora_metadata, - _to_cpu_contiguous + _to_cpu_contiguous, cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, @@ -1229,7 +1229,7 @@ def save_model_hook(models, weights, output_dir): raise ValueError("No transformer model found in 'models'") # 2) Optionally gather FSDP state dict once - state_dict = accelerator.get_state_dict(models) if is_fsdp else None + state_dict = accelerator.get_state_dict(model) if is_fsdp else None # 3) Only main process materializes the LoRA state dict transformer_lora_layers_to_save = None @@ -1239,7 +1239,7 @@ def save_model_hook(models, weights, output_dir): peft_kwargs["state_dict"] = state_dict transformer_lora_layers_to_save = get_peft_model_state_dict( - unwrap_model(transformer_model) if is_fsdp else transformer_model + unwrap_model(transformer_model) if is_fsdp else transformer_model, **peft_kwargs, ) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 90523c4c3c1f..2d2f26b266a1 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -403,10 +403,7 @@ def find_nearest_bucket(h, w, bucket_options): def _to_cpu_contiguous(state_dicts) -> dict: - return { - k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v - for k, v in state_dicts.items() - } + return {k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dicts.items()} def get_fsdp_kwargs_from_accelerator(accelerator) -> dict: From 8da9ea7d4a49e65b4276c4dc168e02bbe74ce50b Mon Sep 17 00:00:00 2001 From: js1234567 Date: Wed, 7 Jan 2026 09:50:02 +0800 Subject: [PATCH 7/9] Add FSDP option for Flux2 --- examples/dreambooth/README_flux2.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/examples/dreambooth/README_flux2.md b/examples/dreambooth/README_flux2.md index 41a77c3bbcc8..69bffc9d7a8c 100644 --- a/examples/dreambooth/README_flux2.md +++ b/examples/dreambooth/README_flux2.md @@ -169,6 +169,26 @@ To better track our training experiments, we're using the following flags in the > [!NOTE] > If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases. +### FSDP on the Transformers +By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to: + +```shell +distributed_type: FSDP +fsdp_config: + fsdp_version: 2 + fsdp_offload_params: false + fsdp_sharding_strategy: HYBRID_SHARD + fsdp_auto_wrap_policy: TRANSFOMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock, Flux2SingleTransformerBlock + fsdp_forward_prefetch: true + fsdp_sync_module_states: false + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_use_orig_params: false + fsdp_activation_checkpointing: true + fsdp_reshard_after_forward: true + fsdp_cpu_ram_efficient_loading: false +``` + ## LoRA + DreamBooth [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. From f392c60cde3186b026f9747c6972e2a61d4d9156 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 7 Jan 2026 09:38:28 +0530 Subject: [PATCH 8/9] Update examples/dreambooth/README_flux2.md --- examples/dreambooth/README_flux2.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/README_flux2.md b/examples/dreambooth/README_flux2.md index 69bffc9d7a8c..7b6df9c0202d 100644 --- a/examples/dreambooth/README_flux2.md +++ b/examples/dreambooth/README_flux2.md @@ -169,7 +169,7 @@ To better track our training experiments, we're using the following flags in the > [!NOTE] > If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases. -### FSDP on the Transformers +### FSDP on the transformer By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to: ```shell From 1b98e10614053fa28c8400ebbeea2856ef85c8eb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 7 Jan 2026 09:50:34 +0530 Subject: [PATCH 9/9] guard accelerate import. --- src/diffusers/training_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 2d2f26b266a1..3e9968d47fdd 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -11,7 +11,6 @@ import numpy as np import torch -from accelerate.logging import get_logger if getattr(torch, "distributed", None) is not None: @@ -26,6 +25,7 @@ convert_state_dict_to_diffusers, convert_state_dict_to_peft, deprecate, + is_accelerate_available, is_peft_available, is_torch_npu_available, is_torchvision_available, @@ -39,6 +39,9 @@ if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): import deepspeed +if is_accelerate_available(): + from accelerate.logging import get_logger + if is_peft_available(): from peft import set_peft_model_state_dict