Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Changelog
- Add mixed-precision FP8 + NVFP4 export for Megatron-Core: per-layer ``quant_algo`` recorded under ``quantized_layers`` in ``hf_quant_config.json``, PP-aware ``kv_cache_dtype`` gather, fused-QKV exclude split into per-HF-name ``q/k/v_proj`` entries.
- Add Nemotron-3-Super-120B-A12B PTQ recipes ``modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml`` (MSE-mixed) and ``super-nvfp4-max-calib.yaml`` (max-calib mixed): NVFP4 W4A4 routed experts + FP8 per-tensor shared experts / Mamba in/out_proj + FP8 KV cache.
- Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default).
- Add post-training quantization (PTQ) example for the Megatron-Bridge framework: ``examples/megatron_bridge/quantize.py`` calibrates an HF model (via ``--quant_cfg`` alias / full config name or a ``--recipe`` YAML, with optional KV-cache quant, weight-only, compression, and MoE expert-ratio calibration) and saves a Megatron checkpoint (tensor / pipeline / expert parallelism supported), and ``examples/megatron_bridge/export.py`` converts that checkpoint to a deployable HuggingFace (unified) checkpoint for TensorRT-LLM / vLLM / SGLang. See `examples/megatron_bridge/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge>`_ for details.

**Bug Fixes**

Expand Down
153 changes: 92 additions & 61 deletions examples/megatron_bridge/README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# Megatron Bridge

