Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Changelog
- Add mixed-precision FP8 + NVFP4 export for Megatron-Core: per-layer ``quant_algo`` recorded under ``quantized_layers`` in ``hf_quant_config.json``, PP-aware ``kv_cache_dtype`` gather, fused-QKV exclude split into per-HF-name ``q/k/v_proj`` entries.
- Add Nemotron-3-Super-120B-A12B PTQ recipes ``modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml`` (MSE-mixed) and ``super-nvfp4-max-calib.yaml`` (max-calib mixed): NVFP4 W4A4 routed experts + FP8 per-tensor shared experts / Mamba in/out_proj + FP8 KV cache.
- Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default).
- Group layerwise calibration options under a nested ``LayerwiseConfig`` and add three knobs: ``get_qdq_activations_from_prev_layer`` (correct GPTQ-Hessian vs max-calib activation semantics — defaults to True for GPTQ, False for max/mse/local_hessian), ``save_every`` (gate per-window ``next_inputs.pt`` activation-cache writes), and ``save_quantizers_only`` (skip the layer-weights blob for amax-only algorithms — whitelisted to ``max``/``mse``/``local_hessian``). Legacy bool ``layerwise`` and flat ``layerwise_checkpoint_dir`` keys still work; the bool form emits a ``DeprecationWarning``.
- Group layerwise calibration options under a nested ``LayerwiseConfig`` and add three knobs: ``get_qdq_activations_from_prev_layer`` (correct GPTQ-Hessian vs max-calib activation semantics — defaults to True for GPTQ, False for max/mse/local_hessian), ``save_every`` (gate per-window ``next_inputs.pt`` activation-cache writes), and ``calib_mutates_weights`` (set False to skip layer-weight checkpointing/writeback when calibration only updates quantizer state). Legacy bool ``layerwise`` and flat ``layerwise_checkpoint_dir`` keys still work; the bool form emits a ``DeprecationWarning``.

**Bug Fixes**

Expand Down
23 changes: 12 additions & 11 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,13 +678,14 @@ class LayerwiseConfig(ModeloptBaseConfig):
),
)

