Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
bcc05b1
add rabbit feedback
Fridah-nv Feb 6, 2026
e2c781e
minor
Fridah-nv Feb 13, 2026
7f21be1
tested perplexity
sugunav14 Feb 4, 2026
9b600bf
tested, revert later
sugunav14 Feb 9, 2026
0cc53df
tested
sugunav14 Feb 10, 2026
cde122a
refactor
sugunav14 Feb 11, 2026
c2aeed5
Track global_amax for weight FP4 MSE sweep; Refactor to NVFP4StaticQa…
realAsma Feb 6, 2026
d1ebcca
address reviewers feedback, delegate scaling factor calculation to NV…
Fridah-nv Feb 6, 2026
2cf8294
tested perplexity
sugunav14 Feb 4, 2026
abf6e8d
tested exported checkpoints on 0211
sugunav14 Feb 12, 2026
c604539
tested nano v3
sugunav14 Feb 13, 2026
f94e577
added activation MSE logging
sugunav14 Feb 16, 2026
b570e7b
super v3 run
sugunav14 Feb 17, 2026
ae30ff1
debug logs
sugunav14 Feb 17, 2026
f809975
added activationmse logging helper
sugunav14 Feb 17, 2026
8bf4e29
input amax sync added + tested gptq super sft checkpoint
sugunav14 Feb 19, 2026
01640cf
checkpoints generated on 0223
sugunav14 Feb 23, 2026
c1c7c96
tested perplexity
sugunav14 Feb 4, 2026
bcf6fe3
tested, revert later
sugunav14 Feb 9, 2026
558041c
tested
sugunav14 Feb 10, 2026
a907038
initial cleanup
sugunav14 Feb 24, 2026
1e5ce74
cleanup
sugunav14 Feb 24, 2026
96f35c8
removed stray config
sugunav14 Feb 24, 2026
72faf3e
removed stray prints
sugunav14 Feb 24, 2026
890f28b
fix rebase issues
sugunav14 Mar 6, 2026
40a6657
minor
sugunav14 Mar 6, 2026
cd9246a
tested e2e on qwen
sugunav14 Mar 6, 2026
49e54cc
removed perplexity eval
sugunav14 Mar 6, 2026
2bc18cd
update
sugunav14 Mar 6, 2026
c311950
revert later
sugunav14 Mar 16, 2026
fec8f89
minor update
sugunav14 Mar 19, 2026
806e8ac
update
sugunav14 Mar 18, 2026
3f2d7c0
gptq faster
sugunav14 Mar 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import argparse
import copy
import os
import random
import time
import warnings
Expand Down Expand Up @@ -61,6 +62,11 @@
)
from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration
from modelopt.torch.quantization.metrics import (
ActivationMSELogger,
compute_perplexity,
get_wikitext2,
)
from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
from modelopt.torch.quantization.utils import is_quantized
from modelopt.torch.utils.dataset_utils import (
Expand Down Expand Up @@ -97,6 +103,7 @@ def _set_kv_cache_constant_amax(quant_cfg: dict) -> None:
"int4_awq": mtq.INT4_AWQ_CFG,
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
"nvfp4_wo": mtq.NVFP4_WEIGHT_ONLY_CFG,
"nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG,
"nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG,
"fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
Expand All @@ -106,6 +113,8 @@ def _set_kv_cache_constant_amax(quant_cfg: dict) -> None:
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG,
"nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG,
"nvfp4_wo_gptq": mtq.NVFP4_WEIGHT_ONLY_GPTQ_CFG,
"nvfp4_gptq": mtq.NVFP4_GPTQ_CFG,
"mxfp8": mtq.MXFP8_DEFAULT_CFG,
"nvfp4_local_hessian": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG,
}
Expand Down Expand Up @@ -676,6 +685,12 @@ def export_quantized(
"They will be set at deployment time."
)

if getattr(args, "eval_perplexity", False) and tokenizer is not None:
seq_len = getattr(args, "eval_perplexity_seq_len", 2048)
eval_data = get_wikitext2(tokenizer, seq_len)
ppl = compute_perplexity(full_model, eval_data)
print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}")

# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
# Store the MTP layer prefixes on the model for later exclusion from quantization
mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(full_model, args.pyt_ckpt_path)
Expand Down Expand Up @@ -907,6 +922,64 @@ def quantize_main(
args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model
)

