diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 511eed5d774..880561090c5 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -370,32 +370,42 @@ mtq.calibrate(model, algorithm="max", forward_loop=calibrate_loop) ## Multi-Node Post-Training Quantization with FSDP2 -ModelOpt enables quantization of LLMs across multiple GPU nodes using various quantization formats. It leverages HuggingFace's Accelerate library and FSDP2 for distributed model sharding and calibration. +ModelOpt enables quantization of LLMs across multiple GPU nodes using FSDP2 for distributed model sharding and calibration, exposed via the `--use_fsdp2` flag on the standard `hf_ptq.py` entry point. ### Usage -For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user specific requirements. +Single-node (multiple GPUs): -On each node run the following command: +```bash +torchrun --standalone --nproc_per_node= hf_ptq.py \ + --pyt_ckpt_path \ + --qformat \ + --kv_cache_qformat \ + --batch_size \ + --calib_size \ + --export_path \ + --use_fsdp2 +``` + +Multi-node (run on each node): ```bash -accelerate launch --config_file fsdp2.yaml \ - --num_machines= \ - --machine_rank= \ - --main_process_ip= \ - --main_process_port= \ - --fsdp_transformer_layer_cls_to_wrap= - multinode_ptq.py \ +torchrun \ + --nnodes= --node_rank= \ + --master_addr= --master_port= \ + --nproc_per_node= \ + hf_ptq.py \ --pyt_ckpt_path \ - --qformat \ + --qformat \ --kv_cache_qformat \ --batch_size \ --calib_size \ - --dataset \ --export_path \ - --trust_remote_code + --use_fsdp2 ``` +For layerwise calibration (amortizes cross-node all-gather cost across all calibration batches), use `--qformat nvfp4_max_layerwise`. + The exported checkpoint can be deployed using TensorRT-LLM/ vLLM/ SGLang. For more details refer to the [deployment section](#deployment) of this document. > *Performance Note: FSDP2 is designed for training workloads and may result in longer calibration and export times. For faster calibration, maximize the batch size based on available GPU memory and choose the right number of GPUs to avoid unnecessary communication.* diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 9455157645c..bbc594b61b2 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -51,6 +51,108 @@ SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"] +def setup_distributed_args(args): + """Set ``args.rank``/``world_size``/``device``/``is_main`` (single-process if FSDP2 off).""" + from modelopt.torch.utils import distributed as dist_utils + + if getattr(args, "use_fsdp2", False): + dist_utils.setup() + args.rank = dist_utils.rank() + args.world_size = dist_utils.size() + args.device = torch.device(f"cuda:{dist_utils.local_rank()}") + args.is_main = args.rank == 0 + else: + args.rank = 0 + args.world_size = 1 + args.is_main = True + + +def cleanup_distributed(args): + """Destroy the process group if ``--use_fsdp2`` set it up.""" + from modelopt.torch.utils import distributed as dist_utils + + if getattr(args, "use_fsdp2", False): + dist_utils.cleanup() + + +def _checkpoint_has_mtp_weights(model_path: str) -> bool: + """Return True if the checkpoint's safetensors index advertises MTP weights.""" + candidates = [Path(model_path) / "model.safetensors.index.json"] + try: + from huggingface_hub import try_to_load_from_cache + + cached = try_to_load_from_cache(model_path, "model.safetensors.index.json") + except ImportError: + cached = None + if cached: + candidates.append(Path(cached)) + for index_file in candidates: + if not index_file.exists(): + continue + try: + weight_map = json.load(open(index_file)).get("weight_map", {}) + except (OSError, json.JSONDecodeError): + continue + return any("mtp" in k or "mtp" in v for k, v in weight_map.items()) + return False + + +def validate_fsdp2_supported(args, config): + """Raise ``NotImplementedError`` for model/CLI combos the FSDP2 path doesn't support yet.""" + issues = [] + if "vila" in args.pyt_ckpt_path.lower(): + issues.append("VILA (custom builder + non-standard layer layout)") + if is_nemotron_vl(config) or _is_multimodal_config(config): + issues.append("multimodal / VL models (decoder layers not auto-detectable)") + if getattr(config, "quantization_config", None) is not None: + issues.append("pack-quantized / compressed-tensors checkpoints") + if getattr(args, "specdec_offline_dataset", None) is not None: + issues.append("speculative decoding (--specdec_offline_dataset)") + if getattr(args, "low_memory_mode", False): + issues.append("--low_memory_mode (redundant with FSDP2)") + if _checkpoint_has_mtp_weights(args.pyt_ckpt_path): + issues.append( + "MTP (Multi-Token Prediction) weights — the FSDP2 loader doesn't " + "carry them through; the exported checkpoint would be missing MTP layers" + ) + if issues: + raise NotImplementedError( + "--use_fsdp2 does not support:\n - " + + "\n - ".join(issues) + + "\nRemove --use_fsdp2 or use a standard causal-LM checkpoint." + ) + + +def load_and_prepare_fsdp2_model( + ckpt_path: str, + device: torch.device, + rank: int, + world_size: int = 1, + args=None, + trust_remote_code: bool = False, + mp_policy=None, + cpu_offload: bool = False, + attn_implementation: str | None = None, +): + """Validate CLI constraints, then delegate to :func:`load_fsdp2_causal_lm`.""" + from modelopt.torch.utils.distributed import load_fsdp2_causal_lm + + if args is not None: + hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=trust_remote_code) + validate_fsdp2_supported(args, hf_config) + + return load_fsdp2_causal_lm( + ckpt_path, + device, + rank, + world_size, + trust_remote_code=trust_remote_code, + mp_policy=mp_policy, + cpu_offload=cpu_offload, + attn_implementation=attn_implementation, + ) + + def run_nemotron_vl_preview( full_model, tokenizer, diff --git a/examples/llm_ptq/fsdp2.yaml b/examples/llm_ptq/fsdp2.yaml deleted file mode 100644 index 646d63f9e67..00000000000 --- a/examples/llm_ptq/fsdp2.yaml +++ /dev/null @@ -1,30 +0,0 @@ -# ============================================================================= -# FSDP Configuration for running LLM PTQ on multinode setup. This file is consumed by examples/llm_ptq/multinode_ptq.py -# ============================================================================= - -compute_environment: LOCAL_MACHINE -debug: false -distributed_type: FSDP -downcast_bf16: 'no' -enable_cpu_affinity: false -fsdp_config: - fsdp_activation_checkpointing: false - fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP - fsdp_cpu_ram_efficient_loading: true - fsdp_offload_params: false - fsdp_reshard_after_forward: true - fsdp_state_dict_type: FULL_STATE_DICT - fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer - fsdp_use_orig_params: true - fsdp_version: 2 -machine_rank: 0 -main_training_function: main -mixed_precision: 'no' -num_machines: 2 -num_processes: 16 -rdzv_backend: c10d -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d52c3ee40bb..5346fc34222 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -15,6 +15,7 @@ import argparse import copy +import os import random import time import warnings @@ -28,6 +29,7 @@ from cast_mxfp4_to_nvfp4 import force_weight_quantizers_static from example_utils import ( build_quant_cfg, + cleanup_distributed, copy_custom_model_files, create_vlm_calibration_loop, get_model, @@ -35,10 +37,12 @@ get_tokenizer, is_enc_dec, is_nemotron_vl, + load_and_prepare_fsdp2_model, load_mtp_weights, needs_checkpoint_path_update, resolve_checkpoint_dir, run_nemotron_vl_preview, + setup_distributed_args, ) from torch.utils.data import DataLoader from transformers import ( @@ -68,7 +72,7 @@ from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights -from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.quantization.utils import is_quantized, patch_fsdp_mp_dtypes from modelopt.torch.speculative.eagle.utils import ( EagleOfflineDataCollator, OfflineSupervisedDataset, @@ -79,6 +83,7 @@ get_max_batch_size, get_supported_datasets, ) +from modelopt.torch.utils.distributed import fsdp_aware_forward_loop, shard_dataloader from modelopt.torch.utils.memory_monitor import launch_memory_monitor from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader @@ -285,6 +290,8 @@ def make_calib_dataloader( device=device, include_labels=include_labels, ) + if args.use_fsdp2 and calib_dataloader is not None and isinstance(calib_dataloader, DataLoader): + calib_dataloader = shard_dataloader(calib_dataloader, args.rank, args.world_size) return calib_dataloader, first_text_speech_dataset @@ -309,6 +316,11 @@ def auto_quantize( "Auto Quantization is not supported for pipeline parallel size > 1" ) + assert not (args.auto_quantize_bits and args.use_fsdp2), ( + "Auto Quantization is not supported with --use_fsdp2: mtq.auto_quantize " + "is invoked after fsdp2_shard has frozen every parameter." + ) + qformat_list = args.qformat.split(",") assert qformat_list, "No quantization formats provided" # Check if all provided quantization formats are supported @@ -435,7 +447,18 @@ def forward_step(model, batch): def load_model(args: argparse.Namespace): # If low memory mode is enabled, we compress the model while loading the HF checkpoint. calibration_only = False - if args.specdec_offline_dataset is not None or not args.low_memory_mode: + if args.use_fsdp2: + full_model = load_and_prepare_fsdp2_model( + ckpt_path=args.pyt_ckpt_path, + device=args.device, + rank=args.rank, + world_size=args.world_size, + args=args, + trust_remote_code=args.trust_remote_code, + cpu_offload=args.cpu_offload, + attn_implementation=args.attn_implementation, + ) + elif args.specdec_offline_dataset is not None or not args.low_memory_mode: full_model = get_model( args.pyt_ckpt_path, args.device, @@ -477,9 +500,12 @@ def load_model(args: argparse.Namespace): model_type = get_model_type(full_model) - device = full_model.device - if hasattr(full_model, "model"): - device = full_model.model.device + if args.use_fsdp2: + device = args.device + else: + device = full_model.device + if hasattr(full_model, "model"): + device = full_model.model.device processor = None tokenizer = None language_model = full_model @@ -654,6 +680,10 @@ def mono_quantize( # Those kwargs must be consumed by the *full* VLM model, not the extracted language_model. if args.calib_with_images and is_nemotron_vl_model: calibrate_loop = create_vlm_calibration_loop(full_model, calib_dataloader) + elif args.use_fsdp2: + calibrate_loop = fsdp_aware_forward_loop( + language_model, calib_dataloader, args.device + ) else: calibrate_loop = create_forward_loop( dataloader=calib_dataloader, @@ -773,12 +803,13 @@ def export_quantized( full_model, export_dir=export_path, inplace_mem_efficient=True ) else: - mtp_layer_prefixes, mtp_state_dict = load_mtp_weights( - full_model, args.pyt_ckpt_path - ) - - if mtp_layer_prefixes: - full_model._mtp_layer_prefixes = mtp_layer_prefixes + mtp_state_dict = None + if not args.use_fsdp2: + mtp_layer_prefixes, mtp_state_dict = load_mtp_weights( + full_model, args.pyt_ckpt_path + ) + if mtp_layer_prefixes: + full_model._mtp_layer_prefixes = mtp_layer_prefixes export_hf_checkpoint( full_model, @@ -797,18 +828,21 @@ def export_quantized( tokenizer.padding_side = default_padding_side if default_pad_token is not None: tokenizer.pad_token = default_pad_token - tokenizer.save_pretrained(export_path) + if args.is_main: + tokenizer.save_pretrained(export_path) # Copy custom model files (Python files and JSON configs) if trust_remote_code is used. # This must run AFTER tokenizer.save_pretrained() so original tokenizer files # from the source checkpoint take precedence over regenerated ones (which may # differ in format due to newer transformers versions). - copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code) + if args.is_main: + copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code) end_time = time.time() - print( - f"Quantized model exported to: {export_path}. Total time used {end_time - start_time}s" - ) + if args.is_main: + print( + f"Quantized model exported to: {export_path}. Total time used {end_time - start_time}s" + ) def pre_quantize( @@ -1337,6 +1371,28 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) + parser.add_argument( + "--use_fsdp2", + action="store_true", + help=( + "Run calibration under PyTorch FSDP2 (requires launching with torchrun). " + "Takes precedence over --use_seq_device_map. " + "v1 limitations: standard causal-LM only (no VILA / pack-quantized / speculative / " + "auto-quantize / sparsity / VLM). Rank 0 holds the full model in CPU briefly " + "during the broadcast step; other ranks pay ~0 CPU." + ), + ) + parser.add_argument( + "--cpu_offload", + action="store_true", + help=( + "Only valid with --use_fsdp2. Attach FSDP2's CPUOffloadPolicy so each " + "rank's decoder shard lives on CPU between forwards (streamed to GPU " + "per-layer). Frees GPU memory at the cost of PCIe traffic per layer per " + "batch. Worth it for trillion-param models or tight-GPU setups; usually " + "slows down runs where the model already fits comfortably." + ), + ) parser.add_argument( "--verbose", help="Print verbose output (e.g. quantization summary). Disable by --no-verbose.", @@ -1441,6 +1497,14 @@ def parse_args() -> argparse.Namespace: if args.specdec_offline_dataset is not None and args.low_memory_mode: parser.error("--specdec_offline_dataset is not compatible with --low_memory_mode.") + if args.use_fsdp2 and args.use_seq_device_map: + warnings.warn("--use_seq_device_map is ignored when --use_fsdp2 is set.") + args.use_seq_device_map = False + if args.use_fsdp2 and os.environ.get("RANK") is None: + parser.error("--use_fsdp2 requires launching with torchrun") + if args.cpu_offload and not args.use_fsdp2: + parser.error("--cpu_offload requires --use_fsdp2") + return args @@ -1451,6 +1515,8 @@ def main(args: argparse.Namespace): random.seed(RAND_SEED) np.random.seed(RAND_SEED) + setup_distributed_args(args) + # launch a memory monitor to read the currently used GPU memory. launch_memory_monitor() @@ -1487,6 +1553,8 @@ def main(args: argparse.Namespace): device, ) + cleanup_distributed(args) + if __name__ == "__main__": args = parse_args() @@ -1515,4 +1583,5 @@ def main(args: argparse.Namespace): "(multi-format auto-quantize)." ) - main(args) + with patch_fsdp_mp_dtypes(): + main(args) diff --git a/examples/llm_ptq/multinode_ptq.py b/examples/llm_ptq/multinode_ptq.py deleted file mode 100644 index 93ef21ea4d4..00000000000 --- a/examples/llm_ptq/multinode_ptq.py +++ /dev/null @@ -1,390 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Multi-node PTQ (Post-Training Quantization) with FSDP2 support.""" - -import argparse -import json -import os -import random -import time -import warnings -from pathlib import Path -from typing import Any - -import numpy as np -import torch -import torch.nn as nn -from accelerate import Accelerator -from example_utils import build_quant_cfg, get_tokenizer -from tqdm import tqdm -from transformers import AutoModelForCausalLM, PreTrainedTokenizer, PreTrainedTokenizerFast - -import modelopt.torch.opt as mto -import modelopt.torch.quantization as mtq -from modelopt.torch.export import get_model_type -from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format -from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint -from modelopt.torch.quantization.config import need_calibration -from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes -from modelopt.torch.utils.dataset_utils import get_dataset_dataloader, get_supported_datasets - -# Constants -RAND_SEED = 1234 - -QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { - "int8": mtq.INT8_DEFAULT_CFG, - "int4_awq": mtq.INT4_AWQ_CFG, - "fp8": mtq.FP8_DEFAULT_CFG, - "nvfp4": mtq.NVFP4_DEFAULT_CFG, - "nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG, - "w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG, - "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, - "nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG, - "nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG, -} - -KV_QUANT_CFG_CHOICES = { - "none": "none", - "fp8": "FP8_KV_CFG", - "nvfp4": "NVFP4_KV_CFG", - "nvfp4_affine": "NVFP4_AFFINE_KV_CFG", -} - - -# Enable HuggingFace checkpointing -mto.enable_huggingface_checkpointing() - - -def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="Multi-node post-training quantization with FSDP2") - - parser.add_argument( - "--pyt_ckpt_path", - required=True, - help="Path to PyTorch checkpoint", - ) - parser.add_argument( - "--qformat", - default="fp8", - choices=QUANT_CFG_CHOICES.keys(), - help="Quantization format", - ) - parser.add_argument( - "--kv_cache_qformat", - default="fp8", - choices=list(KV_QUANT_CFG_CHOICES.keys()), - help="KV cache quantization format", - ) - parser.add_argument( - "--batch_size", - type=int, - default=1, - help="Batch size for calibration", - ) - parser.add_argument( - "--calib_size", - type=str, - default="512", - help="Comma-separated list of calibration sizes per dataset", - ) - parser.add_argument( - "--dataset", - help=( - f"name of a dataset, or a comma separated list of datasets. " - f"dataset choices are {get_supported_datasets()}" - ), - type=str, - default=None, - ) - parser.add_argument( - "--export_path", - default="exported_model", - help="Directory to export the quantized model", - ) - parser.add_argument( - "--trust_remote_code", - action="store_true", - help="Trust remote code for HuggingFace models", - ) - parser.add_argument("--awq_block_size", default=0, type=int) - - args = parser.parse_args() - - # Parse comma-separated lists - args.dataset = args.dataset.split(",") if args.dataset else None - args.calib_size = [int(x) for x in args.calib_size.split(",")] - - return args - - -def load_and_prepare_model( - model_path: str, - calib_dataloader: torch.utils.data.DataLoader, - accelerator: Accelerator, - trust_remote_code: bool = False, -) -> tuple[nn.Module, str, list[str], torch.utils.data.DataLoader]: - """Load model and prepare it for FSDP2 distributed execution. - - Args: - model_path: Path to the HuggingFace model - calibration_dataloader: Calibration dataloader to be sharded for calibration - accelerator: Accelerate's Accelerator instance - trust_remote_code: Whether to trust remote code - - Returns: - Tuple of (prepared_model, model_type, original_architectures, calibration_dataloader) - """ - model = AutoModelForCausalLM.from_pretrained( - model_path, dtype="auto", trust_remote_code=trust_remote_code - ) - model.eval() - model_type = get_model_type(model) - # Need the original architectures for export - # FSDP prefix is added to the architectures for FSDP2 wrapped models - original_architectures = model.config.architectures - - # FSDP2 requires an optimizer to be prepared together with the model - dummy_optimizer = torch.optim.SGD(model.parameters(), lr=0.0) - model, _, calibration_dataloader = accelerator.prepare(model, dummy_optimizer, calib_dataloader) - - return model, model_type, original_architectures, calibration_dataloader - - -def create_calibration_dataloader( - tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - dataset_names: list[str], - calib_sizes: list[int], - batch_size: int, -) -> torch.utils.data.DataLoader: - """Create calibration dataloader from dataset. - - Args: - tokenizer: HuggingFace tokenizer - dataset_names: List of dataset names (defaults to cnn_dailymail) - calib_sizes: Number of samples for each dataset - batch_size: Batch size for calibration - - Returns: - DataLoader for calibration - """ - - return get_dataset_dataloader( - dataset_name=dataset_names, - tokenizer=tokenizer, - batch_size=batch_size, - num_samples=calib_sizes, - device=None, # Keep data on CPU, calibration loop handles device transfer - include_labels=False, - ) - - -def create_fsdp2_calibration_loop( - model: nn.Module, - dataloader: torch.utils.data.DataLoader, - accelerator: Accelerator, -): - """Create calibration loop compatible with FSDP2. - - For FSDP2, we need to use the outer FSDP-wrapped model instead of - the parameter passed by mtq.quantize to properly handle DTensor. - - Args: - model: FSDP2-wrapped model - dataloader: Calibration dataloader - accelerator: Accelerator instance for device management - - Returns: - Calibration function compatible with mtq.quantize - """ - - def calibrate(unwrapped_model): - """Calibration loop that uses the FSDP-wrapped model.""" - for batch in tqdm(dataloader, desc="Calibrating"): - if isinstance(batch, dict): - batch = { - k: v.to(accelerator.device) if isinstance(v, torch.Tensor) else v - for k, v in batch.items() - } - # Use outer model (FSDP-wrapped), not the parameter - # Important: We should forward pass using the unwrapped model - # mtq.quantize will unwrap the model & pass to the forward_loop - model(**batch) - - return calibrate - - -def export_model( - model: nn.Module, - accelerator: Accelerator, - export_path: str | Path, - architectures: list[str], -): - """Export quantized model to HuggingFace format. - - Args: - model: Quantized model - accelerator: Accelerator instance for state dict gathering - export_path: Directory to export model to - """ - export_dir = Path(export_path) - export_dir.mkdir(parents=True, exist_ok=True) - - post_state_dict, hf_quant_config = _export_transformers_checkpoint( - model, torch.bfloat16, accelerator=accelerator - ) - - if accelerator.is_main_process: - # Save hf_quant_config.json for backward compatibility - with open(f"{export_dir}/hf_quant_config.json", "w") as file: - json.dump(hf_quant_config, file, indent=4) - - hf_quant_config = convert_hf_quant_config_format(hf_quant_config) - - # Save model - model.save_pretrained(export_dir, state_dict=post_state_dict, save_modelopt_state=False) - - original_config = f"{export_dir}/config.json" - config_data = {} - - with open(original_config) as file: - config_data = json.load(file) - - config_data["quantization_config"] = hf_quant_config - # Update config architectures to use original architectures that does not have FSDP prefix - config_data["architectures"] = architectures - - with open(original_config, "w") as file: - json.dump(config_data, file, indent=4) - - -def main(args): - """Main quantization workflow.""" - # Validate GPU availability - if not torch.cuda.is_available(): - raise OSError("GPU is required for quantization.") - - # Validate quantization format - if args.qformat not in QUANT_CFG_CHOICES: - raise ValueError( - f"Quantization format {args.qformat} not supported. Choose from: {QUANT_CFG_CHOICES.keys()}" - ) - - # Set random seeds - random.seed(RAND_SEED) - np.random.seed(RAND_SEED) - torch.manual_seed(RAND_SEED) - - # Initialize accelerator - accelerator = Accelerator() - - print(f"Rank: {os.environ.get('RANK', 'Not set')}") - print(f"World Size: {os.environ.get('WORLD_SIZE', 'Not set')}") - print(f"Local Rank: {os.environ.get('LOCAL_RANK', 'Not set')}") - - # Load tokenizer - tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) - default_padding_side = tokenizer.padding_side - tokenizer.padding_side = "left" # Left padding for better calibration - - # Set default dataset if not provided - if args.dataset is None: - args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] - warnings.warn( - "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2." - ) - # Adjust calib_size to match dataset length by extending or truncating as needed - args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[ - : len(args.dataset) - ] - - # Create calibration dataloader with max batch size - calib_dataloader = create_calibration_dataloader( - tokenizer=tokenizer, - dataset_names=args.dataset, - calib_sizes=args.calib_size, - batch_size=args.batch_size, - ) - - # Load and prepare model - model, model_type, original_architectures, calib_dataloader = load_and_prepare_model( - model_path=args.pyt_ckpt_path, - calib_dataloader=calib_dataloader, - accelerator=accelerator, - trust_remote_code=args.trust_remote_code, - ) - - quant_cfg = QUANT_CFG_CHOICES[args.qformat] - - quant_cfg = build_quant_cfg( - args.qformat, - quant_cfg, - args.awq_block_size, - model_type, - ) - - enable_quant_kv_cache = args.kv_cache_qformat != "none" - print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") - - # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer. - if enable_quant_kv_cache: - quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant( - quant_cfg, - getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"], - ) - - # Quantize the model - if accelerator.is_main_process: - print("Starting quantization...") - - start_time = time.time() - - if need_calibration(quant_cfg): - calibrate_fn = create_fsdp2_calibration_loop(model, calib_dataloader, accelerator) - else: - calibrate_fn = None - warnings.warn("Dynamic quantization. Calibration skipped.") - - with torch.no_grad(): - model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_fn) - - elapsed = time.time() - start_time - - if accelerator.is_main_process: - print(f"Quantization completed in {elapsed:.2f}s") - mtq.print_quant_summary(model) - - start_time = time.time() - export_model(model, accelerator, args.export_path, original_architectures) - elapsed = time.time() - start_time - - if accelerator.is_main_process: - # Restore default padding and export the tokenizer as well. - if tokenizer is not None: - tokenizer.padding_side = default_padding_side - tokenizer.save_pretrained(args.export_path) - # Export the model - print(f"Export completed in {elapsed:.2f}s") - print(f"Model exported to {args.export_path}") - - print("Unpatching FSDP2 MP dtypes") - - -if __name__ == "__main__": - args = parse_args() - # This context manager can be removed once the update to FSDP2 function is reflected in torch - with patch_fsdp_mp_dtypes(): - main(args) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 0626d0a8fd5..4d00a705e68 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -49,6 +49,7 @@ except ImportError: HAS_DIFFUSERS = False +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict from torch.distributed.fsdp import FSDPModule from modelopt.torch.quantization import set_quantizer_by_cfg_context @@ -710,7 +711,6 @@ def _export_transformers_checkpoint( Args: model: the full torch model to export. The actual quantized model may be a submodule. dtype: the weights data type to export the unquantized layers or the default model data type if None. - accelerator: the accelerator instance in case of distributed export setup. Returns: post_state_dict: Dict containing quantized weights @@ -724,8 +724,6 @@ def _export_transformers_checkpoint( f"({dtype}), which may lead to numerical errors." ) - accelerator = kwargs.get("accelerator") - # Handle input quantizers of experts that are not calibrated for _, sub_module in model.named_modules(): if is_moe(sub_module) and hasattr(sub_module, "experts"): @@ -829,9 +827,14 @@ def _export_transformers_checkpoint( _reconstruct_fused_moe_linear(model) - if accelerator is not None: - # Gather state_dict from all ranks - quantized_state_dict = accelerator.get_state_dict(model) + if any(isinstance(m, FSDPModule) for m in model.modules()): + # FSDP2: gather full state_dict to CPU on rank 0 only. + quantized_state_dict = get_model_state_dict( + model, + options=StateDictOptions( + full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True + ), + ) else: quantized_state_dict = model.state_dict() @@ -1177,6 +1180,10 @@ def export_hf_checkpoint( This function automatically detects whether the model is from transformers or diffusers and applies the appropriate export logic. + Under ``torch.distributed`` (e.g. FSDP2), all ranks participate in the + collective state-dict gather inside ``_export_transformers_checkpoint``; + only rank 0 writes files. A final barrier syncs the other ranks. + Args: model: The full torch model to export. The actual quantized model may be a submodule. Supports both transformers models (e.g., LlamaForCausalLM) and diffusers @@ -1213,6 +1220,28 @@ def export_hf_checkpoint( try: post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype) + # Under torch.distributed: only rank 0 writes; others sync at the barrier below. + is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + if is_distributed: + rank = torch.distributed.get_rank() + populated = bool(post_state_dict) + if rank == 0 and not populated: + raise RuntimeError( + "Expected rank 0 to receive a populated state_dict from " + "_export_transformers_checkpoint under FSDP2 export; got empty. " + "Check that StateDictOptions(broadcast_from_rank0=True) is honored " + "by this PyTorch's get_model_state_dict." + ) + if rank != 0 and populated: + raise RuntimeError( + f"Expected rank {rank} to receive an empty state_dict from " + "_export_transformers_checkpoint under FSDP2 export (broadcast_from_rank0=True); " + "got populated. PyTorch's get_model_state_dict semantics may have changed." + ) + if rank != 0: + torch.distributed.barrier() + return + # Only treat the export as quantized when at least one quant_algo field is set. # get_quant_config always returns a dict (even for sparsity-only or unmodified models), # so emitting hf_quant_config.json unconditionally produces a file with @@ -1280,3 +1309,6 @@ def export_hf_checkpoint( " can be saved with torch.save for further inspection." ) raise e + + if is_distributed: + torch.distributed.barrier() diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index cea3d4260e4..d571c05c9e5 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -476,11 +476,9 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn. If TP is implemented with DTensor, the weight will be a local tensor of the TP DTensor under this context. """ - assert isinstance(root_model, torch.distributed.fsdp.FSDPModule), "We only support FSDP2" - assert not hasattr(module, "_hf_hook"), "We dont support FSDP2 with HF accelerate hooks" fsdp_module = _get_enclosing_fsdp_module(module, root_model) - assert fsdp_module is not None, "Module is not wrapped by FSDP" + assert fsdp_module is not None, "Module is not wrapped by FSDP2" fsdp_device_mesh = _get_fsdp2_mesh(fsdp_module) fsdp_dim = fsdp_device_mesh.ndim @@ -499,8 +497,29 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn. placements=[Replicate()] * fsdp_dim + list(original_placements[fsdp_dim:]), device_mesh=original_device_mesh, ) - originals[name] = (param, collected, original_placements, original_device_mesh) - _set_parameter(module, name, nn.Parameter(collected.to_local())) + local_replicated = collected.to_local() + # cpu_offload: gathered shard is on CPU; mirror to GPU for forward. + if local_replicated.device.type == "cpu" and torch.cuda.is_available(): + working_local = local_replicated.to(torch.cuda.current_device()) + originals[name] = ( + param, + collected, + original_placements, + original_device_mesh, + local_replicated, + working_local, + ) + _set_parameter(module, name, nn.Parameter(working_local)) + else: + originals[name] = ( + param, + collected, + original_placements, + original_device_mesh, + None, + None, + ) + _set_parameter(module, name, nn.Parameter(local_replicated)) yield @@ -510,7 +529,11 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn. collected, original_placements, original_device_mesh, + cpu_local, + gpu_working, ) in originals.items(): + if cpu_local is not None: + cpu_local.data.copy_(gpu_working.data.to(cpu_local.device)) original_param.to_local().data.copy_( collected.redistribute( placements=original_placements, device_mesh=original_device_mesh @@ -520,7 +543,11 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn. @contextmanager -def enable_weight_access_and_writeback(module, root_model, name_to_module: dict | None = None): +def enable_weight_access_and_writeback( + module, + root_model, + name_to_module: dict | None = None, +): """Enable weight access and writeback for a module. Useful for modules with weight not intact such as Linear layer in FSDP wrapped model or diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index 7922b688052..ba4ae4b693f 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -27,6 +27,7 @@ import torch import torch.distributed +import torch.nn as nn from torch.distributed.tensor import DTensor __all__ = [ @@ -34,10 +35,15 @@ "ParallelState", "backend", "barrier", + "fsdp2_shard", + "fsdp2_wrap", + "fsdp_aware_forward_loop", "is_available", "is_initialized", "is_master", + "load_fsdp2_causal_lm", "rank", + "shard_dataloader", "size", ] @@ -216,6 +222,488 @@ def cleanup(): torch.distributed.destroy_process_group() +def fsdp2_wrap( + model, + override_cls_name: str | None = None, + mp_policy=None, + device=None, + cpu_offload: bool = False, +): + """Apply FSDP2 ``fully_shard`` to each decoder layer of ``model``. + + Decoder layers are auto-detected via ``LayerActivationCollector.get_decoder_layers``; + pass ``override_cls_name`` to force a specific block class instead. + + Args: + mp_policy: ``MixedPrecisionPolicy`` for compute/reduce dtype (``None`` = no cast). + device: stream each layer here just before sharding (avoids holding the full + model on GPU at once); ``None`` shards in place. + cpu_offload: attach ``CPUOffloadPolicy`` so each shard lives on CPU between + forwards and streams to GPU per-layer. Trades PCIe traffic for GPU memory; + use only when the per-rank shard is the binding constraint. + + The root is intentionally NOT sharded — ``embed_tokens``/``lm_head`` stay plain + replicated tensors, since a DTensor ``embed_tokens.weight`` breaks the embedding + lookup on plain ``input_ids`` during layerwise calibration. + """ + from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard + + from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector + + if override_cls_name: + layers = [m for m in model.modules() if type(m).__name__ == override_cls_name] + if not layers: + raise RuntimeError(f"No modules of class {override_cls_name!r} found in model") + else: + layers = LayerActivationCollector.get_decoder_layers(model) + if layers is None: + raise RuntimeError( + "Could not auto-detect decoder layers; pass override_cls_name explicitly." + ) + fsdp_kwargs: dict[str, Any] = {"reshard_after_forward": True} + if mp_policy is not None: + fsdp_kwargs["mp_policy"] = mp_policy + if cpu_offload: + fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() + # Snapshot and restore config.architectures around fully_shard, in case the + # wrap mutates the class name that downstream save_pretrained reads. + original_architectures = list(getattr(model.config, "architectures", []) or []) + for layer in layers: + if device is not None: + layer.to(device) + fully_shard(layer, **fsdp_kwargs) + if original_architectures: + model.config.architectures = original_architectures + return model + + +def fsdp2_shard(model, device, src_state_dict=None, mp_policy=None, cpu_offload=False): + """Shard a model across the current process group (accelerate-style rank-0 load). + + Caller contract: ``model`` is built on every rank with params on ``meta`` and + buffers on CPU (use ``init_empty_weights(include_buffers=False)`` around + ``from_config``). Rank 0 additionally passes ``src_state_dict`` captured from a + real CPU model loaded via ``from_pretrained``; other ranks pass ``{}``. + + ``set_model_state_dict(broadcast_from_rank0=True)`` (step 3) is a collective, so + every rank must reach it: non-rank-0 ranks pass ``{}`` (empty, not ``None``) to + participate in the broadcast. ``src_state_dict=None`` skips the broadcast entirely + and must therefore be ``None`` on *all* ranks (e.g. sharding a model that will be + loaded later) — mixing ``None`` with a populated dict across ranks will hang. + + Set ``cpu_offload=True`` to attach FSDP2's ``CPUOffloadPolicy`` to wrapped + layers (each rank's shard lives on CPU between forwards). See + ``fsdp2_wrap`` docstring for the trade-off. + + Root is never sharded (see ``fsdp2_wrap`` docstring). embed_tokens and + lm_head stay as plain replicated tensors on every rank. + + Steps: + 1. ``fsdp2_wrap`` — apply ``fully_shard`` to decoder layers. + 2. Materialize: meta params → empty GPU storage; real CPU buffers → GPU + (preserves their values; ``to_empty`` is NOT used because it would wipe + buffers). + 3. ``set_model_state_dict(broadcast_from_rank0=True)`` — fills params and + persistent buffers from rank 0. + 4. ``model.tie_weights()`` — restore tied embeddings (no-op for untied). + 5. Freeze params (so ``patch_fsdp_mp_dtypes`` trainable-only check passes). + """ + from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict + + fsdp2_wrap(model, mp_policy=mp_policy, cpu_offload=cpu_offload) + + _materialize_meta_model(model, torch.device("cpu") if cpu_offload else device) + + if src_state_dict is not None: + set_model_state_dict( + model, + src_state_dict, + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), + ) + + if cpu_offload: + _promote_non_dtensor_to_gpu(model, device) + + if hasattr(model, "tie_weights"): + model.tie_weights() + + model.requires_grad_(False) + + return model + + +def shard_dataloader(loader, rank: int, world_size: int): + """Wrap a DataLoader with a DistributedSampler so each rank sees a unique shard. + + ``drop_last=False`` keeps per-rank batch counts equal (else a rank exits + calibration early and hangs the others on FSDP2 collectives), at the cost of the + sampler repeating up to ``world_size - 1`` samples to pad the even split. + + Forwards all non-sampler DataLoader settings from ``loader`` (workers, pinning, + prefetch, init fn, generator, ...). + """ + from torch.utils.data import DataLoader + from torch.utils.data.distributed import DistributedSampler + + sampler = DistributedSampler( + loader.dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + return DataLoader( + loader.dataset, + batch_size=loader.batch_size, + sampler=sampler, + collate_fn=loader.collate_fn, + num_workers=loader.num_workers, + pin_memory=loader.pin_memory, + timeout=loader.timeout, + worker_init_fn=loader.worker_init_fn, + multiprocessing_context=loader.multiprocessing_context, + generator=loader.generator, + prefetch_factor=loader.prefetch_factor, + persistent_workers=loader.persistent_workers, + pin_memory_device=getattr(loader, "pin_memory_device", ""), + ) + + +def fsdp_aware_forward_loop(wrapped_model, dataloader, device=None): + """Build an ``mtq.quantize`` ``forward_loop`` that respects FSDP wrapping. + + ``mtq.quantize`` hands ``forward_loop`` the *unwrapped* inner module, and calling + that bypasses FSDP's pre/post-forward hooks (no all-gather/reshard) — breaking + calibration. This closure ignores that argument and calls the captured *wrapped* + model instead. + + TODO: ``transformers_trainer.py`` (QLoRA path) has the same logic inlined in + ``_quantize_model``; consolidate it onto this helper. + """ + from tqdm import tqdm + + def calibrate(_unwrapped_model): + for batch in tqdm(dataloader, desc="Calibrating", disable=not is_master()): + if device is not None: + batch = { + k: (v.to(device) if isinstance(v, torch.Tensor) else v) + for k, v in batch.items() + } + wrapped_model(**batch) + + return calibrate + + +def broadcast_state_dict( + state_dict_or_none: dict | None, + src: int, + device: torch.device, + pg=None, +) -> dict: + """Broadcast a dict of CPU tensors from rank ``src`` to all ranks. + + Two phases: (1) broadcast metadata (key list + shape/dtype) via + ``broadcast_object_list``, (2) broadcast each tensor via ``dist.broadcast``. + Source rank passes the populated dict; non-source ranks pass ``None``. + Returns a dict of tensors on ``device`` on every rank. + """ + is_src = torch.distributed.get_rank() == src + meta: list[Any] = ( + [{name: (tuple(t.shape), t.dtype) for name, t in state_dict_or_none.items()}] + if is_src and state_dict_or_none is not None + else [None] + ) + torch.distributed.broadcast_object_list(meta, src=src, group=pg) + meta_dict = meta[0] + assert meta_dict is not None, f"src rank {src} passed no state dict to broadcast" + + src_state_dict = state_dict_or_none or {} + out: dict[str, torch.Tensor] = {} + for name, (shape, dtype) in meta_dict.items(): + if is_src: + t = src_state_dict[name].to(device, non_blocking=True) + else: + t = torch.empty(shape, dtype=dtype, device=device) + torch.distributed.broadcast(t, src=src, group=pg) + out[name] = t + return out + + +def _read_safetensors_state_dict( + ckpt_path: str, + weight_map: dict, + select: Callable[[str], bool], +) -> dict: + """Read tensors whose name satisfies ``select`` from safetensors files. + + Groups param names by file to avoid re-opening. Returns CPU tensors. + Uses ``safe_open`` so only the requested tensors' bytes are read. + """ + from safetensors import safe_open + + by_file: dict[str, list[str]] = {} + for name, file in weight_map.items(): + if select(name): + by_file.setdefault(file, []).append(name) + + state: dict[str, torch.Tensor] = {} + for file, names in by_file.items(): + with safe_open(os.path.join(ckpt_path, file), framework="pt", device="cpu") as f: + for name in names: + state[name] = f.get_tensor(name) + return state + + +def _materialize_meta_model(model: nn.Module, materialize_device: torch.device) -> None: + """Replace meta-device params/buffers with empty tensors on ``materialize_device``. + + Triggers FSDP2's ``_apply`` override on wrapped modules, which calls + ``reset_sharded_param`` to refresh FSDP's internal state. + """ + + def _fn(t): + is_meta_dtensor = isinstance(t, DTensor) and t._local_tensor.is_meta + if is_meta_dtensor or (not isinstance(t, DTensor) and t.is_meta): + return torch.empty_like(t, device=materialize_device) + return t.to(materialize_device) + + model._apply(_fn) + + +def _promote_non_dtensor_to_gpu(model: nn.Module, device: torch.device) -> None: + """Move all non-DTensor params + buffers in ``model`` to ``device`` in-place. + + Used after CPU-offload loading: decoder DTensor shards stay on CPU (FSDP2 + streams them to GPU per layer), while root-level plain params and buffers + need to live on GPU so forwards work. + """ + for module in model.modules(): + for name, param in list(module._parameters.items()): + if param is None or isinstance(param, DTensor): + continue + module._parameters[name] = nn.Parameter( + param.data.to(device), requires_grad=param.requires_grad + ) + for name, buf in list(module._buffers.items()): + if buf is None or isinstance(buf, DTensor): + continue + module._buffers[name] = buf.to(device) + + +def _load_via_parallel_read( + ckpt_path: str, + device: torch.device, + rank: int, + world_size: int, + trust_remote_code: bool, + mp_policy, + cpu_offload: bool, + weight_map: dict, + attn_implementation: str | None = None, +): + """Parallel-read path: each rank reads its share of decoder layers from disk. + + Phase D: each rank reads its owned layers from disk in parallel. + Phase E: per-layer broadcast from owner to all ranks; shard locally into the + FSDP2 DTensor via ``set_model_state_dict(broadcast_from_rank0=False)``. + Phase F: rank 0 reads + broadcasts non-decoder params; loaded into the + unwrapped root. + """ + from accelerate import init_empty_weights + from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict + from transformers import AutoConfig, AutoModelForCausalLM + + from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector + + config_kwargs: dict[str, Any] = {"trust_remote_code": trust_remote_code} + if attn_implementation is not None: + config_kwargs["attn_implementation"] = attn_implementation + hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) + dtype = getattr(hf_config, "torch_dtype", None) or torch.bfloat16 + + # include_buffers=False keeps computed buffers (rotary inv_freq, etc.) real on CPU. + with init_empty_weights(include_buffers=False): + model = AutoModelForCausalLM.from_config( + hf_config, torch_dtype=dtype, trust_remote_code=trust_remote_code + ) + model.eval() + if hasattr(model, "config") and hasattr(model.config, "use_cache"): + model.config.use_cache = False + + decoder_layers = LayerActivationCollector.get_decoder_layers(model) + if decoder_layers is None: + raise RuntimeError("Could not auto-detect decoder layers for parallel-read loader.") + module_to_name = {m: n for n, m in model.named_modules()} + layer_prefixes = [module_to_name[layer] + "." for layer in decoder_layers] + + fsdp2_wrap(model, mp_policy=mp_policy, cpu_offload=cpu_offload) + + _materialize_meta_model(model, torch.device("cpu") if cpu_offload else device) + + # Each rank reads only its owned layers from disk. + owned: dict[int, dict] = {} + for layer_idx in range(len(decoder_layers)): + if layer_idx % world_size == rank: + prefix = layer_prefixes[layer_idx] + + def _has_prefix(n: str, p: str = prefix) -> bool: + return n.startswith(p) + + owned[layer_idx] = _read_safetensors_state_dict(ckpt_path, weight_map, _has_prefix) + + # Per-layer broadcast from owner, then shard locally. + for layer_idx, layer in enumerate(decoder_layers): + src = layer_idx % world_size + layer_state_full = broadcast_state_dict(owned.get(layer_idx), src=src, device=device) + prefix = layer_prefixes[layer_idx] + stripped = {k[len(prefix) :]: v for k, v in layer_state_full.items()} + if cpu_offload: + stripped = {k: v.cpu() for k, v in stripped.items()} + set_model_state_dict( + layer, + stripped, + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=False), + ) + if src == rank: + del owned[layer_idx] + del layer_state_full, stripped + + # Non-decoder params (embed, lm_head, norm): rank 0 reads + broadcasts. + layer_prefix_tuple = tuple(layer_prefixes) + non_layer = ( + _read_safetensors_state_dict( + ckpt_path, weight_map, lambda n: not n.startswith(layer_prefix_tuple) + ) + if rank == 0 + else None + ) + non_layer = broadcast_state_dict(non_layer, src=0, device=device) + if cpu_offload: + non_layer = {k: v.cpu() for k, v in non_layer.items()} + missing, unexpected = model.load_state_dict(non_layer, strict=False, assign=False) + real_missing = [k for k in missing if not k.startswith(layer_prefix_tuple)] + if real_missing: + warn(f"Missing non-layer keys on rank {rank}: {real_missing[:5]}...") + if unexpected: + warn(f"Unexpected keys in non-layer state dict on rank {rank}: {unexpected[:3]}...") + + if cpu_offload: + _promote_non_dtensor_to_gpu(model, device) + + if hasattr(model, "tie_weights"): + model.tie_weights() + model.requires_grad_(False) + + return model + + +def load_fsdp2_causal_lm( + ckpt_path: str, + device: torch.device, + rank: int, + world_size: int = 1, + *, + trust_remote_code: bool = False, + mp_policy=None, + cpu_offload: bool = False, + attn_implementation: str | None = None, +): + """Load and FSDP2-shard a HuggingFace causal LM. + + Reusable loader with no dependency on argparse / CLI semantics. + + Default path: **parallel read** — each rank reads its share of decoder + layers from disk in parallel, broadcasts to other ranks. Eliminates the + rank-0 disk bottleneck. Handles ``cpu_offload`` internally. + + Fallback path (when no ``model.safetensors.index.json`` exists): rank-0 + ``from_pretrained`` + ``set_model_state_dict`` broadcast via + :func:`fsdp2_shard`. + + Both paths produce identical sharded models (same FSDP2 wrap layout, root + unsharded, decoder layers DTensor-sharded across the FSDP mesh). + """ + import json + + from accelerate import init_empty_weights + from transformers import AutoConfig, AutoModelForCausalLM + + # HF Hub ID: rank 0 downloads, others wait at the barrier. + resolved_path: str | None = ckpt_path + if not os.path.isdir(ckpt_path): + try: + from huggingface_hub import snapshot_download + except ImportError: + snapshot_download = None + if snapshot_download is not None: + if rank == 0: + resolved_path = snapshot_download(ckpt_path) + if is_initialized(): + barrier() + if rank != 0: + resolved_path = snapshot_download(ckpt_path) + else: + resolved_path = None + + index_path = ( + os.path.join(resolved_path, "model.safetensors.index.json") if resolved_path else None + ) + if resolved_path is not None and index_path is not None and os.path.exists(index_path): + with open(index_path) as f: + weight_map = json.load(f)["weight_map"] + return _load_via_parallel_read( + ckpt_path=resolved_path, + device=device, + rank=rank, + world_size=world_size, + trust_remote_code=trust_remote_code, + mp_policy=mp_policy, + cpu_offload=cpu_offload, + weight_map=weight_map, + attn_implementation=attn_implementation, + ) + + # Fallback: rank-0 from_pretrained + broadcast via fsdp2_shard. + config_kwargs: dict[str, Any] = {"trust_remote_code": trust_remote_code} + if attn_implementation is not None: + config_kwargs["attn_implementation"] = attn_implementation + hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) + dtype = getattr(hf_config, "torch_dtype", None) or torch.bfloat16 + + if rank == 0: + model_kwargs: dict[str, Any] = { + "torch_dtype": "auto", + "trust_remote_code": trust_remote_code, + "low_cpu_mem_usage": True, + } + if attn_implementation is not None: + model_kwargs["attn_implementation"] = attn_implementation + src_model = AutoModelForCausalLM.from_pretrained(ckpt_path, **model_kwargs) + src_model.eval() + src_state_dict = src_model.state_dict() + else: + src_model = None + src_state_dict = {} + + with init_empty_weights(include_buffers=False): + model = AutoModelForCausalLM.from_config( + hf_config, torch_dtype=dtype, trust_remote_code=trust_remote_code + ) + model.eval() + if hasattr(model, "config") and hasattr(model.config, "use_cache"): + model.config.use_cache = False + + sharded = fsdp2_shard( + model, + device, + src_state_dict=src_state_dict, + mp_policy=mp_policy, + cpu_offload=cpu_offload, + ) + del src_model, src_state_dict + return sharded + + class DistributedProcessGroup: """A convenient wrapper around torch.distributed.ProcessGroup objects.""" diff --git a/tests/gpu/torch/quantization/test_fsdp2.py b/tests/gpu/torch/quantization/test_fsdp2.py index c5584ece5cf..0fca03ada9d 100644 --- a/tests/gpu/torch/quantization/test_fsdp2.py +++ b/tests/gpu/torch/quantization/test_fsdp2.py @@ -261,3 +261,53 @@ def _test_persistent_materialization(rank, size): def test_persistent_materialization(dist_workers): dist_workers.run(_test_persistent_materialization) + + +def _test_writeback_root_unwrapped(rank, size): + """Writeback works when only the decoder layers are wrapped and the root is left + unsharded -- the layout ``fsdp2_wrap`` produces (root deliberately not wrapped) and + the one ``layerwise_calib`` save()/full_restore() rely on via + ``enable_weight_access_and_writeback(layer, model)``. + + Regression guard for the stale ``isinstance(root_model, FSDPModule)`` assert that + previously required the root itself to be FSDP-wrapped. + """ + from torch.distributed.tensor import DTensor + + from modelopt.torch.quantization.utils import enable_weight_access_and_writeback + + dim = 32 + torch.manual_seed(1) + # Root is a plain container; model[0] stands in for a decoder layer. + model = nn.Sequential(nn.Sequential(nn.Linear(dim, dim), nn.Linear(dim, dim))).cuda(rank) + synchronize_state_dict(model) + + # Wrap ONLY the "decoder layer" -- intentionally NO ``fully_shard(model)`` on the root, + # mirroring fsdp2_wrap. ``root_model`` (model) is therefore not an FSDPModule. + fully_shard(model[0]) + layer = model[0] + inputs = torch.randn(2, dim).cuda(rank) + + # Warmup forward to trigger FSDP2's lazy_init (mirrors layerwise calibration). + model(inputs) + + # Sharded before the context. + assert isinstance(next(iter(layer.parameters())), DTensor) + + # This is the exact call save()/full_restore() make. Before the fix it tripped the + # ``assert isinstance(root_model, FSDPModule)`` because the root is unwrapped. + with enable_weight_access_and_writeback(layer[0], model): + assert not isinstance(layer[0].weight, DTensor) # gathered to a local replicated tensor + ref_weight = layer[0].weight.clone() + layer[0].weight.data.add_(1.0) # mutate -> exercises the writeback path + + # Restored to a sharded DTensor on exit. + assert isinstance(next(iter(layer.parameters())), DTensor) + + # Modification was written back into the shards. + with enable_weight_access_and_writeback(layer[0], model): + assert torch.allclose(layer[0].weight, ref_weight + 1.0) + + +def test_writeback_root_unwrapped(dist_workers): + dist_workers.run(_test_writeback_root_unwrapped)