save_quantizers_only: bool = ModeloptField(
default=False,
title="Skip the per-layer weights blob; persist only quantizer state.",
calib_mutates_weights: bool = ModeloptField(
default=True,
title="Whether layerwise calibration mutates layer weights.",
description=(
"Only accepted by algorithms that update solely ``TensorQuantizer._amax`` "
"(max, mse, local_hessian). Rejected for weight-mutating algorithms "
"(GPTQ, AWQ, SmoothQuant) where it would silently lose updates on resume."
"Set to False only for algorithms that update solely "
"``TensorQuantizer._amax`` (max, mse, local_hessian). Rejected for "
"weight-mutating algorithms (GPTQ, AWQ, SmoothQuant) where it would "
"silently lose updates on resume."
),
)

Expand Down Expand Up @@ -712,7 +713,7 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
"""Calibration algorithm config base."""

# Set True only for algorithms that update solely ``TensorQuantizer._amax``
# (no ``layer.weight`` mutation). Gates ``layerwise.save_quantizers_only``.
# (no ``layer.weight`` mutation). Gates ``layerwise.calib_mutates_weights=False``.
_supports_save_quantizers_only: ClassVar[bool] = False

method: Literal[None] = ModeloptField(
Expand Down Expand Up @@ -794,12 +795,12 @@ def validate_layerwise_checkpoint_dir(self):
return self

@model_validator(mode="after")
def _validate_save_quantizers_only_supported(self):
"""Enforce the ``_supports_save_quantizers_only`` whitelist."""
if self.layerwise.save_quantizers_only and not self._supports_save_quantizers_only:
def _validate_non_mutating_layerwise_supported(self):
"""Enforce the ``calib_mutates_weights=False`` whitelist."""
if not self.layerwise.calib_mutates_weights and not self._supports_save_quantizers_only:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can rename _supports_save_quantizers_only to names like _calib_is_amax_only to align with new vocabulary, otherwise a reader has to learn that _supports_save_quantizers_only=True means "amax-only algorithm, so calib_mutates_weights=False is allowed"

raise ValueError(
f"Algorithm '{self.method}' mutates layer weights in-place; "
"save_quantizers_only=True would lose those updates on resume. "
"calib_mutates_weights=False would lose those updates on resume. "
"Only max/mse/local_hessian (amax-only) support this flag."
)
return self
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def wrapped_calib_func(
checkpoint_dir = layerwise_cfg.get("checkpoint_dir")
qdq_from_prev = layerwise_cfg.get("get_qdq_activations_from_prev_layer", False)
save_every = layerwise_cfg.get("save_every", 1)
save_quantizers_only = layerwise_cfg.get("save_quantizers_only", False)
calib_mutates_weights = layerwise_cfg.get("calib_mutates_weights", True)
if method is not None and "awq" in method:
# For backward compatibility
kwargs["algorithm"] = method
Expand Down Expand Up @@ -263,7 +263,7 @@ def wrapped_calib_func(
checkpoint_dir=checkpoint_dir,
get_qdq_activations_from_prev_layer=qdq_from_prev,
save_every=save_every,
save_quantizers_only=save_quantizers_only,
calib_mutates_weights=calib_mutates_weights,
**kwargs,
)
else:
Expand Down
73 changes: 44 additions & 29 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
_CheckpointState,
)
from modelopt.torch.utils import print_rank_0
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState, is_master
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method

from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator
Expand Down Expand Up @@ -1761,7 +1761,7 @@ def layerwise_calibrate(
checkpoint_dir = calib_kwargs.pop("checkpoint_dir", None)
qdq_from_prev = calib_kwargs.pop("get_qdq_activations_from_prev_layer", False)
save_every = calib_kwargs.pop("save_every", 1)
save_quantizers_only = calib_kwargs.pop("save_quantizers_only", False)
calib_mutates_weights = calib_kwargs.pop("calib_mutates_weights", True)

if forward_loop is None:
raise ValueError(
Expand All @@ -1783,16 +1783,27 @@ def layerwise_calibrate(
checkpoint_dir,
num_layers,
save_every=save_every,
save_quantizers_only=save_quantizers_only,
calib_mutates_weights=calib_mutates_weights,
)
start_layer = ckpt.start_layer if ckpt else 0

input_getter = LayerActivationCollector(model)
input_getter._patch_all_layers(decoder_layers=transformer_layers)
layer_pbar = tqdm(
total=num_layers,
initial=start_layer,
desc="Layerwise calibration",
disable=not is_master(),
dynamic_ncols=True,
)

def _set_layer_status(status: str):
layer_pbar.set_postfix_str(status, refresh=True)

resumed_inputs = ckpt.setup_resume(transformer_layers) if ckpt and start_layer > 0 else None
input_getter = LayerActivationCollector(model, status_callback=_set_layer_status)

try:
input_getter._patch_all_layers(decoder_layers=transformer_layers)
resumed_inputs = ckpt.setup_resume(transformer_layers) if ckpt and start_layer > 0 else None

# Bootstrap: get first layer's inputs (or use resumed inputs).
layer_inputs = input_getter.get_first_layer_inputs(
start_layer, resumed_inputs, forward_loop
Expand Down Expand Up @@ -1822,39 +1833,43 @@ def _layer_forward_loop(m, _inputs=layer_inputs):

is_last = layer_idx + 1 >= num_layers

# qdq_from_prev=False: capture before calib_func so the forward
# replay uses the original FP weights. Disable quantizers too in
# case any pre-calibration observer behavior would perturb the
# captured activations.
if not is_last and not qdq_from_prev:
with set_quantizer_by_cfg_context(
layer, [{"quantizer_name": "*", "enable": False}]
):
next_inputs = input_getter.cache_outputs_for_next_layer_calib(
layer, forward_loop
)
# cache_outputs left this layer in "run" mode with an empty
# deque; reset so calib_func's replay hits the real forward.
layer._layerwise_calib.mode = "original"
with persistent_materialization(layer, writeback=calib_mutates_weights):
# qdq_from_prev=False: capture before calib_func so the forward
# replay uses the original FP weights. Disable quantizers too in
# case any pre-calibration observer behavior would perturb the
# captured activations.
if not is_last and not qdq_from_prev:
with set_quantizer_by_cfg_context(
layer, [{"quantizer_name": "*", "enable": False}]
):
next_inputs = input_getter.cache_outputs_for_next_layer_calib(
layer, forward_loop
)
# cache_outputs left this layer in "run" mode with an empty
# deque; reset so calib_func's replay hits the real forward.
layer._layerwise_calib.mode = "original"

with persistent_materialization(layer):
calib_func(layer, _layer_forward_loop, **calib_kwargs)

# qdq_from_prev=True: capture after calib_func so the next layer
# sees QDQ error and any in-place weight updates from this layer.
if not is_last and qdq_from_prev:
next_inputs = input_getter.cache_outputs_for_next_layer_calib(layer, forward_loop)
elif is_last:
next_inputs = None
# qdq_from_prev=True: capture after calib_func so the next layer
# sees QDQ error and any in-place weight updates from this layer.
if not is_last and qdq_from_prev:
next_inputs = input_getter.cache_outputs_for_next_layer_calib(
layer, forward_loop
)
elif is_last:
next_inputs = None

if ckpt:
ckpt.save(layer_idx, model, transformer_layers, next_inputs)
if ckpt:
ckpt.save(layer_idx, model, transformer_layers, next_inputs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain why we move ckpt.save inside the persistent_materialization context?


layer_pbar.update(1)
del layer_inputs
torch.cuda.empty_cache()
layer_inputs = next_inputs # noqa: F841 (used in next iteration's closure)
finally:
input_getter._unpatch_all_layers()
layer_pbar.close()

if ckpt:
ckpt.full_restore(transformer_layers, model)
Expand Down
16 changes: 12 additions & 4 deletions modelopt/torch/quantization/plugins/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,22 @@ def _writeback_params_to_weights_map(module, align_hook):
# on-disk version. OffloadedWeightsLoader.__getitem__ gives
# state_dict priority over index, so this is sufficient.
w_map[key] = tensor.detach().cpu()
warnings.warn(
"Accelerate disk-offload writeback is currently kept in CPU state_dict; "
"writing updates back to disk offload files is TODO.",
RuntimeWarning,
stacklevel=2,
)


@contextmanager
def weight_access_and_writeback_context(module):
def weight_access_and_writeback_context(module, writeback: bool = True):
"""Context manager for weight access and writeback for modules managed by accelerate.

Handles CPU-offloaded and disk-offloaded models. Iterates over the module and all
its descendants, materializing weights from any offload hook found and writing them
back on exit. ``pre_forward`` is skipped on modules whose weights are already
its descendants, materializing weights from any offload hook found. If ``writeback``
is True, writes materialized tensors back on exit. ``pre_forward`` is skipped on
modules whose weights are already
materialized (not on meta) to avoid overwriting them with stale CPU copies.
"""
assert hasattr(module, "_hf_hook")
Expand All @@ -99,7 +106,8 @@ def weight_access_and_writeback_context(module):
finally:
for mod, hook, was_materialized in materialized:
hook.offload = True
_writeback_params_to_weights_map(mod, hook)
if writeback:
_writeback_params_to_weights_map(mod, hook)
if was_materialized:
hook.post_forward(mod, None)

Expand Down
48 changes: 40 additions & 8 deletions modelopt/torch/quantization/utils/core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from modelopt.torch.quantization.config import QuantizerCfgEntry
from modelopt.torch.utils import get_unwrapped_name, print_rank_0
from modelopt.torch.utils.network import temporarily_remove_accelerate_hook

if TYPE_CHECKING:
from collections.abc import Generator
Expand Down Expand Up @@ -471,7 +472,24 @@ def _set_parameter(module: nn.Module, name: str, value: nn.Parameter):


@contextmanager
def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn.Module):
def _fsdp2_unshard_context(fsdp_module: FSDPModule):
"""Unshard an FSDP2 module without replacing individual DTensor parameters."""
fsdp_param_group = fully_shard.state(fsdp_module)._fsdp_param_group
was_sharded = fsdp_param_group.is_sharded
if was_sharded:
fsdp_module.unshard()
try:
with _disable_fsdp_unshard_reshard(fsdp_module):
yield
finally:
if was_sharded:
fsdp_module.reshard()


@contextmanager
def fsdp2_weight_access_and_writeback_context(
module: nn.Module, root_model: nn.Module, writeback: bool = True
):
"""Context manager for FSDP2 weight access and writeback.

Gathers sharded DTensor parameters across FSDP/HSDP shards so they can be
Expand All @@ -486,6 +504,11 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn.
assert not hasattr(module, "_hf_hook"), "We dont support FSDP2 with HF accelerate hooks"
fsdp_module = _get_enclosing_fsdp_module(module, root_model)
assert fsdp_module is not None, "Module is not wrapped by FSDP"
if not writeback:
with _fsdp2_unshard_context(fsdp_module):
yield
return

Comment on lines +507 to +511
Copy link
Copy Markdown
Contributor Author

@realAsma realAsma Jun 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sugunav14 here is an easy perf improvement for layerwise FSDP2

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@realAsma Claude claims a correctness issue on this branch, it makes sense to me, please check from your side:

Bug: FSDP2 writeback=False path never all-gathers weights (non-mutating layerwise calibration computes on sharded shards)

Where: _fsdp2_unshard_context (core_utils.py:480) reached via persistent_materialization(layer, writeback=False).

Root cause — a collision between two layers of the same context stack.

persistent_materialization enters its context managers in this order:

with (
    _disable_fsdp_unshard_reshard(layer),                              # ① enters FIRST
    enable_weight_access_and_writeback(layer, layer, writeback=False), # ② enters SECOND
    temporarily_remove_accelerate_hook(layer),
):
  • monkeypatches the class method FSDPParamGroup.unshard to a no-op (to stop FSDP from re-sharding weights between
    calibration batches). This patch is global and is now active.

  • with writeback=False routes into _fsdp2_unshard_context, whose gather step is:

    if was_sharded:
        fsdp_module.unshard()   # core_utils.py:480

    But FSDPModule.unshard() delegates the actual all-gather to fsdp_param_group.unshard() — which is exactly the method
    just patched to a no-op. So no all-gather happens; the params stay sharded DTensors, and the calibration/capture forward
    runs on partial shards → wrong _amax (or a shape/dtype error).

The layer's own FSDP pre-forward hook can't save it either: that hook also calls FSDPParamGroup.unshard, which is still the
no-op.

Why writeback=True (GPTQ/AWQ/SmoothQuant) is unaffected: that branch of fsdp2_weight_access_and_writeback_context gathers
via param.redistribute(... Replicate ...) — a DTensor collective that doesn't touch the patched unshard(). So only the
writeback=False (max/mse/local_hessian) path, i.e. the non-mutating optimization this PR adds, is broken on FSDP2.

_fsdp2_unshard_context is internally self-contradictory under this caller: line 480 calls unshard() (needs it to work) while
line 482 calls _disable_fsdp_unshard_reshard (disables it). It assumes it runs before anyone disables unshard, but its only
production caller disables it first.

Impact / trigger: FSDP2 layerwise calibration with calib_mutates_weights=False. The default (True) takes the redistribute
path and works, so the bug only fires when a user opts into the non-mutating feature.

Test that should catch this (likely unrun — it's a multi-GPU GPU test on a draft PR):
tests/gpu/torch/quantization/test_fsdp2.py:264

with persistent_materialization(layer, writeback=False):
    assert not isinstance(layer[0].weight, DTensor)   # fails: still a DTensor
    layer(inputs)

Suggested fix (don't double-disable):

  1. Have _fsdp2_unshard_context call the unpatched unshard (_disable_fsdp_unshard_reshard already captures orig_unshard;
    expose it), so the gather works while per-forward reshard stays suppressed; or
  2. Skip persistent_materialization's outer _disable_fsdp_unshard_reshard on the FSDP writeback=False path, since
    _fsdp2_unshard_context already suppresses reshard internally after a real unshard.

fsdp_device_mesh = _get_fsdp2_mesh(fsdp_module)
fsdp_dim = fsdp_device_mesh.ndim

Expand Down Expand Up @@ -525,7 +548,9 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn.


@contextmanager
def enable_weight_access_and_writeback(module, root_model, name_to_module: dict | None = None):
def enable_weight_access_and_writeback(
module, root_model, name_to_module: dict | None = None, writeback: bool = True
):
"""Enable weight access and writeback for a module.

Useful for modules with weight not intact such as Linear layer in FSDP wrapped model or
Expand All @@ -539,16 +564,18 @@ def enable_weight_access_and_writeback(module, root_model, name_to_module: dict
total cost when called in a loop. This causes significant CPU overhead on large
models, particularly Sparse MoE architectures where each expert is typically
implemented as its own module.
writeback: Whether modified weights must be written back to the owning sharded/offload
representation when exiting the context.
"""
if _get_enclosing_fsdp_module(module, root_model, name_to_module) is not None:
context = fsdp2_weight_access_and_writeback_context(module, root_model)
context = fsdp2_weight_access_and_writeback_context(module, root_model, writeback)
elif is_quantized_parallel_linear(module) and hasattr(module, "_hf_tp_plan"):
# HF transformers TP sharded linear layer
context = module.enable_weight_access_and_writeback()
elif hasattr(module, "_hf_hook"):
from ..plugins.accelerate import weight_access_and_writeback_context

context = weight_access_and_writeback_context(module)
context = weight_access_and_writeback_context(module, writeback)
else:
context = nullcontext()

Expand All @@ -557,18 +584,23 @@ def enable_weight_access_and_writeback(module, root_model, name_to_module: dict


@contextmanager
def persistent_materialization(layer):
def persistent_materialization(layer, writeback: bool = True):
"""Keep all layer weights materialized on GPU for the duration.

Suppresses per-forward weight transfers so that N calibration batches
pay the cost of one load/unload instead of N.

- **FSDP2**: patches ``FSDPParamGroup.unshard/reshard`` to no-ops, then
gathers weights once via ``enable_weight_access_and_writeback``.
- **Accelerate**: materializes weights and sets ``hook.offload = False``
so per-forward hooks skip materialization/offloading.
- **Accelerate**: materializes weights, sets ``hook.offload = False``,
and bypasses the layer's top-level accelerate hook while the weights are
materialized.
"""
with _disable_fsdp_unshard_reshard(layer), enable_weight_access_and_writeback(layer, layer):
with (
_disable_fsdp_unshard_reshard(layer),
enable_weight_access_and_writeback(layer, layer, writeback=writeback),
temporarily_remove_accelerate_hook(layer),
):
yield


Expand Down
Loading