diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 337ed597301..ace0fb7d040 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -40,6 +40,7 @@ Changelog - Add mixed-precision FP8 + NVFP4 export for Megatron-Core: per-layer ``quant_algo`` recorded under ``quantized_layers`` in ``hf_quant_config.json``, PP-aware ``kv_cache_dtype`` gather, fused-QKV exclude split into per-HF-name ``q/k/v_proj`` entries. - Add Nemotron-3-Super-120B-A12B PTQ recipes ``modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml`` (MSE-mixed) and ``super-nvfp4-max-calib.yaml`` (max-calib mixed): NVFP4 W4A4 routed experts + FP8 per-tensor shared experts / Mamba in/out_proj + FP8 KV cache. - Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default). +- Group layerwise calibration options under a nested ``LayerwiseConfig`` and add three knobs: ``get_qdq_activations_from_prev_layer`` (correct GPTQ-Hessian vs max-calib activation semantics — defaults to True for GPTQ, False for max/mse/local_hessian), ``save_every`` (gate per-window ``next_inputs.pt`` activation-cache writes), and ``save_quantizers_only`` (skip the layer-weights blob for amax-only algorithms — whitelisted to ``max``/``mse``/``local_hessian``). Legacy bool ``layerwise`` and flat ``layerwise_checkpoint_dir`` keys still work; the bool form emits a ``DeprecationWarning``. **Bug Fixes** diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index ec4bb2b0519..d95d06d3349 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -850,22 +850,37 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod print("No custom model files found to copy") -def needs_checkpoint_path_update(quant_cfg: dict) -> bool: - """Check if quant_cfg has a layerwise_checkpoint_dir that should be auto-resolved to a unique subpath.""" - algorithm = quant_cfg.get("algorithm") +def _layerwise_checkpoint_dir_location(algorithm) -> tuple[str, str] | None: + """Return ``("flat"/"nested", checkpoint_dir)`` for the layerwise checkpoint dir, or None.""" if not isinstance(algorithm, dict): - return False - return algorithm.get("layerwise_checkpoint_dir") is not None + return None + flat = algorithm.get("layerwise_checkpoint_dir") + if flat is not None: + return "flat", flat + nested = algorithm.get("layerwise") or {} + ckpt = nested.get("checkpoint_dir") if isinstance(nested, dict) else None + return ("nested", ckpt) if ckpt is not None else None + + +def needs_checkpoint_path_update(quant_cfg: dict) -> bool: + """Check if quant_cfg has a layerwise checkpoint_dir that should be auto-resolved to a unique subpath.""" + return _layerwise_checkpoint_dir_location(quant_cfg.get("algorithm")) is not None -def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> dict: - """Append a unique ``_`` subdirectory to layerwise_checkpoint_dir. +def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> tuple[dict, str]: + """Append a unique ``_`` subdirectory to the layerwise checkpoint_dir. Allows a single recipe to be reused across models without checkpoint collisions. + Supports both the legacy flat ``layerwise_checkpoint_dir`` and the nested + ``layerwise.checkpoint_dir`` shape, writing back to whichever the user provided. Must only be called when :func:`needs_checkpoint_path_update` returns True. + + Returns ``(updated_quant_cfg, resolved_path)`` so the caller can log or + reference the resolved path without re-deriving the dict shape. """ - algorithm = quant_cfg["algorithm"] - base_dir = algorithm["layerwise_checkpoint_dir"] + location = _layerwise_checkpoint_dir_location(quant_cfg["algorithm"]) + assert location is not None # guaranteed by needs_checkpoint_path_update + shape, base_dir = location name = model_path.rstrip("/") if "/" in name and not os.path.isabs(name): @@ -874,9 +889,12 @@ def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> dict: name = Path(name).name config_hash = hashlib.sha256(json.dumps(quant_cfg, default=str).encode()).hexdigest()[:8] + resolved = os.path.join(base_dir, f"{name}_{config_hash}") quant_cfg = copy.deepcopy(quant_cfg) - quant_cfg["algorithm"]["layerwise_checkpoint_dir"] = os.path.join( - base_dir, f"{name}_{config_hash}" - ) - return quant_cfg + algo = quant_cfg["algorithm"] + if "layerwise_checkpoint_dir" in algo: + algo["layerwise_checkpoint_dir"] = resolved + if isinstance(algo.get("layerwise"), dict) and "checkpoint_dir" in algo["layerwise"]: + algo["layerwise"]["checkpoint_dir"] = resolved + return quant_cfg, resolved diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 6d27aa593f6..7483deb10c0 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -1000,7 +1000,8 @@ def _is_layerwise(obj): return _is_layerwise(obj.quantize.algorithm) if isinstance(obj, list): return any(_is_layerwise(a) for a in obj) - return bool(getattr(obj, "layerwise", False)) + layerwise = getattr(obj, "layerwise", None) + return bool(getattr(layerwise, "enable", False)) is_layerwise = _is_layerwise(recipe) @@ -1135,10 +1136,8 @@ def _is_layerwise(obj): _set_kv_cache_constant_amax(quant_cfg["quant_cfg"]) if needs_checkpoint_path_update(quant_cfg): - quant_cfg = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path) - print( - f"Auto-resolved layerwise_checkpoint_dir: {quant_cfg['algorithm']['layerwise_checkpoint_dir']}" - ) + quant_cfg, resolved_dir = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path) + print(f"Auto-resolved layerwise checkpoint_dir: {resolved_dir}") if args.cast_mxfp4_to_nvfp4: quant_cfg = copy.deepcopy(quant_cfg) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 602375d2bf6..726a78d129b 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -152,9 +152,9 @@ import warnings from collections.abc import Mapping, Sequence -from typing import Any, Literal +from typing import Any, ClassVar, Literal -from pydantic import AliasChoices, ValidationInfo, field_validator, model_validator +from pydantic import AliasChoices, Field, ValidationInfo, field_validator, model_validator from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.opt.config_loader import load_config @@ -633,9 +633,88 @@ def validate_calibrator(cls, v, info: ValidationInfo): ) +class LayerwiseConfig(ModeloptBaseConfig): + """Nested config for layer-by-layer calibration behavior.""" + + enable: bool = ModeloptField( + default=False, + title="Enable layerwise (layer-by-layer) calibration.", + description=( + "If True, the calibration algorithm is applied layer by layer. " + "Each layer's inputs are captured via a forward pass that reflects the " + "quantization of all preceding layers, incurring O(N) forward passes for N layers." + ), + ) + + get_qdq_activations_from_prev_layer: bool = ModeloptField( + default=False, + title="Cache next-layer inputs from QDQ outputs of prior layers.", + description=( + "If True (GPTQ default), capture each layer's next-layer inputs " + "after it is calibrated, so QDQ error and in-place weight updates " + "propagate forward. If False (max/mse default), capture before, so " + "the next layer sees the same FP activations as a non-layerwise pass." + ), + ) + + checkpoint_dir: str | None = ModeloptField( + default=None, + title="Per-layer checkpoint directory (resume on restart).", + description=( + "If set, per-layer checkpoints are saved here during calibration. " + "On restart, calibration resumes from the last completed layer." + ), + ) + + save_every: int = ModeloptField( + default=1, + ge=1, + title="Flush resume metadata every N layers (final layer always flushes).", + description=( + "Only the boundary layer of each window writes the large " + "``next_inputs.pt`` activation cache; other per-layer files are " + "still written for every layer (resume needs them to replay skips). " + "Mid-window interrupts re-calibrate the unfinished window on resume." + ), + ) + + save_quantizers_only: bool = ModeloptField( + default=False, + title="Skip the per-layer weights blob; persist only quantizer state.", + description=( + "Only accepted by algorithms that update solely ``TensorQuantizer._amax`` " + "(max, mse, local_hessian). Rejected for weight-mutating algorithms " + "(GPTQ, AWQ, SmoothQuant) where it would silently lose updates on resume." + ), + ) + + +def _coerce_layerwise_input(value): + """Normalize a raw ``layerwise`` value to a dict; warn on deprecated bool.""" + if isinstance(value, bool): + warnings.warn( + "Passing the layerwise field as a bool is deprecated; use a dict, " + "e.g. `{'enable': True}`.", + DeprecationWarning, + stacklevel=2, + ) + return {"enable": value} + if value is None: + return {} + if isinstance(value, LayerwiseConfig): + # ``exclude_unset=True`` so downstream ``model_fields_set`` reflects the + # user's actual input + return value.model_dump(exclude_unset=True) + return value + + class QuantizeAlgorithmConfig(ModeloptBaseConfig): """Calibration algorithm config base.""" + # Set True only for algorithms that update solely ``TensorQuantizer._amax`` + # (no ``layer.weight`` mutation). Gates ``layerwise.save_quantizers_only``. + _supports_save_quantizers_only: ClassVar[bool] = False + method: Literal[None] = ModeloptField( None, title="This field specifies the name of the calibration algorithm. If None, no calibration is performed.", @@ -656,34 +735,72 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): ), ) - layerwise: bool = ModeloptField( - default=False, + layerwise: LayerwiseConfig = Field( + default_factory=LayerwiseConfig, validation_alias=AliasChoices("layerwise", "use_sequential"), - title="Enable layerwise (layer-by-layer) calibration.", + title="Layerwise calibration configuration.", description=( - "If True, the calibration algorithm is applied layer by layer. " - "Each layer's inputs are captured via a forward pass that reflects the " - "quantization of all preceding layers, incurring O(N) forward passes for N layers." + "Nested config controlling layer-by-layer calibration. Pass a dict, " + "e.g. ``{'enable': True, 'checkpoint_dir': '/path'}``. Bool input is " + "accepted for backward compatibility but deprecated." ), ) - layerwise_checkpoint_dir: str | None = ModeloptField( - default=None, - title="Checkpoint directory for layerwise calibration.", - description=( - "If set together with layerwise=True, per-layer checkpoints are saved to this " - "directory during calibration. On restart, calibration resumes from the last " - "completed layer." - ), - ) + @model_validator(mode="before") + @classmethod + def _migrate_layerwise_checkpoint_dir(cls, data): + """Merge the legacy flat ``layerwise_checkpoint_dir`` key into ``layerwise``. + + Raises if both the flat key and a nested ``checkpoint_dir`` are set with conflicting values. + """ + if not isinstance(data, dict) or "layerwise_checkpoint_dir" not in data: + return data + warnings.warn( + "Passing `layerwise_checkpoint_dir` at the top level is deprecated; " + "nest it under `layerwise.checkpoint_dir` instead.", + DeprecationWarning, + stacklevel=2, + ) + data = dict(data) + flat_dir = data.pop("layerwise_checkpoint_dir") + # Resolve the legacy ``use_sequential`` alias before writing ``layerwise``, + # otherwise the alias value is silently dropped when AliasChoices picks the + # newly-written ``layerwise`` key over ``use_sequential``. + raw_layerwise = data.pop("layerwise", data.pop("use_sequential", None)) + layerwise = _coerce_layerwise_input(raw_layerwise) + existing = layerwise.get("checkpoint_dir") + if existing is not None and existing != flat_dir: + raise ValueError( + f"Conflicting checkpoint_dir: layerwise_checkpoint_dir={flat_dir!r} " + f"differs from layerwise.checkpoint_dir={existing!r}. Set only one." + ) + data["layerwise"] = {**layerwise, "checkpoint_dir": flat_dir} + return data + + @field_validator("layerwise", mode="before") + @classmethod + def _coerce_layerwise(cls, value): + """Coerce ``layerwise=bool/None`` to dict form; also handles the alias path.""" + return _coerce_layerwise_input(value) @model_validator(mode="after") def validate_layerwise_checkpoint_dir(self): - """Raise if layerwise_checkpoint_dir is set but layerwise is False.""" - if self.layerwise_checkpoint_dir is not None and not self.layerwise: + """Raise if layerwise.checkpoint_dir is set but layerwise.enable is False.""" + if self.layerwise.checkpoint_dir is not None and not self.layerwise.enable: raise ValueError( - "layerwise_checkpoint_dir requires layerwise=True. " - "Set layerwise=True or remove layerwise_checkpoint_dir." + "layerwise.checkpoint_dir requires layerwise.enable=True. " + "Set layerwise.enable=True or remove layerwise.checkpoint_dir." + ) + return self + + @model_validator(mode="after") + def _validate_save_quantizers_only_supported(self): + """Enforce the ``_supports_save_quantizers_only`` whitelist.""" + if self.layerwise.save_quantizers_only and not self._supports_save_quantizers_only: + raise ValueError( + f"Algorithm '{self.method}' mutates layer weights in-place; " + "save_quantizers_only=True would lose those updates on resume. " + "Only max/mse/local_hessian (amax-only) support this flag." ) return self @@ -696,6 +813,8 @@ class MaxCalibConfig(QuantizeAlgorithmConfig): See `Integer Quantization `_ for the concepts. """ + _supports_save_quantizers_only: ClassVar[bool] = True + method: Literal["max"] = ModeloptField("max") distributed_sync: bool | None = ModeloptField( @@ -727,6 +846,8 @@ class MseCalibConfig(QuantizeAlgorithmConfig): When fp8_scale_sweep is enabled, step_size is ignored. """ + _supports_save_quantizers_only: ClassVar[bool] = True + method: Literal["mse"] = ModeloptField("mse") step_size: float | None = ModeloptField( @@ -779,6 +900,8 @@ class LocalHessianCalibConfig(QuantizeAlgorithmConfig): """ + _supports_save_quantizers_only: ClassVar[bool] = True + method: Literal["local_hessian"] = ModeloptField("local_hessian") step_size: float | None = ModeloptField( @@ -996,6 +1119,21 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig): per-column error propagation into one launch per GPTQ block.""", ) + @model_validator(mode="after") + def _gptq_qdq_default(self): + """Inject ``get_qdq_activations_from_prev_layer=True`` unless the user set it. + + GPTQ's Hessian correctness depends on prior-layer QDQ activations, so the + default differs from the base class. Uses ``model_fields_set`` to detect + whether the user explicitly set the field — covers every input shape + (empty constructor, bool, dict) without a per-shape special case. + """ + if "get_qdq_activations_from_prev_layer" not in self.layerwise.model_fields_set: + self.layerwise = self.layerwise.model_copy( + update={"get_qdq_activations_from_prev_layer": True} + ) + return self + QuantizeQuantCfgType = list[QuantizerCfgEntry] QuantizerCfgListConfig = QuantizeQuantCfgType diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 530e90aaa00..28cadf3aa7f 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -223,8 +223,12 @@ def wrapped_calib_func( """ kwargs = config.model_dump() method = kwargs.pop("method") - layerwise = kwargs.pop("layerwise", False) - checkpoint_dir = kwargs.pop("layerwise_checkpoint_dir", None) + layerwise_cfg = kwargs.pop("layerwise", None) or {} + layerwise = layerwise_cfg.get("enable", False) + checkpoint_dir = layerwise_cfg.get("checkpoint_dir") + qdq_from_prev = layerwise_cfg.get("get_qdq_activations_from_prev_layer", False) + save_every = layerwise_cfg.get("save_every", 1) + save_quantizers_only = layerwise_cfg.get("save_quantizers_only", False) if method is not None and "awq" in method: # For backward compatibility kwargs["algorithm"] = method @@ -244,8 +248,8 @@ def wrapped_calib_func( # future algorithms that need full-model context must add a guard here. if not supports_layerwise: raise ValueError( - f"Calibration algorithm '{method}' does not support layerwise=True. " - "Set layerwise=False, or override `_supports_layerwise = True` on the " + f"Calibration algorithm '{method}' does not support layerwise.enable=True. " + "Set layerwise.enable=False, or override `_supports_layerwise = True` on the " "corresponding CalibrateModeDescriptor once the algorithm is made " "compatible with per-layer calibration." ) @@ -257,6 +261,9 @@ def wrapped_calib_func( forward_loop=forward_loop, calib_func=func, checkpoint_dir=checkpoint_dir, + get_qdq_activations_from_prev_layer=qdq_from_prev, + save_every=save_every, + save_quantizers_only=save_quantizers_only, **kwargs, ) else: diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 978f660f9d1..b35beef50d1 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1752,8 +1752,16 @@ def layerwise_calibrate( If ``checkpoint_dir`` is passed (via ``calib_kwargs``), per-layer checkpoints are saved after each layer completes. On restart, calibration resumes from the last completed layer. + + ``get_qdq_activations_from_prev_layer`` (via ``calib_kwargs``) controls + whether the cached inputs handed to layer N+1 come from a forward through + the just-calibrated layer with quantizers active (True; e.g. GPTQ) or + temporarily disabled (False; matches non-layerwise max-calib semantics). """ checkpoint_dir = calib_kwargs.pop("checkpoint_dir", None) + qdq_from_prev = calib_kwargs.pop("get_qdq_activations_from_prev_layer", False) + save_every = calib_kwargs.pop("save_every", 1) + save_quantizers_only = calib_kwargs.pop("save_quantizers_only", False) if forward_loop is None: raise ValueError( @@ -1771,7 +1779,12 @@ def layerwise_calibrate( num_layers = len(transformer_layers) print_rank_0(f"Layerwise calibration: Found {num_layers} transformer layers") - ckpt = _CheckpointState.from_folder(checkpoint_dir, num_layers) + ckpt = _CheckpointState.from_folder( + checkpoint_dir, + num_layers, + save_every=save_every, + save_quantizers_only=save_quantizers_only, + ) start_layer = ckpt.start_layer if ckpt else 0 input_getter = LayerActivationCollector(model) @@ -1807,19 +1820,35 @@ def _layer_forward_loop(m, _inputs=layer_inputs): kwargs_input["past_key_values"] = None m(*args, **kwargs_input) + is_last = layer_idx + 1 >= num_layers + + # qdq_from_prev=False: capture before calib_func so the forward + # replay uses the original FP weights. Disable quantizers too in + # case any pre-calibration observer behavior would perturb the + # captured activations. + if not is_last and not qdq_from_prev: + with set_quantizer_by_cfg_context( + layer, [{"quantizer_name": "*", "enable": False}] + ): + next_inputs = input_getter.cache_outputs_for_next_layer_calib( + layer, forward_loop + ) + # cache_outputs left this layer in "run" mode with an empty + # deque; reset so calib_func's replay hits the real forward. + layer._layerwise_calib.mode = "original" + with persistent_materialization(layer): calib_func(layer, _layer_forward_loop, **calib_kwargs) - # Run one more forward to get next layer's inputs and set - # output_meta on the just-calibrated layer (via "run" mode). - is_last = layer_idx + 1 >= num_layers - if not is_last: + # qdq_from_prev=True: capture after calib_func so the next layer + # sees QDQ error and any in-place weight updates from this layer. + if not is_last and qdq_from_prev: next_inputs = input_getter.cache_outputs_for_next_layer_calib(layer, forward_loop) - else: + elif is_last: next_inputs = None if ckpt: - ckpt.save(layer_idx, layer, model, transformer_layers, next_inputs) + ckpt.save(layer_idx, model, transformer_layers, next_inputs) del layer_inputs torch.cuda.empty_cache() @@ -1843,13 +1872,13 @@ def gptq( ): """GPTQ quantization. - Works in two modes depending on ``layerwise`` in the config: + Works in two modes depending on ``layerwise.enable`` in the config: - * **Layerwise** (``layerwise=True``): ``layerwise_calibrate`` calls this - function once per decoder layer with updated activations, producing more - accurate Hessian estimates. - * **Non-layerwise** (``layerwise=False``): called once on the full model. - All layers are quantized in parallel from the original activations. + * **Layerwise** (``layerwise.enable=True``): ``layerwise_calibrate`` calls + this function once per decoder layer with updated activations, producing + more accurate Hessian estimates. + * **Non-layerwise** (``layerwise.enable=False``): called once on the full + model. All layers are quantized in parallel from the original activations. Per-module steps: diff --git a/modelopt/torch/quantization/utils/layerwise_calib.py b/modelopt/torch/quantization/utils/layerwise_calib.py index aed403ad87b..eeb47cf2820 100644 --- a/modelopt/torch/quantization/utils/layerwise_calib.py +++ b/modelopt/torch/quantization/utils/layerwise_calib.py @@ -473,13 +473,24 @@ def _read_manifest(checkpoint_dir: str) -> dict | None: return None -def _write_manifest(checkpoint_dir: str, last_completed_layer: int, num_layers: int) -> None: - """Atomically write manifest.json.""" +def _write_manifest( + checkpoint_dir: str, + last_completed_layer: int, + num_layers: int, + save_every: int, + save_quantizers_only: bool, +) -> None: + """Atomically write manifest.json. Config keys are persisted so resume can detect drift.""" path = os.path.join(checkpoint_dir, "manifest.json") tmp = path + ".tmp" with open(tmp, "w") as f: json.dump( - {"last_completed_layer": last_completed_layer, "num_layers": num_layers}, + { + "last_completed_layer": last_completed_layer, + "num_layers": num_layers, + "save_every": save_every, + "save_quantizers_only": save_quantizers_only, + }, f, ) os.replace(tmp, path) @@ -489,26 +500,32 @@ def _layer_dir(checkpoint_dir: str, idx: int) -> str: return os.path.join(checkpoint_dir, f"layer_{idx:04d}") -def _save_layer( +def _save_layer_files( checkpoint_dir: str, idx: int, - weights: dict, + weights: dict | None, qstate: dict, + quantizer_buffers: dict | None, output_meta: tuple, - next_inputs: list | None, - num_layers: int, ) -> None: - """Save a single layer checkpoint and update the manifest atomically.""" + """Write the per-layer files for layer *idx*. + + Exactly one of ``weights`` (full layer state_dict) or ``quantizer_buffers`` + (just the TensorQuantizer state_dict slice, used by ``save_quantizers_only``) + is written; ``full_restore`` falls back to whichever is present. + ``next_inputs.pt`` and ``manifest.json`` are deferred to window boundaries + in :meth:`_CheckpointState.save`. + """ d = _layer_dir(checkpoint_dir, idx) if os.path.isdir(d): shutil.rmtree(d) os.makedirs(d) - torch.save(weights, os.path.join(d, "weights.pt")) + if weights is not None: + torch.save(weights, os.path.join(d, "weights.pt")) + elif quantizer_buffers is not None: + torch.save(quantizer_buffers, os.path.join(d, "quantizer_buffers.pt")) torch.save(qstate, os.path.join(d, "quantizer_state.pt")) torch.save(output_meta, os.path.join(d, "output_meta.pt")) - if next_inputs is not None: - torch.save(next_inputs, os.path.join(d, "next_inputs.pt")) - _write_manifest(checkpoint_dir, idx, num_layers) def detect_resume_point(checkpoint_dir: str) -> tuple[int, dict] | None: @@ -541,7 +558,14 @@ class _CheckpointState: and broadcast restored state to all ranks during resume. """ - def __init__(self, checkpoint_dir: str, num_layers: int, start_layer: int = 0): + def __init__( + self, + checkpoint_dir: str, + num_layers: int, + start_layer: int = 0, + save_every: int = 1, + save_quantizers_only: bool = False, + ): if dist.is_initialized() and dist.size() > 1: raise RuntimeError( "Layerwise calibration checkpointing is not supported in " @@ -552,27 +576,53 @@ def __init__(self, checkpoint_dir: str, num_layers: int, start_layer: int = 0): self.checkpoint_dir = checkpoint_dir self.num_layers = num_layers self.start_layer = start_layer + self.save_every = save_every + self.save_quantizers_only = save_quantizers_only + # Tracks the most recent saved layer so save() can window-save the layers + # since the last save event. Initialized to start_layer - 1 so the first + # save event after resume covers the new work only. + self._last_saved_layer = start_layer - 1 @classmethod - def from_folder(cls, checkpoint_dir: str | None, num_layers: int) -> _CheckpointState | None: + def from_folder( + cls, + checkpoint_dir: str | None, + num_layers: int, + save_every: int = 1, + save_quantizers_only: bool = False, + ) -> _CheckpointState | None: """Create from folder. Detects resume point. Returns None if no checkpoint_dir.""" if not checkpoint_dir: return None os.makedirs(checkpoint_dir, exist_ok=True) info = detect_resume_point(checkpoint_dir) if info is not None: - manifest_num_layers = info[1].get("num_layers") - if manifest_num_layers is not None and manifest_num_layers != num_layers: - raise ValueError( - f"Checkpoint num_layers mismatch: manifest has {manifest_num_layers} " - f"but model has {num_layers}. Use a fresh checkpoint directory." - ) + manifest = info[1] + # Pre-0.45 manifests omit save_every / save_quantizers_only; skip the + # check for keys absent from the on-disk manifest. + for key, new_value in ( + ("num_layers", num_layers), + ("save_every", save_every), + ("save_quantizers_only", save_quantizers_only), + ): + ckpt_value = manifest.get(key) + if ckpt_value is not None and ckpt_value != new_value: + raise ValueError( + f"Checkpoint {key} mismatch: manifest has {ckpt_value!r} but " + f"new run uses {new_value!r}. Use a fresh checkpoint directory." + ) start = info[0] if info else 0 if start > 0: print_rank_0( f"Checkpoint: resuming layerwise calibration from layer {start}/{num_layers}" ) - return cls(checkpoint_dir, num_layers, start_layer=start) + return cls( + checkpoint_dir, + num_layers, + start_layer=start, + save_every=save_every, + save_quantizers_only=save_quantizers_only, + ) def setup_resume(self, layers: nn.ModuleList) -> list | None: """Load output_meta for skip layers 0..K-1, return next_inputs for layer K. @@ -609,7 +659,10 @@ def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None: """Restore weights and quantizer state for layers 0..K-1 after the calibration loop.""" from modelopt.torch.quantization.config import QuantizeConfig from modelopt.torch.quantization.conversion import restore_quantizer_state - from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback + from modelopt.torch.quantization.utils.core_utils import ( + enable_weight_access_and_writeback, + set_quantizer_state_dict, + ) if self.start_layer == 0: return @@ -630,55 +683,96 @@ def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None: map_location=layer_device, weights_only=False, ) - weights = torch.load( - os.path.join(d, "weights.pt"), - map_location=layer_device, - weights_only=False, - ) restore_quantizer_state(layer, dummy_config, {"quantizer_state": qstate}) - layer.load_state_dict(weights, strict=False, assign=True) + weights_path = os.path.join(d, "weights.pt") + buffers_path = os.path.join(d, "quantizer_buffers.pt") + if os.path.isfile(weights_path): + weights = torch.load( + weights_path, map_location=layer_device, weights_only=False + ) + layer.load_state_dict(weights, strict=False, assign=True) + elif os.path.isfile(buffers_path): + # save_quantizers_only mode: restore just the TensorQuantizer + # state_dict (carries _amax). The layer's other weights + # weren't modified by the algorithm, so the in-memory values + # already match what would have been saved. + quantizer_buffers = torch.load( + buffers_path, map_location=layer_device, weights_only=False + ) + set_quantizer_state_dict(layer, quantizer_buffers) print_rank_0(f"Checkpoint: restored {self.start_layer} previously calibrated layers") def save( self, layer_idx: int, - layer: nn.Module, model: nn.Module, layers: nn.ModuleList, next_layer_inputs: list | None = None, ) -> None: - """Snapshot layer state and write checkpoint to disk in one step. - - Args: - layer_idx: Index of the layer just calibrated. - layer: The layer module (weights may be on GPU or managed by accelerate/FSDP2). - model: The full model (needed for ``enable_weight_access_and_writeback``). - layers: The decoder layer list (to read ``output_meta``). - next_layer_inputs: Inputs for the next layer (``None`` for the final layer). + """Snapshot the just-calibrated layer; commit the window at boundaries. + + Each call reads state from ``layers[layer_idx]`` *before* the next + iteration's capture forward swaps it to a ``_SkipLayer``, so state is + always read from the real calibrated layer. Per-layer files are written + every call; ``next_inputs.pt`` and the manifest are deferred to window + boundaries so a mid-window crash leaves the manifest pointing at the + previous boundary. """ from modelopt.torch.quantization.conversion import quantizer_state - from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback + from modelopt.torch.quantization.utils.core_utils import ( + enable_weight_access_and_writeback, + get_quantizer_state_dict, + ) _cpu = torch.device("cpu") + layer = layers[layer_idx] with enable_weight_access_and_writeback(layer, model): - weights = _move_to_device(layer.state_dict(), _cpu) qstate = _move_to_device(quantizer_state(layer), _cpu) + if self.save_quantizers_only: + weights = None + quantizer_buffers = _move_to_device(get_quantizer_state_dict(layer), _cpu) + else: + weights = _move_to_device(layer.state_dict(), _cpu) + quantizer_buffers = None output_meta = getattr(layer._layerwise_calib, "output_meta", None) if output_meta is None: - # Placeholder for the last layer: output_meta is never used for skip mode - # since there is no subsequent layer that needs a correctly shaped dummy output. + # Final-layer placeholder: never consumed by skip mode (no successor). output_meta = LayerActivationCollector._extract_output_meta(torch.zeros(1)) - _save_layer( + _save_layer_files( self.checkpoint_dir, layer_idx, weights, qstate, + quantizer_buffers, _move_to_device(output_meta, _cpu), - _move_to_device(next_layer_inputs, _cpu) if next_layer_inputs is not None else None, + ) + + is_final = layer_idx + 1 == self.num_layers + is_window_end = (layer_idx + 1) % self.save_every == 0 + if not (is_final or is_window_end): + return + + # Window boundary: write next_inputs.pt + manifest to commit the window. + if next_layer_inputs is not None: + torch.save( + _move_to_device(next_layer_inputs, _cpu), + os.path.join(_layer_dir(self.checkpoint_dir, layer_idx), "next_inputs.pt"), + ) + _write_manifest( + self.checkpoint_dir, + layer_idx, self.num_layers, + save_every=self.save_every, + save_quantizers_only=self.save_quantizers_only, ) + window_start = self._last_saved_layer + 1 + self._last_saved_layer = layer_idx + window_size = layer_idx - window_start + 1 suffix = " (final)" if next_layer_inputs is None else "" - print_rank_0(f"Checkpoint: saved layer {layer_idx}{suffix}") + if window_size > 1: + print_rank_0(f"Checkpoint: committed window {window_start}..{layer_idx}{suffix}") + else: + print_rank_0(f"Checkpoint: committed layer {layer_idx}{suffix}") diff --git a/modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml b/modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml index 6dee51857c8..5dc5786f0ff 100644 --- a/modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml @@ -29,8 +29,9 @@ metadata: quantize: algorithm: method: gptq - layerwise: true - layerwise_checkpoint_dir: output/layerwise_ckpts/ + layerwise: + enable: true + checkpoint_dir: output/layerwise_ckpts/ quant_cfg: - $import: base_disable_all - quantizer_name: '*weight_quantizer' diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8.yaml index 08864c8a50d..af68684989b 100644 --- a/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8.yaml @@ -30,9 +30,10 @@ quantize: algorithm: method: max # Max calibration is fast and does not typically need checkpointing. - # layerwise=false required for VLMs where the decoder layers are nested under + # layerwise.enable=false required for VLMs where the decoder layers are nested under # `model.language_model.layers` (layerwise_calibrate can't find them otherwise). - layerwise: false + layerwise: + enable: false quant_cfg: - $import: base_disable_all - quantizer_name: '*mlp.experts*weight_quantizer' diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_layerwise.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_layerwise.yaml index 9d1f470643f..2da5abb3b89 100644 --- a/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_layerwise.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_layerwise.yaml @@ -26,7 +26,8 @@ quantize: algorithm: method: max # Max calibration is fast and does not typically need checkpointing. - layerwise: true + layerwise: + enable: true quant_cfg: - $import: base_disable_all - quantizer_name: '*mlp.experts*weight_quantizer' diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-kv_fp8_cast.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-kv_fp8_cast.yaml index 5bf9a36dc31..3043ed32951 100644 --- a/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-kv_fp8_cast.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-kv_fp8_cast.yaml @@ -31,9 +31,10 @@ quantize: algorithm: method: mse fp8_scale_sweep: true - # layerwise=false required for VLMs where the decoder layers are nested under + # layerwise.enable=false required for VLMs where the decoder layers are nested under # `model.language_model.layers` (layerwise_calibrate can't find them otherwise). - layerwise: false + layerwise: + enable: false quant_cfg: - $import: base_disable_all - quantizer_name: '*mlp.experts*weight_quantizer' diff --git a/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-kv_fp8_cast.yaml b/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-kv_fp8_cast.yaml index 2ea2c0ab13e..71c354ee1b1 100644 --- a/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-kv_fp8_cast.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-kv_fp8_cast.yaml @@ -31,9 +31,10 @@ quantize: algorithm: method: mse fp8_scale_sweep: true - # layerwise=false required for VLMs where the decoder layers are nested under + # layerwise.enable=false required for VLMs where the decoder layers are nested under # `model.language_model.layers` (layerwise_calibrate can't find them otherwise). - layerwise: false + layerwise: + enable: false quant_cfg: - $import: base_disable_all - quantizer_name: '*mlp*weight_quantizer' diff --git a/modelopt_recipes/huggingface/qwen3_5/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml b/modelopt_recipes/huggingface/qwen3_5/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml index 0386c70b13e..355bffee67c 100644 --- a/modelopt_recipes/huggingface/qwen3_5/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml +++ b/modelopt_recipes/huggingface/qwen3_5/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml @@ -34,6 +34,7 @@ metadata: quantize: algorithm: method: max - layerwise: false + layerwise: + enable: false quant_cfg: - $import: shared_quant_cfg diff --git a/modelopt_recipes/huggingface/qwen3_5_moe/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml b/modelopt_recipes/huggingface/qwen3_5_moe/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml index 2a3b99cdeb4..fa24bee97af 100644 --- a/modelopt_recipes/huggingface/qwen3_5_moe/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml +++ b/modelopt_recipes/huggingface/qwen3_5_moe/ptq/w4a16_nvfp4-fp8_attn-kv_fp8_cast.yaml @@ -35,6 +35,7 @@ metadata: quantize: algorithm: method: max - layerwise: false + layerwise: + enable: false quant_cfg: - $import: shared_quant_cfg diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index ce98f989f51..82b9d24d7a4 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -15,6 +15,8 @@ """Test of quantization config validations.""" +import warnings + import pytest from pydantic import ValidationError @@ -25,8 +27,14 @@ INT4_AWQ_CFG, NVFP4_DEFAULT_CFG, W4A8_AWQ_BETA_CFG, + AWQLiteCalibConfig, + GPTQCalibConfig, + LayerwiseConfig, + LocalHessianCalibConfig, MaxCalibConfig, + MseCalibConfig, QuantizeConfig, + SmoothQuantCalibConfig, find_quant_cfg_entry_by_path, need_calibration, normalize_quant_cfg_list, @@ -573,32 +581,137 @@ def test_validate_quant_cfg_entries_accepts_valid_cfg(self): class TestLayerwiseUseSequentialAlias: - """`layerwise` accepts the legacy `use_sequential` name via validation_alias. - - Old PTQ checkpoints serialized the field as `use_sequential` before #1251 renamed - it to `layerwise`. AliasChoices lets those checkpoints load without a migration - validator while still serializing under the current name. - """ + """`use_sequential` is the legacy alias for `layerwise` (pre-#1251 checkpoints).""" + + @pytest.mark.parametrize("value", [True, False]) + def test_use_sequential_resolves_to_layerwise(self, value): + with pytest.warns(DeprecationWarning): + cfg = MaxCalibConfig(use_sequential=value) + assert cfg.layerwise.enable is value + + def test_serializes_under_layerwise_not_alias(self): + with pytest.warns(DeprecationWarning): + dumped = MaxCalibConfig(use_sequential=True).model_dump() + assert dumped["layerwise"]["enable"] is True + assert "use_sequential" not in dumped - def test_use_sequential_true_sets_layerwise(self): - cfg = MaxCalibConfig(use_sequential=True) - assert cfg.layerwise is True - def test_use_sequential_false_sets_layerwise(self): - cfg = MaxCalibConfig(use_sequential=False) - assert cfg.layerwise is False +class TestLayerwiseNestedConfig: + """Layerwise expands from a bool to a nested ``LayerwiseConfig``. - def test_layerwise_name_still_accepted(self): - cfg = MaxCalibConfig(layerwise=True) - assert cfg.layerwise is True + Backward compatibility: bool input is coerced with a DeprecationWarning, and + the legacy flat ``layerwise_checkpoint_dir`` key is silently absorbed. + """ - def test_serializes_under_current_name(self): - """Dump must use `layerwise`, not the legacy alias.""" - dumped = MaxCalibConfig(use_sequential=True).model_dump() - assert dumped["layerwise"] is True - assert "use_sequential" not in dumped + def test_nested_form_accepted(self): + cfg = MaxCalibConfig(layerwise={"enable": True, "checkpoint_dir": "/x"}) + assert cfg.layerwise.enable is True + assert cfg.layerwise.checkpoint_dir == "/x" + + def test_bool_form_deprecated_but_accepted(self): + with pytest.warns(DeprecationWarning, match="bool is deprecated"): + cfg = MaxCalibConfig(layerwise=True) + assert cfg.layerwise.enable is True + + def test_dict_form_no_deprecation(self): + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + MaxCalibConfig(layerwise={"enable": True}) + + def test_flat_checkpoint_dir_migrated_with_deprecation(self): + """Legacy ``layerwise_checkpoint_dir`` is migrated into the nested config + and emits a deprecation warning naming the flat key (independent of the + bool-form deprecation tested above). + """ + with pytest.warns(DeprecationWarning, match="layerwise_checkpoint_dir.*deprecated"): + cfg = MaxCalibConfig(layerwise={"enable": True}, layerwise_checkpoint_dir="/x") + assert cfg.layerwise.checkpoint_dir == "/x" + + def test_use_sequential_alias_survives_flat_checkpoint_migration(self): + """``use_sequential`` + flat ``layerwise_checkpoint_dir`` must not drop the alias value.""" + with pytest.warns(DeprecationWarning): + cfg = MaxCalibConfig(use_sequential=True, layerwise_checkpoint_dir="/x") + assert cfg.layerwise.enable is True + assert cfg.layerwise.checkpoint_dir == "/x" + + def test_conflicting_flat_and_nested_checkpoint_dir_raises(self): + with pytest.raises(ValidationError, match="Conflicting checkpoint_dir"): + MaxCalibConfig( + layerwise={"enable": True, "checkpoint_dir": "/a"}, + layerwise_checkpoint_dir="/b", + ) - def test_unknown_field_still_rejected(self): - """extra='forbid' must still reject unrelated unknown fields.""" + @pytest.mark.parametrize( + "kwargs", + [ + {"layerwise": {"checkpoint_dir": "/x"}}, + {"layerwise_checkpoint_dir": "/x"}, + ], + ) + def test_checkpoint_dir_requires_enable(self, kwargs): + with pytest.raises(ValidationError, match="requires layerwise.enable=True"): + MaxCalibConfig(**kwargs) + + @pytest.mark.parametrize( + ("cfg_cls", "expected_qdq"), + [(MaxCalibConfig, False), (GPTQCalibConfig, True)], + ) + def test_per_algorithm_qdq_default(self, cfg_cls, expected_qdq): + assert cfg_cls().layerwise.get_qdq_activations_from_prev_layer is expected_qdq + + @pytest.mark.parametrize( + ("layerwise_input", "expected_qdq"), + [ + # GPTQ default kicks in for user dict that doesn't mention qdq. + ({"enable": True}, True), + # GPTQ default kicks in for legacy bool form too. + pytest.param( + True, True, marks=pytest.mark.filterwarnings("ignore::DeprecationWarning") + ), + # User-explicit False overrides the GPTQ default. + ({"enable": True, "get_qdq_activations_from_prev_layer": False}, False), + # ``LayerwiseConfig`` instance: ``_coerce_layerwise_input`` must + # preserve ``model_fields_set`` so the GPTQ default still kicks in + # for fields the user didn't explicitly set. + (LayerwiseConfig(enable=True), True), + ( + LayerwiseConfig(enable=True, get_qdq_activations_from_prev_layer=False), + False, + ), + ], + ) + def test_gptq_qdq_default_respects_user_explicit_value(self, layerwise_input, expected_qdq): + cfg = GPTQCalibConfig(layerwise=layerwise_input) + assert cfg.layerwise.get_qdq_activations_from_prev_layer is expected_qdq + + def test_default_dump_shape(self): + dumped = MaxCalibConfig().model_dump() + assert dumped["layerwise"] == { + "enable": False, + "get_qdq_activations_from_prev_layer": False, + "checkpoint_dir": None, + "save_every": 1, + "save_quantizers_only": False, + } + assert "layerwise_checkpoint_dir" not in dumped + + def test_save_every_must_be_positive(self): with pytest.raises(ValidationError): - MaxCalibConfig(not_a_real_field=True) + MaxCalibConfig(layerwise={"enable": True, "save_every": 0}) + + @pytest.mark.parametrize( + "cfg_cls", + [GPTQCalibConfig, AWQLiteCalibConfig, SmoothQuantCalibConfig], + ) + def test_save_quantizers_only_rejected_for_weight_mutating_algorithms(self, cfg_cls): + """Whitelist: only amax-only algorithms (max/mse/local_hessian) may set + save_quantizers_only=True. Weight-mutating algorithms (GPTQ folds Hessian + updates, AWQ/SmoothQuant fold pre-quant scales) must reject the flag. + """ + with pytest.raises(ValidationError, match="mutates layer weights in-place"): + cfg_cls(layerwise={"enable": True, "save_quantizers_only": True}) + + @pytest.mark.parametrize("cfg_cls", [MaxCalibConfig, MseCalibConfig, LocalHessianCalibConfig]) + def test_save_quantizers_only_accepted_for_amax_only_algorithms(self, cfg_cls): + cfg = cfg_cls(layerwise={"enable": True, "save_quantizers_only": True}) + assert cfg.layerwise.save_quantizers_only is True diff --git a/tests/unit/torch/quantization/test_layerwise_calibrate.py b/tests/unit/torch/quantization/test_layerwise_calibrate.py index 3739feff969..5ec4cc2daed 100644 --- a/tests/unit/torch/quantization/test_layerwise_calibrate.py +++ b/tests/unit/torch/quantization/test_layerwise_calibrate.py @@ -16,6 +16,7 @@ """Unit tests for layerwise_calibrate and LayerActivationCollector.""" import copy +import json from collections import deque import pytest @@ -713,9 +714,320 @@ def test_mtq_quantize_layerwise_raises_for_unsupported_algorithm(): config = _svdquant_layerwise_config() torch.manual_seed(0) model = _SimpleTransformerModel(n_layers=2, dim=16) - with pytest.raises(ValueError, match="does not support layerwise=True"): + with pytest.raises(ValueError, match="does not support layerwise.enable=True"): mtq.quantize( model, config, forward_loop=lambda m: m(torch.randint(0, 32, (2, 8))), ) + + +def _collect_amax(model): + return { + name: q._amax.clone().detach() + for name, q in model.named_modules() + if isinstance(q, TensorQuantizer) and q.is_enabled and getattr(q, "_amax", None) is not None + } + + +def _assert_amax_close(actual, expected, label): + assert set(actual) == set(expected), ( + f"Different quantizers populated for {label}: " + f"missing={set(expected) - set(actual)}, extra={set(actual) - set(expected)}" + ) + for name in expected: + torch.testing.assert_close( + actual[name], + expected[name], + rtol=1e-5, + atol=1e-6, + msg=f"{label}: amax mismatch at {name}", + ) + + +def test_layerwise_no_qdq_matches_sequential_amax(monkeypatch): + """Layerwise + ``get_qdq_activations_from_prev_layer=False`` must produce the + same per-quantizer amax as the non-layerwise (sequential) max-calibration + flow. Both paths feed every layer full-precision activations. + """ + _register_test_discoverer(monkeypatch) + + torch.manual_seed(0) + model_seq = _SimpleTransformerModel(n_layers=3, dim=16) + model_lw = copy.deepcopy(model_seq) + calib_data = [torch.randint(0, 32, (2, 8)) for _ in range(2)] + + def fwd(m): + for batch in calib_data: + m(batch) + + mtq.quantize(model_seq, _int8_layerwise_config({"method": "max"}), forward_loop=fwd) + mtq.quantize( + model_lw, + _int8_layerwise_config( + { + "method": "max", + "layerwise": {"enable": True, "get_qdq_activations_from_prev_layer": False}, + } + ), + forward_loop=fwd, + ) + + seq_amax = _collect_amax(model_seq) + assert seq_amax, "sequential calibration populated no amax values" + _assert_amax_close(_collect_amax(model_lw), seq_amax, "layerwise vs sequential") + + +def test_layerwise_no_qdq_captures_inputs_before_calib_func_mutates_weights(monkeypatch): + """A destructive ``calib_func`` (zeros weights) must not affect what is + captured for downstream layers under ``qdq_from_prev=False`` — otherwise + weight-mutating algorithms (GPTQ/AWQ/SmoothQuant) silently propagate + updates forward and break the "identical to non-layerwise pass" contract. + """ + _register_test_discoverer(monkeypatch) + calib_data = [torch.randint(0, 32, (2, 8))] + + def run_and_capture(calib_func): + torch.manual_seed(0) + model = _SimpleTransformerModel(n_layers=3, dim=16) + captured: dict[int, torch.Tensor] = {} + real = LayerActivationCollector.cache_outputs_for_next_layer_calib + + def spy(self, layer, fwd): + result = real(self, layer, fwd) + captured[self._layer_to_idx[layer] + 1] = result[0][0][0].clone().detach() + return result + + with monkeypatch.context() as m: + m.setattr(LayerActivationCollector, "cache_outputs_for_next_layer_calib", spy) + layerwise_calibrate( + model, + forward_loop=lambda mm: [mm(b) for b in calib_data], + calib_func=calib_func, + get_qdq_activations_from_prev_layer=False, + ) + return captured + + def identity(layer, fwd, **_): + fwd(layer) + + def destructive(layer, fwd, **_): + fwd(layer) + for sub in layer.modules(): + if isinstance(sub, nn.Linear): + sub.weight.data.zero_() + + benign = run_and_capture(identity) + mutated = run_and_capture(destructive) + + for i in (1, 2): + torch.testing.assert_close(mutated[i], benign[i], rtol=1e-5, atol=1e-6) + + +def _layer_dir_names(checkpoint_dir): + return sorted(p.name for p in checkpoint_dir.iterdir() if p.name.startswith("layer_")) + + +def test_layerwise_save_every_writes_next_inputs_only_at_window_boundaries(monkeypatch, tmp_path): + """With save_every=2 on a 4-layer model, every layer dir is still written + (so resume can replay skip layers), but ``next_inputs.pt`` — the large + activation cache — appears only at window boundaries. + """ + _register_test_discoverer(monkeypatch) + + config = _int8_layerwise_config( + { + "method": "max", + "layerwise": { + "enable": True, + "checkpoint_dir": str(tmp_path), + "save_every": 2, + }, + } + ) + torch.manual_seed(0) + model = _SimpleTransformerModel(n_layers=4, dim=16) + calib_data = [torch.randint(0, 32, (2, 8))] + mtq.quantize(model, config, forward_loop=lambda m: [m(b) for b in calib_data]) + + # Window boundaries are layer_idx 1 and 3 -> all 4 layer dirs exist (window-save). + assert _layer_dir_names(tmp_path) == [ + "layer_0000", + "layer_0001", + "layer_0002", + "layer_0003", + ] + # next_inputs.pt is only at the boundary layers (the resume restart points). + assert not (tmp_path / "layer_0000" / "next_inputs.pt").exists() + assert (tmp_path / "layer_0001" / "next_inputs.pt").exists() + assert not (tmp_path / "layer_0002" / "next_inputs.pt").exists() + # Last layer never has next_inputs.pt (no subsequent layer). + assert not (tmp_path / "layer_0003" / "next_inputs.pt").exists() + + +@pytest.mark.parametrize( + ("scenario", "n_layers", "save_every", "save_quantizers_only", "rewind_to"), + [ + # Pins the quantizer_buffers.pt restore path (no weights.pt on disk). + ("quantizers_only", 3, 1, True, 0), + # Pins the per-call snapshot fix: each save() captures the + # just-calibrated layer's state before the next-layer capture forward + # swaps it to _SkipLayer. + ("save_every", 4, 2, False, 1), + ], +) +def test_layerwise_checkpoint_resume_matches_one_shot_amax( + monkeypatch, tmp_path, scenario, n_layers, save_every, save_quantizers_only, rewind_to +): + """Full run → rewind manifest → fresh resume reproduces one-shot ``_amax``. + + Single test covering both checkpoint optimizations. For the + ``save_quantizers_only`` case also asserts the on-disk shape (no + ``weights.pt``, ``quantizer_buffers.pt`` present per layer). + """ + _register_test_discoverer(monkeypatch) + + calib_data = [torch.randint(0, 32, (2, 8))] + forward_loop = lambda m: [m(b) for b in calib_data] # noqa: E731 + + def build_cfg(ckpt_dir): + return _int8_layerwise_config( + { + "method": "max", + "layerwise": { + "enable": True, + "checkpoint_dir": str(ckpt_dir), + "save_every": save_every, + "save_quantizers_only": save_quantizers_only, + }, + } + ) + + baseline_dir = tmp_path / "baseline" + torch.manual_seed(0) + baseline_model = _SimpleTransformerModel(n_layers=n_layers, dim=16) + mtq.quantize(baseline_model, build_cfg(baseline_dir), forward_loop=forward_loop) + baseline_amax = _collect_amax(baseline_model) + assert baseline_amax + + resume_dir = tmp_path / "resume" + torch.manual_seed(0) + setup_model = _SimpleTransformerModel(n_layers=n_layers, dim=16) + mtq.quantize(setup_model, build_cfg(resume_dir), forward_loop=forward_loop) + + if scenario == "quantizers_only": + for name in _layer_dir_names(resume_dir): + d = resume_dir / name + assert not (d / "weights.pt").exists() + assert (d / "quantizer_buffers.pt").exists() + + (resume_dir / "manifest.json").write_text( + json.dumps( + { + "last_completed_layer": rewind_to, + "num_layers": n_layers, + "save_every": save_every, + "save_quantizers_only": save_quantizers_only, + } + ) + ) + + torch.manual_seed(0) + resumed_model = _SimpleTransformerModel(n_layers=n_layers, dim=16) + mtq.quantize(resumed_model, build_cfg(resume_dir), forward_loop=forward_loop) + + _assert_amax_close(_collect_amax(resumed_model), baseline_amax, f"{scenario} resume") + + +def test_layerwise_save_every_mid_window_crash_recovers_at_prev_boundary(monkeypatch, tmp_path): + """A crash inside a window must not advance ``last_completed_layer``; resume + re-calibrates the unfinished window from the previous boundary. + + Monkeypatches ``torch.save`` to raise on the second window's first per-layer + write (layer 2), then asserts the on-disk manifest still points at layer 1 + (the previous boundary). + """ + _register_test_discoverer(monkeypatch) + + cfg = _int8_layerwise_config( + { + "method": "max", + "layerwise": { + "enable": True, + "checkpoint_dir": str(tmp_path), + "save_every": 2, + }, + } + ) + torch.manual_seed(0) + model = _SimpleTransformerModel(n_layers=4, dim=16) + calib_data = [torch.randint(0, 32, (2, 8))] + + real_torch_save = torch.save + state = {"crashed": False} + + def crashing_torch_save(obj, path, *args, **kwargs): + # Crash on the first per-layer file write for layer 2 (mid-window). + if not state["crashed"] and "layer_0002" in str(path): + state["crashed"] = True + raise RuntimeError("simulated crash during layer 2 save") + return real_torch_save(obj, path, *args, **kwargs) + + monkeypatch.setattr( + "modelopt.torch.quantization.utils.layerwise_calib.torch.save", + crashing_torch_save, + ) + with pytest.raises(RuntimeError, match="simulated crash"): + mtq.quantize(model, cfg, forward_loop=lambda m: [m(b) for b in calib_data]) + + manifest = json.loads((tmp_path / "manifest.json").read_text()) + assert manifest["last_completed_layer"] == 1, f"manifest leaked mid-window state: {manifest}" + + +def test_layerwise_checkpoint_mismatch_save_every_raises(monkeypatch, tmp_path): + """Resuming with a different ``save_every`` than the checkpoint was produced + with must raise — the on-disk window layout assumes a fixed value. + """ + _register_test_discoverer(monkeypatch) + + cfg_first = _int8_layerwise_config( + { + "method": "max", + "layerwise": { + "enable": True, + "checkpoint_dir": str(tmp_path), + "save_every": 2, + }, + } + ) + torch.manual_seed(0) + model = _SimpleTransformerModel(n_layers=4, dim=16) + calib_data = [torch.randint(0, 32, (2, 8))] + mtq.quantize(model, cfg_first, forward_loop=lambda m: [m(b) for b in calib_data]) + + # Rewind manifest so the second run sees an in-progress resume, + # then change save_every to trigger the mismatch check. + (tmp_path / "manifest.json").write_text( + json.dumps( + { + "last_completed_layer": 1, + "num_layers": 4, + "save_every": 2, + "save_quantizers_only": False, + } + ) + ) + cfg_mismatched = _int8_layerwise_config( + { + "method": "max", + "layerwise": { + "enable": True, + "checkpoint_dir": str(tmp_path), + "save_every": 4, + }, + } + ) + torch.manual_seed(0) + fresh_model = _SimpleTransformerModel(n_layers=4, dim=16) + with pytest.raises(ValueError, match="save_every mismatch"): + mtq.quantize(fresh_model, cfg_mismatched, forward_loop=lambda m: [m(b) for b in calib_data])