Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions examples/llm_ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -370,32 +370,42 @@ mtq.calibrate(model, algorithm="max", forward_loop=calibrate_loop)

## Multi-Node Post-Training Quantization with FSDP2

ModelOpt enables quantization of LLMs across multiple GPU nodes using various quantization formats. It leverages HuggingFace's Accelerate library and FSDP2 for distributed model sharding and calibration.
ModelOpt enables quantization of LLMs across multiple GPU nodes using FSDP2 for distributed model sharding and calibration, exposed via the `--use_fsdp2` flag on the standard `hf_ptq.py` entry point.

### Usage

For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user specific requirements.
Single-node (multiple GPUs):

On each node run the following command:
```bash
torchrun --standalone --nproc_per_node=<num_gpus> hf_ptq.py \
--pyt_ckpt_path <path_to_model> \
--qformat <fp8/nvfp4/nvfp4_max/nvfp4_max_layerwise/...> \
--kv_cache_qformat <fp8/nvfp4/nvfp4_affine/none> \
--batch_size <calib_batch_size> \
--calib_size <num_calib_samples> \
--export_path <export_path> \
--use_fsdp2
```

Multi-node (run on each node):

```bash
accelerate launch --config_file fsdp2.yaml \
--num_machines=<num_nodes> \
--machine_rank=<current_node_rank> \
--main_process_ip=<node0_ip_addr> \
--main_process_port=<port> \
--fsdp_transformer_layer_cls_to_wrap=<decoder_layer_name>
multinode_ptq.py \
torchrun \
--nnodes=<num_nodes> --node_rank=<current_node_rank> \
--master_addr=<node0_ip_addr> --master_port=<port> \
--nproc_per_node=<num_gpus_per_node> \
hf_ptq.py \
--pyt_ckpt_path <path_to_model> \
--qformat <fp8/nvfp4/nvfp4_mlp_only/nvfp4_experts_only/nvfp4_omlp_only/nvfp4_awq/int8> \
--qformat <qformat> \
--kv_cache_qformat <fp8/nvfp4/nvfp4_affine/none> \
--batch_size <calib_batch_size> \
--calib_size <num_calib_samples> \
--dataset <dataset> \
--export_path <export_path> \
--trust_remote_code
--use_fsdp2
```

For layerwise calibration (amortizes cross-node all-gather cost across all calibration batches), use `--qformat nvfp4_max_layerwise`.

The exported checkpoint can be deployed using TensorRT-LLM/ vLLM/ SGLang. For more details refer to the [deployment section](#deployment) of this document.

> *Performance Note: FSDP2 is designed for training workloads and may result in longer calibration and export times. For faster calibration, maximize the batch size based on available GPU memory and choose the right number of GPUs to avoid unnecessary communication.*
Expand Down
102 changes: 102 additions & 0 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,108 @@
SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]


def setup_distributed_args(args):
"""Set ``args.rank``/``world_size``/``device``/``is_main`` (single-process if FSDP2 off)."""
from modelopt.torch.utils import distributed as dist_utils

if getattr(args, "use_fsdp2", False):
dist_utils.setup()
args.rank = dist_utils.rank()
args.world_size = dist_utils.size()
args.device = torch.device(f"cuda:{dist_utils.local_rank()}")
args.is_main = args.rank == 0
else:
args.rank = 0
args.world_size = 1
args.is_main = True
Comment thread
sugunav14 marked this conversation as resolved.


def cleanup_distributed(args):
"""Destroy the process group if ``--use_fsdp2`` set it up."""
from modelopt.torch.utils import distributed as dist_utils

if getattr(args, "use_fsdp2", False):
dist_utils.cleanup()


def _checkpoint_has_mtp_weights(model_path: str) -> bool:
"""Return True if the checkpoint's safetensors index advertises MTP weights."""
candidates = [Path(model_path) / "model.safetensors.index.json"]
try:
from huggingface_hub import try_to_load_from_cache

cached = try_to_load_from_cache(model_path, "model.safetensors.index.json")
except ImportError:
cached = None
if cached:
candidates.append(Path(cached))
for index_file in candidates:
if not index_file.exists():
continue
try:
weight_map = json.load(open(index_file)).get("weight_map", {})
except (OSError, json.JSONDecodeError):
continue
return any("mtp" in k or "mtp" in v for k, v in weight_map.items())
return False


def validate_fsdp2_supported(args, config):
"""Raise ``NotImplementedError`` for model/CLI combos the FSDP2 path doesn't support yet."""
issues = []
if "vila" in args.pyt_ckpt_path.lower():
issues.append("VILA (custom builder + non-standard layer layout)")
if is_nemotron_vl(config) or _is_multimodal_config(config):
issues.append("multimodal / VL models (decoder layers not auto-detectable)")
if getattr(config, "quantization_config", None) is not None:
issues.append("pack-quantized / compressed-tensors checkpoints")
if getattr(args, "specdec_offline_dataset", None) is not None:
issues.append("speculative decoding (--specdec_offline_dataset)")
if getattr(args, "low_memory_mode", False):
issues.append("--low_memory_mode (redundant with FSDP2)")
if _checkpoint_has_mtp_weights(args.pyt_ckpt_path):
issues.append(
"MTP (Multi-Token Prediction) weights — the FSDP2 loader doesn't "
"carry them through; the exported checkpoint would be missing MTP layers"
)
if issues:
raise NotImplementedError(
"--use_fsdp2 does not support:\n - "
+ "\n - ".join(issues)
+ "\nRemove --use_fsdp2 or use a standard causal-LM checkpoint."
)


def load_and_prepare_fsdp2_model(
ckpt_path: str,
device: torch.device,
rank: int,
world_size: int = 1,
args=None,
trust_remote_code: bool = False,
mp_policy=None,
cpu_offload: bool = False,
attn_implementation: str | None = None,
):
"""Validate CLI constraints, then delegate to :func:`load_fsdp2_causal_lm`."""
from modelopt.torch.utils.distributed import load_fsdp2_causal_lm

if args is not None:
hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=trust_remote_code)
validate_fsdp2_supported(args, hf_config)

return load_fsdp2_causal_lm(
ckpt_path,
device,
rank,
world_size,
trust_remote_code=trust_remote_code,
mp_policy=mp_policy,
cpu_offload=cpu_offload,
attn_implementation=attn_implementation,
)


def run_nemotron_vl_preview(
full_model,
tokenizer,
Expand Down
30 changes: 0 additions & 30 deletions examples/llm_ptq/fsdp2.yaml

This file was deleted.

Loading