From 090d307d3b983f7b934a4c91d0e2e5f28d0b1dc2 Mon Sep 17 00:00:00 2001 From: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 29 May 2026 21:50:00 +0000 Subject: [PATCH 1/6] Nest layerwise calibration config; add get_qdq_activations_from_prev_layer Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 43 ++++-- examples/llm_ptq/hf_ptq.py | 9 +- modelopt/torch/quantization/config.py | 129 +++++++++++++++--- modelopt/torch/quantization/mode.py | 11 +- modelopt/torch/quantization/model_calib.py | 34 ++++- .../quantization/test_config_validation.py | 123 +++++++++++++---- .../quantization/test_layerwise_calibrate.py | 57 +++++++- 7 files changed, 332 insertions(+), 74 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index ec4bb2b0519..240f68a6531 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -850,22 +850,37 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod print("No custom model files found to copy") -def needs_checkpoint_path_update(quant_cfg: dict) -> bool: - """Check if quant_cfg has a layerwise_checkpoint_dir that should be auto-resolved to a unique subpath.""" - algorithm = quant_cfg.get("algorithm") +def _layerwise_checkpoint_dir_location(algorithm) -> tuple[str, str] | None: + """Return ``("flat"/"nested", checkpoint_dir)`` for the layerwise checkpoint dir, or None.""" if not isinstance(algorithm, dict): - return False - return algorithm.get("layerwise_checkpoint_dir") is not None + return None + flat = algorithm.get("layerwise_checkpoint_dir") + if flat is not None: + return "flat", flat + nested = algorithm.get("layerwise") or {} + ckpt = nested.get("checkpoint_dir") if isinstance(nested, dict) else None + return ("nested", ckpt) if ckpt is not None else None + + +def needs_checkpoint_path_update(quant_cfg: dict) -> bool: + """Check if quant_cfg has a layerwise checkpoint_dir that should be auto-resolved to a unique subpath.""" + return _layerwise_checkpoint_dir_location(quant_cfg.get("algorithm")) is not None -def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> dict: - """Append a unique ``_`` subdirectory to layerwise_checkpoint_dir. +def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> tuple[dict, str]: + """Append a unique ``_`` subdirectory to the layerwise checkpoint_dir. Allows a single recipe to be reused across models without checkpoint collisions. + Supports both the legacy flat ``layerwise_checkpoint_dir`` and the nested + ``layerwise.checkpoint_dir`` shape, writing back to whichever the user provided. Must only be called when :func:`needs_checkpoint_path_update` returns True. + + Returns ``(updated_quant_cfg, resolved_path)`` so the caller can log or + reference the resolved path without re-deriving the dict shape. """ - algorithm = quant_cfg["algorithm"] - base_dir = algorithm["layerwise_checkpoint_dir"] + location = _layerwise_checkpoint_dir_location(quant_cfg["algorithm"]) + assert location is not None # guaranteed by needs_checkpoint_path_update + shape, base_dir = location name = model_path.rstrip("/") if "/" in name and not os.path.isabs(name): @@ -874,9 +889,11 @@ def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> dict: name = Path(name).name config_hash = hashlib.sha256(json.dumps(quant_cfg, default=str).encode()).hexdigest()[:8] + resolved = os.path.join(base_dir, f"{name}_{config_hash}") quant_cfg = copy.deepcopy(quant_cfg) - quant_cfg["algorithm"]["layerwise_checkpoint_dir"] = os.path.join( - base_dir, f"{name}_{config_hash}" - ) - return quant_cfg + if shape == "flat": + quant_cfg["algorithm"]["layerwise_checkpoint_dir"] = resolved + else: + quant_cfg["algorithm"]["layerwise"]["checkpoint_dir"] = resolved + return quant_cfg, resolved diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 6d27aa593f6..7483deb10c0 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -1000,7 +1000,8 @@ def _is_layerwise(obj): return _is_layerwise(obj.quantize.algorithm) if isinstance(obj, list): return any(_is_layerwise(a) for a in obj) - return bool(getattr(obj, "layerwise", False)) + layerwise = getattr(obj, "layerwise", None) + return bool(getattr(layerwise, "enable", False)) is_layerwise = _is_layerwise(recipe) @@ -1135,10 +1136,8 @@ def _is_layerwise(obj): _set_kv_cache_constant_amax(quant_cfg["quant_cfg"]) if needs_checkpoint_path_update(quant_cfg): - quant_cfg = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path) - print( - f"Auto-resolved layerwise_checkpoint_dir: {quant_cfg['algorithm']['layerwise_checkpoint_dir']}" - ) + quant_cfg, resolved_dir = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path) + print(f"Auto-resolved layerwise checkpoint_dir: {resolved_dir}") if args.cast_mxfp4_to_nvfp4: quant_cfg = copy.deepcopy(quant_cfg) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 602375d2bf6..6f91a4fa79a 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -154,7 +154,7 @@ from collections.abc import Mapping, Sequence from typing import Any, Literal -from pydantic import AliasChoices, ValidationInfo, field_validator, model_validator +from pydantic import AliasChoices, Field, ValidationInfo, field_validator, model_validator from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.opt.config_loader import load_config @@ -633,6 +633,59 @@ def validate_calibrator(cls, v, info: ValidationInfo): ) +class LayerwiseConfig(ModeloptBaseConfig): + """Nested config for layer-by-layer calibration behavior.""" + + enable: bool = ModeloptField( + default=False, + title="Enable layerwise (layer-by-layer) calibration.", + description=( + "If True, the calibration algorithm is applied layer by layer. " + "Each layer's inputs are captured via a forward pass that reflects the " + "quantization of all preceding layers, incurring O(N) forward passes for N layers." + ), + ) + + get_qdq_activations_from_prev_layer: bool = ModeloptField( + default=False, + title="Cache next-layer inputs from QDQ outputs of prior layers.", + description=( + "If True (GPTQ default), layer N's calibration sees inputs carrying " + "the quantize-dequantize error of layers 0..N-1, so quantization " + "error compounds across layers. If False (max-calib default), " + "quantizers are temporarily disabled during the capture forward, so " + "layer N sees the same full-precision activations as a non-layerwise " + "calibration pass." + ), + ) + + checkpoint_dir: str | None = ModeloptField( + default=None, + title="Per-layer checkpoint directory (resume on restart).", + description=( + "If set, per-layer checkpoints are saved here during calibration. " + "On restart, calibration resumes from the last completed layer." + ), + ) + + +def _coerce_layerwise_input(value): + """Normalize a raw ``layerwise`` value to a dict; warn on deprecated bool.""" + if isinstance(value, bool): + warnings.warn( + "Passing the layerwise field as a bool is deprecated; use a dict, " + "e.g. `{'enable': True}`.", + DeprecationWarning, + stacklevel=2, + ) + return {"enable": value} + if value is None: + return {} + if isinstance(value, LayerwiseConfig): + return value.model_dump() + return value + + class QuantizeAlgorithmConfig(ModeloptBaseConfig): """Calibration algorithm config base.""" @@ -656,34 +709,55 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): ), ) - layerwise: bool = ModeloptField( - default=False, + layerwise: LayerwiseConfig = Field( + default_factory=LayerwiseConfig, validation_alias=AliasChoices("layerwise", "use_sequential"), - title="Enable layerwise (layer-by-layer) calibration.", + title="Layerwise calibration configuration.", description=( - "If True, the calibration algorithm is applied layer by layer. " - "Each layer's inputs are captured via a forward pass that reflects the " - "quantization of all preceding layers, incurring O(N) forward passes for N layers." + "Nested config controlling layer-by-layer calibration. Pass a dict, " + "e.g. ``{'enable': True, 'checkpoint_dir': '/path'}``. Bool input is " + "accepted for backward compatibility but deprecated." ), ) - layerwise_checkpoint_dir: str | None = ModeloptField( - default=None, - title="Checkpoint directory for layerwise calibration.", - description=( - "If set together with layerwise=True, per-layer checkpoints are saved to this " - "directory during calibration. On restart, calibration resumes from the last " - "completed layer." - ), - ) + @model_validator(mode="before") + @classmethod + def _migrate_layerwise_checkpoint_dir(cls, data): + """Merge the legacy flat ``layerwise_checkpoint_dir`` key into ``layerwise``. + + Raises if both the flat key and a nested ``checkpoint_dir`` are set with conflicting values. + """ + if not isinstance(data, dict) or "layerwise_checkpoint_dir" not in data: + return data + data = dict(data) + flat_dir = data.pop("layerwise_checkpoint_dir") + # Resolve the legacy ``use_sequential`` alias before writing ``layerwise``, + # otherwise the alias value is silently dropped when AliasChoices picks the + # newly-written ``layerwise`` key over ``use_sequential``. + raw_layerwise = data.pop("layerwise", data.pop("use_sequential", None)) + layerwise = _coerce_layerwise_input(raw_layerwise) + existing = layerwise.get("checkpoint_dir") + if existing is not None and existing != flat_dir: + raise ValueError( + f"Conflicting checkpoint_dir: layerwise_checkpoint_dir={flat_dir!r} " + f"differs from layerwise.checkpoint_dir={existing!r}. Set only one." + ) + data["layerwise"] = {**layerwise, "checkpoint_dir": flat_dir} + return data + + @field_validator("layerwise", mode="before") + @classmethod + def _coerce_layerwise(cls, value): + """Coerce ``layerwise=bool/None`` to dict form; also handles the alias path.""" + return _coerce_layerwise_input(value) @model_validator(mode="after") def validate_layerwise_checkpoint_dir(self): - """Raise if layerwise_checkpoint_dir is set but layerwise is False.""" - if self.layerwise_checkpoint_dir is not None and not self.layerwise: + """Raise if layerwise.checkpoint_dir is set but layerwise.enable is False.""" + if self.layerwise.checkpoint_dir is not None and not self.layerwise.enable: raise ValueError( - "layerwise_checkpoint_dir requires layerwise=True. " - "Set layerwise=True or remove layerwise_checkpoint_dir." + "layerwise.checkpoint_dir requires layerwise.enable=True. " + "Set layerwise.enable=True or remove layerwise.checkpoint_dir." ) return self @@ -996,6 +1070,21 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig): per-column error propagation into one launch per GPTQ block.""", ) + @model_validator(mode="after") + def _gptq_qdq_default(self): + """Inject ``get_qdq_activations_from_prev_layer=True`` unless the user set it. + + GPTQ's Hessian correctness depends on prior-layer QDQ activations, so the + default differs from the base class. Uses ``model_fields_set`` to detect + whether the user explicitly set the field — covers every input shape + (empty constructor, bool, dict) without a per-shape special case. + """ + if "get_qdq_activations_from_prev_layer" not in self.layerwise.model_fields_set: + self.layerwise = self.layerwise.model_copy( + update={"get_qdq_activations_from_prev_layer": True} + ) + return self + QuantizeQuantCfgType = list[QuantizerCfgEntry] QuantizerCfgListConfig = QuantizeQuantCfgType diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 530e90aaa00..05fbe2d223d 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -223,8 +223,10 @@ def wrapped_calib_func( """ kwargs = config.model_dump() method = kwargs.pop("method") - layerwise = kwargs.pop("layerwise", False) - checkpoint_dir = kwargs.pop("layerwise_checkpoint_dir", None) + layerwise_cfg = kwargs.pop("layerwise", None) or {} + layerwise = layerwise_cfg.get("enable", False) + checkpoint_dir = layerwise_cfg.get("checkpoint_dir") + qdq_from_prev = layerwise_cfg.get("get_qdq_activations_from_prev_layer", False) if method is not None and "awq" in method: # For backward compatibility kwargs["algorithm"] = method @@ -244,8 +246,8 @@ def wrapped_calib_func( # future algorithms that need full-model context must add a guard here. if not supports_layerwise: raise ValueError( - f"Calibration algorithm '{method}' does not support layerwise=True. " - "Set layerwise=False, or override `_supports_layerwise = True` on the " + f"Calibration algorithm '{method}' does not support layerwise.enable=True. " + "Set layerwise.enable=False, or override `_supports_layerwise = True` on the " "corresponding CalibrateModeDescriptor once the algorithm is made " "compatible with per-layer calibration." ) @@ -257,6 +259,7 @@ def wrapped_calib_func( forward_loop=forward_loop, calib_func=func, checkpoint_dir=checkpoint_dir, + get_qdq_activations_from_prev_layer=qdq_from_prev, **kwargs, ) else: diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 978f660f9d1..e08fb9bad17 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -19,6 +19,7 @@ import time import warnings from collections.abc import Callable +from contextlib import AbstractContextManager, nullcontext from functools import partial from typing import TypeAlias @@ -1752,8 +1753,14 @@ def layerwise_calibrate( If ``checkpoint_dir`` is passed (via ``calib_kwargs``), per-layer checkpoints are saved after each layer completes. On restart, calibration resumes from the last completed layer. + + ``get_qdq_activations_from_prev_layer`` (via ``calib_kwargs``) controls + whether the cached inputs handed to layer N+1 come from a forward through + the just-calibrated layer with quantizers active (True; e.g. GPTQ) or + temporarily disabled (False; matches non-layerwise max-calib semantics). """ checkpoint_dir = calib_kwargs.pop("checkpoint_dir", None) + qdq_from_prev = calib_kwargs.pop("get_qdq_activations_from_prev_layer", False) if forward_loop is None: raise ValueError( @@ -1814,7 +1821,20 @@ def _layer_forward_loop(m, _inputs=layer_inputs): # output_meta on the just-calibrated layer (via "run" mode). is_last = layer_idx + 1 >= num_layers if not is_last: - next_inputs = input_getter.cache_outputs_for_next_layer_calib(layer, forward_loop) + # When qdq_from_prev is False, temporarily disable every quantizer + # under the just-calibrated layer so the next layer sees full-precision + # activations (matches non-layerwise calibration semantics). + capture_ctx: AbstractContextManager = ( + nullcontext() + if qdq_from_prev + else set_quantizer_by_cfg_context( + layer, [{"quantizer_name": "*", "enable": False}] + ) + ) + with capture_ctx: + next_inputs = input_getter.cache_outputs_for_next_layer_calib( + layer, forward_loop + ) else: next_inputs = None @@ -1843,13 +1863,13 @@ def gptq( ): """GPTQ quantization. - Works in two modes depending on ``layerwise`` in the config: + Works in two modes depending on ``layerwise.enable`` in the config: - * **Layerwise** (``layerwise=True``): ``layerwise_calibrate`` calls this - function once per decoder layer with updated activations, producing more - accurate Hessian estimates. - * **Non-layerwise** (``layerwise=False``): called once on the full model. - All layers are quantized in parallel from the original activations. + * **Layerwise** (``layerwise.enable=True``): ``layerwise_calibrate`` calls + this function once per decoder layer with updated activations, producing + more accurate Hessian estimates. + * **Non-layerwise** (``layerwise.enable=False``): called once on the full + model. All layers are quantized in parallel from the original activations. Per-module steps: diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index ce98f989f51..dd610c18b80 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -15,6 +15,8 @@ """Test of quantization config validations.""" +import warnings + import pytest from pydantic import ValidationError @@ -25,6 +27,7 @@ INT4_AWQ_CFG, NVFP4_DEFAULT_CFG, W4A8_AWQ_BETA_CFG, + GPTQCalibConfig, MaxCalibConfig, QuantizeConfig, find_quant_cfg_entry_by_path, @@ -573,32 +576,104 @@ def test_validate_quant_cfg_entries_accepts_valid_cfg(self): class TestLayerwiseUseSequentialAlias: - """`layerwise` accepts the legacy `use_sequential` name via validation_alias. - - Old PTQ checkpoints serialized the field as `use_sequential` before #1251 renamed - it to `layerwise`. AliasChoices lets those checkpoints load without a migration - validator while still serializing under the current name. - """ + """`use_sequential` is the legacy alias for `layerwise` (pre-#1251 checkpoints).""" + + @pytest.mark.parametrize("value", [True, False]) + def test_use_sequential_resolves_to_layerwise(self, value): + with pytest.warns(DeprecationWarning): + cfg = MaxCalibConfig(use_sequential=value) + assert cfg.layerwise.enable is value + + def test_serializes_under_layerwise_not_alias(self): + with pytest.warns(DeprecationWarning): + dumped = MaxCalibConfig(use_sequential=True).model_dump() + assert dumped["layerwise"]["enable"] is True + assert "use_sequential" not in dumped - def test_use_sequential_true_sets_layerwise(self): - cfg = MaxCalibConfig(use_sequential=True) - assert cfg.layerwise is True - def test_use_sequential_false_sets_layerwise(self): - cfg = MaxCalibConfig(use_sequential=False) - assert cfg.layerwise is False +class TestLayerwiseNestedConfig: + """Layerwise expands from a bool to a nested ``LayerwiseConfig``. - def test_layerwise_name_still_accepted(self): - cfg = MaxCalibConfig(layerwise=True) - assert cfg.layerwise is True + Backward compatibility: bool input is coerced with a DeprecationWarning, and + the legacy flat ``layerwise_checkpoint_dir`` key is silently absorbed. + """ - def test_serializes_under_current_name(self): - """Dump must use `layerwise`, not the legacy alias.""" - dumped = MaxCalibConfig(use_sequential=True).model_dump() - assert dumped["layerwise"] is True - assert "use_sequential" not in dumped + def test_nested_form_accepted(self): + cfg = MaxCalibConfig(layerwise={"enable": True, "checkpoint_dir": "/x"}) + assert cfg.layerwise.enable is True + assert cfg.layerwise.checkpoint_dir == "/x" + + def test_bool_form_deprecated_but_accepted(self): + with pytest.warns(DeprecationWarning, match="bool is deprecated"): + cfg = MaxCalibConfig(layerwise=True) + assert cfg.layerwise.enable is True + + def test_dict_form_no_deprecation(self): + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + MaxCalibConfig(layerwise={"enable": True}) + + def test_flat_checkpoint_dir_migrated_to_nested(self): + with pytest.warns(DeprecationWarning): + cfg = MaxCalibConfig(layerwise=True, layerwise_checkpoint_dir="/x") + assert cfg.layerwise.checkpoint_dir == "/x" + + def test_use_sequential_alias_survives_flat_checkpoint_migration(self): + """``use_sequential`` + flat ``layerwise_checkpoint_dir`` must not drop the alias value.""" + with pytest.warns(DeprecationWarning): + cfg = MaxCalibConfig(use_sequential=True, layerwise_checkpoint_dir="/x") + assert cfg.layerwise.enable is True + assert cfg.layerwise.checkpoint_dir == "/x" + + def test_conflicting_flat_and_nested_checkpoint_dir_raises(self): + with pytest.raises(ValidationError, match="Conflicting checkpoint_dir"): + MaxCalibConfig( + layerwise={"enable": True, "checkpoint_dir": "/a"}, + layerwise_checkpoint_dir="/b", + ) - def test_unknown_field_still_rejected(self): - """extra='forbid' must still reject unrelated unknown fields.""" - with pytest.raises(ValidationError): - MaxCalibConfig(not_a_real_field=True) + @pytest.mark.parametrize( + "kwargs", + [ + {"layerwise": {"checkpoint_dir": "/x"}}, + {"layerwise_checkpoint_dir": "/x"}, + ], + ) + def test_checkpoint_dir_requires_enable(self, kwargs): + with pytest.raises(ValidationError, match="requires layerwise.enable=True"): + MaxCalibConfig(**kwargs) + + @pytest.mark.parametrize( + ("cfg_cls", "expected_qdq"), + [(MaxCalibConfig, False), (GPTQCalibConfig, True)], + ) + def test_per_algorithm_qdq_default(self, cfg_cls, expected_qdq): + assert cfg_cls().layerwise.get_qdq_activations_from_prev_layer is expected_qdq + + @pytest.mark.parametrize( + "layerwise", + [ + {"enable": True}, + pytest.param(True, marks=pytest.mark.filterwarnings("ignore::DeprecationWarning")), + ], + ) + def test_gptq_qdq_default_survives_user_layerwise_input(self, layerwise): + """GPTQ must default qdq=True even when the user supplies a layerwise dict/bool.""" + cfg = GPTQCalibConfig(layerwise=layerwise) + assert cfg.layerwise.get_qdq_activations_from_prev_layer is True + + def test_gptq_user_explicit_qdq_false_wins(self): + """An explicit ``get_qdq_activations_from_prev_layer=False`` must override the GPTQ default.""" + cfg = GPTQCalibConfig( + layerwise={"enable": True, "get_qdq_activations_from_prev_layer": False} + ) + assert cfg.layerwise.get_qdq_activations_from_prev_layer is False + + def test_default_dump_shape(self): + dumped = MaxCalibConfig().model_dump() + assert dumped["layerwise"] == { + "enable": False, + "get_qdq_activations_from_prev_layer": False, + "checkpoint_dir": None, + } + assert "layerwise_checkpoint_dir" not in dumped diff --git a/tests/unit/torch/quantization/test_layerwise_calibrate.py b/tests/unit/torch/quantization/test_layerwise_calibrate.py index 3739feff969..0ff6023cffd 100644 --- a/tests/unit/torch/quantization/test_layerwise_calibrate.py +++ b/tests/unit/torch/quantization/test_layerwise_calibrate.py @@ -713,9 +713,64 @@ def test_mtq_quantize_layerwise_raises_for_unsupported_algorithm(): config = _svdquant_layerwise_config() torch.manual_seed(0) model = _SimpleTransformerModel(n_layers=2, dim=16) - with pytest.raises(ValueError, match="does not support layerwise=True"): + with pytest.raises(ValueError, match="does not support layerwise.enable=True"): mtq.quantize( model, config, forward_loop=lambda m: m(torch.randint(0, 32, (2, 8))), ) + + +def test_layerwise_no_qdq_matches_sequential_amax(monkeypatch): + """Layerwise + ``get_qdq_activations_from_prev_layer=False`` must produce the + same per-quantizer amax as the non-layerwise (sequential) max-calibration + flow. Both paths feed every layer full-precision activations, so the amax + statistics they collect must agree. + """ + _register_test_discoverer(monkeypatch) + + torch.manual_seed(0) + model_seq = _SimpleTransformerModel(n_layers=3, dim=16) + model_lw = copy.deepcopy(model_seq) + calib_data = [torch.randint(0, 32, (2, 8)) for _ in range(2)] + + def fwd(m): + for batch in calib_data: + m(batch) + + seq_cfg = _int8_layerwise_config({"method": "max"}) + mtq.quantize(model_seq, seq_cfg, forward_loop=fwd) + + lw_cfg = _int8_layerwise_config( + { + "method": "max", + "layerwise": {"enable": True, "get_qdq_activations_from_prev_layer": False}, + } + ) + mtq.quantize(model_lw, lw_cfg, forward_loop=fwd) + + def collect_amax(model): + return { + name: q._amax.clone().detach() + for name, q in model.named_modules() + if isinstance(q, TensorQuantizer) + and q.is_enabled + and getattr(q, "_amax", None) is not None + } + + seq_amax = collect_amax(model_seq) + lw_amax = collect_amax(model_lw) + + assert seq_amax, "sequential calibration populated no amax values" + assert set(seq_amax) == set(lw_amax), ( + f"Different quantizers populated: only-seq={set(seq_amax) - set(lw_amax)}, " + f"only-lw={set(lw_amax) - set(seq_amax)}" + ) + for name in seq_amax: + torch.testing.assert_close( + lw_amax[name], + seq_amax[name], + rtol=1e-5, + atol=1e-6, + msg=f"amax mismatch at {name}: seq={seq_amax[name]}, lw={lw_amax[name]}", + ) From 7973f347814dbb551e3052e4b629a8e0f36c2325 Mon Sep 17 00:00:00 2001 From: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 29 May 2026 21:59:06 +0000 Subject: [PATCH 2/6] Migrate PTQ recipes to the nested layerwise config form Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> --- modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml | 5 +++-- modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8.yaml | 5 +++-- .../general/ptq/nvfp4_experts_only-kv_fp8_layerwise.yaml | 3 ++- .../general/ptq/nvfp4_experts_only_mse-kv_fp8_cast.yaml | 5 +++-- .../general/ptq/nvfp4_mlp_only_mse-kv_fp8_cast.yaml | 5 +++-- .../qwen3_5/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml | 3 ++- .../qwen3_5_moe/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml | 3 ++- 7 files changed, 18 insertions(+), 11 deletions(-) diff --git a/modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml b/modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml index 6dee51857c8..5dc5786f0ff 100644 --- a/modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml @@ -29,8 +29,9 @@ metadata: quantize: algorithm: method: gptq - layerwise: true - layerwise_checkpoint_dir: output/layerwise_ckpts/ + layerwise: + enable: true + checkpoint_dir: output/layerwise_ckpts/ quant_cfg: - $import: base_disable_all - quantizer_name: '*weight_quantizer' diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8.yaml index 08864c8a50d..af68684989b 100644 --- a/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8.yaml @@ -30,9 +30,10 @@ quantize: algorithm: method: max # Max calibration is fast and does not typically need checkpointing. - # layerwise=false required for VLMs where the decoder layers are nested under + # layerwise.enable=false required for VLMs where the decoder layers are nested under # `model.language_model.layers` (layerwise_calibrate can't find them otherwise). - layerwise: false + layerwise: + enable: false quant_cfg: - $import: base_disable_all - quantizer_name: '*mlp.experts*weight_quantizer' diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_layerwise.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_layerwise.yaml index 9d1f470643f..2da5abb3b89 100644 --- a/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_layerwise.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_layerwise.yaml @@ -26,7 +26,8 @@ quantize: algorithm: method: max # Max calibration is fast and does not typically need checkpointing. - layerwise: true + layerwise: + enable: true quant_cfg: - $import: base_disable_all - quantizer_name: '*mlp.experts*weight_quantizer' diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-kv_fp8_cast.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-kv_fp8_cast.yaml index 5bf9a36dc31..3043ed32951 100644 --- a/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-kv_fp8_cast.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-kv_fp8_cast.yaml @@ -31,9 +31,10 @@ quantize: algorithm: method: mse fp8_scale_sweep: true - # layerwise=false required for VLMs where the decoder layers are nested under + # layerwise.enable=false required for VLMs where the decoder layers are nested under # `model.language_model.layers` (layerwise_calibrate can't find them otherwise). - layerwise: false + layerwise: + enable: false quant_cfg: - $import: base_disable_all - quantizer_name: '*mlp.experts*weight_quantizer' diff --git a/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-kv_fp8_cast.yaml b/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-kv_fp8_cast.yaml index 2ea2c0ab13e..71c354ee1b1 100644 --- a/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-kv_fp8_cast.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-kv_fp8_cast.yaml @@ -31,9 +31,10 @@ quantize: algorithm: method: mse fp8_scale_sweep: true - # layerwise=false required for VLMs where the decoder layers are nested under + # layerwise.enable=false required for VLMs where the decoder layers are nested under # `model.language_model.layers` (layerwise_calibrate can't find them otherwise). - layerwise: false + layerwise: + enable: false quant_cfg: - $import: base_disable_all - quantizer_name: '*mlp*weight_quantizer' diff --git a/modelopt_recipes/huggingface/qwen3_5/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml b/modelopt_recipes/huggingface/qwen3_5/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml index 0386c70b13e..355bffee67c 100644 --- a/modelopt_recipes/huggingface/qwen3_5/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml +++ b/modelopt_recipes/huggingface/qwen3_5/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml @@ -34,6 +34,7 @@ metadata: quantize: algorithm: method: max - layerwise: false + layerwise: + enable: false quant_cfg: - $import: shared_quant_cfg diff --git a/modelopt_recipes/huggingface/qwen3_5_moe/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml b/modelopt_recipes/huggingface/qwen3_5_moe/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml index 2a3b99cdeb4..fa24bee97af 100644 --- a/modelopt_recipes/huggingface/qwen3_5_moe/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml +++ b/modelopt_recipes/huggingface/qwen3_5_moe/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml @@ -35,6 +35,7 @@ metadata: quantize: algorithm: method: max - layerwise: false + layerwise: + enable: false quant_cfg: - $import: shared_quant_cfg From adcf0611632844eb02316fcf9e9b353e48aac1ce Mon Sep 17 00:00:00 2001 From: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 29 May 2026 22:55:17 +0000 Subject: [PATCH 3/6] Add save_every + save_quantizers_only checkpoint knobs Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 45 ++++- modelopt/torch/quantization/mode.py | 4 + modelopt/torch/quantization/model_calib.py | 11 +- .../quantization/utils/layerwise_calib.py | 163 +++++++++++++----- .../quantization/test_config_validation.py | 53 ++++-- .../quantization/test_layerwise_calibrate.py | 128 ++++++++++++++ 6 files changed, 348 insertions(+), 56 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 6f91a4fa79a..10f7b5db8ff 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -152,7 +152,7 @@ import warnings from collections.abc import Mapping, Sequence -from typing import Any, Literal +from typing import Any, ClassVar, Literal from pydantic import AliasChoices, Field, ValidationInfo, field_validator, model_validator @@ -668,6 +668,28 @@ class LayerwiseConfig(ModeloptBaseConfig): ), ) + save_every: int = ModeloptField( + default=1, + ge=1, + title="Flush resume metadata every N layers (final layer always flushes).", + description=( + "Only the boundary layer of each window writes the large " + "``next_inputs.pt`` activation cache; other per-layer files are " + "still written for every layer (resume needs them to replay skips). " + "Mid-window interrupts re-calibrate the unfinished window on resume." + ), + ) + + save_quantizers_only: bool = ModeloptField( + default=False, + title="Skip the per-layer weights blob; persist only quantizer state.", + description=( + "Only accepted by algorithms that update solely ``TensorQuantizer._amax`` " + "(max, mse, local_hessian). Rejected for weight-mutating algorithms " + "(GPTQ, AWQ, SmoothQuant) where it would silently lose updates on resume." + ), + ) + def _coerce_layerwise_input(value): """Normalize a raw ``layerwise`` value to a dict; warn on deprecated bool.""" @@ -689,6 +711,10 @@ def _coerce_layerwise_input(value): class QuantizeAlgorithmConfig(ModeloptBaseConfig): """Calibration algorithm config base.""" + # Set True only for algorithms that update solely ``TensorQuantizer._amax`` + # (no ``layer.weight`` mutation). Gates ``layerwise.save_quantizers_only``. + _supports_save_quantizers_only: ClassVar[bool] = False + method: Literal[None] = ModeloptField( None, title="This field specifies the name of the calibration algorithm. If None, no calibration is performed.", @@ -761,6 +787,17 @@ def validate_layerwise_checkpoint_dir(self): ) return self + @model_validator(mode="after") + def _validate_save_quantizers_only_supported(self): + """Enforce the ``_supports_save_quantizers_only`` whitelist.""" + if self.layerwise.save_quantizers_only and not self._supports_save_quantizers_only: + raise ValueError( + f"Algorithm '{self.method}' mutates layer weights in-place; " + "save_quantizers_only=True would lose those updates on resume. " + "Only max/mse/local_hessian (amax-only) support this flag." + ) + return self + class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. @@ -770,6 +807,8 @@ class MaxCalibConfig(QuantizeAlgorithmConfig): See `Integer Quantization `_ for the concepts. """ + _supports_save_quantizers_only: ClassVar[bool] = True + method: Literal["max"] = ModeloptField("max") distributed_sync: bool | None = ModeloptField( @@ -801,6 +840,8 @@ class MseCalibConfig(QuantizeAlgorithmConfig): When fp8_scale_sweep is enabled, step_size is ignored. """ + _supports_save_quantizers_only: ClassVar[bool] = True + method: Literal["mse"] = ModeloptField("mse") step_size: float | None = ModeloptField( @@ -853,6 +894,8 @@ class LocalHessianCalibConfig(QuantizeAlgorithmConfig): """ + _supports_save_quantizers_only: ClassVar[bool] = True + method: Literal["local_hessian"] = ModeloptField("local_hessian") step_size: float | None = ModeloptField( diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 05fbe2d223d..28cadf3aa7f 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -227,6 +227,8 @@ def wrapped_calib_func( layerwise = layerwise_cfg.get("enable", False) checkpoint_dir = layerwise_cfg.get("checkpoint_dir") qdq_from_prev = layerwise_cfg.get("get_qdq_activations_from_prev_layer", False) + save_every = layerwise_cfg.get("save_every", 1) + save_quantizers_only = layerwise_cfg.get("save_quantizers_only", False) if method is not None and "awq" in method: # For backward compatibility kwargs["algorithm"] = method @@ -260,6 +262,8 @@ def wrapped_calib_func( calib_func=func, checkpoint_dir=checkpoint_dir, get_qdq_activations_from_prev_layer=qdq_from_prev, + save_every=save_every, + save_quantizers_only=save_quantizers_only, **kwargs, ) else: diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index e08fb9bad17..017e75840ee 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1761,6 +1761,8 @@ def layerwise_calibrate( """ checkpoint_dir = calib_kwargs.pop("checkpoint_dir", None) qdq_from_prev = calib_kwargs.pop("get_qdq_activations_from_prev_layer", False) + save_every = calib_kwargs.pop("save_every", 1) + save_quantizers_only = calib_kwargs.pop("save_quantizers_only", False) if forward_loop is None: raise ValueError( @@ -1778,7 +1780,12 @@ def layerwise_calibrate( num_layers = len(transformer_layers) print_rank_0(f"Layerwise calibration: Found {num_layers} transformer layers") - ckpt = _CheckpointState.from_folder(checkpoint_dir, num_layers) + ckpt = _CheckpointState.from_folder( + checkpoint_dir, + num_layers, + save_every=save_every, + save_quantizers_only=save_quantizers_only, + ) start_layer = ckpt.start_layer if ckpt else 0 input_getter = LayerActivationCollector(model) @@ -1839,7 +1846,7 @@ def _layer_forward_loop(m, _inputs=layer_inputs): next_inputs = None if ckpt: - ckpt.save(layer_idx, layer, model, transformer_layers, next_inputs) + ckpt.save(layer_idx, model, transformer_layers, next_inputs) del layer_inputs torch.cuda.empty_cache() diff --git a/modelopt/torch/quantization/utils/layerwise_calib.py b/modelopt/torch/quantization/utils/layerwise_calib.py index aed403ad87b..8856bb6f17e 100644 --- a/modelopt/torch/quantization/utils/layerwise_calib.py +++ b/modelopt/torch/quantization/utils/layerwise_calib.py @@ -492,18 +492,29 @@ def _layer_dir(checkpoint_dir: str, idx: int) -> str: def _save_layer( checkpoint_dir: str, idx: int, - weights: dict, + weights: dict | None, qstate: dict, + quantizer_buffers: dict | None, output_meta: tuple, next_inputs: list | None, num_layers: int, ) -> None: - """Save a single layer checkpoint and update the manifest atomically.""" + """Save a single layer checkpoint and update the manifest atomically. + + ``weights`` may be ``None`` when the caller opts into a quantizer-only save + (algorithms that do not mutate ``layer.weight``). In that case the per- + quantizer ``state_dict()`` slice — which carries ``_amax`` — is written to + ``quantizer_buffers.pt`` instead, so ``full_restore`` can still recover the + calibrated quantizer state without the full ``weights.pt``. + """ d = _layer_dir(checkpoint_dir, idx) if os.path.isdir(d): shutil.rmtree(d) os.makedirs(d) - torch.save(weights, os.path.join(d, "weights.pt")) + if weights is not None: + torch.save(weights, os.path.join(d, "weights.pt")) + elif quantizer_buffers is not None: + torch.save(quantizer_buffers, os.path.join(d, "quantizer_buffers.pt")) torch.save(qstate, os.path.join(d, "quantizer_state.pt")) torch.save(output_meta, os.path.join(d, "output_meta.pt")) if next_inputs is not None: @@ -541,7 +552,14 @@ class _CheckpointState: and broadcast restored state to all ranks during resume. """ - def __init__(self, checkpoint_dir: str, num_layers: int, start_layer: int = 0): + def __init__( + self, + checkpoint_dir: str, + num_layers: int, + start_layer: int = 0, + save_every: int = 1, + save_quantizers_only: bool = False, + ): if dist.is_initialized() and dist.size() > 1: raise RuntimeError( "Layerwise calibration checkpointing is not supported in " @@ -552,9 +570,21 @@ def __init__(self, checkpoint_dir: str, num_layers: int, start_layer: int = 0): self.checkpoint_dir = checkpoint_dir self.num_layers = num_layers self.start_layer = start_layer + self.save_every = save_every + self.save_quantizers_only = save_quantizers_only + # Tracks the most recent saved layer so save() can window-save the layers + # since the last save event. Initialized to start_layer - 1 so the first + # save event after resume covers the new work only. + self._last_saved_layer = start_layer - 1 @classmethod - def from_folder(cls, checkpoint_dir: str | None, num_layers: int) -> _CheckpointState | None: + def from_folder( + cls, + checkpoint_dir: str | None, + num_layers: int, + save_every: int = 1, + save_quantizers_only: bool = False, + ) -> _CheckpointState | None: """Create from folder. Detects resume point. Returns None if no checkpoint_dir.""" if not checkpoint_dir: return None @@ -572,7 +602,13 @@ def from_folder(cls, checkpoint_dir: str | None, num_layers: int) -> _Checkpoint print_rank_0( f"Checkpoint: resuming layerwise calibration from layer {start}/{num_layers}" ) - return cls(checkpoint_dir, num_layers, start_layer=start) + return cls( + checkpoint_dir, + num_layers, + start_layer=start, + save_every=save_every, + save_quantizers_only=save_quantizers_only, + ) def setup_resume(self, layers: nn.ModuleList) -> list | None: """Load output_meta for skip layers 0..K-1, return next_inputs for layer K. @@ -609,7 +645,10 @@ def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None: """Restore weights and quantizer state for layers 0..K-1 after the calibration loop.""" from modelopt.torch.quantization.config import QuantizeConfig from modelopt.torch.quantization.conversion import restore_quantizer_state - from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback + from modelopt.torch.quantization.utils.core_utils import ( + enable_weight_access_and_writeback, + set_quantizer_state_dict, + ) if self.start_layer == 0: return @@ -630,55 +669,101 @@ def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None: map_location=layer_device, weights_only=False, ) - weights = torch.load( - os.path.join(d, "weights.pt"), - map_location=layer_device, - weights_only=False, - ) restore_quantizer_state(layer, dummy_config, {"quantizer_state": qstate}) - layer.load_state_dict(weights, strict=False, assign=True) + weights_path = os.path.join(d, "weights.pt") + buffers_path = os.path.join(d, "quantizer_buffers.pt") + if os.path.isfile(weights_path): + weights = torch.load( + weights_path, map_location=layer_device, weights_only=False + ) + layer.load_state_dict(weights, strict=False, assign=True) + elif os.path.isfile(buffers_path): + # save_quantizers_only mode: restore just the TensorQuantizer + # state_dict (carries _amax). The layer's other weights + # weren't modified by the algorithm, so the in-memory values + # already match what would have been saved. + quantizer_buffers = torch.load( + buffers_path, map_location=layer_device, weights_only=False + ) + set_quantizer_state_dict(layer, quantizer_buffers) print_rank_0(f"Checkpoint: restored {self.start_layer} previously calibrated layers") def save( self, layer_idx: int, - layer: nn.Module, model: nn.Module, layers: nn.ModuleList, next_layer_inputs: list | None = None, ) -> None: - """Snapshot layer state and write checkpoint to disk in one step. + """Snapshot layer state and write checkpoint to disk. + + With ``save_every == 1`` (default), this writes a single layer per call. + With ``save_every > 1``, the call is a no-op except at window boundaries + (every Nth layer plus the final layer); at a boundary, the full window + of layers since the last save is flushed to disk so ``setup_resume`` can + replay the skip layers correctly. The large ``next_inputs.pt`` is only + written for the final layer in the window — the layer the next resume + would restart from. Args: layer_idx: Index of the layer just calibrated. - layer: The layer module (weights may be on GPU or managed by accelerate/FSDP2). model: The full model (needed for ``enable_weight_access_and_writeback``). - layers: The decoder layer list (to read ``output_meta``). - next_layer_inputs: Inputs for the next layer (``None`` for the final layer). + layers: The decoder layer list. + next_layer_inputs: Inputs for the layer following ``layer_idx`` + (``None`` for the final layer). """ from modelopt.torch.quantization.conversion import quantizer_state - from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback + from modelopt.torch.quantization.utils.core_utils import ( + enable_weight_access_and_writeback, + get_quantizer_state_dict, + ) + + is_final = layer_idx + 1 == self.num_layers + is_window_end = (layer_idx + 1) % self.save_every == 0 + if not (is_final or is_window_end): + return _cpu = torch.device("cpu") - with enable_weight_access_and_writeback(layer, model): - weights = _move_to_device(layer.state_dict(), _cpu) - qstate = _move_to_device(quantizer_state(layer), _cpu) - - output_meta = getattr(layer._layerwise_calib, "output_meta", None) - if output_meta is None: - # Placeholder for the last layer: output_meta is never used for skip mode - # since there is no subsequent layer that needs a correctly shaped dummy output. - output_meta = LayerActivationCollector._extract_output_meta(torch.zeros(1)) - - _save_layer( - self.checkpoint_dir, - layer_idx, - weights, - qstate, - _move_to_device(output_meta, _cpu), - _move_to_device(next_layer_inputs, _cpu) if next_layer_inputs is not None else None, - self.num_layers, - ) + window_start = self._last_saved_layer + 1 + for i in range(window_start, layer_idx + 1): + is_last_in_window = i == layer_idx + layer_i = layers[i] + with enable_weight_access_and_writeback(layer_i, model): + qstate = _move_to_device(quantizer_state(layer_i), _cpu) + if self.save_quantizers_only: + # Save just the TensorQuantizer state_dict slice (carries _amax) + # — the layer's other weights weren't modified by the algorithm. + weights = None + quantizer_buffers = _move_to_device(get_quantizer_state_dict(layer_i), _cpu) + else: + weights = _move_to_device(layer_i.state_dict(), _cpu) + quantizer_buffers = None + + output_meta = getattr(layer_i._layerwise_calib, "output_meta", None) + if output_meta is None: + # Placeholder for the final layer: output_meta is never used for + # skip mode since there is no subsequent layer that needs a + # correctly shaped dummy output. + output_meta = LayerActivationCollector._extract_output_meta(torch.zeros(1)) + + _save_layer( + self.checkpoint_dir, + i, + weights, + qstate, + quantizer_buffers, + _move_to_device(output_meta, _cpu), + _move_to_device(next_layer_inputs, _cpu) + if is_last_in_window and next_layer_inputs is not None + else None, + self.num_layers, + ) + + self._last_saved_layer = layer_idx + window_size = layer_idx - window_start + 1 suffix = " (final)" if next_layer_inputs is None else "" - print_rank_0(f"Checkpoint: saved layer {layer_idx}{suffix}") + if window_size > 1: + print_rank_0(f"Checkpoint: saved layers {window_start}..{layer_idx}{suffix}") + else: + print_rank_0(f"Checkpoint: saved layer {layer_idx}{suffix}") diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index dd610c18b80..c43bdedd09a 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -27,9 +27,13 @@ INT4_AWQ_CFG, NVFP4_DEFAULT_CFG, W4A8_AWQ_BETA_CFG, + AWQLiteCalibConfig, GPTQCalibConfig, + LocalHessianCalibConfig, MaxCalibConfig, + MseCalibConfig, QuantizeConfig, + SmoothQuantCalibConfig, find_quant_cfg_entry_by_path, need_calibration, normalize_quant_cfg_list, @@ -651,23 +655,21 @@ def test_per_algorithm_qdq_default(self, cfg_cls, expected_qdq): assert cfg_cls().layerwise.get_qdq_activations_from_prev_layer is expected_qdq @pytest.mark.parametrize( - "layerwise", + ("layerwise_input", "expected_qdq"), [ - {"enable": True}, - pytest.param(True, marks=pytest.mark.filterwarnings("ignore::DeprecationWarning")), + # GPTQ default kicks in for user dict that doesn't mention qdq. + ({"enable": True}, True), + # GPTQ default kicks in for legacy bool form too. + pytest.param( + True, True, marks=pytest.mark.filterwarnings("ignore::DeprecationWarning") + ), + # User-explicit False overrides the GPTQ default. + ({"enable": True, "get_qdq_activations_from_prev_layer": False}, False), ], ) - def test_gptq_qdq_default_survives_user_layerwise_input(self, layerwise): - """GPTQ must default qdq=True even when the user supplies a layerwise dict/bool.""" - cfg = GPTQCalibConfig(layerwise=layerwise) - assert cfg.layerwise.get_qdq_activations_from_prev_layer is True - - def test_gptq_user_explicit_qdq_false_wins(self): - """An explicit ``get_qdq_activations_from_prev_layer=False`` must override the GPTQ default.""" - cfg = GPTQCalibConfig( - layerwise={"enable": True, "get_qdq_activations_from_prev_layer": False} - ) - assert cfg.layerwise.get_qdq_activations_from_prev_layer is False + def test_gptq_qdq_default_respects_user_explicit_value(self, layerwise_input, expected_qdq): + cfg = GPTQCalibConfig(layerwise=layerwise_input) + assert cfg.layerwise.get_qdq_activations_from_prev_layer is expected_qdq def test_default_dump_shape(self): dumped = MaxCalibConfig().model_dump() @@ -675,5 +677,28 @@ def test_default_dump_shape(self): "enable": False, "get_qdq_activations_from_prev_layer": False, "checkpoint_dir": None, + "save_every": 1, + "save_quantizers_only": False, } assert "layerwise_checkpoint_dir" not in dumped + + def test_save_every_must_be_positive(self): + with pytest.raises(ValidationError): + MaxCalibConfig(layerwise={"enable": True, "save_every": 0}) + + @pytest.mark.parametrize( + "cfg_cls", + [GPTQCalibConfig, AWQLiteCalibConfig, SmoothQuantCalibConfig], + ) + def test_save_quantizers_only_rejected_for_weight_mutating_algorithms(self, cfg_cls): + """Whitelist: only amax-only algorithms (max/mse/local_hessian) may set + save_quantizers_only=True. Weight-mutating algorithms (GPTQ folds Hessian + updates, AWQ/SmoothQuant fold pre-quant scales) must reject the flag. + """ + with pytest.raises(ValidationError, match="mutates layer weights in-place"): + cfg_cls(layerwise={"enable": True, "save_quantizers_only": True}) + + @pytest.mark.parametrize("cfg_cls", [MaxCalibConfig, MseCalibConfig, LocalHessianCalibConfig]) + def test_save_quantizers_only_accepted_for_amax_only_algorithms(self, cfg_cls): + cfg = cfg_cls(layerwise={"enable": True, "save_quantizers_only": True}) + assert cfg.layerwise.save_quantizers_only is True diff --git a/tests/unit/torch/quantization/test_layerwise_calibrate.py b/tests/unit/torch/quantization/test_layerwise_calibrate.py index 0ff6023cffd..46552fa604a 100644 --- a/tests/unit/torch/quantization/test_layerwise_calibrate.py +++ b/tests/unit/torch/quantization/test_layerwise_calibrate.py @@ -16,6 +16,7 @@ """Unit tests for layerwise_calibrate and LayerActivationCollector.""" import copy +import json from collections import deque import pytest @@ -774,3 +775,130 @@ def collect_amax(model): atol=1e-6, msg=f"amax mismatch at {name}: seq={seq_amax[name]}, lw={lw_amax[name]}", ) + + +def _layer_dir_names(checkpoint_dir): + return sorted(p.name for p in checkpoint_dir.iterdir() if p.name.startswith("layer_")) + + +def test_layerwise_save_every_writes_next_inputs_only_at_window_boundaries(monkeypatch, tmp_path): + """With save_every=2 on a 4-layer model, every layer dir is still written + (so resume can replay skip layers), but ``next_inputs.pt`` — the large + activation cache — appears only at window boundaries. + """ + _register_test_discoverer(monkeypatch) + + config = _int8_layerwise_config( + { + "method": "max", + "layerwise": { + "enable": True, + "checkpoint_dir": str(tmp_path), + "save_every": 2, + }, + } + ) + torch.manual_seed(0) + model = _SimpleTransformerModel(n_layers=4, dim=16) + calib_data = [torch.randint(0, 32, (2, 8))] + mtq.quantize(model, config, forward_loop=lambda m: [m(b) for b in calib_data]) + + # Window boundaries are layer_idx 1 and 3 -> all 4 layer dirs exist (window-save). + assert _layer_dir_names(tmp_path) == [ + "layer_0000", + "layer_0001", + "layer_0002", + "layer_0003", + ] + # next_inputs.pt is only at the boundary layers (the resume restart points). + assert not (tmp_path / "layer_0000" / "next_inputs.pt").exists() + assert (tmp_path / "layer_0001" / "next_inputs.pt").exists() + assert not (tmp_path / "layer_0002" / "next_inputs.pt").exists() + # Last layer never has next_inputs.pt (no subsequent layer). + assert not (tmp_path / "layer_0003" / "next_inputs.pt").exists() + + +def test_layerwise_save_quantizers_only_resume_matches_one_shot_amax(monkeypatch, tmp_path): + """End-to-end resume with ``save_quantizers_only=True`` matches a one-shot run. + + Run a full calibration with the flag enabled, then rewind the manifest to + ``last_completed_layer=0`` and re-run on a fresh model with the same + checkpoint dir. The resume path will: + 1. ``setup_resume`` reads only ``output_meta`` (no weights.pt needed). + 2. The main loop re-calibrates layers 1 and 2 from scratch. + 3. ``full_restore`` reloads layer 0 from disk — no ``weights.pt``, just + ``quantizer_buffers.pt`` (which carries ``_amax``). + Also asserts the on-disk shape (weights.pt absent, quantizer_buffers.pt + present) so a regression in the file layout fails here too. + """ + _register_test_discoverer(monkeypatch) + + def collect_amax(model): + return { + name: q._amax.clone().detach() + for name, q in model.named_modules() + if isinstance(q, TensorQuantizer) + and q.is_enabled + and getattr(q, "_amax", None) is not None + } + + def build_cfg(checkpoint_dir): + return _int8_layerwise_config( + { + "method": "max", + "layerwise": { + "enable": True, + "checkpoint_dir": str(checkpoint_dir), + "save_quantizers_only": True, + }, + } + ) + + calib_data = [torch.randint(0, 32, (2, 8))] + forward_loop = lambda m: [m(b) for b in calib_data] # noqa: E731 + + # One-shot baseline (separate checkpoint dir so it doesn't interfere). + baseline_dir = tmp_path / "baseline" + torch.manual_seed(0) + baseline_model = _SimpleTransformerModel(n_layers=3, dim=16) + mtq.quantize(baseline_model, build_cfg(baseline_dir), forward_loop=forward_loop) + baseline_amax = collect_amax(baseline_model) + assert baseline_amax, "baseline produced no amax values" + + # Full run into resume_dir, then rewind the manifest to simulate an + # interrupt after layer 0. + resume_dir = tmp_path / "resume" + torch.manual_seed(0) + setup_model = _SimpleTransformerModel(n_layers=3, dim=16) + mtq.quantize(setup_model, build_cfg(resume_dir), forward_loop=forward_loop) + + # On-disk shape: no weights.pt, quantizer_buffers.pt present per layer. + for name in _layer_dir_names(resume_dir): + d = resume_dir / name + assert not (d / "weights.pt").exists(), ( + f"{name}: weights.pt present but save_quantizers_only=True" + ) + assert (d / "quantizer_buffers.pt").exists(), ( + f"{name}: quantizer_buffers.pt missing under save_quantizers_only" + ) + + manifest_path = resume_dir / "manifest.json" + manifest_path.write_text(json.dumps({"last_completed_layer": 0, "num_layers": 3})) + + # Resume: fresh model, same seed, same checkpoint dir. Resume reads + # output_meta + quantizer_state for layer 0 (no weights.pt) and + # re-calibrates layers 1, 2. + torch.manual_seed(0) + resumed_model = _SimpleTransformerModel(n_layers=3, dim=16) + mtq.quantize(resumed_model, build_cfg(resume_dir), forward_loop=forward_loop) + + resumed_amax = collect_amax(resumed_model) + assert set(resumed_amax) == set(baseline_amax) + for name in baseline_amax: + torch.testing.assert_close( + resumed_amax[name], + baseline_amax[name], + rtol=1e-5, + atol=1e-6, + msg=f"amax mismatch after resume at {name}", + ) From dbca1a5a5a29da6394788a63ff3d49f7febfd827 Mon Sep 17 00:00:00 2001 From: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 29 May 2026 23:04:47 +0000 Subject: [PATCH 4/6] update changelog Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> --- CHANGELOG.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 337ed597301..ace0fb7d040 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -40,6 +40,7 @@ Changelog - Add mixed-precision FP8 + NVFP4 export for Megatron-Core: per-layer ``quant_algo`` recorded under ``quantized_layers`` in ``hf_quant_config.json``, PP-aware ``kv_cache_dtype`` gather, fused-QKV exclude split into per-HF-name ``q/k/v_proj`` entries. - Add Nemotron-3-Super-120B-A12B PTQ recipes ``modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml`` (MSE-mixed) and ``super-nvfp4-max-calib.yaml`` (max-calib mixed): NVFP4 W4A4 routed experts + FP8 per-tensor shared experts / Mamba in/out_proj + FP8 KV cache. - Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default). +- Group layerwise calibration options under a nested ``LayerwiseConfig`` and add three knobs: ``get_qdq_activations_from_prev_layer`` (correct GPTQ-Hessian vs max-calib activation semantics — defaults to True for GPTQ, False for max/mse/local_hessian), ``save_every`` (gate per-window ``next_inputs.pt`` activation-cache writes), and ``save_quantizers_only`` (skip the layer-weights blob for amax-only algorithms — whitelisted to ``max``/``mse``/``local_hessian``). Legacy bool ``layerwise`` and flat ``layerwise_checkpoint_dir`` keys still work; the bool form emits a ``DeprecationWarning``. **Bug Fixes** From eba56ee768aac8939c2da3a9d9cb7d0461cd163a Mon Sep 17 00:00:00 2001 From: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Date: Sat, 30 May 2026 00:05:32 +0000 Subject: [PATCH 5/6] apply reviewer feedbacks Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 9 +- modelopt/torch/quantization/config.py | 10 +- .../quantization/utils/layerwise_calib.py | 157 +++++----- .../quantization/test_config_validation.py | 19 +- .../quantization/test_layerwise_calibrate.py | 267 ++++++++++++------ 5 files changed, 288 insertions(+), 174 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 240f68a6531..d95d06d3349 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -892,8 +892,9 @@ def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> tuple[dict, str] resolved = os.path.join(base_dir, f"{name}_{config_hash}") quant_cfg = copy.deepcopy(quant_cfg) - if shape == "flat": - quant_cfg["algorithm"]["layerwise_checkpoint_dir"] = resolved - else: - quant_cfg["algorithm"]["layerwise"]["checkpoint_dir"] = resolved + algo = quant_cfg["algorithm"] + if "layerwise_checkpoint_dir" in algo: + algo["layerwise_checkpoint_dir"] = resolved + if isinstance(algo.get("layerwise"), dict) and "checkpoint_dir" in algo["layerwise"]: + algo["layerwise"]["checkpoint_dir"] = resolved return quant_cfg, resolved diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 10f7b5db8ff..b6bcdec1039 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -704,7 +704,9 @@ def _coerce_layerwise_input(value): if value is None: return {} if isinstance(value, LayerwiseConfig): - return value.model_dump() + # ``exclude_unset=True`` so downstream ``model_fields_set`` reflects the + # user's actual input + return value.model_dump(exclude_unset=True) return value @@ -755,6 +757,12 @@ def _migrate_layerwise_checkpoint_dir(cls, data): """ if not isinstance(data, dict) or "layerwise_checkpoint_dir" not in data: return data + warnings.warn( + "Passing `layerwise_checkpoint_dir` at the top level is deprecated; " + "nest it under `layerwise.checkpoint_dir` instead.", + DeprecationWarning, + stacklevel=2, + ) data = dict(data) flat_dir = data.pop("layerwise_checkpoint_dir") # Resolve the legacy ``use_sequential`` alias before writing ``layerwise``, diff --git a/modelopt/torch/quantization/utils/layerwise_calib.py b/modelopt/torch/quantization/utils/layerwise_calib.py index 8856bb6f17e..eeb47cf2820 100644 --- a/modelopt/torch/quantization/utils/layerwise_calib.py +++ b/modelopt/torch/quantization/utils/layerwise_calib.py @@ -473,13 +473,24 @@ def _read_manifest(checkpoint_dir: str) -> dict | None: return None -def _write_manifest(checkpoint_dir: str, last_completed_layer: int, num_layers: int) -> None: - """Atomically write manifest.json.""" +def _write_manifest( + checkpoint_dir: str, + last_completed_layer: int, + num_layers: int, + save_every: int, + save_quantizers_only: bool, +) -> None: + """Atomically write manifest.json. Config keys are persisted so resume can detect drift.""" path = os.path.join(checkpoint_dir, "manifest.json") tmp = path + ".tmp" with open(tmp, "w") as f: json.dump( - {"last_completed_layer": last_completed_layer, "num_layers": num_layers}, + { + "last_completed_layer": last_completed_layer, + "num_layers": num_layers, + "save_every": save_every, + "save_quantizers_only": save_quantizers_only, + }, f, ) os.replace(tmp, path) @@ -489,23 +500,21 @@ def _layer_dir(checkpoint_dir: str, idx: int) -> str: return os.path.join(checkpoint_dir, f"layer_{idx:04d}") -def _save_layer( +def _save_layer_files( checkpoint_dir: str, idx: int, weights: dict | None, qstate: dict, quantizer_buffers: dict | None, output_meta: tuple, - next_inputs: list | None, - num_layers: int, ) -> None: - """Save a single layer checkpoint and update the manifest atomically. + """Write the per-layer files for layer *idx*. - ``weights`` may be ``None`` when the caller opts into a quantizer-only save - (algorithms that do not mutate ``layer.weight``). In that case the per- - quantizer ``state_dict()`` slice — which carries ``_amax`` — is written to - ``quantizer_buffers.pt`` instead, so ``full_restore`` can still recover the - calibrated quantizer state without the full ``weights.pt``. + Exactly one of ``weights`` (full layer state_dict) or ``quantizer_buffers`` + (just the TensorQuantizer state_dict slice, used by ``save_quantizers_only``) + is written; ``full_restore`` falls back to whichever is present. + ``next_inputs.pt`` and ``manifest.json`` are deferred to window boundaries + in :meth:`_CheckpointState.save`. """ d = _layer_dir(checkpoint_dir, idx) if os.path.isdir(d): @@ -517,9 +526,6 @@ def _save_layer( torch.save(quantizer_buffers, os.path.join(d, "quantizer_buffers.pt")) torch.save(qstate, os.path.join(d, "quantizer_state.pt")) torch.save(output_meta, os.path.join(d, "output_meta.pt")) - if next_inputs is not None: - torch.save(next_inputs, os.path.join(d, "next_inputs.pt")) - _write_manifest(checkpoint_dir, idx, num_layers) def detect_resume_point(checkpoint_dir: str) -> tuple[int, dict] | None: @@ -591,12 +597,20 @@ def from_folder( os.makedirs(checkpoint_dir, exist_ok=True) info = detect_resume_point(checkpoint_dir) if info is not None: - manifest_num_layers = info[1].get("num_layers") - if manifest_num_layers is not None and manifest_num_layers != num_layers: - raise ValueError( - f"Checkpoint num_layers mismatch: manifest has {manifest_num_layers} " - f"but model has {num_layers}. Use a fresh checkpoint directory." - ) + manifest = info[1] + # Pre-0.45 manifests omit save_every / save_quantizers_only; skip the + # check for keys absent from the on-disk manifest. + for key, new_value in ( + ("num_layers", num_layers), + ("save_every", save_every), + ("save_quantizers_only", save_quantizers_only), + ): + ckpt_value = manifest.get(key) + if ckpt_value is not None and ckpt_value != new_value: + raise ValueError( + f"Checkpoint {key} mismatch: manifest has {ckpt_value!r} but " + f"new run uses {new_value!r}. Use a fresh checkpoint directory." + ) start = info[0] if info else 0 if start > 0: print_rank_0( @@ -696,22 +710,14 @@ def save( layers: nn.ModuleList, next_layer_inputs: list | None = None, ) -> None: - """Snapshot layer state and write checkpoint to disk. - - With ``save_every == 1`` (default), this writes a single layer per call. - With ``save_every > 1``, the call is a no-op except at window boundaries - (every Nth layer plus the final layer); at a boundary, the full window - of layers since the last save is flushed to disk so ``setup_resume`` can - replay the skip layers correctly. The large ``next_inputs.pt`` is only - written for the final layer in the window — the layer the next resume - would restart from. - - Args: - layer_idx: Index of the layer just calibrated. - model: The full model (needed for ``enable_weight_access_and_writeback``). - layers: The decoder layer list. - next_layer_inputs: Inputs for the layer following ``layer_idx`` - (``None`` for the final layer). + """Snapshot the just-calibrated layer; commit the window at boundaries. + + Each call reads state from ``layers[layer_idx]`` *before* the next + iteration's capture forward swaps it to a ``_SkipLayer``, so state is + always read from the real calibrated layer. Per-layer files are written + every call; ``next_inputs.pt`` and the manifest are deferred to window + boundaries so a mid-window crash leaves the manifest pointing at the + previous boundary. """ from modelopt.torch.quantization.conversion import quantizer_state from modelopt.torch.quantization.utils.core_utils import ( @@ -719,51 +725,54 @@ def save( get_quantizer_state_dict, ) + _cpu = torch.device("cpu") + layer = layers[layer_idx] + with enable_weight_access_and_writeback(layer, model): + qstate = _move_to_device(quantizer_state(layer), _cpu) + if self.save_quantizers_only: + weights = None + quantizer_buffers = _move_to_device(get_quantizer_state_dict(layer), _cpu) + else: + weights = _move_to_device(layer.state_dict(), _cpu) + quantizer_buffers = None + + output_meta = getattr(layer._layerwise_calib, "output_meta", None) + if output_meta is None: + # Final-layer placeholder: never consumed by skip mode (no successor). + output_meta = LayerActivationCollector._extract_output_meta(torch.zeros(1)) + + _save_layer_files( + self.checkpoint_dir, + layer_idx, + weights, + qstate, + quantizer_buffers, + _move_to_device(output_meta, _cpu), + ) + is_final = layer_idx + 1 == self.num_layers is_window_end = (layer_idx + 1) % self.save_every == 0 if not (is_final or is_window_end): return - _cpu = torch.device("cpu") - window_start = self._last_saved_layer + 1 - for i in range(window_start, layer_idx + 1): - is_last_in_window = i == layer_idx - layer_i = layers[i] - with enable_weight_access_and_writeback(layer_i, model): - qstate = _move_to_device(quantizer_state(layer_i), _cpu) - if self.save_quantizers_only: - # Save just the TensorQuantizer state_dict slice (carries _amax) - # — the layer's other weights weren't modified by the algorithm. - weights = None - quantizer_buffers = _move_to_device(get_quantizer_state_dict(layer_i), _cpu) - else: - weights = _move_to_device(layer_i.state_dict(), _cpu) - quantizer_buffers = None - - output_meta = getattr(layer_i._layerwise_calib, "output_meta", None) - if output_meta is None: - # Placeholder for the final layer: output_meta is never used for - # skip mode since there is no subsequent layer that needs a - # correctly shaped dummy output. - output_meta = LayerActivationCollector._extract_output_meta(torch.zeros(1)) - - _save_layer( - self.checkpoint_dir, - i, - weights, - qstate, - quantizer_buffers, - _move_to_device(output_meta, _cpu), - _move_to_device(next_layer_inputs, _cpu) - if is_last_in_window and next_layer_inputs is not None - else None, - self.num_layers, + # Window boundary: write next_inputs.pt + manifest to commit the window. + if next_layer_inputs is not None: + torch.save( + _move_to_device(next_layer_inputs, _cpu), + os.path.join(_layer_dir(self.checkpoint_dir, layer_idx), "next_inputs.pt"), ) - + _write_manifest( + self.checkpoint_dir, + layer_idx, + self.num_layers, + save_every=self.save_every, + save_quantizers_only=self.save_quantizers_only, + ) + window_start = self._last_saved_layer + 1 self._last_saved_layer = layer_idx window_size = layer_idx - window_start + 1 suffix = " (final)" if next_layer_inputs is None else "" if window_size > 1: - print_rank_0(f"Checkpoint: saved layers {window_start}..{layer_idx}{suffix}") + print_rank_0(f"Checkpoint: committed window {window_start}..{layer_idx}{suffix}") else: - print_rank_0(f"Checkpoint: saved layer {layer_idx}{suffix}") + print_rank_0(f"Checkpoint: committed layer {layer_idx}{suffix}") diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index c43bdedd09a..82b9d24d7a4 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -29,6 +29,7 @@ W4A8_AWQ_BETA_CFG, AWQLiteCalibConfig, GPTQCalibConfig, + LayerwiseConfig, LocalHessianCalibConfig, MaxCalibConfig, MseCalibConfig, @@ -617,9 +618,13 @@ def test_dict_form_no_deprecation(self): warnings.simplefilter("error", DeprecationWarning) MaxCalibConfig(layerwise={"enable": True}) - def test_flat_checkpoint_dir_migrated_to_nested(self): - with pytest.warns(DeprecationWarning): - cfg = MaxCalibConfig(layerwise=True, layerwise_checkpoint_dir="/x") + def test_flat_checkpoint_dir_migrated_with_deprecation(self): + """Legacy ``layerwise_checkpoint_dir`` is migrated into the nested config + and emits a deprecation warning naming the flat key (independent of the + bool-form deprecation tested above). + """ + with pytest.warns(DeprecationWarning, match="layerwise_checkpoint_dir.*deprecated"): + cfg = MaxCalibConfig(layerwise={"enable": True}, layerwise_checkpoint_dir="/x") assert cfg.layerwise.checkpoint_dir == "/x" def test_use_sequential_alias_survives_flat_checkpoint_migration(self): @@ -665,6 +670,14 @@ def test_per_algorithm_qdq_default(self, cfg_cls, expected_qdq): ), # User-explicit False overrides the GPTQ default. ({"enable": True, "get_qdq_activations_from_prev_layer": False}, False), + # ``LayerwiseConfig`` instance: ``_coerce_layerwise_input`` must + # preserve ``model_fields_set`` so the GPTQ default still kicks in + # for fields the user didn't explicitly set. + (LayerwiseConfig(enable=True), True), + ( + LayerwiseConfig(enable=True, get_qdq_activations_from_prev_layer=False), + False, + ), ], ) def test_gptq_qdq_default_respects_user_explicit_value(self, layerwise_input, expected_qdq): diff --git a/tests/unit/torch/quantization/test_layerwise_calibrate.py b/tests/unit/torch/quantization/test_layerwise_calibrate.py index 46552fa604a..c64eb464cff 100644 --- a/tests/unit/torch/quantization/test_layerwise_calibrate.py +++ b/tests/unit/torch/quantization/test_layerwise_calibrate.py @@ -722,11 +722,33 @@ def test_mtq_quantize_layerwise_raises_for_unsupported_algorithm(): ) +def _collect_amax(model): + return { + name: q._amax.clone().detach() + for name, q in model.named_modules() + if isinstance(q, TensorQuantizer) and q.is_enabled and getattr(q, "_amax", None) is not None + } + + +def _assert_amax_close(actual, expected, label): + assert set(actual) == set(expected), ( + f"Different quantizers populated for {label}: " + f"missing={set(expected) - set(actual)}, extra={set(actual) - set(expected)}" + ) + for name in expected: + torch.testing.assert_close( + actual[name], + expected[name], + rtol=1e-5, + atol=1e-6, + msg=f"{label}: amax mismatch at {name}", + ) + + def test_layerwise_no_qdq_matches_sequential_amax(monkeypatch): """Layerwise + ``get_qdq_activations_from_prev_layer=False`` must produce the same per-quantizer amax as the non-layerwise (sequential) max-calibration - flow. Both paths feed every layer full-precision activations, so the amax - statistics they collect must agree. + flow. Both paths feed every layer full-precision activations. """ _register_test_discoverer(monkeypatch) @@ -739,42 +761,21 @@ def fwd(m): for batch in calib_data: m(batch) - seq_cfg = _int8_layerwise_config({"method": "max"}) - mtq.quantize(model_seq, seq_cfg, forward_loop=fwd) - - lw_cfg = _int8_layerwise_config( - { - "method": "max", - "layerwise": {"enable": True, "get_qdq_activations_from_prev_layer": False}, - } + mtq.quantize(model_seq, _int8_layerwise_config({"method": "max"}), forward_loop=fwd) + mtq.quantize( + model_lw, + _int8_layerwise_config( + { + "method": "max", + "layerwise": {"enable": True, "get_qdq_activations_from_prev_layer": False}, + } + ), + forward_loop=fwd, ) - mtq.quantize(model_lw, lw_cfg, forward_loop=fwd) - - def collect_amax(model): - return { - name: q._amax.clone().detach() - for name, q in model.named_modules() - if isinstance(q, TensorQuantizer) - and q.is_enabled - and getattr(q, "_amax", None) is not None - } - - seq_amax = collect_amax(model_seq) - lw_amax = collect_amax(model_lw) + seq_amax = _collect_amax(model_seq) assert seq_amax, "sequential calibration populated no amax values" - assert set(seq_amax) == set(lw_amax), ( - f"Different quantizers populated: only-seq={set(seq_amax) - set(lw_amax)}, " - f"only-lw={set(lw_amax) - set(seq_amax)}" - ) - for name in seq_amax: - torch.testing.assert_close( - lw_amax[name], - seq_amax[name], - rtol=1e-5, - atol=1e-6, - msg=f"amax mismatch at {name}: seq={seq_amax[name]}, lw={lw_amax[name]}", - ) + _assert_amax_close(_collect_amax(model_lw), seq_amax, "layerwise vs sequential") def _layer_dir_names(checkpoint_dir): @@ -818,87 +819,169 @@ def test_layerwise_save_every_writes_next_inputs_only_at_window_boundaries(monke assert not (tmp_path / "layer_0003" / "next_inputs.pt").exists() -def test_layerwise_save_quantizers_only_resume_matches_one_shot_amax(monkeypatch, tmp_path): - """End-to-end resume with ``save_quantizers_only=True`` matches a one-shot run. - - Run a full calibration with the flag enabled, then rewind the manifest to - ``last_completed_layer=0`` and re-run on a fresh model with the same - checkpoint dir. The resume path will: - 1. ``setup_resume`` reads only ``output_meta`` (no weights.pt needed). - 2. The main loop re-calibrates layers 1 and 2 from scratch. - 3. ``full_restore`` reloads layer 0 from disk — no ``weights.pt``, just - ``quantizer_buffers.pt`` (which carries ``_amax``). - Also asserts the on-disk shape (weights.pt absent, quantizer_buffers.pt - present) so a regression in the file layout fails here too. +@pytest.mark.parametrize( + ("scenario", "n_layers", "save_every", "save_quantizers_only", "rewind_to"), + [ + # Pins the quantizer_buffers.pt restore path (no weights.pt on disk). + ("quantizers_only", 3, 1, True, 0), + # Pins the per-call snapshot fix: each save() captures the + # just-calibrated layer's state before the next-layer capture forward + # swaps it to _SkipLayer. + ("save_every", 4, 2, False, 1), + ], +) +def test_layerwise_checkpoint_resume_matches_one_shot_amax( + monkeypatch, tmp_path, scenario, n_layers, save_every, save_quantizers_only, rewind_to +): + """Full run → rewind manifest → fresh resume reproduces one-shot ``_amax``. + + Single test covering both checkpoint optimizations. For the + ``save_quantizers_only`` case also asserts the on-disk shape (no + ``weights.pt``, ``quantizer_buffers.pt`` present per layer). """ _register_test_discoverer(monkeypatch) - def collect_amax(model): - return { - name: q._amax.clone().detach() - for name, q in model.named_modules() - if isinstance(q, TensorQuantizer) - and q.is_enabled - and getattr(q, "_amax", None) is not None - } + calib_data = [torch.randint(0, 32, (2, 8))] + forward_loop = lambda m: [m(b) for b in calib_data] # noqa: E731 - def build_cfg(checkpoint_dir): + def build_cfg(ckpt_dir): return _int8_layerwise_config( { "method": "max", "layerwise": { "enable": True, - "checkpoint_dir": str(checkpoint_dir), - "save_quantizers_only": True, + "checkpoint_dir": str(ckpt_dir), + "save_every": save_every, + "save_quantizers_only": save_quantizers_only, }, } ) - calib_data = [torch.randint(0, 32, (2, 8))] - forward_loop = lambda m: [m(b) for b in calib_data] # noqa: E731 - - # One-shot baseline (separate checkpoint dir so it doesn't interfere). baseline_dir = tmp_path / "baseline" torch.manual_seed(0) - baseline_model = _SimpleTransformerModel(n_layers=3, dim=16) + baseline_model = _SimpleTransformerModel(n_layers=n_layers, dim=16) mtq.quantize(baseline_model, build_cfg(baseline_dir), forward_loop=forward_loop) - baseline_amax = collect_amax(baseline_model) - assert baseline_amax, "baseline produced no amax values" + baseline_amax = _collect_amax(baseline_model) + assert baseline_amax - # Full run into resume_dir, then rewind the manifest to simulate an - # interrupt after layer 0. resume_dir = tmp_path / "resume" torch.manual_seed(0) - setup_model = _SimpleTransformerModel(n_layers=3, dim=16) + setup_model = _SimpleTransformerModel(n_layers=n_layers, dim=16) mtq.quantize(setup_model, build_cfg(resume_dir), forward_loop=forward_loop) - # On-disk shape: no weights.pt, quantizer_buffers.pt present per layer. - for name in _layer_dir_names(resume_dir): - d = resume_dir / name - assert not (d / "weights.pt").exists(), ( - f"{name}: weights.pt present but save_quantizers_only=True" - ) - assert (d / "quantizer_buffers.pt").exists(), ( - f"{name}: quantizer_buffers.pt missing under save_quantizers_only" - ) + if scenario == "quantizers_only": + for name in _layer_dir_names(resume_dir): + d = resume_dir / name + assert not (d / "weights.pt").exists() + assert (d / "quantizer_buffers.pt").exists() - manifest_path = resume_dir / "manifest.json" - manifest_path.write_text(json.dumps({"last_completed_layer": 0, "num_layers": 3})) + (resume_dir / "manifest.json").write_text( + json.dumps( + { + "last_completed_layer": rewind_to, + "num_layers": n_layers, + "save_every": save_every, + "save_quantizers_only": save_quantizers_only, + } + ) + ) - # Resume: fresh model, same seed, same checkpoint dir. Resume reads - # output_meta + quantizer_state for layer 0 (no weights.pt) and - # re-calibrates layers 1, 2. torch.manual_seed(0) - resumed_model = _SimpleTransformerModel(n_layers=3, dim=16) + resumed_model = _SimpleTransformerModel(n_layers=n_layers, dim=16) mtq.quantize(resumed_model, build_cfg(resume_dir), forward_loop=forward_loop) - resumed_amax = collect_amax(resumed_model) - assert set(resumed_amax) == set(baseline_amax) - for name in baseline_amax: - torch.testing.assert_close( - resumed_amax[name], - baseline_amax[name], - rtol=1e-5, - atol=1e-6, - msg=f"amax mismatch after resume at {name}", + _assert_amax_close(_collect_amax(resumed_model), baseline_amax, f"{scenario} resume") + + +def test_layerwise_save_every_mid_window_crash_recovers_at_prev_boundary(monkeypatch, tmp_path): + """A crash inside a window must not advance ``last_completed_layer``; resume + re-calibrates the unfinished window from the previous boundary. + + Monkeypatches ``torch.save`` to raise on the second window's first per-layer + write (layer 2), then asserts the on-disk manifest still points at layer 1 + (the previous boundary). + """ + _register_test_discoverer(monkeypatch) + + cfg = _int8_layerwise_config( + { + "method": "max", + "layerwise": { + "enable": True, + "checkpoint_dir": str(tmp_path), + "save_every": 2, + }, + } + ) + torch.manual_seed(0) + model = _SimpleTransformerModel(n_layers=4, dim=16) + calib_data = [torch.randint(0, 32, (2, 8))] + + real_torch_save = torch.save + state = {"crashed": False} + + def crashing_torch_save(obj, path, *args, **kwargs): + # Crash on the first per-layer file write for layer 2 (mid-window). + if not state["crashed"] and "layer_0002" in str(path): + state["crashed"] = True + raise RuntimeError("simulated crash during layer 2 save") + return real_torch_save(obj, path, *args, **kwargs) + + monkeypatch.setattr( + "modelopt.torch.quantization.utils.layerwise_calib.torch.save", + crashing_torch_save, + ) + with pytest.raises(RuntimeError, match="simulated crash"): + mtq.quantize(model, cfg, forward_loop=lambda m: [m(b) for b in calib_data]) + + manifest = json.loads((tmp_path / "manifest.json").read_text()) + assert manifest["last_completed_layer"] == 1, f"manifest leaked mid-window state: {manifest}" + + +def test_layerwise_checkpoint_mismatch_save_every_raises(monkeypatch, tmp_path): + """Resuming with a different ``save_every`` than the checkpoint was produced + with must raise — the on-disk window layout assumes a fixed value. + """ + _register_test_discoverer(monkeypatch) + + cfg_first = _int8_layerwise_config( + { + "method": "max", + "layerwise": { + "enable": True, + "checkpoint_dir": str(tmp_path), + "save_every": 2, + }, + } + ) + torch.manual_seed(0) + model = _SimpleTransformerModel(n_layers=4, dim=16) + calib_data = [torch.randint(0, 32, (2, 8))] + mtq.quantize(model, cfg_first, forward_loop=lambda m: [m(b) for b in calib_data]) + + # Rewind manifest so the second run sees an in-progress resume, + # then change save_every to trigger the mismatch check. + (tmp_path / "manifest.json").write_text( + json.dumps( + { + "last_completed_layer": 1, + "num_layers": 4, + "save_every": 2, + "save_quantizers_only": False, + } ) + ) + cfg_mismatched = _int8_layerwise_config( + { + "method": "max", + "layerwise": { + "enable": True, + "checkpoint_dir": str(tmp_path), + "save_every": 4, + }, + } + ) + torch.manual_seed(0) + fresh_model = _SimpleTransformerModel(n_layers=4, dim=16) + with pytest.raises(ValueError, match="save_every mismatch"): + mtq.quantize(fresh_model, cfg_mismatched, forward_loop=lambda m: [m(b) for b in calib_data]) From af1bb4f2b5485a5a3c9413b999e9121ed23c7281 Mon Sep 17 00:00:00 2001 From: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Date: Mon, 1 Jun 2026 15:36:56 +0000 Subject: [PATCH 6/6] Capture next-layer inputs before calib_func under qdq_from_prev=False Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 10 ++-- modelopt/torch/quantization/model_calib.py | 40 ++++++++-------- .../quantization/test_layerwise_calibrate.py | 46 +++++++++++++++++++ 3 files changed, 71 insertions(+), 25 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index b6bcdec1039..726a78d129b 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -650,12 +650,10 @@ class LayerwiseConfig(ModeloptBaseConfig): default=False, title="Cache next-layer inputs from QDQ outputs of prior layers.", description=( - "If True (GPTQ default), layer N's calibration sees inputs carrying " - "the quantize-dequantize error of layers 0..N-1, so quantization " - "error compounds across layers. If False (max-calib default), " - "quantizers are temporarily disabled during the capture forward, so " - "layer N sees the same full-precision activations as a non-layerwise " - "calibration pass." + "If True (GPTQ default), capture each layer's next-layer inputs " + "after it is calibrated, so QDQ error and in-place weight updates " + "propagate forward. If False (max/mse default), capture before, so " + "the next layer sees the same FP activations as a non-layerwise pass." ), ) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 017e75840ee..b35beef50d1 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -19,7 +19,6 @@ import time import warnings from collections.abc import Callable -from contextlib import AbstractContextManager, nullcontext from functools import partial from typing import TypeAlias @@ -1821,28 +1820,31 @@ def _layer_forward_loop(m, _inputs=layer_inputs): kwargs_input["past_key_values"] = None m(*args, **kwargs_input) - with persistent_materialization(layer): - calib_func(layer, _layer_forward_loop, **calib_kwargs) - - # Run one more forward to get next layer's inputs and set - # output_meta on the just-calibrated layer (via "run" mode). is_last = layer_idx + 1 >= num_layers - if not is_last: - # When qdq_from_prev is False, temporarily disable every quantizer - # under the just-calibrated layer so the next layer sees full-precision - # activations (matches non-layerwise calibration semantics). - capture_ctx: AbstractContextManager = ( - nullcontext() - if qdq_from_prev - else set_quantizer_by_cfg_context( - layer, [{"quantizer_name": "*", "enable": False}] - ) - ) - with capture_ctx: + + # qdq_from_prev=False: capture before calib_func so the forward + # replay uses the original FP weights. Disable quantizers too in + # case any pre-calibration observer behavior would perturb the + # captured activations. + if not is_last and not qdq_from_prev: + with set_quantizer_by_cfg_context( + layer, [{"quantizer_name": "*", "enable": False}] + ): next_inputs = input_getter.cache_outputs_for_next_layer_calib( layer, forward_loop ) - else: + # cache_outputs left this layer in "run" mode with an empty + # deque; reset so calib_func's replay hits the real forward. + layer._layerwise_calib.mode = "original" + + with persistent_materialization(layer): + calib_func(layer, _layer_forward_loop, **calib_kwargs) + + # qdq_from_prev=True: capture after calib_func so the next layer + # sees QDQ error and any in-place weight updates from this layer. + if not is_last and qdq_from_prev: + next_inputs = input_getter.cache_outputs_for_next_layer_calib(layer, forward_loop) + elif is_last: next_inputs = None if ckpt: diff --git a/tests/unit/torch/quantization/test_layerwise_calibrate.py b/tests/unit/torch/quantization/test_layerwise_calibrate.py index c64eb464cff..5ec4cc2daed 100644 --- a/tests/unit/torch/quantization/test_layerwise_calibrate.py +++ b/tests/unit/torch/quantization/test_layerwise_calibrate.py @@ -778,6 +778,52 @@ def fwd(m): _assert_amax_close(_collect_amax(model_lw), seq_amax, "layerwise vs sequential") +def test_layerwise_no_qdq_captures_inputs_before_calib_func_mutates_weights(monkeypatch): + """A destructive ``calib_func`` (zeros weights) must not affect what is + captured for downstream layers under ``qdq_from_prev=False`` — otherwise + weight-mutating algorithms (GPTQ/AWQ/SmoothQuant) silently propagate + updates forward and break the "identical to non-layerwise pass" contract. + """ + _register_test_discoverer(monkeypatch) + calib_data = [torch.randint(0, 32, (2, 8))] + + def run_and_capture(calib_func): + torch.manual_seed(0) + model = _SimpleTransformerModel(n_layers=3, dim=16) + captured: dict[int, torch.Tensor] = {} + real = LayerActivationCollector.cache_outputs_for_next_layer_calib + + def spy(self, layer, fwd): + result = real(self, layer, fwd) + captured[self._layer_to_idx[layer] + 1] = result[0][0][0].clone().detach() + return result + + with monkeypatch.context() as m: + m.setattr(LayerActivationCollector, "cache_outputs_for_next_layer_calib", spy) + layerwise_calibrate( + model, + forward_loop=lambda mm: [mm(b) for b in calib_data], + calib_func=calib_func, + get_qdq_activations_from_prev_layer=False, + ) + return captured + + def identity(layer, fwd, **_): + fwd(layer) + + def destructive(layer, fwd, **_): + fwd(layer) + for sub in layer.modules(): + if isinstance(sub, nn.Linear): + sub.weight.data.zero_() + + benign = run_and_capture(identity) + mutated = run_and_capture(destructive) + + for i in (1, 2): + torch.testing.assert_close(mutated[i], benign[i], rtol=1e-5, atol=1e-6) + + def _layer_dir_names(checkpoint_dir): return sorted(p.name for p in checkpoint_dir.iterdir() if p.name.startswith("layer_"))