This directory contains examples of using Model Optimizer with [NeMo Megatron-Bridge](https://github.com/NVIDIA-Nemo/Megatron-Bridge) framework for pruning, distillation, quantization, etc.
This directory contains examples of using Model Optimizer with [NeMo Megatron-Bridge](https://github.com/NVIDIA-Nemo/Megatron-Bridge) framework for quantization, distillation, pruning, etc.

<div align="center">

| **Section** | **Description** | **Link** |
| :------------: | :------------: | :------------: |
| Pre-Requisites | Development environment setup | \[[Link](#pre-requisites)\] |
| Pruning | Examples of pruning a model using Minitron algorithm | \[[Link](#pruning)\] |
| Distillation | Examples of distillation a pruned or quantized model | \[[Link](#distillation)\] |
| Post-Training Quantization | Examples of quantizing a model | \[[Link](#post-training-quantization)\] |
| Post-Training Quantization | Quantizing a model | \[[Link](#post-training-quantization)\] |
| Distillation | Distilling a pruned or quantized model | \[[Link](#distillation)\] |
| Pruning | Pruning a model using Minitron algorithm | \[[Link](#pruning)\] |
| Resources | Extra links to relevant resources | \[[Link](#resources)\] |

</div>
Expand Down Expand Up @@ -56,77 +56,40 @@ Note that the default dataset for pruning and quantization is [`nemotron-post-tr
hf auth login --token <your token>
```

## Pruning

This section shows how to prune a HuggingFace model using Minitron algorithm in Megatron-Bridge framework. Checkout other available pruning algorithms, supported frameworks and models, and general pruning getting-started in the [pruning README](../pruning/README.md).

The script supports three NAS-based pruning targets and one manual export mode:

| Mode | Flag | Description |
| :---: | :---: | :--- |
| NAS | `--prune_target_params` | Prune to a target total parameter count |
| NAS | `--prune_target_active_params` | Prune to a target active parameter count (useful for MoE models). For non-MoE models, this is equivalent to `--prune_target_params`. |
| NAS | `--prune_target_memory_mb` | Prune to a target memory footprint in MB (weights + KV-cache) for a given batch size and sequence length assuming BF16 precision |
| Manual | `--prune_export_config` | Prune directly to a specified architecture config (no NAS). Useful if you want to take top K candidates and do a short knowledge distillation before selecting the best model. |

Multiple NAS targets can be combined — e.g. `--prune_target_params 6e9 --prune_target_memory_mb 12288` finds the best model with under 6B params and under 12GB memory footprint at (default) batch size 1 and sequence length 4096 assuming BF16 precision.

**Prune by total parameter count** — prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while skipping pruning of `num_attention_heads` using following defaults:
1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration,
at-most 20% depth (`num_layers`) and 40% width is pruned per prunable hparam (`hidden_size`, `ffn_hidden_size`, ...),
top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model.
## Post-Training Quantization

```bash
torchrun --nproc_per_node 2 prune_minitron.py \
--pp_size 2 \
--hf_model_name_or_path Qwen/Qwen3-8B \
--prune_target_params 6e9 \
--hparams_to_skip num_attention_heads \
--output_hf_path /tmp/Qwen3-8B-Pruned-6B
```
This section shows how to quantize a HuggingFace model using ModelOpt in the Megatron-Bridge framework. Quantization is a two-step flow:

**Prune by active parameter count** — useful for MoE models where most experts are inactive per token (e.g. prune Nemotron-3-Nano-30B-A3B-BF16 (3.6B active params) to 3B active params):
1. [quantize.py](quantize.py) applies post-training quantization (PTQ) with calibration and saves a **Megatron checkpoint** (with ModelOpt state). Tensor / pipeline / expert parallelism are all supported, and the checkpoint can be reloaded for further training (Quantization Aware Training / Quantization Aware Distillation).
2. [export.py](export.py) converts that Megatron checkpoint to a **HuggingFace (unified) checkpoint** that deploys directly with TensorRT-LLM, vLLM, or SGLang.

```bash
torchrun --nproc_per_node 2 prune_minitron.py \
--pp_size 2 \
--hf_model_name_or_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \
--prune_target_active_params 3e9 \
--output_hf_path /tmp/Nemotron-3-Nano-30B-A3B-BF16-Pruned-3B-Active
```
`quantize.py` supports the following formats via `--quant_cfg` (e.g. `fp8`, `nvfp4`, `int8_sq`, `int4_awq`, `w4a8_awq`, ...). You can also pass any full config name exposed by ModelOpt (e.g. `FP8_DEFAULT_CFG`) or a YAML `--recipe` (e.g. `general/ptq/fp8_default-kv_fp8`, authoritative for quant_cfg + algorithm + KV-cache). KV-cache quantization can be enabled on top via `--kv_cache_quant` (e.g. `fp8`, `nvfp4`).

**Prune by memory footprint** — prune to fit a target GPU memory budget (weights + KV-cache at the given sequence length and batch size, assuming BF16):
**Step 1 — quantize** Qwen3-8B to FP8 on 2 GPUs (Tensor Parallelism = 2) using 1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration:

```bash
torchrun --nproc_per_node 2 prune_minitron.py \
--pp_size 2 \
torchrun --nproc_per_node 2 quantize.py \
--hf_model_name_or_path Qwen/Qwen3-8B \
--prune_target_memory_mb 12288 \
--seq_length 4096 \
--calib_batch_size 1 \
--output_hf_path /tmp/Qwen3-8B-Pruned-12GB
--quant_cfg fp8 \
--tp_size 2 \
--export_megatron_path /tmp/Qwen3-8B-FP8-megatron
```

**Manual pruning** — prune directly to a specified architecture (no NAS, no score evaluation):
**Step 2 — export** the Megatron checkpoint to a deployable HuggingFace checkpoint:

```bash
torchrun --nproc_per_node 2 prune_minitron.py \
--pp_size 2 \
torchrun --nproc_per_node 1 export.py \
--hf_model_name_or_path Qwen/Qwen3-8B \
--prune_export_config '{"hidden_size": 3584, "ffn_hidden_size": 9216}' \
--output_hf_path /tmp/Qwen3-8B-Pruned-6B-manual
--megatron_path /tmp/Qwen3-8B-FP8-megatron \
--export_unified_hf_path /tmp/Qwen3-8B-FP8-hf
```

To see the full usage for advanced configurations, run:
> [!NOTE]
> The HuggingFace unified exporter does not gather tensor-parallel-sharded weights. Use `--pp_size` on `export.py` to shard a large model with pipeline parallelism across GPUs for export.

```bash
torchrun --nproc_per_node 1 prune_minitron.py --help
```
To see the full usage for advanced configurations, run `torchrun --nproc_per_node 1 quantize.py --help` (or `export.py --help`).

> [!TIP]
> If number of layers in the model is not divisible by number of GPUs i.e. pipeline parallel (PP) size, you can configure
> uneven PP by setting `--num_layers_in_first_pipeline_stage` and `--num_layers_in_last_pipeline_stage`.
> E.g. for Qwen3-8B with 36 layers and 8 GPUs, you can set both to 3 to get 3-5-5-5-5-5-5-3 layers per GPU.
For Quantization scripts covering VLMs, QAT, and resuming quantized checkpoints, see the Megatron-Bridge repository [here](https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/examples/quantization).

## Distillation

Expand Down Expand Up @@ -230,9 +193,77 @@ For more details, see the [Megatron-Bridge conversion README](https://github.com

See [examples/pruning/](../pruning/README.md#tutorials--results) for distillation experiment results covering Minitron and Puzzletron pruning algorithms.

## Post-Training Quantization
## Pruning

This section shows how to prune a HuggingFace model using Minitron algorithm in Megatron-Bridge framework. Checkout other available pruning algorithms, supported frameworks and models, and general pruning getting-started in the [pruning README](../pruning/README.md).

Checkout Quantization scripts for LLMs and VLMs in the Megatron-Bridge repository [here](https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/examples/quantization).
Comment thread
kevalmorabia97 marked this conversation as resolved.
The script supports three NAS-based pruning targets and one manual export mode:

| Mode | Flag | Description |
| :---: | :---: | :--- |
| NAS | `--prune_target_params` | Prune to a target total parameter count |
| NAS | `--prune_target_active_params` | Prune to a target active parameter count (useful for MoE models). For non-MoE models, this is equivalent to `--prune_target_params`. |
| NAS | `--prune_target_memory_mb` | Prune to a target memory footprint in MB (weights + KV-cache) for a given batch size and sequence length assuming BF16 precision |
| Manual | `--prune_export_config` | Prune directly to a specified architecture config (no NAS). Useful if you want to take top K candidates and do a short knowledge distillation before selecting the best model. |

Multiple NAS targets can be combined — e.g. `--prune_target_params 6e9 --prune_target_memory_mb 12288` finds the best model with under 6B params and under 12GB memory footprint at (default) batch size 1 and sequence length 4096 assuming BF16 precision.

**Prune by total parameter count** — prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while skipping pruning of `num_attention_heads` using following defaults:
1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration,
at-most 20% depth (`num_layers`) and 40% width is pruned per prunable hparam (`hidden_size`, `ffn_hidden_size`, ...),
top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model.

```bash
torchrun --nproc_per_node 2 prune_minitron.py \
--pp_size 2 \
--hf_model_name_or_path Qwen/Qwen3-8B \
--prune_target_params 6e9 \
--hparams_to_skip num_attention_heads \
--output_hf_path /tmp/Qwen3-8B-Pruned-6B
```

**Prune by active parameter count** — useful for MoE models where most experts are inactive per token (e.g. prune Nemotron-3-Nano-30B-A3B-BF16 (3.6B active params) to 3B active params):

```bash
torchrun --nproc_per_node 2 prune_minitron.py \
--pp_size 2 \
--hf_model_name_or_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \
--prune_target_active_params 3e9 \
--output_hf_path /tmp/Nemotron-3-Nano-30B-A3B-BF16-Pruned-3B-Active
```

**Prune by memory footprint** — prune to fit a target GPU memory budget (weights + KV-cache at the given sequence length and batch size, assuming BF16):

```bash
torchrun --nproc_per_node 2 prune_minitron.py \
--pp_size 2 \
--hf_model_name_or_path Qwen/Qwen3-8B \
--prune_target_memory_mb 12288 \
--seq_length 4096 \
--calib_batch_size 1 \
--output_hf_path /tmp/Qwen3-8B-Pruned-12GB
```

**Manual pruning** — prune directly to a specified architecture (no NAS, no score evaluation):

```bash
torchrun --nproc_per_node 2 prune_minitron.py \
--pp_size 2 \
--hf_model_name_or_path Qwen/Qwen3-8B \
--prune_export_config '{"hidden_size": 3584, "ffn_hidden_size": 9216}' \
--output_hf_path /tmp/Qwen3-8B-Pruned-6B-manual
```

To see the full usage for advanced configurations, run:

```bash
torchrun --nproc_per_node 1 prune_minitron.py --help
```

> [!TIP]
> If number of layers in the model is not divisible by number of GPUs i.e. pipeline parallel (PP) size, you can configure
> uneven PP by setting `--num_layers_in_first_pipeline_stage` and `--num_layers_in_last_pipeline_stage`.
> E.g. for Qwen3-8B with 36 layers and 8 GPUs, you can set both to 3 to get 3-5-5-5-5-5-5-3 layers per GPU.

## Resources

Expand Down
10 changes: 3 additions & 7 deletions examples/megatron_bridge/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from transformers import AutoConfig

import modelopt.torch.utils.distributed as dist
from modelopt.torch.utils import print_rank_0
from modelopt.torch.utils import print_args, print_rank_0

with contextlib.suppress(ModuleNotFoundError):
import modelopt.torch.puzzletron.plugins.mbridge # noqa: F401
Expand Down Expand Up @@ -112,7 +112,6 @@ def get_args():
parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size")
parser.add_argument("--cp_size", type=int, default=1, help="Context parallel size")
parser.add_argument("--ep_size", type=int, default=1, help="Expert parallel size")
parser.add_argument("--etp_size", type=int, default=1, help="Expert tensor parallel size")

# Dataset arguments
parser.add_argument(
Expand Down Expand Up @@ -225,10 +224,7 @@ def get_args():
if args.hf_export_path and not args.student_hf_model:
raise ValueError("Must provide --student_hf_model if --hf_export_path is provided.")

print_rank_0("\n==================== Arguments ====================")
for k, v in args.__dict__.items():
print_rank_0(f"{k:<35} {v}")
print_rank_0("===================================================\n")
print_args(args)

return args

Expand All @@ -249,7 +245,7 @@ def _build_model_provider(hf_path):
provider.pipeline_dtype = torch.bfloat16
provider.context_parallel_size = args.cp_size
provider.expert_model_parallel_size = args.ep_size
provider.expert_tensor_parallel_size = args.etp_size
provider.expert_tensor_parallel_size = 1 # Expert tensor parallelism is not supported
provider.seq_length = args.seq_length
if args.recompute_granularity is not None:
provider.recompute_granularity = args.recompute_granularity
Expand Down
Loading
Loading