# Collect original (unquantized) activations before quantization modifies the model
mse_logger = None
if getattr(args, "measure_activation_mse", False):
n_mse = getattr(args, "activation_mse_max_samples", 16)
mse_save_dir = getattr(args, "activation_mse_save_dir", None)
mse_input_path = getattr(args, "activation_mse_input_path", None)

# Resolve MSE input data: frozen file (raw text or tokenized) or live dataloader
mse_data = None
if mse_input_path is not None:
if mse_input_path.endswith(".json"):
if os.path.isfile(mse_input_path):
print(f"Loading MSE input data from existing .json file: {mse_input_path}")
texts = ActivationMSELogger.load_raw_text(mse_input_path)
mse_data = ActivationMSELogger.tokenize_raw_text(
texts,
tokenizer,
max_length=args.calib_seq,
)
else:
assert tokenizer is not None, (
"--activation_mse_input_path with .json requires a tokenizer to decode"
)
print(f"Creating MSE input data .json file: {mse_input_path}")
texts = ActivationMSELogger.materialize_raw_text(
calib_dataloader,
mse_input_path,
tokenizer=tokenizer,
max_samples=n_mse,
)
mse_data = ActivationMSELogger.tokenize_raw_text(
texts,
tokenizer,
max_length=args.calib_seq,
)
elif mse_input_path.endswith(".pt"):
if os.path.isfile(mse_input_path):
print(f"Loading MSE input data from existing .pt file: {mse_input_path}")
mse_data = ActivationMSELogger.load_data(mse_input_path)
else:
print(f"Creating MSE input data .pt file: {mse_input_path}")
mse_data = ActivationMSELogger.materialize_data(
calib_dataloader,
mse_input_path,
max_samples=n_mse,
)
else:
raise ValueError(
f"--activation_mse_input_path must end with .json or .pt, got: {mse_input_path}"
)

if mse_data is None:
mse_data = calib_dataloader

mse_logger = ActivationMSELogger(max_samples=n_mse, save_dir=mse_save_dir)
print(f"Collecting original (unquantized) activations for MSE over {n_mse} samples...")
mse_logger.collect(language_model, mse_data, phase="original")

if args.auto_quantize_bits:
assert len(args.qformat.split(",")) > 1, (
"Auto quantization needs multiple quantization format."
Expand Down Expand Up @@ -999,6 +1072,23 @@ def quantize_main(
is_nemotron_vl_model,
first_text_speech_dataset,
)

if mse_logger is not None:
import gc

print("Collecting quantized activations for MSE...")
mse_logger.collect(language_model, mse_data, phase="quantized")

mse_logger.compute_mse()
print(mse_logger.summary())

if getattr(args, "activation_mse_save_dir", None):
mse_logger.save()

del mse_logger, mse_data
gc.collect()
torch.cuda.empty_cache()

export_quantized(
args,
full_model,
Expand Down Expand Up @@ -1157,6 +1247,7 @@ def parse_args() -> argparse.Namespace:
default=False,
action="store_true",
)

parser.add_argument(
"--low_memory_mode",
help=(
Expand Down Expand Up @@ -1215,6 +1306,48 @@ def parse_args() -> argparse.Namespace:
"Does not impact non-MOE models."
),
)
parser.add_argument(
"--eval_perplexity",
action=argparse.BooleanOptionalAction,
default=False,
help="Evaluate Wikitext-2 perplexity after quantization (before export).",
)
parser.add_argument(
"--eval_perplexity_seq_len",
type=int,
default=2048,
help="Sequence length for perplexity evaluation (default: 2048).",
)
parser.add_argument(
"--measure_activation_mse",
action=argparse.BooleanOptionalAction,
default=False,
help="Measure per-layer activation MSE (original vs quantized) after quantization.",
)
parser.add_argument(
"--activation_mse_max_samples",
type=int,
default=16,
help="Max calibration samples for activation MSE (default: 16).",
)
parser.add_argument(
"--activation_mse_save_dir",
type=str,
default=None,
help="Directory to save activation MSE results. If not set, results are only printed.",
)
parser.add_argument(
"--activation_mse_input_path",
type=str,
default=None,
help=(
"Path to frozen MSE input data. Supports two formats:\n"
" .json — raw text (cross-model reuse): if file exists, loads and re-tokenizes "
"with the current model's tokenizer; if not, decodes calibration data to text and saves.\n"
" .pt — tokenized tensors (same-tokenizer reuse): if file exists, loads directly; "
"if not, materializes from calibration data and saves."
),
)

args = parser.parse_args()
if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0):
Expand Down
8 changes: 7 additions & 1 deletion modelopt/torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@
"""Quantization package."""

