diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 337ed597301..ae6e58d9fb3 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -40,6 +40,7 @@ Changelog - Add mixed-precision FP8 + NVFP4 export for Megatron-Core: per-layer ``quant_algo`` recorded under ``quantized_layers`` in ``hf_quant_config.json``, PP-aware ``kv_cache_dtype`` gather, fused-QKV exclude split into per-HF-name ``q/k/v_proj`` entries. - Add Nemotron-3-Super-120B-A12B PTQ recipes ``modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml`` (MSE-mixed) and ``super-nvfp4-max-calib.yaml`` (max-calib mixed): NVFP4 W4A4 routed experts + FP8 per-tensor shared experts / Mamba in/out_proj + FP8 KV cache. - Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default). +- Add post-training quantization (PTQ) example for the Megatron-Bridge framework: ``examples/megatron_bridge/quantize.py`` calibrates an HF model (via ``--quant_cfg`` alias / full config name or a ``--recipe`` YAML, with optional KV-cache quant, weight-only, compression, and MoE expert-ratio calibration) and saves a Megatron checkpoint (tensor / pipeline / expert parallelism supported), and ``examples/megatron_bridge/export.py`` converts that checkpoint to a deployable HuggingFace (unified) checkpoint for TensorRT-LLM / vLLM / SGLang. See `examples/megatron_bridge/README.md `_ for details. **Bug Fixes** diff --git a/examples/megatron_bridge/README.md b/examples/megatron_bridge/README.md index 56dbe2ef7d7..4d000103ddf 100644 --- a/examples/megatron_bridge/README.md +++ b/examples/megatron_bridge/README.md @@ -1,15 +1,15 @@ # Megatron Bridge -This directory contains examples of using Model Optimizer with [NeMo Megatron-Bridge](https://github.com/NVIDIA-Nemo/Megatron-Bridge) framework for pruning, distillation, quantization, etc. +This directory contains examples of using Model Optimizer with [NeMo Megatron-Bridge](https://github.com/NVIDIA-Nemo/Megatron-Bridge) framework for quantization, distillation, pruning, etc.
| **Section** | **Description** | **Link** | | :------------: | :------------: | :------------: | | Pre-Requisites | Development environment setup | \[[Link](#pre-requisites)\] | -| Pruning | Examples of pruning a model using Minitron algorithm | \[[Link](#pruning)\] | -| Distillation | Examples of distillation a pruned or quantized model | \[[Link](#distillation)\] | -| Post-Training Quantization | Examples of quantizing a model | \[[Link](#post-training-quantization)\] | +| Post-Training Quantization | Quantizing a model | \[[Link](#post-training-quantization)\] | +| Distillation | Distilling a pruned or quantized model | \[[Link](#distillation)\] | +| Pruning | Pruning a model using Minitron algorithm | \[[Link](#pruning)\] | | Resources | Extra links to relevant resources | \[[Link](#resources)\] |
@@ -56,77 +56,40 @@ Note that the default dataset for pruning and quantization is [`nemotron-post-tr hf auth login --token ``` -## Pruning - -This section shows how to prune a HuggingFace model using Minitron algorithm in Megatron-Bridge framework. Checkout other available pruning algorithms, supported frameworks and models, and general pruning getting-started in the [pruning README](../pruning/README.md). - -The script supports three NAS-based pruning targets and one manual export mode: - -| Mode | Flag | Description | -| :---: | :---: | :--- | -| NAS | `--prune_target_params` | Prune to a target total parameter count | -| NAS | `--prune_target_active_params` | Prune to a target active parameter count (useful for MoE models). For non-MoE models, this is equivalent to `--prune_target_params`. | -| NAS | `--prune_target_memory_mb` | Prune to a target memory footprint in MB (weights + KV-cache) for a given batch size and sequence length assuming BF16 precision | -| Manual | `--prune_export_config` | Prune directly to a specified architecture config (no NAS). Useful if you want to take top K candidates and do a short knowledge distillation before selecting the best model. | - -Multiple NAS targets can be combined — e.g. `--prune_target_params 6e9 --prune_target_memory_mb 12288` finds the best model with under 6B params and under 12GB memory footprint at (default) batch size 1 and sequence length 4096 assuming BF16 precision. - -**Prune by total parameter count** — prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while skipping pruning of `num_attention_heads` using following defaults: - 1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration, - at-most 20% depth (`num_layers`) and 40% width is pruned per prunable hparam (`hidden_size`, `ffn_hidden_size`, ...), - top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model. +## Post-Training Quantization -```bash -torchrun --nproc_per_node 2 prune_minitron.py \ - --pp_size 2 \ - --hf_model_name_or_path Qwen/Qwen3-8B \ - --prune_target_params 6e9 \ - --hparams_to_skip num_attention_heads \ - --output_hf_path /tmp/Qwen3-8B-Pruned-6B -``` +This section shows how to quantize a HuggingFace model using ModelOpt in the Megatron-Bridge framework. Quantization is a two-step flow: -**Prune by active parameter count** — useful for MoE models where most experts are inactive per token (e.g. prune Nemotron-3-Nano-30B-A3B-BF16 (3.6B active params) to 3B active params): +1. [quantize.py](quantize.py) applies post-training quantization (PTQ) with calibration and saves a **Megatron checkpoint** (with ModelOpt state). Tensor / pipeline / expert parallelism are all supported, and the checkpoint can be reloaded for further training (Quantization Aware Training / Quantization Aware Distillation). +2. [export.py](export.py) converts that Megatron checkpoint to a **HuggingFace (unified) checkpoint** that deploys directly with TensorRT-LLM, vLLM, or SGLang. -```bash -torchrun --nproc_per_node 2 prune_minitron.py \ - --pp_size 2 \ - --hf_model_name_or_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ - --prune_target_active_params 3e9 \ - --output_hf_path /tmp/Nemotron-3-Nano-30B-A3B-BF16-Pruned-3B-Active -``` +`quantize.py` supports the following formats via `--quant_cfg` (e.g. `fp8`, `nvfp4`, `int8_sq`, `int4_awq`, `w4a8_awq`, ...). You can also pass any full config name exposed by ModelOpt (e.g. `FP8_DEFAULT_CFG`) or a YAML `--recipe` (e.g. `general/ptq/fp8_default-kv_fp8`, authoritative for quant_cfg + algorithm + KV-cache). KV-cache quantization can be enabled on top via `--kv_cache_quant` (e.g. `fp8`, `nvfp4`). -**Prune by memory footprint** — prune to fit a target GPU memory budget (weights + KV-cache at the given sequence length and batch size, assuming BF16): +**Step 1 — quantize** Qwen3-8B to FP8 on 2 GPUs (Tensor Parallelism = 2) using 1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration: ```bash -torchrun --nproc_per_node 2 prune_minitron.py \ - --pp_size 2 \ +torchrun --nproc_per_node 2 quantize.py \ --hf_model_name_or_path Qwen/Qwen3-8B \ - --prune_target_memory_mb 12288 \ - --seq_length 4096 \ - --calib_batch_size 1 \ - --output_hf_path /tmp/Qwen3-8B-Pruned-12GB + --quant_cfg fp8 \ + --tp_size 2 \ + --export_megatron_path /tmp/Qwen3-8B-FP8-megatron ``` -**Manual pruning** — prune directly to a specified architecture (no NAS, no score evaluation): +**Step 2 — export** the Megatron checkpoint to a deployable HuggingFace checkpoint: ```bash -torchrun --nproc_per_node 2 prune_minitron.py \ - --pp_size 2 \ +torchrun --nproc_per_node 1 export.py \ --hf_model_name_or_path Qwen/Qwen3-8B \ - --prune_export_config '{"hidden_size": 3584, "ffn_hidden_size": 9216}' \ - --output_hf_path /tmp/Qwen3-8B-Pruned-6B-manual + --megatron_path /tmp/Qwen3-8B-FP8-megatron \ + --export_unified_hf_path /tmp/Qwen3-8B-FP8-hf ``` -To see the full usage for advanced configurations, run: +> [!NOTE] +> The HuggingFace unified exporter does not gather tensor-parallel-sharded weights. Use `--pp_size` on `export.py` to shard a large model with pipeline parallelism across GPUs for export. -```bash -torchrun --nproc_per_node 1 prune_minitron.py --help -``` +To see the full usage for advanced configurations, run `torchrun --nproc_per_node 1 quantize.py --help` (or `export.py --help`). -> [!TIP] -> If number of layers in the model is not divisible by number of GPUs i.e. pipeline parallel (PP) size, you can configure -> uneven PP by setting `--num_layers_in_first_pipeline_stage` and `--num_layers_in_last_pipeline_stage`. -> E.g. for Qwen3-8B with 36 layers and 8 GPUs, you can set both to 3 to get 3-5-5-5-5-5-5-3 layers per GPU. +For Quantization scripts covering VLMs, QAT, and resuming quantized checkpoints, see the Megatron-Bridge repository [here](https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/examples/quantization). ## Distillation @@ -230,9 +193,77 @@ For more details, see the [Megatron-Bridge conversion README](https://github.com See [examples/pruning/](../pruning/README.md#tutorials--results) for distillation experiment results covering Minitron and Puzzletron pruning algorithms. -## Post-Training Quantization +## Pruning + +This section shows how to prune a HuggingFace model using Minitron algorithm in Megatron-Bridge framework. Checkout other available pruning algorithms, supported frameworks and models, and general pruning getting-started in the [pruning README](../pruning/README.md). -Checkout Quantization scripts for LLMs and VLMs in the Megatron-Bridge repository [here](https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/examples/quantization). +The script supports three NAS-based pruning targets and one manual export mode: + +| Mode | Flag | Description | +| :---: | :---: | :--- | +| NAS | `--prune_target_params` | Prune to a target total parameter count | +| NAS | `--prune_target_active_params` | Prune to a target active parameter count (useful for MoE models). For non-MoE models, this is equivalent to `--prune_target_params`. | +| NAS | `--prune_target_memory_mb` | Prune to a target memory footprint in MB (weights + KV-cache) for a given batch size and sequence length assuming BF16 precision | +| Manual | `--prune_export_config` | Prune directly to a specified architecture config (no NAS). Useful if you want to take top K candidates and do a short knowledge distillation before selecting the best model. | + +Multiple NAS targets can be combined — e.g. `--prune_target_params 6e9 --prune_target_memory_mb 12288` finds the best model with under 6B params and under 12GB memory footprint at (default) batch size 1 and sequence length 4096 assuming BF16 precision. + +**Prune by total parameter count** — prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while skipping pruning of `num_attention_heads` using following defaults: + 1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration, + at-most 20% depth (`num_layers`) and 40% width is pruned per prunable hparam (`hidden_size`, `ffn_hidden_size`, ...), + top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model. + +```bash +torchrun --nproc_per_node 2 prune_minitron.py \ + --pp_size 2 \ + --hf_model_name_or_path Qwen/Qwen3-8B \ + --prune_target_params 6e9 \ + --hparams_to_skip num_attention_heads \ + --output_hf_path /tmp/Qwen3-8B-Pruned-6B +``` + +**Prune by active parameter count** — useful for MoE models where most experts are inactive per token (e.g. prune Nemotron-3-Nano-30B-A3B-BF16 (3.6B active params) to 3B active params): + +```bash +torchrun --nproc_per_node 2 prune_minitron.py \ + --pp_size 2 \ + --hf_model_name_or_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ + --prune_target_active_params 3e9 \ + --output_hf_path /tmp/Nemotron-3-Nano-30B-A3B-BF16-Pruned-3B-Active +``` + +**Prune by memory footprint** — prune to fit a target GPU memory budget (weights + KV-cache at the given sequence length and batch size, assuming BF16): + +```bash +torchrun --nproc_per_node 2 prune_minitron.py \ + --pp_size 2 \ + --hf_model_name_or_path Qwen/Qwen3-8B \ + --prune_target_memory_mb 12288 \ + --seq_length 4096 \ + --calib_batch_size 1 \ + --output_hf_path /tmp/Qwen3-8B-Pruned-12GB +``` + +**Manual pruning** — prune directly to a specified architecture (no NAS, no score evaluation): + +```bash +torchrun --nproc_per_node 2 prune_minitron.py \ + --pp_size 2 \ + --hf_model_name_or_path Qwen/Qwen3-8B \ + --prune_export_config '{"hidden_size": 3584, "ffn_hidden_size": 9216}' \ + --output_hf_path /tmp/Qwen3-8B-Pruned-6B-manual +``` + +To see the full usage for advanced configurations, run: + +```bash +torchrun --nproc_per_node 1 prune_minitron.py --help +``` + +> [!TIP] +> If number of layers in the model is not divisible by number of GPUs i.e. pipeline parallel (PP) size, you can configure +> uneven PP by setting `--num_layers_in_first_pipeline_stage` and `--num_layers_in_last_pipeline_stage`. +> E.g. for Qwen3-8B with 36 layers and 8 GPUs, you can set both to 3 to get 3-5-5-5-5-5-5-3 layers per GPU. ## Resources diff --git a/examples/megatron_bridge/distill.py b/examples/megatron_bridge/distill.py index 16a0a85f842..9b9014f9560 100644 --- a/examples/megatron_bridge/distill.py +++ b/examples/megatron_bridge/distill.py @@ -51,7 +51,7 @@ from transformers import AutoConfig import modelopt.torch.utils.distributed as dist -from modelopt.torch.utils import print_rank_0 +from modelopt.torch.utils import print_args, print_rank_0 with contextlib.suppress(ModuleNotFoundError): import modelopt.torch.puzzletron.plugins.mbridge # noqa: F401 @@ -112,7 +112,6 @@ def get_args(): parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size") parser.add_argument("--cp_size", type=int, default=1, help="Context parallel size") parser.add_argument("--ep_size", type=int, default=1, help="Expert parallel size") - parser.add_argument("--etp_size", type=int, default=1, help="Expert tensor parallel size") # Dataset arguments parser.add_argument( @@ -225,10 +224,7 @@ def get_args(): if args.hf_export_path and not args.student_hf_model: raise ValueError("Must provide --student_hf_model if --hf_export_path is provided.") - print_rank_0("\n==================== Arguments ====================") - for k, v in args.__dict__.items(): - print_rank_0(f"{k:<35} {v}") - print_rank_0("===================================================\n") + print_args(args) return args @@ -249,7 +245,7 @@ def _build_model_provider(hf_path): provider.pipeline_dtype = torch.bfloat16 provider.context_parallel_size = args.cp_size provider.expert_model_parallel_size = args.ep_size - provider.expert_tensor_parallel_size = args.etp_size + provider.expert_tensor_parallel_size = 1 # Expert tensor parallelism is not supported provider.seq_length = args.seq_length if args.recompute_granularity is not None: provider.recompute_granularity = args.recompute_granularity diff --git a/examples/megatron_bridge/export.py b/examples/megatron_bridge/export.py new file mode 100644 index 00000000000..71a00ead469 --- /dev/null +++ b/examples/megatron_bridge/export.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Export a quantized Megatron checkpoint (produced by quantize.py) to a HuggingFace (unified) +checkpoint that can be deployed directly with TensorRT-LLM, vLLM, or SGLang. + +The process is as follows: + 1. Build the Megatron-Core model structure + tokenizer from the original HuggingFace model. + 2. Load the quantized Megatron checkpoint (ModelOpt state + weights are restored automatically). + 3. Export the model to a HuggingFace (unified) checkpoint via ModelOpt. + +The HuggingFace unified exporter does not gather tensor-parallel-sharded weights, so this script +always loads the checkpoint at tensor_model_parallel_size=1 (re-sharding from whatever TP was used +during quantization). Use --pp_size to shard a large model across GPUs for export. + +Example usage to export an FP8 checkpoint produced by quantize.py: + + torchrun --nproc_per_node 1 export.py \ + --hf_model_name_or_path Qwen/Qwen3-8B \ + --megatron_path /tmp/Qwen3-8B-FP8-megatron \ + --export_unified_hf_path /tmp/Qwen3-8B-FP8-hf + +See `README.md` in this directory for more details. +""" + +import argparse + +import torch +from megatron.bridge import AutoBridge +from megatron.bridge.models.hf_pretrained.utils import is_safe_repo +from megatron.core.utils import unwrap_model + +import modelopt.torch.export as mtex +import modelopt.torch.utils.distributed as dist +from modelopt.torch.utils import print_args, print_rank_0 + +_DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32} + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--hf_model_name_or_path", + type=str, + required=True, + help="Original HuggingFace model (used for the model structure, tokenizer, and config).", + ) + parser.add_argument( + "--megatron_path", + type=str, + required=True, + help="Path to the quantized Megatron checkpoint produced by quantize.py.", + ) + parser.add_argument( + "--export_unified_hf_path", + type=str, + required=True, + help="Directory to write the exported HuggingFace (unified) checkpoint to.", + ) + parser.add_argument("--trust_remote_code", action="store_true") + + # Only Pipeline parallelism is supported for export + parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size") + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=list(_DTYPE_MAP), + help="Data type for the exported weights.", + ) + parser.add_argument( + "--export_extra_modules", + action="store_true", + help="Export extra modules such as Medusa heads, EAGLE, or MTP.", + ) + + args = parser.parse_args() + + print_args(args) + + return args + + +def main(args: argparse.Namespace): + trust_remote_code = is_safe_repo( + trust_remote_code=args.trust_remote_code, hf_path=args.hf_model_name_or_path + ) + torch_dtype = _DTYPE_MAP[args.dtype] + + # Build the model structure + tokenizer from HF (weights come from the Megatron checkpoint). + bridge = AutoBridge.from_hf_pretrained( + args.hf_model_name_or_path, trust_remote_code=trust_remote_code + ) + provider = bridge.to_megatron_provider(load_weights=False) + provider.tensor_model_parallel_size = 1 # Tensor parallelism is not supported + provider.pipeline_model_parallel_size = args.pp_size + provider.expert_model_parallel_size = 1 # Expert parallelism is not supported + provider.expert_tensor_parallel_size = 1 # Expert tensor parallelism is not supported + provider.pipeline_dtype = torch_dtype + provider.finalize() + provider.initialize_model_parallel(seed=0) + + # Load the quantized checkpoint. ModelOpt state + weights are restored automatically, and the + # checkpoint is re-sharded to TP=1 regardless of the parallelism used during quantization. + print_rank_0(f"Loading quantized Megatron checkpoint from {args.megatron_path}...") + megatron_model = bridge.load_megatron_model( + args.megatron_path, + mp_overrides={ + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": args.pp_size, + "expert_model_parallel_size": 1, + "expert_tensor_parallel_size": 1, + }, + wrap_with_ddp=False, + ) + unwrapped_model = unwrap_model(megatron_model[0]) + + # Extra modules (Medusa / EAGLE / MTP) only exist on the last pipeline stage. Use an all-reduce + # MAX over all ranks (rather than a broadcast from a hard-coded source rank) so the decision is + # correct regardless of pipeline placement / global rank ordering. + has_extra_modules = hasattr(unwrapped_model, "eagle_module") or hasattr( + unwrapped_model, "medusa_heads" + ) + if torch.distributed.is_initialized(): + flag = torch.tensor( + [int(has_extra_modules)], dtype=torch.int, device=torch.cuda.current_device() + ) + torch.distributed.all_reduce(flag, op=torch.distributed.ReduceOp.MAX) + has_extra_modules = bool(flag.item()) + export_extra_modules = has_extra_modules and args.export_extra_modules + + print_rank_0( + f"Exporting to HuggingFace (unified) checkpoint at {args.export_unified_hf_path}..." + ) + mtex.export_mcore_gpt_to_hf( + unwrapped_model, + args.hf_model_name_or_path, + export_extra_modules=export_extra_modules, + dtype=torch_dtype, + export_dir=args.export_unified_hf_path, + moe_router_dtype=getattr(unwrapped_model.config, "moe_router_dtype", None), + trust_remote_code=trust_remote_code, + ) + print_rank_0(f"Exported HuggingFace checkpoint to {args.export_unified_hf_path}") + + +if __name__ == "__main__": + dist.setup() + args = get_args() + try: + main(args) + finally: + dist.cleanup() diff --git a/examples/megatron_bridge/prune_minitron.py b/examples/megatron_bridge/prune_minitron.py index ffe0b834b65..247cbf170e0 100644 --- a/examples/megatron_bridge/prune_minitron.py +++ b/examples/megatron_bridge/prune_minitron.py @@ -52,7 +52,7 @@ import modelopt.torch.opt as mto import modelopt.torch.prune as mtp import modelopt.torch.utils.distributed as dist -from modelopt.torch.utils import get_supported_datasets, print_rank_0, warn_rank_0 +from modelopt.torch.utils import get_supported_datasets, print_args, print_rank_0, warn_rank_0 from modelopt.torch.utils.plugins.mbridge import load_mbridge_model_from_hf from modelopt.torch.utils.plugins.megatron_calibration import get_megatron_calibration_forward_loop from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu @@ -256,10 +256,7 @@ def get_args() -> argparse.Namespace: raise ValueError("--prune_export_config must parse to a dictionary.") args.prune_export_config = prune_export_config - print_rank_0("\n==================== Arguments ====================") - for k, v in args.__dict__.items(): - print_rank_0(f"{k:<35} {v}") - print_rank_0("===================================================\n") + print_args(args) return args @@ -280,7 +277,8 @@ def main(args: argparse.Namespace): hf_model_name_or_path=args.hf_model_name_or_path, trust_remote_code=args.trust_remote_code, provider_overrides={ - "tensor_model_parallel_size": 1, + "tensor_model_parallel_size": 1, # Tensor parallelism is not supported + "expert_tensor_parallel_size": 1, # Expert tensor parallelism is not supported "pipeline_model_parallel_size": args.pp_size, "num_layers_in_first_pipeline_stage": args.num_layers_in_first_pipeline_stage, "num_layers_in_last_pipeline_stage": args.num_layers_in_last_pipeline_stage, diff --git a/examples/megatron_bridge/quantize.py b/examples/megatron_bridge/quantize.py new file mode 100644 index 00000000000..fff16518d5f --- /dev/null +++ b/examples/megatron_bridge/quantize.py @@ -0,0 +1,384 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Example script for post-training quantization (PTQ) of a GPT / Mamba model using ModelOpt on a +Megatron-Bridge model (loaded from HF). + +The process is as follows: + 1. Load a pretrained HuggingFace model into a Megatron-Core model via Megatron-Bridge. + 2. Apply ModelOpt quantization (fake-quant) with calibration on a few samples from a dataset. + The quantization format is specified either by a short --quant_cfg alias or a --recipe YAML. + 3. (Optional) Compress weights to a real low-bit representation. + 4. Save the quantized model as a Megatron checkpoint (with ModelOpt state). The checkpoint can be + reloaded for further training (QAT / distillation) or converted to a HuggingFace (unified) + checkpoint for deployment with `export.py` (see that script for TensorRT-LLM / vLLM / SGLang). + +Tensor / pipeline / expert parallelism are all supported here — the Megatron checkpoint is saved +sharded and can be re-sharded on load (e.g. `export.py` reloads it at TP=1 for the HF export). + +Example usage to quantize Qwen3-8B to FP8 on 2 GPUs (Tensor Parallelism = 2): + 1024 samples from nemotron-post-training-dataset-v2 are used for calibration. + + torchrun --nproc_per_node 2 quantize.py \ + --hf_model_name_or_path Qwen/Qwen3-8B \ + --quant_cfg fp8 \ + --tp_size 2 \ + --export_megatron_path /tmp/Qwen3-8B-FP8-megatron + +Equivalent run using a YAML recipe (authoritative for quant_cfg + algorithm + KV-cache config): + + torchrun --nproc_per_node 2 quantize.py \ + --hf_model_name_or_path Qwen/Qwen3-8B \ + --recipe general/ptq/fp8_default-kv_fp8 \ + --tp_size 2 \ + --export_megatron_path /tmp/Qwen3-8B-FP8-megatron + +To convert the saved Megatron checkpoint to a deployable HuggingFace checkpoint, run `export.py`. + +To see the full usage for advanced configurations, run: + torchrun --nproc_per_node 1 quantize.py --help + +See `README.md` in this directory for more details. +""" + +import argparse +import copy + +import torch + +import modelopt.torch.quantization as mtq +import modelopt.torch.utils.distributed as dist +from modelopt.recipe import ModelOptPTQRecipe, load_recipe +from modelopt.torch.utils import print_args, print_rank_0, warn_rank_0 +from modelopt.torch.utils.plugins.mbridge import load_mbridge_model_from_hf +from modelopt.torch.utils.plugins.megatron_calibration import get_megatron_calibration_forward_loop +from modelopt.torch.utils.plugins.megatron_generate import megatron_generate + +# Curated short-name aliases for the most common quantization configs. Any other config exposed by +# ``mtq.config.choices`` (e.g. ``FP8_DEFAULT_CFG``) can also be passed by its full name. +QUANT_CFG_CHOICES = { + "int8": mtq.INT8_DEFAULT_CFG, + "int8_sq": mtq.INT8_SMOOTHQUANT_CFG, + "fp8": mtq.FP8_DEFAULT_CFG, + "fp8_blockwise": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, + "int4_awq": mtq.INT4_AWQ_CFG, + "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG, + "nvfp4": mtq.NVFP4_DEFAULT_CFG, + "nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG, +} + +# KV-cache quantization configs (applied on top of the weight/activation quant config). +KV_QUANT_CFG_CHOICES = { + "none": "none", + "fp8": "FP8_KV_CFG", + "nvfp4": "NVFP4_KV_CFG", + "nvfp4_affine": "NVFP4_AFFINE_KV_CFG", +} + +# TODO: Add AutoQuantize (mtq.auto_quantize) support to automatically search a per-layer mix of +# quantization formats that meets a target compression / accuracy constraint, instead of applying a +# single fixed --quant_cfg / --recipe to the whole model. + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--hf_model_name_or_path", type=str, required=True) + parser.add_argument("--trust_remote_code", action="store_true") + parser.add_argument( + "--export_megatron_path", + type=str, + required=True, + help="Path to save the quantized model in Megatron checkpoint format (with ModelOpt state).", + ) + + # Parallelism arguments + parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size") + parser.add_argument("--ep_size", type=int, default=1, help="Expert parallel size") + + # Quantization arguments + parser.add_argument( + "--recipe", + type=str, + default=None, + help=( + "PTQ recipe YAML file or builtin name (e.g. 'general/ptq/fp8_default-kv_fp8'). " + "When set, --quant_cfg, --kv_cache_quant, --weight_only, and --moe_calib_experts_ratio " + "are ignored; the recipe is authoritative for quant_cfg, algorithm, and KV-cache config." + ), + ) + parser.add_argument( + "--quant_cfg", + type=str, + default="fp8", + help=( + f"Quantization config. Short aliases: {', '.join(QUANT_CFG_CHOICES)}. " + "You can also pass any full config name exposed by modelopt (e.g. FP8_DEFAULT_CFG). " + "Ignored when --recipe is set." + ), + ) + parser.add_argument( + "--kv_cache_quant", + type=str, + default="none", + choices=list(KV_QUANT_CFG_CHOICES), + help="KV-cache quantization config to apply on top of --quant_cfg. Ignored when --recipe is set.", + ) + parser.add_argument( + "--weight_only", + action="store_true", + help="Disable input (activation) quantization, i.e. weight-only quantization.", + ) + parser.add_argument( + "--compress", + action="store_true", + help="Compress weights to a real low-bit representation (instead of fake quantization).", + ) + parser.add_argument( + "--moe_calib_experts_ratio", + type=float, + default=None, + help=( + "Fraction of experts (in (0.0, 1.0]) to calibrate per forward pass for MoE models. " + "Lower values speed up calibration of models with many experts; ignored for dense models." + ), + ) + + # Calibration dataset arguments + parser.add_argument( + "--calib_dataset_name", + type=str, + default="nemotron-post-training-dataset-v2", + help="HF Dataset name or local path used for calibration.", + ) + parser.add_argument( + "--calib_num_samples", type=int, default=1024, help="Number of samples for calibration" + ) + parser.add_argument("--calib_batch_size", type=int, default=1, help="Calibration batch size") + parser.add_argument("--seq_length", type=int, default=4096, help="Calibration sequence length") + + # Post-quantization generation (sanity check) arguments + parser.add_argument( + "--prompts", + type=str, + default="Hello!|Born in California, Soyer trained as a", + help="Prompts to sanity-check the quantized model. Use | to separate batches.", + ) + parser.add_argument( + "--osl", + type=int, + default=32, + help="Output sequence length for the generation sanity check.", + ) + parser.add_argument( + "--skip_generate", + action="store_true", + help="Skip the post-quantization generation sanity check.", + ) + + args = parser.parse_args() + + if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): + parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].") + + print_args(args) + + return args + + +def get_quant_config(args: argparse.Namespace) -> dict: + """Build the ModelOpt quantization config dict from the parsed arguments.""" + if args.recipe is not None: + # A YAML recipe is authoritative: it encodes quant_cfg + algorithm + KV-cache config + # directly, so the --quant_cfg / --kv_cache_quant / --weight_only / --moe_calib_experts_ratio + # customizations below are skipped. + print_rank_0(f"Using recipe {args.recipe} for quantization") + if ( + args.kv_cache_quant != "none" + or args.weight_only + or args.moe_calib_experts_ratio is not None + ): + warn_rank_0( + "--kv_cache_quant / --weight_only / --moe_calib_experts_ratio are ignored when " + "--recipe is set; the recipe is authoritative." + ) + recipe = load_recipe(args.recipe) + if not isinstance(recipe, ModelOptPTQRecipe): + raise TypeError( + f"Expected a PTQ recipe but got {type(recipe).__name__} from {args.recipe}" + ) + return recipe.quantize.model_dump() + + if args.quant_cfg in QUANT_CFG_CHOICES: + mtq_config = QUANT_CFG_CHOICES[args.quant_cfg] + elif args.quant_cfg in mtq.config.choices: + mtq_config = getattr(mtq, args.quant_cfg) + else: + raise ValueError( + f"Unsupported --quant_cfg '{args.quant_cfg}'. Choose one of the short aliases " + f"({', '.join(QUANT_CFG_CHOICES)}) or a full config name from {mtq.config.choices}." + ) + + # Deepcopy so we don't mutate the shared module-level config, and normalize the inner quant_cfg + # to the list format so we can safely append customizations below. + mtq_config = copy.deepcopy(mtq_config) + mtq_config["quant_cfg"] = mtq.normalize_quant_cfg_list(mtq_config["quant_cfg"]) + + if args.weight_only: + mtq_config["quant_cfg"].append({"quantizer_name": "*input_quantizer", "enable": False}) + + if args.kv_cache_quant != "none": + kv_cache_quant_cfg = getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_quant])["quant_cfg"] + mtq_config = mtq.utils.update_quant_cfg_with_kv_cache_quant(mtq_config, kv_cache_quant_cfg) + + # For MoE models, optionally calibrate only a fraction of experts per forward pass for speed. + if args.moe_calib_experts_ratio is not None: + algorithm = mtq_config.get("algorithm") + if isinstance(algorithm, str): + mtq_config["algorithm"] = { + "method": algorithm, + "moe_calib_experts_ratio": args.moe_calib_experts_ratio, + } + elif isinstance(algorithm, dict): + algorithm["moe_calib_experts_ratio"] = args.moe_calib_experts_ratio + else: + warn_rank_0( + f"Quantization algorithm {algorithm!r} does not support moe_calib_experts_ratio; ignoring." + ) + + return mtq_config + + +def main(args: argparse.Namespace): + bridge, _provider, model, unwrapped_model, tokenizer = load_mbridge_model_from_hf( + hf_model_name_or_path=args.hf_model_name_or_path, + trust_remote_code=args.trust_remote_code, + provider_overrides={ + "tensor_model_parallel_size": args.tp_size, + "pipeline_model_parallel_size": args.pp_size, + "expert_model_parallel_size": args.ep_size, + "expert_tensor_parallel_size": 1, # Expert tensor parallelism is not supported + "pipeline_dtype": torch.bfloat16, + "seq_length": args.seq_length, + }, + init_model_parallel=True, + # Grouped GEMM is not supported for PTQ + export; use the per-expert (sequential) MLP. + moe_grouped_gemm=False, + ) + + mtq_config = get_quant_config(args) + + # KV-cache quantization is incompatible with weight compression. Validate on the *resolved* + # config (KV-cache quantizers are named ``*[kv]_bmm_quantizer``) so this also covers + # recipe-driven KV-cache configs, not just the --kv_cache_quant flag. + if args.compress and any( + isinstance(entry, dict) and "bmm_quantizer" in str(entry.get("quantizer_name", "")) + for entry in mtq.normalize_quant_cfg_list(mtq_config["quant_cfg"]) + ): + raise ValueError("--compress cannot be combined with KV-cache quantization.") + + print_rank_0(f"Quantizing the model with: {args.recipe or args.quant_cfg}") + if "awq" in str(mtq_config.get("algorithm")): + print_rank_0( + "AWQ calibration can take longer than other methods; " + "reduce --calib_num_samples to speed it up." + ) + + # Dynamic and weight-only configs need no activation statistics, so skip both the + # (potentially expensive) calibration dataset download and the calibration forward pass. + if mtq.need_calibration(mtq_config): + forward_loop = get_megatron_calibration_forward_loop( + tokenizer, + dataset_name=args.calib_dataset_name, + num_samples=args.calib_num_samples, + seq_length=args.seq_length, + batch_size=args.calib_batch_size, + # Calibrate on unpacked sequences. pack=True is Megatron pretraining-style global-stream + # document packing, which changes the per-sample calibration statistics. + pack=False, + ) + else: + warn_rank_0("Dynamic or weight-only quantization detected; skipping calibration.") + forward_loop = None + + if hasattr(unwrapped_model, "calibration_mode"): + # Some model wrappers (e.g. distillation/speculative) gate calibration behind a flag. + # Reset it in a finally so a failure mid-calibration doesn't leave the flag set for the + # subsequent compress / save calls. + unwrapped_model.calibration_mode = True + try: + mtq.quantize(unwrapped_model, mtq_config, forward_loop) + finally: + unwrapped_model.calibration_mode = False + else: + mtq.quantize(unwrapped_model, mtq_config, forward_loop) + + if args.compress: + mtq.compress(unwrapped_model) + print_rank_0("Weights are now compressed to low-bit!") + + # Save the quantizer summary alongside the checkpoint for later inspection. Only the master + # rank writes the file to avoid a multi-rank race on the same path. + if dist.is_master(): + mtq.print_quant_summary(unwrapped_model, args.export_megatron_path) + + print_rank_0(f"Saving quantized model to {args.export_megatron_path} in Megatron format...") + bridge.save_megatron_model( + model, + args.export_megatron_path, + hf_tokenizer_path=args.hf_model_name_or_path, + hf_tokenizer_kwargs={"trust_remote_code": args.trust_remote_code}, + ) + print_rank_0(f"Saved quantized model to {args.export_megatron_path} in Megatron format") + print_rank_0( + "To deploy this model (TensorRT-LLM / vLLM / SGLang), convert it to a HuggingFace " + f"checkpoint with export.py:\n" + f" torchrun --nproc_per_node export.py " + f"--hf_model_name_or_path {args.hf_model_name_or_path} " + f"--megatron_path {args.export_megatron_path} " + f"--export_unified_hf_path {args.export_megatron_path}_hf" + ) + + # Sanity-check generation with the fake-quantized model. Skipped when --compress is set: the + # weights are now real low-bit and megatron_generate may not support compressed forward for + # every quant format. + if args.compress and not args.skip_generate: + warn_rank_0( + "Skipping the post-quantization generation sanity check because --compress is set." + ) + if not args.skip_generate and not args.compress: + print_rank_0("Testing quantized model with custom prompts...") + unwrapped_model.eval() + for idx, prompt in enumerate(args.prompts.split("|")): + tokens = tokenizer(prompt, return_tensors="pt") + # enable_kv_cache=False avoids pre-allocating the static KV cache: this is a short + # sanity-check generation and the KV-cache allocation can OOM tight quantization runs + # on large MoE models. + generated_ids = megatron_generate( + unwrapped_model, tokens.input_ids.cuda(), osl=args.osl, enable_kv_cache=False + ) + generated_texts = tokenizer.batch_decode(generated_ids) + print_rank_0(f"Prompt {idx + 1}: {prompt}") + print_rank_0(f"Generated: {generated_texts}") + + print_rank_0("Done!") + + +if __name__ == "__main__": + dist.setup() + args = get_args() + try: + main(args) + finally: + dist.cleanup() diff --git a/modelopt/torch/utils/logging.py b/modelopt/torch/utils/logging.py index 724dda14434..d986034747c 100644 --- a/modelopt/torch/utils/logging.py +++ b/modelopt/torch/utils/logging.py @@ -15,6 +15,7 @@ """Utility functions for logging.""" +import argparse import contextlib import os import re @@ -37,6 +38,7 @@ "capture_io", "no_stdout", "num2hrb", + "print_args", "print_rank_0", "silence_matched_warnings", "warn_rank_0", @@ -110,6 +112,15 @@ def print_rank_0(*args, **kwargs): print(*args, **kwargs) +def print_args(args: argparse.Namespace, title: str = "Arguments") -> None: + """Pretty-print an ``argparse.Namespace`` (one entry per line) on rank 0.""" + header = f"{'=' * 20} {title} {'=' * 20}" + print_rank_0(f"\n{header}") + for key, value in vars(args).items(): + print_rank_0(f"{key:<35} {value}") + print_rank_0("=" * len(header) + "\n") + + def warn_rank_0(message, *args, **kwargs): """Issues a warning only on the master process. diff --git a/tests/examples/megatron_bridge/test_quantize_export.py b/tests/examples/megatron_bridge/test_quantize_export.py new file mode 100644 index 00000000000..f22a5e73b17 --- /dev/null +++ b/tests/examples/megatron_bridge/test_quantize_export.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for quantize.py and export.py scripts.""" + +from pathlib import Path + +from _test_utils.examples.run_command import extend_cmd_parts, run_example_command +from _test_utils.torch.transformers_models import create_tiny_qwen3_dir + + +def test_quantize_and_export(tmp_path: Path, num_gpus): + """Quantize a tiny Qwen3 via a YAML recipe and export it to a unified HF checkpoint.""" + # Use a vLLM-friendly head_dim (64) since the default tiny config (head_dim=2) is unsupported. + hf_model_path = create_tiny_qwen3_dir( + tmp_path, + with_tokenizer=True, + hidden_size=128, + num_attention_heads=2, + num_key_value_heads=2, + num_hidden_layers=2, + intermediate_size=256, + max_position_embeddings=512, + ) + megatron_path = tmp_path / "qwen3_fp8_megatron" + hf_export_path = tmp_path / "qwen3_fp8_hf" + + # Step 1: quantize (tensor parallelism is supported here) and save a Megatron checkpoint. The + # checkpoint must carry the ModelOpt state so it can be reloaded (for export or further QAT/QAD). + quantize_cmd = extend_cmd_parts( + ["torchrun", f"--nproc_per_node={num_gpus}", "quantize.py", "--skip_generate"], + hf_model_name_or_path=hf_model_path, + recipe="general/ptq/fp8_default-kv_fp8", + tp_size=num_gpus, + calib_dataset_name="cnn_dailymail", + calib_num_samples=16, + calib_batch_size=1, + seq_length=32, + export_megatron_path=megatron_path, + ) + run_example_command(quantize_cmd, example_path="megatron_bridge", setup_free_port=True) + assert (megatron_path / "latest_checkpointed_iteration.txt").exists() + assert list(megatron_path.rglob("modelopt_state")), ( + "Expected modelopt_state in the Megatron checkpoint" + ) + + # Step 2: export to HF (re-shards to TP=1) on a single rank. export.py reloads the quantized + # Megatron checkpoint (restoring the ModelOpt quantizers) before converting to HF. + export_cmd = extend_cmd_parts( + ["torchrun", "--nproc_per_node=1", "export.py"], + hf_model_name_or_path=hf_model_path, + megatron_path=megatron_path, + export_unified_hf_path=hf_export_path, + ) + run_example_command(export_cmd, example_path="megatron_bridge", setup_free_port=True) + + # HF (unified) quantized checkpoint exists with the exported quantization config + weights. + # hf_quant_config.json is only written when the reloaded model is actually quantized, so its + # presence also confirms export.py restored the ModelOpt quantizers from the checkpoint. + assert (hf_export_path / "config.json").exists() + assert (hf_export_path / "hf_quant_config.json").exists() + assert list(hf_export_path.glob("*.safetensors")), "Expected exported safetensors weights" + + # The exported unified checkpoint should be loadable and runnable by vLLM. The deployment check below + # is disabled because it hangs in CI; to validate deployment locally in nemo container, uncomment it + # + # import vllm + # llm = vllm.LLM( + # model=str(hf_export_path), + # tensor_parallel_size=1, + # enforce_eager=True, + # gpu_memory_utilization=0.4, + # max_model_len=128, + # dtype="bfloat16", + # ) + # outputs = llm.generate(["Hello!"], vllm.SamplingParams(max_tokens=4)) + # assert outputs and outputs[0].outputs and outputs[0].outputs[0].text