diff --git a/examples/dreambooth/README_flux2.md b/examples/dreambooth/README_flux2.md index 1d1777811387..7b6df9c0202d 100644 --- a/examples/dreambooth/README_flux2.md +++ b/examples/dreambooth/README_flux2.md @@ -98,6 +98,9 @@ Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take This way, the text encoder model is not loaded into memory during training. > [!NOTE] > to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`. +### FSDP Text Encoder +Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings. +This way, it distributes the memory cost across multiple nodes. ### CPU Offloading To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed. ### Latent Caching @@ -166,6 +169,26 @@ To better track our training experiments, we're using the following flags in the > [!NOTE] > If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases. +### FSDP on the transformer +By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to: + +```shell +distributed_type: FSDP +fsdp_config: + fsdp_version: 2 + fsdp_offload_params: false + fsdp_sharding_strategy: HYBRID_SHARD + fsdp_auto_wrap_policy: TRANSFOMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock, Flux2SingleTransformerBlock + fsdp_forward_prefetch: true + fsdp_sync_module_states: false + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_use_orig_params: false + fsdp_activation_checkpointing: true + fsdp_reshard_after_forward: true + fsdp_cpu_ram_efficient_loading: false +``` + ## LoRA + DreamBooth [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 81306940af8f..6bba0b94b1b2 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -44,6 +44,7 @@ import warnings from contextlib import nullcontext from pathlib import Path +from typing import Any import numpy as np import torch @@ -75,13 +76,16 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import ( _collate_lora_metadata, + _to_cpu_contiguous, cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, find_nearest_bucket, free_memory, + get_fsdp_kwargs_from_accelerator, offload_models, parse_buckets_string, + wrap_with_fsdp, ) from diffusers.utils import ( check_min_version, @@ -93,6 +97,9 @@ from diffusers.utils.torch_utils import is_compiled_module +if getattr(torch, "distributed", None) is not None: + import torch.distributed as dist + if is_wandb_available(): import wandb @@ -722,6 +729,7 @@ def parse_args(input_args=None): ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") + parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder") if input_args is not None: args = parser.parse_args(input_args) @@ -1219,7 +1227,11 @@ def main(args): if args.bnb_quantization_config_path is not None else {"device": accelerator.device, "dtype": weight_dtype} ) - transformer.to(**transformer_to_kwargs) + + is_fsdp = accelerator.state.fsdp_plugin is not None + if not is_fsdp: + transformer.to(**transformer_to_kwargs) + if args.do_fp8_training: convert_to_float8_training( transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) @@ -1263,17 +1275,42 @@ def unwrap_model(model): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): + transformer_cls = type(unwrap_model(transformer)) + + # 1) Validate and pick the transformer model + modules_to_save: dict[str, Any] = {} + transformer_model = None + + for model in models: + if isinstance(unwrap_model(model), transformer_cls): + transformer_model = model + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer_model is None: + raise ValueError("No transformer model found in 'models'") + + # 2) Optionally gather FSDP state dict once + state_dict = accelerator.get_state_dict(model) if is_fsdp else None + + # 3) Only main process materializes the LoRA state dict + transformer_lora_layers_to_save = None if accelerator.is_main_process: - transformer_lora_layers_to_save = None - modules_to_save = {} - for model in models: - if isinstance(model, type(unwrap_model(transformer))): - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - modules_to_save["transformer"] = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + peft_kwargs = {} + if is_fsdp: + peft_kwargs["state_dict"] = state_dict + + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(transformer_model) if is_fsdp else transformer_model, + **peft_kwargs, + ) - # make sure to pop weight so that corresponding model is not saved again + if is_fsdp: + transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save) + + # make sure to pop weight so that corresponding model is not saved again + if weights: weights.pop() Flux2Pipeline.save_lora_weights( @@ -1285,13 +1322,20 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): transformer_ = None - while len(models) > 0: - model = models.pop() + if not is_fsdp: + while len(models) > 0: + model = models.pop() - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + ) + transformer_.add_adapter(transformer_lora_config) lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir) @@ -1507,6 +1551,21 @@ def _encode_single(prompt: str): args.validation_prompt, text_encoding_pipeline ) + # Init FSDP for text encoder + if args.fsdp_text_encoder: + fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator) + text_encoder_fsdp = wrap_with_fsdp( + model=text_encoding_pipeline.text_encoder, + device=accelerator.device, + offload=args.offload, + limit_all_gathers=True, + use_orig_params=True, + fsdp_kwargs=fsdp_kwargs, + ) + + text_encoding_pipeline.text_encoder = text_encoder_fsdp + dist.barrier() + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. @@ -1536,6 +1595,8 @@ def _encode_single(prompt: str): if train_dataset.custom_instance_prompts: if args.remote_text_encoder: prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"]) + elif args.fsdp_text_encoder: + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) else: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) @@ -1777,7 +1838,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or is_fsdp: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: @@ -1836,15 +1897,41 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Save the lora layers accelerator.wait_for_everyone() + + if is_fsdp: + transformer = unwrap_model(transformer) + state_dict = accelerator.get_state_dict(transformer) if accelerator.is_main_process: modules_to_save = {} - transformer = unwrap_model(transformer) - if args.bnb_quantization_config_path is None: - if args.upcast_before_saving: - transformer.to(torch.float32) - else: - transformer = transformer.to(weight_dtype) - transformer_lora_layers = get_peft_model_state_dict(transformer) + if is_fsdp: + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + state_dict = { + k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + else: + state_dict = { + k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + + transformer_lora_layers = get_peft_model_state_dict( + transformer, + state_dict=state_dict, + ) + transformer_lora_layers = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers.items() + } + + else: + transformer = unwrap_model(transformer) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer Flux2Pipeline.save_lora_weights( diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 0b9b9f993094..c22c48ecaeb6 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -43,6 +43,7 @@ import shutil from contextlib import nullcontext from pathlib import Path +from typing import Any import numpy as np import torch @@ -74,13 +75,16 @@ from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor from diffusers.training_utils import ( _collate_lora_metadata, + _to_cpu_contiguous, cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, find_nearest_bucket, free_memory, + get_fsdp_kwargs_from_accelerator, offload_models, parse_buckets_string, + wrap_with_fsdp, ) from diffusers.utils import ( check_min_version, @@ -93,6 +97,9 @@ from diffusers.utils.torch_utils import is_compiled_module +if getattr(torch, "distributed", None) is not None: + import torch.distributed as dist + if is_wandb_available(): import wandb @@ -691,6 +698,7 @@ def parse_args(input_args=None): parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") + parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder") if input_args is not None: args = parser.parse_args(input_args) @@ -1156,7 +1164,11 @@ def main(args): if args.bnb_quantization_config_path is not None else {"device": accelerator.device, "dtype": weight_dtype} ) - transformer.to(**transformer_to_kwargs) + + is_fsdp = accelerator.state.fsdp_plugin is not None + if not is_fsdp: + transformer.to(**transformer_to_kwargs) + if args.do_fp8_training: convert_to_float8_training( transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) @@ -1200,17 +1212,42 @@ def unwrap_model(model): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): + transformer_cls = type(unwrap_model(transformer)) + + # 1) Validate and pick the transformer model + modules_to_save: dict[str, Any] = {} + transformer_model = None + + for model in models: + if isinstance(unwrap_model(model), transformer_cls): + transformer_model = model + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer_model is None: + raise ValueError("No transformer model found in 'models'") + + # 2) Optionally gather FSDP state dict once + state_dict = accelerator.get_state_dict(model) if is_fsdp else None + + # 3) Only main process materializes the LoRA state dict + transformer_lora_layers_to_save = None if accelerator.is_main_process: - transformer_lora_layers_to_save = None - modules_to_save = {} - for model in models: - if isinstance(model, type(unwrap_model(transformer))): - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - modules_to_save["transformer"] = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + peft_kwargs = {} + if is_fsdp: + peft_kwargs["state_dict"] = state_dict + + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(transformer_model) if is_fsdp else transformer_model, + **peft_kwargs, + ) - # make sure to pop weight so that corresponding model is not saved again + if is_fsdp: + transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save) + + # make sure to pop weight so that corresponding model is not saved again + if weights: weights.pop() Flux2Pipeline.save_lora_weights( @@ -1222,13 +1259,20 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): transformer_ = None - while len(models) > 0: - model = models.pop() + if not is_fsdp: + while len(models) > 0: + model = models.pop() - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + ) + transformer_.add_adapter(transformer_lora_config) lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir) @@ -1430,6 +1474,21 @@ def _encode_single(prompt: str): args.validation_prompt, text_encoding_pipeline ) + # Init FSDP for text encoder + if args.fsdp_text_encoder: + fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator) + text_encoder_fsdp = wrap_with_fsdp( + model=text_encoding_pipeline.text_encoder, + device=accelerator.device, + offload=args.offload, + limit_all_gathers=True, + use_orig_params=True, + fsdp_kwargs=fsdp_kwargs, + ) + + text_encoding_pipeline.text_encoder = text_encoder_fsdp + dist.barrier() + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. @@ -1461,6 +1520,8 @@ def _encode_single(prompt: str): if train_dataset.custom_instance_prompts: if args.remote_text_encoder: prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"]) + elif args.fsdp_text_encoder: + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) else: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) @@ -1700,7 +1761,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or is_fsdp: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: @@ -1759,15 +1820,41 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Save the lora layers accelerator.wait_for_everyone() + + if is_fsdp: + transformer = unwrap_model(transformer) + state_dict = accelerator.get_state_dict(transformer) if accelerator.is_main_process: modules_to_save = {} - transformer = unwrap_model(transformer) - if args.bnb_quantization_config_path is None: - if args.upcast_before_saving: - transformer.to(torch.float32) - else: - transformer = transformer.to(weight_dtype) - transformer_lora_layers = get_peft_model_state_dict(transformer) + if is_fsdp: + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + state_dict = { + k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + else: + state_dict = { + k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + + transformer_lora_layers = get_peft_model_state_dict( + transformer, + state_dict=state_dict, + ) + transformer_lora_layers = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers.items() + } + + else: + transformer = unwrap_model(transformer) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer Flux2Pipeline.save_lora_weights( diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 7a98fa3da14a..3e9968d47fdd 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -6,11 +6,18 @@ import re import warnings from contextlib import contextmanager -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from functools import partial +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union import numpy as np import torch + +if getattr(torch, "distributed", None) is not None: + from torch.distributed.fsdp import CPUOffload, ShardingStrategy + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy + from .models import UNet2DConditionModel from .pipelines import DiffusionPipeline from .schedulers import SchedulerMixin @@ -18,6 +25,7 @@ convert_state_dict_to_diffusers, convert_state_dict_to_peft, deprecate, + is_accelerate_available, is_peft_available, is_torch_npu_available, is_torchvision_available, @@ -31,6 +39,9 @@ if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): import deepspeed +if is_accelerate_available(): + from accelerate.logging import get_logger + if is_peft_available(): from peft import set_peft_model_state_dict @@ -394,6 +405,86 @@ def find_nearest_bucket(h, w, bucket_options): return best_bucket_idx +def _to_cpu_contiguous(state_dicts) -> dict: + return {k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dicts.items()} + + +def get_fsdp_kwargs_from_accelerator(accelerator) -> dict: + """ + Extract and convert FSDP config from Accelerator into PyTorch FSDP kwargs. + """ + + kwargs = {} + fsdp_state = getattr(accelerator.state, "fsdp_plugin", None) + + if fsdp_state is None: + raise ValueError("Accelerate isn't configured to handle FSDP. Please update your installation.") + + fsdp_plugin = accelerator.state.fsdp_plugin + + if fsdp_plugin is None: + # FSDP not enabled in Accelerator + kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD + else: + # FSDP is enabled → use plugin's strategy, or default if None + kwargs["sharding_strategy"] = fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD + + return kwargs + + +def wrap_with_fsdp( + model: torch.nn.Module, + device: Union[str, torch.device], + offload: bool = True, + use_orig_params: bool = True, + limit_all_gathers: bool = True, + fsdp_kwargs: Optional[Dict[str, Any]] = None, + transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None, +) -> FSDP: + """ + Wrap a model with FSDP using common defaults and optional transformer auto-wrapping. + + Args: + model: Model to wrap + device: Target device (e.g., accelerator.device) + offload: Whether to enable CPU parameter offloading + use_orig_params: Whether to use original parameters + limit_all_gathers: Whether to limit all gathers + fsdp_kwargs: FSDP arguments (sharding_strategy, etc.) — usually from Accelerate config + transformer_layer_cls: Classes for auto-wrapping (if not using policy from fsdp_kwargs) + + Returns: + FSDP-wrapped model + """ + + logger = get_logger(__name__) + + if transformer_layer_cls is None: + # Set the default layers if transformer_layer_cls is not provided + transformer_layer_cls = type(model.model.language_model.layers[0]) + logger.info(f"transformer_layer_cls is not provided, auto-inferred as {transformer_layer_cls.__name__}") + + # Add auto-wrap policy if transformer layers specified + auto_wrap_policy = partial( + transformer_auto_wrap_policy, + transformer_layer_cls={transformer_layer_cls}, + ) + + config = { + "device_id": device, + "cpu_offload": CPUOffload(offload_params=offload) if offload else None, + "use_orig_params": use_orig_params, + "limit_all_gathers": limit_all_gathers, + "auto_wrap_policy": auto_wrap_policy, + } + + if fsdp_kwargs: + config.update(fsdp_kwargs) + + fsdp_model = FSDP(model, **config) + return fsdp_model + + # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """