Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@

"""

import re
import warnings
from collections.abc import Mapping, Sequence
from typing import Any, Literal
Expand Down Expand Up @@ -715,6 +716,50 @@ class MaxCalibConfig(QuantizeAlgorithmConfig):
),
)

shared_patterns: dict[str, list[str]] | None = ModeloptField(
default=None,
title="Regex patterns for groups that share quantization state",
description=(
"Optional dict keyed by quantizer kind (``'weight'`` and/or ``'input'``), each a list "
"of regexes matched (full-match) against module fully-qualified names. They must list "
"every group you want for that kind. Modules whose match yields the same capture-group "
"tuple form one group; the capture boundary chooses granularity: capture the immediate "
"parent for per-parent / per-expert groups (e.g. ``r'(.*)\\.(?:q_proj|k_proj|v_proj)'``, "
"``r'(.*)\\.(?:w1|w3)'``); leave the expert index uncaptured for one cross-expert group "
"(``r'(.*)\\.experts\\.\\d+\\.(?:w1|w3)'``). Only ``'weight'`` is used today; ``'input'`` is "
"reserved for future input-quantizer sharing. When the ``'weight'`` list is omitted, "
"the default fusible patterns (q/k/v, gate/up, w1/w3) are used — these match exactly "
"the sibling groups export fuses, avoiding the over-grouping a shared-input heuristic "
"would cause (e.g. a ``shared_expert_gate`` that reads the same input but is not fused)."
),
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

@field_validator("shared_patterns")
@classmethod
def validate_shared_patterns(cls, v):
"""Reject unknown quantizer kinds and invalid regexes at the config boundary."""
if v is None:
return v
supported = {"weight", "input"}
unknown = set(v) - supported
if unknown:
raise ValueError(
f"shared_patterns has unsupported quantizer kind(s) {sorted(unknown)}; "
f"expected keys from {sorted(supported)}."
)
offending = ("", "") # (kind, pattern) of the last regex tried; set before each compile
try:
for kind, patterns in v.items():
for pattern in patterns:
offending = (kind, pattern)
re.compile(pattern)
except re.error as e:
bad_kind, bad_pattern = offending
raise ValueError(
f"shared_patterns[{bad_kind!r}] has an invalid regex {bad_pattern!r}: {e}"
) from e
return v


class MseCalibConfig(QuantizeAlgorithmConfig):
"""Configuration for per-tensor MSE calibration.
Expand Down
102 changes: 34 additions & 68 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import math
import time
import warnings
from collections.abc import Callable
from collections.abc import Callable, Mapping, Sequence
from functools import partial
from typing import TypeAlias

Expand All @@ -41,6 +41,8 @@
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer
from .utils import (
DEFAULT_WEIGHT_SHARED_PATTERNS,
attach_shared_quant_states,
disable_calib,
enable_fake_quant,
enable_quant,
Expand All @@ -49,8 +51,8 @@
is_quantized_linear,
is_quantized_row_parallel_linear,
persistent_materialization,
populate_shared_state,
promote_nvfp4_static_quantizers,
quantizer_attr_names,
reduce_amax,
)
from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper
Expand All @@ -74,69 +76,10 @@ def _check_nvfp4_static_tp_supported(model: nn.Module) -> None: # no-op without
]


def _is_calibrated_nvfp4_static(q) -> bool:
"""True iff ``q`` is an enabled NVFP4-static weight quantizer with ``_amax`` set."""
return (
isinstance(q, NVFP4StaticQuantizer)
and not q._disabled
and q.is_nvfp4_static
and getattr(q, "_amax", None) is not None
)


def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]:
"""Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers."""
# Inline: layer_utils → quant_utils → model_calib cycle.
from modelopt.torch.export.layer_utils import _GATE_UP_PAIRS

# Reuses the existing gate/up pairs and adds Q/K/V (no equivalent constant
# in export). Single source for the gate/up half avoids parallel lists.
patterns: tuple[tuple[str, ...], ...] = (("q_proj", "k_proj", "v_proj"), *_GATE_UP_PAIRS)
groups: list[list[nn.Module]] = []
wq_attr = quantizer_attr_names("weight").weight_quantizer
for parent in model.modules():
for sibling_names in patterns:
members = [
child
for child in (getattr(parent, n, None) for n in sibling_names)
if child is not None and _is_calibrated_nvfp4_static(getattr(child, wq_attr, None))
]
if len(members) >= 2:
groups.append(members)
return groups


