From 6e607518090353387944b3a76b9c1c9047a03091 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Mon, 1 Jun 2026 15:42:43 +0000 Subject: [PATCH 1/5] Refactor local_hessian onto shared MSE flow + fused-MoE expert support Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 490 ++++++++++-------- .../plugins/test_fused_experts.py | 54 ++ .../torch/quantization/test_local_hessian.py | 270 ++++++++++ 3 files changed, 604 insertions(+), 210 deletions(-) create mode 100644 tests/unit/torch/quantization/test_local_hessian.py diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index ea95d77450b..5311823d7e6 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -33,8 +33,10 @@ LayerActivationCollector, _CheckpointState, ) -from modelopt.torch.utils import print_rank_0 +from modelopt.torch.utils import print_rank_0, warn_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState +from modelopt.torch.utils.distributed import is_initialized as dist_is_initialized +from modelopt.torch.utils.distributed import size as dist_size from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator @@ -51,7 +53,6 @@ persistent_materialization, promote_nvfp4_static_quantizers, quantizer_attr_names, - reduce_amax, ) from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper @@ -473,8 +474,14 @@ def _make_weight_mse_calibrator( start_multiplier: float, stop_multiplier: float, fp8_scale_sweep: bool, + error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, ) -> _Calibrator | None: - """Create the MSE calibrator for one eligible weight quantizer.""" + """Create the MSE calibrator for one eligible weight quantizer. + + ``error_func`` overrides the default squared-error metric (used by local-Hessian + calibration to weight the per-block error). When set, the NVFP4 Triton fast path is + bypassed in favor of the reference sweep so the custom metric is honored. + """ if ( not isinstance(weight_quantizer, TensorQuantizer) or not weight_quantizer.is_enabled @@ -494,6 +501,14 @@ def _make_weight_mse_calibrator( _FP8_SWEEP_CALIBRATOR_REGISTRY.get(backend) if backend is not None else None ) if backend is not None and backend_factory is not None: + if error_func is not None: + # Registered backend factories don't accept a custom error_func, so a + # Hessian-weighted metric can't be honored; leave at max/MSE amax. + warnings.warn( + f"local_hessian: backend '{backend}' does not support a custom error " + "function; skipping Hessian-weighted calibration for this quantizer." + ) + return None return backend_factory(initial_amax, axis, quant_func) if _uses_modelopt_fp8_weight_scales(weight_quantizer): return NVFP4MSECalibrator( @@ -501,6 +516,7 @@ def _make_weight_mse_calibrator( axis=axis, global_amax=weight_quantizer.global_amax, quant_func=quant_func, + error_func=error_func, ) # fp8_scale_sweep covers only registered backends and static NVFP4 weights; # skip MSE calibration for all other quantizers (no multiplier search). @@ -514,6 +530,7 @@ def _make_weight_mse_calibrator( start_multiplier=start_multiplier, stop_multiplier=stop_multiplier, quant_func=quant_func, + error_func=error_func, ) @@ -553,6 +570,32 @@ def mse_calibrate( # max_calibrate initializes activations and weights; MSE only refines weights below. max_calibrate(model, forward_loop, distributed_sync) name_to_module = dict(model.named_modules()) + _mse_calibrate_weights( + model, + name_to_module, + step_size=step_size, + start_multiplier=start_multiplier, + stop_multiplier=stop_multiplier, + fp8_scale_sweep=fp8_scale_sweep, + ) + + +@torch.no_grad() +def _mse_calibrate_weights( + model: nn.Module, + name_to_module: dict[str, nn.Module], + step_size: float, + start_multiplier: float, + stop_multiplier: float, + fp8_scale_sweep: bool, + error_func_for: Callable[[TensorQuantizer], Callable | None] | None = None, +): + """Replace each eligible weight quantizer's calibrator with an MSE calibrator and run it. + + Shared by ``mse_calibrate`` and ``local_hessian_calibrate``. ``error_func_for`` maps a + weight quantizer to an optional per-weight error function (used by local-Hessian to + inject the Hessian-weighted metric); it defaults to ``None`` (plain squared error). + """ seen_modules: set[int] = set() pbar = tqdm(desc="MSE weight calibration") for parent_module in name_to_module.values(): @@ -561,12 +604,14 @@ def mse_calibrate( seen_modules.add(id(parent_module)) with enable_weight_access_and_writeback(parent_module, model, name_to_module): for weight, weight_quantizer in parent_module.iter_weights_for_calibration(): + error_func = error_func_for(weight_quantizer) if error_func_for else None cal = _make_weight_mse_calibrator( weight_quantizer, step_size, start_multiplier, stop_multiplier, fp8_scale_sweep, + error_func=error_func, ) if cal is None: continue @@ -581,6 +626,93 @@ def mse_calibrate( pbar.close() +class _LocalHessianAccumulator: + """Accumulates a per-block local Hessian ``H = ΣXᵀX`` for one weight quantizer. + + The Hessian is partitioned over the input (``cin``) dimension into + ``cin // block_size`` blocks of shape ``(block_size, block_size)`` so the + Hessian-weighted error matches the NVFP4 per-block scale granularity. The raw + accumulator buffer is allocated lazily on the first ``accumulate`` call, so + never-routed MoE experts cost no memory. + """ + + def __init__(self, cout: int, cin: int, block_size: int): + self.cout = cout + self.cin = cin + self.block_size = block_size + self.num_blocks_per_cin = cin // block_size + # Quantizers whose cin doesn't tile evenly fall back to plain MSE (error_func None). + self.is_enabled = cin % block_size == 0 + self.hessian_per_block: torch.Tensor | None = None + self.num_samples = 0 + + @torch.no_grad() + def accumulate(self, input_tensor: torch.Tensor) -> None: + """Accumulate ``XᵀX`` per block from an activation of shape ``(..., cin)``.""" + if not self.is_enabled: + return + # Accumulate the XᵀX GEMM in fp32 to avoid bf16/fp16 precision loss on the sum. + # (cin, num_tokens) -> (num_blocks, block_size, num_tokens) + x = input_tensor.reshape(-1, self.cin).to(torch.float32).T + x = x.reshape(self.num_blocks_per_cin, self.block_size, -1) + hessian_batch = x @ x.transpose(-1, -2) + if self.hessian_per_block is None: + self.hessian_per_block = hessian_batch + else: + self.hessian_per_block += hessian_batch + self.num_samples += input_tensor.numel() // self.cin + + def build_error_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None: + """Build the Hessian-weighted error function, or ``None`` if no samples were seen. + + Releases the raw Hessian buffer; the returned closure keeps only the normalized copy. + """ + if self.hessian_per_block is None or self.num_samples == 0: + return None + cout = self.cout + bs = self.block_size + hessian = self.hessian_per_block / self.num_samples + self.hessian_per_block = None # free the raw accumulator + + def local_hessian_error(x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: + original_shape = x.shape + # Blocked weight error -> (cout, num_blocks_per_cin, block_size). + dw = (x - xq).view(cout, -1, bs) + # einsum avoids materializing the cout-repeated Hessian: + # dw: (cout, n_blocks, bs), hessian: (n_blocks, bs, bs) -> (cout, n_blocks) + block_loss = torch.einsum("cnb,nbd,cnd->cn", dw, hessian, dw).reshape(-1) + return block_loss.unsqueeze(-1).expand(-1, bs).reshape(original_shape) + + return local_hessian_error + + +def _is_quant_fused_experts(module: nn.Module) -> bool: + """Whether ``module`` is a *converted* fused-MoE-experts wrapper with per-expert quantizers. + + Distinct from ``plugins.huggingface._is_fused_experts_module`` (which detects the + *unconverted* HF module by structure); this checks for the per-expert weight-quantizer + lists and routing index added by ``_QuantFusedExperts``. + """ + return hasattr(module, "_current_expert_idx") and hasattr( + module, "gate_up_proj_weight_quantizers" + ) + + +def _warn_if_block_size_mismatch(weight_quantizer: TensorQuantizer, block_size: int, name: str): + """Warn if the Hessian block_size disagrees with the quantizer's last-axis scale block. + + They must match for the per-block Hessian to weight the same blocks the scale search + optimizes; a mismatch silently produces a finite-but-misaligned amax. + """ + block_sizes = getattr(weight_quantizer, "block_sizes", None) + quant_block = block_sizes.get(-1) if block_sizes else None + if quant_block is not None and quant_block != block_size: + warn_rank_0( + f"local_hessian: block_size ({block_size}) != quantizer scale block " + f"({quant_block}) for {name}; Hessian weighting will not align with the scale blocks." + ) + + @torch.no_grad() def local_hessian_calibrate( model: nn.Module, @@ -595,12 +727,20 @@ def local_hessian_calibrate( ): """Calibrate the model using local Hessian-weighted MSE search. - Instead of minimizing weight error ``||W - Wq||²``, this minimizes Hessian-weighted error - ``loss = (W - Wq)ᵀ H (W - Wq)`` where ``H = X @ X.T`` approximates output reconstruction - error ``||WX - WqX||²``. + Instead of minimizing weight error ``||W - Wq||²``, this minimizes the Hessian-weighted + error ``loss = (W - Wq)ᵀ H (W - Wq)`` where ``H = ΣXᵀX`` approximates the output + reconstruction error ``||WX - WqX||²``. - Per-block Hessians of shape ``(cin // block_size, block_size, block_size)`` are accumulated - during forward pass and used to weight the MSE loss during scale search. + Per-block Hessians of shape ``(cin // block_size, block_size, block_size)`` are + accumulated during a dedicated full-precision forward pass, then used to weight the + per-block MSE error while searching the weight scale (reusing :func:`mse_calibrate`'s + weight-calibration machinery via a custom ``error_func``). + + Coverage: dense quantized linears and HF fused-MoE experts (per-expert weight + quantizers, Hessian built from each expert's routed activations). Quantizers without a + usable Hessian — never-routed experts, ``cin`` not divisible by ``block_size``, + registered custom backends, or non-eager fused-expert kernels that bypass ``F.linear`` + — fall back to the plain max/MSE amax. Args: model: Model to be calibrated. @@ -613,7 +753,8 @@ def local_hessian_calibrate( fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values for NVFP4 per-block quantization (default: True). block_size: Block size for local Hessian computation (default: 16). - debug: If True, keep the local Hessian metadata on modules. + debug: If True, retain the per-quantizer Hessian accumulators on the model + (``model._local_hessian_accumulators``) for inspection. See :class:`LocalHessianCalibConfig ` for details on the configuration options. @@ -622,221 +763,150 @@ def local_hessian_calibrate( warnings.warn("forward_loop must be provided for local_hessian; skipping local_hessian") return - class LocalHessianHelper: - """Helper class to collect activations and compute local Hessian per module.""" - - cache_mode: bool = False + # Phase 1: max-calibrate the whole model. max_calibrate also bootstraps dead-expert + # weight quantizers and promotes / global-amax-syncs NVFP4 static quantizers, so amax + # is initialized for every quantizer before the Hessian refinement below. + print_rank_0("local_hessian: Running max calibration for all quantizers...") + max_calibrate(model, forward_loop, distributed_sync) - def __init__(self, module, name): - self.name = name - self.module = module - self.weight_shape = module.weight.shape # (cout, cin) - self.cout, self.cin = self.weight_shape - self.block_size = block_size - self.num_blocks_per_cin = self.cin // block_size - self.is_enabled = True + name_to_module = dict(model.named_modules()) - # Accumulated Hessian per block: (cin // block_size, block_size, block_size) - self.hessian_per_block = torch.zeros( - self.num_blocks_per_cin, - block_size, - block_size, - dtype=torch.float32, - device=module.weight.device, + # Per-block Hessian accumulators keyed by id(weight_quantizer), so dense linears and + # per-expert MoE quantizers share one uniform lookup during weight calibration. + accumulators: dict[int, _LocalHessianAccumulator] = {} + cleanup_callbacks: list[Callable[[], None]] = [] + # Fused per-expert weight quantizers to silence during caching (their fake-quant runs + # inside the intercepted F.linear, not in a hookable module forward). + fused_weight_quantizers: list[TensorQuantizer] = [] + # Mutable toggle shared with the capture hooks; only on during the caching forward. + cache_state = {"on": False} + + def _setup_dense(module: nn.Module, name: str) -> None: + weight_quantizer = module.weight_quantizer + cout, cin = module.weight.shape[0], module.weight.shape[1] + acc = _LocalHessianAccumulator(cout, cin, block_size) + accumulators[id(weight_quantizer)] = acc + if not acc.is_enabled: + warn_rank_0( + f"local_hessian: {name} input features ({cin}) not divisible by block_size " + f"({block_size}); falling back to plain MSE for this module." ) - self.num_samples = 0 - - def setup(self): - """Set up the forward hook to collect activations.""" - module = self.module - bind_forward_method(module, forward, "_forward_no_local_hessian") + _warn_if_block_size_mismatch(weight_quantizer, block_size, name) + + def forward(self, input, *args, **kwargs): + if cache_state["on"]: + input_local = input.to_local() if hasattr(input, "to_local") else input + acc.accumulate(input_local) + # Capture full-precision-weight activations during the caching pass. + # Snapshot _if_quant so a quantizer that was not fake-quanting stays off. + was_quant = self.weight_quantizer._if_quant + self.weight_quantizer.disable_quant() + try: + return self._forward_no_local_hessian(input, *args, **kwargs) + finally: + if was_quant: + self.weight_quantizer.enable_quant() + return self._forward_no_local_hessian(input, *args, **kwargs) + + bind_forward_method(module, forward, "_forward_no_local_hessian") + cleanup_callbacks.append( + partial(unpatch_forward_method, module, "_forward_no_local_hessian") + ) - # Check if cin is divisible by block_size - if self.cin % self.block_size != 0: - warnings.warn( - f"Module {self.name}: input features ({self.cin}) not divisible by " - f"block_size ({self.block_size}). Skipping local Hessian for this module." + def _setup_fused_experts(module: nn.Module) -> None: + # _QuantFusedExperts calls each shared input quantizer with the current expert's + # routed activations while module._current_expert_idx holds that expert's index. + # A pre-hook on the input quantizer therefore captures the per-expert Hessian input. + for proj in ("gate_up", "down"): + weight = getattr(module, f"{proj}_proj", None) + input_quantizer = getattr(module, f"{proj}_proj_input_quantizer", None) + weight_quantizers = getattr(module, f"{proj}_proj_weight_quantizers", None) + if weight is None or input_quantizer is None or weight_quantizers is None: + continue + # All experts of a projection share cin, so validate / warn once per projection. + cin = weight[0].shape[1] + if cin % block_size != 0: + warn_rank_0( + f"local_hessian: fused {proj}_proj input features ({cin}) not divisible by " + f"block_size ({block_size}); falling back to plain MSE for these experts." + ) + _warn_if_block_size_mismatch(weight_quantizers[0], block_size, f"fused {proj}_proj") + for idx, weight_quantizer in enumerate(weight_quantizers): + cout = weight[idx].shape[0] + accumulators[id(weight_quantizer)] = _LocalHessianAccumulator(cout, cin, block_size) + fused_weight_quantizers.append(weight_quantizer) + + def pre_hook(_input_q, args, _module=module, _weight_quantizers=weight_quantizers): + if not cache_state["on"] or not args: + return + input_tensor = args[0] + input_tensor = ( + input_tensor.to_local() if hasattr(input_tensor, "to_local") else input_tensor + ) + # _current_expert_idx indexes the same list these accumulators are keyed by; + # index directly so a routing/ordering desync surfaces instead of silently + # dropping activations. + accumulators[id(_weight_quantizers[_module._current_expert_idx])].accumulate( + input_tensor ) - self.is_enabled = False - - def cleanup(self): - """Clean up the forward hook.""" - unpatch_forward_method(self.module, "_forward_no_local_hessian") - if not debug: - if hasattr(self.module, "hessian_helper"): - delattr(self.module, "hessian_helper") - - def accumulate_hessian(self, input_tensor: torch.Tensor): - """Accumulate local Hessian from input activations. - - Args: - input_tensor: Input tensor of shape (..., cin) - """ - if not self.is_enabled: - return - - # Flatten to (num_tokens, cin) - x = input_tensor.reshape(-1, self.cin).T # (cin, num_tokens) - x = x.reshape(self.num_blocks_per_cin, self.block_size, -1) # (num_blocks, bs, n) - - # Compute H = X @ X.T for each block and accumulate - hessian_batch = (x @ x.transpose(-1, -2)).to(torch.float32) - self.hessian_per_block += hessian_batch - self.num_samples += input_tensor.numel() // self.cin - - def get_error_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: - """Get the local Hessian error function for MSE calibration.""" - cout = self.cout - bs = self.block_size - # Normalize hessian by number of samples - hessian = self.hessian_per_block / max(self.num_samples, 1) - - def local_hessian_error(x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: - """Compute local Hessian-weighted error.""" - original_shape = x.shape - # Reshape to (cout, num_blocks_per_cin, block_size) - dw = (x - xq).view(cout, -1, bs) - # Use einsum to avoid materializing cout-repeated Hessian - # dw: (cout, n_blocks, bs), hessian: (n_blocks, bs, bs) -> (cout, n_blocks) - block_loss = torch.einsum("cnb,nbd,cnd->cn", dw, hessian, dw) - block_loss = block_loss.reshape(-1) - error = block_loss.unsqueeze(-1).expand(-1, bs).reshape(original_shape) - return error - - return local_hessian_error - - def forward(self, input, *args, **kwargs): - """Custom forward that collects activations in cache mode.""" - if LocalHessianHelper.cache_mode and self.hessian_helper.is_enabled: - # Get local tensor from DTensor if applicable - input_local = input.to_local() if hasattr(input, "to_local") else input - self.hessian_helper.accumulate_hessian(input_local) - - # Forward without quantization during caching - if LocalHessianHelper.cache_mode: - self.weight_quantizer.disable() - out = self._forward_no_local_hessian(input, *args, **kwargs) - self.weight_quantizer.enable() - return out - - return self._forward_no_local_hessian(input, *args, **kwargs) - - # First, run max_calibrate on the whole model to get initial amax for all quantizers - # This calibrates both weight_quantizer and input_quantizer with max calibration - print_rank_0("local_hessian: Running max calibration for all quantizers...") - max_calibrate(model, forward_loop, distributed_sync) - # Setup helpers for all quantized linear modules - name_to_module = dict(model.named_modules()) - weight_quantizers_info = [] - all_patched_modules = [] # Track all modules for cleanup (including disabled ones) + handle = input_quantizer.register_forward_pre_hook(pre_hook) + cleanup_callbacks.append(handle.remove) + # Phase 2: install per-quantizer activation-capture hooks. for name, module in name_to_module.items(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: with enable_weight_access_and_writeback(module, model, name_to_module): - module.hessian_helper = LocalHessianHelper(module, name) - module.hessian_helper.setup() - all_patched_modules.append((name, module)) - if module.hessian_helper.is_enabled: - weight_quantizers_info.append((name, module)) - - # Cache activations by running forward loop - LocalHessianHelper.cache_mode = True + _setup_dense(module, name) + elif _is_quant_fused_experts(module): + with enable_weight_access_and_writeback(module, model, name_to_module): + _setup_fused_experts(module) + + # Phase 3: cache activations / accumulate Hessians with a full-precision forward. + # Silence fused per-expert weight quantizers so down-proj inputs see FP gate_up weights, + # mirroring the dense path which disables its own weight quantizer per forward. + fused_disabled = [q for q in fused_weight_quantizers if q.is_enabled and q._if_quant] + for q in fused_disabled: + q.disable_quant() + cache_state["on"] = True print_rank_0("local_hessian: Caching activations and computing local Hessian...") - forward_loop(model) - - # TODO(fridah-nv): Sync Hessian across distributed processes if needed + try: + forward_loop(model) + finally: + cache_state["on"] = False + for q in fused_disabled: + q.enable_quant() + + # TODO(fridah-nv): the per-block Hessian is sharded for row-parallel linears and differs + # across data-parallel ranks; max_calibrate's amax sync runs *before* this refinement, so + # refined amaxes can diverge across ranks. All-reduce the Hessian (and/or re-sync amax) + # for correct TP/DP behavior. + if dist_is_initialized() and dist_size() > 1: + warn_rank_0( + "local_hessian: Hessian is not synced across ranks; refined weight amaxes may " + "diverge under tensor/data parallelism. Treat local_hessian as single-rank for now." + ) - # Replace calibrators with MseCalibrator using local Hessian error function + # Phase 4: build per-quantizer error functions and refine weight scales via the shared + # MSE weight-calibration loop. Quantizers without a usable Hessian map to None and fall + # back to plain max/MSE. + error_funcs = {qid: acc.build_error_func() for qid, acc in accumulators.items()} print_rank_0("local_hessian: Running MSE calibration with local Hessian loss...") - for name, module in weight_quantizers_info: - weight_quantizer = module.weight_quantizer - helper = module.hessian_helper - - if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None: - continue - - initial_amax = weight_quantizer._amax.clone().detach() - - def quant_func(x, amax, quantizer=weight_quantizer): - return _mse_quant_func(x, amax, quantizer) - - is_nvfp4_static = weight_quantizer.is_nvfp4_static - - if is_nvfp4_static and not isinstance(weight_quantizer, NVFP4StaticQuantizer): - global_amax = reduce_amax(initial_amax, axis=None) - NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax) - - error_func = helper.get_error_func() - - if fp8_scale_sweep and is_nvfp4_static: - weight_quantizer._calibrator = NVFP4MSECalibrator( - amax=initial_amax, - axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, - global_amax=weight_quantizer.global_amax, - quant_func=quant_func, - error_func=error_func, - ) - else: - weight_quantizer._calibrator = MseCalibrator( - amax=initial_amax, - axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, - step_size=step_size, - start_multiplier=start_multiplier, - stop_multiplier=stop_multiplier, - quant_func=quant_func, - error_func=error_func, - ) - - # Free cached memory before heavy calibration - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # Process weights ONE AT A TIME with immediate amax computation and cleanup - weight_list = [ - (name, module) - for name, module in weight_quantizers_info - if module.weight_quantizer._calibrator is not None - ] - - for idx, (name, module) in enumerate(weight_list): - weight_quantizer = module.weight_quantizer - cal = weight_quantizer._calibrator - - # Step 1: Calibrate this weight - weight_quantizer.disable_quant() - weight_quantizer.enable_calib() - with enable_weight_access_and_writeback(module, model, name_to_module): - weight = module.weight - weight_quantizer(weight) - - # Step 2: IMMEDIATELY compute amax (before calibration data grows) - if cal.compute_amax() is not None: - weight_quantizer.load_calib_amax() - - weight_quantizer.enable_quant() - weight_quantizer.disable_calib() - - # Step 3: Sync all devices and reset calibrator for next weight - if torch.cuda.is_available(): - for dev_id in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) - - if hasattr(cal, "reset"): - cal.reset() - - if (idx + 1) % 10 == 0 and torch.cuda.is_available(): - torch.cuda.empty_cache() - - if torch.cuda.is_available(): - for dev_id in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) - torch.cuda.empty_cache() + _mse_calibrate_weights( + model, + name_to_module, + step_size=step_size, + start_multiplier=start_multiplier, + stop_multiplier=stop_multiplier, + fp8_scale_sweep=fp8_scale_sweep, + error_func_for=lambda q: error_funcs.get(id(q)), + ) - # Cleanup and free memory - LocalHessianHelper.cache_mode = False - for name, module in all_patched_modules: - module.hessian_helper.cleanup() + # Phase 5: remove hooks / patched forwards. + for cleanup in cleanup_callbacks: + cleanup() + if debug: + model._local_hessian_accumulators = accumulators print_rank_0("local_hessian: Calibration complete.") diff --git a/tests/unit/torch/quantization/plugins/test_fused_experts.py b/tests/unit/torch/quantization/plugins/test_fused_experts.py index 5a7fe8cd9a6..61c2e4a07e6 100644 --- a/tests/unit/torch/quantization/plugins/test_fused_experts.py +++ b/tests/unit/torch/quantization/plugins/test_fused_experts.py @@ -656,6 +656,60 @@ def forward_loop(m): self._cleanup_registry(expert_type) + def test_local_hessian_calibrates_per_expert_weights(self): + """local_hessian builds a per-expert Hessian (from each expert's routed inputs) + and refines every expert's weight amax via the shared MSE calibration loop.""" + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.model_calib import local_hessian_calibrate + + model = _TinyMoEModel() + expert_type = type(model.moe.experts) + self._cleanup_registry(expert_type) + + quant_cfg = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*gate_up_proj_weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }, + { + "quantizer_name": "*down_proj_weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }, + ], + "algorithm": "max", + } + + def forward_loop(m): + torch.manual_seed(0) + for _ in range(3): + m(torch.randn(1, 8, HIDDEN_DIM)) + + mtq.quantize(model, quant_cfg, forward_loop=forward_loop) + experts = model.moe.experts + expert_quantizers = list(experts.gate_up_proj_weight_quantizers) + list( + experts.down_proj_weight_quantizers + ) + max_amax = {id(q): q.amax.clone() for q in expert_quantizers if q.amax is not None} + + local_hessian_calibrate(model, forward_loop, fp8_scale_sweep=False, debug=True) + + # At least one expert's routed activations produced a per-block Hessian. + accumulators = model._local_hessian_accumulators + assert any(acc.num_samples > 0 for acc in accumulators.values()), ( + "no per-expert Hessian captured — F.linear hook likely bypassed." + ) + + refined = False + for q in expert_quantizers: + assert q.amax is not None and torch.isfinite(q.amax).all() + if id(q) in max_amax and not torch.allclose(q.amax, max_amax[id(q)]): + refined = True + assert refined, "local_hessian did not refine any expert weight amax" + + self._cleanup_registry(expert_type) + def test_max_calibrate_populates_dead_static_nvfp4_expert_quantizers(self): """max calibration fills static NVFP4 ``_amax`` on experts the forward never routed to. diff --git a/tests/unit/torch/quantization/test_local_hessian.py b/tests/unit/torch/quantization/test_local_hessian.py new file mode 100644 index 00000000000..013dd3ecbaa --- /dev/null +++ b/tests/unit/torch/quantization/test_local_hessian.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for local Hessian-weighted MSE calibration (CPU).""" + +import pytest +import torch +from _test_utils.torch.quantization.models import SimpleLinear + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization import calib +from modelopt.torch.quantization.config import QuantizerAttributeConfig +from modelopt.torch.quantization.model_calib import ( + _FP8_SWEEP_CALIBRATOR_REGISTRY, + _LocalHessianAccumulator, + _make_weight_mse_calibrator, + _register_fp8_sweep_calibrator, + local_hessian_calibrate, + mse_calibrate, +) +from modelopt.torch.quantization.nn import TensorQuantizer +from modelopt.torch.quantization.nn.modules.tensor_quantizer import ( + _QUANT_FUNCTIONAL_BACKENDS, + register_quant_backend, +) + +# Weight-only INT8 per-channel config; calibration is re-run explicitly per test. +INT8_WEIGHT_CFG = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + ], + "algorithm": "max", +} + + +class TestLocalHessianAccumulator: + def test_shape_samples_and_buffer_release(self): + torch.manual_seed(0) + cout, cin, bs = 8, 32, 16 + acc = _LocalHessianAccumulator(cout, cin, bs) + assert acc.is_enabled + + x = torch.randn(10, cin) + acc.accumulate(x) + assert acc.hessian_per_block.shape == (cin // bs, bs, bs) + assert acc.num_samples == 10 + + # A second batch accumulates (sum over samples). + acc.accumulate(torch.randn(5, cin)) + assert acc.num_samples == 15 + + error_func = acc.build_error_func() + assert error_func is not None + # build_error_func releases the raw accumulator (keeps only the normalized copy). + assert acc.hessian_per_block is None + + def test_error_func_matches_explicit_hessian_weighted_loss(self): + torch.manual_seed(1) + cout, cin, bs = 4, 32, 16 + n_blocks = cin // bs + acc = _LocalHessianAccumulator(cout, cin, bs) + + x = torch.randn(7, cin) + acc.accumulate(x) + num_samples = acc.num_samples + error_func = acc.build_error_func() + + # Reference normalized per-block Hessian. + xb = x.reshape(-1, cin).T.reshape(n_blocks, bs, -1) + hessian = (xb @ xb.transpose(-1, -2)) / num_samples + + total_blocks = cout * n_blocks + w = torch.randn(total_blocks, bs) + wq = w + 0.05 * torch.randn_like(w) + err = error_func(w, wq) + + assert err.shape == w.shape + # The per-block scalar loss is broadcast across the block's bs entries. + err_blocks = err.view(-1, bs) + assert torch.allclose(err_blocks, err_blocks[:, :1].expand(-1, bs)) + + dw = (w - wq).view(cout, n_blocks, bs) + expected = torch.einsum("cnb,nbd,cnd->cn", dw, hessian, dw).reshape(-1) + assert torch.allclose(err_blocks[:, 0], expected, atol=1e-5) + + def test_disabled_when_cin_not_divisible(self): + acc = _LocalHessianAccumulator(8, 30, 16) + assert not acc.is_enabled + acc.accumulate(torch.randn(4, 30)) # no-op + assert acc.hessian_per_block is None + assert acc.build_error_func() is None + + def test_no_samples_returns_none(self): + acc = _LocalHessianAccumulator(8, 32, 16) + assert acc.build_error_func() is None + + def test_accumulates_in_fp32_for_low_precision_input(self): + acc = _LocalHessianAccumulator(4, 16, 16) + acc.accumulate(torch.randn(8, 16, dtype=torch.bfloat16)) + assert acc.hessian_per_block.dtype == torch.float32 + + +class TestBlockSizeMismatchWarning: + def _block_quantizer(self, block): + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), block_sizes={-1: block, "type": "static", "scale_bits": (4, 3)} + ) + return TensorQuantizer(quant_attribute_cfg=cfg) + + def test_warns_on_mismatch(self): + from modelopt.torch.quantization.model_calib import _warn_if_block_size_mismatch + + with pytest.warns(UserWarning, match="will not align"): + _warn_if_block_size_mismatch(self._block_quantizer(32), 16, "layer") + + def test_silent_when_matching_or_no_block_sizes(self): + import warnings + + from modelopt.torch.quantization.model_calib import _warn_if_block_size_mismatch + + with warnings.catch_warnings(): + warnings.simplefilter("error") + _warn_if_block_size_mismatch(self._block_quantizer(16), 16, "layer") + per_channel = TensorQuantizer(QuantizerAttributeConfig(num_bits=8, axis=0)) + _warn_if_block_size_mismatch(per_channel, 16, "layer") + + +def _make_forward_loop(seed=0, skew=True): + def forward_loop(model): + torch.manual_seed(seed) + for _ in range(3): + x = torch.randn(8, 16) + if skew: + # Skew one input feature so the activation Hessian is non-trivial and + # the Hessian-weighted optimum diverges from the plain weight MSE. + x[:, 0] *= 40.0 + model(x) + + return forward_loop + + +class TestLocalHessianCalibrateDense: + def test_runs_and_refines_amax(self): + torch.manual_seed(0) + model = SimpleLinear() + forward_loop = _make_forward_loop() + mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=forward_loop) + + # Snapshot the post-max-calibration amax for each weight quantizer. + max_amax = { + name: module.amax.clone() + for name, module in model.named_modules() + if isinstance(module, TensorQuantizer) and module.is_enabled and module.amax is not None + } + assert max_amax, "expected enabled weight quantizers after quantize" + + local_hessian_calibrate(model, forward_loop, fp8_scale_sweep=False, debug=True) + + # Every enabled weight quantizer got a Hessian accumulator with collected samples. + accumulators = model._local_hessian_accumulators + assert accumulators + assert all(acc.num_samples > 0 for acc in accumulators.values()) + + # The Hessian-weighted search moved at least one amax away from the max value. + changed = False + for name, module in model.named_modules(): + if name in max_amax: + assert torch.isfinite(module.amax).all() and (module.amax > 0).all() + if not torch.allclose(module.amax, max_amax[name]): + changed = True + assert changed, "local_hessian did not refine any amax away from max-calibration" + + def test_differs_from_plain_mse(self): + forward_loop = _make_forward_loop(seed=3) + + torch.manual_seed(0) + model_lh = SimpleLinear() + mtq.quantize(model_lh, INT8_WEIGHT_CFG, forward_loop=forward_loop) + local_hessian_calibrate(model_lh, forward_loop, fp8_scale_sweep=False) + + torch.manual_seed(0) + model_mse = SimpleLinear() + mtq.quantize(model_mse, INT8_WEIGHT_CFG, forward_loop=forward_loop) + mse_calibrate(model_mse, forward_loop, fp8_scale_sweep=False) + + lh = { + n: m.amax + for n, m in model_lh.named_modules() + if isinstance(m, TensorQuantizer) and m.is_enabled and m.amax is not None + } + mse = { + n: m.amax + for n, m in model_mse.named_modules() + if isinstance(m, TensorQuantizer) and m.is_enabled and m.amax is not None + } + assert lh.keys() == mse.keys() + # Hessian weighting should change the chosen scale for at least one quantizer. + assert any(not torch.allclose(lh[n], mse[n]) for n in lh) + + def test_no_forward_loop_is_skipped(self): + torch.manual_seed(0) + model = SimpleLinear() + mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=_make_forward_loop()) + before = { + n: m.amax.clone() + for n, m in model.named_modules() + if isinstance(m, TensorQuantizer) and m.is_enabled and m.amax is not None + } + with pytest.warns(UserWarning, match="forward_loop must be provided"): + local_hessian_calibrate(model, forward_loop=None) + after = { + n: m.amax + for n, m in model.named_modules() + if isinstance(m, TensorQuantizer) and m.is_enabled and m.amax is not None + } + for n in before: + assert torch.equal(before[n], after[n]) + + +class TestMakeWeightMseCalibratorErrorFunc: + def setup_method(self): + self._orig_fp8_registry = dict(_FP8_SWEEP_CALIBRATOR_REGISTRY) + self._orig_quant_backends = dict(_QUANT_FUNCTIONAL_BACKENDS) + + def teardown_method(self): + _FP8_SWEEP_CALIBRATOR_REGISTRY.clear() + _FP8_SWEEP_CALIBRATOR_REGISTRY.update(self._orig_fp8_registry) + _QUANT_FUNCTIONAL_BACKENDS.clear() + _QUANT_FUNCTIONAL_BACKENDS.update(self._orig_quant_backends) + + def _make_quantizer(self, backend=None): + cfg = QuantizerAttributeConfig(num_bits=8, axis=None, backend=backend) + q = TensorQuantizer(quant_attribute_cfg=cfg) + q.amax = torch.tensor(1.0) + return q + + def test_error_func_threaded_to_mse_calibrator(self): + q = self._make_quantizer() + marker = lambda x, xq: (x - xq) ** 2 # noqa: E731 + cal = _make_weight_mse_calibrator( + q, 0.1, 0.25, 4.0, fp8_scale_sweep=False, error_func=marker + ) + assert isinstance(cal, calib.MseCalibrator) + assert cal._error_func is marker + + def test_registered_backend_with_error_func_is_skipped(self): + register_quant_backend("_lh_test_backend", lambda x, tq: x) + _register_fp8_sweep_calibrator( + "_lh_test_backend", + lambda amax, axis, qf: calib.MseCalibrator(amax=amax, axis=axis, quant_func=qf), + ) + q = self._make_quantizer(backend="_lh_test_backend") + with pytest.warns(UserWarning, match="does not support a custom error"): + cal = _make_weight_mse_calibrator( + q, 0.1, 0.25, 4.0, fp8_scale_sweep=True, error_func=lambda x, xq: x + ) + assert cal is None From 0bfe1d4ae429aa5c1600f6a9d9e380e82b05c3e4 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Mon, 1 Jun 2026 20:45:12 +0000 Subject: [PATCH 2/5] Decouple fused-MoE local-Hessian via activation-capture hook Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 258 ++++++----------- .../quantization/nn/modules/quant_module.py | 24 ++ .../torch/quantization/plugins/huggingface.py | 30 ++ .../plugins/test_fused_experts.py | 54 ++-- .../torch/quantization/test_local_hessian.py | 261 ++++++++---------- 5 files changed, 282 insertions(+), 345 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 5311823d7e6..9755004cf19 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -476,11 +476,10 @@ def _make_weight_mse_calibrator( fp8_scale_sweep: bool, error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, ) -> _Calibrator | None: - """Create the MSE calibrator for one eligible weight quantizer. + """Create the MSE calibrator for one eligible weight quantizer (``None`` if ineligible). - ``error_func`` overrides the default squared-error metric (used by local-Hessian - calibration to weight the per-block error). When set, the NVFP4 Triton fast path is - bypassed in favor of the reference sweep so the custom metric is honored. + ``error_func`` overrides the squared-error metric (local-Hessian's per-block weighting); + when set, NVFP4's Triton fast path is bypassed for the reference sweep. """ if ( not isinstance(weight_quantizer, TensorQuantizer) @@ -502,8 +501,7 @@ def _make_weight_mse_calibrator( ) if backend is not None and backend_factory is not None: if error_func is not None: - # Registered backend factories don't accept a custom error_func, so a - # Hessian-weighted metric can't be honored; leave at max/MSE amax. + # Registered backends can't take a custom error_func; skip Hessian refinement. warnings.warn( f"local_hessian: backend '{backend}' does not support a custom error " "function; skipping Hessian-weighted calibration for this quantizer." @@ -518,11 +516,10 @@ def _make_weight_mse_calibrator( quant_func=quant_func, error_func=error_func, ) - # fp8_scale_sweep covers only registered backends and static NVFP4 weights; - # skip MSE calibration for all other quantizers (no multiplier search). + # fp8_scale_sweep applies only to registered backends and static NVFP4; skip others. return None - # fp8_scale_sweep disabled: multiplier-search MSE calibration for all quantizers. + # No fp8_scale_sweep: multiplier-search MSE for all quantizers. return MseCalibrator( amax=initial_amax, axis=axis, @@ -590,11 +587,10 @@ def _mse_calibrate_weights( fp8_scale_sweep: bool, error_func_for: Callable[[TensorQuantizer], Callable | None] | None = None, ): - """Replace each eligible weight quantizer's calibrator with an MSE calibrator and run it. + """Run MSE weight calibration over all eligible quantizers (shared by mse / local-Hessian). - Shared by ``mse_calibrate`` and ``local_hessian_calibrate``. ``error_func_for`` maps a - weight quantizer to an optional per-weight error function (used by local-Hessian to - inject the Hessian-weighted metric); it defaults to ``None`` (plain squared error). + ``error_func_for`` maps a weight quantizer to an optional per-weight error function + (local-Hessian's Hessian metric); ``None`` means plain squared error. """ seen_modules: set[int] = set() pbar = tqdm(desc="MSE weight calibration") @@ -627,13 +623,10 @@ def _mse_calibrate_weights( class _LocalHessianAccumulator: - """Accumulates a per-block local Hessian ``H = ΣXᵀX`` for one weight quantizer. + """Per-block local Hessian ``H = ΣXᵀX`` for one weight quantizer. - The Hessian is partitioned over the input (``cin``) dimension into - ``cin // block_size`` blocks of shape ``(block_size, block_size)`` so the - Hessian-weighted error matches the NVFP4 per-block scale granularity. The raw - accumulator buffer is allocated lazily on the first ``accumulate`` call, so - never-routed MoE experts cost no memory. + Partitioned over ``cin`` into ``cin // block_size`` blocks to match the NVFP4 per-block + scale; the buffer is allocated lazily so never-routed experts cost nothing. """ def __init__(self, cout: int, cin: int, block_size: int): @@ -641,7 +634,7 @@ def __init__(self, cout: int, cin: int, block_size: int): self.cin = cin self.block_size = block_size self.num_blocks_per_cin = cin // block_size - # Quantizers whose cin doesn't tile evenly fall back to plain MSE (error_func None). + # Not block-divisible -> no Hessian (falls back to plain MSE). self.is_enabled = cin % block_size == 0 self.hessian_per_block: torch.Tensor | None = None self.num_samples = 0 @@ -651,8 +644,7 @@ def accumulate(self, input_tensor: torch.Tensor) -> None: """Accumulate ``XᵀX`` per block from an activation of shape ``(..., cin)``.""" if not self.is_enabled: return - # Accumulate the XᵀX GEMM in fp32 to avoid bf16/fp16 precision loss on the sum. - # (cin, num_tokens) -> (num_blocks, block_size, num_tokens) + # fp32 GEMM avoids bf16/fp16 precision loss; (cin, tokens) -> (n_blocks, bs, tokens). x = input_tensor.reshape(-1, self.cin).to(torch.float32).T x = x.reshape(self.num_blocks_per_cin, self.block_size, -1) hessian_batch = x @ x.transpose(-1, -2) @@ -663,47 +655,26 @@ def accumulate(self, input_tensor: torch.Tensor) -> None: self.num_samples += input_tensor.numel() // self.cin def build_error_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None: - """Build the Hessian-weighted error function, or ``None`` if no samples were seen. - - Releases the raw Hessian buffer; the returned closure keeps only the normalized copy. - """ + """Hessian-weighted error function (``None`` if no samples); frees the raw buffer.""" if self.hessian_per_block is None or self.num_samples == 0: return None cout = self.cout bs = self.block_size hessian = self.hessian_per_block / self.num_samples - self.hessian_per_block = None # free the raw accumulator + self.hessian_per_block = None def local_hessian_error(x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: original_shape = x.shape - # Blocked weight error -> (cout, num_blocks_per_cin, block_size). + # Per-block weighted error: dw (cout,n,bs) · H (n,bs,bs) -> (cout,n). dw = (x - xq).view(cout, -1, bs) - # einsum avoids materializing the cout-repeated Hessian: - # dw: (cout, n_blocks, bs), hessian: (n_blocks, bs, bs) -> (cout, n_blocks) block_loss = torch.einsum("cnb,nbd,cnd->cn", dw, hessian, dw).reshape(-1) return block_loss.unsqueeze(-1).expand(-1, bs).reshape(original_shape) return local_hessian_error -def _is_quant_fused_experts(module: nn.Module) -> bool: - """Whether ``module`` is a *converted* fused-MoE-experts wrapper with per-expert quantizers. - - Distinct from ``plugins.huggingface._is_fused_experts_module`` (which detects the - *unconverted* HF module by structure); this checks for the per-expert weight-quantizer - lists and routing index added by ``_QuantFusedExperts``. - """ - return hasattr(module, "_current_expert_idx") and hasattr( - module, "gate_up_proj_weight_quantizers" - ) - - def _warn_if_block_size_mismatch(weight_quantizer: TensorQuantizer, block_size: int, name: str): - """Warn if the Hessian block_size disagrees with the quantizer's last-axis scale block. - - They must match for the per-block Hessian to weight the same blocks the scale search - optimizes; a mismatch silently produces a finite-but-misaligned amax. - """ + """Warn if the Hessian block_size differs from the quantizer's scale block (misaligns).""" block_sizes = getattr(weight_quantizer, "block_sizes", None) quant_block = block_sizes.get(-1) if block_sizes else None if quant_block is not None and quant_block != block_size: @@ -713,6 +684,22 @@ def _warn_if_block_size_mismatch(weight_quantizer: TensorQuantizer, block_size: ) +def _warn_local_hessian_fallback(name, weight, weight_quantizer, block_size, warned: set): + """Warn once per ``(name, cin)`` when a captured layer falls back to plain MSE.""" + if weight.dim() < 2: + return + cin = weight.shape[1] + if (name, cin) in warned: + return + warned.add((name, cin)) + if cin % block_size != 0: + warn_rank_0( + f"local_hessian: {name} input features ({cin}) not divisible by block_size " + f"({block_size}); falling back to plain MSE for these weights." + ) + _warn_if_block_size_mismatch(weight_quantizer, block_size, name) + + @torch.no_grad() def local_hessian_calibrate( model: nn.Module, @@ -725,22 +712,15 @@ def local_hessian_calibrate( block_size: int = 16, debug: bool = False, ): - """Calibrate the model using local Hessian-weighted MSE search. - - Instead of minimizing weight error ``||W - Wq||²``, this minimizes the Hessian-weighted - error ``loss = (W - Wq)ᵀ H (W - Wq)`` where ``H = ΣXᵀX`` approximates the output - reconstruction error ``||WX - WqX||²``. + """Calibrate weight quantizers by minimizing the Hessian-weighted error. - Per-block Hessians of shape ``(cin // block_size, block_size, block_size)`` are - accumulated during a dedicated full-precision forward pass, then used to weight the - per-block MSE error while searching the weight scale (reusing :func:`mse_calibrate`'s - weight-calibration machinery via a custom ``error_func``). + Minimizes ``(W - Wq)ᵀ H (W - Wq)`` with per-block Hessian ``H = ΣXᵀX`` (approximating the + output error ``||WX - WqX||²``), built from a full-precision forward and fed to + :func:`mse_calibrate`'s weight search via a custom ``error_func``. - Coverage: dense quantized linears and HF fused-MoE experts (per-expert weight - quantizers, Hessian built from each expert's routed activations). Quantizers without a - usable Hessian — never-routed experts, ``cin`` not divisible by ``block_size``, - registered custom backends, or non-eager fused-expert kernels that bypass ``F.linear`` - — fall back to the plain max/MSE amax. + Like :func:`mse_calibrate`, every weight quantizer is calibrated; the Hessian metric is + applied only where a weight can be paired with its input activations (dense linears and + HF fused-MoE experts). All other weights fall back to plain MSE. Args: model: Model to be calibrated. @@ -763,133 +743,68 @@ def local_hessian_calibrate( warnings.warn("forward_loop must be provided for local_hessian; skipping local_hessian") return - # Phase 1: max-calibrate the whole model. max_calibrate also bootstraps dead-expert - # weight quantizers and promotes / global-amax-syncs NVFP4 static quantizers, so amax - # is initialized for every quantizer before the Hessian refinement below. + # Phase 1: max-calibrate (also bootstraps dead experts + promotes/syncs NVFP4 static). print_rank_0("local_hessian: Running max calibration for all quantizers...") max_calibrate(model, forward_loop, distributed_sync) name_to_module = dict(model.named_modules()) - # Per-block Hessian accumulators keyed by id(weight_quantizer), so dense linears and - # per-expert MoE quantizers share one uniform lookup during weight calibration. + # Hessians keyed by id(weight_quantizer); modules pair weights<->activations via the hook. accumulators: dict[int, _LocalHessianAccumulator] = {} - cleanup_callbacks: list[Callable[[], None]] = [] - # Fused per-expert weight quantizers to silence during caching (their fake-quant runs - # inside the intercepted F.linear, not in a hookable module forward). - fused_weight_quantizers: list[TensorQuantizer] = [] - # Mutable toggle shared with the capture hooks; only on during the caching forward. - cache_state = {"on": False} - - def _setup_dense(module: nn.Module, name: str) -> None: - weight_quantizer = module.weight_quantizer - cout, cin = module.weight.shape[0], module.weight.shape[1] - acc = _LocalHessianAccumulator(cout, cin, block_size) - accumulators[id(weight_quantizer)] = acc - if not acc.is_enabled: - warn_rank_0( - f"local_hessian: {name} input features ({cin}) not divisible by block_size " - f"({block_size}); falling back to plain MSE for this module." - ) - _warn_if_block_size_mismatch(weight_quantizer, block_size, name) - - def forward(self, input, *args, **kwargs): - if cache_state["on"]: - input_local = input.to_local() if hasattr(input, "to_local") else input - acc.accumulate(input_local) - # Capture full-precision-weight activations during the caching pass. - # Snapshot _if_quant so a quantizer that was not fake-quanting stays off. - was_quant = self.weight_quantizer._if_quant - self.weight_quantizer.disable_quant() - try: - return self._forward_no_local_hessian(input, *args, **kwargs) - finally: - if was_quant: - self.weight_quantizer.enable_quant() - return self._forward_no_local_hessian(input, *args, **kwargs) - - bind_forward_method(module, forward, "_forward_no_local_hessian") - cleanup_callbacks.append( - partial(unpatch_forward_method, module, "_forward_no_local_hessian") - ) - def _setup_fused_experts(module: nn.Module) -> None: - # _QuantFusedExperts calls each shared input quantizer with the current expert's - # routed activations while module._current_expert_idx holds that expert's index. - # A pre-hook on the input quantizer therefore captures the per-expert Hessian input. - for proj in ("gate_up", "down"): - weight = getattr(module, f"{proj}_proj", None) - input_quantizer = getattr(module, f"{proj}_proj_input_quantizer", None) - weight_quantizers = getattr(module, f"{proj}_proj_weight_quantizers", None) - if weight is None or input_quantizer is None or weight_quantizers is None: - continue - # All experts of a projection share cin, so validate / warn once per projection. - cin = weight[0].shape[1] - if cin % block_size != 0: - warn_rank_0( - f"local_hessian: fused {proj}_proj input features ({cin}) not divisible by " - f"block_size ({block_size}); falling back to plain MSE for these experts." - ) - _warn_if_block_size_mismatch(weight_quantizers[0], block_size, f"fused {proj}_proj") - for idx, weight_quantizer in enumerate(weight_quantizers): - cout = weight[idx].shape[0] - accumulators[id(weight_quantizer)] = _LocalHessianAccumulator(cout, cin, block_size) - fused_weight_quantizers.append(weight_quantizer) - - def pre_hook(_input_q, args, _module=module, _weight_quantizers=weight_quantizers): - if not cache_state["on"] or not args: - return - input_tensor = args[0] - input_tensor = ( - input_tensor.to_local() if hasattr(input_tensor, "to_local") else input_tensor - ) - # _current_expert_idx indexes the same list these accumulators are keyed by; - # index directly so a routing/ordering desync surfaces instead of silently - # dropping activations. - accumulators[id(_weight_quantizers[_module._current_expert_idx])].accumulate( - input_tensor - ) - - handle = input_quantizer.register_forward_pre_hook(pre_hook) - cleanup_callbacks.append(handle.remove) - - # Phase 2: install per-quantizer activation-capture hooks. + def capture(weight_quantizer, weight, input_tensor): + input_local = input_tensor.to_local() if hasattr(input_tensor, "to_local") else input_tensor + acc = accumulators.get(id(weight_quantizer)) + if acc is None: + acc = _LocalHessianAccumulator(weight.shape[0], weight.shape[1], block_size) + accumulators[id(weight_quantizer)] = acc + acc.accumulate(input_local) + + # Phase 2: register capture hooks, silence weight quant (full-precision activations), + # run one forward to accumulate Hessians. Hooks live only for this forward. + handles: list = [] + silenced_weight_quantizers: list[TensorQuantizer] = [] + warned: set = set() + seen_modules: set[int] = set() for name, module in name_to_module.items(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - with enable_weight_access_and_writeback(module, model, name_to_module): - _setup_dense(module, name) - elif _is_quant_fused_experts(module): - with enable_weight_access_and_writeback(module, model, name_to_module): - _setup_fused_experts(module) - - # Phase 3: cache activations / accumulate Hessians with a full-precision forward. - # Silence fused per-expert weight quantizers so down-proj inputs see FP gate_up weights, - # mirroring the dense path which disables its own weight quantizer per forward. - fused_disabled = [q for q in fused_weight_quantizers if q.is_enabled and q._if_quant] - for q in fused_disabled: - q.disable_quant() - cache_state["on"] = True + if not isinstance(module, QuantModule) or id(module) in seen_modules: + continue + seen_modules.add(id(module)) + with enable_weight_access_and_writeback(module, model, name_to_module): + captures = module.register_calibration_input_hooks(capture) + handles.extend(captures) + for weight, weight_quantizer in module.iter_weights_for_calibration(): + # Only TensorQuantizer weights are calibrated (matches mse_calibrate), so only + # those are silenced. + if ( + isinstance(weight_quantizer, TensorQuantizer) + and weight_quantizer.is_enabled + and weight_quantizer._if_quant + ): + silenced_weight_quantizers.append(weight_quantizer) + if captures: + _warn_local_hessian_fallback(name, weight, weight_quantizer, block_size, warned) + + for weight_quantizer in silenced_weight_quantizers: + weight_quantizer.disable_quant() print_rank_0("local_hessian: Caching activations and computing local Hessian...") try: forward_loop(model) finally: - cache_state["on"] = False - for q in fused_disabled: - q.enable_quant() - - # TODO(fridah-nv): the per-block Hessian is sharded for row-parallel linears and differs - # across data-parallel ranks; max_calibrate's amax sync runs *before* this refinement, so - # refined amaxes can diverge across ranks. All-reduce the Hessian (and/or re-sync amax) - # for correct TP/DP behavior. + for weight_quantizer in silenced_weight_quantizers: + weight_quantizer.enable_quant() + for handle in handles: + handle.remove() + + # TODO(fridah-nv): the per-block Hessian is not synced across TP/DP ranks (max_calibrate's + # amax sync runs before this), so refined amaxes can diverge. All-reduce Hessian / re-sync. if dist_is_initialized() and dist_size() > 1: warn_rank_0( "local_hessian: Hessian is not synced across ranks; refined weight amaxes may " "diverge under tensor/data parallelism. Treat local_hessian as single-rank for now." ) - # Phase 4: build per-quantizer error functions and refine weight scales via the shared - # MSE weight-calibration loop. Quantizers without a usable Hessian map to None and fall - # back to plain max/MSE. + # Phase 3: build error funcs and run the shared MSE weight loop. error_funcs = {qid: acc.build_error_func() for qid, acc in accumulators.items()} print_rank_0("local_hessian: Running MSE calibration with local Hessian loss...") _mse_calibrate_weights( @@ -902,9 +817,6 @@ def pre_hook(_input_q, args, _module=module, _weight_quantizers=weight_quantizer error_func_for=lambda q: error_funcs.get(id(q)), ) - # Phase 5: remove hooks / patched forwards. - for cleanup in cleanup_callbacks: - cleanup() if debug: model._local_hessian_accumulators = accumulators diff --git a/modelopt/torch/quantization/nn/modules/quant_module.py b/modelopt/torch/quantization/nn/modules/quant_module.py index 419c6f4924f..a54d2cf8def 100644 --- a/modelopt/torch/quantization/nn/modules/quant_module.py +++ b/modelopt/torch/quantization/nn/modules/quant_module.py @@ -17,6 +17,7 @@ import contextlib import warnings +from collections.abc import Callable from typing import Any import torch @@ -127,6 +128,17 @@ def iter_weights_for_calibration(self): weight_quantizer = getattr(self, quantizer_attr_names(weight_name).weight_quantizer) yield getattr(self, weight_name), weight_quantizer + def register_calibration_input_hooks( + self, callback: Callable[[TensorQuantizer, torch.Tensor, torch.Tensor], None] + ) -> list: + """Register forward hooks calling ``callback(weight_quantizer, weight, input)`` per weight. + + Activation-side counterpart to :meth:`iter_weights_for_calibration`, used by + activation-aware calibration (e.g. local-Hessian). Returns removable handles; the base + default is ``[]`` (no pairing available -> plain weight calibration). Override per module. + """ + return [] + def fold_weight(self, keep_attrs: bool = False): """Fold the weight for faster eval.""" # Handle all attributes that end with _weight_quantizer @@ -247,6 +259,18 @@ def _setup(self): self._register_temp_attribute("_enable_weight_quantization", False) self._register_dynamic_attribute("weight", self._get_quantized_weight) + def register_calibration_input_hooks(self, callback): + """Pair the weight quantizer with the forward input (linear only; conv falls back).""" + weight = getattr(self, "weight", None) + if weight is None or weight.dim() != 2 or not self.weight_quantizer.is_enabled: + return [] + + def _pre_hook(module, args): + if args: + callback(module.weight_quantizer, module.weight, args[0]) + + return [self.register_forward_pre_hook(_pre_hook)] + class _LegacyQuantInputBaseMixin: """A mixin to support legacy quantized modules which needs to have an __init__ method.""" diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 1873ecda528..a4c5cca0663 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -918,6 +918,36 @@ def iter_weights_for_calibration(self): for idx, q in enumerate(quantizers): yield weight[idx], q + def register_calibration_input_hooks(self, callback): + """Pair each per-expert weight quantizer with its routed input activation. + + Hooks the shared input quantizers, which the eager ``F.linear`` path calls per expert + while ``_current_expert_idx`` is set. Batched/grouped kernels never call them, so those + experts get no capture (fall back to plain weight calibration). + """ + handles = [] + for weight_name, quantizers_name, input_quantizer_name in ( + ("gate_up_proj", "gate_up_proj_weight_quantizers", "gate_up_proj_input_quantizer"), + ("down_proj", "down_proj_weight_quantizers", "down_proj_input_quantizer"), + ): + weight = getattr(self, weight_name, None) + quantizers = getattr(self, quantizers_name, None) + input_quantizer = getattr(self, input_quantizer_name, None) + if weight is None or quantizers is None or input_quantizer is None: + continue + + def _pre_hook(_iq, args, _weight_name=weight_name, _quantizers=quantizers): + if not args: + return + idx = self._current_expert_idx + weight_quantizer = _quantizers[idx] + if weight_quantizer.is_enabled: + # Read the weight fresh (valid under accelerate/FSDP re-materialization). + callback(weight_quantizer, getattr(self, _weight_name)[idx], args[0]) + + handles.append(input_quantizer.register_forward_pre_hook(_pre_hook)) + return handles + def fold_weight(self, keep_attrs: bool = False): """Fold per-expert weight quantizers into the fused 3-D weights. diff --git a/tests/unit/torch/quantization/plugins/test_fused_experts.py b/tests/unit/torch/quantization/plugins/test_fused_experts.py index 61c2e4a07e6..ec8259c818f 100644 --- a/tests/unit/torch/quantization/plugins/test_fused_experts.py +++ b/tests/unit/torch/quantization/plugins/test_fused_experts.py @@ -656,9 +656,9 @@ def forward_loop(m): self._cleanup_registry(expert_type) - def test_local_hessian_calibrates_per_expert_weights(self): - """local_hessian builds a per-expert Hessian (from each expert's routed inputs) - and refines every expert's weight amax via the shared MSE calibration loop.""" + def test_local_hessian_per_expert_capture_and_refinement(self): + """The plugin's extension point pairs each per-expert weight quantizer with its routed + input, and local_hessian uses that to refine every expert's weight amax.""" import modelopt.torch.quantization as mtq from modelopt.torch.quantization.model_calib import local_hessian_calibrate @@ -666,17 +666,12 @@ def test_local_hessian_calibrates_per_expert_weights(self): expert_type = type(model.moe.experts) self._cleanup_registry(expert_type) + weight_quant = {"num_bits": 8, "axis": 0} quant_cfg = { "quant_cfg": [ {"quantizer_name": "*", "enable": False}, - { - "quantizer_name": "*gate_up_proj_weight_quantizer", - "cfg": {"num_bits": 8, "axis": 0}, - }, - { - "quantizer_name": "*down_proj_weight_quantizer", - "cfg": {"num_bits": 8, "axis": 0}, - }, + {"quantizer_name": "*gate_up_proj_weight_quantizer", "cfg": weight_quant}, + {"quantizer_name": "*down_proj_weight_quantizer", "cfg": weight_quant}, ], "algorithm": "max", } @@ -691,22 +686,33 @@ def forward_loop(m): expert_quantizers = list(experts.gate_up_proj_weight_quantizers) + list( experts.down_proj_weight_quantizers ) - max_amax = {id(q): q.amax.clone() for q in expert_quantizers if q.amax is not None} - - local_hessian_calibrate(model, forward_loop, fp8_scale_sweep=False, debug=True) - # At least one expert's routed activations produced a per-block Hessian. - accumulators = model._local_hessian_accumulators - assert any(acc.num_samples > 0 for acc in accumulators.values()), ( - "no per-expert Hessian captured — F.linear hook likely bypassed." + # Extension point captures per-expert (weight_quantizer, weight_slice, cin). + captured = [] + handles = experts.register_calibration_input_hooks( + lambda wq, w, x: captured.append((id(wq), tuple(w.shape), x.shape[-1])) + ) + assert len(handles) == 2 # one pre-hook per shared input quantizer (gate_up, down) + with torch.no_grad(): + model(torch.randn(1, 8, HIDDEN_DIM)) + for h in handles: + h.remove() + valid_ids = {id(q) for q in expert_quantizers} + shapes = {(2 * INTERMEDIATE_DIM, HIDDEN_DIM), (HIDDEN_DIM, INTERMEDIATE_DIM)} + assert captured and all( + wq_id in valid_ids and shape in shapes and cin == shape[1] + for wq_id, shape, cin in captured ) - refined = False - for q in expert_quantizers: - assert q.amax is not None and torch.isfinite(q.amax).all() - if id(q) in max_amax and not torch.allclose(q.amax, max_amax[id(q)]): - refined = True - assert refined, "local_hessian did not refine any expert weight amax" + # End-to-end: local_hessian refines per-expert weight amax via that capture. + max_amax = {id(q): q.amax.clone() for q in expert_quantizers if q.amax is not None} + local_hessian_calibrate(model, forward_loop, fp8_scale_sweep=False, debug=True) + assert any(a.num_samples > 0 for a in model._local_hessian_accumulators.values()) + assert all(q.amax is not None and torch.isfinite(q.amax).all() for q in expert_quantizers) + assert any( + id(q) in max_amax and not torch.allclose(q.amax, max_amax[id(q)]) + for q in expert_quantizers + ) self._cleanup_registry(expert_type) diff --git a/tests/unit/torch/quantization/test_local_hessian.py b/tests/unit/torch/quantization/test_local_hessian.py index 013dd3ecbaa..1a7e6ccd9a7 100644 --- a/tests/unit/torch/quantization/test_local_hessian.py +++ b/tests/unit/torch/quantization/test_local_hessian.py @@ -15,9 +15,12 @@ """Tests for local Hessian-weighted MSE calibration (CPU).""" +import warnings + import pytest import torch -from _test_utils.torch.quantization.models import SimpleLinear +import torch.nn as nn +from _test_utils.torch.quantization.models import SimpleConv, SimpleLinear import modelopt.torch.quantization as mtq from modelopt.torch.quantization import calib @@ -27,6 +30,7 @@ _LocalHessianAccumulator, _make_weight_mse_calibrator, _register_fp8_sweep_calibrator, + _warn_if_block_size_mismatch, local_hessian_calibrate, mse_calibrate, ) @@ -36,7 +40,7 @@ register_quant_backend, ) -# Weight-only INT8 per-channel config; calibration is re-run explicitly per test. +# Weight-only INT8 per-channel; calibration is re-run explicitly per test. INT8_WEIGHT_CFG = { "quant_cfg": [ {"quantizer_name": "*", "enable": False}, @@ -46,188 +50,151 @@ } +def _weight_amaxes(model): + return { + n: m.amax + for n, m in model.named_modules() + if isinstance(m, TensorQuantizer) and m.is_enabled and m.amax is not None + } + + +def _make_forward_loop(seed=0): + def forward_loop(model): + torch.manual_seed(seed) + for _ in range(3): + x = torch.randn(8, 16) + x[:, 0] *= 40.0 # skew so the Hessian is non-trivial vs plain weight MSE + model(x) + + return forward_loop + + class TestLocalHessianAccumulator: - def test_shape_samples_and_buffer_release(self): + def test_accumulate_shape_samples_fp32_buffer(self): torch.manual_seed(0) - cout, cin, bs = 8, 32, 16 - acc = _LocalHessianAccumulator(cout, cin, bs) + acc = _LocalHessianAccumulator(8, 32, 16) assert acc.is_enabled - - x = torch.randn(10, cin) - acc.accumulate(x) - assert acc.hessian_per_block.shape == (cin // bs, bs, bs) - assert acc.num_samples == 10 - - # A second batch accumulates (sum over samples). - acc.accumulate(torch.randn(5, cin)) + acc.accumulate(torch.randn(10, 32, dtype=torch.bfloat16)) + assert acc.hessian_per_block.shape == (2, 16, 16) + assert acc.hessian_per_block.dtype == torch.float32 # fp32 despite bf16 input + acc.accumulate(torch.randn(5, 32)) assert acc.num_samples == 15 - - error_func = acc.build_error_func() - assert error_func is not None - # build_error_func releases the raw accumulator (keeps only the normalized copy). - assert acc.hessian_per_block is None + assert acc.build_error_func() is not None + assert acc.hessian_per_block is None # raw buffer freed def test_error_func_matches_explicit_hessian_weighted_loss(self): torch.manual_seed(1) cout, cin, bs = 4, 32, 16 n_blocks = cin // bs acc = _LocalHessianAccumulator(cout, cin, bs) - x = torch.randn(7, cin) acc.accumulate(x) - num_samples = acc.num_samples error_func = acc.build_error_func() - # Reference normalized per-block Hessian. xb = x.reshape(-1, cin).T.reshape(n_blocks, bs, -1) - hessian = (xb @ xb.transpose(-1, -2)) / num_samples - - total_blocks = cout * n_blocks - w = torch.randn(total_blocks, bs) + hessian = (xb @ xb.transpose(-1, -2)) / acc.num_samples + w = torch.randn(cout * n_blocks, bs) wq = w + 0.05 * torch.randn_like(w) - err = error_func(w, wq) - - assert err.shape == w.shape - # The per-block scalar loss is broadcast across the block's bs entries. - err_blocks = err.view(-1, bs) - assert torch.allclose(err_blocks, err_blocks[:, :1].expand(-1, bs)) + err = error_func(w, wq).view(-1, bs) + assert err.shape == (cout * n_blocks, bs) + assert torch.allclose(err, err[:, :1].expand(-1, bs)) # per-block scalar broadcast dw = (w - wq).view(cout, n_blocks, bs) expected = torch.einsum("cnb,nbd,cnd->cn", dw, hessian, dw).reshape(-1) - assert torch.allclose(err_blocks[:, 0], expected, atol=1e-5) + assert torch.allclose(err[:, 0], expected, atol=1e-5) - def test_disabled_when_cin_not_divisible(self): - acc = _LocalHessianAccumulator(8, 30, 16) - assert not acc.is_enabled - acc.accumulate(torch.randn(4, 30)) # no-op - assert acc.hessian_per_block is None - assert acc.build_error_func() is None - - def test_no_samples_returns_none(self): - acc = _LocalHessianAccumulator(8, 32, 16) - assert acc.build_error_func() is None - - def test_accumulates_in_fp32_for_low_precision_input(self): - acc = _LocalHessianAccumulator(4, 16, 16) - acc.accumulate(torch.randn(8, 16, dtype=torch.bfloat16)) - assert acc.hessian_per_block.dtype == torch.float32 - - -class TestBlockSizeMismatchWarning: - def _block_quantizer(self, block): - cfg = QuantizerAttributeConfig( - num_bits=(2, 1), block_sizes={-1: block, "type": "static", "scale_bits": (4, 3)} - ) - return TensorQuantizer(quant_attribute_cfg=cfg) - - def test_warns_on_mismatch(self): - from modelopt.torch.quantization.model_calib import _warn_if_block_size_mismatch - - with pytest.warns(UserWarning, match="will not align"): - _warn_if_block_size_mismatch(self._block_quantizer(32), 16, "layer") - - def test_silent_when_matching_or_no_block_sizes(self): - import warnings - - from modelopt.torch.quantization.model_calib import _warn_if_block_size_mismatch - - with warnings.catch_warnings(): - warnings.simplefilter("error") - _warn_if_block_size_mismatch(self._block_quantizer(16), 16, "layer") - per_channel = TensorQuantizer(QuantizerAttributeConfig(num_bits=8, axis=0)) - _warn_if_block_size_mismatch(per_channel, 16, "layer") - - -def _make_forward_loop(seed=0, skew=True): - def forward_loop(model): - torch.manual_seed(seed) - for _ in range(3): - x = torch.randn(8, 16) - if skew: - # Skew one input feature so the activation Hessian is non-trivial and - # the Hessian-weighted optimum diverges from the plain weight MSE. - x[:, 0] *= 40.0 - model(x) - - return forward_loop + def test_returns_none_when_disabled_or_no_samples(self): + not_divisible = _LocalHessianAccumulator(8, 30, 16) + assert not not_divisible.is_enabled + not_divisible.accumulate(torch.randn(4, 30)) # no-op + assert not_divisible.build_error_func() is None + assert _LocalHessianAccumulator(8, 32, 16).build_error_func() is None # no samples class TestLocalHessianCalibrateDense: - def test_runs_and_refines_amax(self): - torch.manual_seed(0) - model = SimpleLinear() + def test_refines_amax_beyond_max_and_plain_mse(self): forward_loop = _make_forward_loop() - mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=forward_loop) - - # Snapshot the post-max-calibration amax for each weight quantizer. - max_amax = { - name: module.amax.clone() - for name, module in model.named_modules() - if isinstance(module, TensorQuantizer) and module.is_enabled and module.amax is not None - } - assert max_amax, "expected enabled weight quantizers after quantize" - - local_hessian_calibrate(model, forward_loop, fp8_scale_sweep=False, debug=True) - - # Every enabled weight quantizer got a Hessian accumulator with collected samples. - accumulators = model._local_hessian_accumulators - assert accumulators - assert all(acc.num_samples > 0 for acc in accumulators.values()) - - # The Hessian-weighted search moved at least one amax away from the max value. - changed = False - for name, module in model.named_modules(): - if name in max_amax: - assert torch.isfinite(module.amax).all() and (module.amax > 0).all() - if not torch.allclose(module.amax, max_amax[name]): - changed = True - assert changed, "local_hessian did not refine any amax away from max-calibration" - - def test_differs_from_plain_mse(self): - forward_loop = _make_forward_loop(seed=3) - torch.manual_seed(0) model_lh = SimpleLinear() mtq.quantize(model_lh, INT8_WEIGHT_CFG, forward_loop=forward_loop) - local_hessian_calibrate(model_lh, forward_loop, fp8_scale_sweep=False) + max_amax = {n: a.clone() for n, a in _weight_amaxes(model_lh).items()} + local_hessian_calibrate(model_lh, forward_loop, fp8_scale_sweep=False, debug=True) torch.manual_seed(0) model_mse = SimpleLinear() mtq.quantize(model_mse, INT8_WEIGHT_CFG, forward_loop=forward_loop) mse_calibrate(model_mse, forward_loop, fp8_scale_sweep=False) - lh = { - n: m.amax - for n, m in model_lh.named_modules() - if isinstance(m, TensorQuantizer) and m.is_enabled and m.amax is not None - } - mse = { - n: m.amax - for n, m in model_mse.named_modules() - if isinstance(m, TensorQuantizer) and m.is_enabled and m.amax is not None - } - assert lh.keys() == mse.keys() - # Hessian weighting should change the chosen scale for at least one quantizer. - assert any(not torch.allclose(lh[n], mse[n]) for n in lh) + accs = model_lh._local_hessian_accumulators + assert accs and all(a.num_samples > 0 for a in accs.values()) + lh, mse = _weight_amaxes(model_lh), _weight_amaxes(model_mse) + assert all(torch.isfinite(a).all() and (a > 0).all() for a in lh.values()) + assert any(not torch.allclose(lh[n], max_amax[n]) for n in lh) # refined past max-cal + assert any(not torch.allclose(lh[n], mse[n]) for n in lh) # Hessian changed the choice + + def test_warns_with_module_name_when_cin_not_divisible(self): + class _OddModel(nn.Module): + def __init__(self): + super().__init__() + self.odd = nn.Linear(24, 32) # 24 not divisible by block_size 16 + + def forward(self, x): + return self.odd(x) + + torch.manual_seed(0) + model = _OddModel() + forward_loop = lambda m: m(torch.randn(4, 24)) # noqa: E731 + mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=forward_loop) + with pytest.warns(UserWarning, match=r"odd input features \(24\) not divisible"): + local_hessian_calibrate(model, forward_loop, fp8_scale_sweep=False) def test_no_forward_loop_is_skipped(self): torch.manual_seed(0) model = SimpleLinear() mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=_make_forward_loop()) - before = { - n: m.amax.clone() - for n, m in model.named_modules() - if isinstance(m, TensorQuantizer) and m.is_enabled and m.amax is not None - } + before = {n: a.clone() for n, a in _weight_amaxes(model).items()} with pytest.warns(UserWarning, match="forward_loop must be provided"): local_hessian_calibrate(model, forward_loop=None) - after = { - n: m.amax - for n, m in model.named_modules() - if isinstance(m, TensorQuantizer) and m.is_enabled and m.amax is not None - } - for n in before: - assert torch.equal(before[n], after[n]) + assert all(torch.equal(before[n], a) for n, a in _weight_amaxes(model).items()) + + +class TestActivationCaptureExtensionPoint: + """The extension point that decouples local-Hessian capture from module type.""" + + def test_dense_captures_and_conv_falls_back(self): + torch.manual_seed(0) + model = SimpleLinear() + mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=_make_forward_loop()) + captured = [] + handles = model.net[0].register_calibration_input_hooks( + lambda wq, w, x: captured.append((tuple(w.shape), x.shape[-1])) + ) + assert len(handles) == 1 + with torch.no_grad(): + model(torch.randn(2, 16)) + for h in handles: + h.remove() + assert captured and captured[0] == ((32, 16), 16) # cin from activation matches weight + + conv = SimpleConv() + mtq.quantize(conv, INT8_WEIGHT_CFG, forward_loop=lambda m: m(SimpleConv.get_input())) + assert conv.net[0].register_calibration_input_hooks(lambda *a: None) == [] # 4-D weight + + def test_block_size_mismatch_warns_only_on_mismatch(self): + def q(block): + return TensorQuantizer( + QuantizerAttributeConfig( + num_bits=(2, 1), block_sizes={-1: block, "type": "static", "scale_bits": (4, 3)} + ) + ) + + with pytest.warns(UserWarning, match="will not align"): + _warn_if_block_size_mismatch(q(32), 16, "layer") + with warnings.catch_warnings(): + warnings.simplefilter("error") + _warn_if_block_size_mismatch(q(16), 16, "layer") # matching block + per_channel = TensorQuantizer(QuantizerAttributeConfig(num_bits=8, axis=0)) + _warn_if_block_size_mismatch(per_channel, 16, "layer") # no block_sizes class TestMakeWeightMseCalibratorErrorFunc: @@ -242,16 +209,14 @@ def teardown_method(self): _QUANT_FUNCTIONAL_BACKENDS.update(self._orig_quant_backends) def _make_quantizer(self, backend=None): - cfg = QuantizerAttributeConfig(num_bits=8, axis=None, backend=backend) - q = TensorQuantizer(quant_attribute_cfg=cfg) + q = TensorQuantizer(QuantizerAttributeConfig(num_bits=8, axis=None, backend=backend)) q.amax = torch.tensor(1.0) return q def test_error_func_threaded_to_mse_calibrator(self): - q = self._make_quantizer() marker = lambda x, xq: (x - xq) ** 2 # noqa: E731 cal = _make_weight_mse_calibrator( - q, 0.1, 0.25, 4.0, fp8_scale_sweep=False, error_func=marker + self._make_quantizer(), 0.1, 0.25, 4.0, fp8_scale_sweep=False, error_func=marker ) assert isinstance(cal, calib.MseCalibrator) assert cal._error_func is marker From 63ff148c2bb47cf3f45844c8ff0cac90aa81f5f3 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Mon, 1 Jun 2026 21:23:56 +0000 Subject: [PATCH 3/5] address review feedbacks Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 61 +++++++++++++------ .../quantization/nn/modules/quant_module.py | 13 +++- .../plugins/test_fused_experts.py | 5 +- .../torch/quantization/test_local_hessian.py | 10 +++ 4 files changed, 65 insertions(+), 24 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 9755004cf19..dd59d6911d4 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -654,14 +654,20 @@ def accumulate(self, input_tensor: torch.Tensor) -> None: self.hessian_per_block += hessian_batch self.num_samples += input_tensor.numel() // self.cin - def build_error_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None: - """Hessian-weighted error function (``None`` if no samples); frees the raw buffer.""" + def build_error_func( + self, keep_buffer: bool = False + ) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None: + """Hessian-weighted error function (``None`` if no samples). + + Frees the raw Hessian buffer unless ``keep_buffer`` (kept for debug inspection). + """ if self.hessian_per_block is None or self.num_samples == 0: return None cout = self.cout bs = self.block_size hessian = self.hessian_per_block / self.num_samples - self.hessian_per_block = None + if not keep_buffer: + self.hessian_per_block = None def local_hessian_error(x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: original_shape = x.shape @@ -715,12 +721,13 @@ def local_hessian_calibrate( """Calibrate weight quantizers by minimizing the Hessian-weighted error. Minimizes ``(W - Wq)ᵀ H (W - Wq)`` with per-block Hessian ``H = ΣXᵀX`` (approximating the - output error ``||WX - WqX||²``), built from a full-precision forward and fed to - :func:`mse_calibrate`'s weight search via a custom ``error_func``. + output error ``||WX - WqX||²``), built from a forward with weight fake-quant disabled + (input quantizers untouched) and fed to :func:`mse_calibrate`'s weight search via ``error_func``. - Like :func:`mse_calibrate`, every weight quantizer is calibrated; the Hessian metric is - applied only where a weight can be paired with its input activations (dense linears and - HF fused-MoE experts). All other weights fall back to plain MSE. + Like :func:`mse_calibrate`, TensorQuantizer weights are calibrated — with the Hessian + metric where a weight pairs with its input activations (dense linears and HF fused-MoE + experts), plain MSE otherwise. Other quantizer types (e.g. SequentialQuantizer) are + unsupported and left at their max-calibrated scale. Args: model: Model to be calibrated. @@ -760,8 +767,8 @@ def capture(weight_quantizer, weight, input_tensor): accumulators[id(weight_quantizer)] = acc acc.accumulate(input_local) - # Phase 2: register capture hooks, silence weight quant (full-precision activations), - # run one forward to accumulate Hessians. Hooks live only for this forward. + # Phase 2: register capture hooks, disable weight fake-quant (input quantizers left as-is, + # matching prior behavior), run one forward to accumulate Hessians. Hooks live only for it. handles: list = [] silenced_weight_quantizers: list[TensorQuantizer] = [] warned: set = set() @@ -774,14 +781,28 @@ def capture(weight_quantizer, weight, input_tensor): captures = module.register_calibration_input_hooks(capture) handles.extend(captures) for weight, weight_quantizer in module.iter_weights_for_calibration(): - # Only TensorQuantizer weights are calibrated (matches mse_calibrate), so only - # those are silenced. - if ( - isinstance(weight_quantizer, TensorQuantizer) - and weight_quantizer.is_enabled - and weight_quantizer._if_quant - ): - silenced_weight_quantizers.append(weight_quantizer) + # Silence weight fake-quant (incl. SequentialQuantizer leaves) so the capture + # forward uses full-precision weights and downstream Hessians aren't corrupted. + leaves = ( + list(weight_quantizer) + if isinstance(weight_quantizer, SequentialQuantizer) + else [weight_quantizer] + ) + silenced_weight_quantizers.extend( + q + for q in leaves + if isinstance(q, TensorQuantizer) and q.is_enabled and q._if_quant + ) + # Only TensorQuantizer weights are refined (same as mse_calibrate); other types + # (e.g. SequentialQuantizer) are unsupported and left at their max-cal scale. + if not isinstance(weight_quantizer, TensorQuantizer): + if weight_quantizer.is_enabled and "unsupported" not in warned: + warned.add("unsupported") + warn_rank_0( + "local_hessian: only TensorQuantizer weights are calibrated; other " + "types (e.g. SequentialQuantizer) stay at their max-calibrated scale." + ) + continue if captures: _warn_local_hessian_fallback(name, weight, weight_quantizer, block_size, warned) @@ -805,7 +826,9 @@ def capture(weight_quantizer, weight, input_tensor): ) # Phase 3: build error funcs and run the shared MSE weight loop. - error_funcs = {qid: acc.build_error_func() for qid, acc in accumulators.items()} + error_funcs = { + qid: acc.build_error_func(keep_buffer=debug) for qid, acc in accumulators.items() + } print_rank_0("local_hessian: Running MSE calibration with local Hessian loss...") _mse_calibrate_weights( model, diff --git a/modelopt/torch/quantization/nn/modules/quant_module.py b/modelopt/torch/quantization/nn/modules/quant_module.py index a54d2cf8def..e533ab8848a 100644 --- a/modelopt/torch/quantization/nn/modules/quant_module.py +++ b/modelopt/torch/quantization/nn/modules/quant_module.py @@ -260,9 +260,18 @@ def _setup(self): self._register_dynamic_attribute("weight", self._get_quantized_weight) def register_calibration_input_hooks(self, callback): - """Pair the weight quantizer with the forward input (linear only; conv falls back).""" + """Pair the weight quantizer with the forward input. + + Only a 2-D weight with an enabled ``TensorQuantizer`` is hooked; conv (4-D) and + ``SequentialQuantizer`` weights are unsupported and fall back to plain calibration. + """ weight = getattr(self, "weight", None) - if weight is None or weight.dim() != 2 or not self.weight_quantizer.is_enabled: + if ( + weight is None + or weight.dim() != 2 + or not isinstance(self.weight_quantizer, TensorQuantizer) + or not self.weight_quantizer.is_enabled + ): return [] def _pre_hook(module, args): diff --git a/tests/unit/torch/quantization/plugins/test_fused_experts.py b/tests/unit/torch/quantization/plugins/test_fused_experts.py index ec8259c818f..eb244e1036e 100644 --- a/tests/unit/torch/quantization/plugins/test_fused_experts.py +++ b/tests/unit/torch/quantization/plugins/test_fused_experts.py @@ -22,7 +22,9 @@ pytest.importorskip("transformers") +import modelopt.torch.quantization as mtq from modelopt.torch.quantization.conversion import _normalize_fused_experts_quantizer_name +from modelopt.torch.quantization.model_calib import local_hessian_calibrate from modelopt.torch.quantization.nn import QuantModuleRegistry from modelopt.torch.quantization.plugins.huggingface import ( _is_fused_experts_module, @@ -659,9 +661,6 @@ def forward_loop(m): def test_local_hessian_per_expert_capture_and_refinement(self): """The plugin's extension point pairs each per-expert weight quantizer with its routed input, and local_hessian uses that to refine every expert's weight amax.""" - import modelopt.torch.quantization as mtq - from modelopt.torch.quantization.model_calib import local_hessian_calibrate - model = _TinyMoEModel() expert_type = type(model.moe.experts) self._cleanup_registry(expert_type) diff --git a/tests/unit/torch/quantization/test_local_hessian.py b/tests/unit/torch/quantization/test_local_hessian.py index 1a7e6ccd9a7..7f014cf9bdb 100644 --- a/tests/unit/torch/quantization/test_local_hessian.py +++ b/tests/unit/torch/quantization/test_local_hessian.py @@ -180,6 +180,16 @@ def test_dense_captures_and_conv_falls_back(self): mtq.quantize(conv, INT8_WEIGHT_CFG, forward_loop=lambda m: m(SimpleConv.get_input())) assert conv.net[0].register_calibration_input_hooks(lambda *a: None) == [] # 4-D weight + def test_sequential_quantizer_weight_not_hooked(self): + from modelopt.torch.quantization.nn import SequentialQuantizer + + torch.manual_seed(0) + model = SimpleLinear() + mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=_make_forward_loop()) + linear = model.net[0] + linear.weight_quantizer = SequentialQuantizer(TensorQuantizer(), TensorQuantizer()) + assert linear.register_calibration_input_hooks(lambda *a: None) == [] # unsupported + def test_block_size_mismatch_warns_only_on_mismatch(self): def q(block): return TensorQuantizer( From 9a6bc61f67025fbc6315a6aac9090151df6669f1 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Mon, 1 Jun 2026 21:36:54 +0000 Subject: [PATCH 4/5] minor: more feedback Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- tests/unit/torch/quantization/test_local_hessian.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/unit/torch/quantization/test_local_hessian.py b/tests/unit/torch/quantization/test_local_hessian.py index 7f014cf9bdb..2a610aa87e6 100644 --- a/tests/unit/torch/quantization/test_local_hessian.py +++ b/tests/unit/torch/quantization/test_local_hessian.py @@ -34,7 +34,7 @@ local_hessian_calibrate, mse_calibrate, ) -from modelopt.torch.quantization.nn import TensorQuantizer +from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer from modelopt.torch.quantization.nn.modules.tensor_quantizer import ( _QUANT_FUNCTIONAL_BACKENDS, register_quant_backend, @@ -181,8 +181,6 @@ def test_dense_captures_and_conv_falls_back(self): assert conv.net[0].register_calibration_input_hooks(lambda *a: None) == [] # 4-D weight def test_sequential_quantizer_weight_not_hooked(self): - from modelopt.torch.quantization.nn import SequentialQuantizer - torch.manual_seed(0) model = SimpleLinear() mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=_make_forward_loop()) From 4a7e6751ddc075d8495217b310aadac49dcb1b9c Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Tue, 2 Jun 2026 00:16:34 +0000 Subject: [PATCH 5/5] memory refinement Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index dd59d6911d4..68aee5cde86 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -840,6 +840,15 @@ def capture(weight_quantizer, weight, input_tensor): error_func_for=lambda q: error_funcs.get(id(q)), ) + # Free the per-block Hessians (pinned by error_func closures) and the sweep's cached + # allocations so export starts from a defragmented allocator. + error_funcs.clear() + for module in name_to_module.values(): + if isinstance(module, TensorQuantizer) and isinstance(module._calibrator, MseCalibrator): + module._calibrator._error_func = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if debug: model._local_hessian_accumulators = accumulators