Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 0 additions & 104 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from modelopt.torch.quantization.utils import (
QuantizerAttrNames,
quantizer_attr_names,
reduce_block_amax,
representative_weight_quantizer,
weight_attr_names,
)
Expand Down Expand Up @@ -241,100 +240,6 @@ def get_scaling_factor(quantizer: TensorQuantizer) -> torch.Tensor:
return scaling_factor


def _get_nvfp4_block_size(
weight_quantizer: NVFP4StaticQuantizer, weight: torch.Tensor, module_name: str = ""
) -> int:
"""Return block size for NVFP4 from quantizer's block_sizes; raise if missing."""
prefix = f"NVFP4StaticQuantizer{f' for {module_name}' if module_name else ''}"
block_sizes = weight_quantizer.block_sizes
if block_sizes is None:
raise ValueError(f"{prefix} has no block_sizes; cannot compute per-block amax from weight.")
block_size = block_sizes.get(-1) or block_sizes.get(weight.dim() - 1)
if block_size is None:
raise ValueError(
f"{prefix} block_sizes has no -1 or last-dim key; cannot compute per-block amax."
)
return block_size


def _set_amax_from_tensor(weight_quantizer: TensorQuantizer, tensor: torch.Tensor) -> None:
"""Set quantizer _amax buffer from tensor; copy in-place if same shape, else replace buffer."""
if (
hasattr(weight_quantizer, "_amax")
and weight_quantizer._amax is not None
and weight_quantizer._amax.shape == tensor.shape
):
weight_quantizer._amax.data.copy_(tensor.to(weight_quantizer._amax.device))
else:
if hasattr(weight_quantizer, "_amax"):
delattr(weight_quantizer, "_amax")
weight_quantizer.register_buffer("_amax", tensor.clone().detach())


def _ensure_weight_quantizer_calibrated(
weight_quantizer: TensorQuantizer, weight: torch.Tensor, module_name: str = ""
) -> None:
"""Calibrate weight quantizer if amax is not set.

This is a lazy calibration pattern used during export when weight quantizers
may not have been calibrated during the main calibration phase.

For NVFP4StaticQuantizer, _amax is per-block amax and _global_amax is the max over
blocks; both are computed from the weight when missing.

Args:
weight_quantizer: The weight quantizer to calibrate
weight: The weight tensor to use for calibration
module_name: Optional module name for better warning messages
"""
if isinstance(weight_quantizer, NVFP4StaticQuantizer):

def _amax_is_invalid(t: torch.Tensor | None) -> bool:
# MCore distcp may register but not fill amax — treat missing/non-finite/negative as recompute.
if t is None:
return True
t = t.detach()
if not torch.is_floating_point(t):
return False
return bool((~torch.isfinite(t) | (t < 0)).any().item())

need_per_block = (
not hasattr(weight_quantizer, "_amax")
or weight_quantizer._amax is None
or _amax_is_invalid(weight_quantizer._amax)
)
need_global = (
not hasattr(weight_quantizer, "_global_amax")
or weight_quantizer.global_amax is None
or _amax_is_invalid(weight_quantizer.global_amax)
)
if not (need_per_block or need_global):
return
block_size = _get_nvfp4_block_size(weight_quantizer, weight, module_name)
warn(
f"NVFP4StaticQuantizer{f' for {module_name}' if module_name else ''} was not fully calibrated. "
f"Computing per-block amax and global_amax from weights. This may occur if: "
f"some experts were not activated during calibration (expected for MoE models), try increasing --calib_size"
)
per_block_amax = reduce_block_amax(weight, block_sizes={-1: block_size})
if need_per_block:
_set_amax_from_tensor(weight_quantizer, per_block_amax.to(weight.device))
if need_global:
weight_quantizer.global_amax = per_block_amax.max()
return

if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None:
warn(
f"Weight quantizer{f' for {module_name}' if module_name else ''} was not calibrated. "
f"Computing amax from weights. This may occur if: "
f"some experts were not activated during calibration (expected for MoE models), try increasing --calib_size"
)
weight_quantizer.reset_amax()
enable_stats_collection(weight_quantizer)
weight_quantizer(weight)
finish_stats_collection(weight_quantizer)