def _collect_weight_stats(quantizer: nn.Module, weight: torch.Tensor) -> None:
quantizer(weight)


@torch.no_grad()
def _sync_grouped_weight_global_amax(model: nn.Module) -> int:
"""Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers.

Run after ``max_calibrate``. Sibling discovery is name-based via
``_collect_grouped_linears``; non-matching architectures (wqkv, fused
qkv_proj, DeepSeek variants, single-Linear fused gate_up_proj) silently
fall back to per-module global_amax. Fused-experts containers already
share a single quantizer across gate/up halves and need no sync.
"""
# quant_utils imports back from this module; top-level would cycle.
from modelopt.torch.export.quant_utils import preprocess_linear_fusion

n_groups = 0
for group in _collect_grouped_linears(model):
preprocess_linear_fusion(group)
n_groups += 1
return n_groups


@torch.no_grad()
def _promote_nvfp4_static_quantizers_with_global_amax_sync(model: nn.Module) -> None:
"""Promote static NVFP4 weight quantizers and sync grouped global amax."""
promote_nvfp4_static_quantizers(model)
_sync_grouped_weight_global_amax(model)


CalibratorFactory: TypeAlias = Callable[
[torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator
]
Expand Down Expand Up @@ -249,6 +192,7 @@ def max_calibrate(
forward_loop: ForwardLoop | None = None,
distributed_sync=True,
sync_expert_weight_amax=False,
shared_patterns: Mapping[str, Sequence[str]] | None = None,
):
"""Calibrate the model using max.

Expand All @@ -259,10 +203,31 @@ def max_calibrate(
distributed_sync: Whether to sync input_quantizer amax across distributed processes.
sync_expert_weight_amax: SequentialMLP only — share one weight amax across all experts
in a MoE layer (within-rank sync + EP all-reduce when EP>1).
shared_patterns: Optional dict keyed by quantizer kind (``"weight"``/``"input"``), each a
list of regexes over module FQNs. When the ``"weight"`` list is omitted,
:data:`DEFAULT_WEIGHT_SHARED_PATTERNS` (q/k/v, gate/up, w1/w3) is used. Modules whose
regex match yields the same capture-group tuple form one group — capture the immediate
parent for per-parent (per-expert) grouping, or leave the expert index uncaptured for
cross-expert. Only ``"weight"`` is used today; ``"input"`` is reserved for future
input-quantizer sharing.

See :class:`MaxCalibConfig <modelopt.torch.quantization.config.MaxCalibConfig>` for
details on the remaining arguments.
"""
# Discover fusible sibling groups by name regex and attach the (initially empty) shared
# state up front, so the SharedQuantState container exists for the whole calibration —
# forward-time fields can accumulate into it. Discovery is structural (a pattern over the
# module tree), so it needs no ``_amax``; per-member values are aggregated later by
# populate_shared_state, after the forward and any cross-rank ``_amax`` sync. Default to
# q/k/v + gate/up when no "weight" key is given; an explicit (possibly empty) list
# overrides it — key presence, not truthiness, so {"weight": []} disables grouping.
# Only "weight" is consumed today; "input" is reserved.
if shared_patterns is not None and "weight" in shared_patterns:
weight_patterns = list(shared_patterns["weight"])
else:
weight_patterns = DEFAULT_WEIGHT_SHARED_PATTERNS
attach_shared_quant_states(model, patterns=weight_patterns)

# Always run weight calibration on the weight tensor directly so every weight
# quantizer gets ``_amax``, regardless of MoE routing. Downstream algorithms
# (MSE, AWQ, export) then no longer need to patch in a missing ``_amax``.
Expand All @@ -280,14 +245,10 @@ def max_calibrate(
# Fail fast on NVFP4 static-block with TP>1 (sharded_state_dict treats _amax as replicated).
_check_nvfp4_static_tp_supported(model)

# Promote eligible static-block NVFP4 weight quantizers to NVFP4StaticQuantizer so
# the static blockwise fake-quant path is used in forward and export picks up the
# two-level (per-block + global) scaling. Run before the ``distributed_sync`` early
# return so single-process callers also get the promotion. No-op for dynamic-block
# / non-NVFP4 configs.
_promote_nvfp4_static_quantizers_with_global_amax_sync(model)

if not distributed_sync:
# Single-process: _amax is final — aggregate shared global_amax, then promote.
populate_shared_state(model)
promote_nvfp4_static_quantizers(model)
return

# Check MoE calibration completeness before sync
Expand Down Expand Up @@ -404,6 +365,11 @@ def sync_quantizer_amax_across_tp(
module.parallel_state.tensor_parallel_group
)

# _amax is now cross-rank consistent — aggregate shared global_amax then
# promote, so siblings read the unified value instead of their own _amax.
populate_shared_state(model)
promote_nvfp4_static_quantizers(model)


def _mse_quant_func(x, amax, quantizer):
"""Quantization function for MSE calibration."""
Expand Down
4 changes: 4 additions & 0 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ class TensorQuantizer(nn.Module):
"pre_bwd_fn",
# quantizer cache for custom backends, like luts
"_quantizer_cache",
# Runtime-only back-reference to a sibling group's SharedQuantState; it is
# re-established during calibration and must not be serialized (it points to a
# live module whose dynamic QuantLinear members are not picklable).
"_shared_quant_state_ref",
}

def __init__(
Expand Down
6 changes: 6 additions & 0 deletions modelopt/torch/quantization/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,21 @@

from .core_utils import *
from .layerwise_calib import LayerActivationCollector
from .shared_input import *

__all__ = [
"DEFAULT_WEIGHT_SHARED_PATTERNS",
"EXPORT_MODE",
"SharedQuantState",
"attach_shared_quant_states",
"convert_quantization_axis_to_reduce_axis",
"export_torch_mode",
"find_shared_input_groups",
"is_quantized",
"is_quantized_column_parallel_linear",
"is_quantized_linear",
"is_quantized_row_parallel_linear",
"populate_shared_state",
"reduce_amax",
"reduce_sum",
"replace_function",
Expand Down
28 changes: 27 additions & 1 deletion modelopt/torch/quantization/utils/core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,10 +954,26 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int:
need to be promoted so they use the two-level scaling path (global amax +
per-block amax) instead of the generic E4M3 path.

If the quantizer has a ``_shared_quant_state_ref`` with a populated
``weight_global_amax`` (sibling group) whose owning state lives within ``model``,
that shared value is used instead of this quantizer's own ``_amax`` reduction,
keeping siblings on a common FP8 grid.

Returns the number of quantizers converted.
"""
from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer

# Shared states owned within THIS promotion root. This function also runs on
# submodules / individual linears; a quantizer may still carry a back-reference from
# an earlier full-model calibration whose owning ``_shared_quant_state`` is outside
# ``model``. Only trust refs reachable here — otherwise the global_amax would come
# from an unrelated prior run; fall back to the quantizer's own amax instead.
valid_shared_states = {
id(state)
for owner in model.modules()
if (state := getattr(owner, "_shared_quant_state", None)) is not None
}

converted = 0
for _name, module in list(model.named_modules()):
if not isinstance(module, TensorQuantizer) or not module.is_enabled:
Expand All @@ -968,8 +984,18 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int:
if amax is None:
continue

# Grouped siblings share one ``weight_global_amax`` (common FP8 grid);
# otherwise fall back to this quantizer's own per-block amax.
already_promoted = isinstance(module, NVFP4StaticQuantizer)
global_amax = reduce_amax(amax.clone().detach(), axis=None)
shared = getattr(module, "_shared_quant_state_ref", None)
if (
shared is not None
and id(shared) in valid_shared_states
and shared.weight_global_amax is not None
):
global_amax = shared.weight_global_amax
else:
global_amax = reduce_amax(amax.clone().detach(), axis=None)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
if not already_promoted:
converted += 1
Expand Down
Loading
Loading