diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ace0fb7d040..e00da4c0f7a 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -40,7 +40,7 @@ Changelog - Add mixed-precision FP8 + NVFP4 export for Megatron-Core: per-layer ``quant_algo`` recorded under ``quantized_layers`` in ``hf_quant_config.json``, PP-aware ``kv_cache_dtype`` gather, fused-QKV exclude split into per-HF-name ``q/k/v_proj`` entries. - Add Nemotron-3-Super-120B-A12B PTQ recipes ``modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml`` (MSE-mixed) and ``super-nvfp4-max-calib.yaml`` (max-calib mixed): NVFP4 W4A4 routed experts + FP8 per-tensor shared experts / Mamba in/out_proj + FP8 KV cache. - Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default). -- Group layerwise calibration options under a nested ``LayerwiseConfig`` and add three knobs: ``get_qdq_activations_from_prev_layer`` (correct GPTQ-Hessian vs max-calib activation semantics — defaults to True for GPTQ, False for max/mse/local_hessian), ``save_every`` (gate per-window ``next_inputs.pt`` activation-cache writes), and ``save_quantizers_only`` (skip the layer-weights blob for amax-only algorithms — whitelisted to ``max``/``mse``/``local_hessian``). Legacy bool ``layerwise`` and flat ``layerwise_checkpoint_dir`` keys still work; the bool form emits a ``DeprecationWarning``. +- Group layerwise calibration options under a nested ``LayerwiseConfig`` and add three knobs: ``get_qdq_activations_from_prev_layer`` (correct GPTQ-Hessian vs max-calib activation semantics — defaults to True for GPTQ, False for max/mse/local_hessian), ``save_every`` (gate per-window ``next_inputs.pt`` activation-cache writes), and ``calib_mutates_weights`` (set False to skip layer-weight checkpointing/writeback when calibration only updates quantizer state). Legacy bool ``layerwise`` and flat ``layerwise_checkpoint_dir`` keys still work; the bool form emits a ``DeprecationWarning``. **Bug Fixes** diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 726a78d129b..4c26c98a8d4 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -678,13 +678,14 @@ class LayerwiseConfig(ModeloptBaseConfig): ), ) - save_quantizers_only: bool = ModeloptField( - default=False, - title="Skip the per-layer weights blob; persist only quantizer state.", + calib_mutates_weights: bool = ModeloptField( + default=True, + title="Whether layerwise calibration mutates layer weights.", description=( - "Only accepted by algorithms that update solely ``TensorQuantizer._amax`` " - "(max, mse, local_hessian). Rejected for weight-mutating algorithms " - "(GPTQ, AWQ, SmoothQuant) where it would silently lose updates on resume." + "Set to False only for algorithms that update solely " + "``TensorQuantizer._amax`` (max, mse, local_hessian). Rejected for " + "weight-mutating algorithms (GPTQ, AWQ, SmoothQuant) where it would " + "silently lose updates on resume." ), ) @@ -712,7 +713,7 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): """Calibration algorithm config base.""" # Set True only for algorithms that update solely ``TensorQuantizer._amax`` - # (no ``layer.weight`` mutation). Gates ``layerwise.save_quantizers_only``. + # (no ``layer.weight`` mutation). Gates ``layerwise.calib_mutates_weights=False``. _supports_save_quantizers_only: ClassVar[bool] = False method: Literal[None] = ModeloptField( @@ -794,12 +795,12 @@ def validate_layerwise_checkpoint_dir(self): return self @model_validator(mode="after") - def _validate_save_quantizers_only_supported(self): - """Enforce the ``_supports_save_quantizers_only`` whitelist.""" - if self.layerwise.save_quantizers_only and not self._supports_save_quantizers_only: + def _validate_non_mutating_layerwise_supported(self): + """Enforce the ``calib_mutates_weights=False`` whitelist.""" + if not self.layerwise.calib_mutates_weights and not self._supports_save_quantizers_only: raise ValueError( f"Algorithm '{self.method}' mutates layer weights in-place; " - "save_quantizers_only=True would lose those updates on resume. " + "calib_mutates_weights=False would lose those updates on resume. " "Only max/mse/local_hessian (amax-only) support this flag." ) return self diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 28cadf3aa7f..28c904c1804 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -228,7 +228,7 @@ def wrapped_calib_func( checkpoint_dir = layerwise_cfg.get("checkpoint_dir") qdq_from_prev = layerwise_cfg.get("get_qdq_activations_from_prev_layer", False) save_every = layerwise_cfg.get("save_every", 1) - save_quantizers_only = layerwise_cfg.get("save_quantizers_only", False) + calib_mutates_weights = layerwise_cfg.get("calib_mutates_weights", True) if method is not None and "awq" in method: # For backward compatibility kwargs["algorithm"] = method @@ -263,7 +263,7 @@ def wrapped_calib_func( checkpoint_dir=checkpoint_dir, get_qdq_activations_from_prev_layer=qdq_from_prev, save_every=save_every, - save_quantizers_only=save_quantizers_only, + calib_mutates_weights=calib_mutates_weights, **kwargs, ) else: diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index b35beef50d1..f6f34b81fbb 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -34,7 +34,7 @@ _CheckpointState, ) from modelopt.torch.utils import print_rank_0 -from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState +from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState, is_master from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator @@ -1761,7 +1761,7 @@ def layerwise_calibrate( checkpoint_dir = calib_kwargs.pop("checkpoint_dir", None) qdq_from_prev = calib_kwargs.pop("get_qdq_activations_from_prev_layer", False) save_every = calib_kwargs.pop("save_every", 1) - save_quantizers_only = calib_kwargs.pop("save_quantizers_only", False) + calib_mutates_weights = calib_kwargs.pop("calib_mutates_weights", True) if forward_loop is None: raise ValueError( @@ -1783,16 +1783,27 @@ def layerwise_calibrate( checkpoint_dir, num_layers, save_every=save_every, - save_quantizers_only=save_quantizers_only, + calib_mutates_weights=calib_mutates_weights, ) start_layer = ckpt.start_layer if ckpt else 0 - input_getter = LayerActivationCollector(model) - input_getter._patch_all_layers(decoder_layers=transformer_layers) + layer_pbar = tqdm( + total=num_layers, + initial=start_layer, + desc="Layerwise calibration", + disable=not is_master(), + dynamic_ncols=True, + ) + + def _set_layer_status(status: str): + layer_pbar.set_postfix_str(status, refresh=True) - resumed_inputs = ckpt.setup_resume(transformer_layers) if ckpt and start_layer > 0 else None + input_getter = LayerActivationCollector(model, status_callback=_set_layer_status) try: + input_getter._patch_all_layers(decoder_layers=transformer_layers) + resumed_inputs = ckpt.setup_resume(transformer_layers) if ckpt and start_layer > 0 else None + # Bootstrap: get first layer's inputs (or use resumed inputs). layer_inputs = input_getter.get_first_layer_inputs( start_layer, resumed_inputs, forward_loop @@ -1822,39 +1833,43 @@ def _layer_forward_loop(m, _inputs=layer_inputs): is_last = layer_idx + 1 >= num_layers - # qdq_from_prev=False: capture before calib_func so the forward - # replay uses the original FP weights. Disable quantizers too in - # case any pre-calibration observer behavior would perturb the - # captured activations. - if not is_last and not qdq_from_prev: - with set_quantizer_by_cfg_context( - layer, [{"quantizer_name": "*", "enable": False}] - ): - next_inputs = input_getter.cache_outputs_for_next_layer_calib( - layer, forward_loop - ) - # cache_outputs left this layer in "run" mode with an empty - # deque; reset so calib_func's replay hits the real forward. - layer._layerwise_calib.mode = "original" + with persistent_materialization(layer, writeback=calib_mutates_weights): + # qdq_from_prev=False: capture before calib_func so the forward + # replay uses the original FP weights. Disable quantizers too in + # case any pre-calibration observer behavior would perturb the + # captured activations. + if not is_last and not qdq_from_prev: + with set_quantizer_by_cfg_context( + layer, [{"quantizer_name": "*", "enable": False}] + ): + next_inputs = input_getter.cache_outputs_for_next_layer_calib( + layer, forward_loop + ) + # cache_outputs left this layer in "run" mode with an empty + # deque; reset so calib_func's replay hits the real forward. + layer._layerwise_calib.mode = "original" - with persistent_materialization(layer): calib_func(layer, _layer_forward_loop, **calib_kwargs) - # qdq_from_prev=True: capture after calib_func so the next layer - # sees QDQ error and any in-place weight updates from this layer. - if not is_last and qdq_from_prev: - next_inputs = input_getter.cache_outputs_for_next_layer_calib(layer, forward_loop) - elif is_last: - next_inputs = None + # qdq_from_prev=True: capture after calib_func so the next layer + # sees QDQ error and any in-place weight updates from this layer. + if not is_last and qdq_from_prev: + next_inputs = input_getter.cache_outputs_for_next_layer_calib( + layer, forward_loop + ) + elif is_last: + next_inputs = None - if ckpt: - ckpt.save(layer_idx, model, transformer_layers, next_inputs) + if ckpt: + ckpt.save(layer_idx, model, transformer_layers, next_inputs) + layer_pbar.update(1) del layer_inputs torch.cuda.empty_cache() layer_inputs = next_inputs # noqa: F841 (used in next iteration's closure) finally: input_getter._unpatch_all_layers() + layer_pbar.close() if ckpt: ckpt.full_restore(transformer_layers, model) diff --git a/modelopt/torch/quantization/plugins/accelerate.py b/modelopt/torch/quantization/plugins/accelerate.py index f80e2478dc6..156669bfc8f 100644 --- a/modelopt/torch/quantization/plugins/accelerate.py +++ b/modelopt/torch/quantization/plugins/accelerate.py @@ -66,15 +66,22 @@ def _writeback_params_to_weights_map(module, align_hook): # on-disk version. OffloadedWeightsLoader.__getitem__ gives # state_dict priority over index, so this is sufficient. w_map[key] = tensor.detach().cpu() + warnings.warn( + "Accelerate disk-offload writeback is currently kept in CPU state_dict; " + "writing updates back to disk offload files is TODO.", + RuntimeWarning, + stacklevel=2, + ) @contextmanager -def weight_access_and_writeback_context(module): +def weight_access_and_writeback_context(module, writeback: bool = True): """Context manager for weight access and writeback for modules managed by accelerate. Handles CPU-offloaded and disk-offloaded models. Iterates over the module and all - its descendants, materializing weights from any offload hook found and writing them - back on exit. ``pre_forward`` is skipped on modules whose weights are already + its descendants, materializing weights from any offload hook found. If ``writeback`` + is True, writes materialized tensors back on exit. ``pre_forward`` is skipped on + modules whose weights are already materialized (not on meta) to avoid overwriting them with stale CPU copies. """ assert hasattr(module, "_hf_hook") @@ -99,7 +106,8 @@ def weight_access_and_writeback_context(module): finally: for mod, hook, was_materialized in materialized: hook.offload = True - _writeback_params_to_weights_map(mod, hook) + if writeback: + _writeback_params_to_weights_map(mod, hook) if was_materialized: hook.post_forward(mod, None) diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 15c6504011b..cf3c1d0cdd3 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -29,6 +29,7 @@ from modelopt.torch.quantization.config import QuantizerCfgEntry from modelopt.torch.utils import get_unwrapped_name, print_rank_0 +from modelopt.torch.utils.network import temporarily_remove_accelerate_hook if TYPE_CHECKING: from collections.abc import Generator @@ -471,7 +472,24 @@ def _set_parameter(module: nn.Module, name: str, value: nn.Parameter): @contextmanager -def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn.Module): +def _fsdp2_unshard_context(fsdp_module: FSDPModule): + """Unshard an FSDP2 module without replacing individual DTensor parameters.""" + fsdp_param_group = fully_shard.state(fsdp_module)._fsdp_param_group + was_sharded = fsdp_param_group.is_sharded + if was_sharded: + fsdp_module.unshard() + try: + with _disable_fsdp_unshard_reshard(fsdp_module): + yield + finally: + if was_sharded: + fsdp_module.reshard() + + +@contextmanager +def fsdp2_weight_access_and_writeback_context( + module: nn.Module, root_model: nn.Module, writeback: bool = True +): """Context manager for FSDP2 weight access and writeback. Gathers sharded DTensor parameters across FSDP/HSDP shards so they can be @@ -486,6 +504,11 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn. assert not hasattr(module, "_hf_hook"), "We dont support FSDP2 with HF accelerate hooks" fsdp_module = _get_enclosing_fsdp_module(module, root_model) assert fsdp_module is not None, "Module is not wrapped by FSDP" + if not writeback: + with _fsdp2_unshard_context(fsdp_module): + yield + return + fsdp_device_mesh = _get_fsdp2_mesh(fsdp_module) fsdp_dim = fsdp_device_mesh.ndim @@ -525,7 +548,9 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn. @contextmanager -def enable_weight_access_and_writeback(module, root_model, name_to_module: dict | None = None): +def enable_weight_access_and_writeback( + module, root_model, name_to_module: dict | None = None, writeback: bool = True +): """Enable weight access and writeback for a module. Useful for modules with weight not intact such as Linear layer in FSDP wrapped model or @@ -539,16 +564,18 @@ def enable_weight_access_and_writeback(module, root_model, name_to_module: dict total cost when called in a loop. This causes significant CPU overhead on large models, particularly Sparse MoE architectures where each expert is typically implemented as its own module. + writeback: Whether modified weights must be written back to the owning sharded/offload + representation when exiting the context. """ if _get_enclosing_fsdp_module(module, root_model, name_to_module) is not None: - context = fsdp2_weight_access_and_writeback_context(module, root_model) + context = fsdp2_weight_access_and_writeback_context(module, root_model, writeback) elif is_quantized_parallel_linear(module) and hasattr(module, "_hf_tp_plan"): # HF transformers TP sharded linear layer context = module.enable_weight_access_and_writeback() elif hasattr(module, "_hf_hook"): from ..plugins.accelerate import weight_access_and_writeback_context - context = weight_access_and_writeback_context(module) + context = weight_access_and_writeback_context(module, writeback) else: context = nullcontext() @@ -557,7 +584,7 @@ def enable_weight_access_and_writeback(module, root_model, name_to_module: dict @contextmanager -def persistent_materialization(layer): +def persistent_materialization(layer, writeback: bool = True): """Keep all layer weights materialized on GPU for the duration. Suppresses per-forward weight transfers so that N calibration batches @@ -565,10 +592,15 @@ def persistent_materialization(layer): - **FSDP2**: patches ``FSDPParamGroup.unshard/reshard`` to no-ops, then gathers weights once via ``enable_weight_access_and_writeback``. - - **Accelerate**: materializes weights and sets ``hook.offload = False`` - so per-forward hooks skip materialization/offloading. + - **Accelerate**: materializes weights, sets ``hook.offload = False``, + and bypasses the layer's top-level accelerate hook while the weights are + materialized. """ - with _disable_fsdp_unshard_reshard(layer), enable_weight_access_and_writeback(layer, layer): + with ( + _disable_fsdp_unshard_reshard(layer), + enable_weight_access_and_writeback(layer, layer, writeback=writeback), + temporarily_remove_accelerate_hook(layer), + ): yield diff --git a/modelopt/torch/quantization/utils/layerwise_calib.py b/modelopt/torch/quantization/utils/layerwise_calib.py index eeb47cf2820..03eff46e803 100644 --- a/modelopt/torch/quantization/utils/layerwise_calib.py +++ b/modelopt/torch/quantization/utils/layerwise_calib.py @@ -42,6 +42,8 @@ ) if TYPE_CHECKING: + from collections.abc import Callable + from modelopt.torch.opt.searcher import ForwardLoop @@ -124,18 +126,19 @@ class LayerActivationCollector: _decoder_layer_support: list[tuple[Any, Any]] = [] _LAYER_ATTR = "_layerwise_calib" - def __init__(self, model: nn.Module): + def __init__(self, model: nn.Module, status_callback: Callable[[str], None] | None = None): """Initialize the collector for the given model.""" self.model = model self._decoder_layers: nn.ModuleList | None = None self._layer_to_idx: dict[nn.Module, int] = {} self._patched = False + self._status_callback = status_callback def _swap_to_dummy(self, idx: int): """Replace decoder layer *idx* with a parameter-free dummy. ``output_meta`` is intentionally preserved on the original layer: the - ``_SkipLayer`` reads it to produce correctly shaped zero-filled outputs + ``_SkipLayer`` reads it to produce correctly shaped placeholder outputs for the parent forward pass. """ assert self._decoder_layers is not None @@ -322,8 +325,14 @@ def _set_layer_states(self, layer_idx: int): cur.mode = "capture" cur.collected_inputs = [] - def _log_layer_summary(self, layer_idx: int): - """Log a one-line summary of layer modes for the current calibration step.""" + def _emit_status(self, status: str): + if self._status_callback is None: + print_rank_0(status) + else: + self._status_callback(status) + + def _layer_summary(self, layer_idx: int) -> str: + """Return a one-line summary of layer modes for the current calibration step.""" assert self._decoder_layers is not None n = len(self._decoder_layers) groups: dict[str, list[int]] = {} @@ -338,7 +347,10 @@ def _log_layer_summary(self, layer_idx: int): continue ids = groups[mode] parts.append(f"{mode}: {len(ids)}" if mode == "skip" else f"{mode}: {ids}") - print_rank_0(f"Calibrating layer {layer_idx + 1}/{n} | {' | '.join(parts)}") + return f"Calibrating layer {layer_idx + 1}/{n} | {' | '.join(parts)}" + + def _log_layer_summary(self, layer_idx: int): + self._emit_status(self._layer_summary(layer_idx)) @torch.no_grad() def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: @@ -402,7 +414,8 @@ def get_first_layer_inputs( assert self._decoder_layers is not None if resumed_inputs is not None: - print_rank_0(f"Calibrating layer {start_layer + 1} (resumed)") + n = len(self._decoder_layers) + self._emit_status(f"Calibrating layer {start_layer + 1}/{n} | resumed") for i in range(start_layer): self._swap_to_dummy(i) layer = self._decoder_layers[start_layer] @@ -420,7 +433,8 @@ def cache_outputs_for_next_layer_calib( This puts *layer* into "run" mode (setting its ``output_meta``) and the next layer into "capture" mode, then runs *forward_loop*. Returns the - captured inputs for the next layer. + captured inputs for the next layer. Callers should keep *layer* + materialized for the duration when using offload frameworks. Must be called only when a next layer exists (i.e. *layer* is not the last decoder layer). @@ -429,11 +443,9 @@ def cache_outputs_for_next_layer_calib( layer_idx = self._layer_to_idx[layer] next_idx = layer_idx + 1 assert next_idx < len(self._decoder_layers), "No next layer to capture inputs for." - from .core_utils import persistent_materialization next_layer = self._decoder_layers[next_idx] - with persistent_materialization(layer): - return self.get_input_activations(next_layer, forward_loop) + return self.get_input_activations(next_layer, forward_loop) def _move_to_device(obj: Any, device: torch.device) -> Any: @@ -478,7 +490,7 @@ def _write_manifest( last_completed_layer: int, num_layers: int, save_every: int, - save_quantizers_only: bool, + calib_mutates_weights: bool, ) -> None: """Atomically write manifest.json. Config keys are persisted so resume can detect drift.""" path = os.path.join(checkpoint_dir, "manifest.json") @@ -489,7 +501,7 @@ def _write_manifest( "last_completed_layer": last_completed_layer, "num_layers": num_layers, "save_every": save_every, - "save_quantizers_only": save_quantizers_only, + "calib_mutates_weights": calib_mutates_weights, }, f, ) @@ -511,7 +523,7 @@ def _save_layer_files( """Write the per-layer files for layer *idx*. Exactly one of ``weights`` (full layer state_dict) or ``quantizer_buffers`` - (just the TensorQuantizer state_dict slice, used by ``save_quantizers_only``) + (just the TensorQuantizer state_dict slice, used when calibration does not mutate weights) is written; ``full_restore`` falls back to whichever is present. ``next_inputs.pt`` and ``manifest.json`` are deferred to window boundaries in :meth:`_CheckpointState.save`. @@ -564,7 +576,7 @@ def __init__( num_layers: int, start_layer: int = 0, save_every: int = 1, - save_quantizers_only: bool = False, + calib_mutates_weights: bool = True, ): if dist.is_initialized() and dist.size() > 1: raise RuntimeError( @@ -577,7 +589,7 @@ def __init__( self.num_layers = num_layers self.start_layer = start_layer self.save_every = save_every - self.save_quantizers_only = save_quantizers_only + self.calib_mutates_weights = calib_mutates_weights # Tracks the most recent saved layer so save() can window-save the layers # since the last save event. Initialized to start_layer - 1 so the first # save event after resume covers the new work only. @@ -589,7 +601,7 @@ def from_folder( checkpoint_dir: str | None, num_layers: int, save_every: int = 1, - save_quantizers_only: bool = False, + calib_mutates_weights: bool = True, ) -> _CheckpointState | None: """Create from folder. Detects resume point. Returns None if no checkpoint_dir.""" if not checkpoint_dir: @@ -598,12 +610,10 @@ def from_folder( info = detect_resume_point(checkpoint_dir) if info is not None: manifest = info[1] - # Pre-0.45 manifests omit save_every / save_quantizers_only; skip the - # check for keys absent from the on-disk manifest. for key, new_value in ( ("num_layers", num_layers), ("save_every", save_every), - ("save_quantizers_only", save_quantizers_only), + ("calib_mutates_weights", calib_mutates_weights), ): ckpt_value = manifest.get(key) if ckpt_value is not None and ckpt_value != new_value: @@ -621,7 +631,7 @@ def from_folder( num_layers, start_layer=start, save_every=save_every, - save_quantizers_only=save_quantizers_only, + calib_mutates_weights=calib_mutates_weights, ) def setup_resume(self, layers: nn.ModuleList) -> list | None: @@ -692,10 +702,10 @@ def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None: ) layer.load_state_dict(weights, strict=False, assign=True) elif os.path.isfile(buffers_path): - # save_quantizers_only mode: restore just the TensorQuantizer - # state_dict (carries _amax). The layer's other weights - # weren't modified by the algorithm, so the in-memory values - # already match what would have been saved. + # Non-mutating calibration mode: restore just the TensorQuantizer + # state_dict (carries _amax). The layer's other weights were not + # modified, so the in-memory values already match what would have + # been saved. quantizer_buffers = torch.load( buffers_path, map_location=layer_device, weights_only=False ) @@ -727,14 +737,14 @@ def save( _cpu = torch.device("cpu") layer = layers[layer_idx] - with enable_weight_access_and_writeback(layer, model): + with enable_weight_access_and_writeback(layer, model, writeback=False): qstate = _move_to_device(quantizer_state(layer), _cpu) - if self.save_quantizers_only: - weights = None - quantizer_buffers = _move_to_device(get_quantizer_state_dict(layer), _cpu) - else: + if self.calib_mutates_weights: weights = _move_to_device(layer.state_dict(), _cpu) quantizer_buffers = None + else: + weights = None + quantizer_buffers = _move_to_device(get_quantizer_state_dict(layer), _cpu) output_meta = getattr(layer._layerwise_calib, "output_meta", None) if output_meta is None: @@ -766,7 +776,7 @@ def save( layer_idx, self.num_layers, save_every=self.save_every, - save_quantizers_only=self.save_quantizers_only, + calib_mutates_weights=self.calib_mutates_weights, ) window_start = self._last_saved_layer + 1 self._last_saved_layer = layer_idx diff --git a/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py b/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py index 49e74e5851f..3e766e6c1c6 100644 --- a/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py +++ b/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py @@ -395,8 +395,11 @@ def forward(self, x): def test_skip_dummy_has_no_hf_hook(monkeypatch): """Dummies must not carry _hf_hook from the original layer.""" + from contextlib import nullcontext + from accelerate.hooks import AlignDevicesHook, add_hook_to_module + from modelopt.torch.quantization.utils import persistent_materialization from modelopt.torch.quantization.utils.layerwise_calib import ( LayerActivationCollector, _SkipLayer, @@ -422,8 +425,12 @@ def forward_loop(m): collector = LayerActivationCollector(model) collector._patch_all_layers() try: - for layer in list(model.layers): - collector.get_input_activations(layer, forward_loop) + for i, layer in enumerate(list(model.layers)): + run_layer_context = ( + persistent_materialization(model.layers[i - 1]) if i > 0 else nullcontext() + ) + with run_layer_context: + collector.get_input_activations(layer, forward_loop) for i in range(2): dummy = model.layers[i] @@ -433,6 +440,26 @@ def forward_loop(m): collector._unpatch_all_layers() +def _assert_persistent_materialization_bypasses_top_hook(layer): + from modelopt.torch.quantization.utils import persistent_materialization + + assert hasattr(layer, "_hf_hook") + original_old_forward = layer._old_forward + + def sentinel_forward(*args, **kwargs): + return "unhooked" + + layer._old_forward = sentinel_forward + try: + with persistent_materialization(layer): + assert layer.forward is sentinel_forward + assert layer("unused") == "unhooked" + + assert layer.forward is not sentinel_forward + finally: + layer._old_forward = original_old_forward + + def test_persistent_materialization_cpu_offloaded(tmp_path): """persistent_materialization keeps CPU-offloaded weights on GPU and writes back modifications.""" import torch.nn as nn @@ -446,6 +473,8 @@ def test_persistent_materialization_cpu_offloaded(tmp_path): # Verify offloaded (meta device) assert all(p.device.type == "meta" for p in offloaded_layer.parameters()) + _assert_persistent_materialization_bypasses_top_hook(offloaded_layer) + # Save reference weight linear = None with enable_weight_access_and_writeback(offloaded_layer, model): @@ -564,6 +593,8 @@ def test_persistent_materialization_disk_offloaded(tmp_path): # Verify offloaded (meta device) assert all(p.device.type == "meta" for p in offloaded_layer.parameters()) + _assert_persistent_materialization_bypasses_top_hook(offloaded_layer) + # Save reference weight linear = None with enable_weight_access_and_writeback(offloaded_layer, model): diff --git a/tests/gpu/torch/quantization/test_fsdp2.py b/tests/gpu/torch/quantization/test_fsdp2.py index c5584ece5cf..1a65c70ba54 100644 --- a/tests/gpu/torch/quantization/test_fsdp2.py +++ b/tests/gpu/torch/quantization/test_fsdp2.py @@ -184,7 +184,10 @@ def _test_layerwise_calibrate_fsdp2(rank, size): # Reference: non-FSDP layerwise calibration ref_model = copy.deepcopy(model) seq_cfg = copy.deepcopy(mtq.INT8_DEFAULT_CFG) - seq_cfg["algorithm"] = {"method": "max", "layerwise": True} + seq_cfg["algorithm"] = { + "method": "max", + "layerwise": {"enable": True, "calib_mutates_weights": False}, + } mtq.quantize(ref_model, seq_cfg, lambda m: m(inputs)) output_ref = ref_model(inputs) @@ -258,6 +261,13 @@ def _test_persistent_materialization(rank, size): with enable_weight_access_and_writeback(layer[0], model): assert torch.allclose(layer[0].weight, ref_weight + 1.0) + with persistent_materialization(layer, writeback=False): + assert not isinstance(layer[0].weight, DTensor) + assert layer[0].weight.device.type == "cuda" + layer(inputs) + + assert isinstance(next(iter(layer.parameters())), DTensor) + def test_persistent_materialization(dist_workers): dist_workers.run(_test_persistent_materialization) diff --git a/tests/gpu/torch/quantization/test_layerwise_calibrate.py b/tests/gpu/torch/quantization/test_layerwise_calibrate.py index d38b82f46fb..e234f9e0638 100644 --- a/tests/gpu/torch/quantization/test_layerwise_calibrate.py +++ b/tests/gpu/torch/quantization/test_layerwise_calibrate.py @@ -329,3 +329,12 @@ def weight_doubling_calib(layer, layer_forward_loop, **kwargs): # Verify by running model.layers[0] with its updated weights actual = model.layers[0](x) assert torch.allclose(actual, expected) + + +def test_skip_placeholder_uses_recorded_device(): + device = torch.device("cpu") + meta = ("tensor", torch.Size([2, 3]), torch.float32, device) + + out = LayerActivationCollector._zeros_from_meta(meta) + assert out.device == device + assert torch.count_nonzero(out) == 0 diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index 82b9d24d7a4..daa28df3f3b 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -691,7 +691,7 @@ def test_default_dump_shape(self): "get_qdq_activations_from_prev_layer": False, "checkpoint_dir": None, "save_every": 1, - "save_quantizers_only": False, + "calib_mutates_weights": True, } assert "layerwise_checkpoint_dir" not in dumped @@ -703,15 +703,15 @@ def test_save_every_must_be_positive(self): "cfg_cls", [GPTQCalibConfig, AWQLiteCalibConfig, SmoothQuantCalibConfig], ) - def test_save_quantizers_only_rejected_for_weight_mutating_algorithms(self, cfg_cls): + def test_calib_mutates_weights_false_rejected_for_weight_mutating_algorithms(self, cfg_cls): """Whitelist: only amax-only algorithms (max/mse/local_hessian) may set - save_quantizers_only=True. Weight-mutating algorithms (GPTQ folds Hessian + calib_mutates_weights=False. Weight-mutating algorithms (GPTQ folds Hessian updates, AWQ/SmoothQuant fold pre-quant scales) must reject the flag. """ with pytest.raises(ValidationError, match="mutates layer weights in-place"): - cfg_cls(layerwise={"enable": True, "save_quantizers_only": True}) + cfg_cls(layerwise={"enable": True, "calib_mutates_weights": False}) @pytest.mark.parametrize("cfg_cls", [MaxCalibConfig, MseCalibConfig, LocalHessianCalibConfig]) - def test_save_quantizers_only_accepted_for_amax_only_algorithms(self, cfg_cls): - cfg = cfg_cls(layerwise={"enable": True, "save_quantizers_only": True}) - assert cfg.layerwise.save_quantizers_only is True + def test_calib_mutates_weights_false_accepted_for_amax_only_algorithms(self, cfg_cls): + cfg = cfg_cls(layerwise={"enable": True, "calib_mutates_weights": False}) + assert cfg.layerwise.calib_mutates_weights is False diff --git a/tests/unit/torch/quantization/test_layerwise_calibrate.py b/tests/unit/torch/quantization/test_layerwise_calibrate.py index 5ec4cc2daed..a8d82ff8b54 100644 --- a/tests/unit/torch/quantization/test_layerwise_calibrate.py +++ b/tests/unit/torch/quantization/test_layerwise_calibrate.py @@ -610,6 +610,56 @@ def _int8_layerwise_config(algorithm: dict) -> dict: return cfg +def test_layerwise_calibrate_uses_global_layer_tqdm(monkeypatch): + _register_test_discoverer(monkeypatch) + + class _FakeTqdm: + instances = [] + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.postfixes = [] + self.updates = [] + self.closed = False + _FakeTqdm.instances.append(self) + + def set_postfix_str(self, status, refresh=True): + self.postfixes.append((status, refresh)) + + def update(self, n=1): + self.updates.append(n) + + def close(self): + self.closed = True + + monkeypatch.setattr("modelopt.torch.quantization.model_calib.tqdm", _FakeTqdm) + + torch.manual_seed(0) + model = _SimpleTransformerModel(n_layers=3, dim=16) + calib_data = [torch.randint(0, 32, (2, 8)) for _ in range(2)] + + def forward_loop(m): + for batch in calib_data: + m(batch) + + def calib_func(layer, layer_forward_loop): + layer_forward_loop(layer) + + layerwise_calibrate(model, forward_loop, calib_func) + + assert len(_FakeTqdm.instances) == 1 + pbar = _FakeTqdm.instances[0] + assert pbar.kwargs["total"] == 3 + assert pbar.kwargs["initial"] == 0 + assert pbar.kwargs["desc"] == "Layerwise calibration" + assert pbar.kwargs["dynamic_ncols"] is True + assert pbar.updates == [1, 1, 1] + assert pbar.closed + assert any(status.startswith("Calibrating layer 1/3") for status, _ in pbar.postfixes) + assert any(status.startswith("Calibrating layer 3/3") for status, _ in pbar.postfixes) + + def _awq_layerwise_config() -> dict: """INT4 weight-only AWQ config sized for the _DecoderBlock test model.""" cfg = copy.deepcopy(mtq.INT4_AWQ_CFG) @@ -866,23 +916,23 @@ def test_layerwise_save_every_writes_next_inputs_only_at_window_boundaries(monke @pytest.mark.parametrize( - ("scenario", "n_layers", "save_every", "save_quantizers_only", "rewind_to"), + ("scenario", "n_layers", "save_every", "calib_mutates_weights", "rewind_to"), [ # Pins the quantizer_buffers.pt restore path (no weights.pt on disk). - ("quantizers_only", 3, 1, True, 0), + ("non_mutating", 3, 1, False, 0), # Pins the per-call snapshot fix: each save() captures the # just-calibrated layer's state before the next-layer capture forward # swaps it to _SkipLayer. - ("save_every", 4, 2, False, 1), + ("save_every", 4, 2, True, 1), ], ) def test_layerwise_checkpoint_resume_matches_one_shot_amax( - monkeypatch, tmp_path, scenario, n_layers, save_every, save_quantizers_only, rewind_to + monkeypatch, tmp_path, scenario, n_layers, save_every, calib_mutates_weights, rewind_to ): """Full run → rewind manifest → fresh resume reproduces one-shot ``_amax``. Single test covering both checkpoint optimizations. For the - ``save_quantizers_only`` case also asserts the on-disk shape (no + non-mutating calibration case also asserts the on-disk shape (no ``weights.pt``, ``quantizer_buffers.pt`` present per layer). """ _register_test_discoverer(monkeypatch) @@ -898,7 +948,7 @@ def build_cfg(ckpt_dir): "enable": True, "checkpoint_dir": str(ckpt_dir), "save_every": save_every, - "save_quantizers_only": save_quantizers_only, + "calib_mutates_weights": calib_mutates_weights, }, } ) @@ -915,7 +965,7 @@ def build_cfg(ckpt_dir): setup_model = _SimpleTransformerModel(n_layers=n_layers, dim=16) mtq.quantize(setup_model, build_cfg(resume_dir), forward_loop=forward_loop) - if scenario == "quantizers_only": + if scenario == "non_mutating": for name in _layer_dir_names(resume_dir): d = resume_dir / name assert not (d / "weights.pt").exists() @@ -927,7 +977,7 @@ def build_cfg(ckpt_dir): "last_completed_layer": rewind_to, "num_layers": n_layers, "save_every": save_every, - "save_quantizers_only": save_quantizers_only, + "calib_mutates_weights": calib_mutates_weights, } ) ) @@ -1013,7 +1063,7 @@ def test_layerwise_checkpoint_mismatch_save_every_raises(monkeypatch, tmp_path): "last_completed_layer": 1, "num_layers": 4, "save_every": 2, - "save_quantizers_only": False, + "calib_mutates_weights": True, } ) )