def get_activation_scaling_factor(
module: nn.Module, input_quantizer_name: str = "input_quantizer"
) -> torch.Tensor:
Expand Down Expand Up @@ -379,10 +284,6 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
QUANTIZATION_W4A16_NVFP4,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
# Calibrate weight quantizer if amax is not set
module_name = f"{type(module).__name__}.{weight_name}"
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)

if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
# This is because the kernel dequantizes weight to fp8, which is in range 448.
Expand Down Expand Up @@ -424,11 +325,6 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
QUANTIZATION_W4A16_NVFP4,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
# Calibrate weight quantizer if amax is not set
weight = getattr(module, weight_name)
module_name = f"{type(module).__name__}.{weight_name}"
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)

if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
# This is because the kernel dequantizes weight to fp8, which is in range 448.
Expand Down
52 changes: 8 additions & 44 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,47 +110,6 @@ def _collect_weight_stats(quantizer: nn.Module, weight: torch.Tensor) -> None:
quantizer(weight)


@torch.no_grad()
def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int:
"""Re-run weight calibration on the weight tensor for quantizers missing ``_amax``.

Covers MoE experts that ``max_calibrate`` skipped (no routed tokens) so MSE
doesn't drop them and break the gate==up ``weight_scale_2`` export invariant.
Activation quantizers on those modules remain uncalibrated; emits a warning.
"""
name_to_module = dict(model.named_modules())
n = 0
for module in name_to_module.values():
if not isinstance(module, QuantModule):
continue
with enable_weight_access_and_writeback(module, model, name_to_module):
for weight, q in module.iter_weights_for_calibration():
if (
not isinstance(q, TensorQuantizer)
or q._disabled
or q._dynamic
or q._calibrator is None
):
continue
if weight.is_meta:
continue
amax = q.amax
if amax is not None and (amax.is_meta or not torch.all(amax == 0)):
continue
_run_and_load_max_stats(q, partial(_collect_weight_stats, weight=weight))
if hasattr(q._calibrator, "reset"):
q._calibrator.reset()
n += 1
if n > 0:
warnings.warn(
f"Bootstrapped {n} weight quantizer(s) with no routed calibration tokens; "
f"their activation quantizers (if any) remain uncalibrated. "
f"Increase calib size/seq len to activate all experts.",
stacklevel=2,
)
return n


@torch.no_grad()
def _sync_grouped_weight_global_amax(model: nn.Module) -> int:
"""Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers.
Expand Down Expand Up @@ -304,7 +263,14 @@ def max_calibrate(
See :class:`MaxCalibConfig <modelopt.torch.quantization.config.MaxCalibConfig>` for
details on the remaining arguments.
"""
_run_and_load_max_stats(model, forward_loop)
# Always run weight calibration on the weight tensor directly so every weight
Comment thread
sychen52 marked this conversation as resolved.
# quantizer gets ``_amax``, regardless of MoE routing. Downstream algorithms
# (MSE, AWQ, export) then no longer need to patch in a missing ``_amax``.
enable_stats_collection(model)
weight_only_quantize(model)
Comment thread
sychen52 marked this conversation as resolved.
if forward_loop is not None:
forward_loop(model)
finish_stats_collection(model)

