diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index bf8e99ff5af..0656da065f7 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -150,6 +150,7 @@ """ +import re import warnings from collections.abc import Mapping, Sequence from typing import Any, Literal @@ -715,6 +716,50 @@ class MaxCalibConfig(QuantizeAlgorithmConfig): ), ) + shared_patterns: dict[str, list[str]] | None = ModeloptField( + default=None, + title="Regex patterns for groups that share quantization state", + description=( + "Optional dict keyed by quantizer kind (``'weight'`` and/or ``'input'``), each a list " + "of regexes matched (full-match) against module fully-qualified names. They must list " + "every group you want for that kind. Modules whose match yields the same capture-group " + "tuple form one group; the capture boundary chooses granularity: capture the immediate " + "parent for per-parent / per-expert groups (e.g. ``r'(.*)\\.(?:q_proj|k_proj|v_proj)'``, " + "``r'(.*)\\.(?:w1|w3)'``); leave the expert index uncaptured for one cross-expert group " + "(``r'(.*)\\.experts\\.\\d+\\.(?:w1|w3)'``). Only ``'weight'`` is used today; ``'input'`` is " + "reserved for future input-quantizer sharing. When the ``'weight'`` list is omitted, " + "the default fusible patterns (q/k/v, gate/up, w1/w3) are used — these match exactly " + "the sibling groups export fuses, avoiding the over-grouping a shared-input heuristic " + "would cause (e.g. a ``shared_expert_gate`` that reads the same input but is not fused)." + ), + ) + + @field_validator("shared_patterns") + @classmethod + def validate_shared_patterns(cls, v): + """Reject unknown quantizer kinds and invalid regexes at the config boundary.""" + if v is None: + return v + supported = {"weight", "input"} + unknown = set(v) - supported + if unknown: + raise ValueError( + f"shared_patterns has unsupported quantizer kind(s) {sorted(unknown)}; " + f"expected keys from {sorted(supported)}." + ) + offending = ("", "") # (kind, pattern) of the last regex tried; set before each compile + try: + for kind, patterns in v.items(): + for pattern in patterns: + offending = (kind, pattern) + re.compile(pattern) + except re.error as e: + bad_kind, bad_pattern = offending + raise ValueError( + f"shared_patterns[{bad_kind!r}] has an invalid regex {bad_pattern!r}: {e}" + ) from e + return v + class MseCalibConfig(QuantizeAlgorithmConfig): """Configuration for per-tensor MSE calibration. diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 223994b3c46..000e00347d6 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -18,7 +18,7 @@ import math import time import warnings -from collections.abc import Callable +from collections.abc import Callable, Mapping, Sequence from functools import partial from typing import TypeAlias @@ -41,6 +41,8 @@ from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( + DEFAULT_WEIGHT_SHARED_PATTERNS, + attach_shared_quant_states, disable_calib, enable_fake_quant, enable_quant, @@ -49,8 +51,8 @@ is_quantized_linear, is_quantized_row_parallel_linear, persistent_materialization, + populate_shared_state, promote_nvfp4_static_quantizers, - quantizer_attr_names, reduce_amax, ) from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper @@ -74,69 +76,10 @@ def _check_nvfp4_static_tp_supported(model: nn.Module) -> None: # no-op without ] -def _is_calibrated_nvfp4_static(q) -> bool: - """True iff ``q`` is an enabled NVFP4-static weight quantizer with ``_amax`` set.""" - return ( - isinstance(q, NVFP4StaticQuantizer) - and not q._disabled - and q.is_nvfp4_static - and getattr(q, "_amax", None) is not None - ) - - -def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]: - """Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers.""" - # Inline: layer_utils → quant_utils → model_calib cycle. - from modelopt.torch.export.layer_utils import _GATE_UP_PAIRS - - # Reuses the existing gate/up pairs and adds Q/K/V (no equivalent constant - # in export). Single source for the gate/up half avoids parallel lists. - patterns: tuple[tuple[str, ...], ...] = (("q_proj", "k_proj", "v_proj"), *_GATE_UP_PAIRS) - groups: list[list[nn.Module]] = [] - wq_attr = quantizer_attr_names("weight").weight_quantizer - for parent in model.modules(): - for sibling_names in patterns: - members = [ - child - for child in (getattr(parent, n, None) for n in sibling_names) - if child is not None and _is_calibrated_nvfp4_static(getattr(child, wq_attr, None)) - ] - if len(members) >= 2: - groups.append(members) - return groups - - def _collect_weight_stats(quantizer: nn.Module, weight: torch.Tensor) -> None: quantizer(weight) -@torch.no_grad() -def _sync_grouped_weight_global_amax(model: nn.Module) -> int: - """Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers. - - Run after ``max_calibrate``. Sibling discovery is name-based via - ``_collect_grouped_linears``; non-matching architectures (wqkv, fused - qkv_proj, DeepSeek variants, single-Linear fused gate_up_proj) silently - fall back to per-module global_amax. Fused-experts containers already - share a single quantizer across gate/up halves and need no sync. - """ - # quant_utils imports back from this module; top-level would cycle. - from modelopt.torch.export.quant_utils import preprocess_linear_fusion - - n_groups = 0 - for group in _collect_grouped_linears(model): - preprocess_linear_fusion(group) - n_groups += 1 - return n_groups - - -@torch.no_grad() -def _promote_nvfp4_static_quantizers_with_global_amax_sync(model: nn.Module) -> None: - """Promote static NVFP4 weight quantizers and sync grouped global amax.""" - promote_nvfp4_static_quantizers(model) - _sync_grouped_weight_global_amax(model) - - CalibratorFactory: TypeAlias = Callable[ [torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator ] @@ -249,6 +192,7 @@ def max_calibrate( forward_loop: ForwardLoop | None = None, distributed_sync=True, sync_expert_weight_amax=False, + shared_patterns: Mapping[str, Sequence[str]] | None = None, ): """Calibrate the model using max. @@ -259,10 +203,31 @@ def max_calibrate( distributed_sync: Whether to sync input_quantizer amax across distributed processes. sync_expert_weight_amax: SequentialMLP only — share one weight amax across all experts in a MoE layer (within-rank sync + EP all-reduce when EP>1). + shared_patterns: Optional dict keyed by quantizer kind (``"weight"``/``"input"``), each a + list of regexes over module FQNs. When the ``"weight"`` list is omitted, + :data:`DEFAULT_WEIGHT_SHARED_PATTERNS` (q/k/v, gate/up, w1/w3) is used. Modules whose + regex match yields the same capture-group tuple form one group — capture the immediate + parent for per-parent (per-expert) grouping, or leave the expert index uncaptured for + cross-expert. Only ``"weight"`` is used today; ``"input"`` is reserved for future + input-quantizer sharing. See :class:`MaxCalibConfig ` for details on the remaining arguments. """ + # Discover fusible sibling groups by name regex and attach the (initially empty) shared + # state up front, so the SharedQuantState container exists for the whole calibration — + # forward-time fields can accumulate into it. Discovery is structural (a pattern over the + # module tree), so it needs no ``_amax``; per-member values are aggregated later by + # populate_shared_state, after the forward and any cross-rank ``_amax`` sync. Default to + # q/k/v + gate/up when no "weight" key is given; an explicit (possibly empty) list + # overrides it — key presence, not truthiness, so {"weight": []} disables grouping. + # Only "weight" is consumed today; "input" is reserved. + if shared_patterns is not None and "weight" in shared_patterns: + weight_patterns = list(shared_patterns["weight"]) + else: + weight_patterns = DEFAULT_WEIGHT_SHARED_PATTERNS + attach_shared_quant_states(model, patterns=weight_patterns) + # Always run weight calibration on the weight tensor directly so every weight # quantizer gets ``_amax``, regardless of MoE routing. Downstream algorithms # (MSE, AWQ, export) then no longer need to patch in a missing ``_amax``. @@ -280,14 +245,10 @@ def max_calibrate( # Fail fast on NVFP4 static-block with TP>1 (sharded_state_dict treats _amax as replicated). _check_nvfp4_static_tp_supported(model) - # Promote eligible static-block NVFP4 weight quantizers to NVFP4StaticQuantizer so - # the static blockwise fake-quant path is used in forward and export picks up the - # two-level (per-block + global) scaling. Run before the ``distributed_sync`` early - # return so single-process callers also get the promotion. No-op for dynamic-block - # / non-NVFP4 configs. - _promote_nvfp4_static_quantizers_with_global_amax_sync(model) - if not distributed_sync: + # Single-process: _amax is final — aggregate shared global_amax, then promote. + populate_shared_state(model) + promote_nvfp4_static_quantizers(model) return # Check MoE calibration completeness before sync @@ -404,6 +365,11 @@ def sync_quantizer_amax_across_tp( module.parallel_state.tensor_parallel_group ) + # _amax is now cross-rank consistent — aggregate shared global_amax then + # promote, so siblings read the unified value instead of their own _amax. + populate_shared_state(model) + promote_nvfp4_static_quantizers(model) + def _mse_quant_func(x, amax, quantizer): """Quantization function for MSE calibration.""" diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index eef525f21d5..573f39d7d7c 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -172,6 +172,10 @@ class TensorQuantizer(nn.Module): "pre_bwd_fn", # quantizer cache for custom backends, like luts "_quantizer_cache", + # Runtime-only back-reference to a sibling group's SharedQuantState; it is + # re-established during calibration and must not be serialized (it points to a + # live module whose dynamic QuantLinear members are not picklable). + "_shared_quant_state_ref", } def __init__( diff --git a/modelopt/torch/quantization/utils/__init__.py b/modelopt/torch/quantization/utils/__init__.py index dc6daa00842..f72b26820f9 100644 --- a/modelopt/torch/quantization/utils/__init__.py +++ b/modelopt/torch/quantization/utils/__init__.py @@ -18,15 +18,21 @@ from .core_utils import * from .layerwise_calib import LayerActivationCollector +from .shared_input import * __all__ = [ + "DEFAULT_WEIGHT_SHARED_PATTERNS", "EXPORT_MODE", + "SharedQuantState", + "attach_shared_quant_states", "convert_quantization_axis_to_reduce_axis", "export_torch_mode", + "find_shared_input_groups", "is_quantized", "is_quantized_column_parallel_linear", "is_quantized_linear", "is_quantized_row_parallel_linear", + "populate_shared_state", "reduce_amax", "reduce_sum", "replace_function", diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index e95e175b71d..10fd326b0ff 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -954,10 +954,26 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int: need to be promoted so they use the two-level scaling path (global amax + per-block amax) instead of the generic E4M3 path. + If the quantizer has a ``_shared_quant_state_ref`` with a populated + ``weight_global_amax`` (sibling group) whose owning state lives within ``model``, + that shared value is used instead of this quantizer's own ``_amax`` reduction, + keeping siblings on a common FP8 grid. + Returns the number of quantizers converted. """ from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer + # Shared states owned within THIS promotion root. This function also runs on + # submodules / individual linears; a quantizer may still carry a back-reference from + # an earlier full-model calibration whose owning ``_shared_quant_state`` is outside + # ``model``. Only trust refs reachable here — otherwise the global_amax would come + # from an unrelated prior run; fall back to the quantizer's own amax instead. + valid_shared_states = { + id(state) + for owner in model.modules() + if (state := getattr(owner, "_shared_quant_state", None)) is not None + } + converted = 0 for _name, module in list(model.named_modules()): if not isinstance(module, TensorQuantizer) or not module.is_enabled: @@ -968,8 +984,18 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int: if amax is None: continue + # Grouped siblings share one ``weight_global_amax`` (common FP8 grid); + # otherwise fall back to this quantizer's own per-block amax. already_promoted = isinstance(module, NVFP4StaticQuantizer) - global_amax = reduce_amax(amax.clone().detach(), axis=None) + shared = getattr(module, "_shared_quant_state_ref", None) + if ( + shared is not None + and id(shared) in valid_shared_states + and shared.weight_global_amax is not None + ): + global_amax = shared.weight_global_amax + else: + global_amax = reduce_amax(amax.clone().detach(), axis=None) NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) if not already_promoted: converted += 1 diff --git a/modelopt/torch/quantization/utils/shared_input.py b/modelopt/torch/quantization/utils/shared_input.py new file mode 100644 index 00000000000..44372e28182 --- /dev/null +++ b/modelopt/torch/quantization/utils/shared_input.py @@ -0,0 +1,367 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Per-group shared quantization state for fusible sibling modules. + +Weight ``global_amax`` must be unified across modules that get **fused** at export +(q/k/v -> qkv, gate/up -> gate_up) so they quantize with one per-tensor scale. +:func:`find_shared_input_groups` discovers these groups by regex over module FQNs; +:data:`DEFAULT_WEIGHT_SHARED_PATTERNS` covers the standard q/k/v, gate/up and w1/w3 +names, and callers may override per quantizer kind via ``MaxCalibConfig.shared_patterns``. + +Discovery is name/pattern-based (not input-hook-based) on purpose: "shares an input +tensor" is broader than "gets fused" — e.g. a ``shared_expert_gate`` reads the same +hidden states as the GLU pair but is never fused with it, so a hook would over-group +it. Patterns match exactly the roles export fuses. + +:class:`SharedQuantState` is an ``nn.Module`` attached to a group's parent so its +tensors ride along in ``state_dict``. Holds only ``weight_global_amax`` today; +designed to grow (act scales, LoRA factors, ...). The ``(parent, members)`` tuples +from :func:`find_shared_input_groups` are consumed by +:func:`attach_shared_quant_states` and :func:`populate_shared_state`. +""" + +import re +from collections.abc import Sequence + +import torch +import torch.distributed as dist +import torch.nn as nn + +from modelopt.torch.utils.distributed import ParallelState + +from .core_utils import quantizer_attr_names, reduce_amax + +__all__ = [ + "DEFAULT_WEIGHT_SHARED_PATTERNS", + "SharedQuantState", + "attach_shared_quant_states", + "find_shared_input_groups", + "populate_shared_state", +] + +# Default fusible-sibling patterns for WEIGHT global_amax — the groups export fuses: +# q/k/v -> qkv, gate/up (incl. Mixtral w1/w3) -> gate_up. These reproduce the legacy +# name-based grouping exactly. Regexes are ``re.fullmatch``-ed against module FQNs; +# ``(?:(.*)\.)?`` captures the immediate parent so grouping is per-parent (per-expert +# for MoE experts). Override per quantizer kind via ``MaxCalibConfig.shared_patterns``. +DEFAULT_WEIGHT_SHARED_PATTERNS = [ + r"(?:(.*)\.)?(?:q_proj|k_proj|v_proj)", + r"(?:(.*)\.)?(?:gate_proj|up_proj)", + r"(?:(.*)\.)?(?:w1|w3)", +] + + +# --------------------------------------------------------------------------- +# Per-group shared state +# --------------------------------------------------------------------------- + + +class SharedQuantState(nn.Module): + """State shared across a sibling group of quantized modules. + + Attached to the group's parent (e.g. ``self_attn``, ``block_sparse_moe``) as a + submodule, but its buffers are **non-persistent**: this is a calibration-time + artifact, not part of the checkpoint. Members that quantize the same input resolve + shared values here during calibration; the resolved value is then baked into each + member's promoted quantizer (``NVFP4StaticQuantizer._global_amax``, which *is* + serialized). So the scale survives save/restore via the members, and restore need + not re-create this submodule (it isn't in ``state_dict``) — see + :func:`attach_shared_quant_states`. + + Holds only ``weight_global_amax`` today, and it is mirrored onto every member. Any + future field that is **not** mirrored on a member would not survive save/restore as + a non-persistent buffer and would need its own restore path. + """ + + def __init__(self) -> None: + """Initialize with an unset ``weight_global_amax`` and no registered members.""" + super().__init__() + # NVFP4 two-level FP8 grid scale = max over members' per-block ``_amax``. + # Non-persistent: a calibration-time artifact. The resolved value is baked into + # each member's ``NVFP4StaticQuantizer._global_amax`` (which IS serialized), so + # restore rebuilds scales from members and need not re-create this submodule. + self.register_buffer("weight_global_amax", None, persistent=False) + # Back-references to member modules. ``object.__setattr__`` keeps them out of + # ``_modules`` so members' params don't re-enter our ``state_dict``. + object.__setattr__(self, "_members", []) + + def sync_weight_global_amax(self, parallel_state: ParallelState | None) -> None: + """All-reduce (MAX) ``weight_global_amax`` across EP, plus TP defensively. + + Weights are DP-replicated (no DP sync needed). EP sync is required since + ranks hold different experts; TP sync guards against per-child ``_amax`` TP + sync skipping block-quantized weights. Raises on a failed all-reduce: a silent + failure would leave ranks with different scales and still promote/export them. + """ + if self.weight_global_amax is None or parallel_state is None: + return + for group in ( + parallel_state.expert_model_parallel_group, + parallel_state.tensor_parallel_group, + ): + if group is None or not group.is_initialized(): + continue + try: + dist.all_reduce( + self.weight_global_amax, + op=dist.ReduceOp.MAX, + group=group.group, + ) + except RuntimeError as e: + raise RuntimeError("Failed to sync shared weight_global_amax") from e + + +# --------------------------------------------------------------------------- +# Group discovery (regex over FQNs) → (parent, members) tuples +# --------------------------------------------------------------------------- + + +def _has_enabled_weight_quantizer(child: nn.Module, wq_attr: str) -> bool: + """A child is eligible if it has an enabled weight quantizer. + + Group membership is structural (a pattern over the module tree), independent of + calibration — so this does NOT require ``_amax``. That lets attach run before + ``weight_only_quantize``; per-member ``_amax`` is aggregated later in + :func:`populate_shared_state`. + """ + wq = getattr(child, wq_attr, None) + return wq is not None and hasattr(wq, "_disabled") and not wq._disabled + + +def _build_parent_map(model: nn.Module) -> dict[nn.Module, nn.Module]: + """Build a ``{child_module: direct_parent}`` map by walking ``named_children``.""" + parent_map: dict[nn.Module, nn.Module] = {} + for parent in model.modules(): + for child in parent.children(): + parent_map[child] = parent + return parent_map + + +def _climb_past_modulelist( + module: nn.Module, + parent_map: dict[nn.Module, nn.Module], + fallback: nn.Module, +) -> nn.Module: + """Walk up past any ``nn.ModuleList`` ancestors to a regular module. + + Attaching ``SharedQuantState`` to a ``ModuleList`` registers it in that + container's ``_modules`` and corrupts its iteration/length (the state shows up + alongside the experts), so attach to the first non-ModuleList ancestor. + + Only ``nn.ModuleList`` is handled today. It can be extended in the future to include modules + like nn.ModuleDict``. + """ + cur = module + while isinstance(cur, nn.ModuleList): + parent = parent_map.get(cur) + if parent is None or parent is cur: + return fallback + cur = parent + return cur + + +def _lowest_common_ancestor( + members: Sequence[nn.Module], + parent_map: dict[nn.Module, nn.Module], + fallback: nn.Module, +) -> nn.Module: + """LCA of ``members`` in the module tree (``fallback`` if none). + + Climbs past ``nn.ModuleList`` ancestors so the result can host + ``SharedQuantState`` as a submodule. + """ + if not members: + return fallback + + def ancestors(m: nn.Module) -> list[nn.Module]: + chain = [] + cur = m + while cur in parent_map: + cur = parent_map[cur] + chain.append(cur) + return chain + + chains = [ancestors(m) for m in members] + if not chains[0]: + return fallback + common = set(chains[0]) + for c in chains[1:]: + common &= set(c) + # Deepest common ancestor: first in member[0]'s chain that's in every chain. + for a in chains[0]: + if a in common: + return _climb_past_modulelist(a, parent_map, fallback) + return fallback + + +def find_shared_input_groups( + model: nn.Module, + patterns: Sequence[str] | None = None, +) -> list[tuple[nn.Module, list[nn.Module]]]: + r"""Find fusible sibling groups by regex over module FQNs; capture groups define the key. + + Each pattern is ``re.fullmatch``-ed against every quantized module's fully-qualified + name; modules whose match yields the same capture-group tuple form one group, parented + at their LCA. Granularity is set by *what you capture*: + + - Capture the immediate parent -> per-parent grouping: q/k/v per attention block, and + **per-expert** ``w1``/``w3`` (each expert is the immediate parent), e.g. + ``r"(.*)\.(?:w1|w3)$"``. + - Capture only a level above the expert index, leaving the index uncaptured -> one + **cross-expert** group, e.g. ``r"(.*)\.experts\.\d+\.(?:w1|w3)$"``. + + Roles to fuse together go in a non-capturing alternation ``(?:w1|w3)`` so they don't + split the key; what you wrap in ``(...)`` is the group boundary. Pass + :data:`DEFAULT_WEIGHT_SHARED_PATTERNS` for the standard q/k/v + gate/up groups, or + override via ``MaxCalibConfig.shared_patterns``. The caller selects which quantizer + these groups apply to (today only the weight quantizer). Returns ``(parent, members)`` + tuples; empty when no patterns are given. + """ + if not patterns: + return [] + wq_attr = quantizer_attr_names("weight").weight_quantizer + compiled = [re.compile(p) for p in patterns] + buckets: dict[tuple, list[nn.Module]] = {} + order: list[tuple] = [] + for name, module in model.named_modules(): + if not _has_enabled_weight_quantizer(module, wq_attr): + continue + for pattern_idx, regex in enumerate(compiled): + match = regex.fullmatch(name) + if match is not None: + # include pattern_idx in case 2+ patterns yield the same capture tuple + key = (pattern_idx, match.groups()) + if key not in buckets: + buckets[key] = [] + order.append(key) + buckets[key].append(module) + break # each module belongs to its first matching pattern + parent_map = _build_parent_map(model) + groups: list[tuple[nn.Module, list[nn.Module]]] = [] + for key in order: + members = buckets[key] + if len(members) >= 2: + parent = _lowest_common_ancestor(members, parent_map, fallback=model) + groups.append((parent, members)) + return groups + + +# --------------------------------------------------------------------------- +# Attach / populate lifecycle +# --------------------------------------------------------------------------- + + +def attach_shared_quant_states( + model: nn.Module, + patterns: Sequence[str] | None = None, +) -> int: + """Create ``SharedQuantState`` on each group's parent and link members. + + Groups are discovered by ``patterns`` (regexes over module FQNs; see + :func:`find_shared_input_groups`). The parent owns the state under + ``_shared_quant_state`` (normal setattr → a registered submodule, so its buffer + rides along in ``state_dict``). Each member's weight quantizer — the only consumer, + via ``promote_nvfp4_static_quantizers`` — gets a back-reference under the distinct + name ``_shared_quant_state_ref`` set with ``object.__setattr__`` (not a submodule, + so the buffer isn't duplicated per member). The distinct names let + ``populate_shared_state`` select owners with a plain ``getattr``. + + Idempotent (reuses an existing parent state). Returns the number created. + """ + n_created = 0 + wq_attr = quantizer_attr_names("weight").weight_quantizer + for parent, members in find_shared_input_groups(model, patterns=patterns): + if not hasattr(parent, "_shared_quant_state"): + parent._shared_quant_state = SharedQuantState() + n_created += 1 + state = parent._shared_quant_state + # Record members so populate_shared_state needn't re-run discovery. + object.__setattr__(state, "_members", list(members)) + for child in members: + wq = getattr(child, wq_attr, None) + if wq is None: + continue + # Groups are disjoint after merging, so each quantizer gets one state per + # call and a re-attach reuses the same object; a different existing state + # would mean an inconsistent re-attach. + existing = getattr(wq, "_shared_quant_state_ref", None) + assert existing is None or existing is state, ( + f"{type(wq).__name__} already belongs to a different shared-input " + "group; groups should be disjoint after merging." + ) + object.__setattr__(wq, "_shared_quant_state_ref", state) + return n_created + + +@torch.no_grad() +def populate_shared_state(model: nn.Module) -> int: + """Aggregate per-member stats into each group's ``SharedQuantState``. + + Currently sets ``weight_global_amax`` = max over members' reduced ``_amax``, + EP-synced so all ranks agree, then writes it back to each member's + ``global_amax`` (overriding any stale value from an earlier promotion). Future + fields plug in here as extra aggregation steps. + + Call after members' ``_amax`` is cross-rank consistent (post TP/DP/EP sync in + ``max_calibrate``). Members not yet promoted to ``NVFP4StaticQuantizer`` are + skipped on write-back; the next promotion reads the shared value instead. + Returns the number of groups populated. + """ + from modelopt.torch.quantization.nn import NVFP4StaticQuantizer + + wq_attr = quantizer_attr_names("weight").weight_quantizer + n_groups = 0 + + for parent in model.modules(): + # Owners hold the state under ``_shared_quant_state`` (members use the + # distinct ``_shared_quant_state_ref``), so getattr matches owners only. + state = getattr(parent, "_shared_quant_state", None) + if not isinstance(state, SharedQuantState): + continue + + members = getattr(state, "_members", []) + if not members: + continue + + child_maxes: list[torch.Tensor] = [] + parallel_state: ParallelState | None = None + for child in members: + wq = getattr(child, wq_attr, None) + amax = getattr(wq, "_amax", None) if wq is not None else None + # Skip uncalibrated or meta (no-data) amax. A meta amax — e.g. quantizing an + # ``init_empty_weights`` model before dispatch — would make weight_global_amax a + # meta buffer that then breaks the meta->device ``.to()`` (it needs ``to_empty``). + if amax is None or amax.is_meta: + continue + child_maxes.append(reduce_amax(amax, axis=None)) + if parallel_state is None: + parallel_state = getattr(child, "parallel_state", None) + + if not child_maxes: + continue + + local_max = torch.max(torch.stack(child_maxes)) + state.weight_global_amax = local_max + state.sync_weight_global_amax(parallel_state) + + synced = state.weight_global_amax + for child in members: + wq = getattr(child, wq_attr, None) + if isinstance(wq, NVFP4StaticQuantizer): + wq.global_amax = synced + n_groups += 1 + + return n_groups diff --git a/tests/unit/torch/quantization/test_mse_calibrator.py b/tests/unit/torch/quantization/test_mse_calibrator.py index 064a77a247e..0e7127dfe4d 100644 --- a/tests/unit/torch/quantization/test_mse_calibrator.py +++ b/tests/unit/torch/quantization/test_mse_calibrator.py @@ -23,7 +23,6 @@ from modelopt.torch.quantization.model_calib import ( _FP8_SWEEP_CALIBRATOR_REGISTRY, _make_weight_mse_calibrator, - _promote_nvfp4_static_quantizers_with_global_amax_sync, _register_fp8_sweep_calibrator, mse_calibrate, ) @@ -32,7 +31,12 @@ _QUANT_FUNCTIONAL_BACKENDS, register_quant_backend, ) -from modelopt.torch.quantization.utils import enable_fake_quant, promote_nvfp4_static_quantizers +from modelopt.torch.quantization.utils import ( + attach_shared_quant_states, + enable_fake_quant, + populate_shared_state, + promote_nvfp4_static_quantizers, +) # TODO: avoid code duplication in this file @@ -644,7 +648,7 @@ def test_modelopt_static_nvfp4_uses_fp8_scale_sweep(self): ) model = torch.nn.Module() model.weight_quantizer = q - _promote_nvfp4_static_quantizers_with_global_amax_sync(model) + promote_nvfp4_static_quantizers(model) cal = _make_weight_mse_calibrator( q, @@ -750,7 +754,7 @@ def __init__(self, amax): def test_standalone_static_nvfp4_quantizer_is_promoted(self): model = self._LinearLike(torch.tensor([1.0, 5.0])) - _promote_nvfp4_static_quantizers_with_global_amax_sync(model) + promote_nvfp4_static_quantizers(model) assert isinstance(model.weight_quantizer, NVFP4StaticQuantizer) assert torch.equal(model.weight_quantizer.global_amax, torch.tensor(5.0)) @@ -770,7 +774,9 @@ def test_grouped_static_nvfp4_quantizers_share_global_amax(self): model.k_proj = self._LinearLike(torch.tensor([3.0, 4.0])) model.v_proj = self._LinearLike(torch.tensor([5.0, 6.0])) - _promote_nvfp4_static_quantizers_with_global_amax_sync(model) + attach_shared_quant_states(model, patterns=[r"(?:(.*)\.)?(?:q_proj|k_proj|v_proj)"]) + populate_shared_state(model) + promote_nvfp4_static_quantizers(model) for child in (model.q_proj, model.k_proj, model.v_proj): assert isinstance(child.weight_quantizer, NVFP4StaticQuantizer) diff --git a/tests/unit/torch/quantization/test_shared_input.py b/tests/unit/torch/quantization/test_shared_input.py new file mode 100644 index 00000000000..ca7a1a53fee --- /dev/null +++ b/tests/unit/torch/quantization/test_shared_input.py @@ -0,0 +1,467 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for SharedQuantState — group-level quantization state on parent modules. + +These tests use a hand-built CPU model with Q/K/V siblings under a dummy +``self_attn`` parent, then drive ``max_calibrate`` directly so the run is +fast and deterministic without needing real attention/MoE layers. +""" + +import pytest +import torch +import torch.nn as nn + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.config import MaxCalibConfig, QuantizerAttributeConfig +from modelopt.torch.quantization.model_calib import max_calibrate +from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer +from modelopt.torch.quantization.utils import ( + DEFAULT_WEIGHT_SHARED_PATTERNS, + SharedQuantState, + attach_shared_quant_states, + find_shared_input_groups, + populate_shared_state, + promote_nvfp4_static_quantizers, + quantizer_attr_names, + reduce_amax, +) + +# The production default patterns (q/k/v, gate/up, w1/w3) are exactly what these tests +# need; reuse them so the tests also exercise the real default. ``re.fullmatch``-ed +# against module FQNs; ``(?:(.*)\.)?`` captures the immediate parent (or None at the +# model root, since these test models hold roles directly) -> per-parent / per-expert. +SIBLING_PATTERNS = DEFAULT_WEIGHT_SHARED_PATTERNS + + +NVFP4_BLOCK = 16 + + +def _make_nvfp4_static_cfg() -> QuantizerAttributeConfig: + """NVFP4 static block quantization config (E2M1 weights + E4M3 per-block scales).""" + return QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: NVFP4_BLOCK, "type": "static", "scale_bits": (4, 3)}, + ) + + +class _DummyAttention(nn.Module): + """A toy parent module that exposes ``q_proj``, ``k_proj``, ``v_proj`` siblings.""" + + def __init__(self, in_features: int = 32, out_features: int = 32) -> None: + super().__init__() + self.q_proj = nn.Linear(in_features, out_features, bias=False) + self.k_proj = nn.Linear(in_features, out_features, bias=False) + self.v_proj = nn.Linear(in_features, out_features, bias=False) + + def forward(self, x): + return self.q_proj(x) + self.k_proj(x) + self.v_proj(x) + + +def _populate_amax(linear: nn.Module, value: float) -> None: + """Directly set ``_amax`` on a linear's weight_quantizer for deterministic testing.""" + wq_attr = quantizer_attr_names("weight").weight_quantizer + wq = getattr(linear, wq_attr) + # Match the per-block shape the real calibrator would produce + out_features, in_features = linear.weight.shape + n_blocks = in_features // NVFP4_BLOCK + wq._amax = torch.full( + (out_features, n_blocks), value, dtype=torch.float32, device=linear.weight.device + ) + + +class TestSharedQuantStateBasics: + """Direct exercise of the SharedQuantState container and its helpers.""" + + def test_init_unset(self): + s = SharedQuantState() + assert s.weight_global_amax is None + + def test_attach_creates_state_on_parent(self): + attn = _DummyAttention() + # Configure with NVFP4-static weight quantizers and seed _amax so attach finds them. + mtq.replace_quant_module(attn) + cfg = _make_nvfp4_static_cfg() + for proj in (attn.q_proj, attn.k_proj, attn.v_proj): + proj.weight_quantizer.set_from_attribute_config(cfg) + _populate_amax(proj, value=1.0) + + n = attach_shared_quant_states(attn, patterns=SIBLING_PATTERNS) + assert n == 1, f"expected one new state, got {n}" + assert hasattr(attn, "_shared_quant_state") + for proj in (attn.q_proj, attn.k_proj, attn.v_proj): + assert proj.weight_quantizer._shared_quant_state_ref is attn._shared_quant_state + + def test_find_groups_skips_singletons(self): + """A parent with only one matching child must NOT form a group.""" + + class _OnlyOne(nn.Module): + def __init__(self): + super().__init__() + self.q_proj = nn.Linear(16, 16, bias=False) + + m = _OnlyOne() + mtq.replace_quant_module(m) + cfg = _make_nvfp4_static_cfg() + m.q_proj.weight_quantizer.set_from_attribute_config(cfg) + _populate_amax(m.q_proj, value=1.0) + + groups = find_shared_input_groups(m, patterns=SIBLING_PATTERNS) + assert groups == [], "single sibling must not form a group" + + def test_default_patterns_skip_non_fusible_gate(self): + """A gate sharing the block input but never fused (e.g. ``shared_expert_gate``) + must NOT be grouped with the gate_proj/up_proj pair. + + This is why grouping is name/pattern-based, not shared-input-hook-based: a hook + would lump ``shared_expert_gate`` in with the GLU pair (same input tensor) and + wrongly unify its global_amax. The default patterns match only the fused roles. + """ + + class _MLPWithGate(nn.Module): + def __init__(self): + super().__init__() + self.gate_proj = nn.Linear(32, 32, bias=False) + self.up_proj = nn.Linear(32, 32, bias=False) + self.shared_expert_gate = nn.Linear(32, 1, bias=False) # shares input, not fused + + m = _MLPWithGate() + mtq.replace_quant_module(m) + cfg = _make_nvfp4_static_cfg() + for lin in (m.gate_proj, m.up_proj, m.shared_expert_gate): + lin.weight_quantizer.set_from_attribute_config(cfg) + _populate_amax(lin, value=1.0) + + groups = find_shared_input_groups(m, patterns=DEFAULT_WEIGHT_SHARED_PATTERNS) + assert len(groups) == 1 + _parent, members = groups[0] + assert set(members) == {m.gate_proj, m.up_proj} + assert m.shared_expert_gate not in members + + def test_populate_writes_max_across_siblings(self): + attn = _DummyAttention() + mtq.replace_quant_module(attn) + cfg = _make_nvfp4_static_cfg() + # Seed deterministic _amax values: q=1.0, k=3.0, v=2.0 → max=3.0 + attn.q_proj.weight_quantizer.set_from_attribute_config(cfg) + attn.k_proj.weight_quantizer.set_from_attribute_config(cfg) + attn.v_proj.weight_quantizer.set_from_attribute_config(cfg) + _populate_amax(attn.q_proj, value=1.0) + _populate_amax(attn.k_proj, value=3.0) + _populate_amax(attn.v_proj, value=2.0) + + attach_shared_quant_states(attn, patterns=SIBLING_PATTERNS) + n_groups = populate_shared_state(attn) + + assert n_groups == 1 + shared = attn._shared_quant_state.weight_global_amax + assert shared is not None + assert torch.isclose(shared, torch.tensor(3.0)), f"expected 3.0, got {shared.item()}" + + def test_populate_skips_meta_amax(self): + """Meta (no-data) ``_amax`` must not become a meta ``weight_global_amax`` buffer. + + Quantizing an ``init_empty_weights`` model produces meta ``_amax``; aggregating it + would make ``weight_global_amax`` a meta buffer that breaks the later meta->device + ``.to()`` during dispatch. The group is skipped instead, leaving the buffer ``None``. + """ + attn = _DummyAttention() + mtq.replace_quant_module(attn) + cfg = _make_nvfp4_static_cfg() + for proj in (attn.q_proj, attn.k_proj, attn.v_proj): + proj.weight_quantizer.set_from_attribute_config(cfg) + out_features, in_features = proj.weight.shape + proj.weight_quantizer._amax = torch.empty( + (out_features, in_features // NVFP4_BLOCK), device="meta" + ) + + attach_shared_quant_states(attn, patterns=SIBLING_PATTERNS) # groups q/k/v (no amax needed) + n_groups = populate_shared_state(attn) + + assert n_groups == 0 # nothing real to aggregate + assert attn._shared_quant_state.weight_global_amax is None # not a meta tensor + + def test_promote_ignores_shared_state_outside_root(self): + """Promoting a submodule must ignore a back-ref whose owning state is outside it. + + ``promote_nvfp4_static_quantizers`` also runs on submodules/individual linears; a + quantizer may still carry ``_shared_quant_state_ref`` from an earlier full-model + run. If the owning ``_shared_quant_state`` is not within the promotion root, the + quantizer must fall back to its OWN amax, not the stale group value. + """ + attn = _DummyAttention() + mtq.replace_quant_module(attn) + cfg = _make_nvfp4_static_cfg() + for proj, val in ((attn.q_proj, 1.0), (attn.k_proj, 3.0), (attn.v_proj, 2.0)): + proj.weight_quantizer.set_from_attribute_config(cfg) + _populate_amax(proj, value=val) + # Group on the parent → shared weight_global_amax = max = 3.0. + attach_shared_quant_states(attn, patterns=SIBLING_PATTERNS) + populate_shared_state(attn) + assert torch.isclose(attn._shared_quant_state.weight_global_amax, torch.tensor(3.0)) + + # Promote with q_proj as the root: it does NOT contain attn._shared_quant_state + # (that lives on the parent), so the stale ref is ignored → own amax (1.0), not 3.0. + promote_nvfp4_static_quantizers(attn.q_proj) + ga = attn.q_proj.weight_quantizer.global_amax + assert torch.isclose(ga, torch.tensor(1.0)), f"expected own amax 1.0, got {ga.item()}" + + +class _MoEExpert(nn.Module): + """A toy MoE expert with Mixtral-style ``w1`` (gate) and ``w3`` (up) projections.""" + + def __init__(self, hidden=32, intermediate=32) -> None: + super().__init__() + self.w1 = nn.Linear(hidden, intermediate, bias=False) + self.w3 = nn.Linear(hidden, intermediate, bias=False) + + +class _MoEBlock(nn.Module): + """MoE block holding experts in an ``nn.ModuleList``.""" + + def __init__(self, n_experts: int = 4) -> None: + super().__init__() + self.experts = nn.ModuleList(_MoEExpert() for _ in range(n_experts)) + + +class TestMoESharedState: + """SharedQuantState groups ``w1``/``w3`` per expert via sibling-scope patterns. + + Cross-expert grouping is intentionally not covered here: weight amax is + per-expert (gate==up *within* an expert), so a single scale across experts + is only meaningful for the input quantizer — to be added with that feature. + """ + + def _setup_moe(self, n_experts: int = 4) -> _MoEBlock: + block = _MoEBlock(n_experts=n_experts) + mtq.replace_quant_module(block) + cfg = _make_nvfp4_static_cfg() + for i, expert in enumerate(block.experts): + # Distinct amax values so the max is determined and identifiable. + for j, proj in enumerate((expert.w1, expert.w3)): + proj.weight_quantizer.set_from_attribute_config(cfg) + _populate_amax(proj, value=1.0 + i + 0.1 * j) + return block + + def test_sibling_patterns_group_per_expert(self): + """Sibling-scope w1/w3 patterns group each expert independently (per-expert).""" + block = self._setup_moe(n_experts=4) + groups = find_shared_input_groups(block, patterns=SIBLING_PATTERNS) + # One [w1, w3] group per expert, parented at that expert (not the block). + assert len(groups) == 4 + experts = list(block.experts) + for parent, members in groups: + assert parent in experts + assert len(members) == 2 + + attach_shared_quant_states(block, patterns=SIBLING_PATTERNS) + n_groups = populate_shared_state(block) + assert n_groups == 4 + assert not hasattr(block, "_shared_quant_state") # state is on each expert, not the block + # Each expert's max is within-expert: max(w1=1.0+i, w3=1.0+i+0.1) = 1.1 + i. + for i, expert in enumerate(block.experts): + shared = expert._shared_quant_state.weight_global_amax + assert torch.isclose(shared, torch.tensor(1.1 + i)), f"expert {i}: {shared.item()}" + + +class TestMaxCalibrateEndToEnd: + """End-to-end through ``max_calibrate``: same shared global_amax across siblings.""" + + def _setup_attention(self, scales=(0.5, 2.0, 1.0)) -> _DummyAttention: + """Build an attention block with NVFP4-static weight quantizers; distinct weight scales.""" + attn = _DummyAttention(in_features=32, out_features=32) + # Bias weight magnitudes so per-projection amaxes differ. + with torch.no_grad(): + attn.q_proj.weight.mul_(scales[0]) + attn.k_proj.weight.mul_(scales[1]) + attn.v_proj.weight.mul_(scales[2]) + mtq.replace_quant_module(attn) + cfg = _make_nvfp4_static_cfg() + for proj in (attn.q_proj, attn.k_proj, attn.v_proj): + proj.weight_quantizer.set_from_attribute_config(cfg) + # Other quantizers on the wrapped module would interfere with the + # weight-only calibration path, so disable them. + for name in ("input_quantizer", "output_quantizer"): + q = getattr(proj, name, None) + if isinstance(q, TensorQuantizer): + q.disable() + return attn + + @pytest.mark.parametrize("distributed_sync", [True, False]) + def test_qkv_share_global_amax_via_max_calibrate(self, distributed_sync): + """After max_calibrate, q/k/v_proj have identical global_amax (the group max).""" + attn = self._setup_attention() + + # Drive a forward pass so input shape gets observed by the weight calibrators. + def fwd(m): + x = torch.randn(2, 32) + m.q_proj(x) + m.k_proj(x) + m.v_proj(x) + + max_calibrate(attn, forward_loop=fwd, distributed_sync=distributed_sync) + + # All siblings should be promoted to NVFP4StaticQuantizer with same value. + global_amaxes = [] + for proj in (attn.q_proj, attn.k_proj, attn.v_proj): + assert isinstance(proj.weight_quantizer, NVFP4StaticQuantizer) + ga = proj.weight_quantizer.global_amax + assert ga is not None + global_amaxes.append(ga.item()) + + assert global_amaxes[0] == global_amaxes[1] == global_amaxes[2], ( + f"siblings should share global_amax, got {global_amaxes}" + ) + + # The shared value must equal the max over each child's own _amax reduction. + per_child_max = max( + reduce_amax(proj.weight_quantizer._amax, axis=None).item() + for proj in (attn.q_proj, attn.k_proj, attn.v_proj) + ) + assert global_amaxes[0] == pytest.approx(per_child_max), ( + f"shared global_amax {global_amaxes[0]} != per-child max {per_child_max}" + ) + + def test_standalone_linear_no_shared_state(self): + """A linear without siblings has no shared state attached.""" + + class _Lonely(nn.Module): + def __init__(self): + super().__init__() + self.proj = nn.Linear(32, 32, bias=False) + + m = _Lonely() + mtq.replace_quant_module(m) + m.proj.weight_quantizer.set_from_attribute_config(_make_nvfp4_static_cfg()) + for name in ("input_quantizer", "output_quantizer"): + q = getattr(m.proj, name, None) + if isinstance(q, TensorQuantizer): + q.disable() + + max_calibrate(m, forward_loop=lambda m: m.proj(torch.randn(2, 32)), distributed_sync=False) + + # Promoted, but no shared state attached because there's no sibling group. + assert isinstance(m.proj.weight_quantizer, NVFP4StaticQuantizer) + assert not hasattr(m.proj.weight_quantizer, "_shared_quant_state_ref") + # Standalone global_amax should equal its own reduce_amax(_amax). + expected = reduce_amax(m.proj.weight_quantizer._amax, axis=None).item() + assert m.proj.weight_quantizer.global_amax.item() == pytest.approx(expected) + + def test_shared_state_buffer_is_non_persistent(self): + """The shared buffer is a calibration-time artifact and must NOT be in state_dict. + + The scale is carried by each member's promoted quantizer (``_global_amax``), which + IS serialized. If the shared buffer were persistent it would add + ``_shared_quant_state.weight_global_amax`` keys that restore can't match (the + submodule isn't re-created on load) — the regression covered end-to-end below. + """ + attn = self._setup_attention() + + def fwd(m): + x = torch.randn(2, 32) + m.q_proj(x) + m.k_proj(x) + m.v_proj(x) + + max_calibrate(attn, forward_loop=fwd, distributed_sync=False) + + assert hasattr(attn, "_shared_quant_state") # exists at runtime (calibration artifact) + sd = attn.state_dict() + # Non-persistent: the shared buffer must NOT appear in state_dict. + assert not [k for k in sd if k.endswith("_shared_quant_state.weight_global_amax")] + # The value lives on each member's quantizer (``_global_amax``), which IS persisted, + # and the runtime back-reference is set (but not as a child submodule). + for role in ("q_proj", "k_proj", "v_proj"): + proj = getattr(attn, role) + assert proj.weight_quantizer._shared_quant_state_ref is attn._shared_quant_state + assert "_shared_quant_state_ref" not in proj.weight_quantizer._modules + assert f"{role}.weight_quantizer._global_amax" in sd + + def test_modelopt_save_restore_with_shared_state(self, tmp_path): + """``mtq.quantize`` -> ``mto.save`` -> ``mto.restore`` on a FRESH model round-trips. + + Regression for two save/restore bugs the shared state introduced: (a) the runtime + back-ref must be excluded from ``get_modelopt_state`` (else save pickles a live + ``QuantLinear``), and (b) the shared buffer must be non-persistent (else + ``load_state_dict`` on the fresh, submodule-less model fails on the unexpected key). + """ + cfg = { + "quant_cfg": [ + {"enable": False, "quantizer_name": "*"}, + { + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: NVFP4_BLOCK, "type": "static", "scale_bits": (4, 3)}, + }, + "quantizer_name": "*weight_quantizer", + }, + ], + "algorithm": "max", + } + attn = _DummyAttention() # fresh (un-quantized) — mtq.quantize rejects re-quantizing + mtq.quantize(attn, cfg, lambda m: m(torch.randn(2, 32))) + + # Grouping happened, and each member carries its own promoted global_amax. + assert hasattr(attn, "_shared_quant_state") + expected = { + role: getattr(attn, role).weight_quantizer.global_amax.item() + for role in ("q_proj", "k_proj", "v_proj") + } + + # Save must not raise (pickling metadata.quantizer_state). + path = tmp_path / "model.pth" + mto.save(attn, path) + + # Restore into a fresh model must not raise (no _shared_quant_state.* key to match). + restored = _DummyAttention() + mto.restore(restored, path) + + for role in ("q_proj", "k_proj", "v_proj"): + wq = getattr(restored, role).weight_quantizer + assert isinstance(wq, NVFP4StaticQuantizer) + assert wq.global_amax.item() == pytest.approx(expected[role]) + + def test_empty_weight_patterns_disable_grouping(self): + """``shared_patterns={"weight": []}`` disables grouping (key presence, not truthiness).""" + attn = self._setup_attention(scales=(0.5, 2.0, 1.0)) # distinct per-proj amaxes + + def fwd(m): + x = torch.randn(2, 32) + m.q_proj(x) + m.k_proj(x) + m.v_proj(x) + + max_calibrate( + attn, forward_loop=fwd, distributed_sync=False, shared_patterns={"weight": []} + ) + + # No sibling group: no shared state, and each proj keeps its OWN global_amax. + assert not hasattr(attn, "_shared_quant_state") + gas = [] + for proj in (attn.q_proj, attn.k_proj, attn.v_proj): + assert isinstance(proj.weight_quantizer, NVFP4StaticQuantizer) + assert not hasattr(proj.weight_quantizer, "_shared_quant_state_ref") + gas.append(proj.weight_quantizer.global_amax.item()) + assert len(set(gas)) > 1, f"grouping should be disabled, but global_amax all equal: {gas}" + + def test_config_rejects_invalid_shared_patterns(self): + """Bad keys and bad regexes are rejected when the config is parsed, not at calib time.""" + MaxCalibConfig(shared_patterns={"weight": [r"(?:(.*)\.)?(?:q_proj|k_proj)"]}) # valid + MaxCalibConfig(shared_patterns={"weight": []}) # empty list is valid (disables grouping) + with pytest.raises(ValueError, match="unsupported quantizer kind"): + MaxCalibConfig(shared_patterns={"weigth": [r".*"]}) # typo'd key + with pytest.raises(ValueError, match="invalid regex"): + MaxCalibConfig(shared_patterns={"weight": ["("]}) # unbalanced paren