# Initialize mode and plugins
from . import mode, plugins, utils
from . import metrics, mode, plugins, utils

# Add methods to mtq namespace
from .compress import *
from .config import *
from .conversion import *
from .metrics import (
ActivationMSELogger,
compute_perplexity,
get_wikitext2,
measure_per_layer_activation_mse,
)
from .model_quant import *
from .nn.modules.quant_module import QuantModuleRegistry
from .utils import update_quant_cfg_with_kv_cache_quant
88 changes: 88 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,20 @@
"algorithm": "max",
}

NVFP4_WEIGHT_ONLY_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {"enable": False},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}

NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG = {
"quant_cfg": {
"*weight_quantizer": {
Expand Down Expand Up @@ -459,6 +473,39 @@
},
}

NVFP4_WEIGHT_ONLY_GPTQ_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {"enable": False},
**_default_disabled_quantizer_cfg,
},
"algorithm": {"method": "gptq", "use_sequential": True},
}

NVFP4_GPTQ_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {"method": "gptq", "use_sequential": True},
}

MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = {
"quant_cfg": {
"*weight_quantizer": _nvfp4_quantizer,
Expand Down Expand Up @@ -679,6 +726,9 @@
"NVFP4_AWQ_FULL_CFG",
"NVFP4_AWQ_LITE_CFG",
"NVFP4_DEFAULT_CFG",
"NVFP4_GPTQ_CFG",
"NVFP4_WEIGHT_ONLY_CFG",
"NVFP4_WEIGHT_ONLY_GPTQ_CFG",
"NVFP4_FP8_MHA_CONFIG",
"NVFP4_KV_CFG",
"NVFP4_KV_ROTATE_CFG",
Expand Down Expand Up @@ -1392,6 +1442,44 @@ class GPTQLiteConfig(QuantizeAlgorithmConfig):
)


class GPTQConfig(QuantizeAlgorithmConfig):
"""The config for GPTQ lite.

GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation.

GPTQ lite does not perform sequential quantization of layers. This means that the updated
activations are not used to process the next layer.

The default values are taken from the official GPTQ implementation:
https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L35

Note: This feature is currently experimental and may not translate to improved accuracy as expected.


"""

method: Literal["gptq"] = ModeloptField("gptq")
percdamp: float | None = ModeloptField(
default=0.01,
gt=0.0,
le=1.0,
title="Percentage damping factor.",
description="The percentage of average Hessian diagonal used for damping.",
)
block_size: int | None = ModeloptField(
default=128,
title="Block size for GPTQ weight update.",
description="""The block size for GPTQ weight update, which must be a multiple of the
group_size used in the quantization.""",
)
hessian_state_path: str | None = ModeloptField(
default=None,
title="Path to the Hessian state file.",
description="""The path to the Hessian state file. If hessian path exists, we load from
hessian file instead of recomputing them.""",
)


QuantizeQuantCfgType = dict[
str | Callable,
QuantizerAttributeConfig
Expand Down
18 changes: 16 additions & 2 deletions modelopt/torch/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
AWQFullCalibConfig,
AWQLiteCalibConfig,
CompressConfig,
GPTQConfig,
GPTQLiteConfig,
LocalHessianCalibConfig,
MaxCalibConfig,
Expand All @@ -59,6 +60,7 @@
)
from .model_calib import (
awq,
gptq,
gptq_lite,
local_hessian_calibrate,
max_calibrate,
Expand Down Expand Up @@ -240,8 +242,8 @@ def wrapped_calib_func(
if sequential:
if forward_loop is None:
raise ValueError("forward_loop is required for calibration but got None.")
assert method in ["max"], (
f"Sequential calibration currently only supports max calibration, got {method}"
assert method in ["max", "gptq"], (
f"Sequential calibration currently only supports max and gptq calibration, got {method}"
)
# Wrap with sequential processing
sequential_calibrate(
Expand Down Expand Up @@ -502,3 +504,15 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]:
return GPTQLiteConfig

_calib_func = gptq_lite


@CalibrateModeRegistry.register_mode
class GPTQModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for GPTQ calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
"""Specifies the config class for the mode."""
return GPTQConfig

_calib_func = gptq
Loading
Loading