# Sync quantizer amax across local experts within each rank (for SequentialMLP)
for name, module in model.named_modules():
Expand All @@ -314,8 +280,6 @@ def max_calibrate(
# Fail fast on NVFP4 static-block with TP>1 (sharded_state_dict treats _amax as replicated).
_check_nvfp4_static_tp_supported(model)

_bootstrap_uncalibrated_weight_quantizers(model)

# Promote eligible static-block NVFP4 weight quantizers to NVFP4StaticQuantizer so
# the static blockwise fake-quant path is used in forward and export picks up the
# two-level (per-block + global) scaling. Run before the ``distributed_sync`` early
Expand Down
63 changes: 2 additions & 61 deletions tests/gpu/torch/export/test_export_weight_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@

import torch
import torch.nn as nn
from _test_utils.torch.export.utils import ToyModel, partial_nvfp4_config, partial_w4a8_config
from _test_utils.torch.export.utils import ToyModel, partial_w4a8_config
from torch.nn import functional as F
from torch.nn import init

import modelopt.torch.quantization as mtq
from modelopt.torch.export.unified_export_hf import _export_quantized_weight
from modelopt.torch.quantization.nn import NVFP4StaticQuantizer
from modelopt.torch.quantization.nn.modules.quant_module import QuantModule, QuantModuleRegistry
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
from modelopt.torch.quantization.tensor_quant import QUANT_DESC_8BIT_PER_TENSOR
from modelopt.torch.quantization.utils import quantizer_attr_names, reduce_block_amax
from modelopt.torch.quantization.utils import quantizer_attr_names


class ToyLinear(nn.Module):
Expand Down Expand Up @@ -122,61 +121,3 @@ def test_export_per_block_quantized_weight():
assert hasattr(model.linears[2], quantizer_attrs.output_quantizer)
assert not getattr(model.linears[2], quantizer_attrs.output_quantizer).is_enabled
assert not hasattr(model.linears[2], quantizer_attrs.output_scale)


def test_export_nvfp4_static_weight_dynamic_vs_static_match():
"""Dynamic vs static NVFP4 export: same weight and scales after export even when amaxs are
cleared on one layer (lazy calibration via _ensure_weight_quantizer_calibrated fills them from weights).
"""
device = "cuda"
dims = [32, 32, 32, 32]
block_size = 16
calib_input = torch.randn(1, 4, 32, device=device)
nvfp4_layer_indices = [1, 2] # layers with NVFP4 enabled in partial_nvfp4_config

torch.manual_seed(42)
model_dynamic = ToyModel(dims=dims).to(device)
mtq.quantize(model_dynamic, partial_nvfp4_config, lambda x: x(calib_input))

torch.manual_seed(42)
model_static = ToyModel(dims=dims).to(device)
mtq.quantize(model_static, partial_nvfp4_config, lambda x: x(calib_input))

# Convert NVFP4 layers to NVFP4StaticQuantizer with per-block and global amax
for idx in nvfp4_layer_indices:
layer = model_static.linears[idx]
weight = layer.weight.data
per_block_amax = reduce_block_amax(weight, block_sizes={-1: block_size})
tq = layer.weight_quantizer
if hasattr(tq, "_amax"):
delattr(tq, "_amax")
tq.register_buffer("_amax", per_block_amax.to(weight.device).clone().detach())
NVFP4StaticQuantizer.from_tensor_quantizer(tq, global_amax=per_block_amax.max())

# Clear amaxs on layer 1 to exercise lazy calibration during export
for linear, is_static in [(model_dynamic.linears[1], False), (model_static.linears[1], True)]:
wq = linear.weight_quantizer
if hasattr(wq, "_amax"):
delattr(wq, "_amax")
if is_static and hasattr(wq, "_global_amax"):
delattr(wq, "_global_amax")

quantizer_attrs = quantizer_attr_names("weight")
for idx in nvfp4_layer_indices:
_export_quantized_weight(model_dynamic.linears[idx], torch.float32, "weight")
_export_quantized_weight(model_static.linears[idx], torch.float32, "weight")

for idx in nvfp4_layer_indices:
dyn_linear = model_dynamic.linears[idx]
sta_linear = model_static.linears[idx]
assert torch.equal(dyn_linear.weight, sta_linear.weight), (
f"Layer {idx}: exported NVFP4 weight should match (dynamic vs static)"
)
assert torch.allclose(
getattr(dyn_linear, quantizer_attrs.weight_scale).float(),
getattr(sta_linear, quantizer_attrs.weight_scale).float(),
), f"Layer {idx}: weight_scale should match"
assert torch.allclose(
getattr(dyn_linear, quantizer_attrs.weight_scale_2).float(),
getattr(sta_linear, quantizer_attrs.weight_scale_2).float(),
), f"Layer {idx}: weight_scale_2 should match"
Loading