From 4347e9f4afeb0ff128535f922e00f3f92470395a Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Thu, 21 May 2026 21:16:35 +0000 Subject: [PATCH 01/10] initial refactor Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/fsdp2.yaml | 30 --- examples/llm_ptq/multinode_ptq.py | 340 ++++++++++++++++++------------ 2 files changed, 207 insertions(+), 163 deletions(-) delete mode 100644 examples/llm_ptq/fsdp2.yaml diff --git a/examples/llm_ptq/fsdp2.yaml b/examples/llm_ptq/fsdp2.yaml deleted file mode 100644 index 646d63f9e67..00000000000 --- a/examples/llm_ptq/fsdp2.yaml +++ /dev/null @@ -1,30 +0,0 @@ -# ============================================================================= -# FSDP Configuration for running LLM PTQ on multinode setup. This file is consumed by examples/llm_ptq/multinode_ptq.py -# ============================================================================= - -compute_environment: LOCAL_MACHINE -debug: false -distributed_type: FSDP -downcast_bf16: 'no' -enable_cpu_affinity: false -fsdp_config: - fsdp_activation_checkpointing: false - fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP - fsdp_cpu_ram_efficient_loading: true - fsdp_offload_params: false - fsdp_reshard_after_forward: true - fsdp_state_dict_type: FULL_STATE_DICT - fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer - fsdp_use_orig_params: true - fsdp_version: 2 -machine_rank: 0 -main_training_function: main -mixed_precision: 'no' -num_machines: 2 -num_processes: 16 -rdzv_backend: c10d -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/examples/llm_ptq/multinode_ptq.py b/examples/llm_ptq/multinode_ptq.py index 93ef21ea4d4..649077592bc 100644 --- a/examples/llm_ptq/multinode_ptq.py +++ b/examples/llm_ptq/multinode_ptq.py @@ -16,6 +16,7 @@ """Multi-node PTQ (Post-Training Quantization) with FSDP2 support.""" import argparse +import copy import json import os import random @@ -26,11 +27,24 @@ import numpy as np import torch +import torch.distributed as dist import torch.nn as nn -from accelerate import Accelerator from example_utils import build_quant_cfg, get_tokenizer +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + set_model_state_dict, +) +from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy, fully_shard +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm -from transformers import AutoModelForCausalLM, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq @@ -39,16 +53,26 @@ from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint from modelopt.torch.quantization.config import need_calibration from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector from modelopt.torch.utils.dataset_utils import get_dataset_dataloader, get_supported_datasets -# Constants RAND_SEED = 1234 + +def _nvfp4_max_cfg(*, layerwise: bool) -> dict[str, Any]: + """NVFP4 quant config with explicit max calibration and a layerwise toggle.""" + cfg = copy.deepcopy(mtq.NVFP4_DEFAULT_CFG) + cfg["algorithm"] = {"method": "max", "layerwise": layerwise} + return cfg + + QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { "int8": mtq.INT8_DEFAULT_CFG, "int4_awq": mtq.INT4_AWQ_CFG, "fp8": mtq.FP8_DEFAULT_CFG, "nvfp4": mtq.NVFP4_DEFAULT_CFG, + "nvfp4_max": _nvfp4_max_cfg(layerwise=False), + "nvfp4_max_layerwise": _nvfp4_max_cfg(layerwise=True), "nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG, "w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG, "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, @@ -64,7 +88,6 @@ } -# Enable HuggingFace checkpointing mto.enable_huggingface_checkpointing() @@ -72,16 +95,9 @@ def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser(description="Multi-node post-training quantization with FSDP2") + parser.add_argument("--pyt_ckpt_path", required=True, help="Path to PyTorch checkpoint") parser.add_argument( - "--pyt_ckpt_path", - required=True, - help="Path to PyTorch checkpoint", - ) - parser.add_argument( - "--qformat", - default="fp8", - choices=QUANT_CFG_CHOICES.keys(), - help="Quantization format", + "--qformat", default="fp8", choices=QUANT_CFG_CHOICES.keys(), help="Quantization format" ) parser.add_argument( "--kv_cache_qformat", @@ -89,12 +105,7 @@ def parse_args(): choices=list(KV_QUANT_CFG_CHOICES.keys()), help="KV cache quantization format", ) - parser.add_argument( - "--batch_size", - type=int, - default=1, - help="Batch size for calibration", - ) + parser.add_argument("--batch_size", type=int, default=1, help="Batch size for calibration") parser.add_argument( "--calib_size", type=str, @@ -103,17 +114,15 @@ def parse_args(): ) parser.add_argument( "--dataset", + type=str, + default=None, help=( f"name of a dataset, or a comma separated list of datasets. " f"dataset choices are {get_supported_datasets()}" ), - type=str, - default=None, ) parser.add_argument( - "--export_path", - default="exported_model", - help="Directory to export the quantized model", + "--export_path", default="exported_model", help="Directory to export the quantized model" ) parser.add_argument( "--trust_remote_code", @@ -121,47 +130,138 @@ def parse_args(): help="Trust remote code for HuggingFace models", ) parser.add_argument("--awq_block_size", default=0, type=int) + parser.add_argument( + "--fsdp_transformer_layer_cls_to_wrap", + default=None, + help=( + "Override auto-detect by transformer layer class name " + "(e.g. LlamaDecoderLayer). Auto-detected when omitted." + ), + ) + parser.add_argument( + "--cpu_offload", + action="store_true", + help="Keep FSDP2 sharded params on CPU; gather to GPU per layer forward.", + ) args = parser.parse_args() - # Parse comma-separated lists args.dataset = args.dataset.split(",") if args.dataset else None args.calib_size = [int(x) for x in args.calib_size.split(",")] return args +def setup_distributed() -> tuple[int, int, int, torch.device]: + """Initialize torch.distributed from torchrun env vars and pin the CUDA device.""" + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + print(f"Rank: {rank}, World Size: {world_size}, Local Rank: {local_rank}") + return rank, world_size, local_rank, device + + +def _resolve_decoder_layers(model: nn.Module, override_cls_name: str | None): + """Return the list of decoder layers to apply ``fully_shard`` to.""" + if override_cls_name: + layers = [m for m in model.modules() if type(m).__name__ == override_cls_name] + if not layers: + raise RuntimeError(f"No modules of class {override_cls_name!r} found in model") + return layers + layers = LayerActivationCollector.get_decoder_layers(model) + if layers is None: + raise RuntimeError( + "Could not auto-detect decoder layers; pass " + "--fsdp_transformer_layer_cls_to_wrap explicitly." + ) + return layers + + +def fsdp2_wrap( + model: nn.Module, + override_cls_name: str | None = None, + cpu_offload: bool = False, +) -> nn.Module: + """Apply FSDP2 ``fully_shard`` to each decoder layer, then to the root module.""" + offload_policy: OffloadPolicy = CPUOffloadPolicy() if cpu_offload else OffloadPolicy() + for layer in _resolve_decoder_layers(model, override_cls_name): + fully_shard(layer, reshard_after_forward=True, offload_policy=offload_policy) + fully_shard(model, reshard_after_forward=True, offload_policy=offload_policy) + return model + + def load_and_prepare_model( model_path: str, - calib_dataloader: torch.utils.data.DataLoader, - accelerator: Accelerator, + device: torch.device, + rank: int, trust_remote_code: bool = False, -) -> tuple[nn.Module, str, list[str], torch.utils.data.DataLoader]: - """Load model and prepare it for FSDP2 distributed execution. - - Args: - model_path: Path to the HuggingFace model - calibration_dataloader: Calibration dataloader to be sharded for calibration - accelerator: Accelerate's Accelerator instance - trust_remote_code: Whether to trust remote code - - Returns: - Tuple of (prepared_model, model_type, original_architectures, calibration_dataloader) + override_cls_name: str | None = None, + cpu_offload: bool = False, +) -> tuple[nn.Module, str, list[str]]: + """Load model and shard it with FSDP2 using rank-0-only CPU realization. + + Only rank 0 reads real weights from disk; every other rank instantiates the + model on the ``meta`` device. After ``fully_shard`` sets up the sharded + DTensor layout and ``to_empty`` allocates per-rank shard storage, rank 0's + full state dict is broadcast into the sharded structure. """ - model = AutoModelForCausalLM.from_pretrained( - model_path, dtype="auto", trust_remote_code=trust_remote_code - ) + config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) + dtype = getattr(config, "torch_dtype", None) or torch.bfloat16 + + if rank == 0: + src_model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="auto", + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + src_model.eval() + cpu_state_dict = src_model.state_dict() + else: + src_model = None + cpu_state_dict = {} + + with torch.device("meta"): + model = AutoModelForCausalLM.from_config( + config, torch_dtype=dtype, trust_remote_code=trust_remote_code + ) model.eval() + model_type = get_model_type(model) - # Need the original architectures for export - # FSDP prefix is added to the architectures for FSDP2 wrapped models original_architectures = model.config.architectures - # FSDP2 requires an optimizer to be prepared together with the model - dummy_optimizer = torch.optim.SGD(model.parameters(), lr=0.0) - model, _, calibration_dataloader = accelerator.prepare(model, dummy_optimizer, calib_dataloader) + fsdp2_wrap(model, override_cls_name=override_cls_name, cpu_offload=cpu_offload) + + # For CPU offload: FSDP2 requires its managed params on CPU at lazy_init, + # so materialize the whole model on CPU. Otherwise materialize on GPU. + materialize_device = torch.device("cpu") if cpu_offload else device + model.to_empty(device=materialize_device) + + set_model_state_dict( + model, + cpu_state_dict, + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), + ) + + # With CPU offload, FSDP-managed params stay on CPU but buffers (e.g. MoE + # router corrections, RoPE caches) must live on GPU for layer forwards. + if cpu_offload: + for b in model.buffers(): + b.data = b.data.to(device, non_blocking=True) + torch.cuda.synchronize() + + # Freeze every param so patch_fsdp_mp_dtypes' trainable-only check skips the + # uniform-dtype assertion (e.g. Nemotron-H ships mixed bf16/fp32 weights). + for p in model.parameters(): + p.requires_grad_(False) + + del cpu_state_dict, src_model - return model, model_type, original_architectures, calibration_dataloader + return model, model_type, original_architectures def create_calibration_dataloader( @@ -169,191 +269,167 @@ def create_calibration_dataloader( dataset_names: list[str], calib_sizes: list[int], batch_size: int, -) -> torch.utils.data.DataLoader: - """Create calibration dataloader from dataset. - - Args: - tokenizer: HuggingFace tokenizer - dataset_names: List of dataset names (defaults to cnn_dailymail) - calib_sizes: Number of samples for each dataset - batch_size: Batch size for calibration - - Returns: - DataLoader for calibration - """ - +) -> DataLoader: + """Create calibration dataloader from dataset.""" return get_dataset_dataloader( dataset_name=dataset_names, tokenizer=tokenizer, batch_size=batch_size, num_samples=calib_sizes, - device=None, # Keep data on CPU, calibration loop handles device transfer + device=None, include_labels=False, ) +def shard_dataloader(loader: DataLoader, rank: int, world_size: int) -> DataLoader: + """Wrap a DataLoader with a DistributedSampler so each rank sees a unique shard.""" + sampler = DistributedSampler( + loader.dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + return DataLoader( + loader.dataset, + batch_size=loader.batch_size, + sampler=sampler, + collate_fn=loader.collate_fn, + num_workers=loader.num_workers, + pin_memory=loader.pin_memory, + ) + + def create_fsdp2_calibration_loop( model: nn.Module, - dataloader: torch.utils.data.DataLoader, - accelerator: Accelerator, + dataloader: DataLoader, + device: torch.device, ): - """Create calibration loop compatible with FSDP2. - - For FSDP2, we need to use the outer FSDP-wrapped model instead of - the parameter passed by mtq.quantize to properly handle DTensor. - - Args: - model: FSDP2-wrapped model - dataloader: Calibration dataloader - accelerator: Accelerator instance for device management - - Returns: - Calibration function compatible with mtq.quantize - """ + """Calibration loop that forwards through the FSDP-wrapped model.""" def calibrate(unwrapped_model): - """Calibration loop that uses the FSDP-wrapped model.""" for batch in tqdm(dataloader, desc="Calibrating"): if isinstance(batch, dict): batch = { - k: v.to(accelerator.device) if isinstance(v, torch.Tensor) else v - for k, v in batch.items() + k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() } - # Use outer model (FSDP-wrapped), not the parameter - # Important: We should forward pass using the unwrapped model - # mtq.quantize will unwrap the model & pass to the forward_loop + # Use outer (FSDP-wrapped) model, not the unwrapped parameter passed by mtq.quantize. model(**batch) return calibrate +class _Fsdp2StateDictAdapter: + """Shim exposing ``.get_state_dict(model)`` to ``_export_transformers_checkpoint``.""" + + def get_state_dict(self, model: nn.Module): + return get_model_state_dict( + model, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + + def export_model( model: nn.Module, - accelerator: Accelerator, + rank: int, export_path: str | Path, architectures: list[str], ): - """Export quantized model to HuggingFace format. - - Args: - model: Quantized model - accelerator: Accelerator instance for state dict gathering - export_path: Directory to export model to - """ + """Export the quantized model to HuggingFace format on rank 0.""" export_dir = Path(export_path) export_dir.mkdir(parents=True, exist_ok=True) + adapter = _Fsdp2StateDictAdapter() post_state_dict, hf_quant_config = _export_transformers_checkpoint( - model, torch.bfloat16, accelerator=accelerator + model, torch.bfloat16, accelerator=adapter ) - if accelerator.is_main_process: - # Save hf_quant_config.json for backward compatibility + if rank == 0: with open(f"{export_dir}/hf_quant_config.json", "w") as file: json.dump(hf_quant_config, file, indent=4) hf_quant_config = convert_hf_quant_config_format(hf_quant_config) - # Save model model.save_pretrained(export_dir, state_dict=post_state_dict, save_modelopt_state=False) original_config = f"{export_dir}/config.json" - config_data = {} - with open(original_config) as file: config_data = json.load(file) config_data["quantization_config"] = hf_quant_config - # Update config architectures to use original architectures that does not have FSDP prefix config_data["architectures"] = architectures with open(original_config, "w") as file: json.dump(config_data, file, indent=4) + dist.barrier() + def main(args): """Main quantization workflow.""" - # Validate GPU availability if not torch.cuda.is_available(): raise OSError("GPU is required for quantization.") - # Validate quantization format if args.qformat not in QUANT_CFG_CHOICES: raise ValueError( f"Quantization format {args.qformat} not supported. Choose from: {QUANT_CFG_CHOICES.keys()}" ) - # Set random seeds random.seed(RAND_SEED) np.random.seed(RAND_SEED) torch.manual_seed(RAND_SEED) - # Initialize accelerator - accelerator = Accelerator() + rank, world_size, _, device = setup_distributed() - print(f"Rank: {os.environ.get('RANK', 'Not set')}") - print(f"World Size: {os.environ.get('WORLD_SIZE', 'Not set')}") - print(f"Local Rank: {os.environ.get('LOCAL_RANK', 'Not set')}") - - # Load tokenizer tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) default_padding_side = tokenizer.padding_side - tokenizer.padding_side = "left" # Left padding for better calibration + tokenizer.padding_side = "left" - # Set default dataset if not provided if args.dataset is None: args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] warnings.warn( "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2." ) - # Adjust calib_size to match dataset length by extending or truncating as needed args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[ : len(args.dataset) ] - # Create calibration dataloader with max batch size calib_dataloader = create_calibration_dataloader( tokenizer=tokenizer, dataset_names=args.dataset, calib_sizes=args.calib_size, batch_size=args.batch_size, ) + calib_dataloader = shard_dataloader(calib_dataloader, rank, world_size) - # Load and prepare model - model, model_type, original_architectures, calib_dataloader = load_and_prepare_model( + model, model_type, original_architectures = load_and_prepare_model( model_path=args.pyt_ckpt_path, - calib_dataloader=calib_dataloader, - accelerator=accelerator, + device=device, + rank=rank, trust_remote_code=args.trust_remote_code, + override_cls_name=args.fsdp_transformer_layer_cls_to_wrap, + cpu_offload=args.cpu_offload, ) quant_cfg = QUANT_CFG_CHOICES[args.qformat] - - quant_cfg = build_quant_cfg( - args.qformat, - quant_cfg, - args.awq_block_size, - model_type, - ) + quant_cfg = build_quant_cfg(args.qformat, quant_cfg, args.awq_block_size, model_type) enable_quant_kv_cache = args.kv_cache_qformat != "none" print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") - # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer. if enable_quant_kv_cache: quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant( quant_cfg, getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"], ) - # Quantize the model - if accelerator.is_main_process: + if rank == 0: print("Starting quantization...") start_time = time.time() if need_calibration(quant_cfg): - calibrate_fn = create_fsdp2_calibration_loop(model, calib_dataloader, accelerator) + calibrate_fn = create_fsdp2_calibration_loop(model, calib_dataloader, device) else: calibrate_fn = None warnings.warn("Dynamic quantization. Calibration skipped.") @@ -363,28 +439,26 @@ def main(args): elapsed = time.time() - start_time - if accelerator.is_main_process: + if rank == 0: print(f"Quantization completed in {elapsed:.2f}s") mtq.print_quant_summary(model) start_time = time.time() - export_model(model, accelerator, args.export_path, original_architectures) + export_model(model, rank, args.export_path, original_architectures) elapsed = time.time() - start_time - if accelerator.is_main_process: - # Restore default padding and export the tokenizer as well. + if rank == 0: if tokenizer is not None: tokenizer.padding_side = default_padding_side tokenizer.save_pretrained(args.export_path) - # Export the model print(f"Export completed in {elapsed:.2f}s") print(f"Model exported to {args.export_path}") - print("Unpatching FSDP2 MP dtypes") + dist.barrier() + dist.destroy_process_group() if __name__ == "__main__": args = parse_args() - # This context manager can be removed once the update to FSDP2 function is reflected in torch with patch_fsdp_mp_dtypes(): main(args) From af7bfdd1c5de64b322ba6786690906224a854bc6 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 22 May 2026 00:40:30 +0000 Subject: [PATCH 02/10] cleanup Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/multinode_ptq.py | 35 ++++++------------------------- 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/examples/llm_ptq/multinode_ptq.py b/examples/llm_ptq/multinode_ptq.py index 649077592bc..d77c7c473bc 100644 --- a/examples/llm_ptq/multinode_ptq.py +++ b/examples/llm_ptq/multinode_ptq.py @@ -35,7 +35,7 @@ get_model_state_dict, set_model_state_dict, ) -from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy, fully_shard +from torch.distributed.fsdp import fully_shard from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm @@ -138,12 +138,6 @@ def parse_args(): "(e.g. LlamaDecoderLayer). Auto-detected when omitted." ), ) - parser.add_argument( - "--cpu_offload", - action="store_true", - help="Keep FSDP2 sharded params on CPU; gather to GPU per layer forward.", - ) - args = parser.parse_args() args.dataset = args.dataset.split(",") if args.dataset else None @@ -181,16 +175,11 @@ def _resolve_decoder_layers(model: nn.Module, override_cls_name: str | None): return layers -def fsdp2_wrap( - model: nn.Module, - override_cls_name: str | None = None, - cpu_offload: bool = False, -) -> nn.Module: +def fsdp2_wrap(model: nn.Module, override_cls_name: str | None = None) -> nn.Module: """Apply FSDP2 ``fully_shard`` to each decoder layer, then to the root module.""" - offload_policy: OffloadPolicy = CPUOffloadPolicy() if cpu_offload else OffloadPolicy() for layer in _resolve_decoder_layers(model, override_cls_name): - fully_shard(layer, reshard_after_forward=True, offload_policy=offload_policy) - fully_shard(model, reshard_after_forward=True, offload_policy=offload_policy) + fully_shard(layer, reshard_after_forward=True) + fully_shard(model, reshard_after_forward=True) return model @@ -200,7 +189,6 @@ def load_and_prepare_model( rank: int, trust_remote_code: bool = False, override_cls_name: str | None = None, - cpu_offload: bool = False, ) -> tuple[nn.Module, str, list[str]]: """Load model and shard it with FSDP2 using rank-0-only CPU realization. @@ -234,12 +222,9 @@ def load_and_prepare_model( model_type = get_model_type(model) original_architectures = model.config.architectures - fsdp2_wrap(model, override_cls_name=override_cls_name, cpu_offload=cpu_offload) + fsdp2_wrap(model, override_cls_name=override_cls_name) - # For CPU offload: FSDP2 requires its managed params on CPU at lazy_init, - # so materialize the whole model on CPU. Otherwise materialize on GPU. - materialize_device = torch.device("cpu") if cpu_offload else device - model.to_empty(device=materialize_device) + model.to_empty(device=device) set_model_state_dict( model, @@ -247,13 +232,6 @@ def load_and_prepare_model( options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), ) - # With CPU offload, FSDP-managed params stay on CPU but buffers (e.g. MoE - # router corrections, RoPE caches) must live on GPU for layer forwards. - if cpu_offload: - for b in model.buffers(): - b.data = b.data.to(device, non_blocking=True) - torch.cuda.synchronize() - # Freeze every param so patch_fsdp_mp_dtypes' trainable-only check skips the # uniform-dtype assertion (e.g. Nemotron-H ships mixed bf16/fp32 weights). for p in model.parameters(): @@ -408,7 +386,6 @@ def main(args): rank=rank, trust_remote_code=args.trust_remote_code, override_cls_name=args.fsdp_transformer_layer_cls_to_wrap, - cpu_offload=args.cpu_offload, ) quant_cfg = QUANT_CFG_CHOICES[args.qformat] From 0ce7cd7f0a902729a9333b605fcaebcf43b766d9 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 26 May 2026 04:49:45 +0000 Subject: [PATCH 03/10] prototype, untested Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 100 ++++++++++++++++++++++ examples/llm_ptq/hf_ptq.py | 119 ++++++++++++++++++++++++--- examples/llm_ptq/multinode_ptq.py | 13 ++- modelopt/torch/utils/distributed.py | 123 ++++++++++++++++++++++++++++ 4 files changed, 342 insertions(+), 13 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 9455157645c..ff2ba9631bc 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -51,6 +51,106 @@ SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"] +# --------------------------------------------------------------------------- +# FSDP2 helpers (opt-in via --use_fsdp2 in hf_ptq.py / multinode_ptq.py). +# --------------------------------------------------------------------------- + + +def setup_distributed_args(args): + """Populate ``args.rank`` / ``world_size`` / ``device`` / ``is_main``. + + When ``--use_fsdp2`` is set, initializes the distributed process group and + pins this rank's CUDA device. When the flag is off, fills no-op values so + downstream helpers can use ``args.is_main`` and ``args.rank`` uniformly. + """ + from modelopt.torch.utils import distributed as dist_utils + + if getattr(args, "use_fsdp2", False): + dist_utils.setup() + args.rank = dist_utils.rank() + args.world_size = dist_utils.size() + args.device = torch.device(f"cuda:{dist_utils.local_rank()}") + args.is_main = args.rank == 0 + else: + args.rank = 0 + args.world_size = 1 + args.device = None + args.is_main = True + + +def cleanup_distributed(args): + """Tear down the distributed process group if FSDP2 set it up.""" + from modelopt.torch.utils import distributed as dist_utils + + if getattr(args, "use_fsdp2", False): + dist_utils.cleanup() + + +def load_and_prepare_fsdp2_model( + ckpt_path: str, + device: torch.device, + rank: int, + trust_remote_code: bool = False, + mp_policy=None, +): + """Load and FSDP2-shard a causal LM with rank-0-only CPU realization. + + Rank 0 reads weights from disk on CPU via ``from_pretrained``; other ranks + build a structural skeleton on the ``meta`` device. ``fsdp2_shard`` then + slices each decoder layer, allocates per-rank GPU shard storage, broadcasts + rank-0's weights into the shards, and freezes params. + + Memory: only rank 0 holds the full CPU copy. Each rank ends with + ``model_size / world_size`` of GPU shard storage. + + v1 supports standard transformers families only (causal LMs that load + cleanly via ``AutoModelForCausalLM``). VILA / pack-quantized / speculative + are not validated under FSDP2 and should go through ``get_model``. + """ + from modelopt.torch.utils.distributed import fsdp2_shard + + hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=trust_remote_code) + if rank == 0: + model = AutoModelForCausalLM.from_pretrained( + ckpt_path, + torch_dtype="auto", + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + else: + dtype = getattr(hf_config, "torch_dtype", None) or torch.bfloat16 + with torch.device("meta"): + model = AutoModelForCausalLM.from_config( + hf_config, torch_dtype=dtype, trust_remote_code=trust_remote_code + ) + model.eval() + # Disable HF KV cache; calibration is single-pass and (for layerwise) replays. + if hasattr(model, "config") and hasattr(model.config, "use_cache"): + model.config.use_cache = False + return fsdp2_shard(model, device, rank, mp_policy=mp_policy) + + +def create_fsdp2_calibration_loop(model, dataloader, device): + """Calibration closure that forwards through the outer FSDP2-wrapped model. + + Required because ``mtq.quantize`` unwraps the model before calling + ``forward_loop``; calling the unwrapped inner module skips FSDP2's pre/post + forward hooks and breaks the all-gather. The closure captures the outer + ``model`` and ignores the ``unwrapped_model`` argument. + """ + from tqdm import tqdm + + def calibrate(unwrapped_model): + for batch in tqdm(dataloader, desc="Calibrating"): + if isinstance(batch, dict): + batch = { + k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() + } + model(**batch) + + return calibrate + + def run_nemotron_vl_preview( full_model, tokenizer, diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d52c3ee40bb..e7515e9eb44 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -15,6 +15,8 @@ import argparse import copy +import json +import os import random import time import warnings @@ -23,22 +25,27 @@ import numpy as np import torch +import torch.distributed as dist from accelerate.hooks import remove_hook_from_module from cast_mxfp4_to_nvfp4 import apply_to_model as apply_cast_mxfp4_to_nvfp4 from cast_mxfp4_to_nvfp4 import force_weight_quantizers_static from example_utils import ( build_quant_cfg, + cleanup_distributed, copy_custom_model_files, + create_fsdp2_calibration_loop, create_vlm_calibration_loop, get_model, get_processor, get_tokenizer, is_enc_dec, is_nemotron_vl, + load_and_prepare_fsdp2_model, load_mtp_weights, needs_checkpoint_path_update, resolve_checkpoint_dir, run_nemotron_vl_preview, + setup_distributed_args, ) from torch.utils.data import DataLoader from transformers import ( @@ -65,10 +72,12 @@ has_spec_opt, save_expert_token_count_table, ) +from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model +from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights -from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.quantization.utils import is_quantized, patch_fsdp_mp_dtypes from modelopt.torch.speculative.eagle.utils import ( EagleOfflineDataCollator, OfflineSupervisedDataset, @@ -79,6 +88,7 @@ get_max_batch_size, get_supported_datasets, ) +from modelopt.torch.utils.distributed import Fsdp2StateDictAdapter, shard_dataloader from modelopt.torch.utils.memory_monitor import launch_memory_monitor from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader @@ -285,6 +295,9 @@ def make_calib_dataloader( device=device, include_labels=include_labels, ) + if args.use_fsdp2 and calib_dataloader is not None and isinstance(calib_dataloader, DataLoader): + # Each rank sees a disjoint shard of the calibration set. + calib_dataloader = shard_dataloader(calib_dataloader, args.rank, args.world_size) return calib_dataloader, first_text_speech_dataset @@ -435,7 +448,14 @@ def forward_step(model, batch): def load_model(args: argparse.Namespace): # If low memory mode is enabled, we compress the model while loading the HF checkpoint. calibration_only = False - if args.specdec_offline_dataset is not None or not args.low_memory_mode: + if args.use_fsdp2: + full_model = load_and_prepare_fsdp2_model( + ckpt_path=args.pyt_ckpt_path, + device=args.device, + rank=args.rank, + trust_remote_code=args.trust_remote_code, + ) + elif args.specdec_offline_dataset is not None or not args.low_memory_mode: full_model = get_model( args.pyt_ckpt_path, args.device, @@ -477,9 +497,12 @@ def load_model(args: argparse.Namespace): model_type = get_model_type(full_model) - device = full_model.device - if hasattr(full_model, "model"): - device = full_model.model.device + if args.use_fsdp2: + device = args.device + else: + device = full_model.device + if hasattr(full_model, "model"): + device = full_model.model.device processor = None tokenizer = None language_model = full_model @@ -654,6 +677,12 @@ def mono_quantize( # Those kwargs must be consumed by the *full* VLM model, not the extracted language_model. if args.calib_with_images and is_nemotron_vl_model: calibrate_loop = create_vlm_calibration_loop(full_model, calib_dataloader) + elif args.use_fsdp2: + # mtq.quantize passes the unwrapped inner module to forward_loop; + # FSDP2 needs hooks fired on the outer wrapped model. + calibrate_loop = create_fsdp2_calibration_loop( + language_model, calib_dataloader, args.device + ) else: calibrate_loop = create_forward_loop( dataloader=calib_dataloader, @@ -680,6 +709,43 @@ def mono_quantize( warnings.warn("Skipping quantization: model is already quantized.") +def _export_fsdp2_hf_checkpoint(args: argparse.Namespace, full_model, export_path: str) -> None: + """FSDP2-aware HF checkpoint export. + + Gathers the full state dict from FSDP2 shards via ``Fsdp2StateDictAdapter``, + saves it on rank 0 only, then patches the saved config with quantization + metadata and the original (pre-FSDP-prefix) architectures list. + """ + adapter = Fsdp2StateDictAdapter() + post_state_dict, hf_quant_config = _export_transformers_checkpoint( + full_model, torch.bfloat16, accelerator=adapter + ) + + if args.is_main: + export_dir = Path(export_path) + export_dir.mkdir(parents=True, exist_ok=True) + # Save hf_quant_config.json for backward compatibility. + with open(f"{export_dir}/hf_quant_config.json", "w") as f: + json.dump(hf_quant_config, f, indent=4) + hf_quant_config = convert_hf_quant_config_format(hf_quant_config) + full_model.save_pretrained( + export_dir, state_dict=post_state_dict, save_modelopt_state=False + ) + original_config = f"{export_dir}/config.json" + with open(original_config) as f: + config_data = json.load(f) + config_data["quantization_config"] = hf_quant_config + # Strip FSDP-prefixed architectures and restore the original list captured pre-wrap. + original_archs = getattr( + full_model, "_original_architectures", full_model.config.architectures + ) + if original_archs: + config_data["architectures"] = original_archs + with open(original_config, "w") as f: + json.dump(config_data, f, indent=4) + dist.barrier() + + def export_quantized( args: argparse.Namespace, full_model: torch.nn.Module, @@ -772,6 +838,8 @@ def export_quantized( export_hf_vllm_fq_checkpoint( full_model, export_dir=export_path, inplace_mem_efficient=True ) + elif args.use_fsdp2: + _export_fsdp2_hf_checkpoint(args, full_model, export_path) else: mtp_layer_prefixes, mtp_state_dict = load_mtp_weights( full_model, args.pyt_ckpt_path @@ -793,22 +861,26 @@ def export_quantized( ) # Restore default padding and export the tokenizer as well. + # Under FSDP2 only rank 0 writes to disk. if tokenizer is not None: tokenizer.padding_side = default_padding_side if default_pad_token is not None: tokenizer.pad_token = default_pad_token - tokenizer.save_pretrained(export_path) + if args.is_main: + tokenizer.save_pretrained(export_path) # Copy custom model files (Python files and JSON configs) if trust_remote_code is used. # This must run AFTER tokenizer.save_pretrained() so original tokenizer files # from the source checkpoint take precedence over regenerated ones (which may # differ in format due to newer transformers versions). - copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code) + if args.is_main: + copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code) end_time = time.time() - print( - f"Quantized model exported to: {export_path}. Total time used {end_time - start_time}s" - ) + if args.is_main: + print( + f"Quantized model exported to: {export_path}. Total time used {end_time - start_time}s" + ) def pre_quantize( @@ -1337,6 +1409,17 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) + parser.add_argument( + "--use_fsdp2", + action="store_true", + help=( + "Run calibration under PyTorch FSDP2 (requires launching with torchrun). " + "Takes precedence over --use_seq_device_map. " + "v1 limitations: standard causal-LM only (no VILA / pack-quantized / speculative / " + "auto-quantize / sparsity / VLM); per-rank load memory ceiling is roughly model_size " + "(use multinode_ptq.py for models that don't fit on every rank)." + ), + ) parser.add_argument( "--verbose", help="Print verbose output (e.g. quantization summary). Disable by --no-verbose.", @@ -1441,6 +1524,12 @@ def parse_args() -> argparse.Namespace: if args.specdec_offline_dataset is not None and args.low_memory_mode: parser.error("--specdec_offline_dataset is not compatible with --low_memory_mode.") + if args.use_fsdp2 and args.use_seq_device_map: + warnings.warn("--use_seq_device_map is ignored when --use_fsdp2 is set.") + args.use_seq_device_map = False + if args.use_fsdp2 and os.environ.get("RANK") is None: + parser.error("--use_fsdp2 requires launching with torchrun") + return args @@ -1451,6 +1540,10 @@ def main(args: argparse.Namespace): random.seed(RAND_SEED) np.random.seed(RAND_SEED) + # Populate args.rank / world_size / device / is_main. When --use_fsdp2 is off, + # these default to single-process values so downstream helpers can use them uniformly. + setup_distributed_args(args) + # launch a memory monitor to read the currently used GPU memory. launch_memory_monitor() @@ -1487,6 +1580,8 @@ def main(args: argparse.Namespace): device, ) + cleanup_distributed(args) + if __name__ == "__main__": args = parse_args() @@ -1515,4 +1610,6 @@ def main(args: argparse.Namespace): "(multi-format auto-quantize)." ) - main(args) + # patch_fsdp_mp_dtypes is a no-op when no FSDP2 wrap is applied; safe unconditionally. + with patch_fsdp_mp_dtypes(): + main(args) diff --git a/examples/llm_ptq/multinode_ptq.py b/examples/llm_ptq/multinode_ptq.py index d77c7c473bc..9020d344d00 100644 --- a/examples/llm_ptq/multinode_ptq.py +++ b/examples/llm_ptq/multinode_ptq.py @@ -176,10 +176,15 @@ def _resolve_decoder_layers(model: nn.Module, override_cls_name: str | None): def fsdp2_wrap(model: nn.Module, override_cls_name: str | None = None) -> nn.Module: - """Apply FSDP2 ``fully_shard`` to each decoder layer, then to the root module.""" + """Apply FSDP2 ``fully_shard`` to each decoder layer only. + + The root is intentionally not sharded so embed_tokens / lm_head stay as + plain replicated tensors. Sharding the root makes those weights DTensors, + which collides with modelopt's layerwise forward patching (mixed + plain-tensor / DTensor inputs at the embedding lookup). + """ for layer in _resolve_decoder_layers(model, override_cls_name): fully_shard(layer, reshard_after_forward=True) - fully_shard(model, reshard_after_forward=True) return model @@ -286,11 +291,15 @@ def create_fsdp2_calibration_loop( """Calibration loop that forwards through the FSDP-wrapped model.""" def calibrate(unwrapped_model): + # Force use_cache=False so layerwise replays don't accumulate KV across batches. + if hasattr(model, "config") and hasattr(model.config, "use_cache"): + model.config.use_cache = False for batch in tqdm(dataloader, desc="Calibrating"): if isinstance(batch, dict): batch = { k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() } + batch.setdefault("use_cache", False) # Use outer (FSDP-wrapped) model, not the unwrapped parameter passed by mtq.quantize. model(**batch) diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index 7922b688052..ffe69f83db5 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -31,13 +31,17 @@ __all__ = [ "DistributedProcessGroup", + "Fsdp2StateDictAdapter", "ParallelState", "backend", "barrier", + "fsdp2_shard", + "fsdp2_wrap", "is_available", "is_initialized", "is_master", "rank", + "shard_dataloader", "size", ] @@ -216,6 +220,125 @@ def cleanup(): torch.distributed.destroy_process_group() +# --------------------------------------------------------------------------- +# FSDP2 helpers — used by examples/llm_ptq to run PTQ calibration under FSDP2. +# --------------------------------------------------------------------------- + + +def fsdp2_wrap(model, override_cls_name: str | None = None, mp_policy=None): + """Apply FSDP2 ``fully_shard`` to each decoder layer of ``model``. + + Decoder layers are auto-detected via + ``modelopt.torch.quantization.utils.layerwise_calib.LayerActivationCollector.get_decoder_layers``. + Pass ``override_cls_name`` to force a specific transformer block class. Pass + ``mp_policy`` (a ``torch.distributed.fsdp.MixedPrecisionPolicy``) to control + compute / reduce dtype; default ``None`` means no upcast / downcast. + + The root module is intentionally not sharded so embeddings / lm_head stay as + plain tensors (avoids DTensor / plain-tensor mismatches with modelopt's + layerwise forward patching). + """ + from torch.distributed.fsdp import fully_shard + + from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector + + if override_cls_name: + layers = [m for m in model.modules() if type(m).__name__ == override_cls_name] + if not layers: + raise RuntimeError(f"No modules of class {override_cls_name!r} found in model") + else: + layers = LayerActivationCollector.get_decoder_layers(model) + if layers is None: + raise RuntimeError( + "Could not auto-detect decoder layers; pass override_cls_name explicitly." + ) + fsdp_kwargs: dict[str, Any] = {"reshard_after_forward": True} + if mp_policy is not None: + fsdp_kwargs["mp_policy"] = mp_policy + for layer in layers: + fully_shard(layer, **fsdp_kwargs) + return model + + +def fsdp2_shard(model, device, rank, mp_policy=None): + """Shard a loaded model across the current process group. + + Expects rank 0 to pass a real CPU model and other ranks to pass a meta + skeleton with matching structure. After this call every rank holds its + per-rank GPU shard, populated from rank 0's source. + + Steps: stash ``_original_architectures`` (FSDP2 may mutate + ``model.config.architectures``); capture rank-0's state_dict; ``fsdp2_wrap`` + per decoder layer; ``to_empty`` allocates per-rank GPU shard storage; + ``set_model_state_dict(broadcast_from_rank0=True)`` streams the data; freeze + params (needed by ``patch_fsdp_mp_dtypes``' trainable-only check). + """ + from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict + + model._original_architectures = list(model.config.architectures or []) + cpu_state_dict = model.state_dict() if rank == 0 else {} + + fsdp2_wrap(model, mp_policy=mp_policy) + model.to_empty(device=device) + set_model_state_dict( + model, + cpu_state_dict, + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), + ) + # TODO(temp workaround): FSDP2's _init_mp_dtypes asserts uniform dtype across + # trainable params. patch_fsdp_mp_dtypes narrows the check to trainable-only; + # freezing here makes trainable empty so mixed-dtype models (Nemotron-H, etc.) + # pass. PTQ doesn't need gradients anyway. + for p in model.parameters(): + p.requires_grad_(False) + return model + + +def shard_dataloader(loader, rank: int, world_size: int): + """Wrap a DataLoader with a DistributedSampler so each rank sees a unique shard. + + Preserves the input loader's ``batch_size``, ``collate_fn``, ``num_workers``, + and ``pin_memory``. + """ + from torch.utils.data import DataLoader + from torch.utils.data.distributed import DistributedSampler + + sampler = DistributedSampler( + loader.dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + return DataLoader( + loader.dataset, + batch_size=loader.batch_size, + sampler=sampler, + collate_fn=loader.collate_fn, + num_workers=loader.num_workers, + pin_memory=loader.pin_memory, + ) + + +class Fsdp2StateDictAdapter: + """Adapter exposing ``.get_state_dict(model)`` for FSDP2-sharded models. + + Satisfies the ``accelerator=`` kwarg of + ``modelopt.torch.export.unified_export_hf._export_transformers_checkpoint``. + Backed by ``get_model_state_dict`` which materializes a full unsharded state + dict on every rank (with CPU offload to bound peak GPU memory during gather). + """ + + def get_state_dict(self, model): + """Return the full unsharded state dict gathered from FSDP2 shards (CPU-offloaded).""" + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + + return get_model_state_dict( + model, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + + class DistributedProcessGroup: """A convenient wrapper around torch.distributed.ProcessGroup objects.""" From 46cb80e40f432b1f7aecf71036d1fbc0de41059b Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Thu, 28 May 2026 23:21:43 +0000 Subject: [PATCH 04/10] tested layerwise + non layerwise Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/README.md | 36 +- examples/llm_ptq/example_utils.py | 101 +++- examples/llm_ptq/hf_ptq.py | 24 +- examples/llm_ptq/multinode_ptq.py | 450 ------------------ .../torch/quantization/utils/core_utils.py | 20 +- modelopt/torch/utils/distributed.py | 128 ++++- 6 files changed, 238 insertions(+), 521 deletions(-) delete mode 100644 examples/llm_ptq/multinode_ptq.py diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 511eed5d774..880561090c5 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -370,32 +370,42 @@ mtq.calibrate(model, algorithm="max", forward_loop=calibrate_loop) ## Multi-Node Post-Training Quantization with FSDP2 -ModelOpt enables quantization of LLMs across multiple GPU nodes using various quantization formats. It leverages HuggingFace's Accelerate library and FSDP2 for distributed model sharding and calibration. +ModelOpt enables quantization of LLMs across multiple GPU nodes using FSDP2 for distributed model sharding and calibration, exposed via the `--use_fsdp2` flag on the standard `hf_ptq.py` entry point. ### Usage -For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user specific requirements. +Single-node (multiple GPUs): -On each node run the following command: +```bash +torchrun --standalone --nproc_per_node= hf_ptq.py \ + --pyt_ckpt_path \ + --qformat \ + --kv_cache_qformat \ + --batch_size \ + --calib_size \ + --export_path \ + --use_fsdp2 +``` + +Multi-node (run on each node): ```bash -accelerate launch --config_file fsdp2.yaml \ - --num_machines= \ - --machine_rank= \ - --main_process_ip= \ - --main_process_port= \ - --fsdp_transformer_layer_cls_to_wrap= - multinode_ptq.py \ +torchrun \ + --nnodes= --node_rank= \ + --master_addr= --master_port= \ + --nproc_per_node= \ + hf_ptq.py \ --pyt_ckpt_path \ - --qformat \ + --qformat \ --kv_cache_qformat \ --batch_size \ --calib_size \ - --dataset \ --export_path \ - --trust_remote_code + --use_fsdp2 ``` +For layerwise calibration (amortizes cross-node all-gather cost across all calibration batches), use `--qformat nvfp4_max_layerwise`. + The exported checkpoint can be deployed using TensorRT-LLM/ vLLM/ SGLang. For more details refer to the [deployment section](#deployment) of this document. > *Performance Note: FSDP2 is designed for training workloads and may result in longer calibration and export times. For faster calibration, maximize the batch size based on available GPU memory and choose the right number of GPUs to avoid unnecessary communication.* diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index ff2ba9631bc..6d3cf134e35 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -52,7 +52,7 @@ # --------------------------------------------------------------------------- -# FSDP2 helpers (opt-in via --use_fsdp2 in hf_ptq.py / multinode_ptq.py). +# FSDP2 helpers (opt-in via --use_fsdp2 in hf_ptq.py). # --------------------------------------------------------------------------- @@ -86,48 +86,98 @@ def cleanup_distributed(args): dist_utils.cleanup() +def validate_fsdp2_supported(args, config): + """Raise NotImplementedError if the model config is not FSDP2-supported in v1. + + Called after ``AutoConfig.from_pretrained`` (cheap) and before any heavy + loading work, so unsupported configurations fail fast with a clear message + instead of crashing later inside a DTensor traceback. + """ + issues = [] + if "vila" in args.pyt_ckpt_path.lower(): + issues.append("VILA (custom builder + non-standard layer layout)") + if is_nemotron_vl(config) or _is_multimodal_config(config): + issues.append("multimodal / VL models (decoder layers not auto-detectable)") + if getattr(config, "quantization_config", None) is not None: + issues.append("pack-quantized / compressed-tensors checkpoints") + if getattr(args, "specdec_offline_dataset", None) is not None: + issues.append("speculative decoding (--specdec_offline_dataset)") + if getattr(args, "low_memory_mode", False): + issues.append("--low_memory_mode (redundant with FSDP2)") + if issues: + raise NotImplementedError( + "--use_fsdp2 does not support:\n - " + + "\n - ".join(issues) + + "\nRemove --use_fsdp2 or use a standard causal-LM checkpoint." + ) + + def load_and_prepare_fsdp2_model( ckpt_path: str, device: torch.device, rank: int, + args=None, trust_remote_code: bool = False, mp_policy=None, ): - """Load and FSDP2-shard a causal LM with rank-0-only CPU realization. + """Load and FSDP2-shard a causal LM (accelerate-style rank-0-only CPU load). + + Replicates ``accelerate.init_empty_weights(include_buffers=False)`` + + ``load_checkpoint_in_model`` manually: - Rank 0 reads weights from disk on CPU via ``from_pretrained``; other ranks - build a structural skeleton on the ``meta`` device. ``fsdp2_shard`` then - slices each decoder layer, allocates per-rank GPU shard storage, broadcasts - rank-0's weights into the shards, and freezes params. + - Rank 0: ``from_pretrained`` on CPU; capture ``src_state_dict``. + - Other ranks: ``from_config`` under ``init_params_on_meta`` → params on + meta (~0 CPU), buffers computed on CPU from config (RoPE inv_freq etc.). + - ``fsdp2_shard`` wraps decoder layers (root stays unsharded), materializes + meta→GPU, broadcasts state_dict from rank 0, re-ties weights, freezes. - Memory: only rank 0 holds the full CPU copy. Each rank ends with - ``model_size / world_size`` of GPU shard storage. + Memory: rank 0 holds the full BF16 model in CPU during the broadcast + (~model_size bytes); other ranks pay ~0 CPU. Each rank ends with + ``model_size / world_size`` GPU shard storage plus replicated + ``embed_tokens`` + ``lm_head`` (~few-GiB total). v1 supports standard transformers families only (causal LMs that load - cleanly via ``AutoModelForCausalLM``). VILA / pack-quantized / speculative - are not validated under FSDP2 and should go through ``get_model``. + cleanly via ``AutoModelForCausalLM``). VILA / pack-quantized / + speculative / VL go through ``get_model`` and don't get FSDP2. """ - from modelopt.torch.utils.distributed import fsdp2_shard + from modelopt.torch.utils.distributed import fsdp2_shard, init_params_on_meta hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=trust_remote_code) + if args is not None: + validate_fsdp2_supported(args, hf_config) + + dtype = getattr(hf_config, "torch_dtype", None) or torch.bfloat16 + if rank == 0: - model = AutoModelForCausalLM.from_pretrained( + src_model = AutoModelForCausalLM.from_pretrained( ckpt_path, torch_dtype="auto", trust_remote_code=trust_remote_code, low_cpu_mem_usage=True, ) + src_model.eval() + src_state_dict = src_model.state_dict() else: - dtype = getattr(hf_config, "torch_dtype", None) or torch.bfloat16 - with torch.device("meta"): - model = AutoModelForCausalLM.from_config( - hf_config, torch_dtype=dtype, trust_remote_code=trust_remote_code - ) + src_model = None + src_state_dict = {} + + with init_params_on_meta(): + model = AutoModelForCausalLM.from_config( + hf_config, torch_dtype=dtype, trust_remote_code=trust_remote_code + ) model.eval() - # Disable HF KV cache; calibration is single-pass and (for layerwise) replays. if hasattr(model, "config") and hasattr(model.config, "use_cache"): model.config.use_cache = False - return fsdp2_shard(model, device, rank, mp_policy=mp_policy) + + sharded = fsdp2_shard( + model, + device, + rank, + src_state_dict=src_state_dict, + mp_policy=mp_policy, + ) + del src_model, src_state_dict + return sharded def create_fsdp2_calibration_loop(model, dataloader, device): @@ -664,9 +714,11 @@ def get_model( # Note: Forcibly converting the model precision between bf16 and fp16 may introduce accuracy drop model_kwargs = config_kwargs.copy() - # Don't set torch_dtype for VILA models as they handle it explicitly in their builder + # Don't set torch_dtype for VILA models as they handle it explicitly in their builder. + # Use the legacy ``torch_dtype`` kwarg name — newer transformers forwards the ``dtype`` + # kwarg through to the model class ``__init__``, which custom modeling code may reject. if "vila" not in ckpt_path.lower(): - model_kwargs.setdefault("dtype", "auto") + model_kwargs.setdefault("torch_dtype", "auto") if "vila" in ckpt_path.lower(): hf_vila = AutoModel.from_pretrained( @@ -717,7 +769,7 @@ def has_pack_quantized_config(config): ckpt_path, device_map="auto", trust_remote_code=trust_remote_code, - dtype="auto", + torch_dtype="auto", ) else: architecture = hf_config.architectures[0] @@ -749,7 +801,10 @@ def has_pack_quantized_config(config): model_kwargs2 = model_kwargs.copy() if auto_model_module not in [AutoModelForCausalLM, AutoModel]: model_kwargs2.pop("trust_remote_code", None) - model_kwargs2["dtype"] = torch_dtype + # Use the legacy ``torch_dtype`` kwarg; some custom modeling classes + # reject the newer ``dtype`` name when it's forwarded via **kwargs. + model_kwargs2["torch_dtype"] = torch_dtype + model_kwargs2.pop("dtype", None) model_kwargs2.pop("max_memory", None) model = from_config(hf_config, **model_kwargs2) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index e7515e9eb44..b229fbec179 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -110,6 +110,13 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: break +def _nvfp4_max_cfg(*, layerwise: bool) -> dict[str, Any]: + """NVFP4 quant config with explicit max calibration and a layerwise toggle.""" + cfg = copy.deepcopy(mtq.NVFP4_DEFAULT_CFG) + cfg["algorithm"] = {"method": "max", "layerwise": layerwise} + return cfg + + QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { "int8": mtq.INT8_DEFAULT_CFG, "int8_sq": mtq.INT8_SMOOTHQUANT_CFG, @@ -118,6 +125,8 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: "int4_awq": mtq.INT4_AWQ_CFG, "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG, "nvfp4": mtq.NVFP4_DEFAULT_CFG, + "nvfp4_max": _nvfp4_max_cfg(layerwise=False), + "nvfp4_max_layerwise": _nvfp4_max_cfg(layerwise=True), "nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG, "nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG, "fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, @@ -453,6 +462,7 @@ def load_model(args: argparse.Namespace): ckpt_path=args.pyt_ckpt_path, device=args.device, rank=args.rank, + args=args, trust_remote_code=args.trust_remote_code, ) elif args.specdec_offline_dataset is not None or not args.low_memory_mode: @@ -911,6 +921,10 @@ def pre_quantize( # Generate preview before quantization if args.skip_generate: generated_ids_before_ptq = None + elif args.use_fsdp2: + # FSDP2 generation is slow cross-node (~seconds/token); 5 tokens is + # enough to sanity-check coherence. + generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=5) elif model_type == "deepseek": # DeepSeek generation may go OOM, so we skip it generated_ids_before_ptq = None @@ -987,7 +1001,11 @@ def post_quantize( pass elif model_type != "llama4" and not is_nemotron_vl_model: # Our fake quantizer may not be fully compatible with torch.compile. - generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) + # FSDP2 cross-node generation is slow, so cap at 5 tokens for preview. + max_new_tokens = 5 if args.use_fsdp2 else 100 + generated_ids_after_ptq = full_model.generate( + preview_input_ids, max_new_tokens=max_new_tokens + ) elif is_nemotron_vl_model and tokenizer is not None: generated_ids_after_ptq = run_nemotron_vl_preview( full_model, @@ -1416,8 +1434,8 @@ def parse_args() -> argparse.Namespace: "Run calibration under PyTorch FSDP2 (requires launching with torchrun). " "Takes precedence over --use_seq_device_map. " "v1 limitations: standard causal-LM only (no VILA / pack-quantized / speculative / " - "auto-quantize / sparsity / VLM); per-rank load memory ceiling is roughly model_size " - "(use multinode_ptq.py for models that don't fit on every rank)." + "auto-quantize / sparsity / VLM). Rank 0 holds the full model in CPU briefly " + "during the broadcast step; other ranks pay ~0 CPU." ), ) parser.add_argument( diff --git a/examples/llm_ptq/multinode_ptq.py b/examples/llm_ptq/multinode_ptq.py deleted file mode 100644 index 9020d344d00..00000000000 --- a/examples/llm_ptq/multinode_ptq.py +++ /dev/null @@ -1,450 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Multi-node PTQ (Post-Training Quantization) with FSDP2 support.""" - -import argparse -import copy -import json -import os -import random -import time -import warnings -from pathlib import Path -from typing import Any - -import numpy as np -import torch -import torch.distributed as dist -import torch.nn as nn -from example_utils import build_quant_cfg, get_tokenizer -from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - get_model_state_dict, - set_model_state_dict, -) -from torch.distributed.fsdp import fully_shard -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from tqdm import tqdm -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - PreTrainedTokenizer, - PreTrainedTokenizerFast, -) - -import modelopt.torch.opt as mto -import modelopt.torch.quantization as mtq -from modelopt.torch.export import get_model_type -from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format -from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint -from modelopt.torch.quantization.config import need_calibration -from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes -from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector -from modelopt.torch.utils.dataset_utils import get_dataset_dataloader, get_supported_datasets - -RAND_SEED = 1234 - - -def _nvfp4_max_cfg(*, layerwise: bool) -> dict[str, Any]: - """NVFP4 quant config with explicit max calibration and a layerwise toggle.""" - cfg = copy.deepcopy(mtq.NVFP4_DEFAULT_CFG) - cfg["algorithm"] = {"method": "max", "layerwise": layerwise} - return cfg - - -QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { - "int8": mtq.INT8_DEFAULT_CFG, - "int4_awq": mtq.INT4_AWQ_CFG, - "fp8": mtq.FP8_DEFAULT_CFG, - "nvfp4": mtq.NVFP4_DEFAULT_CFG, - "nvfp4_max": _nvfp4_max_cfg(layerwise=False), - "nvfp4_max_layerwise": _nvfp4_max_cfg(layerwise=True), - "nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG, - "w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG, - "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, - "nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG, - "nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG, -} - -KV_QUANT_CFG_CHOICES = { - "none": "none", - "fp8": "FP8_KV_CFG", - "nvfp4": "NVFP4_KV_CFG", - "nvfp4_affine": "NVFP4_AFFINE_KV_CFG", -} - - -mto.enable_huggingface_checkpointing() - - -def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="Multi-node post-training quantization with FSDP2") - - parser.add_argument("--pyt_ckpt_path", required=True, help="Path to PyTorch checkpoint") - parser.add_argument( - "--qformat", default="fp8", choices=QUANT_CFG_CHOICES.keys(), help="Quantization format" - ) - parser.add_argument( - "--kv_cache_qformat", - default="fp8", - choices=list(KV_QUANT_CFG_CHOICES.keys()), - help="KV cache quantization format", - ) - parser.add_argument("--batch_size", type=int, default=1, help="Batch size for calibration") - parser.add_argument( - "--calib_size", - type=str, - default="512", - help="Comma-separated list of calibration sizes per dataset", - ) - parser.add_argument( - "--dataset", - type=str, - default=None, - help=( - f"name of a dataset, or a comma separated list of datasets. " - f"dataset choices are {get_supported_datasets()}" - ), - ) - parser.add_argument( - "--export_path", default="exported_model", help="Directory to export the quantized model" - ) - parser.add_argument( - "--trust_remote_code", - action="store_true", - help="Trust remote code for HuggingFace models", - ) - parser.add_argument("--awq_block_size", default=0, type=int) - parser.add_argument( - "--fsdp_transformer_layer_cls_to_wrap", - default=None, - help=( - "Override auto-detect by transformer layer class name " - "(e.g. LlamaDecoderLayer). Auto-detected when omitted." - ), - ) - args = parser.parse_args() - - args.dataset = args.dataset.split(",") if args.dataset else None - args.calib_size = [int(x) for x in args.calib_size.split(",")] - - return args - - -def setup_distributed() -> tuple[int, int, int, torch.device]: - """Initialize torch.distributed from torchrun env vars and pin the CUDA device.""" - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - local_rank = int(os.environ["LOCAL_RANK"]) - torch.cuda.set_device(local_rank) - device = torch.device(f"cuda:{local_rank}") - if not dist.is_initialized(): - dist.init_process_group(backend="nccl") - print(f"Rank: {rank}, World Size: {world_size}, Local Rank: {local_rank}") - return rank, world_size, local_rank, device - - -def _resolve_decoder_layers(model: nn.Module, override_cls_name: str | None): - """Return the list of decoder layers to apply ``fully_shard`` to.""" - if override_cls_name: - layers = [m for m in model.modules() if type(m).__name__ == override_cls_name] - if not layers: - raise RuntimeError(f"No modules of class {override_cls_name!r} found in model") - return layers - layers = LayerActivationCollector.get_decoder_layers(model) - if layers is None: - raise RuntimeError( - "Could not auto-detect decoder layers; pass " - "--fsdp_transformer_layer_cls_to_wrap explicitly." - ) - return layers - - -def fsdp2_wrap(model: nn.Module, override_cls_name: str | None = None) -> nn.Module: - """Apply FSDP2 ``fully_shard`` to each decoder layer only. - - The root is intentionally not sharded so embed_tokens / lm_head stay as - plain replicated tensors. Sharding the root makes those weights DTensors, - which collides with modelopt's layerwise forward patching (mixed - plain-tensor / DTensor inputs at the embedding lookup). - """ - for layer in _resolve_decoder_layers(model, override_cls_name): - fully_shard(layer, reshard_after_forward=True) - return model - - -def load_and_prepare_model( - model_path: str, - device: torch.device, - rank: int, - trust_remote_code: bool = False, - override_cls_name: str | None = None, -) -> tuple[nn.Module, str, list[str]]: - """Load model and shard it with FSDP2 using rank-0-only CPU realization. - - Only rank 0 reads real weights from disk; every other rank instantiates the - model on the ``meta`` device. After ``fully_shard`` sets up the sharded - DTensor layout and ``to_empty`` allocates per-rank shard storage, rank 0's - full state dict is broadcast into the sharded structure. - """ - config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) - dtype = getattr(config, "torch_dtype", None) or torch.bfloat16 - - if rank == 0: - src_model = AutoModelForCausalLM.from_pretrained( - model_path, - torch_dtype="auto", - trust_remote_code=trust_remote_code, - low_cpu_mem_usage=True, - ) - src_model.eval() - cpu_state_dict = src_model.state_dict() - else: - src_model = None - cpu_state_dict = {} - - with torch.device("meta"): - model = AutoModelForCausalLM.from_config( - config, torch_dtype=dtype, trust_remote_code=trust_remote_code - ) - model.eval() - - model_type = get_model_type(model) - original_architectures = model.config.architectures - - fsdp2_wrap(model, override_cls_name=override_cls_name) - - model.to_empty(device=device) - - set_model_state_dict( - model, - cpu_state_dict, - options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), - ) - - # Freeze every param so patch_fsdp_mp_dtypes' trainable-only check skips the - # uniform-dtype assertion (e.g. Nemotron-H ships mixed bf16/fp32 weights). - for p in model.parameters(): - p.requires_grad_(False) - - del cpu_state_dict, src_model - - return model, model_type, original_architectures - - -def create_calibration_dataloader( - tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - dataset_names: list[str], - calib_sizes: list[int], - batch_size: int, -) -> DataLoader: - """Create calibration dataloader from dataset.""" - return get_dataset_dataloader( - dataset_name=dataset_names, - tokenizer=tokenizer, - batch_size=batch_size, - num_samples=calib_sizes, - device=None, - include_labels=False, - ) - - -def shard_dataloader(loader: DataLoader, rank: int, world_size: int) -> DataLoader: - """Wrap a DataLoader with a DistributedSampler so each rank sees a unique shard.""" - sampler = DistributedSampler( - loader.dataset, - num_replicas=world_size, - rank=rank, - shuffle=False, - drop_last=False, - ) - return DataLoader( - loader.dataset, - batch_size=loader.batch_size, - sampler=sampler, - collate_fn=loader.collate_fn, - num_workers=loader.num_workers, - pin_memory=loader.pin_memory, - ) - - -def create_fsdp2_calibration_loop( - model: nn.Module, - dataloader: DataLoader, - device: torch.device, -): - """Calibration loop that forwards through the FSDP-wrapped model.""" - - def calibrate(unwrapped_model): - # Force use_cache=False so layerwise replays don't accumulate KV across batches. - if hasattr(model, "config") and hasattr(model.config, "use_cache"): - model.config.use_cache = False - for batch in tqdm(dataloader, desc="Calibrating"): - if isinstance(batch, dict): - batch = { - k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() - } - batch.setdefault("use_cache", False) - # Use outer (FSDP-wrapped) model, not the unwrapped parameter passed by mtq.quantize. - model(**batch) - - return calibrate - - -class _Fsdp2StateDictAdapter: - """Shim exposing ``.get_state_dict(model)`` to ``_export_transformers_checkpoint``.""" - - def get_state_dict(self, model: nn.Module): - return get_model_state_dict( - model, - options=StateDictOptions(full_state_dict=True, cpu_offload=True), - ) - - -def export_model( - model: nn.Module, - rank: int, - export_path: str | Path, - architectures: list[str], -): - """Export the quantized model to HuggingFace format on rank 0.""" - export_dir = Path(export_path) - export_dir.mkdir(parents=True, exist_ok=True) - - adapter = _Fsdp2StateDictAdapter() - post_state_dict, hf_quant_config = _export_transformers_checkpoint( - model, torch.bfloat16, accelerator=adapter - ) - - if rank == 0: - with open(f"{export_dir}/hf_quant_config.json", "w") as file: - json.dump(hf_quant_config, file, indent=4) - - hf_quant_config = convert_hf_quant_config_format(hf_quant_config) - - model.save_pretrained(export_dir, state_dict=post_state_dict, save_modelopt_state=False) - - original_config = f"{export_dir}/config.json" - with open(original_config) as file: - config_data = json.load(file) - - config_data["quantization_config"] = hf_quant_config - config_data["architectures"] = architectures - - with open(original_config, "w") as file: - json.dump(config_data, file, indent=4) - - dist.barrier() - - -def main(args): - """Main quantization workflow.""" - if not torch.cuda.is_available(): - raise OSError("GPU is required for quantization.") - - if args.qformat not in QUANT_CFG_CHOICES: - raise ValueError( - f"Quantization format {args.qformat} not supported. Choose from: {QUANT_CFG_CHOICES.keys()}" - ) - - random.seed(RAND_SEED) - np.random.seed(RAND_SEED) - torch.manual_seed(RAND_SEED) - - rank, world_size, _, device = setup_distributed() - - tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) - default_padding_side = tokenizer.padding_side - tokenizer.padding_side = "left" - - if args.dataset is None: - args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] - warnings.warn( - "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2." - ) - args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[ - : len(args.dataset) - ] - - calib_dataloader = create_calibration_dataloader( - tokenizer=tokenizer, - dataset_names=args.dataset, - calib_sizes=args.calib_size, - batch_size=args.batch_size, - ) - calib_dataloader = shard_dataloader(calib_dataloader, rank, world_size) - - model, model_type, original_architectures = load_and_prepare_model( - model_path=args.pyt_ckpt_path, - device=device, - rank=rank, - trust_remote_code=args.trust_remote_code, - override_cls_name=args.fsdp_transformer_layer_cls_to_wrap, - ) - - quant_cfg = QUANT_CFG_CHOICES[args.qformat] - quant_cfg = build_quant_cfg(args.qformat, quant_cfg, args.awq_block_size, model_type) - - enable_quant_kv_cache = args.kv_cache_qformat != "none" - print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") - - if enable_quant_kv_cache: - quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant( - quant_cfg, - getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"], - ) - - if rank == 0: - print("Starting quantization...") - - start_time = time.time() - - if need_calibration(quant_cfg): - calibrate_fn = create_fsdp2_calibration_loop(model, calib_dataloader, device) - else: - calibrate_fn = None - warnings.warn("Dynamic quantization. Calibration skipped.") - - with torch.no_grad(): - model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_fn) - - elapsed = time.time() - start_time - - if rank == 0: - print(f"Quantization completed in {elapsed:.2f}s") - mtq.print_quant_summary(model) - - start_time = time.time() - export_model(model, rank, args.export_path, original_architectures) - elapsed = time.time() - start_time - - if rank == 0: - if tokenizer is not None: - tokenizer.padding_side = default_padding_side - tokenizer.save_pretrained(args.export_path) - print(f"Export completed in {elapsed:.2f}s") - print(f"Model exported to {args.export_path}") - - dist.barrier() - dist.destroy_process_group() - - -if __name__ == "__main__": - args = parse_args() - with patch_fsdp_mp_dtypes(): - main(args) diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index cea3d4260e4..4d9ed0bdf3e 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -399,18 +399,26 @@ def is_pow2(n): def _get_fsdp2_mesh(module: nn.Module): - """Get the mesh info of the model.""" + """Get the mesh info of the model. + + Prefers ``post_forward_mesh_info`` (set when ``reshard_after_forward=True``); + falls back to ``mesh_info`` if it's None (observed under some PyTorch FSDP2 + configurations: eval mode + all-frozen params at the time ``persistent + _materialization`` queries the state — the mesh itself is still valid). + """ try: from torch.distributed._composable_state import _get_module_state except ImportError: return None fsdp_state = _get_module_state(module) - if ( - fsdp_state._fsdp_param_group - and fsdp_state._fsdp_param_group.post_forward_mesh_info is not None - ): - return fsdp_state._fsdp_param_group.post_forward_mesh_info.mesh + pg = getattr(fsdp_state, "_fsdp_param_group", None) + if pg is None: + return None + info = pg.post_forward_mesh_info or pg.mesh_info + if info is None: + return None + return info.mesh def _get_module_name(module: nn.Module, root_model: nn.Module, name_to_module: dict | None = None): diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index ffe69f83db5..0e5a4ce8d03 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -20,13 +20,14 @@ import os import time from collections.abc import Callable -from contextlib import suppress +from contextlib import contextmanager, suppress from datetime import timedelta from typing import Any from warnings import warn import torch import torch.distributed +import torch.nn as nn from torch.distributed.tensor import DTensor __all__ = [ @@ -37,6 +38,7 @@ "barrier", "fsdp2_shard", "fsdp2_wrap", + "init_params_on_meta", "is_available", "is_initialized", "is_master", @@ -225,7 +227,12 @@ def cleanup(): # --------------------------------------------------------------------------- -def fsdp2_wrap(model, override_cls_name: str | None = None, mp_policy=None): +def fsdp2_wrap( + model, + override_cls_name: str | None = None, + mp_policy=None, + device=None, +): """Apply FSDP2 ``fully_shard`` to each decoder layer of ``model``. Decoder layers are auto-detected via @@ -234,9 +241,17 @@ def fsdp2_wrap(model, override_cls_name: str | None = None, mp_policy=None): ``mp_policy`` (a ``torch.distributed.fsdp.MixedPrecisionPolicy``) to control compute / reduce dtype; default ``None`` means no upcast / downcast. - The root module is intentionally not sharded so embeddings / lm_head stay as - plain tensors (avoids DTensor / plain-tensor mismatches with modelopt's - layerwise forward patching). + Pass ``device`` to stream each layer to that device just before sharding it + (avoids holding the full model on GPU simultaneously). When ``device`` is + ``None``, layers are sharded on whatever device they're already on. + + The root module is intentionally NOT sharded — ``embed_tokens`` and + ``lm_head`` stay as plain replicated tensors. This costs ~few-GiB per rank + (full copies of embed + lm_head) but unifies the layerwise and non-layerwise + code paths: a DTensor-wrapped ``embed_tokens.weight`` raises a "mixed Tensor + / DTensor" error when modelopt's layerwise calibration passes plain + ``input_ids`` into the embedding lookup, and FSDP2's root pre-forward hook + doesn't auto-wrap LongTensor inputs. """ from torch.distributed.fsdp import fully_shard @@ -256,41 +271,102 @@ def fsdp2_wrap(model, override_cls_name: str | None = None, mp_policy=None): if mp_policy is not None: fsdp_kwargs["mp_policy"] = mp_policy for layer in layers: + if device is not None: + layer.to(device) fully_shard(layer, **fsdp_kwargs) return model -def fsdp2_shard(model, device, rank, mp_policy=None): - """Shard a loaded model across the current process group. +@contextmanager +def init_params_on_meta(): + """Replicate ``accelerate.init_empty_weights(include_buffers=False)``. - Expects rank 0 to pass a real CPU model and other ranks to pass a meta - skeleton with matching structure. After this call every rank holds its - per-rank GPU shard, populated from rank 0's source. + Inside this context, ``nn.Module.register_parameter`` is patched so newly + registered parameters land on the ``meta`` device (zero CPU bytes). Buffer + registration is NOT patched — buffers are computed normally on CPU during + ``__init__`` (e.g. ``Qwen2RotaryEmbedding.__init__`` produces a real CPU + ``inv_freq`` from config). - Steps: stash ``_original_architectures`` (FSDP2 may mutate - ``model.config.architectures``); capture rank-0's state_dict; ``fsdp2_wrap`` - per decoder layer; ``to_empty`` allocates per-rank GPU shard storage; - ``set_model_state_dict(broadcast_from_rank0=True)`` streams the data; freeze - params (needed by ``patch_fsdp_mp_dtypes``' trainable-only check). + Use around ``from_config(...)`` on non-rank-0 ranks to build a meta-skeleton + that ``fsdp2_shard`` will materialize via ``set_model_state_dict`` + broadcast from rank 0. + """ + original = nn.Module.register_parameter + + def patched(self, name, param): + original(self, name, param) + if param is not None: + p = self._parameters[name] + self._parameters[name] = nn.Parameter(p.to("meta"), requires_grad=p.requires_grad) + + nn.Module.register_parameter = patched + try: + yield + finally: + nn.Module.register_parameter = original + + +def fsdp2_shard(model, device, rank, src_state_dict=None, mp_policy=None): + """Shard a model across the current process group (accelerate-style rank-0 load). + + Caller contract: ``model`` is built on every rank with params on ``meta`` and + buffers on CPU (use ``init_params_on_meta`` around ``from_config``). Rank 0 + additionally passes ``src_state_dict`` captured from a real CPU model loaded + via ``from_pretrained``; other ranks pass ``None`` or ``{}``. + + Root is never sharded (see ``fsdp2_wrap`` docstring). embed_tokens and + lm_head stay as plain replicated tensors on every rank. + + Steps (each timed and logged per-rank for diagnostics): + 1. ``fsdp2_wrap`` — apply ``fully_shard`` to decoder layers. + 2. Materialize: meta params → empty GPU storage; real CPU buffers → GPU + (preserves their values; ``to_empty`` is NOT used because it would wipe + buffers). + 3. ``set_model_state_dict(broadcast_from_rank0=True)`` — fills params and + persistent buffers from rank 0. + 4. ``model.tie_weights()`` — restore tied embeddings (no-op for untied). + 5. Freeze params (so ``patch_fsdp_mp_dtypes`` trainable-only check passes). """ from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict + def _log(msg): + print(f"[rank {rank}] [fsdp2_shard] {msg}", flush=True) + model._original_architectures = list(model.config.architectures or []) - cpu_state_dict = model.state_dict() if rank == 0 else {} + t0 = time.perf_counter() fsdp2_wrap(model, mp_policy=mp_policy) - model.to_empty(device=device) - set_model_state_dict( - model, - cpu_state_dict, - options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), - ) - # TODO(temp workaround): FSDP2's _init_mp_dtypes asserts uniform dtype across - # trainable params. patch_fsdp_mp_dtypes narrows the check to trainable-only; - # freezing here makes trainable empty so mixed-dtype models (Nemotron-H, etc.) - # pass. PTQ doesn't need gradients anyway. + _log(f"fsdp2_wrap (decoders) took {time.perf_counter() - t0:.1f}s") + + def _materialize(t): + is_meta_dtensor = isinstance(t, DTensor) and t._local_tensor.is_meta + if is_meta_dtensor or (not isinstance(t, DTensor) and t.is_meta): + # empty_like preserves DTensor-ness (returns DTensor with empty local). + return torch.empty_like(t, device=device) + return t.to(device) + + t0 = time.perf_counter() + model._apply(_materialize) + _log(f"materialize (meta→GPU, CPU buf→GPU) took {time.perf_counter() - t0:.1f}s") + + if src_state_dict is not None: + t0 = time.perf_counter() + set_model_state_dict( + model, + src_state_dict, + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), + ) + _log(f"set_model_state_dict broadcast took {time.perf_counter() - t0:.1f}s") + + if hasattr(model, "tie_weights"): + t0 = time.perf_counter() + model.tie_weights() + _log(f"tie_weights took {time.perf_counter() - t0:.1f}s") + + t0 = time.perf_counter() for p in model.parameters(): p.requires_grad_(False) + _log(f"freeze params took {time.perf_counter() - t0:.1f}s") return model From 705316dcfc39b7b99cef91a84102a4fc2f759d05 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 29 May 2026 18:56:33 +0000 Subject: [PATCH 05/10] added and tested CPU offloading policy Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 36 +++---------- examples/llm_ptq/hf_ptq.py | 33 +++++++----- modelopt/torch/utils/distributed.py | 82 ++++++++++++++++++++++++++--- 3 files changed, 103 insertions(+), 48 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 6d3cf134e35..79f57132714 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -119,6 +119,7 @@ def load_and_prepare_fsdp2_model( args=None, trust_remote_code: bool = False, mp_policy=None, + cpu_offload: bool = False, ): """Load and FSDP2-shard a causal LM (accelerate-style rank-0-only CPU load). @@ -175,32 +176,12 @@ def load_and_prepare_fsdp2_model( rank, src_state_dict=src_state_dict, mp_policy=mp_policy, + cpu_offload=cpu_offload, ) del src_model, src_state_dict return sharded -def create_fsdp2_calibration_loop(model, dataloader, device): - """Calibration closure that forwards through the outer FSDP2-wrapped model. - - Required because ``mtq.quantize`` unwraps the model before calling - ``forward_loop``; calling the unwrapped inner module skips FSDP2's pre/post - forward hooks and breaks the all-gather. The closure captures the outer - ``model`` and ignores the ``unwrapped_model`` argument. - """ - from tqdm import tqdm - - def calibrate(unwrapped_model): - for batch in tqdm(dataloader, desc="Calibrating"): - if isinstance(batch, dict): - batch = { - k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() - } - model(**batch) - - return calibrate - - def run_nemotron_vl_preview( full_model, tokenizer, @@ -714,11 +695,9 @@ def get_model( # Note: Forcibly converting the model precision between bf16 and fp16 may introduce accuracy drop model_kwargs = config_kwargs.copy() - # Don't set torch_dtype for VILA models as they handle it explicitly in their builder. - # Use the legacy ``torch_dtype`` kwarg name — newer transformers forwards the ``dtype`` - # kwarg through to the model class ``__init__``, which custom modeling code may reject. + # Don't set torch_dtype for VILA models as they handle it explicitly in their builder if "vila" not in ckpt_path.lower(): - model_kwargs.setdefault("torch_dtype", "auto") + model_kwargs.setdefault("dtype", "auto") if "vila" in ckpt_path.lower(): hf_vila = AutoModel.from_pretrained( @@ -769,7 +748,7 @@ def has_pack_quantized_config(config): ckpt_path, device_map="auto", trust_remote_code=trust_remote_code, - torch_dtype="auto", + dtype="auto", ) else: architecture = hf_config.architectures[0] @@ -801,10 +780,7 @@ def has_pack_quantized_config(config): model_kwargs2 = model_kwargs.copy() if auto_model_module not in [AutoModelForCausalLM, AutoModel]: model_kwargs2.pop("trust_remote_code", None) - # Use the legacy ``torch_dtype`` kwarg; some custom modeling classes - # reject the newer ``dtype`` name when it's forwarded via **kwargs. - model_kwargs2["torch_dtype"] = torch_dtype - model_kwargs2.pop("dtype", None) + model_kwargs2["dtype"] = torch_dtype model_kwargs2.pop("max_memory", None) model = from_config(hf_config, **model_kwargs2) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index b229fbec179..064c3e4a5ad 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -33,7 +33,6 @@ build_quant_cfg, cleanup_distributed, copy_custom_model_files, - create_fsdp2_calibration_loop, create_vlm_calibration_loop, get_model, get_processor, @@ -88,7 +87,11 @@ get_max_batch_size, get_supported_datasets, ) -from modelopt.torch.utils.distributed import Fsdp2StateDictAdapter, shard_dataloader +from modelopt.torch.utils.distributed import ( + Fsdp2StateDictAdapter, + fsdp_aware_forward_loop, + shard_dataloader, +) from modelopt.torch.utils.memory_monitor import launch_memory_monitor from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader @@ -464,6 +467,7 @@ def load_model(args: argparse.Namespace): rank=args.rank, args=args, trust_remote_code=args.trust_remote_code, + cpu_offload=args.cpu_offload, ) elif args.specdec_offline_dataset is not None or not args.low_memory_mode: full_model = get_model( @@ -690,7 +694,7 @@ def mono_quantize( elif args.use_fsdp2: # mtq.quantize passes the unwrapped inner module to forward_loop; # FSDP2 needs hooks fired on the outer wrapped model. - calibrate_loop = create_fsdp2_calibration_loop( + calibrate_loop = fsdp_aware_forward_loop( language_model, calib_dataloader, args.device ) else: @@ -921,10 +925,6 @@ def pre_quantize( # Generate preview before quantization if args.skip_generate: generated_ids_before_ptq = None - elif args.use_fsdp2: - # FSDP2 generation is slow cross-node (~seconds/token); 5 tokens is - # enough to sanity-check coherence. - generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=5) elif model_type == "deepseek": # DeepSeek generation may go OOM, so we skip it generated_ids_before_ptq = None @@ -1001,11 +1001,7 @@ def post_quantize( pass elif model_type != "llama4" and not is_nemotron_vl_model: # Our fake quantizer may not be fully compatible with torch.compile. - # FSDP2 cross-node generation is slow, so cap at 5 tokens for preview. - max_new_tokens = 5 if args.use_fsdp2 else 100 - generated_ids_after_ptq = full_model.generate( - preview_input_ids, max_new_tokens=max_new_tokens - ) + generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) elif is_nemotron_vl_model and tokenizer is not None: generated_ids_after_ptq = run_nemotron_vl_preview( full_model, @@ -1438,6 +1434,17 @@ def parse_args() -> argparse.Namespace: "during the broadcast step; other ranks pay ~0 CPU." ), ) + parser.add_argument( + "--cpu_offload", + action="store_true", + help=( + "Only valid with --use_fsdp2. Attach FSDP2's CPUOffloadPolicy so each " + "rank's decoder shard lives on CPU between forwards (streamed to GPU " + "per-layer). Frees GPU memory at the cost of PCIe traffic per layer per " + "batch. Worth it for trillion-param models or tight-GPU setups; usually " + "slows down runs where the model already fits comfortably." + ), + ) parser.add_argument( "--verbose", help="Print verbose output (e.g. quantization summary). Disable by --no-verbose.", @@ -1547,6 +1554,8 @@ def parse_args() -> argparse.Namespace: args.use_seq_device_map = False if args.use_fsdp2 and os.environ.get("RANK") is None: parser.error("--use_fsdp2 requires launching with torchrun") + if args.cpu_offload and not args.use_fsdp2: + parser.error("--cpu_offload requires --use_fsdp2") return args diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index 0e5a4ce8d03..4ee84a58d6a 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -38,6 +38,7 @@ "barrier", "fsdp2_shard", "fsdp2_wrap", + "fsdp_aware_forward_loop", "init_params_on_meta", "is_available", "is_initialized", @@ -232,6 +233,7 @@ def fsdp2_wrap( override_cls_name: str | None = None, mp_policy=None, device=None, + cpu_offload: bool = False, ): """Apply FSDP2 ``fully_shard`` to each decoder layer of ``model``. @@ -245,6 +247,14 @@ def fsdp2_wrap( (avoids holding the full model on GPU simultaneously). When ``device`` is ``None``, layers are sharded on whatever device they're already on. + Set ``cpu_offload=True`` to attach FSDP2's ``CPUOffloadPolicy`` to each + wrapped layer. Each rank's shard then lives on CPU between forward passes + and is streamed to GPU per-layer (H2D + all-gather + compute + reshard + + D2H). Useful when the per-rank decoder shard is the binding GPU constraint + (e.g., 200B+ models on tight GPU budgets) or when you want headroom for a + larger calibration batch. Adds PCIe traffic per layer per batch; on + setups where the model already fits comfortably it usually slows the run. + The root module is intentionally NOT sharded — ``embed_tokens`` and ``lm_head`` stay as plain replicated tensors. This costs ~few-GiB per rank (full copies of embed + lm_head) but unifies the layerwise and non-layerwise @@ -270,6 +280,10 @@ def fsdp2_wrap( fsdp_kwargs: dict[str, Any] = {"reshard_after_forward": True} if mp_policy is not None: fsdp_kwargs["mp_policy"] = mp_policy + if cpu_offload: + from torch.distributed.fsdp import CPUOffloadPolicy + + fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() for layer in layers: if device is not None: layer.to(device) @@ -306,7 +320,7 @@ def patched(self, name, param): nn.Module.register_parameter = original -def fsdp2_shard(model, device, rank, src_state_dict=None, mp_policy=None): +def fsdp2_shard(model, device, rank, src_state_dict=None, mp_policy=None, cpu_offload=False): """Shard a model across the current process group (accelerate-style rank-0 load). Caller contract: ``model`` is built on every rank with params on ``meta`` and @@ -314,6 +328,10 @@ def fsdp2_shard(model, device, rank, src_state_dict=None, mp_policy=None): additionally passes ``src_state_dict`` captured from a real CPU model loaded via ``from_pretrained``; other ranks pass ``None`` or ``{}``. + Set ``cpu_offload=True`` to attach FSDP2's ``CPUOffloadPolicy`` to wrapped + layers (each rank's shard lives on CPU between forwards). See + ``fsdp2_wrap`` docstring for the trade-off. + Root is never sharded (see ``fsdp2_wrap`` docstring). embed_tokens and lm_head stay as plain replicated tensors on every rank. @@ -335,19 +353,24 @@ def _log(msg): model._original_architectures = list(model.config.architectures or []) t0 = time.perf_counter() - fsdp2_wrap(model, mp_policy=mp_policy) + fsdp2_wrap(model, mp_policy=mp_policy, cpu_offload=cpu_offload) _log(f"fsdp2_wrap (decoders) took {time.perf_counter() - t0:.1f}s") + # With CPU offload, FSDP2 requires DTensor params on CPU at lazy_init time + # (it streams them to GPU per-layer during forward). Also, set_model_state_dict + # rejects mixed-device models — so materialize everything on CPU first, + # broadcast, then promote non-DTensor params + buffers to GPU after. + materialize_device = torch.device("cpu") if cpu_offload else device + def _materialize(t): is_meta_dtensor = isinstance(t, DTensor) and t._local_tensor.is_meta if is_meta_dtensor or (not isinstance(t, DTensor) and t.is_meta): - # empty_like preserves DTensor-ness (returns DTensor with empty local). - return torch.empty_like(t, device=device) - return t.to(device) + return torch.empty_like(t, device=materialize_device) + return t.to(materialize_device) t0 = time.perf_counter() model._apply(_materialize) - _log(f"materialize (meta→GPU, CPU buf→GPU) took {time.perf_counter() - t0:.1f}s") + _log(f"materialize (→{'CPU' if cpu_offload else 'GPU'}) took {time.perf_counter() - t0:.1f}s") if src_state_dict is not None: t0 = time.perf_counter() @@ -358,6 +381,23 @@ def _materialize(t): ) _log(f"set_model_state_dict broadcast took {time.perf_counter() - t0:.1f}s") + if cpu_offload: + # FSDP-managed (DTensor) params stay on CPU — FSDP2 streams them per layer. + # Move everything else (root-level plain params + all buffers) to GPU now. + t0 = time.perf_counter() + for module in model.modules(): + for name, param in list(module._parameters.items()): + if param is None or isinstance(param, DTensor): + continue + module._parameters[name] = nn.Parameter( + param.data.to(device), requires_grad=param.requires_grad + ) + for name, buf in list(module._buffers.items()): + if buf is None or isinstance(buf, DTensor): + continue + module._buffers[name] = buf.to(device) + _log(f"promote non-DTensor → GPU took {time.perf_counter() - t0:.1f}s") + if hasattr(model, "tie_weights"): t0 = time.perf_counter() model.tie_weights() @@ -396,6 +436,36 @@ def shard_dataloader(loader, rank: int, world_size: int): ) +def fsdp_aware_forward_loop(wrapped_model, dataloader, device=None): + """Build an ``mtq.quantize`` ``forward_loop`` that respects FSDP wrapping. + + ``mtq.quantize`` strips the FSDP wrapper before calling ``forward_loop``, + handing the user the unwrapped inner module. Calling the unwrapped module + bypasses FSDP's pre/post-forward hooks (no all-gather, no reshard), which + breaks calibration on FSDP2. The closure returned here captures the outer + *wrapped* model and ignores the ``unwrapped_model`` argument that + ``mtq.quantize`` passes in. + + Used by ``examples/llm_ptq/hf_ptq.py`` under ``--use_fsdp2``. + + TODO: ``modelopt/torch/quantization/plugins/transformers_trainer.py`` (the + QLoRA path) currently has the same logic inlined inside ``_quantize_model``. + Consolidate that call site to use this helper too. + """ + from tqdm import tqdm + + def calibrate(_unwrapped_model): + for batch in tqdm(dataloader, desc="Calibrating"): + if device is not None and isinstance(batch, dict): + batch = { + k: (v.to(device) if isinstance(v, torch.Tensor) else v) + for k, v in batch.items() + } + wrapped_model(**batch) + + return calibrate + + class Fsdp2StateDictAdapter: """Adapter exposing ``.get_state_dict(model)`` for FSDP2-sharded models. From 59c482c112ab4991fb54c6373557316eee54a6f1 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 2 Jun 2026 17:16:57 +0000 Subject: [PATCH 06/10] added process parallel loading Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 229 +++++++++++++++++- examples/llm_ptq/hf_ptq.py | 5 +- .../torch/quantization/utils/core_utils.py | 33 ++- modelopt/torch/utils/distributed.py | 68 ++++-- 4 files changed, 304 insertions(+), 31 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 79f57132714..b08831bcbf3 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -27,6 +27,7 @@ from typing import Any import torch +import torch.nn as nn import transformers from accelerate import infer_auto_device_map, init_empty_weights from accelerate.utils import get_max_memory @@ -112,30 +113,205 @@ def validate_fsdp2_supported(args, config): ) +def _read_safetensors_state_dict_for_prefix( + ckpt_path: str, + weight_map: dict, + prefix: str, +) -> dict: + """Read all tensors whose name starts with ``prefix`` from safetensors files. + + Groups param names by file to avoid re-opening the same file. Returns CPU tensors. + Uses safetensors' ``safe_open`` so only the requested tensors' bytes are read + (the file is mmap-backed, not fully loaded). + """ + import safetensors + + by_file: dict[str, list[str]] = {} + for name, file in weight_map.items(): + if name.startswith(prefix): + by_file.setdefault(file, []).append(name) + + state: dict[str, torch.Tensor] = {} + for file, names in by_file.items(): + with safetensors.safe_open( + os.path.join(ckpt_path, file), framework="pt", device="cpu" + ) as f: + for name in names: + state[name] = f.get_tensor(name) + return state + + +def _read_non_layer_state_dict( + ckpt_path: str, + weight_map: dict, + layer_prefixes: list, +) -> dict: + """Read everything NOT under any decoder-layer prefix (embed, lm_head, norm, ...).""" + import safetensors + + prefixes = tuple(layer_prefixes) + by_file: dict[str, list[str]] = {} + for name, file in weight_map.items(): + if not name.startswith(prefixes): + by_file.setdefault(file, []).append(name) + + state: dict[str, torch.Tensor] = {} + for file, names in by_file.items(): + with safetensors.safe_open( + os.path.join(ckpt_path, file), framework="pt", device="cpu" + ) as f: + for name in names: + state[name] = f.get_tensor(name) + return state + + +def _load_via_parallel_read( + ckpt_path: str, + device: torch.device, + rank: int, + world_size: int, + args, + trust_remote_code: bool, + mp_policy, + cpu_offload: bool, + weight_map: dict, +): + """Parallel-read path: each rank reads its share of decoder layers from disk. + + Avoids the rank-0 bottleneck of the rank-0-load-and-broadcast path. Each layer + is owned by ``layer_idx % world_size`` and broadcast from its owner. + + See ``/home/svelury/.claude/plans/parallel-read-loader.md`` for the design. + """ + from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector + from modelopt.torch.utils.distributed import ( + broadcast_state_dict, + fsdp2_wrap, + init_params_on_meta, + load_state_dict_into_fsdp2_layer, + ) + + hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=trust_remote_code) + if args is not None: + validate_fsdp2_supported(args, hf_config) + dtype = getattr(hf_config, "torch_dtype", None) or torch.bfloat16 + + # Phase A: meta skeleton on every rank + with init_params_on_meta(): + model = AutoModelForCausalLM.from_config( + hf_config, torch_dtype=dtype, trust_remote_code=trust_remote_code + ) + model.eval() + if hasattr(model, "config") and hasattr(model.config, "use_cache"): + model.config.use_cache = False + + # Phase B: wrap decoder layers (root NOT wrapped). Discover prefixes. + decoder_layers = LayerActivationCollector.get_decoder_layers(model) + if decoder_layers is None: + raise RuntimeError("Could not auto-detect decoder layers for parallel-read loader.") + module_to_name = {m: n for n, m in model.named_modules()} + layer_prefixes = [module_to_name[layer] + "." for layer in decoder_layers] + model._original_architectures = list(model.config.architectures or []) + + fsdp2_wrap(model, mp_policy=mp_policy, cpu_offload=cpu_offload) + + layer_to_rank = {i: i % world_size for i in range(len(decoder_layers))} + + # Phase C: materialize meta → empty tensors. With cpu_offload, DTensor shards + # land on CPU (FSDP2 will stream them to GPU per-layer during forward). + materialize_device = torch.device("cpu") if cpu_offload else device + from torch.distributed.tensor import DTensor + + def _materialize(t): + is_meta_dtensor = isinstance(t, DTensor) and t._local_tensor.is_meta + if is_meta_dtensor or (not isinstance(t, DTensor) and t.is_meta): + return torch.empty_like(t, device=materialize_device) + return t.to(materialize_device) + + model._apply(_materialize) + + # Phase D: each rank reads its owned layers from disk in parallel. + owned: dict[int, dict] = {} + for layer_idx, owner in layer_to_rank.items(): + if owner == rank: + owned[layer_idx] = _read_safetensors_state_dict_for_prefix( + ckpt_path, weight_map, layer_prefixes[layer_idx] + ) + + # Phase E: per-layer broadcast + shard. Broadcasts run on GPU (NCCL requires + # CUDA tensors); with cpu_offload we copy back to CPU before writing into the + # CPU-resident DTensor shard. + for layer_idx, layer in enumerate(decoder_layers): + src = layer_to_rank[layer_idx] + layer_state_full = broadcast_state_dict(owned.get(layer_idx), src=src, device=device) + prefix = layer_prefixes[layer_idx] + stripped = {k[len(prefix) :]: v for k, v in layer_state_full.items()} + if cpu_offload: + stripped = {k: v.cpu() for k, v in stripped.items()} + load_state_dict_into_fsdp2_layer(layer, stripped) + if src == rank: + del owned[layer_idx] + del layer_state_full, stripped + + # Phase F: non-decoder params (embed, lm_head, norm) — rank 0 reads + broadcasts. + non_layer = ( + _read_non_layer_state_dict(ckpt_path, weight_map, layer_prefixes) if rank == 0 else None + ) + non_layer = broadcast_state_dict(non_layer, src=0, device=device) + if cpu_offload: + non_layer = {k: v.cpu() for k, v in non_layer.items()} + # Root is NOT FSDP-wrapped → these are plain nn.Parameters / buffers. Direct copy. + missing, unexpected = model.load_state_dict(non_layer, strict=False, assign=False) + if unexpected: + warnings.warn( + f"Unexpected keys in non-layer state dict on rank {rank}: {unexpected[:3]}..." + ) + + if cpu_offload: + # FSDP-managed (DTensor) decoder shards stay on CPU. Promote everything + # else (root-level plain params + all buffers) to GPU now. + for module in model.modules(): + for name, param in list(module._parameters.items()): + if param is None or isinstance(param, DTensor): + continue + module._parameters[name] = nn.Parameter( + param.data.to(device), requires_grad=param.requires_grad + ) + for name, buf in list(module._buffers.items()): + if buf is None or isinstance(buf, DTensor): + continue + module._buffers[name] = buf.to(device) + + # Phase G: tie weights, freeze. + if hasattr(model, "tie_weights"): + model.tie_weights() + for p in model.parameters(): + p.requires_grad_(False) + + return model + + def load_and_prepare_fsdp2_model( ckpt_path: str, device: torch.device, rank: int, + world_size: int = 1, args=None, trust_remote_code: bool = False, mp_policy=None, cpu_offload: bool = False, ): - """Load and FSDP2-shard a causal LM (accelerate-style rank-0-only CPU load). + """Load and FSDP2-shard a causal LM. - Replicates ``accelerate.init_empty_weights(include_buffers=False)`` + - ``load_checkpoint_in_model`` manually: + Default path: **parallel read** — each rank reads its share of decoder layers + from disk and broadcasts to other ranks. Eliminates the rank-0 disk bottleneck. - - Rank 0: ``from_pretrained`` on CPU; capture ``src_state_dict``. - - Other ranks: ``from_config`` under ``init_params_on_meta`` → params on - meta (~0 CPU), buffers computed on CPU from config (RoPE inv_freq etc.). - - ``fsdp2_shard`` wraps decoder layers (root stays unsharded), materializes - meta→GPU, broadcasts state_dict from rank 0, re-ties weights, freezes. + Fallback path (when no ``model.safetensors.index.json`` exists, or when + ``cpu_offload=True``): rank-0 ``from_pretrained`` + ``set_model_state_dict`` + broadcast. Same behavior as previous versions. - Memory: rank 0 holds the full BF16 model in CPU during the broadcast - (~model_size bytes); other ranks pay ~0 CPU. Each rank ends with - ``model_size / world_size`` GPU shard storage plus replicated - ``embed_tokens`` + ``lm_head`` (~few-GiB total). + Both paths produce identical sharded models (same FSDP2 wrap layout, root + unsharded, decoder layers DTensor-sharded across the FSDP mesh). v1 supports standard transformers families only (causal LMs that load cleanly via ``AutoModelForCausalLM``). VILA / pack-quantized / @@ -143,6 +319,35 @@ def load_and_prepare_fsdp2_model( """ from modelopt.torch.utils.distributed import fsdp2_shard, init_params_on_meta + # Try parallel-read path first if the checkpoint has an index file. + # Resolve ckpt_path: if it's a local directory, use as-is; otherwise it's a HF + # Hub ID and we need to materialize the cache directory before parallel read. + resolved_path = ckpt_path + if not os.path.isdir(ckpt_path): + if snapshot_download is None: + resolved_path = None # will fall back to rank-0-broadcast path + else: + resolved_path = snapshot_download(ckpt_path) + + index_path = Path(resolved_path) / "model.safetensors.index.json" if resolved_path else None + if index_path is not None and index_path.exists(): + with open(index_path) as f: + weight_map = json.load(f)["weight_map"] + result = _load_via_parallel_read( + ckpt_path=resolved_path, + device=device, + rank=rank, + world_size=world_size, + args=args, + trust_remote_code=trust_remote_code, + mp_policy=mp_policy, + cpu_offload=cpu_offload, + weight_map=weight_map, + ) + if result is not None: + return result + + # Fallback: existing rank-0-load + broadcast path. hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=trust_remote_code) if args is not None: validate_fsdp2_supported(args, hf_config) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 064c3e4a5ad..7e3bcc2f66c 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -465,6 +465,7 @@ def load_model(args: argparse.Namespace): ckpt_path=args.pyt_ckpt_path, device=args.device, rank=args.rank, + world_size=args.world_size, args=args, trust_remote_code=args.trust_remote_code, cpu_offload=args.cpu_offload, @@ -944,7 +945,7 @@ def pre_quantize( trust_remote_code=args.trust_remote_code, ) else: - generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) + generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=10) return preview_input_ids, generated_ids_before_ptq @@ -1001,7 +1002,7 @@ def post_quantize( pass elif model_type != "llama4" and not is_nemotron_vl_model: # Our fake quantizer may not be fully compatible with torch.compile. - generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) + generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=10) elif is_nemotron_vl_model and tokenizer is not None: generated_ids_after_ptq = run_nemotron_vl_preview( full_model, diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 4d9ed0bdf3e..b0889794c10 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -507,8 +507,31 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn. placements=[Replicate()] * fsdp_dim + list(original_placements[fsdp_dim:]), device_mesh=original_device_mesh, ) - originals[name] = (param, collected, original_placements, original_device_mesh) - _set_parameter(module, name, nn.Parameter(collected.to_local())) + local_replicated = collected.to_local() + # With FSDP2 CPUOffloadPolicy, the DTensor's local shard lives on CPU, so the + # gathered local tensor is also on CPU. Mirror it onto the current GPU so + # calibration forwards (activations on GPU) see a same-device weight. + if local_replicated.device.type == "cpu" and torch.cuda.is_available(): + working_local = local_replicated.to(torch.cuda.current_device()) + originals[name] = ( + param, + collected, + original_placements, + original_device_mesh, + local_replicated, + working_local, + ) + _set_parameter(module, name, nn.Parameter(working_local)) + else: + originals[name] = ( + param, + collected, + original_placements, + original_device_mesh, + None, + None, + ) + _set_parameter(module, name, nn.Parameter(local_replicated)) yield @@ -518,7 +541,13 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn. collected, original_placements, original_device_mesh, + cpu_local, + gpu_working, ) in originals.items(): + if cpu_local is not None: + # Mirror GPU-resident modifications back into the CPU-resident DTensor local shard + # so the subsequent redistribute-back sees the up-to-date values. + cpu_local.data.copy_(gpu_working.data.to(cpu_local.device)) original_param.to_local().data.copy_( collected.redistribute( placements=original_placements, device_mesh=original_device_mesh diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index 4ee84a58d6a..dd914d31396 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -347,14 +347,9 @@ def fsdp2_shard(model, device, rank, src_state_dict=None, mp_policy=None, cpu_of """ from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict - def _log(msg): - print(f"[rank {rank}] [fsdp2_shard] {msg}", flush=True) - model._original_architectures = list(model.config.architectures or []) - t0 = time.perf_counter() fsdp2_wrap(model, mp_policy=mp_policy, cpu_offload=cpu_offload) - _log(f"fsdp2_wrap (decoders) took {time.perf_counter() - t0:.1f}s") # With CPU offload, FSDP2 requires DTensor params on CPU at lazy_init time # (it streams them to GPU per-layer during forward). Also, set_model_state_dict @@ -368,23 +363,18 @@ def _materialize(t): return torch.empty_like(t, device=materialize_device) return t.to(materialize_device) - t0 = time.perf_counter() model._apply(_materialize) - _log(f"materialize (→{'CPU' if cpu_offload else 'GPU'}) took {time.perf_counter() - t0:.1f}s") if src_state_dict is not None: - t0 = time.perf_counter() set_model_state_dict( model, src_state_dict, options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), ) - _log(f"set_model_state_dict broadcast took {time.perf_counter() - t0:.1f}s") if cpu_offload: # FSDP-managed (DTensor) params stay on CPU — FSDP2 streams them per layer. # Move everything else (root-level plain params + all buffers) to GPU now. - t0 = time.perf_counter() for module in model.modules(): for name, param in list(module._parameters.items()): if param is None or isinstance(param, DTensor): @@ -396,17 +386,13 @@ def _materialize(t): if buf is None or isinstance(buf, DTensor): continue module._buffers[name] = buf.to(device) - _log(f"promote non-DTensor → GPU took {time.perf_counter() - t0:.1f}s") if hasattr(model, "tie_weights"): - t0 = time.perf_counter() model.tie_weights() - _log(f"tie_weights took {time.perf_counter() - t0:.1f}s") - t0 = time.perf_counter() for p in model.parameters(): p.requires_grad_(False) - _log(f"freeze params took {time.perf_counter() - t0:.1f}s") + return model @@ -466,6 +452,58 @@ def calibrate(_unwrapped_model): return calibrate +def broadcast_state_dict( + state_dict_or_none: dict | None, + src: int, + device: torch.device, + pg=None, +) -> dict: + """Broadcast a dict of CPU tensors from rank ``src`` to all ranks. + + Two phases: (1) broadcast metadata (key list + shape/dtype) via + ``broadcast_object_list``, (2) broadcast each tensor via ``dist.broadcast``. + Source rank passes the populated dict; non-source ranks pass ``None``. + Returns a dict of tensors on ``device`` on every rank. + """ + is_src = torch.distributed.get_rank() == src + meta: list[Any] = ( + [{name: (tuple(t.shape), t.dtype) for name, t in state_dict_or_none.items()}] + if is_src and state_dict_or_none is not None + else [None] + ) + torch.distributed.broadcast_object_list(meta, src=src, group=pg) + meta_dict = meta[0] + assert meta_dict is not None + + src_state_dict = state_dict_or_none or {} + out: dict[str, torch.Tensor] = {} + for name, (shape, dtype) in meta_dict.items(): + if is_src: + t = src_state_dict[name].to(device, non_blocking=True) + else: + t = torch.empty(shape, dtype=dtype, device=device) + torch.distributed.broadcast(t, src=src, group=pg) + out[name] = t + return out + + +def load_state_dict_into_fsdp2_layer(layer: nn.Module, full_state_dict: dict) -> None: + """Load full (replicated) tensors into an FSDP2-wrapped layer's DTensor local shards. + + Each rank already has the full tensor; we just need to shard locally. + Uses ``set_model_state_dict(broadcast_from_rank0=False)`` — each rank holds the + full tensor in ``full_state_dict``, so no collective is needed; the helper just + slices each rank's local shard from the full tensor. + """ + from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict + + set_model_state_dict( + layer, + full_state_dict, + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=False), + ) + + class Fsdp2StateDictAdapter: """Adapter exposing ``.get_state_dict(model)`` for FSDP2-sharded models. From 2b5080796b9081b39c3bfacb21424362af7b0ea7 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 2 Jun 2026 22:16:22 +0000 Subject: [PATCH 07/10] claude self-review comments Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 263 +---------- examples/llm_ptq/hf_ptq.py | 18 +- modelopt/torch/export/unified_export_hf.py | 13 +- .../torch/quantization/utils/core_utils.py | 34 +- modelopt/torch/utils/distributed.py | 439 +++++++++++++----- tests/gpu/torch/quantization/test_fsdp2.py | 50 ++ 6 files changed, 402 insertions(+), 415 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index b08831bcbf3..1d9fadf6368 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -27,7 +27,6 @@ from typing import Any import torch -import torch.nn as nn import transformers from accelerate import infer_auto_device_map, init_empty_weights from accelerate.utils import get_max_memory @@ -113,184 +112,6 @@ def validate_fsdp2_supported(args, config): ) -def _read_safetensors_state_dict_for_prefix( - ckpt_path: str, - weight_map: dict, - prefix: str, -) -> dict: - """Read all tensors whose name starts with ``prefix`` from safetensors files. - - Groups param names by file to avoid re-opening the same file. Returns CPU tensors. - Uses safetensors' ``safe_open`` so only the requested tensors' bytes are read - (the file is mmap-backed, not fully loaded). - """ - import safetensors - - by_file: dict[str, list[str]] = {} - for name, file in weight_map.items(): - if name.startswith(prefix): - by_file.setdefault(file, []).append(name) - - state: dict[str, torch.Tensor] = {} - for file, names in by_file.items(): - with safetensors.safe_open( - os.path.join(ckpt_path, file), framework="pt", device="cpu" - ) as f: - for name in names: - state[name] = f.get_tensor(name) - return state - - -def _read_non_layer_state_dict( - ckpt_path: str, - weight_map: dict, - layer_prefixes: list, -) -> dict: - """Read everything NOT under any decoder-layer prefix (embed, lm_head, norm, ...).""" - import safetensors - - prefixes = tuple(layer_prefixes) - by_file: dict[str, list[str]] = {} - for name, file in weight_map.items(): - if not name.startswith(prefixes): - by_file.setdefault(file, []).append(name) - - state: dict[str, torch.Tensor] = {} - for file, names in by_file.items(): - with safetensors.safe_open( - os.path.join(ckpt_path, file), framework="pt", device="cpu" - ) as f: - for name in names: - state[name] = f.get_tensor(name) - return state - - -def _load_via_parallel_read( - ckpt_path: str, - device: torch.device, - rank: int, - world_size: int, - args, - trust_remote_code: bool, - mp_policy, - cpu_offload: bool, - weight_map: dict, -): - """Parallel-read path: each rank reads its share of decoder layers from disk. - - Avoids the rank-0 bottleneck of the rank-0-load-and-broadcast path. Each layer - is owned by ``layer_idx % world_size`` and broadcast from its owner. - - See ``/home/svelury/.claude/plans/parallel-read-loader.md`` for the design. - """ - from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector - from modelopt.torch.utils.distributed import ( - broadcast_state_dict, - fsdp2_wrap, - init_params_on_meta, - load_state_dict_into_fsdp2_layer, - ) - - hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=trust_remote_code) - if args is not None: - validate_fsdp2_supported(args, hf_config) - dtype = getattr(hf_config, "torch_dtype", None) or torch.bfloat16 - - # Phase A: meta skeleton on every rank - with init_params_on_meta(): - model = AutoModelForCausalLM.from_config( - hf_config, torch_dtype=dtype, trust_remote_code=trust_remote_code - ) - model.eval() - if hasattr(model, "config") and hasattr(model.config, "use_cache"): - model.config.use_cache = False - - # Phase B: wrap decoder layers (root NOT wrapped). Discover prefixes. - decoder_layers = LayerActivationCollector.get_decoder_layers(model) - if decoder_layers is None: - raise RuntimeError("Could not auto-detect decoder layers for parallel-read loader.") - module_to_name = {m: n for n, m in model.named_modules()} - layer_prefixes = [module_to_name[layer] + "." for layer in decoder_layers] - model._original_architectures = list(model.config.architectures or []) - - fsdp2_wrap(model, mp_policy=mp_policy, cpu_offload=cpu_offload) - - layer_to_rank = {i: i % world_size for i in range(len(decoder_layers))} - - # Phase C: materialize meta → empty tensors. With cpu_offload, DTensor shards - # land on CPU (FSDP2 will stream them to GPU per-layer during forward). - materialize_device = torch.device("cpu") if cpu_offload else device - from torch.distributed.tensor import DTensor - - def _materialize(t): - is_meta_dtensor = isinstance(t, DTensor) and t._local_tensor.is_meta - if is_meta_dtensor or (not isinstance(t, DTensor) and t.is_meta): - return torch.empty_like(t, device=materialize_device) - return t.to(materialize_device) - - model._apply(_materialize) - - # Phase D: each rank reads its owned layers from disk in parallel. - owned: dict[int, dict] = {} - for layer_idx, owner in layer_to_rank.items(): - if owner == rank: - owned[layer_idx] = _read_safetensors_state_dict_for_prefix( - ckpt_path, weight_map, layer_prefixes[layer_idx] - ) - - # Phase E: per-layer broadcast + shard. Broadcasts run on GPU (NCCL requires - # CUDA tensors); with cpu_offload we copy back to CPU before writing into the - # CPU-resident DTensor shard. - for layer_idx, layer in enumerate(decoder_layers): - src = layer_to_rank[layer_idx] - layer_state_full = broadcast_state_dict(owned.get(layer_idx), src=src, device=device) - prefix = layer_prefixes[layer_idx] - stripped = {k[len(prefix) :]: v for k, v in layer_state_full.items()} - if cpu_offload: - stripped = {k: v.cpu() for k, v in stripped.items()} - load_state_dict_into_fsdp2_layer(layer, stripped) - if src == rank: - del owned[layer_idx] - del layer_state_full, stripped - - # Phase F: non-decoder params (embed, lm_head, norm) — rank 0 reads + broadcasts. - non_layer = ( - _read_non_layer_state_dict(ckpt_path, weight_map, layer_prefixes) if rank == 0 else None - ) - non_layer = broadcast_state_dict(non_layer, src=0, device=device) - if cpu_offload: - non_layer = {k: v.cpu() for k, v in non_layer.items()} - # Root is NOT FSDP-wrapped → these are plain nn.Parameters / buffers. Direct copy. - missing, unexpected = model.load_state_dict(non_layer, strict=False, assign=False) - if unexpected: - warnings.warn( - f"Unexpected keys in non-layer state dict on rank {rank}: {unexpected[:3]}..." - ) - - if cpu_offload: - # FSDP-managed (DTensor) decoder shards stay on CPU. Promote everything - # else (root-level plain params + all buffers) to GPU now. - for module in model.modules(): - for name, param in list(module._parameters.items()): - if param is None or isinstance(param, DTensor): - continue - module._parameters[name] = nn.Parameter( - param.data.to(device), requires_grad=param.requires_grad - ) - for name, buf in list(module._buffers.items()): - if buf is None or isinstance(buf, DTensor): - continue - module._buffers[name] = buf.to(device) - - # Phase G: tie weights, freeze. - if hasattr(model, "tie_weights"): - model.tie_weights() - for p in model.parameters(): - p.requires_grad_(False) - - return model - - def load_and_prepare_fsdp2_model( ckpt_path: str, device: torch.device, @@ -301,90 +122,30 @@ def load_and_prepare_fsdp2_model( mp_policy=None, cpu_offload: bool = False, ): - """Load and FSDP2-shard a causal LM. - - Default path: **parallel read** — each rank reads its share of decoder layers - from disk and broadcasts to other ranks. Eliminates the rank-0 disk bottleneck. + """CLI-side FSDP2 loader: validate against CLI constraints, then delegate to core. - Fallback path (when no ``model.safetensors.index.json`` exists, or when - ``cpu_offload=True``): rank-0 ``from_pretrained`` + ``set_model_state_dict`` - broadcast. Same behavior as previous versions. + Runs :func:`validate_fsdp2_supported` to enforce example-script policy (VILA, + multimodal/VL, ``--low_memory_mode``, ``--specdec_offline_dataset``, etc.), + then calls :func:`modelopt.torch.utils.distributed.load_fsdp2_causal_lm`. - Both paths produce identical sharded models (same FSDP2 wrap layout, root - unsharded, decoder layers DTensor-sharded across the FSDP mesh). - - v1 supports standard transformers families only (causal LMs that load - cleanly via ``AutoModelForCausalLM``). VILA / pack-quantized / - speculative / VL go through ``get_model`` and don't get FSDP2. + The core loader is fully reusable (no argparse coupling); this wrapper exists + to keep the CLI policy at the example edge. """ - from modelopt.torch.utils.distributed import fsdp2_shard, init_params_on_meta - - # Try parallel-read path first if the checkpoint has an index file. - # Resolve ckpt_path: if it's a local directory, use as-is; otherwise it's a HF - # Hub ID and we need to materialize the cache directory before parallel read. - resolved_path = ckpt_path - if not os.path.isdir(ckpt_path): - if snapshot_download is None: - resolved_path = None # will fall back to rank-0-broadcast path - else: - resolved_path = snapshot_download(ckpt_path) - - index_path = Path(resolved_path) / "model.safetensors.index.json" if resolved_path else None - if index_path is not None and index_path.exists(): - with open(index_path) as f: - weight_map = json.load(f)["weight_map"] - result = _load_via_parallel_read( - ckpt_path=resolved_path, - device=device, - rank=rank, - world_size=world_size, - args=args, - trust_remote_code=trust_remote_code, - mp_policy=mp_policy, - cpu_offload=cpu_offload, - weight_map=weight_map, - ) - if result is not None: - return result + from modelopt.torch.utils.distributed import load_fsdp2_causal_lm - # Fallback: existing rank-0-load + broadcast path. - hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=trust_remote_code) if args is not None: + hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=trust_remote_code) validate_fsdp2_supported(args, hf_config) - dtype = getattr(hf_config, "torch_dtype", None) or torch.bfloat16 - - if rank == 0: - src_model = AutoModelForCausalLM.from_pretrained( - ckpt_path, - torch_dtype="auto", - trust_remote_code=trust_remote_code, - low_cpu_mem_usage=True, - ) - src_model.eval() - src_state_dict = src_model.state_dict() - else: - src_model = None - src_state_dict = {} - - with init_params_on_meta(): - model = AutoModelForCausalLM.from_config( - hf_config, torch_dtype=dtype, trust_remote_code=trust_remote_code - ) - model.eval() - if hasattr(model, "config") and hasattr(model.config, "use_cache"): - model.config.use_cache = False - - sharded = fsdp2_shard( - model, + return load_fsdp2_causal_lm( + ckpt_path, device, rank, - src_state_dict=src_state_dict, + world_size, + trust_remote_code=trust_remote_code, mp_policy=mp_policy, cpu_offload=cpu_offload, ) - del src_model, src_state_dict - return sharded def run_nemotron_vl_preview( diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 7e3bcc2f66c..a1f5d1b4162 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -87,11 +87,7 @@ get_max_batch_size, get_supported_datasets, ) -from modelopt.torch.utils.distributed import ( - Fsdp2StateDictAdapter, - fsdp_aware_forward_loop, - shard_dataloader, -) +from modelopt.torch.utils.distributed import fsdp_aware_forward_loop, shard_dataloader from modelopt.torch.utils.memory_monitor import launch_memory_monitor from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader @@ -727,14 +723,12 @@ def mono_quantize( def _export_fsdp2_hf_checkpoint(args: argparse.Namespace, full_model, export_path: str) -> None: """FSDP2-aware HF checkpoint export. - Gathers the full state dict from FSDP2 shards via ``Fsdp2StateDictAdapter``, - saves it on rank 0 only, then patches the saved config with quantization - metadata and the original (pre-FSDP-prefix) architectures list. + ``_export_transformers_checkpoint`` detects the FSDP2 wrap internally and + gathers a full unsharded state dict (CPU-offloaded). Rank 0 then writes the + safetensors and patches the saved config with quantization metadata and the + original (pre-FSDP-prefix) architectures list. """ - adapter = Fsdp2StateDictAdapter() - post_state_dict, hf_quant_config = _export_transformers_checkpoint( - full_model, torch.bfloat16, accelerator=adapter - ) + post_state_dict, hf_quant_config = _export_transformers_checkpoint(full_model, torch.bfloat16) if args.is_main: export_dir = Path(export_path) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 0626d0a8fd5..07e67da2f14 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -49,6 +49,7 @@ except ImportError: HAS_DIFFUSERS = False +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict from torch.distributed.fsdp import FSDPModule from modelopt.torch.quantization import set_quantizer_by_cfg_context @@ -710,7 +711,6 @@ def _export_transformers_checkpoint( Args: model: the full torch model to export. The actual quantized model may be a submodule. dtype: the weights data type to export the unquantized layers or the default model data type if None. - accelerator: the accelerator instance in case of distributed export setup. Returns: post_state_dict: Dict containing quantized weights @@ -724,8 +724,6 @@ def _export_transformers_checkpoint( f"({dtype}), which may lead to numerical errors." ) - accelerator = kwargs.get("accelerator") - # Handle input quantizers of experts that are not calibrated for _, sub_module in model.named_modules(): if is_moe(sub_module) and hasattr(sub_module, "experts"): @@ -829,9 +827,12 @@ def _export_transformers_checkpoint( _reconstruct_fused_moe_linear(model) - if accelerator is not None: - # Gather state_dict from all ranks - quantized_state_dict = accelerator.get_state_dict(model) + if any(isinstance(m, FSDPModule) for m in model.modules()): + # FSDP2-sharded model: gather full state_dict from all ranks (with CPU + # offload to bound peak GPU memory during the gather). + quantized_state_dict = get_model_state_dict( + model, options=StateDictOptions(full_state_dict=True, cpu_offload=True) + ) else: quantized_state_dict = model.state_dict() diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index b0889794c10..2e27b67ea13 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -399,26 +399,18 @@ def is_pow2(n): def _get_fsdp2_mesh(module: nn.Module): - """Get the mesh info of the model. - - Prefers ``post_forward_mesh_info`` (set when ``reshard_after_forward=True``); - falls back to ``mesh_info`` if it's None (observed under some PyTorch FSDP2 - configurations: eval mode + all-frozen params at the time ``persistent - _materialization`` queries the state — the mesh itself is still valid). - """ + """Get the mesh info of the model.""" try: from torch.distributed._composable_state import _get_module_state except ImportError: return None fsdp_state = _get_module_state(module) - pg = getattr(fsdp_state, "_fsdp_param_group", None) - if pg is None: - return None - info = pg.post_forward_mesh_info or pg.mesh_info - if info is None: - return None - return info.mesh + if ( + fsdp_state._fsdp_param_group + and fsdp_state._fsdp_param_group.post_forward_mesh_info is not None + ): + return fsdp_state._fsdp_param_group.post_forward_mesh_info.mesh def _get_module_name(module: nn.Module, root_model: nn.Module, name_to_module: dict | None = None): @@ -484,11 +476,13 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn. If TP is implemented with DTensor, the weight will be a local tensor of the TP DTensor under this context. """ - assert isinstance(root_model, torch.distributed.fsdp.FSDPModule), "We only support FSDP2" - assert not hasattr(module, "_hf_hook"), "We dont support FSDP2 with HF accelerate hooks" + # Note: ``root_model`` need not itself be an ``FSDPModule``. ``fsdp2_wrap`` shards + # only the decoder layers and leaves the root unsharded (see its docstring), so the + # real precondition is that ``module`` is *under* FSDP2 — enforced by the assert + # below, since ``_get_enclosing_fsdp_module`` only ever returns ``FSDPModule`` instances. fsdp_module = _get_enclosing_fsdp_module(module, root_model) - assert fsdp_module is not None, "Module is not wrapped by FSDP" + assert fsdp_module is not None, "Module is not wrapped by FSDP2" fsdp_device_mesh = _get_fsdp2_mesh(fsdp_module) fsdp_dim = fsdp_device_mesh.ndim @@ -557,7 +551,11 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn. @contextmanager -def enable_weight_access_and_writeback(module, root_model, name_to_module: dict | None = None): +def enable_weight_access_and_writeback( + module, + root_model, + name_to_module: dict | None = None, +): """Enable weight access and writeback for a module. Useful for modules with weight not intact such as Linear layer in FSDP wrapped model or diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index dd914d31396..feeb3acc50b 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -20,7 +20,7 @@ import os import time from collections.abc import Callable -from contextlib import contextmanager, suppress +from contextlib import suppress from datetime import timedelta from typing import Any from warnings import warn @@ -32,17 +32,16 @@ __all__ = [ "DistributedProcessGroup", - "Fsdp2StateDictAdapter", "ParallelState", "backend", "barrier", "fsdp2_shard", "fsdp2_wrap", "fsdp_aware_forward_loop", - "init_params_on_meta", "is_available", "is_initialized", "is_master", + "load_fsdp2_causal_lm", "rank", "shard_dataloader", "size", @@ -237,33 +236,22 @@ def fsdp2_wrap( ): """Apply FSDP2 ``fully_shard`` to each decoder layer of ``model``. - Decoder layers are auto-detected via - ``modelopt.torch.quantization.utils.layerwise_calib.LayerActivationCollector.get_decoder_layers``. - Pass ``override_cls_name`` to force a specific transformer block class. Pass - ``mp_policy`` (a ``torch.distributed.fsdp.MixedPrecisionPolicy``) to control - compute / reduce dtype; default ``None`` means no upcast / downcast. - - Pass ``device`` to stream each layer to that device just before sharding it - (avoids holding the full model on GPU simultaneously). When ``device`` is - ``None``, layers are sharded on whatever device they're already on. - - Set ``cpu_offload=True`` to attach FSDP2's ``CPUOffloadPolicy`` to each - wrapped layer. Each rank's shard then lives on CPU between forward passes - and is streamed to GPU per-layer (H2D + all-gather + compute + reshard + - D2H). Useful when the per-rank decoder shard is the binding GPU constraint - (e.g., 200B+ models on tight GPU budgets) or when you want headroom for a - larger calibration batch. Adds PCIe traffic per layer per batch; on - setups where the model already fits comfortably it usually slows the run. - - The root module is intentionally NOT sharded — ``embed_tokens`` and - ``lm_head`` stay as plain replicated tensors. This costs ~few-GiB per rank - (full copies of embed + lm_head) but unifies the layerwise and non-layerwise - code paths: a DTensor-wrapped ``embed_tokens.weight`` raises a "mixed Tensor - / DTensor" error when modelopt's layerwise calibration passes plain - ``input_ids`` into the embedding lookup, and FSDP2's root pre-forward hook - doesn't auto-wrap LongTensor inputs. + Decoder layers are auto-detected via ``LayerActivationCollector.get_decoder_layers``; + pass ``override_cls_name`` to force a specific block class instead. + + Args: + mp_policy: ``MixedPrecisionPolicy`` for compute/reduce dtype (``None`` = no cast). + device: stream each layer here just before sharding (avoids holding the full + model on GPU at once); ``None`` shards in place. + cpu_offload: attach ``CPUOffloadPolicy`` so each shard lives on CPU between + forwards and streams to GPU per-layer. Trades PCIe traffic for GPU memory; + use only when the per-rank shard is the binding constraint. + + The root is intentionally NOT sharded — ``embed_tokens``/``lm_head`` stay plain + replicated tensors, since a DTensor ``embed_tokens.weight`` breaks the embedding + lookup on plain ``input_ids`` during layerwise calibration. """ - from torch.distributed.fsdp import fully_shard + from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector @@ -273,7 +261,7 @@ def fsdp2_wrap( raise RuntimeError(f"No modules of class {override_cls_name!r} found in model") else: layers = LayerActivationCollector.get_decoder_layers(model) - if layers is None: + if not layers: raise RuntimeError( "Could not auto-detect decoder layers; pass override_cls_name explicitly." ) @@ -281,8 +269,6 @@ def fsdp2_wrap( if mp_policy is not None: fsdp_kwargs["mp_policy"] = mp_policy if cpu_offload: - from torch.distributed.fsdp import CPUOffloadPolicy - fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() for layer in layers: if device is not None: @@ -291,42 +277,22 @@ def fsdp2_wrap( return model -@contextmanager -def init_params_on_meta(): - """Replicate ``accelerate.init_empty_weights(include_buffers=False)``. - - Inside this context, ``nn.Module.register_parameter`` is patched so newly - registered parameters land on the ``meta`` device (zero CPU bytes). Buffer - registration is NOT patched — buffers are computed normally on CPU during - ``__init__`` (e.g. ``Qwen2RotaryEmbedding.__init__`` produces a real CPU - ``inv_freq`` from config). - - Use around ``from_config(...)`` on non-rank-0 ranks to build a meta-skeleton - that ``fsdp2_shard`` will materialize via ``set_model_state_dict`` - broadcast from rank 0. - """ - original = nn.Module.register_parameter - - def patched(self, name, param): - original(self, name, param) - if param is not None: - p = self._parameters[name] - self._parameters[name] = nn.Parameter(p.to("meta"), requires_grad=p.requires_grad) - - nn.Module.register_parameter = patched - try: - yield - finally: - nn.Module.register_parameter = original - - -def fsdp2_shard(model, device, rank, src_state_dict=None, mp_policy=None, cpu_offload=False): +def fsdp2_shard(model, device, src_state_dict=None, mp_policy=None, cpu_offload=False): """Shard a model across the current process group (accelerate-style rank-0 load). Caller contract: ``model`` is built on every rank with params on ``meta`` and - buffers on CPU (use ``init_params_on_meta`` around ``from_config``). Rank 0 - additionally passes ``src_state_dict`` captured from a real CPU model loaded - via ``from_pretrained``; other ranks pass ``None`` or ``{}``. + buffers on CPU (use ``init_empty_weights(include_buffers=False)`` around + ``from_config``). Rank 0 additionally passes ``src_state_dict`` captured from a + real CPU model loaded via ``from_pretrained``; other ranks pass ``{}``. + + ``set_model_state_dict(broadcast_from_rank0=True)`` (step 3) is a collective, so + every rank must reach it: non-rank-0 ranks pass ``{}`` (empty, not ``None``) to + participate in the broadcast. ``src_state_dict=None`` skips the broadcast entirely + and must therefore be ``None`` on *all* ranks (e.g. sharding a model that will be + loaded later) — mixing ``None`` with a populated dict across ranks will hang. + + Also sets ``model._original_architectures`` (FSDP2 wrapping can clobber + ``config.architectures``, which export reads back). Set ``cpu_offload=True`` to attach FSDP2's ``CPUOffloadPolicy`` to wrapped layers (each rank's shard lives on CPU between forwards). See @@ -335,7 +301,7 @@ def fsdp2_shard(model, device, rank, src_state_dict=None, mp_policy=None, cpu_of Root is never sharded (see ``fsdp2_wrap`` docstring). embed_tokens and lm_head stay as plain replicated tensors on every rank. - Steps (each timed and logged per-rank for diagnostics): + Steps: 1. ``fsdp2_wrap`` — apply ``fully_shard`` to decoder layers. 2. Materialize: meta params → empty GPU storage; real CPU buffers → GPU (preserves their values; ``to_empty`` is NOT used because it would wipe @@ -355,15 +321,7 @@ def fsdp2_shard(model, device, rank, src_state_dict=None, mp_policy=None, cpu_of # (it streams them to GPU per-layer during forward). Also, set_model_state_dict # rejects mixed-device models — so materialize everything on CPU first, # broadcast, then promote non-DTensor params + buffers to GPU after. - materialize_device = torch.device("cpu") if cpu_offload else device - - def _materialize(t): - is_meta_dtensor = isinstance(t, DTensor) and t._local_tensor.is_meta - if is_meta_dtensor or (not isinstance(t, DTensor) and t.is_meta): - return torch.empty_like(t, device=materialize_device) - return t.to(materialize_device) - - model._apply(_materialize) + _materialize_meta_model(model, torch.device("cpu") if cpu_offload else device) if src_state_dict is not None: set_model_state_dict( @@ -373,25 +331,12 @@ def _materialize(t): ) if cpu_offload: - # FSDP-managed (DTensor) params stay on CPU — FSDP2 streams them per layer. - # Move everything else (root-level plain params + all buffers) to GPU now. - for module in model.modules(): - for name, param in list(module._parameters.items()): - if param is None or isinstance(param, DTensor): - continue - module._parameters[name] = nn.Parameter( - param.data.to(device), requires_grad=param.requires_grad - ) - for name, buf in list(module._buffers.items()): - if buf is None or isinstance(buf, DTensor): - continue - module._buffers[name] = buf.to(device) + _promote_non_dtensor_to_gpu(model, device) if hasattr(model, "tie_weights"): model.tie_weights() - for p in model.parameters(): - p.requires_grad_(False) + model.requires_grad_(False) return model @@ -399,8 +344,9 @@ def _materialize(t): def shard_dataloader(loader, rank: int, world_size: int): """Wrap a DataLoader with a DistributedSampler so each rank sees a unique shard. - Preserves the input loader's ``batch_size``, ``collate_fn``, ``num_workers``, - and ``pin_memory``. + ``drop_last=False`` keeps per-rank batch counts equal (else a rank exits + calibration early and hangs the others on FSDP2 collectives), at the cost of the + sampler repeating up to ``world_size - 1`` samples to pad the even split. """ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -425,24 +371,19 @@ def shard_dataloader(loader, rank: int, world_size: int): def fsdp_aware_forward_loop(wrapped_model, dataloader, device=None): """Build an ``mtq.quantize`` ``forward_loop`` that respects FSDP wrapping. - ``mtq.quantize`` strips the FSDP wrapper before calling ``forward_loop``, - handing the user the unwrapped inner module. Calling the unwrapped module - bypasses FSDP's pre/post-forward hooks (no all-gather, no reshard), which - breaks calibration on FSDP2. The closure returned here captures the outer - *wrapped* model and ignores the ``unwrapped_model`` argument that - ``mtq.quantize`` passes in. + ``mtq.quantize`` hands ``forward_loop`` the *unwrapped* inner module, and calling + that bypasses FSDP's pre/post-forward hooks (no all-gather/reshard) — breaking + calibration. This closure ignores that argument and calls the captured *wrapped* + model instead. - Used by ``examples/llm_ptq/hf_ptq.py`` under ``--use_fsdp2``. - - TODO: ``modelopt/torch/quantization/plugins/transformers_trainer.py`` (the - QLoRA path) currently has the same logic inlined inside ``_quantize_model``. - Consolidate that call site to use this helper too. + TODO: ``transformers_trainer.py`` (QLoRA path) has the same logic inlined in + ``_quantize_model``; consolidate it onto this helper. """ from tqdm import tqdm def calibrate(_unwrapped_model): - for batch in tqdm(dataloader, desc="Calibrating"): - if device is not None and isinstance(batch, dict): + for batch in tqdm(dataloader, desc="Calibrating", disable=not is_master()): + if device is not None: batch = { k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items() @@ -473,7 +414,7 @@ def broadcast_state_dict( ) torch.distributed.broadcast_object_list(meta, src=src, group=pg) meta_dict = meta[0] - assert meta_dict is not None + assert meta_dict is not None, f"src rank {src} passed no state dict to broadcast" src_state_dict = state_dict_or_none or {} out: dict[str, torch.Tensor] = {} @@ -487,40 +428,282 @@ def broadcast_state_dict( return out -def load_state_dict_into_fsdp2_layer(layer: nn.Module, full_state_dict: dict) -> None: - """Load full (replicated) tensors into an FSDP2-wrapped layer's DTensor local shards. +def _read_safetensors_state_dict( + ckpt_path: str, + weight_map: dict, + select: Callable[[str], bool], +) -> dict: + """Read tensors whose name satisfies ``select`` from safetensors files. - Each rank already has the full tensor; we just need to shard locally. - Uses ``set_model_state_dict(broadcast_from_rank0=False)`` — each rank holds the - full tensor in ``full_state_dict``, so no collective is needed; the helper just - slices each rank's local shard from the full tensor. + Groups param names by file to avoid re-opening. Returns CPU tensors. + Uses ``safe_open`` so only the requested tensors' bytes are read. """ + from safetensors import safe_open + + by_file: dict[str, list[str]] = {} + for name, file in weight_map.items(): + if select(name): + by_file.setdefault(file, []).append(name) + + state: dict[str, torch.Tensor] = {} + for file, names in by_file.items(): + with safe_open(os.path.join(ckpt_path, file), framework="pt", device="cpu") as f: + for name in names: + state[name] = f.get_tensor(name) + return state + + +def _materialize_meta_model(model: nn.Module, materialize_device: torch.device) -> None: + """Replace meta-device params/buffers with empty tensors on ``materialize_device``. + + Triggers FSDP2's ``_apply`` override on wrapped modules, which calls + ``reset_sharded_param`` to refresh FSDP's internal state. + """ + + def _fn(t): + is_meta_dtensor = isinstance(t, DTensor) and t._local_tensor.is_meta + if is_meta_dtensor or (not isinstance(t, DTensor) and t.is_meta): + return torch.empty_like(t, device=materialize_device) + return t.to(materialize_device) + + model._apply(_fn) + + +def _promote_non_dtensor_to_gpu(model: nn.Module, device: torch.device) -> None: + """Move all non-DTensor params + buffers in ``model`` to ``device`` in-place. + + Used after CPU-offload loading: decoder DTensor shards stay on CPU (FSDP2 + streams them to GPU per layer), while root-level plain params and buffers + need to live on GPU so forwards work. + """ + for module in model.modules(): + for name, param in list(module._parameters.items()): + if param is None or isinstance(param, DTensor): + continue + module._parameters[name] = nn.Parameter( + param.data.to(device), requires_grad=param.requires_grad + ) + for name, buf in list(module._buffers.items()): + if buf is None or isinstance(buf, DTensor): + continue + module._buffers[name] = buf.to(device) + + +def _load_via_parallel_read( + ckpt_path: str, + device: torch.device, + rank: int, + world_size: int, + trust_remote_code: bool, + mp_policy, + cpu_offload: bool, + weight_map: dict, +): + """Parallel-read path: each rank reads its share of decoder layers from disk. + + Phase D: each rank reads its owned layers from disk in parallel. + Phase E: per-layer broadcast from owner to all ranks; shard locally into the + FSDP2 DTensor via ``set_model_state_dict(broadcast_from_rank0=False)``. + Phase F: rank 0 reads + broadcasts non-decoder params; loaded into the + unwrapped root. + """ + from accelerate import init_empty_weights from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict + from transformers import AutoConfig, AutoModelForCausalLM + + from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector + + hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=trust_remote_code) + dtype = getattr(hf_config, "torch_dtype", None) or torch.bfloat16 + + # Phase A: meta skeleton on every rank. include_buffers=False so computed + # buffers (e.g. rotary inv_freq, often non-persistent) are built for real on + # CPU here rather than stranded on meta with nothing to materialize them. + with init_empty_weights(include_buffers=False): + model = AutoModelForCausalLM.from_config( + hf_config, torch_dtype=dtype, trust_remote_code=trust_remote_code + ) + model.eval() + if hasattr(model, "config") and hasattr(model.config, "use_cache"): + model.config.use_cache = False + + # Phase B: wrap decoder layers (root NOT wrapped). Discover prefixes. + decoder_layers = LayerActivationCollector.get_decoder_layers(model) + if decoder_layers is None: + raise RuntimeError("Could not auto-detect decoder layers for parallel-read loader.") + module_to_name = {m: n for n, m in model.named_modules()} + layer_prefixes = [module_to_name[layer] + "." for layer in decoder_layers] + model._original_architectures = list(model.config.architectures or []) + + fsdp2_wrap(model, mp_policy=mp_policy, cpu_offload=cpu_offload) - set_model_state_dict( - layer, - full_state_dict, - options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=False), + # Phase C: materialize meta → empty tensors. With cpu_offload, DTensor shards + # land on CPU (FSDP2 streams them to GPU per-layer during forward). + _materialize_meta_model(model, torch.device("cpu") if cpu_offload else device) + + # Phase D: each rank reads its owned layers from disk in parallel. + owned: dict[int, dict] = {} + for layer_idx in range(len(decoder_layers)): + if layer_idx % world_size == rank: + prefix = layer_prefixes[layer_idx] + + def _has_prefix(n: str, p: str = prefix) -> bool: + return n.startswith(p) + + owned[layer_idx] = _read_safetensors_state_dict(ckpt_path, weight_map, _has_prefix) + + # Phase E: per-layer broadcast + shard. Broadcasts run on GPU (NCCL requires + # CUDA tensors); with cpu_offload we copy back to CPU before writing into the + # CPU-resident DTensor shard. + for layer_idx, layer in enumerate(decoder_layers): + src = layer_idx % world_size + layer_state_full = broadcast_state_dict(owned.get(layer_idx), src=src, device=device) + prefix = layer_prefixes[layer_idx] + stripped = {k[len(prefix) :]: v for k, v in layer_state_full.items()} + if cpu_offload: + stripped = {k: v.cpu() for k, v in stripped.items()} + # Slice each rank's local DTensor shard from the full tensor it already holds + # (broadcast_from_rank0=False → no collective needed). + set_model_state_dict( + layer, + stripped, + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=False), + ) + if src == rank: + del owned[layer_idx] + del layer_state_full, stripped + + # Phase F: non-decoder params (embed, lm_head, norm) — rank 0 reads + broadcasts. + layer_prefix_tuple = tuple(layer_prefixes) + non_layer = ( + _read_safetensors_state_dict( + ckpt_path, weight_map, lambda n: not n.startswith(layer_prefix_tuple) + ) + if rank == 0 + else None ) + non_layer = broadcast_state_dict(non_layer, src=0, device=device) + if cpu_offload: + non_layer = {k: v.cpu() for k, v in non_layer.items()} + missing, unexpected = model.load_state_dict(non_layer, strict=False, assign=False) + real_missing = [k for k in missing if not k.startswith(layer_prefix_tuple)] + if real_missing: + warn(f"Missing non-layer keys on rank {rank}: {real_missing[:5]}...") + if unexpected: + warn(f"Unexpected keys in non-layer state dict on rank {rank}: {unexpected[:3]}...") + + if cpu_offload: + _promote_non_dtensor_to_gpu(model, device) + + # Phase G: tie weights, freeze. + if hasattr(model, "tie_weights"): + model.tie_weights() + model.requires_grad_(False) + return model + + +def load_fsdp2_causal_lm( + ckpt_path: str, + device: torch.device, + rank: int, + world_size: int = 1, + *, + trust_remote_code: bool = False, + mp_policy=None, + cpu_offload: bool = False, +): + """Load and FSDP2-shard a HuggingFace causal LM. + + Reusable loader with no dependency on argparse / CLI semantics. -class Fsdp2StateDictAdapter: - """Adapter exposing ``.get_state_dict(model)`` for FSDP2-sharded models. + Default path: **parallel read** — each rank reads its share of decoder + layers from disk in parallel, broadcasts to other ranks. Eliminates the + rank-0 disk bottleneck. Handles ``cpu_offload`` internally. - Satisfies the ``accelerator=`` kwarg of - ``modelopt.torch.export.unified_export_hf._export_transformers_checkpoint``. - Backed by ``get_model_state_dict`` which materializes a full unsharded state - dict on every rank (with CPU offload to bound peak GPU memory during gather). + Fallback path (when no ``model.safetensors.index.json`` exists): rank-0 + ``from_pretrained`` + ``set_model_state_dict`` broadcast via + :func:`fsdp2_shard`. + + Both paths produce identical sharded models (same FSDP2 wrap layout, root + unsharded, decoder layers DTensor-sharded across the FSDP mesh). """ + import json - def get_state_dict(self, model): - """Return the full unsharded state dict gathered from FSDP2 shards (CPU-offloaded).""" - from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + from accelerate import init_empty_weights + from transformers import AutoConfig, AutoModelForCausalLM - return get_model_state_dict( - model, - options=StateDictOptions(full_state_dict=True, cpu_offload=True), + # Resolve ckpt_path: local dir as-is, otherwise HF Hub ID — rank 0 downloads, + # others wait at the barrier so we don't contend on the cache lock. + resolved_path: str | None = ckpt_path + if not os.path.isdir(ckpt_path): + try: + from huggingface_hub import snapshot_download + except ImportError: + snapshot_download = None + if snapshot_download is not None: + if rank == 0: + resolved_path = snapshot_download(ckpt_path) + if is_initialized(): + barrier() + if rank != 0: + resolved_path = snapshot_download(ckpt_path) + else: + resolved_path = None + + index_path = ( + os.path.join(resolved_path, "model.safetensors.index.json") if resolved_path else None + ) + if resolved_path is not None and index_path is not None and os.path.exists(index_path): + with open(index_path) as f: + weight_map = json.load(f)["weight_map"] + return _load_via_parallel_read( + ckpt_path=resolved_path, + device=device, + rank=rank, + world_size=world_size, + trust_remote_code=trust_remote_code, + mp_policy=mp_policy, + cpu_offload=cpu_offload, + weight_map=weight_map, + ) + + # Fallback: rank-0 from_pretrained + broadcast via fsdp2_shard. + hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=trust_remote_code) + dtype = getattr(hf_config, "torch_dtype", None) or torch.bfloat16 + + if rank == 0: + src_model = AutoModelForCausalLM.from_pretrained( + ckpt_path, + torch_dtype="auto", + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, ) + src_model.eval() + src_state_dict = src_model.state_dict() + else: + src_model = None + src_state_dict = {} + + # Meta skeleton on every rank; include_buffers=False keeps computed buffers + # (e.g. rotary inv_freq) real on CPU instead of stranded on meta. + with init_empty_weights(include_buffers=False): + model = AutoModelForCausalLM.from_config( + hf_config, torch_dtype=dtype, trust_remote_code=trust_remote_code + ) + model.eval() + if hasattr(model, "config") and hasattr(model.config, "use_cache"): + model.config.use_cache = False + + sharded = fsdp2_shard( + model, + device, + src_state_dict=src_state_dict, + mp_policy=mp_policy, + cpu_offload=cpu_offload, + ) + del src_model, src_state_dict + return sharded class DistributedProcessGroup: diff --git a/tests/gpu/torch/quantization/test_fsdp2.py b/tests/gpu/torch/quantization/test_fsdp2.py index c5584ece5cf..0fca03ada9d 100644 --- a/tests/gpu/torch/quantization/test_fsdp2.py +++ b/tests/gpu/torch/quantization/test_fsdp2.py @@ -261,3 +261,53 @@ def _test_persistent_materialization(rank, size): def test_persistent_materialization(dist_workers): dist_workers.run(_test_persistent_materialization) + + +def _test_writeback_root_unwrapped(rank, size): + """Writeback works when only the decoder layers are wrapped and the root is left + unsharded -- the layout ``fsdp2_wrap`` produces (root deliberately not wrapped) and + the one ``layerwise_calib`` save()/full_restore() rely on via + ``enable_weight_access_and_writeback(layer, model)``. + + Regression guard for the stale ``isinstance(root_model, FSDPModule)`` assert that + previously required the root itself to be FSDP-wrapped. + """ + from torch.distributed.tensor import DTensor + + from modelopt.torch.quantization.utils import enable_weight_access_and_writeback + + dim = 32 + torch.manual_seed(1) + # Root is a plain container; model[0] stands in for a decoder layer. + model = nn.Sequential(nn.Sequential(nn.Linear(dim, dim), nn.Linear(dim, dim))).cuda(rank) + synchronize_state_dict(model) + + # Wrap ONLY the "decoder layer" -- intentionally NO ``fully_shard(model)`` on the root, + # mirroring fsdp2_wrap. ``root_model`` (model) is therefore not an FSDPModule. + fully_shard(model[0]) + layer = model[0] + inputs = torch.randn(2, dim).cuda(rank) + + # Warmup forward to trigger FSDP2's lazy_init (mirrors layerwise calibration). + model(inputs) + + # Sharded before the context. + assert isinstance(next(iter(layer.parameters())), DTensor) + + # This is the exact call save()/full_restore() make. Before the fix it tripped the + # ``assert isinstance(root_model, FSDPModule)`` because the root is unwrapped. + with enable_weight_access_and_writeback(layer[0], model): + assert not isinstance(layer[0].weight, DTensor) # gathered to a local replicated tensor + ref_weight = layer[0].weight.clone() + layer[0].weight.data.add_(1.0) # mutate -> exercises the writeback path + + # Restored to a sharded DTensor on exit. + assert isinstance(next(iter(layer.parameters())), DTensor) + + # Modification was written back into the shards. + with enable_weight_access_and_writeback(layer[0], model): + assert torch.allclose(layer[0].weight, ref_weight + 1.0) + + +def test_writeback_root_unwrapped(dist_workers): + dist_workers.run(_test_writeback_root_unwrapped) From 9ff4c897679cde154b433083dd34b501b7dc8c93 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 2 Jun 2026 22:23:47 +0000 Subject: [PATCH 08/10] coderabbit reviews Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 5 ++++- examples/llm_ptq/hf_ptq.py | 6 ++++++ modelopt/torch/utils/distributed.py | 27 +++++++++++++++++++-------- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 1d9fadf6368..81976e37942 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -74,7 +74,8 @@ def setup_distributed_args(args): else: args.rank = 0 args.world_size = 1 - args.device = None + # Leave ``args.device`` as parsed from ``--device`` (e.g. "cuda", "cpu"); + # downstream helpers (``get_model`` etc.) consume it directly. args.is_main = True @@ -121,6 +122,7 @@ def load_and_prepare_fsdp2_model( trust_remote_code: bool = False, mp_policy=None, cpu_offload: bool = False, + attn_implementation: str | None = None, ): """CLI-side FSDP2 loader: validate against CLI constraints, then delegate to core. @@ -145,6 +147,7 @@ def load_and_prepare_fsdp2_model( trust_remote_code=trust_remote_code, mp_policy=mp_policy, cpu_offload=cpu_offload, + attn_implementation=attn_implementation, ) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index a1f5d1b4162..f35209a956f 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -330,6 +330,11 @@ def auto_quantize( "Auto Quantization is not supported for pipeline parallel size > 1" ) + assert not (args.auto_quantize_bits and args.use_fsdp2), ( + "Auto Quantization is not supported with --use_fsdp2: mtq.auto_quantize " + "is invoked after fsdp2_shard has frozen every parameter." + ) + qformat_list = args.qformat.split(",") assert qformat_list, "No quantization formats provided" # Check if all provided quantization formats are supported @@ -465,6 +470,7 @@ def load_model(args: argparse.Namespace): args=args, trust_remote_code=args.trust_remote_code, cpu_offload=args.cpu_offload, + attn_implementation=args.attn_implementation, ) elif args.specdec_offline_dataset is not None or not args.low_memory_mode: full_model = get_model( diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index feeb3acc50b..41a76502b8f 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -498,6 +498,7 @@ def _load_via_parallel_read( mp_policy, cpu_offload: bool, weight_map: dict, + attn_implementation: str | None = None, ): """Parallel-read path: each rank reads its share of decoder layers from disk. @@ -513,7 +514,10 @@ def _load_via_parallel_read( from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector - hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=trust_remote_code) + config_kwargs: dict[str, Any] = {"trust_remote_code": trust_remote_code} + if attn_implementation is not None: + config_kwargs["attn_implementation"] = attn_implementation + hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) dtype = getattr(hf_config, "torch_dtype", None) or torch.bfloat16 # Phase A: meta skeleton on every rank. include_buffers=False so computed @@ -612,6 +616,7 @@ def load_fsdp2_causal_lm( trust_remote_code: bool = False, mp_policy=None, cpu_offload: bool = False, + attn_implementation: str | None = None, ): """Load and FSDP2-shard a HuggingFace causal LM. @@ -666,19 +671,25 @@ def load_fsdp2_causal_lm( mp_policy=mp_policy, cpu_offload=cpu_offload, weight_map=weight_map, + attn_implementation=attn_implementation, ) # Fallback: rank-0 from_pretrained + broadcast via fsdp2_shard. - hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=trust_remote_code) + config_kwargs: dict[str, Any] = {"trust_remote_code": trust_remote_code} + if attn_implementation is not None: + config_kwargs["attn_implementation"] = attn_implementation + hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) dtype = getattr(hf_config, "torch_dtype", None) or torch.bfloat16 if rank == 0: - src_model = AutoModelForCausalLM.from_pretrained( - ckpt_path, - torch_dtype="auto", - trust_remote_code=trust_remote_code, - low_cpu_mem_usage=True, - ) + model_kwargs: dict[str, Any] = { + "torch_dtype": "auto", + "trust_remote_code": trust_remote_code, + "low_cpu_mem_usage": True, + } + if attn_implementation is not None: + model_kwargs["attn_implementation"] = attn_implementation + src_model = AutoModelForCausalLM.from_pretrained(ckpt_path, **model_kwargs) src_model.eval() src_state_dict = src_model.state_dict() else: From 999f94c714ea0fb53f41ef408831c08424564140 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 2 Jun 2026 23:22:54 +0000 Subject: [PATCH 09/10] modelopt bot review Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 7 --- examples/llm_ptq/hf_ptq.py | 57 +++---------------- modelopt/torch/export/unified_export_hf.py | 44 +++++++++++++- .../torch/quantization/utils/core_utils.py | 10 +--- modelopt/torch/utils/distributed.py | 44 ++++++-------- 5 files changed, 66 insertions(+), 96 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 81976e37942..34776c1d280 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -51,11 +51,6 @@ SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"] -# --------------------------------------------------------------------------- -# FSDP2 helpers (opt-in via --use_fsdp2 in hf_ptq.py). -# --------------------------------------------------------------------------- - - def setup_distributed_args(args): """Populate ``args.rank`` / ``world_size`` / ``device`` / ``is_main``. @@ -74,8 +69,6 @@ def setup_distributed_args(args): else: args.rank = 0 args.world_size = 1 - # Leave ``args.device`` as parsed from ``--device`` (e.g. "cuda", "cpu"); - # downstream helpers (``get_model`` etc.) consume it directly. args.is_main = True diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index f35209a956f..e89410e2998 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -15,7 +15,6 @@ import argparse import copy -import json import os import random import time @@ -25,7 +24,6 @@ import numpy as np import torch -import torch.distributed as dist from accelerate.hooks import remove_hook_from_module from cast_mxfp4_to_nvfp4 import apply_to_model as apply_cast_mxfp4_to_nvfp4 from cast_mxfp4_to_nvfp4 import force_weight_quantizers_static @@ -71,9 +69,7 @@ has_spec_opt, save_expert_token_count_table, ) -from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model -from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights from modelopt.torch.quantization.utils import is_quantized, patch_fsdp_mp_dtypes @@ -304,7 +300,6 @@ def make_calib_dataloader( include_labels=include_labels, ) if args.use_fsdp2 and calib_dataloader is not None and isinstance(calib_dataloader, DataLoader): - # Each rank sees a disjoint shard of the calibration set. calib_dataloader = shard_dataloader(calib_dataloader, args.rank, args.world_size) return calib_dataloader, first_text_speech_dataset @@ -695,8 +690,6 @@ def mono_quantize( if args.calib_with_images and is_nemotron_vl_model: calibrate_loop = create_vlm_calibration_loop(full_model, calib_dataloader) elif args.use_fsdp2: - # mtq.quantize passes the unwrapped inner module to forward_loop; - # FSDP2 needs hooks fired on the outer wrapped model. calibrate_loop = fsdp_aware_forward_loop( language_model, calib_dataloader, args.device ) @@ -726,41 +719,6 @@ def mono_quantize( warnings.warn("Skipping quantization: model is already quantized.") -def _export_fsdp2_hf_checkpoint(args: argparse.Namespace, full_model, export_path: str) -> None: - """FSDP2-aware HF checkpoint export. - - ``_export_transformers_checkpoint`` detects the FSDP2 wrap internally and - gathers a full unsharded state dict (CPU-offloaded). Rank 0 then writes the - safetensors and patches the saved config with quantization metadata and the - original (pre-FSDP-prefix) architectures list. - """ - post_state_dict, hf_quant_config = _export_transformers_checkpoint(full_model, torch.bfloat16) - - if args.is_main: - export_dir = Path(export_path) - export_dir.mkdir(parents=True, exist_ok=True) - # Save hf_quant_config.json for backward compatibility. - with open(f"{export_dir}/hf_quant_config.json", "w") as f: - json.dump(hf_quant_config, f, indent=4) - hf_quant_config = convert_hf_quant_config_format(hf_quant_config) - full_model.save_pretrained( - export_dir, state_dict=post_state_dict, save_modelopt_state=False - ) - original_config = f"{export_dir}/config.json" - with open(original_config) as f: - config_data = json.load(f) - config_data["quantization_config"] = hf_quant_config - # Strip FSDP-prefixed architectures and restore the original list captured pre-wrap. - original_archs = getattr( - full_model, "_original_architectures", full_model.config.architectures - ) - if original_archs: - config_data["architectures"] = original_archs - with open(original_config, "w") as f: - json.dump(config_data, f, indent=4) - dist.barrier() - - def export_quantized( args: argparse.Namespace, full_model: torch.nn.Module, @@ -854,7 +812,12 @@ def export_quantized( full_model, export_dir=export_path, inplace_mem_efficient=True ) elif args.use_fsdp2: - _export_fsdp2_hf_checkpoint(args, full_model, export_path) + export_hf_checkpoint( + full_model, + torch.bfloat16, + export_dir=export_path, + architectures_override=getattr(full_model, "_original_architectures", None), + ) else: mtp_layer_prefixes, mtp_state_dict = load_mtp_weights( full_model, args.pyt_ckpt_path @@ -876,7 +839,6 @@ def export_quantized( ) # Restore default padding and export the tokenizer as well. - # Under FSDP2 only rank 0 writes to disk. if tokenizer is not None: tokenizer.padding_side = default_padding_side if default_pad_token is not None: @@ -945,7 +907,7 @@ def pre_quantize( trust_remote_code=args.trust_remote_code, ) else: - generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=10) + generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) return preview_input_ids, generated_ids_before_ptq @@ -1002,7 +964,7 @@ def post_quantize( pass elif model_type != "llama4" and not is_nemotron_vl_model: # Our fake quantizer may not be fully compatible with torch.compile. - generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=10) + generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) elif is_nemotron_vl_model and tokenizer is not None: generated_ids_after_ptq = run_nemotron_vl_preview( full_model, @@ -1568,8 +1530,6 @@ def main(args: argparse.Namespace): random.seed(RAND_SEED) np.random.seed(RAND_SEED) - # Populate args.rank / world_size / device / is_main. When --use_fsdp2 is off, - # these default to single-process values so downstream helpers can use them uniformly. setup_distributed_args(args) # launch a memory monitor to read the currently used GPU memory. @@ -1638,6 +1598,5 @@ def main(args: argparse.Namespace): "(multi-format auto-quantize)." ) - # patch_fsdp_mp_dtypes is a no-op when no FSDP2 wrap is applied; safe unconditionally. with patch_fsdp_mp_dtypes(): main(args) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 07e67da2f14..c49cb938303 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -828,10 +828,12 @@ def _export_transformers_checkpoint( _reconstruct_fused_moe_linear(model) if any(isinstance(m, FSDPModule) for m in model.modules()): - # FSDP2-sharded model: gather full state_dict from all ranks (with CPU - # offload to bound peak GPU memory during the gather). + # FSDP2: gather full state_dict to CPU on rank 0 only. quantized_state_dict = get_model_state_dict( - model, options=StateDictOptions(full_state_dict=True, cpu_offload=True) + model, + options=StateDictOptions( + full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True + ), ) else: quantized_state_dict = model.state_dict() @@ -1171,6 +1173,7 @@ def export_hf_checkpoint( components: list[str] | None = None, extra_state_dict: dict[str, torch.Tensor] | None = None, max_shard_size: int | str = "10GB", + architectures_override: list[str] | None = None, **kwargs, ): """Export quantized HuggingFace model checkpoint (transformers or diffusers). @@ -1178,6 +1181,10 @@ def export_hf_checkpoint( This function automatically detects whether the model is from transformers or diffusers and applies the appropriate export logic. + Under ``torch.distributed`` (e.g. FSDP2), all ranks participate in the + collective state-dict gather inside ``_export_transformers_checkpoint``; + only rank 0 writes files. A final barrier syncs the other ranks. + Args: model: The full torch model to export. The actual quantized model may be a submodule. Supports both transformers models (e.g., LlamaForCausalLM) and diffusers @@ -1190,6 +1197,9 @@ def export_hf_checkpoint( to export. If None, all quantized components are exported. extra_state_dict: Extra state dictionary to add to the exported model. max_shard_size: Maximum size of each safetensors shard file. Defaults to "10GB". + architectures_override: If set, written into ``config.json`` as + ``architectures``. Use this to restore the original architectures list + after FSDP2 wrapping, which prefixes class names. **kwargs: Internal-only keyword arguments. Supported key: merged_base_safetensor_path (str, optional). When provided, merges the exported diffusion transformer weights with non-transformer components (VAE, vocoder, text encoders, etc.) @@ -1214,6 +1224,28 @@ def export_hf_checkpoint( try: post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype) + # Under torch.distributed: only rank 0 writes; others sync at the barrier below. + is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + if is_distributed: + rank = torch.distributed.get_rank() + populated = bool(post_state_dict) + if rank == 0 and not populated: + raise RuntimeError( + "Expected rank 0 to receive a populated state_dict from " + "_export_transformers_checkpoint under FSDP2 export; got empty. " + "Check that StateDictOptions(broadcast_from_rank0=True) is honored " + "by this PyTorch's get_model_state_dict." + ) + if rank != 0 and populated: + raise RuntimeError( + f"Expected rank {rank} to receive an empty state_dict from " + "_export_transformers_checkpoint under FSDP2 export (broadcast_from_rank0=True); " + "got populated. PyTorch's get_model_state_dict semantics may have changed." + ) + if rank != 0: + torch.distributed.barrier() + return + # Only treat the export as quantized when at least one quant_algo field is set. # get_quant_config always returns a dict (even for sparsity-only or unmodified models), # so emitting hf_quant_config.json unconditionally produces a file with @@ -1272,6 +1304,9 @@ def export_hf_checkpoint( if sparse_attn_config is not None: config_data["sparse_attention_config"] = sparse_attn_config + if architectures_override: + config_data["architectures"] = architectures_override + with open(original_config, "w") as file: json.dump(config_data, file, indent=4) @@ -1281,3 +1316,6 @@ def export_hf_checkpoint( " can be saved with torch.save for further inspection." ) raise e + + if is_distributed: + torch.distributed.barrier() diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 2e27b67ea13..d571c05c9e5 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -477,10 +477,6 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn. TP DTensor under this context. """ assert not hasattr(module, "_hf_hook"), "We dont support FSDP2 with HF accelerate hooks" - # Note: ``root_model`` need not itself be an ``FSDPModule``. ``fsdp2_wrap`` shards - # only the decoder layers and leaves the root unsharded (see its docstring), so the - # real precondition is that ``module`` is *under* FSDP2 — enforced by the assert - # below, since ``_get_enclosing_fsdp_module`` only ever returns ``FSDPModule`` instances. fsdp_module = _get_enclosing_fsdp_module(module, root_model) assert fsdp_module is not None, "Module is not wrapped by FSDP2" fsdp_device_mesh = _get_fsdp2_mesh(fsdp_module) @@ -502,9 +498,7 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn. device_mesh=original_device_mesh, ) local_replicated = collected.to_local() - # With FSDP2 CPUOffloadPolicy, the DTensor's local shard lives on CPU, so the - # gathered local tensor is also on CPU. Mirror it onto the current GPU so - # calibration forwards (activations on GPU) see a same-device weight. + # cpu_offload: gathered shard is on CPU; mirror to GPU for forward. if local_replicated.device.type == "cpu" and torch.cuda.is_available(): working_local = local_replicated.to(torch.cuda.current_device()) originals[name] = ( @@ -539,8 +533,6 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn. gpu_working, ) in originals.items(): if cpu_local is not None: - # Mirror GPU-resident modifications back into the CPU-resident DTensor local shard - # so the subsequent redistribute-back sees the up-to-date values. cpu_local.data.copy_(gpu_working.data.to(cpu_local.device)) original_param.to_local().data.copy_( collected.redistribute( diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index 41a76502b8f..3c7fbed0126 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -222,11 +222,6 @@ def cleanup(): torch.distributed.destroy_process_group() -# --------------------------------------------------------------------------- -# FSDP2 helpers — used by examples/llm_ptq to run PTQ calibration under FSDP2. -# --------------------------------------------------------------------------- - - def fsdp2_wrap( model, override_cls_name: str | None = None, @@ -261,7 +256,7 @@ def fsdp2_wrap( raise RuntimeError(f"No modules of class {override_cls_name!r} found in model") else: layers = LayerActivationCollector.get_decoder_layers(model) - if not layers: + if layers is None: raise RuntimeError( "Could not auto-detect decoder layers; pass override_cls_name explicitly." ) @@ -317,10 +312,6 @@ def fsdp2_shard(model, device, src_state_dict=None, mp_policy=None, cpu_offload= fsdp2_wrap(model, mp_policy=mp_policy, cpu_offload=cpu_offload) - # With CPU offload, FSDP2 requires DTensor params on CPU at lazy_init time - # (it streams them to GPU per-layer during forward). Also, set_model_state_dict - # rejects mixed-device models — so materialize everything on CPU first, - # broadcast, then promote non-DTensor params + buffers to GPU after. _materialize_meta_model(model, torch.device("cpu") if cpu_offload else device) if src_state_dict is not None: @@ -347,6 +338,9 @@ def shard_dataloader(loader, rank: int, world_size: int): ``drop_last=False`` keeps per-rank batch counts equal (else a rank exits calibration early and hangs the others on FSDP2 collectives), at the cost of the sampler repeating up to ``world_size - 1`` samples to pad the even split. + + Forwards all non-sampler DataLoader settings from ``loader`` (workers, pinning, + prefetch, init fn, generator, ...). """ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -365,6 +359,13 @@ def shard_dataloader(loader, rank: int, world_size: int): collate_fn=loader.collate_fn, num_workers=loader.num_workers, pin_memory=loader.pin_memory, + timeout=loader.timeout, + worker_init_fn=loader.worker_init_fn, + multiprocessing_context=loader.multiprocessing_context, + generator=loader.generator, + prefetch_factor=loader.prefetch_factor, + persistent_workers=loader.persistent_workers, + pin_memory_device=getattr(loader, "pin_memory_device", ""), ) @@ -520,9 +521,7 @@ def _load_via_parallel_read( hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) dtype = getattr(hf_config, "torch_dtype", None) or torch.bfloat16 - # Phase A: meta skeleton on every rank. include_buffers=False so computed - # buffers (e.g. rotary inv_freq, often non-persistent) are built for real on - # CPU here rather than stranded on meta with nothing to materialize them. + # include_buffers=False keeps computed buffers (rotary inv_freq, etc.) real on CPU. with init_empty_weights(include_buffers=False): model = AutoModelForCausalLM.from_config( hf_config, torch_dtype=dtype, trust_remote_code=trust_remote_code @@ -531,7 +530,6 @@ def _load_via_parallel_read( if hasattr(model, "config") and hasattr(model.config, "use_cache"): model.config.use_cache = False - # Phase B: wrap decoder layers (root NOT wrapped). Discover prefixes. decoder_layers = LayerActivationCollector.get_decoder_layers(model) if decoder_layers is None: raise RuntimeError("Could not auto-detect decoder layers for parallel-read loader.") @@ -541,11 +539,9 @@ def _load_via_parallel_read( fsdp2_wrap(model, mp_policy=mp_policy, cpu_offload=cpu_offload) - # Phase C: materialize meta → empty tensors. With cpu_offload, DTensor shards - # land on CPU (FSDP2 streams them to GPU per-layer during forward). _materialize_meta_model(model, torch.device("cpu") if cpu_offload else device) - # Phase D: each rank reads its owned layers from disk in parallel. + # Each rank reads only its owned layers from disk. owned: dict[int, dict] = {} for layer_idx in range(len(decoder_layers)): if layer_idx % world_size == rank: @@ -556,9 +552,7 @@ def _has_prefix(n: str, p: str = prefix) -> bool: owned[layer_idx] = _read_safetensors_state_dict(ckpt_path, weight_map, _has_prefix) - # Phase E: per-layer broadcast + shard. Broadcasts run on GPU (NCCL requires - # CUDA tensors); with cpu_offload we copy back to CPU before writing into the - # CPU-resident DTensor shard. + # Per-layer broadcast from owner, then shard locally. for layer_idx, layer in enumerate(decoder_layers): src = layer_idx % world_size layer_state_full = broadcast_state_dict(owned.get(layer_idx), src=src, device=device) @@ -566,8 +560,6 @@ def _has_prefix(n: str, p: str = prefix) -> bool: stripped = {k[len(prefix) :]: v for k, v in layer_state_full.items()} if cpu_offload: stripped = {k: v.cpu() for k, v in stripped.items()} - # Slice each rank's local DTensor shard from the full tensor it already holds - # (broadcast_from_rank0=False → no collective needed). set_model_state_dict( layer, stripped, @@ -577,7 +569,7 @@ def _has_prefix(n: str, p: str = prefix) -> bool: del owned[layer_idx] del layer_state_full, stripped - # Phase F: non-decoder params (embed, lm_head, norm) — rank 0 reads + broadcasts. + # Non-decoder params (embed, lm_head, norm): rank 0 reads + broadcasts. layer_prefix_tuple = tuple(layer_prefixes) non_layer = ( _read_safetensors_state_dict( @@ -599,7 +591,6 @@ def _has_prefix(n: str, p: str = prefix) -> bool: if cpu_offload: _promote_non_dtensor_to_gpu(model, device) - # Phase G: tie weights, freeze. if hasattr(model, "tie_weights"): model.tie_weights() model.requires_grad_(False) @@ -638,8 +629,7 @@ def load_fsdp2_causal_lm( from accelerate import init_empty_weights from transformers import AutoConfig, AutoModelForCausalLM - # Resolve ckpt_path: local dir as-is, otherwise HF Hub ID — rank 0 downloads, - # others wait at the barrier so we don't contend on the cache lock. + # HF Hub ID: rank 0 downloads, others wait at the barrier. resolved_path: str | None = ckpt_path if not os.path.isdir(ckpt_path): try: @@ -696,8 +686,6 @@ def load_fsdp2_causal_lm( src_model = None src_state_dict = {} - # Meta skeleton on every rank; include_buffers=False keeps computed buffers - # (e.g. rotary inv_freq) real on CPU instead of stranded on meta. with init_empty_weights(include_buffers=False): model = AutoModelForCausalLM.from_config( hf_config, torch_dtype=dtype, trust_remote_code=trust_remote_code From 5a75f62e9fab910dcde2603ff88f1386028f66ae Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 2 Jun 2026 23:48:45 +0000 Subject: [PATCH 10/10] clean up (1) Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 53 +++++++++++++--------- examples/llm_ptq/hf_ptq.py | 29 +++--------- modelopt/torch/export/unified_export_hf.py | 7 --- modelopt/torch/utils/distributed.py | 11 ++--- 4 files changed, 43 insertions(+), 57 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 34776c1d280..bbc594b61b2 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -52,12 +52,7 @@ def setup_distributed_args(args): - """Populate ``args.rank`` / ``world_size`` / ``device`` / ``is_main``. - - When ``--use_fsdp2`` is set, initializes the distributed process group and - pins this rank's CUDA device. When the flag is off, fills no-op values so - downstream helpers can use ``args.is_main`` and ``args.rank`` uniformly. - """ + """Set ``args.rank``/``world_size``/``device``/``is_main`` (single-process if FSDP2 off).""" from modelopt.torch.utils import distributed as dist_utils if getattr(args, "use_fsdp2", False): @@ -73,20 +68,37 @@ def setup_distributed_args(args): def cleanup_distributed(args): - """Tear down the distributed process group if FSDP2 set it up.""" + """Destroy the process group if ``--use_fsdp2`` set it up.""" from modelopt.torch.utils import distributed as dist_utils if getattr(args, "use_fsdp2", False): dist_utils.cleanup() -def validate_fsdp2_supported(args, config): - """Raise NotImplementedError if the model config is not FSDP2-supported in v1. +def _checkpoint_has_mtp_weights(model_path: str) -> bool: + """Return True if the checkpoint's safetensors index advertises MTP weights.""" + candidates = [Path(model_path) / "model.safetensors.index.json"] + try: + from huggingface_hub import try_to_load_from_cache - Called after ``AutoConfig.from_pretrained`` (cheap) and before any heavy - loading work, so unsupported configurations fail fast with a clear message - instead of crashing later inside a DTensor traceback. - """ + cached = try_to_load_from_cache(model_path, "model.safetensors.index.json") + except ImportError: + cached = None + if cached: + candidates.append(Path(cached)) + for index_file in candidates: + if not index_file.exists(): + continue + try: + weight_map = json.load(open(index_file)).get("weight_map", {}) + except (OSError, json.JSONDecodeError): + continue + return any("mtp" in k or "mtp" in v for k, v in weight_map.items()) + return False + + +def validate_fsdp2_supported(args, config): + """Raise ``NotImplementedError`` for model/CLI combos the FSDP2 path doesn't support yet.""" issues = [] if "vila" in args.pyt_ckpt_path.lower(): issues.append("VILA (custom builder + non-standard layer layout)") @@ -98,6 +110,11 @@ def validate_fsdp2_supported(args, config): issues.append("speculative decoding (--specdec_offline_dataset)") if getattr(args, "low_memory_mode", False): issues.append("--low_memory_mode (redundant with FSDP2)") + if _checkpoint_has_mtp_weights(args.pyt_ckpt_path): + issues.append( + "MTP (Multi-Token Prediction) weights — the FSDP2 loader doesn't " + "carry them through; the exported checkpoint would be missing MTP layers" + ) if issues: raise NotImplementedError( "--use_fsdp2 does not support:\n - " @@ -117,15 +134,7 @@ def load_and_prepare_fsdp2_model( cpu_offload: bool = False, attn_implementation: str | None = None, ): - """CLI-side FSDP2 loader: validate against CLI constraints, then delegate to core. - - Runs :func:`validate_fsdp2_supported` to enforce example-script policy (VILA, - multimodal/VL, ``--low_memory_mode``, ``--specdec_offline_dataset``, etc.), - then calls :func:`modelopt.torch.utils.distributed.load_fsdp2_causal_lm`. - - The core loader is fully reusable (no argparse coupling); this wrapper exists - to keep the CLI policy at the example edge. - """ + """Validate CLI constraints, then delegate to :func:`load_fsdp2_causal_lm`.""" from modelopt.torch.utils.distributed import load_fsdp2_causal_lm if args is not None: diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index e89410e2998..5346fc34222 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -105,13 +105,6 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: break -def _nvfp4_max_cfg(*, layerwise: bool) -> dict[str, Any]: - """NVFP4 quant config with explicit max calibration and a layerwise toggle.""" - cfg = copy.deepcopy(mtq.NVFP4_DEFAULT_CFG) - cfg["algorithm"] = {"method": "max", "layerwise": layerwise} - return cfg - - QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { "int8": mtq.INT8_DEFAULT_CFG, "int8_sq": mtq.INT8_SMOOTHQUANT_CFG, @@ -120,8 +113,6 @@ def _nvfp4_max_cfg(*, layerwise: bool) -> dict[str, Any]: "int4_awq": mtq.INT4_AWQ_CFG, "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG, "nvfp4": mtq.NVFP4_DEFAULT_CFG, - "nvfp4_max": _nvfp4_max_cfg(layerwise=False), - "nvfp4_max_layerwise": _nvfp4_max_cfg(layerwise=True), "nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG, "nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG, "fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, @@ -811,20 +802,14 @@ def export_quantized( export_hf_vllm_fq_checkpoint( full_model, export_dir=export_path, inplace_mem_efficient=True ) - elif args.use_fsdp2: - export_hf_checkpoint( - full_model, - torch.bfloat16, - export_dir=export_path, - architectures_override=getattr(full_model, "_original_architectures", None), - ) else: - mtp_layer_prefixes, mtp_state_dict = load_mtp_weights( - full_model, args.pyt_ckpt_path - ) - - if mtp_layer_prefixes: - full_model._mtp_layer_prefixes = mtp_layer_prefixes + mtp_state_dict = None + if not args.use_fsdp2: + mtp_layer_prefixes, mtp_state_dict = load_mtp_weights( + full_model, args.pyt_ckpt_path + ) + if mtp_layer_prefixes: + full_model._mtp_layer_prefixes = mtp_layer_prefixes export_hf_checkpoint( full_model, diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index c49cb938303..4d00a705e68 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -1173,7 +1173,6 @@ def export_hf_checkpoint( components: list[str] | None = None, extra_state_dict: dict[str, torch.Tensor] | None = None, max_shard_size: int | str = "10GB", - architectures_override: list[str] | None = None, **kwargs, ): """Export quantized HuggingFace model checkpoint (transformers or diffusers). @@ -1197,9 +1196,6 @@ def export_hf_checkpoint( to export. If None, all quantized components are exported. extra_state_dict: Extra state dictionary to add to the exported model. max_shard_size: Maximum size of each safetensors shard file. Defaults to "10GB". - architectures_override: If set, written into ``config.json`` as - ``architectures``. Use this to restore the original architectures list - after FSDP2 wrapping, which prefixes class names. **kwargs: Internal-only keyword arguments. Supported key: merged_base_safetensor_path (str, optional). When provided, merges the exported diffusion transformer weights with non-transformer components (VAE, vocoder, text encoders, etc.) @@ -1304,9 +1300,6 @@ def export_hf_checkpoint( if sparse_attn_config is not None: config_data["sparse_attention_config"] = sparse_attn_config - if architectures_override: - config_data["architectures"] = architectures_override - with open(original_config, "w") as file: json.dump(config_data, file, indent=4) diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index 3c7fbed0126..ba4ae4b693f 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -265,10 +265,15 @@ def fsdp2_wrap( fsdp_kwargs["mp_policy"] = mp_policy if cpu_offload: fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() + # Snapshot and restore config.architectures around fully_shard, in case the + # wrap mutates the class name that downstream save_pretrained reads. + original_architectures = list(getattr(model.config, "architectures", []) or []) for layer in layers: if device is not None: layer.to(device) fully_shard(layer, **fsdp_kwargs) + if original_architectures: + model.config.architectures = original_architectures return model @@ -286,9 +291,6 @@ def fsdp2_shard(model, device, src_state_dict=None, mp_policy=None, cpu_offload= and must therefore be ``None`` on *all* ranks (e.g. sharding a model that will be loaded later) — mixing ``None`` with a populated dict across ranks will hang. - Also sets ``model._original_architectures`` (FSDP2 wrapping can clobber - ``config.architectures``, which export reads back). - Set ``cpu_offload=True`` to attach FSDP2's ``CPUOffloadPolicy`` to wrapped layers (each rank's shard lives on CPU between forwards). See ``fsdp2_wrap`` docstring for the trade-off. @@ -308,8 +310,6 @@ def fsdp2_shard(model, device, src_state_dict=None, mp_policy=None, cpu_offload= """ from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict - model._original_architectures = list(model.config.architectures or []) - fsdp2_wrap(model, mp_policy=mp_policy, cpu_offload=cpu_offload) _materialize_meta_model(model, torch.device("cpu") if cpu_offload else device) @@ -535,7 +535,6 @@ def _load_via_parallel_read( raise RuntimeError("Could not auto-detect decoder layers for parallel-read loader.") module_to_name = {m: n for n, m in model.named_modules()} layer_prefixes = [module_to_name[layer] + "." for layer in decoder_layers] - model._original_architectures = list(model.config.architectures or []) fsdp2_wrap(model, mp_policy=mp_policy, cpu_offload=cpu_offload)