diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index ea95d77450b..68aee5cde86 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -33,8 +33,10 @@ LayerActivationCollector, _CheckpointState, ) -from modelopt.torch.utils import print_rank_0 +from modelopt.torch.utils import print_rank_0, warn_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState +from modelopt.torch.utils.distributed import is_initialized as dist_is_initialized +from modelopt.torch.utils.distributed import size as dist_size from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator @@ -51,7 +53,6 @@ persistent_materialization, promote_nvfp4_static_quantizers, quantizer_attr_names, - reduce_amax, ) from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper @@ -473,8 +474,13 @@ def _make_weight_mse_calibrator( start_multiplier: float, stop_multiplier: float, fp8_scale_sweep: bool, + error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, ) -> _Calibrator | None: - """Create the MSE calibrator for one eligible weight quantizer.""" + """Create the MSE calibrator for one eligible weight quantizer (``None`` if ineligible). + + ``error_func`` overrides the squared-error metric (local-Hessian's per-block weighting); + when set, NVFP4's Triton fast path is bypassed for the reference sweep. + """ if ( not isinstance(weight_quantizer, TensorQuantizer) or not weight_quantizer.is_enabled @@ -494,6 +500,13 @@ def _make_weight_mse_calibrator( _FP8_SWEEP_CALIBRATOR_REGISTRY.get(backend) if backend is not None else None ) if backend is not None and backend_factory is not None: + if error_func is not None: + # Registered backends can't take a custom error_func; skip Hessian refinement. + warnings.warn( + f"local_hessian: backend '{backend}' does not support a custom error " + "function; skipping Hessian-weighted calibration for this quantizer." + ) + return None return backend_factory(initial_amax, axis, quant_func) if _uses_modelopt_fp8_weight_scales(weight_quantizer): return NVFP4MSECalibrator( @@ -501,12 +514,12 @@ def _make_weight_mse_calibrator( axis=axis, global_amax=weight_quantizer.global_amax, quant_func=quant_func, + error_func=error_func, ) - # fp8_scale_sweep covers only registered backends and static NVFP4 weights; - # skip MSE calibration for all other quantizers (no multiplier search). + # fp8_scale_sweep applies only to registered backends and static NVFP4; skip others. return None - # fp8_scale_sweep disabled: multiplier-search MSE calibration for all quantizers. + # No fp8_scale_sweep: multiplier-search MSE for all quantizers. return MseCalibrator( amax=initial_amax, axis=axis, @@ -514,6 +527,7 @@ def _make_weight_mse_calibrator( start_multiplier=start_multiplier, stop_multiplier=stop_multiplier, quant_func=quant_func, + error_func=error_func, ) @@ -553,6 +567,31 @@ def mse_calibrate( # max_calibrate initializes activations and weights; MSE only refines weights below. max_calibrate(model, forward_loop, distributed_sync) name_to_module = dict(model.named_modules()) + _mse_calibrate_weights( + model, + name_to_module, + step_size=step_size, + start_multiplier=start_multiplier, + stop_multiplier=stop_multiplier, + fp8_scale_sweep=fp8_scale_sweep, + ) + + +@torch.no_grad() +def _mse_calibrate_weights( + model: nn.Module, + name_to_module: dict[str, nn.Module], + step_size: float, + start_multiplier: float, + stop_multiplier: float, + fp8_scale_sweep: bool, + error_func_for: Callable[[TensorQuantizer], Callable | None] | None = None, +): + """Run MSE weight calibration over all eligible quantizers (shared by mse / local-Hessian). + + ``error_func_for`` maps a weight quantizer to an optional per-weight error function + (local-Hessian's Hessian metric); ``None`` means plain squared error. + """ seen_modules: set[int] = set() pbar = tqdm(desc="MSE weight calibration") for parent_module in name_to_module.values(): @@ -561,12 +600,14 @@ def mse_calibrate( seen_modules.add(id(parent_module)) with enable_weight_access_and_writeback(parent_module, model, name_to_module): for weight, weight_quantizer in parent_module.iter_weights_for_calibration(): + error_func = error_func_for(weight_quantizer) if error_func_for else None cal = _make_weight_mse_calibrator( weight_quantizer, step_size, start_multiplier, stop_multiplier, fp8_scale_sweep, + error_func=error_func, ) if cal is None: continue @@ -581,6 +622,90 @@ def mse_calibrate( pbar.close() +class _LocalHessianAccumulator: + """Per-block local Hessian ``H = ΣXᵀX`` for one weight quantizer. + + Partitioned over ``cin`` into ``cin // block_size`` blocks to match the NVFP4 per-block + scale; the buffer is allocated lazily so never-routed experts cost nothing. + """ + + def __init__(self, cout: int, cin: int, block_size: int): + self.cout = cout + self.cin = cin + self.block_size = block_size + self.num_blocks_per_cin = cin // block_size + # Not block-divisible -> no Hessian (falls back to plain MSE). + self.is_enabled = cin % block_size == 0 + self.hessian_per_block: torch.Tensor | None = None + self.num_samples = 0 + + @torch.no_grad() + def accumulate(self, input_tensor: torch.Tensor) -> None: + """Accumulate ``XᵀX`` per block from an activation of shape ``(..., cin)``.""" + if not self.is_enabled: + return + # fp32 GEMM avoids bf16/fp16 precision loss; (cin, tokens) -> (n_blocks, bs, tokens). + x = input_tensor.reshape(-1, self.cin).to(torch.float32).T + x = x.reshape(self.num_blocks_per_cin, self.block_size, -1) + hessian_batch = x @ x.transpose(-1, -2) + if self.hessian_per_block is None: + self.hessian_per_block = hessian_batch + else: + self.hessian_per_block += hessian_batch + self.num_samples += input_tensor.numel() // self.cin + + def build_error_func( + self, keep_buffer: bool = False + ) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None: + """Hessian-weighted error function (``None`` if no samples). + + Frees the raw Hessian buffer unless ``keep_buffer`` (kept for debug inspection). + """ + if self.hessian_per_block is None or self.num_samples == 0: + return None + cout = self.cout + bs = self.block_size + hessian = self.hessian_per_block / self.num_samples + if not keep_buffer: + self.hessian_per_block = None + + def local_hessian_error(x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: + original_shape = x.shape + # Per-block weighted error: dw (cout,n,bs) · H (n,bs,bs) -> (cout,n). + dw = (x - xq).view(cout, -1, bs) + block_loss = torch.einsum("cnb,nbd,cnd->cn", dw, hessian, dw).reshape(-1) + return block_loss.unsqueeze(-1).expand(-1, bs).reshape(original_shape) + + return local_hessian_error + + +def _warn_if_block_size_mismatch(weight_quantizer: TensorQuantizer, block_size: int, name: str): + """Warn if the Hessian block_size differs from the quantizer's scale block (misaligns).""" + block_sizes = getattr(weight_quantizer, "block_sizes", None) + quant_block = block_sizes.get(-1) if block_sizes else None + if quant_block is not None and quant_block != block_size: + warn_rank_0( + f"local_hessian: block_size ({block_size}) != quantizer scale block " + f"({quant_block}) for {name}; Hessian weighting will not align with the scale blocks." + ) + + +def _warn_local_hessian_fallback(name, weight, weight_quantizer, block_size, warned: set): + """Warn once per ``(name, cin)`` when a captured layer falls back to plain MSE.""" + if weight.dim() < 2: + return + cin = weight.shape[1] + if (name, cin) in warned: + return + warned.add((name, cin)) + if cin % block_size != 0: + warn_rank_0( + f"local_hessian: {name} input features ({cin}) not divisible by block_size " + f"({block_size}); falling back to plain MSE for these weights." + ) + _warn_if_block_size_mismatch(weight_quantizer, block_size, name) + + @torch.no_grad() def local_hessian_calibrate( model: nn.Module, @@ -593,14 +718,16 @@ def local_hessian_calibrate( block_size: int = 16, debug: bool = False, ): - """Calibrate the model using local Hessian-weighted MSE search. + """Calibrate weight quantizers by minimizing the Hessian-weighted error. - Instead of minimizing weight error ``||W - Wq||²``, this minimizes Hessian-weighted error - ``loss = (W - Wq)ᵀ H (W - Wq)`` where ``H = X @ X.T`` approximates output reconstruction - error ``||WX - WqX||²``. + Minimizes ``(W - Wq)ᵀ H (W - Wq)`` with per-block Hessian ``H = ΣXᵀX`` (approximating the + output error ``||WX - WqX||²``), built from a forward with weight fake-quant disabled + (input quantizers untouched) and fed to :func:`mse_calibrate`'s weight search via ``error_func``. - Per-block Hessians of shape ``(cin // block_size, block_size, block_size)`` are accumulated - during forward pass and used to weight the MSE loss during scale search. + Like :func:`mse_calibrate`, TensorQuantizer weights are calibrated — with the Hessian + metric where a weight pairs with its input activations (dense linears and HF fused-MoE + experts), plain MSE otherwise. Other quantizer types (e.g. SequentialQuantizer) are + unsupported and left at their max-calibrated scale. Args: model: Model to be calibrated. @@ -613,7 +740,8 @@ def local_hessian_calibrate( fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values for NVFP4 per-block quantization (default: True). block_size: Block size for local Hessian computation (default: 16). - debug: If True, keep the local Hessian metadata on modules. + debug: If True, retain the per-quantizer Hessian accumulators on the model + (``model._local_hessian_accumulators``) for inspection. See :class:`LocalHessianCalibConfig ` for details on the configuration options. @@ -622,221 +750,107 @@ def local_hessian_calibrate( warnings.warn("forward_loop must be provided for local_hessian; skipping local_hessian") return - class LocalHessianHelper: - """Helper class to collect activations and compute local Hessian per module.""" - - cache_mode: bool = False - - def __init__(self, module, name): - self.name = name - self.module = module - self.weight_shape = module.weight.shape # (cout, cin) - self.cout, self.cin = self.weight_shape - self.block_size = block_size - self.num_blocks_per_cin = self.cin // block_size - self.is_enabled = True - - # Accumulated Hessian per block: (cin // block_size, block_size, block_size) - self.hessian_per_block = torch.zeros( - self.num_blocks_per_cin, - block_size, - block_size, - dtype=torch.float32, - device=module.weight.device, - ) - self.num_samples = 0 - - def setup(self): - """Set up the forward hook to collect activations.""" - module = self.module - bind_forward_method(module, forward, "_forward_no_local_hessian") - - # Check if cin is divisible by block_size - if self.cin % self.block_size != 0: - warnings.warn( - f"Module {self.name}: input features ({self.cin}) not divisible by " - f"block_size ({self.block_size}). Skipping local Hessian for this module." - ) - self.is_enabled = False - - def cleanup(self): - """Clean up the forward hook.""" - unpatch_forward_method(self.module, "_forward_no_local_hessian") - if not debug: - if hasattr(self.module, "hessian_helper"): - delattr(self.module, "hessian_helper") - - def accumulate_hessian(self, input_tensor: torch.Tensor): - """Accumulate local Hessian from input activations. - - Args: - input_tensor: Input tensor of shape (..., cin) - """ - if not self.is_enabled: - return - - # Flatten to (num_tokens, cin) - x = input_tensor.reshape(-1, self.cin).T # (cin, num_tokens) - x = x.reshape(self.num_blocks_per_cin, self.block_size, -1) # (num_blocks, bs, n) - - # Compute H = X @ X.T for each block and accumulate - hessian_batch = (x @ x.transpose(-1, -2)).to(torch.float32) - self.hessian_per_block += hessian_batch - self.num_samples += input_tensor.numel() // self.cin - - def get_error_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: - """Get the local Hessian error function for MSE calibration.""" - cout = self.cout - bs = self.block_size - # Normalize hessian by number of samples - hessian = self.hessian_per_block / max(self.num_samples, 1) - - def local_hessian_error(x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: - """Compute local Hessian-weighted error.""" - original_shape = x.shape - # Reshape to (cout, num_blocks_per_cin, block_size) - dw = (x - xq).view(cout, -1, bs) - # Use einsum to avoid materializing cout-repeated Hessian - # dw: (cout, n_blocks, bs), hessian: (n_blocks, bs, bs) -> (cout, n_blocks) - block_loss = torch.einsum("cnb,nbd,cnd->cn", dw, hessian, dw) - block_loss = block_loss.reshape(-1) - error = block_loss.unsqueeze(-1).expand(-1, bs).reshape(original_shape) - return error - - return local_hessian_error - - def forward(self, input, *args, **kwargs): - """Custom forward that collects activations in cache mode.""" - if LocalHessianHelper.cache_mode and self.hessian_helper.is_enabled: - # Get local tensor from DTensor if applicable - input_local = input.to_local() if hasattr(input, "to_local") else input - self.hessian_helper.accumulate_hessian(input_local) - - # Forward without quantization during caching - if LocalHessianHelper.cache_mode: - self.weight_quantizer.disable() - out = self._forward_no_local_hessian(input, *args, **kwargs) - self.weight_quantizer.enable() - return out - - return self._forward_no_local_hessian(input, *args, **kwargs) - - # First, run max_calibrate on the whole model to get initial amax for all quantizers - # This calibrates both weight_quantizer and input_quantizer with max calibration + # Phase 1: max-calibrate (also bootstraps dead experts + promotes/syncs NVFP4 static). print_rank_0("local_hessian: Running max calibration for all quantizers...") max_calibrate(model, forward_loop, distributed_sync) - # Setup helpers for all quantized linear modules name_to_module = dict(model.named_modules()) - weight_quantizers_info = [] - all_patched_modules = [] # Track all modules for cleanup (including disabled ones) + # Hessians keyed by id(weight_quantizer); modules pair weights<->activations via the hook. + accumulators: dict[int, _LocalHessianAccumulator] = {} + + def capture(weight_quantizer, weight, input_tensor): + input_local = input_tensor.to_local() if hasattr(input_tensor, "to_local") else input_tensor + acc = accumulators.get(id(weight_quantizer)) + if acc is None: + acc = _LocalHessianAccumulator(weight.shape[0], weight.shape[1], block_size) + accumulators[id(weight_quantizer)] = acc + acc.accumulate(input_local) + + # Phase 2: register capture hooks, disable weight fake-quant (input quantizers left as-is, + # matching prior behavior), run one forward to accumulate Hessians. Hooks live only for it. + handles: list = [] + silenced_weight_quantizers: list[TensorQuantizer] = [] + warned: set = set() + seen_modules: set[int] = set() for name, module in name_to_module.items(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - with enable_weight_access_and_writeback(module, model, name_to_module): - module.hessian_helper = LocalHessianHelper(module, name) - module.hessian_helper.setup() - all_patched_modules.append((name, module)) - if module.hessian_helper.is_enabled: - weight_quantizers_info.append((name, module)) - - # Cache activations by running forward loop - LocalHessianHelper.cache_mode = True - print_rank_0("local_hessian: Caching activations and computing local Hessian...") - forward_loop(model) - - # TODO(fridah-nv): Sync Hessian across distributed processes if needed - - # Replace calibrators with MseCalibrator using local Hessian error function - print_rank_0("local_hessian: Running MSE calibration with local Hessian loss...") - for name, module in weight_quantizers_info: - weight_quantizer = module.weight_quantizer - helper = module.hessian_helper - - if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None: + if not isinstance(module, QuantModule) or id(module) in seen_modules: continue - - initial_amax = weight_quantizer._amax.clone().detach() - - def quant_func(x, amax, quantizer=weight_quantizer): - return _mse_quant_func(x, amax, quantizer) - - is_nvfp4_static = weight_quantizer.is_nvfp4_static - - if is_nvfp4_static and not isinstance(weight_quantizer, NVFP4StaticQuantizer): - global_amax = reduce_amax(initial_amax, axis=None) - NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax) - - error_func = helper.get_error_func() - - if fp8_scale_sweep and is_nvfp4_static: - weight_quantizer._calibrator = NVFP4MSECalibrator( - amax=initial_amax, - axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, - global_amax=weight_quantizer.global_amax, - quant_func=quant_func, - error_func=error_func, - ) - else: - weight_quantizer._calibrator = MseCalibrator( - amax=initial_amax, - axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, - step_size=step_size, - start_multiplier=start_multiplier, - stop_multiplier=stop_multiplier, - quant_func=quant_func, - error_func=error_func, - ) - - # Free cached memory before heavy calibration - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # Process weights ONE AT A TIME with immediate amax computation and cleanup - weight_list = [ - (name, module) - for name, module in weight_quantizers_info - if module.weight_quantizer._calibrator is not None - ] - - for idx, (name, module) in enumerate(weight_list): - weight_quantizer = module.weight_quantizer - cal = weight_quantizer._calibrator - - # Step 1: Calibrate this weight - weight_quantizer.disable_quant() - weight_quantizer.enable_calib() + seen_modules.add(id(module)) with enable_weight_access_and_writeback(module, model, name_to_module): - weight = module.weight - weight_quantizer(weight) - - # Step 2: IMMEDIATELY compute amax (before calibration data grows) - if cal.compute_amax() is not None: - weight_quantizer.load_calib_amax() - - weight_quantizer.enable_quant() - weight_quantizer.disable_calib() - - # Step 3: Sync all devices and reset calibrator for next weight - if torch.cuda.is_available(): - for dev_id in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + captures = module.register_calibration_input_hooks(capture) + handles.extend(captures) + for weight, weight_quantizer in module.iter_weights_for_calibration(): + # Silence weight fake-quant (incl. SequentialQuantizer leaves) so the capture + # forward uses full-precision weights and downstream Hessians aren't corrupted. + leaves = ( + list(weight_quantizer) + if isinstance(weight_quantizer, SequentialQuantizer) + else [weight_quantizer] + ) + silenced_weight_quantizers.extend( + q + for q in leaves + if isinstance(q, TensorQuantizer) and q.is_enabled and q._if_quant + ) + # Only TensorQuantizer weights are refined (same as mse_calibrate); other types + # (e.g. SequentialQuantizer) are unsupported and left at their max-cal scale. + if not isinstance(weight_quantizer, TensorQuantizer): + if weight_quantizer.is_enabled and "unsupported" not in warned: + warned.add("unsupported") + warn_rank_0( + "local_hessian: only TensorQuantizer weights are calibrated; other " + "types (e.g. SequentialQuantizer) stay at their max-calibrated scale." + ) + continue + if captures: + _warn_local_hessian_fallback(name, weight, weight_quantizer, block_size, warned) - if hasattr(cal, "reset"): - cal.reset() + for weight_quantizer in silenced_weight_quantizers: + weight_quantizer.disable_quant() + print_rank_0("local_hessian: Caching activations and computing local Hessian...") + try: + forward_loop(model) + finally: + for weight_quantizer in silenced_weight_quantizers: + weight_quantizer.enable_quant() + for handle in handles: + handle.remove() + + # TODO(fridah-nv): the per-block Hessian is not synced across TP/DP ranks (max_calibrate's + # amax sync runs before this), so refined amaxes can diverge. All-reduce Hessian / re-sync. + if dist_is_initialized() and dist_size() > 1: + warn_rank_0( + "local_hessian: Hessian is not synced across ranks; refined weight amaxes may " + "diverge under tensor/data parallelism. Treat local_hessian as single-rank for now." + ) - if (idx + 1) % 10 == 0 and torch.cuda.is_available(): - torch.cuda.empty_cache() + # Phase 3: build error funcs and run the shared MSE weight loop. + error_funcs = { + qid: acc.build_error_func(keep_buffer=debug) for qid, acc in accumulators.items() + } + print_rank_0("local_hessian: Running MSE calibration with local Hessian loss...") + _mse_calibrate_weights( + model, + name_to_module, + step_size=step_size, + start_multiplier=start_multiplier, + stop_multiplier=stop_multiplier, + fp8_scale_sweep=fp8_scale_sweep, + error_func_for=lambda q: error_funcs.get(id(q)), + ) + # Free the per-block Hessians (pinned by error_func closures) and the sweep's cached + # allocations so export starts from a defragmented allocator. + error_funcs.clear() + for module in name_to_module.values(): + if isinstance(module, TensorQuantizer) and isinstance(module._calibrator, MseCalibrator): + module._calibrator._error_func = None if torch.cuda.is_available(): - for dev_id in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) torch.cuda.empty_cache() - # Cleanup and free memory - LocalHessianHelper.cache_mode = False - for name, module in all_patched_modules: - module.hessian_helper.cleanup() + if debug: + model._local_hessian_accumulators = accumulators print_rank_0("local_hessian: Calibration complete.") diff --git a/modelopt/torch/quantization/nn/modules/quant_module.py b/modelopt/torch/quantization/nn/modules/quant_module.py index 419c6f4924f..e533ab8848a 100644 --- a/modelopt/torch/quantization/nn/modules/quant_module.py +++ b/modelopt/torch/quantization/nn/modules/quant_module.py @@ -17,6 +17,7 @@ import contextlib import warnings +from collections.abc import Callable from typing import Any import torch @@ -127,6 +128,17 @@ def iter_weights_for_calibration(self): weight_quantizer = getattr(self, quantizer_attr_names(weight_name).weight_quantizer) yield getattr(self, weight_name), weight_quantizer + def register_calibration_input_hooks( + self, callback: Callable[[TensorQuantizer, torch.Tensor, torch.Tensor], None] + ) -> list: + """Register forward hooks calling ``callback(weight_quantizer, weight, input)`` per weight. + + Activation-side counterpart to :meth:`iter_weights_for_calibration`, used by + activation-aware calibration (e.g. local-Hessian). Returns removable handles; the base + default is ``[]`` (no pairing available -> plain weight calibration). Override per module. + """ + return [] + def fold_weight(self, keep_attrs: bool = False): """Fold the weight for faster eval.""" # Handle all attributes that end with _weight_quantizer @@ -247,6 +259,27 @@ def _setup(self): self._register_temp_attribute("_enable_weight_quantization", False) self._register_dynamic_attribute("weight", self._get_quantized_weight) + def register_calibration_input_hooks(self, callback): + """Pair the weight quantizer with the forward input. + + Only a 2-D weight with an enabled ``TensorQuantizer`` is hooked; conv (4-D) and + ``SequentialQuantizer`` weights are unsupported and fall back to plain calibration. + """ + weight = getattr(self, "weight", None) + if ( + weight is None + or weight.dim() != 2 + or not isinstance(self.weight_quantizer, TensorQuantizer) + or not self.weight_quantizer.is_enabled + ): + return [] + + def _pre_hook(module, args): + if args: + callback(module.weight_quantizer, module.weight, args[0]) + + return [self.register_forward_pre_hook(_pre_hook)] + class _LegacyQuantInputBaseMixin: """A mixin to support legacy quantized modules which needs to have an __init__ method.""" diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 1873ecda528..a4c5cca0663 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -918,6 +918,36 @@ def iter_weights_for_calibration(self): for idx, q in enumerate(quantizers): yield weight[idx], q + def register_calibration_input_hooks(self, callback): + """Pair each per-expert weight quantizer with its routed input activation. + + Hooks the shared input quantizers, which the eager ``F.linear`` path calls per expert + while ``_current_expert_idx`` is set. Batched/grouped kernels never call them, so those + experts get no capture (fall back to plain weight calibration). + """ + handles = [] + for weight_name, quantizers_name, input_quantizer_name in ( + ("gate_up_proj", "gate_up_proj_weight_quantizers", "gate_up_proj_input_quantizer"), + ("down_proj", "down_proj_weight_quantizers", "down_proj_input_quantizer"), + ): + weight = getattr(self, weight_name, None) + quantizers = getattr(self, quantizers_name, None) + input_quantizer = getattr(self, input_quantizer_name, None) + if weight is None or quantizers is None or input_quantizer is None: + continue + + def _pre_hook(_iq, args, _weight_name=weight_name, _quantizers=quantizers): + if not args: + return + idx = self._current_expert_idx + weight_quantizer = _quantizers[idx] + if weight_quantizer.is_enabled: + # Read the weight fresh (valid under accelerate/FSDP re-materialization). + callback(weight_quantizer, getattr(self, _weight_name)[idx], args[0]) + + handles.append(input_quantizer.register_forward_pre_hook(_pre_hook)) + return handles + def fold_weight(self, keep_attrs: bool = False): """Fold per-expert weight quantizers into the fused 3-D weights. diff --git a/tests/unit/torch/quantization/plugins/test_fused_experts.py b/tests/unit/torch/quantization/plugins/test_fused_experts.py index 5a7fe8cd9a6..eb244e1036e 100644 --- a/tests/unit/torch/quantization/plugins/test_fused_experts.py +++ b/tests/unit/torch/quantization/plugins/test_fused_experts.py @@ -22,7 +22,9 @@ pytest.importorskip("transformers") +import modelopt.torch.quantization as mtq from modelopt.torch.quantization.conversion import _normalize_fused_experts_quantizer_name +from modelopt.torch.quantization.model_calib import local_hessian_calibrate from modelopt.torch.quantization.nn import QuantModuleRegistry from modelopt.torch.quantization.plugins.huggingface import ( _is_fused_experts_module, @@ -656,6 +658,63 @@ def forward_loop(m): self._cleanup_registry(expert_type) + def test_local_hessian_per_expert_capture_and_refinement(self): + """The plugin's extension point pairs each per-expert weight quantizer with its routed + input, and local_hessian uses that to refine every expert's weight amax.""" + model = _TinyMoEModel() + expert_type = type(model.moe.experts) + self._cleanup_registry(expert_type) + + weight_quant = {"num_bits": 8, "axis": 0} + quant_cfg = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + {"quantizer_name": "*gate_up_proj_weight_quantizer", "cfg": weight_quant}, + {"quantizer_name": "*down_proj_weight_quantizer", "cfg": weight_quant}, + ], + "algorithm": "max", + } + + def forward_loop(m): + torch.manual_seed(0) + for _ in range(3): + m(torch.randn(1, 8, HIDDEN_DIM)) + + mtq.quantize(model, quant_cfg, forward_loop=forward_loop) + experts = model.moe.experts + expert_quantizers = list(experts.gate_up_proj_weight_quantizers) + list( + experts.down_proj_weight_quantizers + ) + + # Extension point captures per-expert (weight_quantizer, weight_slice, cin). + captured = [] + handles = experts.register_calibration_input_hooks( + lambda wq, w, x: captured.append((id(wq), tuple(w.shape), x.shape[-1])) + ) + assert len(handles) == 2 # one pre-hook per shared input quantizer (gate_up, down) + with torch.no_grad(): + model(torch.randn(1, 8, HIDDEN_DIM)) + for h in handles: + h.remove() + valid_ids = {id(q) for q in expert_quantizers} + shapes = {(2 * INTERMEDIATE_DIM, HIDDEN_DIM), (HIDDEN_DIM, INTERMEDIATE_DIM)} + assert captured and all( + wq_id in valid_ids and shape in shapes and cin == shape[1] + for wq_id, shape, cin in captured + ) + + # End-to-end: local_hessian refines per-expert weight amax via that capture. + max_amax = {id(q): q.amax.clone() for q in expert_quantizers if q.amax is not None} + local_hessian_calibrate(model, forward_loop, fp8_scale_sweep=False, debug=True) + assert any(a.num_samples > 0 for a in model._local_hessian_accumulators.values()) + assert all(q.amax is not None and torch.isfinite(q.amax).all() for q in expert_quantizers) + assert any( + id(q) in max_amax and not torch.allclose(q.amax, max_amax[id(q)]) + for q in expert_quantizers + ) + + self._cleanup_registry(expert_type) + def test_max_calibrate_populates_dead_static_nvfp4_expert_quantizers(self): """max calibration fills static NVFP4 ``_amax`` on experts the forward never routed to. diff --git a/tests/unit/torch/quantization/test_local_hessian.py b/tests/unit/torch/quantization/test_local_hessian.py new file mode 100644 index 00000000000..2a610aa87e6 --- /dev/null +++ b/tests/unit/torch/quantization/test_local_hessian.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for local Hessian-weighted MSE calibration (CPU).""" + +import warnings + +import pytest +import torch +import torch.nn as nn +from _test_utils.torch.quantization.models import SimpleConv, SimpleLinear + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization import calib +from modelopt.torch.quantization.config import QuantizerAttributeConfig +from modelopt.torch.quantization.model_calib import ( + _FP8_SWEEP_CALIBRATOR_REGISTRY, + _LocalHessianAccumulator, + _make_weight_mse_calibrator, + _register_fp8_sweep_calibrator, + _warn_if_block_size_mismatch, + local_hessian_calibrate, + mse_calibrate, +) +from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer +from modelopt.torch.quantization.nn.modules.tensor_quantizer import ( + _QUANT_FUNCTIONAL_BACKENDS, + register_quant_backend, +) + +# Weight-only INT8 per-channel; calibration is re-run explicitly per test. +INT8_WEIGHT_CFG = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + ], + "algorithm": "max", +} + + +def _weight_amaxes(model): + return { + n: m.amax + for n, m in model.named_modules() + if isinstance(m, TensorQuantizer) and m.is_enabled and m.amax is not None + } + + +def _make_forward_loop(seed=0): + def forward_loop(model): + torch.manual_seed(seed) + for _ in range(3): + x = torch.randn(8, 16) + x[:, 0] *= 40.0 # skew so the Hessian is non-trivial vs plain weight MSE + model(x) + + return forward_loop + + +class TestLocalHessianAccumulator: + def test_accumulate_shape_samples_fp32_buffer(self): + torch.manual_seed(0) + acc = _LocalHessianAccumulator(8, 32, 16) + assert acc.is_enabled + acc.accumulate(torch.randn(10, 32, dtype=torch.bfloat16)) + assert acc.hessian_per_block.shape == (2, 16, 16) + assert acc.hessian_per_block.dtype == torch.float32 # fp32 despite bf16 input + acc.accumulate(torch.randn(5, 32)) + assert acc.num_samples == 15 + assert acc.build_error_func() is not None + assert acc.hessian_per_block is None # raw buffer freed + + def test_error_func_matches_explicit_hessian_weighted_loss(self): + torch.manual_seed(1) + cout, cin, bs = 4, 32, 16 + n_blocks = cin // bs + acc = _LocalHessianAccumulator(cout, cin, bs) + x = torch.randn(7, cin) + acc.accumulate(x) + error_func = acc.build_error_func() + + xb = x.reshape(-1, cin).T.reshape(n_blocks, bs, -1) + hessian = (xb @ xb.transpose(-1, -2)) / acc.num_samples + w = torch.randn(cout * n_blocks, bs) + wq = w + 0.05 * torch.randn_like(w) + err = error_func(w, wq).view(-1, bs) + + assert err.shape == (cout * n_blocks, bs) + assert torch.allclose(err, err[:, :1].expand(-1, bs)) # per-block scalar broadcast + dw = (w - wq).view(cout, n_blocks, bs) + expected = torch.einsum("cnb,nbd,cnd->cn", dw, hessian, dw).reshape(-1) + assert torch.allclose(err[:, 0], expected, atol=1e-5) + + def test_returns_none_when_disabled_or_no_samples(self): + not_divisible = _LocalHessianAccumulator(8, 30, 16) + assert not not_divisible.is_enabled + not_divisible.accumulate(torch.randn(4, 30)) # no-op + assert not_divisible.build_error_func() is None + assert _LocalHessianAccumulator(8, 32, 16).build_error_func() is None # no samples + + +class TestLocalHessianCalibrateDense: + def test_refines_amax_beyond_max_and_plain_mse(self): + forward_loop = _make_forward_loop() + torch.manual_seed(0) + model_lh = SimpleLinear() + mtq.quantize(model_lh, INT8_WEIGHT_CFG, forward_loop=forward_loop) + max_amax = {n: a.clone() for n, a in _weight_amaxes(model_lh).items()} + local_hessian_calibrate(model_lh, forward_loop, fp8_scale_sweep=False, debug=True) + + torch.manual_seed(0) + model_mse = SimpleLinear() + mtq.quantize(model_mse, INT8_WEIGHT_CFG, forward_loop=forward_loop) + mse_calibrate(model_mse, forward_loop, fp8_scale_sweep=False) + + accs = model_lh._local_hessian_accumulators + assert accs and all(a.num_samples > 0 for a in accs.values()) + lh, mse = _weight_amaxes(model_lh), _weight_amaxes(model_mse) + assert all(torch.isfinite(a).all() and (a > 0).all() for a in lh.values()) + assert any(not torch.allclose(lh[n], max_amax[n]) for n in lh) # refined past max-cal + assert any(not torch.allclose(lh[n], mse[n]) for n in lh) # Hessian changed the choice + + def test_warns_with_module_name_when_cin_not_divisible(self): + class _OddModel(nn.Module): + def __init__(self): + super().__init__() + self.odd = nn.Linear(24, 32) # 24 not divisible by block_size 16 + + def forward(self, x): + return self.odd(x) + + torch.manual_seed(0) + model = _OddModel() + forward_loop = lambda m: m(torch.randn(4, 24)) # noqa: E731 + mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=forward_loop) + with pytest.warns(UserWarning, match=r"odd input features \(24\) not divisible"): + local_hessian_calibrate(model, forward_loop, fp8_scale_sweep=False) + + def test_no_forward_loop_is_skipped(self): + torch.manual_seed(0) + model = SimpleLinear() + mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=_make_forward_loop()) + before = {n: a.clone() for n, a in _weight_amaxes(model).items()} + with pytest.warns(UserWarning, match="forward_loop must be provided"): + local_hessian_calibrate(model, forward_loop=None) + assert all(torch.equal(before[n], a) for n, a in _weight_amaxes(model).items()) + + +class TestActivationCaptureExtensionPoint: + """The extension point that decouples local-Hessian capture from module type.""" + + def test_dense_captures_and_conv_falls_back(self): + torch.manual_seed(0) + model = SimpleLinear() + mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=_make_forward_loop()) + captured = [] + handles = model.net[0].register_calibration_input_hooks( + lambda wq, w, x: captured.append((tuple(w.shape), x.shape[-1])) + ) + assert len(handles) == 1 + with torch.no_grad(): + model(torch.randn(2, 16)) + for h in handles: + h.remove() + assert captured and captured[0] == ((32, 16), 16) # cin from activation matches weight + + conv = SimpleConv() + mtq.quantize(conv, INT8_WEIGHT_CFG, forward_loop=lambda m: m(SimpleConv.get_input())) + assert conv.net[0].register_calibration_input_hooks(lambda *a: None) == [] # 4-D weight + + def test_sequential_quantizer_weight_not_hooked(self): + torch.manual_seed(0) + model = SimpleLinear() + mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=_make_forward_loop()) + linear = model.net[0] + linear.weight_quantizer = SequentialQuantizer(TensorQuantizer(), TensorQuantizer()) + assert linear.register_calibration_input_hooks(lambda *a: None) == [] # unsupported + + def test_block_size_mismatch_warns_only_on_mismatch(self): + def q(block): + return TensorQuantizer( + QuantizerAttributeConfig( + num_bits=(2, 1), block_sizes={-1: block, "type": "static", "scale_bits": (4, 3)} + ) + ) + + with pytest.warns(UserWarning, match="will not align"): + _warn_if_block_size_mismatch(q(32), 16, "layer") + with warnings.catch_warnings(): + warnings.simplefilter("error") + _warn_if_block_size_mismatch(q(16), 16, "layer") # matching block + per_channel = TensorQuantizer(QuantizerAttributeConfig(num_bits=8, axis=0)) + _warn_if_block_size_mismatch(per_channel, 16, "layer") # no block_sizes + + +class TestMakeWeightMseCalibratorErrorFunc: + def setup_method(self): + self._orig_fp8_registry = dict(_FP8_SWEEP_CALIBRATOR_REGISTRY) + self._orig_quant_backends = dict(_QUANT_FUNCTIONAL_BACKENDS) + + def teardown_method(self): + _FP8_SWEEP_CALIBRATOR_REGISTRY.clear() + _FP8_SWEEP_CALIBRATOR_REGISTRY.update(self._orig_fp8_registry) + _QUANT_FUNCTIONAL_BACKENDS.clear() + _QUANT_FUNCTIONAL_BACKENDS.update(self._orig_quant_backends) + + def _make_quantizer(self, backend=None): + q = TensorQuantizer(QuantizerAttributeConfig(num_bits=8, axis=None, backend=backend)) + q.amax = torch.tensor(1.0) + return q + + def test_error_func_threaded_to_mse_calibrator(self): + marker = lambda x, xq: (x - xq) ** 2 # noqa: E731 + cal = _make_weight_mse_calibrator( + self._make_quantizer(), 0.1, 0.25, 4.0, fp8_scale_sweep=False, error_func=marker + ) + assert isinstance(cal, calib.MseCalibrator) + assert cal._error_func is marker + + def test_registered_backend_with_error_func_is_skipped(self): + register_quant_backend("_lh_test_backend", lambda x, tq: x) + _register_fp8_sweep_calibrator( + "_lh_test_backend", + lambda amax, axis, qf: calib.MseCalibrator(amax=amax, axis=axis, quant_func=qf), + ) + q = self._make_quantizer(backend="_lh_test_backend") + with pytest.warns(UserWarning, match="does not support a custom error"): + cal = _make_weight_mse_calibrator( + q, 0.1, 0.25, 4.0, fp8_scale_sweep=True, error_func=lambda x, xq: x + ) + assert cal is None