diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 8ae40cdf67c..d8ddf442924 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -41,7 +41,6 @@ from modelopt.torch.quantization.utils import ( QuantizerAttrNames, quantizer_attr_names, - reduce_block_amax, representative_weight_quantizer, weight_attr_names, ) @@ -241,100 +240,6 @@ def get_scaling_factor(quantizer: TensorQuantizer) -> torch.Tensor: return scaling_factor -def _get_nvfp4_block_size( - weight_quantizer: NVFP4StaticQuantizer, weight: torch.Tensor, module_name: str = "" -) -> int: - """Return block size for NVFP4 from quantizer's block_sizes; raise if missing.""" - prefix = f"NVFP4StaticQuantizer{f' for {module_name}' if module_name else ''}" - block_sizes = weight_quantizer.block_sizes - if block_sizes is None: - raise ValueError(f"{prefix} has no block_sizes; cannot compute per-block amax from weight.") - block_size = block_sizes.get(-1) or block_sizes.get(weight.dim() - 1) - if block_size is None: - raise ValueError( - f"{prefix} block_sizes has no -1 or last-dim key; cannot compute per-block amax." - ) - return block_size - - -def _set_amax_from_tensor(weight_quantizer: TensorQuantizer, tensor: torch.Tensor) -> None: - """Set quantizer _amax buffer from tensor; copy in-place if same shape, else replace buffer.""" - if ( - hasattr(weight_quantizer, "_amax") - and weight_quantizer._amax is not None - and weight_quantizer._amax.shape == tensor.shape - ): - weight_quantizer._amax.data.copy_(tensor.to(weight_quantizer._amax.device)) - else: - if hasattr(weight_quantizer, "_amax"): - delattr(weight_quantizer, "_amax") - weight_quantizer.register_buffer("_amax", tensor.clone().detach()) - - -def _ensure_weight_quantizer_calibrated( - weight_quantizer: TensorQuantizer, weight: torch.Tensor, module_name: str = "" -) -> None: - """Calibrate weight quantizer if amax is not set. - - This is a lazy calibration pattern used during export when weight quantizers - may not have been calibrated during the main calibration phase. - - For NVFP4StaticQuantizer, _amax is per-block amax and _global_amax is the max over - blocks; both are computed from the weight when missing. - - Args: - weight_quantizer: The weight quantizer to calibrate - weight: The weight tensor to use for calibration - module_name: Optional module name for better warning messages - """ - if isinstance(weight_quantizer, NVFP4StaticQuantizer): - - def _amax_is_invalid(t: torch.Tensor | None) -> bool: - # MCore distcp may register but not fill amax — treat missing/non-finite/negative as recompute. - if t is None: - return True - t = t.detach() - if not torch.is_floating_point(t): - return False - return bool((~torch.isfinite(t) | (t < 0)).any().item()) - - need_per_block = ( - not hasattr(weight_quantizer, "_amax") - or weight_quantizer._amax is None - or _amax_is_invalid(weight_quantizer._amax) - ) - need_global = ( - not hasattr(weight_quantizer, "_global_amax") - or weight_quantizer.global_amax is None - or _amax_is_invalid(weight_quantizer.global_amax) - ) - if not (need_per_block or need_global): - return - block_size = _get_nvfp4_block_size(weight_quantizer, weight, module_name) - warn( - f"NVFP4StaticQuantizer{f' for {module_name}' if module_name else ''} was not fully calibrated. " - f"Computing per-block amax and global_amax from weights. This may occur if: " - f"some experts were not activated during calibration (expected for MoE models), try increasing --calib_size" - ) - per_block_amax = reduce_block_amax(weight, block_sizes={-1: block_size}) - if need_per_block: - _set_amax_from_tensor(weight_quantizer, per_block_amax.to(weight.device)) - if need_global: - weight_quantizer.global_amax = per_block_amax.max() - return - - if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None: - warn( - f"Weight quantizer{f' for {module_name}' if module_name else ''} was not calibrated. " - f"Computing amax from weights. This may occur if: " - f"some experts were not activated during calibration (expected for MoE models), try increasing --calib_size" - ) - weight_quantizer.reset_amax() - enable_stats_collection(weight_quantizer) - weight_quantizer(weight) - finish_stats_collection(weight_quantizer) - - def get_activation_scaling_factor( module: nn.Module, input_quantizer_name: str = "input_quantizer" ) -> torch.Tensor: @@ -379,10 +284,6 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> QUANTIZATION_W4A16_NVFP4, QUANTIZATION_W4A8_NVFP4_FP8, ]: - # Calibrate weight quantizer if amax is not set - module_name = f"{type(module).__name__}.{weight_name}" - _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) - if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. # This is because the kernel dequantizes weight to fp8, which is in range 448. @@ -424,11 +325,6 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") QUANTIZATION_W4A16_NVFP4, QUANTIZATION_W4A8_NVFP4_FP8, ]: - # Calibrate weight quantizer if amax is not set - weight = getattr(module, weight_name) - module_name = f"{type(module).__name__}.{weight_name}" - _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) - if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. # This is because the kernel dequantizes weight to fp8, which is in range 448. diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index ea95d77450b..223994b3c46 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -110,47 +110,6 @@ def _collect_weight_stats(quantizer: nn.Module, weight: torch.Tensor) -> None: quantizer(weight) -@torch.no_grad() -def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int: - """Re-run weight calibration on the weight tensor for quantizers missing ``_amax``. - - Covers MoE experts that ``max_calibrate`` skipped (no routed tokens) so MSE - doesn't drop them and break the gate==up ``weight_scale_2`` export invariant. - Activation quantizers on those modules remain uncalibrated; emits a warning. - """ - name_to_module = dict(model.named_modules()) - n = 0 - for module in name_to_module.values(): - if not isinstance(module, QuantModule): - continue - with enable_weight_access_and_writeback(module, model, name_to_module): - for weight, q in module.iter_weights_for_calibration(): - if ( - not isinstance(q, TensorQuantizer) - or q._disabled - or q._dynamic - or q._calibrator is None - ): - continue - if weight.is_meta: - continue - amax = q.amax - if amax is not None and (amax.is_meta or not torch.all(amax == 0)): - continue - _run_and_load_max_stats(q, partial(_collect_weight_stats, weight=weight)) - if hasattr(q._calibrator, "reset"): - q._calibrator.reset() - n += 1 - if n > 0: - warnings.warn( - f"Bootstrapped {n} weight quantizer(s) with no routed calibration tokens; " - f"their activation quantizers (if any) remain uncalibrated. " - f"Increase calib size/seq len to activate all experts.", - stacklevel=2, - ) - return n - - @torch.no_grad() def _sync_grouped_weight_global_amax(model: nn.Module) -> int: """Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers. @@ -304,7 +263,14 @@ def max_calibrate( See :class:`MaxCalibConfig ` for details on the remaining arguments. """ - _run_and_load_max_stats(model, forward_loop) + # Always run weight calibration on the weight tensor directly so every weight + # quantizer gets ``_amax``, regardless of MoE routing. Downstream algorithms + # (MSE, AWQ, export) then no longer need to patch in a missing ``_amax``. + enable_stats_collection(model) + weight_only_quantize(model) + if forward_loop is not None: + forward_loop(model) + finish_stats_collection(model) # Sync quantizer amax across local experts within each rank (for SequentialMLP) for name, module in model.named_modules(): @@ -314,8 +280,6 @@ def max_calibrate( # Fail fast on NVFP4 static-block with TP>1 (sharded_state_dict treats _amax as replicated). _check_nvfp4_static_tp_supported(model) - _bootstrap_uncalibrated_weight_quantizers(model) - # Promote eligible static-block NVFP4 weight quantizers to NVFP4StaticQuantizer so # the static blockwise fake-quant path is used in forward and export picks up the # two-level (per-block + global) scaling. Run before the ``distributed_sync`` early diff --git a/tests/gpu/torch/export/test_export_weight_gpu.py b/tests/gpu/torch/export/test_export_weight_gpu.py index f2b0ef404ea..2167f1e7936 100644 --- a/tests/gpu/torch/export/test_export_weight_gpu.py +++ b/tests/gpu/torch/export/test_export_weight_gpu.py @@ -17,17 +17,16 @@ import torch import torch.nn as nn -from _test_utils.torch.export.utils import ToyModel, partial_nvfp4_config, partial_w4a8_config +from _test_utils.torch.export.utils import ToyModel, partial_w4a8_config from torch.nn import functional as F from torch.nn import init import modelopt.torch.quantization as mtq from modelopt.torch.export.unified_export_hf import _export_quantized_weight -from modelopt.torch.quantization.nn import NVFP4StaticQuantizer from modelopt.torch.quantization.nn.modules.quant_module import QuantModule, QuantModuleRegistry from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer from modelopt.torch.quantization.tensor_quant import QUANT_DESC_8BIT_PER_TENSOR -from modelopt.torch.quantization.utils import quantizer_attr_names, reduce_block_amax +from modelopt.torch.quantization.utils import quantizer_attr_names class ToyLinear(nn.Module): @@ -122,61 +121,3 @@ def test_export_per_block_quantized_weight(): assert hasattr(model.linears[2], quantizer_attrs.output_quantizer) assert not getattr(model.linears[2], quantizer_attrs.output_quantizer).is_enabled assert not hasattr(model.linears[2], quantizer_attrs.output_scale) - - -def test_export_nvfp4_static_weight_dynamic_vs_static_match(): - """Dynamic vs static NVFP4 export: same weight and scales after export even when amaxs are - cleared on one layer (lazy calibration via _ensure_weight_quantizer_calibrated fills them from weights). - """ - device = "cuda" - dims = [32, 32, 32, 32] - block_size = 16 - calib_input = torch.randn(1, 4, 32, device=device) - nvfp4_layer_indices = [1, 2] # layers with NVFP4 enabled in partial_nvfp4_config - - torch.manual_seed(42) - model_dynamic = ToyModel(dims=dims).to(device) - mtq.quantize(model_dynamic, partial_nvfp4_config, lambda x: x(calib_input)) - - torch.manual_seed(42) - model_static = ToyModel(dims=dims).to(device) - mtq.quantize(model_static, partial_nvfp4_config, lambda x: x(calib_input)) - - # Convert NVFP4 layers to NVFP4StaticQuantizer with per-block and global amax - for idx in nvfp4_layer_indices: - layer = model_static.linears[idx] - weight = layer.weight.data - per_block_amax = reduce_block_amax(weight, block_sizes={-1: block_size}) - tq = layer.weight_quantizer - if hasattr(tq, "_amax"): - delattr(tq, "_amax") - tq.register_buffer("_amax", per_block_amax.to(weight.device).clone().detach()) - NVFP4StaticQuantizer.from_tensor_quantizer(tq, global_amax=per_block_amax.max()) - - # Clear amaxs on layer 1 to exercise lazy calibration during export - for linear, is_static in [(model_dynamic.linears[1], False), (model_static.linears[1], True)]: - wq = linear.weight_quantizer - if hasattr(wq, "_amax"): - delattr(wq, "_amax") - if is_static and hasattr(wq, "_global_amax"): - delattr(wq, "_global_amax") - - quantizer_attrs = quantizer_attr_names("weight") - for idx in nvfp4_layer_indices: - _export_quantized_weight(model_dynamic.linears[idx], torch.float32, "weight") - _export_quantized_weight(model_static.linears[idx], torch.float32, "weight") - - for idx in nvfp4_layer_indices: - dyn_linear = model_dynamic.linears[idx] - sta_linear = model_static.linears[idx] - assert torch.equal(dyn_linear.weight, sta_linear.weight), ( - f"Layer {idx}: exported NVFP4 weight should match (dynamic vs static)" - ) - assert torch.allclose( - getattr(dyn_linear, quantizer_attrs.weight_scale).float(), - getattr(sta_linear, quantizer_attrs.weight_scale).float(), - ), f"Layer {idx}: weight_scale should match" - assert torch.allclose( - getattr(dyn_linear, quantizer_attrs.weight_scale_2).float(), - getattr(sta_linear, quantizer_attrs.weight_scale_2).float(), - ), f"Layer {idx}: weight_scale_2 should match"