From da6eae49d92bfa44d57bdb807e1a1ee8326a77d4 Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Mon, 1 Jun 2026 20:54:56 -0700 Subject: [PATCH 1/6] [OMNIML-3994] Share NVFP4 weight global_amax across sibling modules - Add SharedQuantState: sibling weight quantizers (q/k/v, gate/up, per-expert w1/w3) calibrate to a single NVFP4 global_amax (the group max). - Discover groups via input-sharing hooks or regex patterns (MaxCalibConfig.shared_patterns); max_calibrate attaches the state on the group parent, populates, and promotes. - Extract out the collect_shared_input_modules helper and reuse it for this and export Signed-off-by: Shiyang Chen --- modelopt/torch/export/unified_export_hf.py | 79 +-- modelopt/torch/quantization/config.py | 16 + modelopt/torch/quantization/model_calib.py | 90 +--- modelopt/torch/quantization/utils/__init__.py | 12 + .../torch/quantization/utils/core_utils.py | 12 +- .../torch/quantization/utils/shared_input.py | 470 ++++++++++++++++++ .../torch/quantization/test_mse_calibrator.py | 16 +- .../torch/quantization/test_shared_input.py | 419 ++++++++++++++++ 8 files changed, 978 insertions(+), 136 deletions(-) create mode 100644 modelopt/torch/quantization/utils/shared_input.py create mode 100644 tests/unit/torch/quantization/test_shared_input.py diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 00e4a7008a9..0660ec1dcf6 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -54,11 +54,12 @@ from torch.distributed.fsdp import FSDPModule -from modelopt.torch.quantization import set_quantizer_by_cfg_context from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor +from modelopt.torch.quantization.utils import ( + collect_shared_input_modules as _collect_shared_input_modules, +) from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names -from modelopt.torch.utils.dataset_utils import _disable_use_cache try: from modelopt.torch.sparsity.attention_sparsity.conversion import export_sparse_attention_config @@ -255,68 +256,20 @@ def collect_shared_input_modules( ) -> tuple[dict, dict | None]: """Collect modules that share the same input using forward hooks. - This is a common helper for both LLM and diffusion model fusion. - - Args: - model: The model to analyze. - dummy_forward_fn: A callable that runs a dummy forward pass on the model. - Should be a function that takes no arguments. - collect_layernorms: If True, also collect layernorm output mappings (for AWQ). - - Returns: - A tuple of (input_to_linear, output_to_layernorm). - input_to_linear: Dict mapping input tensor to list of modules sharing that input. - output_to_layernorm: Dict mapping layernorm output to the layernorm module (or None). + Thin wrapper around + :func:`modelopt.torch.quantization.utils.collect_shared_input_modules`, + parameterized for the export use case: hooks every ``is_quantlinear`` + module with at least one enabled quantizer, and optionally tracks + layernorm outputs when ``collect_layernorms=True`` (for AWQ + pre_quant_scale folding into the preceding layernorm). """ - input_to_linear: dict = defaultdict(list) - output_to_layernorm: dict | None = defaultdict(lambda: None) if collect_layernorms else None - - def _input_hook(module, input, output): - """Update dictionary with list of all modules that share the same input.""" - if len(input) > 0 and isinstance(input[0], torch.Tensor): - # TODO: Handle DBRX MoE case - input_to_linear[input[0]].append(module) - - def _output_hook(module, input, output): - """Update dictionary with mapping of layernorms and their outputs.""" - if output_to_layernorm is not None and isinstance(output, torch.Tensor): - output_to_layernorm[output] = module - - handles = [] - - # Register hooks on all quantized linear modules (and optionally layernorms) - for name, module in model.named_modules(): - if collect_layernorms and is_layernorm(module): - module.name = name - handle = module.register_forward_hook(_output_hook) - handles.append(handle) - elif is_quantlinear(module) and ( - _is_enabled_quantizer(module.input_quantizer) - or _is_enabled_quantizer(module.weight_quantizer) - ): - module.name = name - handle = module.register_forward_hook(_input_hook) - handles.append(handle) - - if not handles: - return input_to_linear, output_to_layernorm - - # Run dummy forward pass to collect modules sharing same input. - # `_disable_use_cache` keeps the probe forward working on configs that don't - # set `use_cache` (e.g., stepfun-ai/Step-3.5-Flash's Step3p5Config). - try: - with ( - torch.no_grad(), - set_quantizer_by_cfg_context(model, [{"quantizer_name": "*", "enable": False}]), - _disable_use_cache(model), - ): - dummy_forward_fn() - finally: - # Always remove hooks - for handle in handles: - handle.remove() - - return input_to_linear, output_to_layernorm + return _collect_shared_input_modules( + model, + dummy_forward_fn, + module_filter=lambda m: is_quantlinear(m) + and (_is_enabled_quantizer(m.input_quantizer) or _is_enabled_quantizer(m.weight_quantizer)), + output_filter=is_layernorm if collect_layernorms else None, + ) def _fuse_shared_input_modules( diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index bf8e99ff5af..7746536fd74 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -715,6 +715,22 @@ class MaxCalibConfig(QuantizeAlgorithmConfig): ), ) + shared_patterns: dict[str, list[str]] | None = ModeloptField( + default=None, + title="Regex patterns for groups that share quantization state", + description=( + "Optional dict keyed by quantizer kind (``'weight'`` and/or ``'input'``), each a list " + "of regexes matched (full-match) against module fully-qualified names. For a kind, when " + "patterns are given they are the sole discovery source (hook-based discovery is skipped), " + "so they must list every group you want. Modules whose match yields the same capture-group " + "tuple form one group; the capture boundary chooses granularity: capture the immediate " + "parent for per-parent / per-expert groups (e.g. ``r'(.*)\\.(?:q_proj|k_proj|v_proj)'``, " + "``r'(.*)\\.(?:w1|w3)'``); leave the expert index uncaptured for one cross-expert group " + "(``r'(.*)\\.experts\\.\\d+\\.(?:w1|w3)'``). Only ``'weight'`` is used today; ``'input'`` is " + "reserved for future input-quantizer sharing. Default None -> hook-based discovery." + ), + ) + class MseCalibConfig(QuantizeAlgorithmConfig): """Configuration for per-tensor MSE calibration. diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 223994b3c46..94431670113 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -18,7 +18,7 @@ import math import time import warnings -from collections.abc import Callable +from collections.abc import Callable, Mapping, Sequence from functools import partial from typing import TypeAlias @@ -41,6 +41,7 @@ from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( + attach_shared_quant_states, disable_calib, enable_fake_quant, enable_quant, @@ -49,8 +50,8 @@ is_quantized_linear, is_quantized_row_parallel_linear, persistent_materialization, + populate_shared_state, promote_nvfp4_static_quantizers, - quantizer_attr_names, reduce_amax, ) from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper @@ -74,69 +75,10 @@ def _check_nvfp4_static_tp_supported(model: nn.Module) -> None: # no-op without ] -def _is_calibrated_nvfp4_static(q) -> bool: - """True iff ``q`` is an enabled NVFP4-static weight quantizer with ``_amax`` set.""" - return ( - isinstance(q, NVFP4StaticQuantizer) - and not q._disabled - and q.is_nvfp4_static - and getattr(q, "_amax", None) is not None - ) - - -def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]: - """Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers.""" - # Inline: layer_utils → quant_utils → model_calib cycle. - from modelopt.torch.export.layer_utils import _GATE_UP_PAIRS - - # Reuses the existing gate/up pairs and adds Q/K/V (no equivalent constant - # in export). Single source for the gate/up half avoids parallel lists. - patterns: tuple[tuple[str, ...], ...] = (("q_proj", "k_proj", "v_proj"), *_GATE_UP_PAIRS) - groups: list[list[nn.Module]] = [] - wq_attr = quantizer_attr_names("weight").weight_quantizer - for parent in model.modules(): - for sibling_names in patterns: - members = [ - child - for child in (getattr(parent, n, None) for n in sibling_names) - if child is not None and _is_calibrated_nvfp4_static(getattr(child, wq_attr, None)) - ] - if len(members) >= 2: - groups.append(members) - return groups - - def _collect_weight_stats(quantizer: nn.Module, weight: torch.Tensor) -> None: quantizer(weight) -@torch.no_grad() -def _sync_grouped_weight_global_amax(model: nn.Module) -> int: - """Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers. - - Run after ``max_calibrate``. Sibling discovery is name-based via - ``_collect_grouped_linears``; non-matching architectures (wqkv, fused - qkv_proj, DeepSeek variants, single-Linear fused gate_up_proj) silently - fall back to per-module global_amax. Fused-experts containers already - share a single quantizer across gate/up halves and need no sync. - """ - # quant_utils imports back from this module; top-level would cycle. - from modelopt.torch.export.quant_utils import preprocess_linear_fusion - - n_groups = 0 - for group in _collect_grouped_linears(model): - preprocess_linear_fusion(group) - n_groups += 1 - return n_groups - - -@torch.no_grad() -def _promote_nvfp4_static_quantizers_with_global_amax_sync(model: nn.Module) -> None: - """Promote static NVFP4 weight quantizers and sync grouped global amax.""" - promote_nvfp4_static_quantizers(model) - _sync_grouped_weight_global_amax(model) - - CalibratorFactory: TypeAlias = Callable[ [torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator ] @@ -249,6 +191,7 @@ def max_calibrate( forward_loop: ForwardLoop | None = None, distributed_sync=True, sync_expert_weight_amax=False, + shared_patterns: Mapping[str, Sequence[str]] | None = None, ): """Calibrate the model using max. @@ -259,6 +202,12 @@ def max_calibrate( distributed_sync: Whether to sync input_quantizer amax across distributed processes. sync_expert_weight_amax: SequentialMLP only — share one weight amax across all experts in a MoE layer (within-rank sync + EP all-reduce when EP>1). + shared_patterns: Optional dict keyed by quantizer kind (``"weight"``/``"input"``), each a + list of regexes over module FQNs. For a kind, when patterns are given they are the sole + discovery source (hooks skipped). Modules whose regex match yields the same capture-group + tuple form one group — capture the immediate parent for per-parent (per-expert) grouping, + or leave the expert index uncaptured for cross-expert. Only ``"weight"`` is used today; + ``"input"`` is reserved for future input-quantizer sharing. See :class:`MaxCalibConfig ` for details on the remaining arguments. @@ -280,14 +229,16 @@ def max_calibrate( # Fail fast on NVFP4 static-block with TP>1 (sharded_state_dict treats _amax as replicated). _check_nvfp4_static_tp_supported(model) - # Promote eligible static-block NVFP4 weight quantizers to NVFP4StaticQuantizer so - # the static blockwise fake-quant path is used in forward and export picks up the - # two-level (per-block + global) scaling. Run before the ``distributed_sync`` early - # return so single-process callers also get the promotion. No-op for dynamic-block - # / non-NVFP4 configs. - _promote_nvfp4_static_quantizers_with_global_amax_sync(model) + # Discover sibling groups (forward_loop hooks, or the "weight" regexes when given) + # and attach shared state; populated below, after any cross-rank _amax sync. + # Only "weight" patterns are consumed today; "input" is reserved for the future. + weight_patterns = shared_patterns.get("weight") if shared_patterns else None + attach_shared_quant_states(model, forward_loop=forward_loop, patterns=weight_patterns) if not distributed_sync: + # Single-process: _amax is final — aggregate shared global_amax, then promote. + populate_shared_state(model) + promote_nvfp4_static_quantizers(model) return # Check MoE calibration completeness before sync @@ -404,6 +355,11 @@ def sync_quantizer_amax_across_tp( module.parallel_state.tensor_parallel_group ) + # _amax is now cross-rank consistent — aggregate shared global_amax then + # promote, so siblings read the unified value instead of their own _amax. + populate_shared_state(model) + promote_nvfp4_static_quantizers(model) + def _mse_quant_func(x, amax, quantizer): """Quantization function for MSE calibration.""" diff --git a/modelopt/torch/quantization/utils/__init__.py b/modelopt/torch/quantization/utils/__init__.py index dc6daa00842..3469ce7b16a 100644 --- a/modelopt/torch/quantization/utils/__init__.py +++ b/modelopt/torch/quantization/utils/__init__.py @@ -18,15 +18,27 @@ from .core_utils import * from .layerwise_calib import LayerActivationCollector +from .shared_input import ( + SharedQuantState, + attach_shared_quant_states, + collect_shared_input_modules, + find_shared_input_groups, + populate_shared_state, +) __all__ = [ "EXPORT_MODE", + "SharedQuantState", + "attach_shared_quant_states", + "collect_shared_input_modules", "convert_quantization_axis_to_reduce_axis", "export_torch_mode", + "find_shared_input_groups", "is_quantized", "is_quantized_column_parallel_linear", "is_quantized_linear", "is_quantized_row_parallel_linear", + "populate_shared_state", "reduce_amax", "reduce_sum", "replace_function", diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index e95e175b71d..f49b75d53e7 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -954,6 +954,10 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int: need to be promoted so they use the two-level scaling path (global amax + per-block amax) instead of the generic E4M3 path. + If the quantizer has a ``_shared_quant_state_ref`` with a populated + ``weight_global_amax`` (sibling group), that shared value is used instead of + this quantizer's own ``_amax`` reduction, keeping siblings on a common FP8 grid. + Returns the number of quantizers converted. """ from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer @@ -968,8 +972,14 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int: if amax is None: continue + # Grouped siblings share one ``weight_global_amax`` (common FP8 grid); + # otherwise fall back to this quantizer's own per-block amax. already_promoted = isinstance(module, NVFP4StaticQuantizer) - global_amax = reduce_amax(amax.clone().detach(), axis=None) + shared = getattr(module, "_shared_quant_state_ref", None) + if shared is not None and shared.weight_global_amax is not None: + global_amax = shared.weight_global_amax + else: + global_amax = reduce_amax(amax.clone().detach(), axis=None) NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) if not already_promoted: converted += 1 diff --git a/modelopt/torch/quantization/utils/shared_input.py b/modelopt/torch/quantization/utils/shared_input.py new file mode 100644 index 00000000000..3b75126c626 --- /dev/null +++ b/modelopt/torch/quantization/utils/shared_input.py @@ -0,0 +1,470 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Discovery and per-group shared state for modules that consume the same input. + +- :func:`collect_shared_input_modules` runs a probe forward with hooks and + returns the modules that received the same input tensor. Used by both + calibration (here) and export (``unified_export_hf``). +- :class:`SharedQuantState` is an ``nn.Module`` attached to a sibling group's + parent so its tensors ride along in ``state_dict``. Holds only + ``weight_global_amax`` today; designed to grow (act scales, LoRA factors, ...). + +:func:`find_shared_input_groups` (hooks + name patterns) produces the +``(parent, members)`` tuples consumed by :func:`attach_shared_quant_states` and +:func:`populate_shared_state`. +""" + +import re +import warnings +from collections import defaultdict +from collections.abc import Callable, Sequence + +import torch +import torch.distributed as dist +import torch.nn as nn + +from modelopt.torch.utils.distributed import ParallelState + +from .core_utils import is_quantized_linear, quantizer_attr_names, reduce_amax + +ForwardLoop = Callable[[nn.Module], None] + +__all__ = [ + "SharedQuantState", + "attach_shared_quant_states", + "collect_shared_input_modules", + "find_shared_input_groups", + "populate_shared_state", +] + + +# --------------------------------------------------------------------------- +# Discovery primitive +# --------------------------------------------------------------------------- + + +def collect_shared_input_modules( + model: nn.Module, + forward_fn: Callable[[], None], + module_filter: Callable[[nn.Module], bool] | None = None, + output_filter: Callable[[nn.Module], bool] | None = None, +) -> tuple[dict, dict | None]: + """Hook the model, run a probe forward, group modules by shared input tensor. + + Args: + model: model to probe. + forward_fn: zero-arg callable running a forward pass on ``model``; + quantizers are disabled during it so probe outputs aren't perturbed. + module_filter: which modules to hook on input (default + :func:`is_quantized_linear`). + output_filter: optional, which modules to hook on output. AWQ export uses + it to map a layernorm's output to itself so the pre_quant_scale can be + folded into the layernorm; ``None`` skips output tracking. + + Returns: + ``(input_to_modules, output_to_modules)``: input tensor -> modules that + received it (the shared-input group), and output tensor -> producing + module (``None`` when ``output_filter`` is not given). + """ + # Inline import to avoid a cycle (conversion/dataset_utils import from utils); + # safe because this runs at calibration/export time, not module load. + from modelopt.torch.quantization.conversion import set_quantizer_by_cfg_context + from modelopt.torch.utils.dataset_utils import _disable_use_cache + + if module_filter is None: + module_filter = is_quantized_linear + + input_to_modules: dict = defaultdict(list) + output_to_modules: dict | None = defaultdict(lambda: None) if output_filter else None + + def _input_hook(module, args, output): + if len(args) > 0 and isinstance(args[0], torch.Tensor): + input_to_modules[args[0]].append(module) + + def _output_hook(module, args, output): + if output_to_modules is not None and isinstance(output, torch.Tensor): + output_to_modules[output] = module + + handles = [] + for name, module in model.named_modules(): + if output_filter is not None and output_filter(module): + module.name = name + handles.append(module.register_forward_hook(_output_hook)) + elif module_filter(module): + module.name = name + handles.append(module.register_forward_hook(_input_hook)) + + if not handles: + return input_to_modules, output_to_modules + + try: + with ( + torch.no_grad(), + set_quantizer_by_cfg_context(model, [{"quantizer_name": "*", "enable": False}]), + _disable_use_cache(model), + ): + forward_fn() + finally: + for handle in handles: + handle.remove() + + return input_to_modules, output_to_modules + + +# --------------------------------------------------------------------------- +# Per-group shared state +# --------------------------------------------------------------------------- + + +class SharedQuantState(nn.Module): + """State shared across a sibling group of quantized modules. + + Attached to the group's parent (e.g. ``self_attn``, ``block_sparse_moe``) as a + registered submodule so its buffers ride along in ``state_dict``. Members that + quantize the same input resolve shared values here instead of computing them + independently and reconciling at export. + + Holds only ``weight_global_amax`` today; new tensor fields (act scales, AWQ + ``pre_quant_scale``, SVDQuant/FlatQuant factors) should use ``register_buffer`` + so they serialize too. + """ + + def __init__(self) -> None: + """Initialize with an unset ``weight_global_amax`` and no registered members.""" + super().__init__() + # NVFP4 two-level FP8 grid scale = max over members' per-block ``_amax``. + # Unset for non-NVFP4 configs. + self.register_buffer("weight_global_amax", None) + # Back-references to member modules. ``object.__setattr__`` keeps them out of + # ``_modules`` so members' params don't re-enter our ``state_dict``. + object.__setattr__(self, "_members", []) + + def sync_weight_global_amax(self, parallel_state: ParallelState | None) -> None: + """All-reduce (MAX) ``weight_global_amax`` across EP, plus TP defensively. + + Weights are DP-replicated (no DP sync needed). EP sync is required since + ranks hold different experts; TP sync guards against per-child ``_amax`` TP + sync skipping block-quantized weights. + """ + if self.weight_global_amax is None or parallel_state is None: + return + for group in ( + parallel_state.expert_model_parallel_group, + parallel_state.tensor_parallel_group, + ): + if group is None or not group.is_initialized(): + continue + try: + dist.all_reduce( + self.weight_global_amax, + op=dist.ReduceOp.MAX, + group=group.group, + ) + except RuntimeError as e: + warnings.warn(f"Failed to sync shared weight_global_amax: {e}") + + +# --------------------------------------------------------------------------- +# Group discovery (hook + pattern) → (parent, members) tuples +# --------------------------------------------------------------------------- + + +def _has_calibratable_weight_quantizer(child: nn.Module, wq_attr: str) -> bool: + """A child is eligible if its weight quantizer is enabled and has ``_amax`` set.""" + wq = getattr(child, wq_attr, None) + if wq is None or not hasattr(wq, "_disabled") or wq._disabled: + return False + return getattr(wq, "_amax", None) is not None + + +def _build_parent_map(model: nn.Module) -> dict[nn.Module, nn.Module]: + """Build a ``{child_module: direct_parent}`` map by walking ``named_children``.""" + parent_map: dict[nn.Module, nn.Module] = {} + for parent in model.modules(): + for child in parent.children(): + parent_map[child] = parent + return parent_map + + +def _climb_past_modulelist( + module: nn.Module, + parent_map: dict[nn.Module, nn.Module], + fallback: nn.Module, +) -> nn.Module: + """Walk up past any ``nn.ModuleList`` ancestors to a regular module. + + Attaching ``SharedQuantState`` to a ``ModuleList`` registers it in that + container's ``_modules`` and corrupts its iteration/length (the state shows up + alongside the experts), so attach to the first non-ModuleList ancestor. + """ + cur = module + while isinstance(cur, nn.ModuleList): + parent = parent_map.get(cur) + if parent is None or parent is cur: + return fallback + cur = parent + return cur + + +def _lowest_common_ancestor( + members: Sequence[nn.Module], + parent_map: dict[nn.Module, nn.Module], + fallback: nn.Module, +) -> nn.Module: + """LCA of ``members`` in the module tree (``fallback`` if none). + + Climbs past ``nn.ModuleList`` ancestors so the result can host + ``SharedQuantState`` as a submodule. + """ + if not members: + return fallback + + def ancestors(m: nn.Module) -> list[nn.Module]: + chain = [] + cur = m + while cur in parent_map: + cur = parent_map[cur] + chain.append(cur) + return chain + + chains = [ancestors(m) for m in members] + if not chains[0]: + return fallback + common = set(chains[0]) + for c in chains[1:]: + common &= set(c) + # Deepest common ancestor: first in member[0]'s chain that's in every chain. + for a in chains[0]: + if a in common: + return _climb_past_modulelist(a, parent_map, fallback) + return fallback + + +def _groups_from_hooks( + model: nn.Module, + forward_loop: ForwardLoop, +) -> list[tuple[nn.Module, list[nn.Module]]]: + """Discover sibling groups from a hook-based probe forward. + + Modules whose first input is the same tensor object form a group, parented at + their LCA. Only catches the *literal same tensor*, so cross-expert MoE sharing + (one input per expert) is missed — patterns cover that. + """ + input_to_modules, _ = collect_shared_input_modules( + model, forward_fn=lambda: forward_loop(model) + ) + if not input_to_modules: + return [] + + parent_map = _build_parent_map(model) + wq_attr = quantizer_attr_names("weight").weight_quantizer + groups: list[tuple[nn.Module, list[nn.Module]]] = [] + for members in input_to_modules.values(): + # Dedup (a module may be hooked twice) and keep calibrated members only. + unique: list[nn.Module] = [] + seen: set[int] = set() + for m in members: + if id(m) in seen: + continue + if not _has_calibratable_weight_quantizer(m, wq_attr): + continue + seen.add(id(m)) + unique.append(m) + if len(unique) >= 2: + parent = _lowest_common_ancestor(unique, parent_map, fallback=model) + groups.append((parent, unique)) + return groups + + +def _groups_from_patterns( + model: nn.Module, + patterns: Sequence[str], +) -> list[tuple[nn.Module, list[nn.Module]]]: + r"""Discover groups by regex over module FQNs; capture groups define the grouping key. + + Each pattern is a regex ``re.fullmatch``-ed against every quantized module's + fully-qualified name; modules whose match yields the same capture-group tuple form + one group, parented at their LCA. Granularity is set by *what you capture*: + + - Capture the immediate parent -> per-parent grouping: q/k/v per attention block, + and **per-expert** ``w1``/``w3`` (each expert is the immediate parent), e.g. + ``r"(.*)\.(?:w1|w3)$"``. + - Capture only a level above the expert index, leaving the index uncaptured -> + one **cross-expert** group, e.g. ``r"(.*)\.experts\.\d+\.(?:w1|w3)$"``. + + Roles to fuse together go in a non-capturing alternation ``(?:w1|w3)`` so they + don't split the key; what you wrap in ``(...)`` is the group boundary. + """ + wq_attr = quantizer_attr_names("weight").weight_quantizer + compiled = [re.compile(p) for p in patterns] + buckets: dict[tuple, list[nn.Module]] = {} + order: list[tuple] = [] + for name, module in model.named_modules(): + if not _has_calibratable_weight_quantizer(module, wq_attr): + continue + for pattern_idx, regex in enumerate(compiled): + match = regex.fullmatch(name) + if match is not None: + key = ( + pattern_idx, + match.groups(), + ) # include pattern_idx in case 1+ partterns collide. + if key not in buckets: + buckets[key] = [] + order.append(key) + buckets[key].append(module) + break # each module belongs to its first matching pattern + parent_map = _build_parent_map(model) + groups: list[tuple[nn.Module, list[nn.Module]]] = [] + for key in order: + members = buckets[key] + if len(members) >= 2: + parent = _lowest_common_ancestor(members, parent_map, fallback=model) + groups.append((parent, members)) + return groups + + +def find_shared_input_groups( + model: nn.Module, + forward_loop: ForwardLoop | None = None, + patterns: Sequence[str] | None = None, +) -> list[tuple[nn.Module, list[nn.Module]]]: + """Find sibling groups from regex patterns if given, else from a hook probe. + + Patterns and hooks are mutually exclusive: when ``patterns`` is set it is the + sole source (and must list every group you want); otherwise hook discovery runs + on the ``forward_loop`` probe. + + - ``patterns`` — regexes over module FQNs whose capture groups define the grouping + key (see :func:`_groups_from_patterns`); the capture boundary chooses per-expert + vs cross-expert granularity. The caller selects which quantizer these groups apply + to (today only the weight quantizer; see ``MaxCalibConfig.shared_patterns``). + - ``forward_loop`` — hook probe grouping modules that receive the *literal same + tensor* (Q/K/V, gate/up within one block/expert; per-expert for MoE). + + Returns ``(parent, members)`` tuples. + """ + if patterns: + return _groups_from_patterns(model, patterns) + if forward_loop is not None: + return _groups_from_hooks(model, forward_loop) + return [] + + +# --------------------------------------------------------------------------- +# Attach / populate lifecycle +# --------------------------------------------------------------------------- + + +def attach_shared_quant_states( + model: nn.Module, + forward_loop: ForwardLoop | None = None, + patterns: Sequence[str] | None = None, +) -> int: + """Create ``SharedQuantState`` on each group's parent and link members. + + The parent owns the state under ``_shared_quant_state`` (normal setattr → a + registered submodule, so its buffer rides along in ``state_dict``). Each + member's weight quantizer — the only consumer, via + ``promote_nvfp4_static_quantizers`` — gets a back-reference under the distinct + name ``_shared_quant_state_ref`` set with ``object.__setattr__`` (not a + submodule, so the buffer isn't duplicated per member). The distinct names let + ``populate_shared_state`` select owners with a plain ``getattr``. + + Idempotent (reuses an existing parent state). Returns the number created. + """ + n_created = 0 + wq_attr = quantizer_attr_names("weight").weight_quantizer + for parent, members in find_shared_input_groups( + model, forward_loop=forward_loop, patterns=patterns + ): + if not hasattr(parent, "_shared_quant_state"): + parent._shared_quant_state = SharedQuantState() + n_created += 1 + state = parent._shared_quant_state + # Record members so populate_shared_state needn't re-run discovery. + object.__setattr__(state, "_members", list(members)) + for child in members: + wq = getattr(child, wq_attr, None) + if wq is None: + continue + # Groups are disjoint after merging, so each quantizer gets one state per + # call and a re-attach reuses the same object; a different existing state + # would mean an inconsistent re-attach. + existing = getattr(wq, "_shared_quant_state_ref", None) + assert existing is None or existing is state, ( + f"{type(wq).__name__} already belongs to a different shared-input " + "group; groups should be disjoint after merging." + ) + object.__setattr__(wq, "_shared_quant_state_ref", state) + return n_created + + +@torch.no_grad() +def populate_shared_state(model: nn.Module) -> int: + """Aggregate per-member stats into each group's ``SharedQuantState``. + + Currently sets ``weight_global_amax`` = max over members' reduced ``_amax``, + EP-synced so all ranks agree, then writes it back to each member's + ``global_amax`` (overriding any stale value from an earlier promotion). Future + fields plug in here as extra aggregation steps. + + Call after members' ``_amax`` is cross-rank consistent (post TP/DP/EP sync in + ``max_calibrate``). Members not yet promoted to ``NVFP4StaticQuantizer`` are + skipped on write-back; the next promotion reads the shared value instead. + Returns the number of groups populated. + """ + from modelopt.torch.quantization.nn import NVFP4StaticQuantizer + + wq_attr = quantizer_attr_names("weight").weight_quantizer + n_groups = 0 + + for parent in model.modules(): + # Owners hold the state under ``_shared_quant_state`` (members use the + # distinct ``_shared_quant_state_ref``), so getattr matches owners only. + state = getattr(parent, "_shared_quant_state", None) + if not isinstance(state, SharedQuantState): + continue + + members = getattr(state, "_members", []) + if not members: + continue + + child_maxes: list[torch.Tensor] = [] + parallel_state: ParallelState | None = None + for child in members: + wq = getattr(child, wq_attr, None) + if wq is None or getattr(wq, "_amax", None) is None: + continue + child_maxes.append(reduce_amax(wq._amax, axis=None)) + if parallel_state is None: + parallel_state = getattr(child, "parallel_state", None) + + if not child_maxes: + continue + + local_max = torch.max(torch.stack(child_maxes)) + state.weight_global_amax = local_max + state.sync_weight_global_amax(parallel_state) + + synced = state.weight_global_amax + for child in members: + wq = getattr(child, wq_attr, None) + if isinstance(wq, NVFP4StaticQuantizer): + wq.global_amax = synced + n_groups += 1 + + return n_groups diff --git a/tests/unit/torch/quantization/test_mse_calibrator.py b/tests/unit/torch/quantization/test_mse_calibrator.py index 064a77a247e..0e7127dfe4d 100644 --- a/tests/unit/torch/quantization/test_mse_calibrator.py +++ b/tests/unit/torch/quantization/test_mse_calibrator.py @@ -23,7 +23,6 @@ from modelopt.torch.quantization.model_calib import ( _FP8_SWEEP_CALIBRATOR_REGISTRY, _make_weight_mse_calibrator, - _promote_nvfp4_static_quantizers_with_global_amax_sync, _register_fp8_sweep_calibrator, mse_calibrate, ) @@ -32,7 +31,12 @@ _QUANT_FUNCTIONAL_BACKENDS, register_quant_backend, ) -from modelopt.torch.quantization.utils import enable_fake_quant, promote_nvfp4_static_quantizers +from modelopt.torch.quantization.utils import ( + attach_shared_quant_states, + enable_fake_quant, + populate_shared_state, + promote_nvfp4_static_quantizers, +) # TODO: avoid code duplication in this file @@ -644,7 +648,7 @@ def test_modelopt_static_nvfp4_uses_fp8_scale_sweep(self): ) model = torch.nn.Module() model.weight_quantizer = q - _promote_nvfp4_static_quantizers_with_global_amax_sync(model) + promote_nvfp4_static_quantizers(model) cal = _make_weight_mse_calibrator( q, @@ -750,7 +754,7 @@ def __init__(self, amax): def test_standalone_static_nvfp4_quantizer_is_promoted(self): model = self._LinearLike(torch.tensor([1.0, 5.0])) - _promote_nvfp4_static_quantizers_with_global_amax_sync(model) + promote_nvfp4_static_quantizers(model) assert isinstance(model.weight_quantizer, NVFP4StaticQuantizer) assert torch.equal(model.weight_quantizer.global_amax, torch.tensor(5.0)) @@ -770,7 +774,9 @@ def test_grouped_static_nvfp4_quantizers_share_global_amax(self): model.k_proj = self._LinearLike(torch.tensor([3.0, 4.0])) model.v_proj = self._LinearLike(torch.tensor([5.0, 6.0])) - _promote_nvfp4_static_quantizers_with_global_amax_sync(model) + attach_shared_quant_states(model, patterns=[r"(?:(.*)\.)?(?:q_proj|k_proj|v_proj)"]) + populate_shared_state(model) + promote_nvfp4_static_quantizers(model) for child in (model.q_proj, model.k_proj, model.v_proj): assert isinstance(child.weight_quantizer, NVFP4StaticQuantizer) diff --git a/tests/unit/torch/quantization/test_shared_input.py b/tests/unit/torch/quantization/test_shared_input.py new file mode 100644 index 00000000000..9fcd89871b3 --- /dev/null +++ b/tests/unit/torch/quantization/test_shared_input.py @@ -0,0 +1,419 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +"""Unit tests for SharedQuantState — group-level quantization state on parent modules. + +These tests use a hand-built CPU model with Q/K/V siblings under a dummy +``self_attn`` parent, then drive ``max_calibrate`` directly so the run is +fast and deterministic without needing real attention/MoE layers. +""" + +import pytest +import torch +import torch.nn as nn + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.config import QuantizerAttributeConfig +from modelopt.torch.quantization.model_calib import max_calibrate +from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer +from modelopt.torch.quantization.utils import ( + SharedQuantState, + attach_shared_quant_states, + find_shared_input_groups, + populate_shared_state, + quantizer_attr_names, + reduce_amax, +) + +# Test-only patterns: lists of regexes (the per-kind value of ``shared_patterns``, +# e.g. the ``"weight"`` list). ``re.fullmatch``-ed against module FQNs; capture groups +# define the grouping key. ``(?:(.*)\.)?`` captures the immediate parent (or None at the +# model root, since these test models hold roles directly) -> per-parent / per-expert. +SIBLING_PATTERNS = [ + r"(?:(.*)\.)?(?:q_proj|k_proj|v_proj)", + r"(?:(.*)\.)?(?:gate_proj|up_proj)", + r"(?:(.*)\.)?(?:w1|w3)", +] + + +NVFP4_BLOCK = 16 + + +def _make_nvfp4_static_cfg() -> QuantizerAttributeConfig: + """NVFP4 static block quantization config (E2M1 weights + E4M3 per-block scales).""" + return QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: NVFP4_BLOCK, "type": "static", "scale_bits": (4, 3)}, + ) + + +class _DummyAttention(nn.Module): + """A toy parent module that exposes ``q_proj``, ``k_proj``, ``v_proj`` siblings.""" + + def __init__(self, in_features: int = 32, out_features: int = 32) -> None: + super().__init__() + self.q_proj = nn.Linear(in_features, out_features, bias=False) + self.k_proj = nn.Linear(in_features, out_features, bias=False) + self.v_proj = nn.Linear(in_features, out_features, bias=False) + + +def _populate_amax(linear: nn.Module, value: float) -> None: + """Directly set ``_amax`` on a linear's weight_quantizer for deterministic testing.""" + wq_attr = quantizer_attr_names("weight").weight_quantizer + wq = getattr(linear, wq_attr) + # Match the per-block shape the real calibrator would produce + out_features, in_features = linear.weight.shape + n_blocks = in_features // NVFP4_BLOCK + wq._amax = torch.full( + (out_features, n_blocks), value, dtype=torch.float32, device=linear.weight.device + ) + + +class TestSharedQuantStateBasics: + """Direct exercise of the SharedQuantState container and its helpers.""" + + def test_init_unset(self): + s = SharedQuantState() + assert s.weight_global_amax is None + + def test_attach_creates_state_on_parent(self): + attn = _DummyAttention() + # Configure with NVFP4-static weight quantizers and seed _amax so attach finds them. + mtq.replace_quant_module(attn) + cfg = _make_nvfp4_static_cfg() + for proj in (attn.q_proj, attn.k_proj, attn.v_proj): + proj.weight_quantizer.set_from_attribute_config(cfg) + _populate_amax(proj, value=1.0) + + n = attach_shared_quant_states(attn, patterns=SIBLING_PATTERNS) + assert n == 1, f"expected one new state, got {n}" + assert hasattr(attn, "_shared_quant_state") + for proj in (attn.q_proj, attn.k_proj, attn.v_proj): + assert proj.weight_quantizer._shared_quant_state_ref is attn._shared_quant_state + + def test_find_groups_skips_singletons(self): + """A parent with only one matching child must NOT form a group.""" + + class _OnlyOne(nn.Module): + def __init__(self): + super().__init__() + self.q_proj = nn.Linear(16, 16, bias=False) + + m = _OnlyOne() + mtq.replace_quant_module(m) + cfg = _make_nvfp4_static_cfg() + m.q_proj.weight_quantizer.set_from_attribute_config(cfg) + _populate_amax(m.q_proj, value=1.0) + + groups = find_shared_input_groups(m, patterns=SIBLING_PATTERNS) + assert groups == [], "single sibling must not form a group" + + def test_patterns_override_hooks(self): + """When patterns are given, hooks are skipped (patterns are the sole source).""" + attn = _DummyAttention() + mtq.replace_quant_module(attn) + cfg = _make_nvfp4_static_cfg() + for proj in (attn.q_proj, attn.k_proj, attn.v_proj): + proj.weight_quantizer.set_from_attribute_config(cfg) + _populate_amax(proj, value=1.0) + + def fwd(m): + x = torch.randn(2, 32) + for proj in (m.q_proj, m.k_proj, m.v_proj): + proj(x) + + # Hooks alone would group all three (shared input). A q/k-only pattern must + # win — v_proj absent proves hooks are skipped when patterns are provided. + groups = find_shared_input_groups( + attn, forward_loop=fwd, patterns=[r"(?:(.*)\.)?(?:q_proj|k_proj)"] + ) + assert len(groups) == 1 + _parent, members = groups[0] + assert len(members) == 2 and attn.v_proj not in members + + def test_populate_writes_max_across_siblings(self): + attn = _DummyAttention() + mtq.replace_quant_module(attn) + cfg = _make_nvfp4_static_cfg() + # Seed deterministic _amax values: q=1.0, k=3.0, v=2.0 → max=3.0 + attn.q_proj.weight_quantizer.set_from_attribute_config(cfg) + attn.k_proj.weight_quantizer.set_from_attribute_config(cfg) + attn.v_proj.weight_quantizer.set_from_attribute_config(cfg) + _populate_amax(attn.q_proj, value=1.0) + _populate_amax(attn.k_proj, value=3.0) + _populate_amax(attn.v_proj, value=2.0) + + attach_shared_quant_states(attn, patterns=SIBLING_PATTERNS) + n_groups = populate_shared_state(attn) + + assert n_groups == 1 + shared = attn._shared_quant_state.weight_global_amax + assert shared is not None + assert torch.isclose(shared, torch.tensor(3.0)), f"expected 3.0, got {shared.item()}" + + +class _MoEExpert(nn.Module): + """A toy MoE expert with Mixtral-style ``w1`` (gate) and ``w3`` (up) projections.""" + + def __init__(self, hidden=32, intermediate=32) -> None: + super().__init__() + self.w1 = nn.Linear(hidden, intermediate, bias=False) + self.w3 = nn.Linear(hidden, intermediate, bias=False) + + +class _MoEBlock(nn.Module): + """MoE block holding experts in an ``nn.ModuleList``.""" + + def __init__(self, n_experts: int = 4) -> None: + super().__init__() + self.experts = nn.ModuleList(_MoEExpert() for _ in range(n_experts)) + + +class TestMoESharedState: + """SharedQuantState groups ``w1``/``w3`` per expert via sibling-scope patterns. + + Cross-expert grouping is intentionally not covered here: weight amax is + per-expert (gate==up *within* an expert), so a single scale across experts + is only meaningful for the input quantizer — to be added with that feature. + """ + + def _setup_moe(self, n_experts: int = 4) -> _MoEBlock: + block = _MoEBlock(n_experts=n_experts) + mtq.replace_quant_module(block) + cfg = _make_nvfp4_static_cfg() + for i, expert in enumerate(block.experts): + # Distinct amax values so the max is determined and identifiable. + for j, proj in enumerate((expert.w1, expert.w3)): + proj.weight_quantizer.set_from_attribute_config(cfg) + _populate_amax(proj, value=1.0 + i + 0.1 * j) + return block + + def test_sibling_patterns_group_per_expert(self): + """Sibling-scope w1/w3 patterns group each expert independently (per-expert).""" + block = self._setup_moe(n_experts=4) + groups = find_shared_input_groups(block, patterns=SIBLING_PATTERNS) + # One [w1, w3] group per expert, parented at that expert (not the block). + assert len(groups) == 4 + experts = list(block.experts) + for parent, members in groups: + assert parent in experts + assert len(members) == 2 + + attach_shared_quant_states(block, patterns=SIBLING_PATTERNS) + n_groups = populate_shared_state(block) + assert n_groups == 4 + assert not hasattr(block, "_shared_quant_state") # state is on each expert, not the block + # Each expert's max is within-expert: max(w1=1.0+i, w3=1.0+i+0.1) = 1.1 + i. + for i, expert in enumerate(block.experts): + shared = expert._shared_quant_state.weight_global_amax + assert torch.isclose(shared, torch.tensor(1.1 + i)), f"expert {i}: {shared.item()}" + + +class TestMaxCalibrateEndToEnd: + """End-to-end through ``max_calibrate``: same shared global_amax across siblings.""" + + def _setup_attention(self, scales=(0.5, 2.0, 1.0)) -> _DummyAttention: + """Build an attention block with NVFP4-static weight quantizers; distinct weight scales.""" + attn = _DummyAttention(in_features=32, out_features=32) + # Bias weight magnitudes so per-projection amaxes differ. + with torch.no_grad(): + attn.q_proj.weight.mul_(scales[0]) + attn.k_proj.weight.mul_(scales[1]) + attn.v_proj.weight.mul_(scales[2]) + mtq.replace_quant_module(attn) + cfg = _make_nvfp4_static_cfg() + for proj in (attn.q_proj, attn.k_proj, attn.v_proj): + proj.weight_quantizer.set_from_attribute_config(cfg) + # Other quantizers on the wrapped module would interfere with the + # weight-only calibration path, so disable them. + for name in ("input_quantizer", "output_quantizer"): + q = getattr(proj, name, None) + if isinstance(q, TensorQuantizer): + q.disable() + return attn + + @pytest.mark.parametrize("distributed_sync", [True, False]) + def test_qkv_share_global_amax_via_max_calibrate(self, distributed_sync): + """After max_calibrate, q/k/v_proj have identical global_amax (the group max).""" + attn = self._setup_attention() + + # Drive a forward pass so input shape gets observed by the weight calibrators. + def fwd(m): + x = torch.randn(2, 32) + m.q_proj(x) + m.k_proj(x) + m.v_proj(x) + + max_calibrate(attn, forward_loop=fwd, distributed_sync=distributed_sync) + + # All siblings should be promoted to NVFP4StaticQuantizer with same value. + global_amaxes = [] + for proj in (attn.q_proj, attn.k_proj, attn.v_proj): + assert isinstance(proj.weight_quantizer, NVFP4StaticQuantizer) + ga = proj.weight_quantizer.global_amax + assert ga is not None + global_amaxes.append(ga.item()) + + assert global_amaxes[0] == global_amaxes[1] == global_amaxes[2], ( + f"siblings should share global_amax, got {global_amaxes}" + ) + + # The shared value must equal the max over each child's own _amax reduction. + per_child_max = max( + reduce_amax(proj.weight_quantizer._amax, axis=None).item() + for proj in (attn.q_proj, attn.k_proj, attn.v_proj) + ) + assert global_amaxes[0] == pytest.approx(per_child_max), ( + f"shared global_amax {global_amaxes[0]} != per-child max {per_child_max}" + ) + + def test_standalone_linear_no_shared_state(self): + """A linear without siblings has no shared state attached.""" + + class _Lonely(nn.Module): + def __init__(self): + super().__init__() + self.proj = nn.Linear(32, 32, bias=False) + + m = _Lonely() + mtq.replace_quant_module(m) + m.proj.weight_quantizer.set_from_attribute_config(_make_nvfp4_static_cfg()) + for name in ("input_quantizer", "output_quantizer"): + q = getattr(m.proj, name, None) + if isinstance(q, TensorQuantizer): + q.disable() + + max_calibrate(m, forward_loop=lambda m: m.proj(torch.randn(2, 32)), distributed_sync=False) + + # Promoted, but no shared state attached because there's no sibling group. + assert isinstance(m.proj.weight_quantizer, NVFP4StaticQuantizer) + assert not hasattr(m.proj.weight_quantizer, "_shared_quant_state_ref") + # Standalone global_amax should equal its own reduce_amax(_amax). + expected = reduce_amax(m.proj.weight_quantizer._amax, axis=None).item() + assert m.proj.weight_quantizer.global_amax.item() == pytest.approx(expected) + + def test_shared_state_survives_state_dict_round_trip(self, tmp_path): + """``SharedQuantState`` is an nn.Module; its buffer must round-trip through state_dict.""" + attn = self._setup_attention() + + def fwd(m): + x = torch.randn(2, 32) + m.q_proj(x) + m.k_proj(x) + m.v_proj(x) + + max_calibrate(attn, forward_loop=fwd, distributed_sync=False) + + # Shared state survives — its buffer lives in the parent's state_dict + # under ``_shared_quant_state.weight_global_amax`` exactly once. + assert hasattr(attn, "_shared_quant_state") + sd = attn.state_dict() + shared_buffer_keys = [k for k in sd if k.endswith("_shared_quant_state.weight_global_amax")] + assert shared_buffer_keys == ["_shared_quant_state.weight_global_amax"], ( + f"buffer should appear exactly once on the parent, got {shared_buffer_keys}" + ) + + # Each weight_quantizer holds the back-reference (via object.__setattr__) + # but it does NOT register as a child submodule — otherwise the buffer + # would be duplicated under each member's prefix. + for proj in (attn.q_proj, attn.k_proj, attn.v_proj): + assert proj.weight_quantizer._shared_quant_state_ref is attn._shared_quant_state + assert "_shared_quant_state_ref" not in proj.weight_quantizer._modules + + # state_dict round-trips under ``weights_only=True``. + path = tmp_path / "sd.pt" + torch.save(sd, path) + loaded = torch.load(path, weights_only=True) + assert loaded["_shared_quant_state.weight_global_amax"].item() == pytest.approx( + sd["_shared_quant_state.weight_global_amax"].item() + ) + + +class _AttnQKV(nn.Module): + """Attention-like module whose q/k/v consume the same input tensor.""" + + def __init__(self, hidden: int = 16) -> None: + super().__init__() + self.q_proj = nn.Linear(hidden, hidden, bias=False) + self.k_proj = nn.Linear(hidden, hidden, bias=False) + self.v_proj = nn.Linear(hidden, hidden, bias=False) + + def forward(self, x): + return self.q_proj(x) + self.k_proj(x) + self.v_proj(x) + + +class _GLUExpert(nn.Module): + """Expert whose w1/w3 consume the same (per-expert) input tensor.""" + + def __init__(self, hidden: int = 16) -> None: + super().__init__() + self.w1 = nn.Linear(hidden, hidden, bias=False) + self.w3 = nn.Linear(hidden, hidden, bias=False) + + def forward(self, x): + return self.w1(x) * self.w3(x) + + +class _AttnMoE(nn.Module): + """q/k/v attention + MoE experts; each expert gets a distinct input slice.""" + + def __init__(self, n_experts: int = 2, hidden: int = 16) -> None: + super().__init__() + self.self_attn = _AttnQKV(hidden) + self.experts = nn.ModuleList(_GLUExpert(hidden) for _ in range(n_experts)) + + def forward(self, x): + a = self.self_attn(x) + for i, expert in enumerate(self.experts): + expert(a[i : i + 1]) # distinct slice per expert -> per-expert hook grouping + return a + + +class TestHookPatternEquivalence: + """The sibling regex patterns reproduce exactly what hook discovery finds.""" + + def test_hooks_match_sibling_patterns(self): + model = _AttnMoE(n_experts=2, hidden=16) + mtq.replace_quant_module(model) + cfg = _make_nvfp4_static_cfg() + linears = [model.self_attn.q_proj, model.self_attn.k_proj, model.self_attn.v_proj] + for expert in model.experts: + linears += [expert.w1, expert.w3] + for lin in linears: + lin.weight_quantizer.set_from_attribute_config(cfg) + _populate_amax(lin, value=1.0) + + def fwd(m): + m(torch.randn(2, 16)) + + hook_groups = find_shared_input_groups(model, forward_loop=fwd) + pattern_groups = find_shared_input_groups(model, patterns=SIBLING_PATTERNS) + + def normalize(groups): + # Group identity = (parent module, set of member modules), order-independent. + return {(id(parent), frozenset(id(m) for m in members)) for parent, members in groups} + + assert normalize(hook_groups) == normalize(pattern_groups) + # Sanity: one q/k/v group + one w1/w3 group per expert (per-expert, not cross). + assert len(hook_groups) == 3 From a1529f9b69d92749535ed2264eaf097392e47b06 Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Tue, 2 Jun 2026 07:41:25 -0700 Subject: [PATCH 2/6] remove hook based route Signed-off-by: Shiyang Chen --- modelopt/torch/export/unified_export_hf.py | 79 ++++-- modelopt/torch/quantization/config.py | 10 +- modelopt/torch/quantization/model_calib.py | 22 +- modelopt/torch/quantization/utils/__init__.py | 4 +- .../torch/quantization/utils/shared_input.py | 242 +++++------------- .../torch/quantization/test_shared_input.py | 124 +++------ 6 files changed, 173 insertions(+), 308 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 0660ec1dcf6..00e4a7008a9 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -54,12 +54,11 @@ from torch.distributed.fsdp import FSDPModule +from modelopt.torch.quantization import set_quantizer_by_cfg_context from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor -from modelopt.torch.quantization.utils import ( - collect_shared_input_modules as _collect_shared_input_modules, -) from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names +from modelopt.torch.utils.dataset_utils import _disable_use_cache try: from modelopt.torch.sparsity.attention_sparsity.conversion import export_sparse_attention_config @@ -256,20 +255,68 @@ def collect_shared_input_modules( ) -> tuple[dict, dict | None]: """Collect modules that share the same input using forward hooks. - Thin wrapper around - :func:`modelopt.torch.quantization.utils.collect_shared_input_modules`, - parameterized for the export use case: hooks every ``is_quantlinear`` - module with at least one enabled quantizer, and optionally tracks - layernorm outputs when ``collect_layernorms=True`` (for AWQ - pre_quant_scale folding into the preceding layernorm). + This is a common helper for both LLM and diffusion model fusion. + + Args: + model: The model to analyze. + dummy_forward_fn: A callable that runs a dummy forward pass on the model. + Should be a function that takes no arguments. + collect_layernorms: If True, also collect layernorm output mappings (for AWQ). + + Returns: + A tuple of (input_to_linear, output_to_layernorm). + input_to_linear: Dict mapping input tensor to list of modules sharing that input. + output_to_layernorm: Dict mapping layernorm output to the layernorm module (or None). """ - return _collect_shared_input_modules( - model, - dummy_forward_fn, - module_filter=lambda m: is_quantlinear(m) - and (_is_enabled_quantizer(m.input_quantizer) or _is_enabled_quantizer(m.weight_quantizer)), - output_filter=is_layernorm if collect_layernorms else None, - ) + input_to_linear: dict = defaultdict(list) + output_to_layernorm: dict | None = defaultdict(lambda: None) if collect_layernorms else None + + def _input_hook(module, input, output): + """Update dictionary with list of all modules that share the same input.""" + if len(input) > 0 and isinstance(input[0], torch.Tensor): + # TODO: Handle DBRX MoE case + input_to_linear[input[0]].append(module) + + def _output_hook(module, input, output): + """Update dictionary with mapping of layernorms and their outputs.""" + if output_to_layernorm is not None and isinstance(output, torch.Tensor): + output_to_layernorm[output] = module + + handles = [] + + # Register hooks on all quantized linear modules (and optionally layernorms) + for name, module in model.named_modules(): + if collect_layernorms and is_layernorm(module): + module.name = name + handle = module.register_forward_hook(_output_hook) + handles.append(handle) + elif is_quantlinear(module) and ( + _is_enabled_quantizer(module.input_quantizer) + or _is_enabled_quantizer(module.weight_quantizer) + ): + module.name = name + handle = module.register_forward_hook(_input_hook) + handles.append(handle) + + if not handles: + return input_to_linear, output_to_layernorm + + # Run dummy forward pass to collect modules sharing same input. + # `_disable_use_cache` keeps the probe forward working on configs that don't + # set `use_cache` (e.g., stepfun-ai/Step-3.5-Flash's Step3p5Config). + try: + with ( + torch.no_grad(), + set_quantizer_by_cfg_context(model, [{"quantizer_name": "*", "enable": False}]), + _disable_use_cache(model), + ): + dummy_forward_fn() + finally: + # Always remove hooks + for handle in handles: + handle.remove() + + return input_to_linear, output_to_layernorm def _fuse_shared_input_modules( diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 7746536fd74..fecbdb90945 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -720,14 +720,16 @@ class MaxCalibConfig(QuantizeAlgorithmConfig): title="Regex patterns for groups that share quantization state", description=( "Optional dict keyed by quantizer kind (``'weight'`` and/or ``'input'``), each a list " - "of regexes matched (full-match) against module fully-qualified names. For a kind, when " - "patterns are given they are the sole discovery source (hook-based discovery is skipped), " - "so they must list every group you want. Modules whose match yields the same capture-group " + "of regexes matched (full-match) against module fully-qualified names. They must list " + "every group you want for that kind. Modules whose match yields the same capture-group " "tuple form one group; the capture boundary chooses granularity: capture the immediate " "parent for per-parent / per-expert groups (e.g. ``r'(.*)\\.(?:q_proj|k_proj|v_proj)'``, " "``r'(.*)\\.(?:w1|w3)'``); leave the expert index uncaptured for one cross-expert group " "(``r'(.*)\\.experts\\.\\d+\\.(?:w1|w3)'``). Only ``'weight'`` is used today; ``'input'`` is " - "reserved for future input-quantizer sharing. Default None -> hook-based discovery." + "reserved for future input-quantizer sharing. When the ``'weight'`` list is omitted, " + "the default fusible patterns (q/k/v, gate/up, w1/w3) are used — these match exactly " + "the sibling groups export fuses, avoiding the over-grouping a shared-input heuristic " + "would cause (e.g. a ``shared_expert_gate`` that reads the same input but is not fused)." ), ) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 94431670113..54ba1d88b2b 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -41,6 +41,7 @@ from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( + DEFAULT_WEIGHT_SHARED_PATTERNS, attach_shared_quant_states, disable_calib, enable_fake_quant, @@ -203,11 +204,12 @@ def max_calibrate( sync_expert_weight_amax: SequentialMLP only — share one weight amax across all experts in a MoE layer (within-rank sync + EP all-reduce when EP>1). shared_patterns: Optional dict keyed by quantizer kind (``"weight"``/``"input"``), each a - list of regexes over module FQNs. For a kind, when patterns are given they are the sole - discovery source (hooks skipped). Modules whose regex match yields the same capture-group - tuple form one group — capture the immediate parent for per-parent (per-expert) grouping, - or leave the expert index uncaptured for cross-expert. Only ``"weight"`` is used today; - ``"input"`` is reserved for future input-quantizer sharing. + list of regexes over module FQNs. When the ``"weight"`` list is omitted, + :data:`DEFAULT_WEIGHT_SHARED_PATTERNS` (q/k/v, gate/up, w1/w3) is used. Modules whose + regex match yields the same capture-group tuple form one group — capture the immediate + parent for per-parent (per-expert) grouping, or leave the expert index uncaptured for + cross-expert. Only ``"weight"`` is used today; ``"input"`` is reserved for future + input-quantizer sharing. See :class:`MaxCalibConfig ` for details on the remaining arguments. @@ -229,11 +231,11 @@ def max_calibrate( # Fail fast on NVFP4 static-block with TP>1 (sharded_state_dict treats _amax as replicated). _check_nvfp4_static_tp_supported(model) - # Discover sibling groups (forward_loop hooks, or the "weight" regexes when given) - # and attach shared state; populated below, after any cross-rank _amax sync. - # Only "weight" patterns are consumed today; "input" is reserved for the future. - weight_patterns = shared_patterns.get("weight") if shared_patterns else None - attach_shared_quant_states(model, forward_loop=forward_loop, patterns=weight_patterns) + # Discover fusible sibling groups by name regex (default q/k/v + gate/up, or the + # caller's "weight" patterns) and attach shared state; populated below, after any + # cross-rank _amax sync. Only "weight" is consumed today; "input" is reserved. + weight_patterns = (shared_patterns or {}).get("weight") or DEFAULT_WEIGHT_SHARED_PATTERNS + attach_shared_quant_states(model, patterns=weight_patterns) if not distributed_sync: # Single-process: _amax is final — aggregate shared global_amax, then promote. diff --git a/modelopt/torch/quantization/utils/__init__.py b/modelopt/torch/quantization/utils/__init__.py index 3469ce7b16a..900c06d50d9 100644 --- a/modelopt/torch/quantization/utils/__init__.py +++ b/modelopt/torch/quantization/utils/__init__.py @@ -19,18 +19,18 @@ from .core_utils import * from .layerwise_calib import LayerActivationCollector from .shared_input import ( + DEFAULT_WEIGHT_SHARED_PATTERNS, SharedQuantState, attach_shared_quant_states, - collect_shared_input_modules, find_shared_input_groups, populate_shared_state, ) __all__ = [ + "DEFAULT_WEIGHT_SHARED_PATTERNS", "EXPORT_MODE", "SharedQuantState", "attach_shared_quant_states", - "collect_shared_input_modules", "convert_quantization_axis_to_reduce_axis", "export_torch_mode", "find_shared_input_groups", diff --git a/modelopt/torch/quantization/utils/shared_input.py b/modelopt/torch/quantization/utils/shared_input.py index 3b75126c626..407e35ecc56 100644 --- a/modelopt/torch/quantization/utils/shared_input.py +++ b/modelopt/torch/quantization/utils/shared_input.py @@ -13,24 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Discovery and per-group shared state for modules that consume the same input. - -- :func:`collect_shared_input_modules` runs a probe forward with hooks and - returns the modules that received the same input tensor. Used by both - calibration (here) and export (``unified_export_hf``). -- :class:`SharedQuantState` is an ``nn.Module`` attached to a sibling group's - parent so its tensors ride along in ``state_dict``. Holds only - ``weight_global_amax`` today; designed to grow (act scales, LoRA factors, ...). - -:func:`find_shared_input_groups` (hooks + name patterns) produces the -``(parent, members)`` tuples consumed by :func:`attach_shared_quant_states` and -:func:`populate_shared_state`. +"""Per-group shared quantization state for fusible sibling modules. + +Weight ``global_amax`` must be unified across modules that get **fused** at export +(q/k/v -> qkv, gate/up -> gate_up) so they quantize with one per-tensor scale. +:func:`find_shared_input_groups` discovers these groups by regex over module FQNs; +:data:`DEFAULT_WEIGHT_SHARED_PATTERNS` covers the standard q/k/v, gate/up and w1/w3 +names, and callers may override per quantizer kind via ``MaxCalibConfig.shared_patterns``. + +Discovery is name/pattern-based (not input-hook-based) on purpose: "shares an input +tensor" is broader than "gets fused" — e.g. a ``shared_expert_gate`` reads the same +hidden states as the GLU pair but is never fused with it, so a hook would over-group +it. Patterns match exactly the roles export fuses. + +:class:`SharedQuantState` is an ``nn.Module`` attached to a group's parent so its +tensors ride along in ``state_dict``. Holds only ``weight_global_amax`` today; +designed to grow (act scales, LoRA factors, ...). The ``(parent, members)`` tuples +from :func:`find_shared_input_groups` are consumed by +:func:`attach_shared_quant_states` and :func:`populate_shared_state`. """ import re import warnings -from collections import defaultdict -from collections.abc import Callable, Sequence +from collections.abc import Sequence import torch import torch.distributed as dist @@ -38,90 +43,26 @@ from modelopt.torch.utils.distributed import ParallelState -from .core_utils import is_quantized_linear, quantizer_attr_names, reduce_amax - -ForwardLoop = Callable[[nn.Module], None] +from .core_utils import quantizer_attr_names, reduce_amax __all__ = [ + "DEFAULT_WEIGHT_SHARED_PATTERNS", "SharedQuantState", "attach_shared_quant_states", - "collect_shared_input_modules", "find_shared_input_groups", "populate_shared_state", ] - -# --------------------------------------------------------------------------- -# Discovery primitive -# --------------------------------------------------------------------------- - - -def collect_shared_input_modules( - model: nn.Module, - forward_fn: Callable[[], None], - module_filter: Callable[[nn.Module], bool] | None = None, - output_filter: Callable[[nn.Module], bool] | None = None, -) -> tuple[dict, dict | None]: - """Hook the model, run a probe forward, group modules by shared input tensor. - - Args: - model: model to probe. - forward_fn: zero-arg callable running a forward pass on ``model``; - quantizers are disabled during it so probe outputs aren't perturbed. - module_filter: which modules to hook on input (default - :func:`is_quantized_linear`). - output_filter: optional, which modules to hook on output. AWQ export uses - it to map a layernorm's output to itself so the pre_quant_scale can be - folded into the layernorm; ``None`` skips output tracking. - - Returns: - ``(input_to_modules, output_to_modules)``: input tensor -> modules that - received it (the shared-input group), and output tensor -> producing - module (``None`` when ``output_filter`` is not given). - """ - # Inline import to avoid a cycle (conversion/dataset_utils import from utils); - # safe because this runs at calibration/export time, not module load. - from modelopt.torch.quantization.conversion import set_quantizer_by_cfg_context - from modelopt.torch.utils.dataset_utils import _disable_use_cache - - if module_filter is None: - module_filter = is_quantized_linear - - input_to_modules: dict = defaultdict(list) - output_to_modules: dict | None = defaultdict(lambda: None) if output_filter else None - - def _input_hook(module, args, output): - if len(args) > 0 and isinstance(args[0], torch.Tensor): - input_to_modules[args[0]].append(module) - - def _output_hook(module, args, output): - if output_to_modules is not None and isinstance(output, torch.Tensor): - output_to_modules[output] = module - - handles = [] - for name, module in model.named_modules(): - if output_filter is not None and output_filter(module): - module.name = name - handles.append(module.register_forward_hook(_output_hook)) - elif module_filter(module): - module.name = name - handles.append(module.register_forward_hook(_input_hook)) - - if not handles: - return input_to_modules, output_to_modules - - try: - with ( - torch.no_grad(), - set_quantizer_by_cfg_context(model, [{"quantizer_name": "*", "enable": False}]), - _disable_use_cache(model), - ): - forward_fn() - finally: - for handle in handles: - handle.remove() - - return input_to_modules, output_to_modules +# Default fusible-sibling patterns for WEIGHT global_amax — the groups export fuses: +# q/k/v -> qkv, gate/up (incl. Mixtral w1/w3) -> gate_up. These reproduce the legacy +# name-based grouping exactly. Regexes are ``re.fullmatch``-ed against module FQNs; +# ``(?:(.*)\.)?`` captures the immediate parent so grouping is per-parent (per-expert +# for MoE experts). Override per quantizer kind via ``MaxCalibConfig.shared_patterns``. +DEFAULT_WEIGHT_SHARED_PATTERNS = [ + r"(?:(.*)\.)?(?:q_proj|k_proj|v_proj)", + r"(?:(.*)\.)?(?:gate_proj|up_proj)", + r"(?:(.*)\.)?(?:w1|w3)", +] # --------------------------------------------------------------------------- @@ -178,7 +119,7 @@ def sync_weight_global_amax(self, parallel_state: ParallelState | None) -> None: # --------------------------------------------------------------------------- -# Group discovery (hook + pattern) → (parent, members) tuples +# Group discovery (regex over FQNs) → (parent, members) tuples # --------------------------------------------------------------------------- @@ -253,61 +194,31 @@ def ancestors(m: nn.Module) -> list[nn.Module]: return fallback -def _groups_from_hooks( - model: nn.Module, - forward_loop: ForwardLoop, -) -> list[tuple[nn.Module, list[nn.Module]]]: - """Discover sibling groups from a hook-based probe forward. - - Modules whose first input is the same tensor object form a group, parented at - their LCA. Only catches the *literal same tensor*, so cross-expert MoE sharing - (one input per expert) is missed — patterns cover that. - """ - input_to_modules, _ = collect_shared_input_modules( - model, forward_fn=lambda: forward_loop(model) - ) - if not input_to_modules: - return [] - - parent_map = _build_parent_map(model) - wq_attr = quantizer_attr_names("weight").weight_quantizer - groups: list[tuple[nn.Module, list[nn.Module]]] = [] - for members in input_to_modules.values(): - # Dedup (a module may be hooked twice) and keep calibrated members only. - unique: list[nn.Module] = [] - seen: set[int] = set() - for m in members: - if id(m) in seen: - continue - if not _has_calibratable_weight_quantizer(m, wq_attr): - continue - seen.add(id(m)) - unique.append(m) - if len(unique) >= 2: - parent = _lowest_common_ancestor(unique, parent_map, fallback=model) - groups.append((parent, unique)) - return groups - - -def _groups_from_patterns( +def find_shared_input_groups( model: nn.Module, - patterns: Sequence[str], + patterns: Sequence[str] | None = None, ) -> list[tuple[nn.Module, list[nn.Module]]]: - r"""Discover groups by regex over module FQNs; capture groups define the grouping key. + r"""Find fusible sibling groups by regex over module FQNs; capture groups define the key. - Each pattern is a regex ``re.fullmatch``-ed against every quantized module's - fully-qualified name; modules whose match yields the same capture-group tuple form - one group, parented at their LCA. Granularity is set by *what you capture*: + Each pattern is ``re.fullmatch``-ed against every quantized module's fully-qualified + name; modules whose match yields the same capture-group tuple form one group, parented + at their LCA. Granularity is set by *what you capture*: - - Capture the immediate parent -> per-parent grouping: q/k/v per attention block, - and **per-expert** ``w1``/``w3`` (each expert is the immediate parent), e.g. + - Capture the immediate parent -> per-parent grouping: q/k/v per attention block, and + **per-expert** ``w1``/``w3`` (each expert is the immediate parent), e.g. ``r"(.*)\.(?:w1|w3)$"``. - - Capture only a level above the expert index, leaving the index uncaptured -> - one **cross-expert** group, e.g. ``r"(.*)\.experts\.\d+\.(?:w1|w3)$"``. - - Roles to fuse together go in a non-capturing alternation ``(?:w1|w3)`` so they - don't split the key; what you wrap in ``(...)`` is the group boundary. + - Capture only a level above the expert index, leaving the index uncaptured -> one + **cross-expert** group, e.g. ``r"(.*)\.experts\.\d+\.(?:w1|w3)$"``. + + Roles to fuse together go in a non-capturing alternation ``(?:w1|w3)`` so they don't + split the key; what you wrap in ``(...)`` is the group boundary. Pass + :data:`DEFAULT_WEIGHT_SHARED_PATTERNS` for the standard q/k/v + gate/up groups, or + override via ``MaxCalibConfig.shared_patterns``. The caller selects which quantizer + these groups apply to (today only the weight quantizer). Returns ``(parent, members)`` + tuples; empty when no patterns are given. """ + if not patterns: + return [] wq_attr = quantizer_attr_names("weight").weight_quantizer compiled = [re.compile(p) for p in patterns] buckets: dict[tuple, list[nn.Module]] = {} @@ -318,10 +229,8 @@ def _groups_from_patterns( for pattern_idx, regex in enumerate(compiled): match = regex.fullmatch(name) if match is not None: - key = ( - pattern_idx, - match.groups(), - ) # include pattern_idx in case 1+ partterns collide. + # include pattern_idx in case 2+ patterns yield the same capture tuple + key = (pattern_idx, match.groups()) if key not in buckets: buckets[key] = [] order.append(key) @@ -337,33 +246,6 @@ def _groups_from_patterns( return groups -def find_shared_input_groups( - model: nn.Module, - forward_loop: ForwardLoop | None = None, - patterns: Sequence[str] | None = None, -) -> list[tuple[nn.Module, list[nn.Module]]]: - """Find sibling groups from regex patterns if given, else from a hook probe. - - Patterns and hooks are mutually exclusive: when ``patterns`` is set it is the - sole source (and must list every group you want); otherwise hook discovery runs - on the ``forward_loop`` probe. - - - ``patterns`` — regexes over module FQNs whose capture groups define the grouping - key (see :func:`_groups_from_patterns`); the capture boundary chooses per-expert - vs cross-expert granularity. The caller selects which quantizer these groups apply - to (today only the weight quantizer; see ``MaxCalibConfig.shared_patterns``). - - ``forward_loop`` — hook probe grouping modules that receive the *literal same - tensor* (Q/K/V, gate/up within one block/expert; per-expert for MoE). - - Returns ``(parent, members)`` tuples. - """ - if patterns: - return _groups_from_patterns(model, patterns) - if forward_loop is not None: - return _groups_from_hooks(model, forward_loop) - return [] - - # --------------------------------------------------------------------------- # Attach / populate lifecycle # --------------------------------------------------------------------------- @@ -371,26 +253,24 @@ def find_shared_input_groups( def attach_shared_quant_states( model: nn.Module, - forward_loop: ForwardLoop | None = None, patterns: Sequence[str] | None = None, ) -> int: """Create ``SharedQuantState`` on each group's parent and link members. - The parent owns the state under ``_shared_quant_state`` (normal setattr → a - registered submodule, so its buffer rides along in ``state_dict``). Each - member's weight quantizer — the only consumer, via - ``promote_nvfp4_static_quantizers`` — gets a back-reference under the distinct - name ``_shared_quant_state_ref`` set with ``object.__setattr__`` (not a - submodule, so the buffer isn't duplicated per member). The distinct names let + Groups are discovered by ``patterns`` (regexes over module FQNs; see + :func:`find_shared_input_groups`). The parent owns the state under + ``_shared_quant_state`` (normal setattr → a registered submodule, so its buffer + rides along in ``state_dict``). Each member's weight quantizer — the only consumer, + via ``promote_nvfp4_static_quantizers`` — gets a back-reference under the distinct + name ``_shared_quant_state_ref`` set with ``object.__setattr__`` (not a submodule, + so the buffer isn't duplicated per member). The distinct names let ``populate_shared_state`` select owners with a plain ``getattr``. Idempotent (reuses an existing parent state). Returns the number created. """ n_created = 0 wq_attr = quantizer_attr_names("weight").weight_quantizer - for parent, members in find_shared_input_groups( - model, forward_loop=forward_loop, patterns=patterns - ): + for parent, members in find_shared_input_groups(model, patterns=patterns): if not hasattr(parent, "_shared_quant_state"): parent._shared_quant_state = SharedQuantState() n_created += 1 diff --git a/tests/unit/torch/quantization/test_shared_input.py b/tests/unit/torch/quantization/test_shared_input.py index 9fcd89871b3..2ac5fd0dab2 100644 --- a/tests/unit/torch/quantization/test_shared_input.py +++ b/tests/unit/torch/quantization/test_shared_input.py @@ -38,6 +38,7 @@ from modelopt.torch.quantization.model_calib import max_calibrate from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer from modelopt.torch.quantization.utils import ( + DEFAULT_WEIGHT_SHARED_PATTERNS, SharedQuantState, attach_shared_quant_states, find_shared_input_groups, @@ -46,15 +47,11 @@ reduce_amax, ) -# Test-only patterns: lists of regexes (the per-kind value of ``shared_patterns``, -# e.g. the ``"weight"`` list). ``re.fullmatch``-ed against module FQNs; capture groups -# define the grouping key. ``(?:(.*)\.)?`` captures the immediate parent (or None at the +# The production default patterns (q/k/v, gate/up, w1/w3) are exactly what these tests +# need; reuse them so the tests also exercise the real default. ``re.fullmatch``-ed +# against module FQNs; ``(?:(.*)\.)?`` captures the immediate parent (or None at the # model root, since these test models hold roles directly) -> per-parent / per-expert. -SIBLING_PATTERNS = [ - r"(?:(.*)\.)?(?:q_proj|k_proj|v_proj)", - r"(?:(.*)\.)?(?:gate_proj|up_proj)", - r"(?:(.*)\.)?(?:w1|w3)", -] +SIBLING_PATTERNS = DEFAULT_WEIGHT_SHARED_PATTERNS NVFP4_BLOCK = 16 @@ -129,28 +126,34 @@ def __init__(self): groups = find_shared_input_groups(m, patterns=SIBLING_PATTERNS) assert groups == [], "single sibling must not form a group" - def test_patterns_override_hooks(self): - """When patterns are given, hooks are skipped (patterns are the sole source).""" - attn = _DummyAttention() - mtq.replace_quant_module(attn) - cfg = _make_nvfp4_static_cfg() - for proj in (attn.q_proj, attn.k_proj, attn.v_proj): - proj.weight_quantizer.set_from_attribute_config(cfg) - _populate_amax(proj, value=1.0) + def test_default_patterns_skip_non_fusible_gate(self): + """A gate sharing the block input but never fused (e.g. ``shared_expert_gate``) + must NOT be grouped with the gate_proj/up_proj pair. - def fwd(m): - x = torch.randn(2, 32) - for proj in (m.q_proj, m.k_proj, m.v_proj): - proj(x) + This is why grouping is name/pattern-based, not shared-input-hook-based: a hook + would lump ``shared_expert_gate`` in with the GLU pair (same input tensor) and + wrongly unify its global_amax. The default patterns match only the fused roles. + """ - # Hooks alone would group all three (shared input). A q/k-only pattern must - # win — v_proj absent proves hooks are skipped when patterns are provided. - groups = find_shared_input_groups( - attn, forward_loop=fwd, patterns=[r"(?:(.*)\.)?(?:q_proj|k_proj)"] - ) + class _MLPWithGate(nn.Module): + def __init__(self): + super().__init__() + self.gate_proj = nn.Linear(32, 32, bias=False) + self.up_proj = nn.Linear(32, 32, bias=False) + self.shared_expert_gate = nn.Linear(32, 1, bias=False) # shares input, not fused + + m = _MLPWithGate() + mtq.replace_quant_module(m) + cfg = _make_nvfp4_static_cfg() + for lin in (m.gate_proj, m.up_proj, m.shared_expert_gate): + lin.weight_quantizer.set_from_attribute_config(cfg) + _populate_amax(lin, value=1.0) + + groups = find_shared_input_groups(m, patterns=DEFAULT_WEIGHT_SHARED_PATTERNS) assert len(groups) == 1 _parent, members = groups[0] - assert len(members) == 2 and attn.v_proj not in members + assert set(members) == {m.gate_proj, m.up_proj} + assert m.shared_expert_gate not in members def test_populate_writes_max_across_siblings(self): attn = _DummyAttention() @@ -348,72 +351,3 @@ def fwd(m): assert loaded["_shared_quant_state.weight_global_amax"].item() == pytest.approx( sd["_shared_quant_state.weight_global_amax"].item() ) - - -class _AttnQKV(nn.Module): - """Attention-like module whose q/k/v consume the same input tensor.""" - - def __init__(self, hidden: int = 16) -> None: - super().__init__() - self.q_proj = nn.Linear(hidden, hidden, bias=False) - self.k_proj = nn.Linear(hidden, hidden, bias=False) - self.v_proj = nn.Linear(hidden, hidden, bias=False) - - def forward(self, x): - return self.q_proj(x) + self.k_proj(x) + self.v_proj(x) - - -class _GLUExpert(nn.Module): - """Expert whose w1/w3 consume the same (per-expert) input tensor.""" - - def __init__(self, hidden: int = 16) -> None: - super().__init__() - self.w1 = nn.Linear(hidden, hidden, bias=False) - self.w3 = nn.Linear(hidden, hidden, bias=False) - - def forward(self, x): - return self.w1(x) * self.w3(x) - - -class _AttnMoE(nn.Module): - """q/k/v attention + MoE experts; each expert gets a distinct input slice.""" - - def __init__(self, n_experts: int = 2, hidden: int = 16) -> None: - super().__init__() - self.self_attn = _AttnQKV(hidden) - self.experts = nn.ModuleList(_GLUExpert(hidden) for _ in range(n_experts)) - - def forward(self, x): - a = self.self_attn(x) - for i, expert in enumerate(self.experts): - expert(a[i : i + 1]) # distinct slice per expert -> per-expert hook grouping - return a - - -class TestHookPatternEquivalence: - """The sibling regex patterns reproduce exactly what hook discovery finds.""" - - def test_hooks_match_sibling_patterns(self): - model = _AttnMoE(n_experts=2, hidden=16) - mtq.replace_quant_module(model) - cfg = _make_nvfp4_static_cfg() - linears = [model.self_attn.q_proj, model.self_attn.k_proj, model.self_attn.v_proj] - for expert in model.experts: - linears += [expert.w1, expert.w3] - for lin in linears: - lin.weight_quantizer.set_from_attribute_config(cfg) - _populate_amax(lin, value=1.0) - - def fwd(m): - m(torch.randn(2, 16)) - - hook_groups = find_shared_input_groups(model, forward_loop=fwd) - pattern_groups = find_shared_input_groups(model, patterns=SIBLING_PATTERNS) - - def normalize(groups): - # Group identity = (parent module, set of member modules), order-independent. - return {(id(parent), frozenset(id(m) for m in members)) for parent, members in groups} - - assert normalize(hook_groups) == normalize(pattern_groups) - # Sanity: one q/k/v group + one w1/w3 group per expert (per-expert, not cross). - assert len(hook_groups) == 3 From 85fe5602a4847277dcef6b60b35f6acd9d2a8d29 Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Tue, 2 Jun 2026 12:55:35 -0700 Subject: [PATCH 3/6] update based on comments Signed-off-by: Shiyang Chen --- modelopt/torch/quantization/config.py | 27 ++++ modelopt/torch/quantization/model_calib.py | 12 +- .../nn/modules/tensor_quantizer.py | 4 + modelopt/torch/quantization/utils/__init__.py | 8 +- .../torch/quantization/utils/core_utils.py | 22 ++- .../torch/quantization/utils/shared_input.py | 33 ++-- .../torch/quantization/test_shared_input.py | 144 +++++++++++++++--- 7 files changed, 204 insertions(+), 46 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index fecbdb90945..0656da065f7 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -150,6 +150,7 @@ """ +import re import warnings from collections.abc import Mapping, Sequence from typing import Any, Literal @@ -733,6 +734,32 @@ class MaxCalibConfig(QuantizeAlgorithmConfig): ), ) + @field_validator("shared_patterns") + @classmethod + def validate_shared_patterns(cls, v): + """Reject unknown quantizer kinds and invalid regexes at the config boundary.""" + if v is None: + return v + supported = {"weight", "input"} + unknown = set(v) - supported + if unknown: + raise ValueError( + f"shared_patterns has unsupported quantizer kind(s) {sorted(unknown)}; " + f"expected keys from {sorted(supported)}." + ) + offending = ("", "") # (kind, pattern) of the last regex tried; set before each compile + try: + for kind, patterns in v.items(): + for pattern in patterns: + offending = (kind, pattern) + re.compile(pattern) + except re.error as e: + bad_kind, bad_pattern = offending + raise ValueError( + f"shared_patterns[{bad_kind!r}] has an invalid regex {bad_pattern!r}: {e}" + ) from e + return v + class MseCalibConfig(QuantizeAlgorithmConfig): """Configuration for per-tensor MSE calibration. diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 54ba1d88b2b..a91b9f6263e 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -231,10 +231,14 @@ def max_calibrate( # Fail fast on NVFP4 static-block with TP>1 (sharded_state_dict treats _amax as replicated). _check_nvfp4_static_tp_supported(model) - # Discover fusible sibling groups by name regex (default q/k/v + gate/up, or the - # caller's "weight" patterns) and attach shared state; populated below, after any - # cross-rank _amax sync. Only "weight" is consumed today; "input" is reserved. - weight_patterns = (shared_patterns or {}).get("weight") or DEFAULT_WEIGHT_SHARED_PATTERNS + # Discover fusible sibling groups by name regex and attach shared state; populated + # below, after any cross-rank _amax sync. Default to q/k/v + gate/up when no "weight" + # key is given; an explicit (possibly empty) list overrides it — key presence, not + # truthiness, so {"weight": []} disables grouping. Only "weight" is consumed today. + if shared_patterns is not None and "weight" in shared_patterns: + weight_patterns = list(shared_patterns["weight"]) + else: + weight_patterns = DEFAULT_WEIGHT_SHARED_PATTERNS attach_shared_quant_states(model, patterns=weight_patterns) if not distributed_sync: diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index eef525f21d5..573f39d7d7c 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -172,6 +172,10 @@ class TensorQuantizer(nn.Module): "pre_bwd_fn", # quantizer cache for custom backends, like luts "_quantizer_cache", + # Runtime-only back-reference to a sibling group's SharedQuantState; it is + # re-established during calibration and must not be serialized (it points to a + # live module whose dynamic QuantLinear members are not picklable). + "_shared_quant_state_ref", } def __init__( diff --git a/modelopt/torch/quantization/utils/__init__.py b/modelopt/torch/quantization/utils/__init__.py index 900c06d50d9..f72b26820f9 100644 --- a/modelopt/torch/quantization/utils/__init__.py +++ b/modelopt/torch/quantization/utils/__init__.py @@ -18,13 +18,7 @@ from .core_utils import * from .layerwise_calib import LayerActivationCollector -from .shared_input import ( - DEFAULT_WEIGHT_SHARED_PATTERNS, - SharedQuantState, - attach_shared_quant_states, - find_shared_input_groups, - populate_shared_state, -) +from .shared_input import * __all__ = [ "DEFAULT_WEIGHT_SHARED_PATTERNS", diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index f49b75d53e7..10fd326b0ff 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -955,13 +955,25 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int: per-block amax) instead of the generic E4M3 path. If the quantizer has a ``_shared_quant_state_ref`` with a populated - ``weight_global_amax`` (sibling group), that shared value is used instead of - this quantizer's own ``_amax`` reduction, keeping siblings on a common FP8 grid. + ``weight_global_amax`` (sibling group) whose owning state lives within ``model``, + that shared value is used instead of this quantizer's own ``_amax`` reduction, + keeping siblings on a common FP8 grid. Returns the number of quantizers converted. """ from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer + # Shared states owned within THIS promotion root. This function also runs on + # submodules / individual linears; a quantizer may still carry a back-reference from + # an earlier full-model calibration whose owning ``_shared_quant_state`` is outside + # ``model``. Only trust refs reachable here — otherwise the global_amax would come + # from an unrelated prior run; fall back to the quantizer's own amax instead. + valid_shared_states = { + id(state) + for owner in model.modules() + if (state := getattr(owner, "_shared_quant_state", None)) is not None + } + converted = 0 for _name, module in list(model.named_modules()): if not isinstance(module, TensorQuantizer) or not module.is_enabled: @@ -976,7 +988,11 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int: # otherwise fall back to this quantizer's own per-block amax. already_promoted = isinstance(module, NVFP4StaticQuantizer) shared = getattr(module, "_shared_quant_state_ref", None) - if shared is not None and shared.weight_global_amax is not None: + if ( + shared is not None + and id(shared) in valid_shared_states + and shared.weight_global_amax is not None + ): global_amax = shared.weight_global_amax else: global_amax = reduce_amax(amax.clone().detach(), axis=None) diff --git a/modelopt/torch/quantization/utils/shared_input.py b/modelopt/torch/quantization/utils/shared_input.py index 407e35ecc56..8c090c0a6fa 100644 --- a/modelopt/torch/quantization/utils/shared_input.py +++ b/modelopt/torch/quantization/utils/shared_input.py @@ -34,7 +34,6 @@ """ import re -import warnings from collections.abc import Sequence import torch @@ -74,21 +73,27 @@ class SharedQuantState(nn.Module): """State shared across a sibling group of quantized modules. Attached to the group's parent (e.g. ``self_attn``, ``block_sparse_moe``) as a - registered submodule so its buffers ride along in ``state_dict``. Members that - quantize the same input resolve shared values here instead of computing them - independently and reconciling at export. - - Holds only ``weight_global_amax`` today; new tensor fields (act scales, AWQ - ``pre_quant_scale``, SVDQuant/FlatQuant factors) should use ``register_buffer`` - so they serialize too. + submodule, but its buffers are **non-persistent**: this is a calibration-time + artifact, not part of the checkpoint. Members that quantize the same input resolve + shared values here during calibration; the resolved value is then baked into each + member's promoted quantizer (``NVFP4StaticQuantizer._global_amax``, which *is* + serialized). So the scale survives save/restore via the members, and restore need + not re-create this submodule (it isn't in ``state_dict``) — see + :func:`attach_shared_quant_states`. + + Holds only ``weight_global_amax`` today, and it is mirrored onto every member. Any + future field that is **not** mirrored on a member would not survive save/restore as + a non-persistent buffer and would need its own restore path. """ def __init__(self) -> None: """Initialize with an unset ``weight_global_amax`` and no registered members.""" super().__init__() # NVFP4 two-level FP8 grid scale = max over members' per-block ``_amax``. - # Unset for non-NVFP4 configs. - self.register_buffer("weight_global_amax", None) + # Non-persistent: a calibration-time artifact. The resolved value is baked into + # each member's ``NVFP4StaticQuantizer._global_amax`` (which IS serialized), so + # restore rebuilds scales from members and need not re-create this submodule. + self.register_buffer("weight_global_amax", None, persistent=False) # Back-references to member modules. ``object.__setattr__`` keeps them out of # ``_modules`` so members' params don't re-enter our ``state_dict``. object.__setattr__(self, "_members", []) @@ -98,7 +103,8 @@ def sync_weight_global_amax(self, parallel_state: ParallelState | None) -> None: Weights are DP-replicated (no DP sync needed). EP sync is required since ranks hold different experts; TP sync guards against per-child ``_amax`` TP - sync skipping block-quantized weights. + sync skipping block-quantized weights. Raises on a failed all-reduce: a silent + failure would leave ranks with different scales and still promote/export them. """ if self.weight_global_amax is None or parallel_state is None: return @@ -115,7 +121,7 @@ def sync_weight_global_amax(self, parallel_state: ParallelState | None) -> None: group=group.group, ) except RuntimeError as e: - warnings.warn(f"Failed to sync shared weight_global_amax: {e}") + raise RuntimeError("Failed to sync shared weight_global_amax") from e # --------------------------------------------------------------------------- @@ -150,6 +156,9 @@ def _climb_past_modulelist( Attaching ``SharedQuantState`` to a ``ModuleList`` registers it in that container's ``_modules`` and corrupts its iteration/length (the state shows up alongside the experts), so attach to the first non-ModuleList ancestor. + + Only ``nn.ModuleList`` is handled today. It can be extended in the future to include modules + like nn.ModuleDict``. """ cur = module while isinstance(cur, nn.ModuleList): diff --git a/tests/unit/torch/quantization/test_shared_input.py b/tests/unit/torch/quantization/test_shared_input.py index 2ac5fd0dab2..879010a7f68 100644 --- a/tests/unit/torch/quantization/test_shared_input.py +++ b/tests/unit/torch/quantization/test_shared_input.py @@ -74,6 +74,9 @@ def __init__(self, in_features: int = 32, out_features: int = 32) -> None: self.k_proj = nn.Linear(in_features, out_features, bias=False) self.v_proj = nn.Linear(in_features, out_features, bias=False) + def forward(self, x): + return self.q_proj(x) + self.k_proj(x) + self.v_proj(x) + def _populate_amax(linear: nn.Module, value: float) -> None: """Directly set ``_amax`` on a linear's weight_quantizer for deterministic testing.""" @@ -175,6 +178,33 @@ def test_populate_writes_max_across_siblings(self): assert shared is not None assert torch.isclose(shared, torch.tensor(3.0)), f"expected 3.0, got {shared.item()}" + def test_promote_ignores_shared_state_outside_root(self): + """Promoting a submodule must ignore a back-ref whose owning state is outside it. + + ``promote_nvfp4_static_quantizers`` also runs on submodules/individual linears; a + quantizer may still carry ``_shared_quant_state_ref`` from an earlier full-model + run. If the owning ``_shared_quant_state`` is not within the promotion root, the + quantizer must fall back to its OWN amax, not the stale group value. + """ + from modelopt.torch.quantization.utils import promote_nvfp4_static_quantizers + + attn = _DummyAttention() + mtq.replace_quant_module(attn) + cfg = _make_nvfp4_static_cfg() + for proj, val in ((attn.q_proj, 1.0), (attn.k_proj, 3.0), (attn.v_proj, 2.0)): + proj.weight_quantizer.set_from_attribute_config(cfg) + _populate_amax(proj, value=val) + # Group on the parent → shared weight_global_amax = max = 3.0. + attach_shared_quant_states(attn, patterns=SIBLING_PATTERNS) + populate_shared_state(attn) + assert torch.isclose(attn._shared_quant_state.weight_global_amax, torch.tensor(3.0)) + + # Promote with q_proj as the root: it does NOT contain attn._shared_quant_state + # (that lives on the parent), so the stale ref is ignored → own amax (1.0), not 3.0. + promote_nvfp4_static_quantizers(attn.q_proj) + ga = attn.q_proj.weight_quantizer.global_amax + assert torch.isclose(ga, torch.tensor(1.0)), f"expected own amax 1.0, got {ga.item()}" + class _MoEExpert(nn.Module): """A toy MoE expert with Mixtral-style ``w1`` (gate) and ``w3`` (up) projections.""" @@ -316,8 +346,14 @@ def __init__(self): expected = reduce_amax(m.proj.weight_quantizer._amax, axis=None).item() assert m.proj.weight_quantizer.global_amax.item() == pytest.approx(expected) - def test_shared_state_survives_state_dict_round_trip(self, tmp_path): - """``SharedQuantState`` is an nn.Module; its buffer must round-trip through state_dict.""" + def test_shared_state_buffer_is_non_persistent(self): + """The shared buffer is a calibration-time artifact and must NOT be in state_dict. + + The scale is carried by each member's promoted quantizer (``_global_amax``), which + IS serialized. If the shared buffer were persistent it would add + ``_shared_quant_state.weight_global_amax`` keys that restore can't match (the + submodule isn't re-created on load) — the regression covered end-to-end below. + """ attn = self._setup_attention() def fwd(m): @@ -328,26 +364,94 @@ def fwd(m): max_calibrate(attn, forward_loop=fwd, distributed_sync=False) - # Shared state survives — its buffer lives in the parent's state_dict - # under ``_shared_quant_state.weight_global_amax`` exactly once. - assert hasattr(attn, "_shared_quant_state") + assert hasattr(attn, "_shared_quant_state") # exists at runtime (calibration artifact) sd = attn.state_dict() - shared_buffer_keys = [k for k in sd if k.endswith("_shared_quant_state.weight_global_amax")] - assert shared_buffer_keys == ["_shared_quant_state.weight_global_amax"], ( - f"buffer should appear exactly once on the parent, got {shared_buffer_keys}" - ) - - # Each weight_quantizer holds the back-reference (via object.__setattr__) - # but it does NOT register as a child submodule — otherwise the buffer - # would be duplicated under each member's prefix. - for proj in (attn.q_proj, attn.k_proj, attn.v_proj): + # Non-persistent: the shared buffer must NOT appear in state_dict. + assert not [k for k in sd if k.endswith("_shared_quant_state.weight_global_amax")] + # The value lives on each member's quantizer (``_global_amax``), which IS persisted, + # and the runtime back-reference is set (but not as a child submodule). + for role in ("q_proj", "k_proj", "v_proj"): + proj = getattr(attn, role) assert proj.weight_quantizer._shared_quant_state_ref is attn._shared_quant_state assert "_shared_quant_state_ref" not in proj.weight_quantizer._modules + assert f"{role}.weight_quantizer._global_amax" in sd - # state_dict round-trips under ``weights_only=True``. - path = tmp_path / "sd.pt" - torch.save(sd, path) - loaded = torch.load(path, weights_only=True) - assert loaded["_shared_quant_state.weight_global_amax"].item() == pytest.approx( - sd["_shared_quant_state.weight_global_amax"].item() + def test_modelopt_save_restore_with_shared_state(self, tmp_path): + """``mtq.quantize`` -> ``mto.save`` -> ``mto.restore`` on a FRESH model round-trips. + + Regression for two save/restore bugs the shared state introduced: (a) the runtime + back-ref must be excluded from ``get_modelopt_state`` (else save pickles a live + ``QuantLinear``), and (b) the shared buffer must be non-persistent (else + ``load_state_dict`` on the fresh, submodule-less model fails on the unexpected key). + """ + import modelopt.torch.opt as mto + + cfg = { + "quant_cfg": [ + {"enable": False, "quantizer_name": "*"}, + { + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: NVFP4_BLOCK, "type": "static", "scale_bits": (4, 3)}, + }, + "quantizer_name": "*weight_quantizer", + }, + ], + "algorithm": "max", + } + attn = _DummyAttention() # fresh (un-quantized) — mtq.quantize rejects re-quantizing + mtq.quantize(attn, cfg, lambda m: m(torch.randn(2, 32))) + + # Grouping happened, and each member carries its own promoted global_amax. + assert hasattr(attn, "_shared_quant_state") + expected = { + role: getattr(attn, role).weight_quantizer.global_amax.item() + for role in ("q_proj", "k_proj", "v_proj") + } + + # Save must not raise (pickling metadata.quantizer_state). + path = tmp_path / "model.pth" + mto.save(attn, path) + + # Restore into a fresh model must not raise (no _shared_quant_state.* key to match). + restored = _DummyAttention() + mto.restore(restored, path) + + for role in ("q_proj", "k_proj", "v_proj"): + wq = getattr(restored, role).weight_quantizer + assert isinstance(wq, NVFP4StaticQuantizer) + assert wq.global_amax.item() == pytest.approx(expected[role]) + + def test_empty_weight_patterns_disable_grouping(self): + """``shared_patterns={"weight": []}`` disables grouping (key presence, not truthiness).""" + attn = self._setup_attention(scales=(0.5, 2.0, 1.0)) # distinct per-proj amaxes + + def fwd(m): + x = torch.randn(2, 32) + m.q_proj(x) + m.k_proj(x) + m.v_proj(x) + + max_calibrate( + attn, forward_loop=fwd, distributed_sync=False, shared_patterns={"weight": []} ) + + # No sibling group: no shared state, and each proj keeps its OWN global_amax. + assert not hasattr(attn, "_shared_quant_state") + gas = [] + for proj in (attn.q_proj, attn.k_proj, attn.v_proj): + assert isinstance(proj.weight_quantizer, NVFP4StaticQuantizer) + assert not hasattr(proj.weight_quantizer, "_shared_quant_state_ref") + gas.append(proj.weight_quantizer.global_amax.item()) + assert len(set(gas)) > 1, f"grouping should be disabled, but global_amax all equal: {gas}" + + def test_config_rejects_invalid_shared_patterns(self): + """Bad keys and bad regexes are rejected when the config is parsed, not at calib time.""" + from modelopt.torch.quantization.config import MaxCalibConfig + + MaxCalibConfig(shared_patterns={"weight": [r"(?:(.*)\.)?(?:q_proj|k_proj)"]}) # valid + MaxCalibConfig(shared_patterns={"weight": []}) # empty list is valid (disables grouping) + with pytest.raises(ValueError, match="unsupported quantizer kind"): + MaxCalibConfig(shared_patterns={"weigth": [r".*"]}) # typo'd key + with pytest.raises(ValueError, match="invalid regex"): + MaxCalibConfig(shared_patterns={"weight": ["("]}) # unbalanced paren From 0596a9f423d94f8f964d2c7c1b87d6f8139013fb Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Tue, 2 Jun 2026 14:00:45 -0700 Subject: [PATCH 4/6] move attach earlier Signed-off-by: Shiyang Chen --- modelopt/torch/quantization/model_calib.py | 24 +++++++++++-------- .../torch/quantization/utils/shared_input.py | 16 ++++++++----- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index a91b9f6263e..000e00347d6 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -214,6 +214,20 @@ def max_calibrate( See :class:`MaxCalibConfig ` for details on the remaining arguments. """ + # Discover fusible sibling groups by name regex and attach the (initially empty) shared + # state up front, so the SharedQuantState container exists for the whole calibration — + # forward-time fields can accumulate into it. Discovery is structural (a pattern over the + # module tree), so it needs no ``_amax``; per-member values are aggregated later by + # populate_shared_state, after the forward and any cross-rank ``_amax`` sync. Default to + # q/k/v + gate/up when no "weight" key is given; an explicit (possibly empty) list + # overrides it — key presence, not truthiness, so {"weight": []} disables grouping. + # Only "weight" is consumed today; "input" is reserved. + if shared_patterns is not None and "weight" in shared_patterns: + weight_patterns = list(shared_patterns["weight"]) + else: + weight_patterns = DEFAULT_WEIGHT_SHARED_PATTERNS + attach_shared_quant_states(model, patterns=weight_patterns) + # Always run weight calibration on the weight tensor directly so every weight # quantizer gets ``_amax``, regardless of MoE routing. Downstream algorithms # (MSE, AWQ, export) then no longer need to patch in a missing ``_amax``. @@ -231,16 +245,6 @@ def max_calibrate( # Fail fast on NVFP4 static-block with TP>1 (sharded_state_dict treats _amax as replicated). _check_nvfp4_static_tp_supported(model) - # Discover fusible sibling groups by name regex and attach shared state; populated - # below, after any cross-rank _amax sync. Default to q/k/v + gate/up when no "weight" - # key is given; an explicit (possibly empty) list overrides it — key presence, not - # truthiness, so {"weight": []} disables grouping. Only "weight" is consumed today. - if shared_patterns is not None and "weight" in shared_patterns: - weight_patterns = list(shared_patterns["weight"]) - else: - weight_patterns = DEFAULT_WEIGHT_SHARED_PATTERNS - attach_shared_quant_states(model, patterns=weight_patterns) - if not distributed_sync: # Single-process: _amax is final — aggregate shared global_amax, then promote. populate_shared_state(model) diff --git a/modelopt/torch/quantization/utils/shared_input.py b/modelopt/torch/quantization/utils/shared_input.py index 8c090c0a6fa..402feb89c5d 100644 --- a/modelopt/torch/quantization/utils/shared_input.py +++ b/modelopt/torch/quantization/utils/shared_input.py @@ -129,12 +129,16 @@ def sync_weight_global_amax(self, parallel_state: ParallelState | None) -> None: # --------------------------------------------------------------------------- -def _has_calibratable_weight_quantizer(child: nn.Module, wq_attr: str) -> bool: - """A child is eligible if its weight quantizer is enabled and has ``_amax`` set.""" +def _has_enabled_weight_quantizer(child: nn.Module, wq_attr: str) -> bool: + """A child is eligible if it has an enabled weight quantizer. + + Group membership is structural (a pattern over the module tree), independent of + calibration — so this does NOT require ``_amax``. That lets attach run before + ``weight_only_quantize``; per-member ``_amax`` is aggregated later in + :func:`populate_shared_state`. + """ wq = getattr(child, wq_attr, None) - if wq is None or not hasattr(wq, "_disabled") or wq._disabled: - return False - return getattr(wq, "_amax", None) is not None + return wq is not None and hasattr(wq, "_disabled") and not wq._disabled def _build_parent_map(model: nn.Module) -> dict[nn.Module, nn.Module]: @@ -233,7 +237,7 @@ def find_shared_input_groups( buckets: dict[tuple, list[nn.Module]] = {} order: list[tuple] = [] for name, module in model.named_modules(): - if not _has_calibratable_weight_quantizer(module, wq_attr): + if not _has_enabled_weight_quantizer(module, wq_attr): continue for pattern_idx, regex in enumerate(compiled): match = regex.fullmatch(name) From c149968f38dadf5878709cfa2a0a6c5422510def Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Tue, 2 Jun 2026 14:12:23 -0700 Subject: [PATCH 5/6] update based on comments Signed-off-by: Shiyang Chen --- .../torch/quantization/test_shared_input.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/tests/unit/torch/quantization/test_shared_input.py b/tests/unit/torch/quantization/test_shared_input.py index 879010a7f68..047d8246072 100644 --- a/tests/unit/torch/quantization/test_shared_input.py +++ b/tests/unit/torch/quantization/test_shared_input.py @@ -13,15 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 - """Unit tests for SharedQuantState — group-level quantization state on parent modules. These tests use a hand-built CPU model with Q/K/V siblings under a dummy @@ -33,8 +24,9 @@ import torch import torch.nn as nn +import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq -from modelopt.torch.quantization.config import QuantizerAttributeConfig +from modelopt.torch.quantization.config import MaxCalibConfig, QuantizerAttributeConfig from modelopt.torch.quantization.model_calib import max_calibrate from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer from modelopt.torch.quantization.utils import ( @@ -43,6 +35,7 @@ attach_shared_quant_states, find_shared_input_groups, populate_shared_state, + promote_nvfp4_static_quantizers, quantizer_attr_names, reduce_amax, ) @@ -186,8 +179,6 @@ def test_promote_ignores_shared_state_outside_root(self): run. If the owning ``_shared_quant_state`` is not within the promotion root, the quantizer must fall back to its OWN amax, not the stale group value. """ - from modelopt.torch.quantization.utils import promote_nvfp4_static_quantizers - attn = _DummyAttention() mtq.replace_quant_module(attn) cfg = _make_nvfp4_static_cfg() @@ -384,8 +375,6 @@ def test_modelopt_save_restore_with_shared_state(self, tmp_path): ``QuantLinear``), and (b) the shared buffer must be non-persistent (else ``load_state_dict`` on the fresh, submodule-less model fails on the unexpected key). """ - import modelopt.torch.opt as mto - cfg = { "quant_cfg": [ {"enable": False, "quantizer_name": "*"}, @@ -447,8 +436,6 @@ def fwd(m): def test_config_rejects_invalid_shared_patterns(self): """Bad keys and bad regexes are rejected when the config is parsed, not at calib time.""" - from modelopt.torch.quantization.config import MaxCalibConfig - MaxCalibConfig(shared_patterns={"weight": [r"(?:(.*)\.)?(?:q_proj|k_proj)"]}) # valid MaxCalibConfig(shared_patterns={"weight": []}) # empty list is valid (disables grouping) with pytest.raises(ValueError, match="unsupported quantizer kind"): From 4e08ef65ca14c403ff46dd952651204877a2c5af Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Wed, 3 Jun 2026 10:15:53 -0700 Subject: [PATCH 6/6] fix ci with meta amax Signed-off-by: Shiyang Chen --- .../torch/quantization/utils/shared_input.py | 8 +++++-- .../torch/quantization/test_shared_input.py | 23 +++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/quantization/utils/shared_input.py b/modelopt/torch/quantization/utils/shared_input.py index 402feb89c5d..44372e28182 100644 --- a/modelopt/torch/quantization/utils/shared_input.py +++ b/modelopt/torch/quantization/utils/shared_input.py @@ -340,9 +340,13 @@ def populate_shared_state(model: nn.Module) -> int: parallel_state: ParallelState | None = None for child in members: wq = getattr(child, wq_attr, None) - if wq is None or getattr(wq, "_amax", None) is None: + amax = getattr(wq, "_amax", None) if wq is not None else None + # Skip uncalibrated or meta (no-data) amax. A meta amax — e.g. quantizing an + # ``init_empty_weights`` model before dispatch — would make weight_global_amax a + # meta buffer that then breaks the meta->device ``.to()`` (it needs ``to_empty``). + if amax is None or amax.is_meta: continue - child_maxes.append(reduce_amax(wq._amax, axis=None)) + child_maxes.append(reduce_amax(amax, axis=None)) if parallel_state is None: parallel_state = getattr(child, "parallel_state", None) diff --git a/tests/unit/torch/quantization/test_shared_input.py b/tests/unit/torch/quantization/test_shared_input.py index 047d8246072..ca7a1a53fee 100644 --- a/tests/unit/torch/quantization/test_shared_input.py +++ b/tests/unit/torch/quantization/test_shared_input.py @@ -171,6 +171,29 @@ def test_populate_writes_max_across_siblings(self): assert shared is not None assert torch.isclose(shared, torch.tensor(3.0)), f"expected 3.0, got {shared.item()}" + def test_populate_skips_meta_amax(self): + """Meta (no-data) ``_amax`` must not become a meta ``weight_global_amax`` buffer. + + Quantizing an ``init_empty_weights`` model produces meta ``_amax``; aggregating it + would make ``weight_global_amax`` a meta buffer that breaks the later meta->device + ``.to()`` during dispatch. The group is skipped instead, leaving the buffer ``None``. + """ + attn = _DummyAttention() + mtq.replace_quant_module(attn) + cfg = _make_nvfp4_static_cfg() + for proj in (attn.q_proj, attn.k_proj, attn.v_proj): + proj.weight_quantizer.set_from_attribute_config(cfg) + out_features, in_features = proj.weight.shape + proj.weight_quantizer._amax = torch.empty( + (out_features, in_features // NVFP4_BLOCK), device="meta" + ) + + attach_shared_quant_states(attn, patterns=SIBLING_PATTERNS) # groups q/k/v (no amax needed) + n_groups = populate_shared_state(attn) + + assert n_groups == 0 # nothing real to aggregate + assert attn._shared_quant_state.weight_global_amax is None # not a meta tensor + def test_promote_ignores_shared_state_outside_root(self): """Promoting a submodule must ignore a back-ref whose owning state is outside it.