From bcc05b1c7fdca740f3dabcb3d0e061f5218bdfb3 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 6 Feb 2026 21:02:12 +0000 Subject: [PATCH 01/33] add rabbit feedback Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index ed57ea3fc..9f72fb9c6 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -563,10 +563,17 @@ def forward(self, input, *args, **kwargs): for name, module in name_to_module.items(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: with enable_weight_access_and_writeback(module, model, name_to_module): +<<<<<<< HEAD module.hessian_helper = LocalHessianHelper(module, name) module.hessian_helper.setup() all_patched_modules.append((name, module)) if module.hessian_helper.is_enabled: +======= + module.local_hessian = LocalHessianHelper(module, name) + module.local_hessian.setup() + all_patched_modules.append((name, module)) + if module.local_hessian.is_enabled: +>>>>>>> e391ea1a (add rabbit feedback) weight_quantizers_info.append((name, module)) # Cache activations by running forward loop @@ -689,7 +696,11 @@ def quant_func(x, amax, quantizer=weight_quantizer): # Cleanup and free memory LocalHessianHelper.cache_mode = False for name, module in all_patched_modules: +<<<<<<< HEAD module.hessian_helper.cleanup() +======= + module.local_hessian.cleanup() +>>>>>>> e391ea1a (add rabbit feedback) print_rank_0("local_hessian: Calibration complete.") From e2c781e2a5a52413b8c808464b923f883e5fd9f5 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 12 Feb 2026 23:12:30 -0800 Subject: [PATCH 02/33] minor Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 9f72fb9c6..ed57ea3fc 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -563,17 +563,10 @@ def forward(self, input, *args, **kwargs): for name, module in name_to_module.items(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: with enable_weight_access_and_writeback(module, model, name_to_module): -<<<<<<< HEAD module.hessian_helper = LocalHessianHelper(module, name) module.hessian_helper.setup() all_patched_modules.append((name, module)) if module.hessian_helper.is_enabled: -======= - module.local_hessian = LocalHessianHelper(module, name) - module.local_hessian.setup() - all_patched_modules.append((name, module)) - if module.local_hessian.is_enabled: ->>>>>>> e391ea1a (add rabbit feedback) weight_quantizers_info.append((name, module)) # Cache activations by running forward loop @@ -696,11 +689,7 @@ def quant_func(x, amax, quantizer=weight_quantizer): # Cleanup and free memory LocalHessianHelper.cache_mode = False for name, module in all_patched_modules: -<<<<<<< HEAD module.hessian_helper.cleanup() -======= - module.local_hessian.cleanup() ->>>>>>> e391ea1a (add rabbit feedback) print_rank_0("local_hessian: Calibration complete.") From 7f21be1ae9f9ea2ebd49b2944933cf771753f5ed Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:20:37 +0000 Subject: [PATCH 03/33] tested perplexity Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 50 +++++++++++++++++++++++++++ modelopt/torch/quantization/mode.py | 14 ++++++++ modelopt/torch/utils/network.py | 1 + 3 files changed, 65 insertions(+) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 46a36cb8f..fe5b56493 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -269,6 +269,18 @@ "algorithm": "max", } +INT4_BLOCKWISE_WEIGHT_ONLY_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True}, + "*input_quantizer": {"enable": False}, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + INT4_AWQ_CFG = { "quant_cfg": { @@ -1392,6 +1404,44 @@ class GPTQLiteConfig(QuantizeAlgorithmConfig): ) +class GPTQConfig(QuantizeAlgorithmConfig): + """The config for GPTQ lite. + + GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation. + + GPTQ lite does not perform sequential quantization of layers. This means that the updated + activations are not used to process the next layer. + + The default values are taken from the official GPTQ implementation: + https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L35 + + Note: This feature is currently experimental and may not translate to improved accuracy as expected. + + + """ + + method: Literal["gptq"] = ModeloptField("gptq") + percdamp: float | None = ModeloptField( + default=0.01, + gt=0.0, + le=1.0, + title="Percentage damping factor.", + description="The percentage of average Hessian diagonal used for damping.", + ) + block_size: int | None = ModeloptField( + default=128, + title="Block size for GPTQ weight update.", + description="""The block size for GPTQ weight update, which must be a multiple of the + group_size used in the quantization.""", + ) + hessian_state_path: str | None = ModeloptField( + default=None, + title="Path to the Hessian state file.", + description="""The path to the Hessian state file. If hessian path exists, we load from + hessian file instead of recomputing them.""", + ) + + QuantizeQuantCfgType = dict[ str | Callable, QuantizerAttributeConfig diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index e08efece9..88e93bb77 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -37,6 +37,7 @@ AWQFullCalibConfig, AWQLiteCalibConfig, CompressConfig, + GPTQConfig, GPTQLiteConfig, LocalHessianCalibConfig, MaxCalibConfig, @@ -59,6 +60,7 @@ ) from .model_calib import ( awq, + gptq, gptq_lite, local_hessian_calibrate, max_calibrate, @@ -502,3 +504,15 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]: return GPTQLiteConfig _calib_func = gptq_lite + + +@CalibrateModeRegistry.register_mode +class GPTQModeDescriptor(BaseCalibrateModeDescriptor): + """Mode for GPTQ calibration algorithm.""" + + @property + def config_class(self) -> type[QuantizeAlgorithmConfig]: + """Specifies the config class for the mode.""" + return GPTQConfig + + _calib_func = gptq diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index b54332375..b07ca570c 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -46,6 +46,7 @@ def _convert_to_wrapped_module_name(name: str) -> str: "ModelLike", "compare_dict", "create_param_grad_clear_hook", + "get_decoder_layers", "get_model_attributes", "get_module_device", "get_same_padding", From 9b600bfd82f3ff80894e227b1775a5828bba93cc Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 9 Feb 2026 16:46:47 +0000 Subject: [PATCH 04/33] tested, revert later --- examples/llm_ptq/hf_ptq.py | 76 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index f1856c113..cb05d88b9 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -675,6 +675,82 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) + if True: + # Disable quantizers + # mtq.fold_weight(full_model) + # print("Folded weights") + print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") + mtq.disable_quantizer(full_model, "*") + if True: + # mtq.fold_weight(full_model) + import os + + import torch.nn.functional as F + from datasets import load_dataset + from tqdm import trange + from transformers import AutoTokenizer + + # Set cache directory to work directory to avoid disk space issues + cache_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), ".hf_cache" + ) + os.makedirs(cache_dir, exist_ok=True) + os.environ["HF_DATASETS_CACHE"] = cache_dir + print(f"Using HuggingFace datasets cache: {cache_dir}") + + def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): + test_dataset_raw = load_dataset( + "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir + ) + test_dataset_tok = tokenizer( + "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" + ).input_ids + num_test_sequences = test_dataset_tok.numel() // sequence_length + test_loader = [ + test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] + for i in range(num_test_sequences) + ] + return test_loader + + @torch.no_grad() + def _compute_perplexity(model, data, batch_size: int = 1): + num_samples = len(data) + device = next(model.parameters()).device + # Running estimate of negative log-likelihood + nll_running = 0 + # Number of tokens processed to far + tokens_processed = 0 + # Loop through each batch + for i in trange( + 0, num_samples, batch_size, desc="Computing perplexity", leave=False + ): + j = min(i + batch_size, num_samples) + inputs = torch.cat(data[i:j]).to(device) + # Forward pass through the model + lm_logits = model(inputs).logits + # Shift logits and labels for next token prediction + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = inputs[:, 1:] + # Compute loss + loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.size(-1)), + shift_labels.reshape(-1), + ) + # Calculate negative log likelihood + a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) + b = tokens_processed / (tokens_processed + shift_labels.numel()) + nll_running = a * loss + b * nll_running + # Update number of processed tokens + tokens_processed += shift_labels.numel() + # Compute perplexity + ppl = nll_running.exp().item() + return ppl + + eval_data = _get_wikitext2(tokenizer, 2048) + ppl = _compute_perplexity(full_model, eval_data) + print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") + + breakpoint() # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization From 0cc53df7de0cf32cde9fc2bc7d3de7c6e7ae9d1b Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 10 Feb 2026 04:41:46 +0000 Subject: [PATCH 05/33] tested Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 17 +++-- modelopt/torch/quantization/config.py | 94 +++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 5 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index cb05d88b9..cab6965f0 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -675,14 +675,16 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) - if True: + if args.export_qdq_weights: # Disable quantizers - # mtq.fold_weight(full_model) - # print("Folded weights") + if "gptq" not in args.qformat: + mtq.fold_weight(full_model) + print("Folded weights") + print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") mtq.disable_quantizer(full_model, "*") + if True: - # mtq.fold_weight(full_model) import os import torch.nn.functional as F @@ -750,7 +752,6 @@ def _compute_perplexity(model, data, batch_size: int = 1): ppl = _compute_perplexity(full_model, eval_data) print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - breakpoint() # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization @@ -1217,6 +1218,12 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) + parser.add_argument( + "--export_qdq_weights", + help=("Used for GPTQ weights as is without compressed weights for deployment."), + default=False, + action="store_true", + ) parser.add_argument( "--verbose", help="Print verbose output (e.g. quantization summary). Disable by --no-verbose.", diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index fe5b56493..c8531b2d3 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -281,6 +281,100 @@ }, } +NVFP4_STATIC_WO_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + +NVFP4_STATIC_WO_GPTQ_LITE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} + +NVFP4_STATIC_WO_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "max", + "use_sequential": False, + }, +} + +NVFP4_STATIC_WO_GPTQ_LITE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} + +NVFP4_DYNAMIC_WO_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} INT4_AWQ_CFG = { "quant_cfg": { From cde122a585f47f01eb0bc98b54bb6d001e1cb036 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 11 Feb 2026 07:43:06 +0000 Subject: [PATCH 06/33] refactor Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 80 +++++----------------- 1 file changed, 16 insertions(+), 64 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index ed57ea3fc..7b0a3df5e 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1601,56 +1601,6 @@ def prepare_hessian_inverse(h, weight, percdamp): return h_inv -def quantize_block(full_weight, block_start, block_end, h_inv, quantizer): - """Quantize a block of weights group by group (based on quantizer block sizes) with error propagation. - - Args: - full_weight: The full weight tensor (needed for INT4 quantization) - block_start: Starting column index of the block - block_end: Ending column index of the block - h_inv: Hessian inverse - quantizer: The quantizer to apply - Returns: - quantized_block: Quantized weights for this block - losses: Quantization losses per element - errors: Accumulated errors for propagation - """ - # Extract the block we're working on - block_weight = full_weight[:, block_start:block_end] - block_hinv = h_inv[block_start:block_end, block_start:block_end] - block_size = block_end - block_start - - quantized_block = torch.zeros_like(block_weight) - losses = torch.zeros_like(block_weight) - errors = torch.zeros_like(block_weight) - - # We perform column-wise update for GPTQ within the block - group_size = 1 - - for group_start in range(0, block_size, group_size): - group_end = min(group_start + group_size, block_size) - group_cols = slice(group_start, group_end) - # Get current column and its Hessian inverse diagonal - weight_col = block_weight[:, group_cols] - hinv_diag = torch.diag(block_hinv[group_cols, group_cols]) - - # Quantize using the full weight, then extract the columns we need - quantized_full = quantizer(full_weight) - quantized_cols = quantized_full[:, block_start + group_start : block_start + group_end] - quantized_block[:, group_cols] = quantized_cols - - # Compute quantization error and loss - error = (weight_col - quantized_cols) / hinv_diag - losses[:, group_cols] = (weight_col - quantized_cols) ** 2 / (hinv_diag**2) / 2 - errors[:, group_cols] = error - - # Propagate error to remaining columns in block - block_weight[:, group_start:] -= error @ block_hinv[group_start:group_end, group_start:] - full_weight[:, block_start:block_end] = block_weight - - return quantized_block, losses, errors - - def blockwise_weight_update(module, h, block_size, percdamp): """Update module weights using GPTQ-style blockwise quantization. @@ -1666,28 +1616,30 @@ def blockwise_weight_update(module, h, block_size, percdamp): # Preprocess Hessian: handle dead neurons and add damping h_inv = prepare_hessian_inverse(h, weight, percdamp) - # Initialize output tensors - quantized_weight = torch.zeros_like(weight) - losses = torch.zeros_like(weight) - # Process weights in blocks for block_start in range(0, num_cols, block_size): block_end = min(block_start + block_size, num_cols) - - quantized_block, block_losses, block_errors = quantize_block( - weight, block_start, block_end, h_inv, module.weight_quantizer - ) - # Store results - quantized_weight[:, block_start:block_end] = quantized_block - losses[:, block_start:block_end] = block_losses + n_cols = block_end - block_start + wblk = weight.clone() + errs = torch.zeros_like(wblk[:, block_start:block_end]) + h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end] + + for i in range(n_cols): + w_ci = wblk[:, block_start + i] + d = h_inv_cho_blk[i, i] + qdq = module.weight_quantizer(wblk) + weight[:, block_start + i] = qdq[:, block_start + i] + err = (w_ci - qdq[:, block_start + i]) / d + wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) + errs[:, i] = err # Propagate errors to remaining weights - weight[:, block_end:] -= block_errors @ h_inv[block_start:block_end, block_end:] + weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1) # Print relative mse error - _print_relative_mse_error(quantized_weight, module.weight.float(), h, module.name) + _print_relative_mse_error(weight, module.weight.float(), h, module.name) # Update module weights - module.weight.data = quantized_weight.reshape(module.weight.shape).to(module.weight.data.dtype) + module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype) def gptq_lite( From c2aeed5aab3bbc16a61ad9b3e48eea9d46ed2d6c Mon Sep 17 00:00:00 2001 From: realAsma <86726418+realAsma@users.noreply.github.com> Date: Fri, 6 Feb 2026 11:47:36 -0800 Subject: [PATCH 07/33] Track global_amax for weight FP4 MSE sweep; Refactor to NVFP4StaticQantizer, NVFP4MSECalibrator (#849) **Type of change:** ? **Overview:** ? ```python ``` - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No * **New Features** * Added NVFP4StaticQuantizer for improved 4-bit quantization with enhanced precision control * Introduced NVFP4MSECalibrator with flexible candidate generation for calibration optimization * **Improvements** * Optimized GPU kernels for Hopper+ graphics cards with better performance * Extended Triton support to broader GPU compatibility * Enhanced backward compatibility for restoring previously quantized models * **Tests** * Added comprehensive test coverage for new quantizers and calibration methods --------- Signed-off-by: realAsma --- modelopt/torch/quantization/triton/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modelopt/torch/quantization/triton/__init__.py b/modelopt/torch/quantization/triton/__init__.py index def70e591..6e8d4dba1 100644 --- a/modelopt/torch/quantization/triton/__init__.py +++ b/modelopt/torch/quantization/triton/__init__.py @@ -34,6 +34,10 @@ from .fp4_kernel import * from .fp8_kernel import * + # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) + if torch.cuda.get_device_capability() >= (8, 9): + from .fp4_kernel_hopper import * + # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) if torch.cuda.get_device_capability() >= (8, 9): from .fp4_kernel_hopper import * From d1ebcca9cd5cb4884f7a20377435b604b3188b8e Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 6 Feb 2026 22:56:02 +0000 Subject: [PATCH 08/33] address reviewers feedback, delegate scaling factor calculation to NVFP4QTensor Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 4ceb51cd2..b762757cb 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -360,9 +360,10 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_NVFP4_FP8, ]: - # Calibrate weight quantizer if amax is not set - module_name = f"{type(module).__name__}.{weight_name}" - _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) + # Calibrate weight quantizer if amax is not set (only needed for dynamic quantizers) + if not is_nvfp4_static: + module_name = f"{type(module).__name__}.{weight_name}" + _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. From 2cf82949a32f6c07482ae93ff5c0ce30be99686c Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:20:37 +0000 Subject: [PATCH 09/33] tested perplexity Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 87 ++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 7b0a3df5e..6fff65410 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1835,3 +1835,90 @@ def _layer_forward_loop(m, _inputs=layer_inputs): torch.cuda.empty_cache() finally: input_getter._unpatch_all_layers() + + print_rank_0("Sequential calibration completed successfully") + + +@torch.no_grad() +def gptq( + layer: nn.Module, + inputs: list[tuple[tuple, dict]], + percdamp: float = 0.01, + block_size: int = 128, + **kwargs, +): + """GPTQ quantization - a GPTQ variant.""" + import time + + total_start = time.time() + + # Dictionary to store hessian matrices for all linear layers in this decoder + hessian_state = {} + + # Phase 1: Build tensor mapping for all quantized linear layers in this decoder layer + tensor_mapping = {} + for name, module in layer.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + in_features = module.weight.shape[-1] + tensor_mapping[name] = ((in_features, in_features), module.weight.device) + module.name = name # Attach name for easy access in hooks + + if not tensor_mapping: + print_rank_0("No quantized linear layers found in decoder layer, skipping GPTQ") + return + + # Initialize hessian state with zeros + for name, (shape, device) in tensor_mapping.items(): + hessian_state[name] = { + "hessian": torch.zeros(shape, dtype=torch.float32, device=device), + "n_samples": 0, + } + + # Phase 2: Register hooks to collect Hessians during forward passes + def hessian_hook(module, input, output): + """Hook to intercept activations and update hessian matrix.""" + state = hessian_state[module.name] + hessian, n_samples = update_hessian(input[0], state["hessian"], state["n_samples"]) + hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} + + handles = [] + for name, module in layer.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + handles.append(module.register_forward_hook(hessian_hook)) + + # Run forward passes with the provided inputs to collect Hessians + hessian_start = time.time() + print_rank_0( + f"Computing Hessians for {len(tensor_mapping)} linear layers using {len(inputs)} batches..." + ) + for args, kwargs_input in inputs: + layer(*args, **kwargs_input) + + # Remove hooks after collecting Hessians + for handle in handles: + handle.remove() + + torch.cuda.synchronize() if torch.cuda.is_available() else None + hessian_time = time.time() - hessian_start + + # Phase 3: Update weights using computed Hessians (same as gptq_lite) + weight_update_start = time.time() + print_rank_0("Updating weights using GPTQ algorithm...") + for name, module in layer.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + state = hessian_state[module.name] + hessian = state["hessian"].to(module.weight.device) + blockwise_weight_update(module, hessian, block_size, percdamp) + # Free memory + del hessian_state[module.name] + torch.cuda.empty_cache() + + torch.cuda.synchronize() if torch.cuda.is_available() else None + weight_update_time = time.time() - weight_update_start + + total_time = time.time() - total_start + print_rank_0( + f"GPTQ timing - Hessian: {hessian_time:.2f}s, " + f"Weight update: {weight_update_time:.2f}s, " + f"Total: {total_time:.2f}s" + ) From abf6e8dbcbae41261b9ce87d2415ff5240c257a0 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Thu, 12 Feb 2026 18:54:46 +0000 Subject: [PATCH 10/33] tested exported checkpoints on 0211 Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 69 ++++++++++++++++++++++ modelopt/torch/export/unified_export_hf.py | 4 +- modelopt/torch/quantization/config.py | 22 +++++++ modelopt/torch/quantization/model_calib.py | 57 +++++++++++++++++- 4 files changed, 147 insertions(+), 5 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index cab6965f0..7bc486656 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -675,6 +675,75 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) + + if True: + import os + + import torch.nn.functional as F + from datasets import load_dataset + from tqdm import trange + from transformers import AutoTokenizer + + # Set cache directory to work directory to avoid disk space issues + cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".hf_cache") + os.makedirs(cache_dir, exist_ok=True) + os.environ["HF_DATASETS_CACHE"] = cache_dir + print(f"Using HuggingFace datasets cache: {cache_dir}") + + def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): + test_dataset_raw = load_dataset( + "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir + ) + test_dataset_tok = tokenizer( + "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" + ).input_ids + num_test_sequences = test_dataset_tok.numel() // sequence_length + test_loader = [ + test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] + for i in range(num_test_sequences) + ] + return test_loader + + @torch.no_grad() + def _compute_perplexity(model, data, batch_size: int = 1): + num_samples = len(data) + device = next(model.parameters()).device + # Running estimate of negative log-likelihood + nll_running = 0 + # Number of tokens processed to far + tokens_processed = 0 + # Loop through each batch + for i in trange( + 0, num_samples, batch_size, desc="Computing perplexity", leave=False + ): + j = min(i + batch_size, num_samples) + inputs = torch.cat(data[i:j]).to(device) + # Forward pass through the model + lm_logits = model(inputs).logits + # Shift logits and labels for next token prediction + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = inputs[:, 1:] + # Compute loss + loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.size(-1)), + shift_labels.reshape(-1), + ) + # Calculate negative log likelihood + a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) + b = tokens_processed / (tokens_processed + shift_labels.numel()) + nll_running = a * loss + b * nll_running + # Update number of processed tokens + tokens_processed += shift_labels.numel() + # Compute perplexity + ppl = nll_running.exp().item() + return ppl + + eval_data = _get_wikitext2(tokenizer, 2048) + ppl = _compute_perplexity(full_model, eval_data) + print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") + print(f"Saving model to {args.export_path}") + full_model.save_pretrained(args.export_path) + if args.export_qdq_weights: # Disable quantizers if "gptq" not in args.qformat: diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 78c8874a0..e5810bc1f 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -556,7 +556,7 @@ def _export_quantized_weight( )[0] quantized_weight = to_quantized_weight( - weight.to(dtype), + weight.to(torch.bfloat16), weight_scale, quantization_format, weight_scale_2, @@ -573,7 +573,7 @@ def _export_quantized_weight( ) quantized_weight = to_quantized_weight( - weight.to(dtype), + weight.to(torch.bfloat16), weight_scale, quantization_format, weight_scale_2, diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index c8531b2d3..93483374b 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -300,6 +300,28 @@ }, } +NVFP4_STATIC_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + NVFP4_STATIC_WO_GPTQ_LITE_CFG = { "quant_cfg": { "*weight_quantizer": { diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 6fff65410..2ac165df9 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -15,6 +15,7 @@ """Calibration utilities.""" +import contextlib import math import os import warnings @@ -1789,6 +1790,56 @@ def hessian_hook(module, input, output): print_rank_0("GPTQ-lite quantization completed successfully") +def _set_input_quantizers_calib_mode(layer: nn.Module): + """Set all input quantizers of a layer to calibration mode.""" + for name, module in layer.named_modules(): + if ( + isinstance(module, TensorQuantizer) + and "input_quantizer" in name + and not module._disabled + and not module._dynamic + and module._calibrator is not None + ): + module._calibrator.reset() + module.disable_quant() + module.enable_calib() + + +def _set_input_quantizers_quant_mode(layer: nn.Module): + """Load fresh amaxes and restore all input quantizers of a layer to quant mode.""" + for name, module in layer.named_modules(): + if ( + isinstance(module, TensorQuantizer) + and "input_quantizer" in name + and not module._disabled + and not module._dynamic + and module._calibrator is not None + ): + if module._calibrator.compute_amax() is not None: + module.load_calib_amax() + module.enable_quant() + module.disable_calib() + + +@contextlib.contextmanager +def _disable_input_quantizers(layer: nn.Module): + """Temporarily disable all enabled input quantizers in a layer.""" + enabled_quantizers = [] + for name, module in layer.named_modules(): + if ( + isinstance(module, TensorQuantizer) + and "input_quantizer" in name + and not module._disabled + ): + module.disable() + enabled_quantizers.append(module) + try: + yield + finally: + for module in enabled_quantizers: + module.enable() + + @torch.no_grad() def sequential_calibrate( model: nn.Module, @@ -1836,8 +1887,6 @@ def _layer_forward_loop(m, _inputs=layer_inputs): finally: input_getter._unpatch_all_layers() - print_rank_0("Sequential calibration completed successfully") - @torch.no_grad() def gptq( @@ -1877,8 +1926,10 @@ def gptq( # Phase 2: Register hooks to collect Hessians during forward passes def hessian_hook(module, input, output): """Hook to intercept activations and update hessian matrix.""" + if hasattr(module, "input_quantizer") and module.input_quantizer.is_enabled: + inp = module.input_quantizer(input[0]) state = hessian_state[module.name] - hessian, n_samples = update_hessian(input[0], state["hessian"], state["n_samples"]) + hessian, n_samples = update_hessian(inp, state["hessian"], state["n_samples"]) hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} handles = [] From c604539b82e1c3e015a516ecdd1e40d8a6e42339 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 13 Feb 2026 19:53:25 +0000 Subject: [PATCH 11/33] tested nano v3 Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 7bc486656..ce90705e6 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -741,8 +741,7 @@ def _compute_perplexity(model, data, batch_size: int = 1): eval_data = _get_wikitext2(tokenizer, 2048) ppl = _compute_perplexity(full_model, eval_data) print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - print(f"Saving model to {args.export_path}") - full_model.save_pretrained(args.export_path) + breakpoint() if args.export_qdq_weights: # Disable quantizers @@ -750,8 +749,8 @@ def _compute_perplexity(model, data, batch_size: int = 1): mtq.fold_weight(full_model) print("Folded weights") - print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") - mtq.disable_quantizer(full_model, "*") + print(f"Saving model to {args.export_path}") + full_model.save_pretrained(args.export_path) if True: import os From f94e5771a46a4d323723f1a420b2e59e2f74ee5b Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 16 Feb 2026 02:48:11 +0000 Subject: [PATCH 12/33] added activation MSE logging Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 48 ++++++++++++++++++++++ modelopt/torch/quantization/__init__.py | 1 + modelopt/torch/quantization/model_calib.py | 2 + 3 files changed, 51 insertions(+) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index ce90705e6..6d2e62a58 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -15,6 +15,7 @@ import argparse import copy +import os import random import time import warnings @@ -565,6 +566,43 @@ def mono_quantize( else: calibrate_loop = create_forward_loop(dataloader=calib_dataloader) + # Phase 1: Collect pre-quantization activations (batch_size=1 to save memory) + if getattr(args, "measure_activation_mse", False): + mse_max_samples = getattr(args, "activation_mse_max_samples", 16) + mse_save_dir = getattr(args, "activation_mse_save_dir", None) + mse_input_path = getattr(args, "activation_mse_input_path", None) + + # Materialize or load a frozen set of MSE inputs so that the exact + # same samples are used across runs and across codebases. + if mse_input_path and os.path.isfile(mse_input_path): + mse_data = mtq.ActivationMSELogger.load_data(mse_input_path) + else: + from torch.utils.data import DataLoader as _DataLoader + + mse_dataloader = _DataLoader(calib_dataloader.dataset, batch_size=1, shuffle=False) + if mse_input_path: + mse_data = mtq.ActivationMSELogger.materialize_data( + mse_dataloader, + mse_input_path, + max_samples=mse_max_samples, + ) + else: + # No path given -- materialize in memory only + mse_data = [] + for i, batch in enumerate(mse_dataloader): + if i >= mse_max_samples: + break + t = batch["input_ids"] if isinstance(batch, dict) else batch + mse_data.append(t.cpu()) + + mse_logger = mtq.ActivationMSELogger( + max_samples=mse_max_samples, + layer_filter=getattr(args, "activation_mse_layer_filter", None), + save_dir=mse_save_dir, + ) + print("\n--- Phase 1: Collecting pre-quantization activations ---") + mse_logger.collect(language_model, mse_data, phase="original") + if calibration_only: language_model = mtq.calibrate( language_model, quant_cfg["algorithm"], forward_loop=calibrate_loop @@ -572,6 +610,16 @@ def mono_quantize( else: language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop) + # Phase 2: Compute MSE against stored pre-quant activations + if getattr(args, "measure_activation_mse", False): + print("\n--- Phase 2: Computing per-layer activation MSE ---") + mse_logger.collect(language_model, mse_data, phase="quantized") + mse_logger.compute_mse() + print(mse_logger.summary()) + if mse_save_dir: + mse_logger.save() + del mse_logger, mse_data + # For VL models, update full_model to use the quantized language model if is_nemotron_vl_model: language_model_lineage = get_language_model_from_vl(full_model) diff --git a/modelopt/torch/quantization/__init__.py b/modelopt/torch/quantization/__init__.py index 87dbf30bb..757b844fb 100644 --- a/modelopt/torch/quantization/__init__.py +++ b/modelopt/torch/quantization/__init__.py @@ -19,6 +19,7 @@ from . import mode, plugins, utils # Add methods to mtq namespace +from .activation_mse import ActivationMSELogger, collect_activations, measure_activation_mse from .compress import * from .config import * from .conversion import * diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 2ac165df9..959e6117a 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1928,6 +1928,8 @@ def hessian_hook(module, input, output): """Hook to intercept activations and update hessian matrix.""" if hasattr(module, "input_quantizer") and module.input_quantizer.is_enabled: inp = module.input_quantizer(input[0]) + else: + inp = input[0] state = hessian_state[module.name] hessian, n_samples = update_hessian(inp, state["hessian"], state["n_samples"]) hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} From b570e7bbc0490285b3e820773b873cb2880fa112 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 17 Feb 2026 06:07:59 +0000 Subject: [PATCH 13/33] super v3 run Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 75 +++++++++ modelopt/torch/quantization/model_calib.py | 144 +++++++++++++++++- .../nn/modules/tensor_quantizer.py | 13 +- 3 files changed, 222 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 93483374b..1e56c1164 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -171,6 +171,54 @@ "*o_proj*": {"enable": False}, # Skip QKV Output Projection } +SUPER_NVFP4_CONSERVATIVE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear + }, + "algorithm": "max", +} + +SUPER_NVFP4_CONSERVATIVE_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + + INT8_DEFAULT_CFG = { "quant_cfg": { "*weight_quantizer": {"num_bits": 8, "axis": 0}, @@ -293,6 +341,9 @@ "enable": False, }, **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, "algorithm": { "method": "gptq", @@ -315,6 +366,9 @@ "enable": True, }, **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, "algorithm": { "method": "gptq", @@ -551,6 +605,9 @@ "*weight_quantizer": _nvfp4_quantizer, "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, "algorithm": "max", } @@ -564,6 +621,9 @@ }, "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, "algorithm": { "method": "mse", @@ -1217,6 +1277,21 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): ), ) + checkpoint_every_n_layers: int | None = ModeloptField( + default=None, + title="Save intermediate checkpoint every N layers during sequential calibration.", + ) + + checkpoint_dir: str | None = ModeloptField( + default=None, + title="Directory for saving/loading intermediate GPTQ checkpoints.", + ) + + resume_from_layer: int = ModeloptField( + default=0, + title="Layer index to resume sequential calibration from (0 = start from beginning).", + ) + class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 959e6117a..54406c159 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -16,6 +16,8 @@ """Calibration utilities.""" import contextlib +import datetime +import json import math import os import warnings @@ -1520,7 +1522,13 @@ def postprocess(module, name): max_calibrate(model, forward_loop) -def _print_relative_mse_error(q: torch.Tensor, w: torch.Tensor, h: torch.Tensor, module_name: str): +def _print_relative_mse_error( + q: torch.Tensor, + w: torch.Tensor, + h: torch.Tensor, + module_name: str, + n_samples: int | None = None, +): """Print relative mean squared error between quantized and original weights. Computes the Hessian-weighted relative MSE between quantized and original weights, @@ -1532,13 +1540,15 @@ def _print_relative_mse_error(q: torch.Tensor, w: torch.Tensor, h: torch.Tensor, w (torch.Tensor): Original weight tensor h (torch.Tensor): Hessian matrix used for weighting the error module_name (str): Name of the module for logging purposes + n_samples (int | None): Number of Hessian samples (batches) used for this layer Note: Implementation adapted from the GPTQ repository: https://github.com/IST-DASLab/FP-Quant """ delta = q - w mse = (delta).mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6) - print(f"[{module_name}] Relative MSE error: {mse.item():.2e}") + suffix = f", n_hessian_samples: {n_samples}" if n_samples is not None else "" + print(f"[{module_name}] Relative MSE error: {mse.item():.2e}{suffix}") def update_hessian(input, hessian, n_samples): @@ -1551,15 +1561,15 @@ def update_hessian(input, hessian, n_samples): Returns: Tuple of (updated_hessian, new_sample_count) """ - batch_size = input.shape[0] + # Flatten to 2D (total_tokens, features) first, so batch_size counts tokens + input_flat = input.reshape(-1, input.shape[-1]).t().float() + batch_size = input_flat.shape[1] # Incremental averaging: scale down old hessian hessian *= n_samples / (n_samples + batch_size) n_samples += batch_size # Compute outer product: H += (2/n_samples) * X @ X^T - # where X is the flattened input reshaped to (features, batch*seq) - input_flat = input.reshape(-1, input.shape[-1]).t().float() scaled_input = math.sqrt(2 / n_samples) * input_flat hessian.add_((scaled_input @ scaled_input.t()).to(hessian.device)) @@ -1602,7 +1612,7 @@ def prepare_hessian_inverse(h, weight, percdamp): return h_inv -def blockwise_weight_update(module, h, block_size, percdamp): +def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): """Update module weights using GPTQ-style blockwise quantization. Args: @@ -1610,6 +1620,7 @@ def blockwise_weight_update(module, h, block_size, percdamp): H: Hessian matrix (d x d) block_size: Size of blocks to process at once percdamp: Damping percentage for Hessian diagonal + n_samples: Number of Hessian samples for logging (optional) """ weight = module.weight.data.float().clone() _, num_cols = weight.shape @@ -1638,7 +1649,7 @@ def blockwise_weight_update(module, h, block_size, percdamp): weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1) # Print relative mse error - _print_relative_mse_error(weight, module.weight.float(), h, module.name) + _print_relative_mse_error(weight, module.weight.float(), h, module.name, n_samples) # Update module weights module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype) @@ -1840,11 +1851,117 @@ def _disable_input_quantizers(layer: nn.Module): module.enable() +def save_fake_checkpoint(model: nn.Module, output_dir: str) -> None: + """Save fake quant checkpoint using save_pretrained() (HuggingFace format). + + Args: + model: The quantized model to save. + output_dir: Directory to write the checkpoint into. + """ + from modelopt.torch.opt.conversion import ModeloptStateManager, modelopt_state + from modelopt.torch.quantization.conversion import quantizer_state as get_quantizer_state + + os.makedirs(output_dir, exist_ok=True) + + # Remove accelerate hooks before saving to avoid pickling errors in modelopt_state. + # Accelerate hooks contain local functions (closures like 'add_hook_to_module..new_forward') + # that can't be pickled. Even after removing hooks from modules, they may still be captured + # in closures within quantizer_state metadata when modelopt_state() calls update_last_state_before_save(). + try: + from accelerate.hooks import remove_hook_from_module + + remove_hook_from_module(model, recurse=True) + except ImportError: + pass + + # Save model weights first (without modelopt_state to avoid pickling error) + model.save_pretrained(output_dir, save_modelopt_state=False) + + # Manually save modelopt_state after removing hooks and rebuilding quantizer_state. + # We need to rebuild quantizer_state because hooks may have been captured in closures + # when quantizer_state() was called during update_last_state_before_save() inside modelopt_state(). + if ModeloptStateManager.is_converted(model): + modelopt_state_path = os.path.join(output_dir, "modelopt_state.pth") + state = modelopt_state(model) + + # Rebuild quantizer_state in metadata to remove any hook references captured in closures + if "modelopt_state_dict" in state and isinstance(state["modelopt_state_dict"], list): + cleaned_state_dict = [] + for entry in state["modelopt_state_dict"]: + if isinstance(entry, tuple) and len(entry) >= 2: + mode_str, state_dict_entry = entry[0], entry[1] + if isinstance(state_dict_entry, dict) and "metadata" in state_dict_entry: + # Rebuild quantizer_state after hooks are removed + cleaned_entry = state_dict_entry.copy() + cleaned_metadata = cleaned_entry["metadata"].copy() + cleaned_metadata["quantizer_state"] = get_quantizer_state(model) + cleaned_entry["metadata"] = cleaned_metadata + cleaned_state_dict.append((mode_str, cleaned_entry)) + else: + cleaned_state_dict.append(entry) + else: + cleaned_state_dict.append(entry) + state["modelopt_state_dict"] = cleaned_state_dict + + torch.save(state, modelopt_state_path) + print_rank_0(f"Saved ModelOpt state to {modelopt_state_path}") + + +def _save_gptq_checkpoint( + model: nn.Module, checkpoint_dir: str, last_layer_idx: int, total_layers: int +) -> None: + """Save intermediate GPTQ checkpoint with metadata for resume support. + + Saves accelerate hooks before calling save_fake_checkpoint (which removes them), + then re-attaches them so the model remains functional for subsequent layers. + """ + print_rank_0( + f"Saving GPTQ checkpoint after layer {last_layer_idx}/{total_layers - 1} to {checkpoint_dir}" + ) + + # Save accelerate hooks before save_fake_checkpoint removes them. + # We need to re-attach them after saving so the model keeps working. + saved_hooks = {} + for name, module in model.named_modules(): + if hasattr(module, "_hf_hook"): + saved_hooks[name] = module._hf_hook + + try: + save_fake_checkpoint(model, checkpoint_dir) + finally: + # Re-attach accelerate hooks so the model keeps working for remaining layers. + if saved_hooks: + try: + from accelerate.hooks import add_hook_to_module + + name_to_module = dict(model.named_modules()) + for name, hook in saved_hooks.items(): + if name in name_to_module: + add_hook_to_module(name_to_module[name], hook) + print_rank_0(f"Re-attached {len(saved_hooks)} accelerate hooks") + except ImportError: + pass + + # Save checkpoint metadata for resume support. + meta = { + "last_completed_layer": last_layer_idx, + "total_layers": total_layers, + "timestamp": datetime.datetime.now().isoformat(), + } + meta_path = os.path.join(checkpoint_dir, "gptq_checkpoint_meta.json") + with open(meta_path, "w") as f: + json.dump(meta, f, indent=2) + print_rank_0(f"GPTQ checkpoint saved (layer {last_layer_idx}/{total_layers - 1})") + + @torch.no_grad() def sequential_calibrate( model: nn.Module, forward_loop: ForwardLoop, calib_func: Callable, + checkpoint_every_n_layers: int | None = None, + checkpoint_dir: str | None = None, + resume_from_layer: int = 0, **calib_kwargs, ): """Sequential calibration - a sequential layer-by-layer calibration algorithm. @@ -1880,12 +1997,21 @@ def _layer_forward_loop(m, _inputs=layer_inputs): for args, kwargs_input in _inputs: m(*args, **kwargs_input) +<<<<<<< HEAD calib_func(layer, _layer_forward_loop, **calib_kwargs) del layer_inputs torch.cuda.empty_cache() finally: input_getter._unpatch_all_layers() +======= + # Call calibration function + calib_func(layer, _layer_forward_loop, **calib_kwargs) + del layer_inputs + torch.cuda.empty_cache() + + print_rank_0("Sequential calibration completed") +>>>>>>> 086c1d21 (super v3 run) @torch.no_grad() @@ -1961,7 +2087,9 @@ def hessian_hook(module, input, output): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: state = hessian_state[module.name] hessian = state["hessian"].to(module.weight.device) - blockwise_weight_update(module, hessian, block_size, percdamp) + blockwise_weight_update( + module, hessian, block_size, percdamp, n_samples=state["n_samples"] + ) # Free memory del hessian_state[module.name] torch.cuda.empty_cache() diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index ec2c3cfc5..4317c5860 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -1331,10 +1331,19 @@ def global_amax(self, value): def _fake_quantize(self, inputs): """Fake quantization using two-level scaling with _amax and _global_amax.""" if self.amax is not None: + # Ensure amax/global_amax are on the same device as inputs. + # After from_pretrained with device_map, quantizer buffers may remain + # on CPU while model weights/activations are on GPU. + amax = self.amax + if amax.device != inputs.device: + amax = amax.to(inputs.device) + global_amax = self.global_amax + if global_amax is not None and global_amax.device != inputs.device: + global_amax = global_amax.to(inputs.device) return static_blockwise_fp4_fake_quant( inputs, - self.amax, - self.global_amax, # Can be None, will be computed internally + amax, + global_amax, # Can be None, will be computed internally True, # quantize_block_scales inputs.dtype, self._pass_through_bwd, From ae30ff1ac823fd08ca2184a487b78946841ff7b7 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 17 Feb 2026 20:27:50 +0000 Subject: [PATCH 14/33] debug logs Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 54406c159..103a6dc45 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1997,21 +1997,14 @@ def _layer_forward_loop(m, _inputs=layer_inputs): for args, kwargs_input in _inputs: m(*args, **kwargs_input) -<<<<<<< HEAD calib_func(layer, _layer_forward_loop, **calib_kwargs) del layer_inputs torch.cuda.empty_cache() finally: input_getter._unpatch_all_layers() -======= - # Call calibration function - calib_func(layer, _layer_forward_loop, **calib_kwargs) - del layer_inputs - torch.cuda.empty_cache() print_rank_0("Sequential calibration completed") ->>>>>>> 086c1d21 (super v3 run) @torch.no_grad() From f8099758f2461775154dd45feba41c46db312c75 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 17 Feb 2026 21:17:38 +0000 Subject: [PATCH 15/33] added activationmse logging helper Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/activation_mse.py | 787 ++++++++++++++++++ 1 file changed, 787 insertions(+) create mode 100644 modelopt/torch/quantization/activation_mse.py diff --git a/modelopt/torch/quantization/activation_mse.py b/modelopt/torch/quantization/activation_mse.py new file mode 100644 index 000000000..df90c84a3 --- /dev/null +++ b/modelopt/torch/quantization/activation_mse.py @@ -0,0 +1,787 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Per-layer activation MSE measurement for quantization analysis. + +This module provides utilities to measure per-linear-layer MSE between a model's +activations before and after quantization. Inspired by FP-Quant's two-phase approach: + +- **Phase 1** (before quantization): ``collect_activations()`` runs the model on + calibration data and stores per-layer outputs in CPU RAM. +- **Phase 2** (after quantization): ``measure_activation_mse()`` runs the quantized + model on the same data and computes MSE on-the-fly against the stored Phase 1 + outputs. Only running scalar accumulators are kept -- no second set of tensors + is stored. + +Typical usage in hf_ptq.py:: + + # Phase 1: before quantization + orig_acts = mtq.collect_activations(model, mse_dataloader, max_samples=16) + + # Quantize + model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + + # Phase 2: after quantization -- computes MSE incrementally + mse = mtq.measure_activation_mse(model, mse_dataloader, orig_acts, max_samples=16) +""" + +import contextlib +import fnmatch +import hashlib +import os +from collections.abc import Iterable +from datetime import datetime + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +from modelopt.torch.utils.network import get_decoder_layers + +__all__ = ["ActivationMSELogger", "collect_activations", "measure_activation_mse"] + + +def _tensor_from_output(out) -> torch.Tensor: + """Extract a single tensor from a layer's output (handles tuple returns).""" + if isinstance(out, torch.Tensor): + return out.detach() + return out[0].detach() + + +def _is_linear(module: nn.Module) -> bool: + """Check if a module is a linear layer (covers both nn.Linear and quantized linear).""" + return isinstance(module, nn.Linear) + + +def _matches_filter(name: str, layer_filter: str | None) -> bool: + """Check if a layer name matches the optional filter pattern (fnmatch-style).""" + if layer_filter is None: + return True + return fnmatch.fnmatch(name, layer_filter) + + +def _discover_target_layers( + model: nn.Module, + layer_filter: str | None = None, +) -> dict[str, nn.Module]: + """Discover linear layers within decoder blocks of the model. + + Uses get_decoder_layers() to find transformer blocks, then finds all linear + submodules within those blocks. Falls back to all linear layers in the model + if decoder blocks cannot be identified. + + Args: + model: The model to inspect. + layer_filter: Optional fnmatch pattern to select specific layers + (e.g., ``"*self_attn*"``). + + Returns: + Dict mapping full module path -> module reference. + """ + decoder_layers = get_decoder_layers(model) + + targets: dict[str, nn.Module] = {} + + if decoder_layers is not None: + # Build a reverse lookup: module id -> full name in model + module_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} + + for block in decoder_layers: + block_name = module_to_name.get(id(block), "") + for sub_name, sub_mod in block.named_modules(): + if _is_linear(sub_mod): + full_name = f"{block_name}.{sub_name}" if block_name else sub_name + if _matches_filter(full_name, layer_filter): + targets[full_name] = sub_mod + else: + # Fallback: scan all modules + for name, module in model.named_modules(): + if _is_linear(module): + if _matches_filter(name, layer_filter): + targets[name] = module + + return targets + + +def _run_batch(model: nn.Module, batch) -> None: + """Run a single batch through the model.""" + if isinstance(batch, dict): + model(**batch) + elif isinstance(batch, (list, tuple)): + model(*batch) + else: + model(batch) + + +@torch.no_grad() +def collect_activations( + model: nn.Module, + dataloader: Iterable, + max_samples: int | None = None, + layer_filter: str | None = None, +) -> dict[str, list[torch.Tensor]]: + """Collect per-linear-layer output activations into CPU memory (Phase 1). + + Registers forward hooks on linear layers within the model's decoder blocks, + runs calibration data through the model, and returns captured per-layer outputs. + + Args: + model: The model to collect activations from (typically pre-quantization). + dataloader: An iterable yielding batches (dicts with ``input_ids``, etc.). + Use batch_size=1 to minimize memory. + max_samples: Maximum number of batches to process. ``None`` means all. + layer_filter: Optional fnmatch pattern to restrict which layers are + collected (e.g., ``"*self_attn*"``). ``None`` means all linear layers + inside decoder blocks. + + Returns: + Dict mapping layer name to a list of output tensors (one per batch, on CPU). + """ + was_training = model.training + model.eval() + + # Discover target linear layers + targets = _discover_target_layers(model, layer_filter) + if not targets: + raise ValueError( + f"No linear layers found matching the given filter. layer_filter={layer_filter!r}" + ) + + print(f"Collecting activations for {len(targets)} layers...") + + # Storage: {layer_name: [tensor_per_batch, ...]} + saved: dict[str, list[torch.Tensor]] = {name: [] for name in targets} + captured: dict[str, torch.Tensor] = {} + + def _make_hook(key: str): + def hook(_module: nn.Module, _input, output) -> None: + captured[key] = _tensor_from_output(output).cpu() + + return hook + + # Register hooks + hooks = [] + for name, module in targets.items(): + hooks.append(module.register_forward_hook(_make_hook(name))) + + try: + n_batches = 0 + for batch in tqdm(dataloader, desc="Collecting activations", leave=False): + if max_samples is not None and n_batches >= max_samples: + break + + captured.clear() + _run_batch(model, batch) + + for name in targets: + if name in captured: + saved[name].append(captured[name]) + + n_batches += 1 + finally: + for h in hooks: + h.remove() + + model.train(was_training) + + print(f"Collected {n_batches} samples across {len(targets)} layers") + return saved + + +@torch.no_grad() +def measure_activation_mse( + model: nn.Module, + dataloader: Iterable, + orig_activations: dict[str, list[torch.Tensor]], + max_samples: int | None = None, + layer_filter: str | None = None, +) -> dict[str, float]: + """Compute per-layer MSE between stored and live activations (Phase 2). + + Runs the (quantized) model on calibration data and computes MSE on-the-fly + against the pre-quantization activations stored by :func:`collect_activations`. + + Only scalar accumulators (sum of squared errors and element count) are kept + per layer -- no second set of activation tensors is stored. + + The MSE for each layer is computed as:: + + MSE = sum_over_all_elements((orig - quant) ^ 2) / total_elements + + Args: + model: The quantized model to measure. + dataloader: Same dataloader used for :func:`collect_activations` + (must yield batches in the same order). + orig_activations: Output of :func:`collect_activations` -- dict mapping + layer name to a list of pre-quantization output tensors. + max_samples: Maximum number of batches to process (should match Phase 1). + layer_filter: Optional fnmatch pattern (should match Phase 1). + + Returns: + Dict mapping layer name to its MSE value. + """ + was_training = model.training + model.eval() + + # Discover target layers on the (now-quantized) model + targets = _discover_target_layers(model, layer_filter) + + # Only measure layers that exist in both the model and orig_activations + common_keys = sorted(set(targets.keys()) & set(orig_activations.keys())) + if not common_keys: + raise ValueError( + "No matching layers between the quantized model and stored activations. " + "Ensure the same layer_filter is used for both phases." + ) + + skipped = set(orig_activations.keys()) - set(targets.keys()) + if skipped: + print(f"Warning: {len(skipped)} layers in orig_activations not found in model (skipped)") + + print(f"Computing activation MSE for {len(common_keys)} layers...") + + # Scalar accumulators + sum_sq: dict[str, float] = dict.fromkeys(common_keys, 0.0) + count: dict[str, int] = dict.fromkeys(common_keys, 0) + + captured: dict[str, torch.Tensor] = {} + + def _make_hook(key: str): + def hook(_module: nn.Module, _input, output) -> None: + captured[key] = _tensor_from_output(output).cpu() + + return hook + + # Register hooks only on common layers + hooks = [targets[name].register_forward_hook(_make_hook(name)) for name in common_keys] + + try: + batch_idx = 0 + for batch in tqdm(dataloader, desc="Computing activation MSE", leave=False): + if max_samples is not None and batch_idx >= max_samples: + break + + captured.clear() + _run_batch(model, batch) + + for name in common_keys: + if name not in captured: + continue + if batch_idx >= len(orig_activations.get(name, [])): + continue + + o = orig_activations[name][batch_idx].float() + q = captured[name].float() + + if o.shape != q.shape: + print( + f"Warning: shape mismatch for {name} batch {batch_idx}: " + f"{o.shape} vs {q.shape}, skipping" + ) + continue + + sum_sq[name] += F.mse_loss(o, q, reduction="sum").item() + count[name] += o.numel() + + batch_idx += 1 + finally: + for h in hooks: + h.remove() + + model.train(was_training) + + mse = { + key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") for key in common_keys + } + + return mse + + +# --------------------------------------------------------------------------- +# Portable ActivationMSELogger class +# --------------------------------------------------------------------------- + + +def _portable_discover_target_layers( + model: nn.Module, + layer_filter: str | None = None, +) -> dict[str, nn.Module]: + """Discover linear layers in decoder blocks with a portable fallback chain. + + Strategy: + 1. Try modelopt's ``get_decoder_layers`` (available inside ModelOpt). + 2. Try common HuggingFace attribute paths (``model.model.layers``, etc.). + 3. Fall back to scanning **all** ``nn.Linear`` in ``model.named_modules()``. + + Within each set of decoder blocks the function collects every ``nn.Linear`` + sub-module and optionally filters by *layer_filter* (fnmatch pattern). + """ + decoder_layers = None + + # 1. Try modelopt helper (may not exist when file is copied elsewhere) + with contextlib.suppress(Exception): + decoder_layers = get_decoder_layers(model) + + # 2. Try common HF / other patterns + if decoder_layers is None: + for attr_chain in ( + ("model", "layers"), + ("decoder", "layers"), + ("transformer", "h"), + ("backbone", "layers"), + ): + obj = model + try: + for attr in attr_chain: + obj = getattr(obj, attr) + if isinstance(obj, nn.ModuleList): + decoder_layers = obj + break + except AttributeError: + continue + + targets: dict[str, nn.Module] = {} + + if decoder_layers is not None: + module_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} + for block in decoder_layers: + block_name = module_to_name.get(id(block), "") + for sub_name, sub_mod in block.named_modules(): + if isinstance(sub_mod, nn.Linear): + full_name = f"{block_name}.{sub_name}" if block_name else sub_name + if _matches_filter(full_name, layer_filter): + targets[full_name] = sub_mod + else: + # 3. Fallback: all linear layers + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + if _matches_filter(name, layer_filter): + targets[name] = module + + return targets + + +class ActivationMSELogger: + """Portable activation MSE logger for comparing original vs quantized models. + + Works with both: + + - ``List[Tensor]`` data (**FP-Quant** style): each element is ``[1, seq_len]`` + or ``[B, seq_len]``, consumed via ``model(tensor)``. + - ``DataLoader`` / ``Iterable`` yielding dicts (**ModelOpt** style): + ``{"input_ids": tensor, ...}``, consumed via ``model(**batch)``. + + Guarantees same samples are used for both phases via SHA-256 hashing of + input tensors. Supports saving / loading all activations to disk for + later cross-codebase comparison. + + Example (ModelOpt -- DataLoader with dict batches):: + + mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") + mse_logger.collect(model, dataloader, phase="original") + model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + mse_logger.collect(model, dataloader, phase="quantized") + results = mse_logger.compute_mse() + print(mse_logger.summary()) + mse_logger.save() + + Example (FP-Quant -- List[Tensor]):: + + mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") + mse_logger.collect(model_orig, calibration_data, phase="original") + mse_logger.collect(model_quant, calibration_data, phase="quantized") + results = mse_logger.compute_mse() + print(mse_logger.summary()) + mse_logger.save() + """ + + def __init__( + self, + max_samples: int = 16, + layer_filter: str | None = None, + save_dir: str | None = None, + ): + """Initialize the ActivationMSELogger. + + Args: + max_samples: Maximum number of calibration batches to process per phase. + layer_filter: Optional glob pattern to restrict which layers are tracked. + save_dir: Optional directory path for persisting activation data to disk. + """ + self.max_samples = max_samples + self.layer_filter = layer_filter + self.save_dir = save_dir + + # Per-phase state + self.original_activations: dict[str, list[torch.Tensor]] = {} + self.quantized_activations: dict[str, list[torch.Tensor]] = {} + self.input_hashes: list[str] = [] # hashes for "original" phase + self.quant_input_hashes: list[str] = [] # hashes for "quantized" phase + + # Computed after both phases + self.mse_results: dict[str, float] | None = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + @torch.no_grad() + def collect( + self, + model: nn.Module, + data: Iterable, + phase: str, + target_modules: dict[str, nn.Module] | None = None, + ) -> None: + """Collect per-linear-layer output activations for a given phase. + + Args: + model: The model to run (original or quantized). + data: An iterable of batches. Each batch can be: + + - ``torch.Tensor`` with shape ``[B, seq_len]`` (FP-Quant style). + - ``dict`` with at least an ``"input_ids"`` key (ModelOpt style). + - ``list`` / ``tuple`` of tensors. + phase: ``"original"`` or ``"quantized"``. + target_modules: Optional explicit mapping of ``{name: nn.Module}`` + to attach hooks to. If *None*, layers are auto-discovered + via decoder-block scanning. + """ + if phase not in ("original", "quantized"): + raise ValueError(f"phase must be 'original' or 'quantized', got {phase!r}") + + was_training = model.training + model.eval() + + # ----- layer discovery ----- + targets = ( + target_modules + if target_modules is not None + else (_portable_discover_target_layers(model, self.layer_filter)) + ) + if not targets: + raise ValueError( + "No linear layers found. Provide target_modules explicitly or " + f"check layer_filter={self.layer_filter!r}." + ) + + print( + f"[ActivationMSELogger] Phase '{phase}': hooking {len(targets)} layers, " + f"max_samples={self.max_samples}" + ) + + # ----- storage ----- + saved: dict[str, list[torch.Tensor]] = {name: [] for name in targets} + captured: dict[str, torch.Tensor] = {} + hashes: list[str] = [] + + def _make_hook(key: str): + def hook(_module: nn.Module, _input, output) -> None: + captured[key] = _tensor_from_output(output).cpu() + + return hook + + hooks = [] + for name, module in targets.items(): + hooks.append(module.register_forward_hook(_make_hook(name))) + + try: + n_batches = 0 + for batch in tqdm(data, desc=f"Collecting ({phase})", leave=False): + if self.max_samples is not None and n_batches >= self.max_samples: + break + + captured.clear() + self._run_batch(model, batch) + + for name in targets: + if name in captured: + saved[name].append(captured[name]) + + hashes.append(self._hash_batch(batch)) + n_batches += 1 + finally: + for h in hooks: + h.remove() + + model.train(was_training) + + # ----- store results on self ----- + if phase == "original": + self.original_activations = saved + self.input_hashes = hashes + else: + self.quantized_activations = saved + self.quant_input_hashes = hashes + # Verify sample consistency + if self.input_hashes: + self._verify_hashes() + + # Invalidate any previous MSE since we have new activations + self.mse_results = None + + print(f"[ActivationMSELogger] Collected {n_batches} batches for phase '{phase}'") + + def compute_mse(self) -> dict[str, float]: + """Compute per-layer MSE between original and quantized activations. + + Returns: + Dict mapping layer name to its MSE value. + + Raises: + ValueError: If either phase has not been collected yet. + """ + if not self.original_activations: + raise ValueError( + "No original activations collected. Call collect(..., phase='original') first." + ) + if not self.quantized_activations: + raise ValueError( + "No quantized activations collected. Call collect(..., phase='quantized') first." + ) + + common_keys = sorted( + set(self.original_activations.keys()) & set(self.quantized_activations.keys()) + ) + if not common_keys: + raise ValueError( + "No matching layer names between original and quantized activations. " + "Ensure the same model architecture / layer_filter is used for both phases." + ) + + orig_only = set(self.original_activations.keys()) - set(self.quantized_activations.keys()) + quant_only = set(self.quantized_activations.keys()) - set(self.original_activations.keys()) + if orig_only: + print( + f"[ActivationMSELogger] Warning: {len(orig_only)} layers only in original (skipped)" + ) + if quant_only: + print( + f"[ActivationMSELogger] Warning: {len(quant_only)} layers only in quantized (skipped)" + ) + + sum_sq: dict[str, float] = dict.fromkeys(common_keys, 0.0) + count: dict[str, int] = dict.fromkeys(common_keys, 0) + + for name in common_keys: + orig_list = self.original_activations[name] + quant_list = self.quantized_activations[name] + n = min(len(orig_list), len(quant_list)) + for i in range(n): + o = orig_list[i].float() + q = quant_list[i].float() + if o.shape != q.shape: + print( + f"[ActivationMSELogger] Warning: shape mismatch for {name} " + f"batch {i}: {o.shape} vs {q.shape}, skipping" + ) + continue + sum_sq[name] += F.mse_loss(o, q, reduction="sum").item() + count[name] += o.numel() + + self.mse_results = { + key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") + for key in common_keys + } + return self.mse_results + + def save(self, path: str | None = None) -> str: + """Save all state (activations, hashes, MSE) to disk via ``torch.save``. + + Args: + path: Explicit file path. If *None*, a timestamped file is created + inside ``self.save_dir`` (which must be set). + + Returns: + The path where the file was saved. + """ + if path is None: + if self.save_dir is None: + raise ValueError("Provide a path or set save_dir in the constructor.") + os.makedirs(self.save_dir, exist_ok=True) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + path = os.path.join(self.save_dir, f"activation_mse_{ts}.pt") + + payload = { + "max_samples": self.max_samples, + "layer_filter": self.layer_filter, + "input_hashes": self.input_hashes, + "quant_input_hashes": self.quant_input_hashes, + "original_activations": self.original_activations, + "quantized_activations": self.quantized_activations, + "mse": self.mse_results, + } + torch.save(payload, path) + print(f"[ActivationMSELogger] Saved to {path}") + return path + + @classmethod + def load(cls, path: str) -> "ActivationMSELogger": + """Load a previously saved ``ActivationMSELogger`` from disk. + + Args: + path: Path to the ``.pt`` file created by :meth:`save`. + + Returns: + A new ``ActivationMSELogger`` instance with restored state. + """ + payload = torch.load(path, map_location="cpu", weights_only=False) + logger = cls( + max_samples=payload.get("max_samples", 16), + layer_filter=payload.get("layer_filter"), + ) + logger.original_activations = payload.get("original_activations", {}) + logger.quantized_activations = payload.get("quantized_activations", {}) + logger.input_hashes = payload.get("input_hashes", []) + logger.quant_input_hashes = payload.get("quant_input_hashes", []) + logger.mse_results = payload.get("mse") + print(f"[ActivationMSELogger] Loaded from {path}") + return logger + + def summary(self) -> str: + """Return a formatted string summarising per-layer MSE results. + + Computes MSE first if not already done. + """ + if self.mse_results is None: + self.compute_mse() + assert self.mse_results is not None + + lines = ["Per-layer activation MSE (original vs quantized):"] + lines.extend( + f" {key}: {self.mse_results[key]:.6e}" for key in sorted(self.mse_results.keys()) + ) + return "\n".join(lines) + + # ------------------------------------------------------------------ + # Pre-materialized MSE data (cross-run / cross-codebase safety) + # ------------------------------------------------------------------ + + @staticmethod + def materialize_data( + data: Iterable, + path: str, + max_samples: int | None = None, + ) -> list[torch.Tensor]: + """Freeze the first *max_samples* batches from *data* into a ``.pt`` file. + + Each batch (``dict``, ``Tensor``, or ``list/tuple``) is normalised to a + single ``input_ids`` CPU tensor before saving. The resulting file is a + plain ``List[Tensor]`` that can be loaded in **any** codebase and passed + straight to :meth:`collect`. + + If *path* already exists it is **not** overwritten -- call + :meth:`load_data` instead. + + Args: + data: Iterable of batches (DataLoader, List[Tensor], etc.). + path: Destination ``.pt`` file path. + max_samples: How many batches to keep. ``None`` means all. + + Returns: + The materialised list of CPU tensors (same object that was saved). + """ + samples: list[torch.Tensor] = [] + for batch in data: + if max_samples is not None and len(samples) >= max_samples: + break + if isinstance(batch, dict): + t = batch.get("input_ids", next(iter(batch.values()))) + elif isinstance(batch, torch.Tensor): + t = batch + elif isinstance(batch, (list, tuple)): + t = batch[0] + else: + raise TypeError(f"Unsupported batch type: {type(batch)}") + samples.append(t.cpu()) + + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) + torch.save(samples, path) + print(f"[ActivationMSELogger] Materialised {len(samples)} MSE input samples -> {path}") + return samples + + @staticmethod + def load_data(path: str) -> list[torch.Tensor]: + """Load a previously materialised MSE input set. + + Args: + path: Path to the ``.pt`` file created by :meth:`materialize_data`. + + Returns: + ``List[Tensor]`` of input batches (on CPU). + """ + samples = torch.load(path, map_location="cpu", weights_only=True) + print(f"[ActivationMSELogger] Loaded {len(samples)} MSE input samples from {path}") + return samples + + # ------------------------------------------------------------------ + # Static / private helpers + # ------------------------------------------------------------------ + + @staticmethod + def _run_batch(model: nn.Module, batch) -> None: + """Run a single batch through the model (handles Tensor, dict, list/tuple). + + Automatically moves inputs to the model's device so that CPU-stored + materialized data works transparently with a CUDA model. + """ + device = next(model.parameters()).device + if isinstance(batch, dict): + batch = { + k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() + } + model(**batch) + elif isinstance(batch, torch.Tensor): + model(batch.to(device)) + elif isinstance(batch, (list, tuple)): + batch = tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in batch) + model(*batch) + else: + raise TypeError(f"Unsupported batch type: {type(batch)}") + + @staticmethod + def _hash_batch(batch) -> str: + """Compute SHA-256 hash of the primary input tensor in *batch*. + + - ``dict`` -> hashes ``batch["input_ids"]`` (falls back to first value). + - ``Tensor`` -> hashes the tensor directly. + - ``list/tuple`` -> hashes the first element. + """ + if isinstance(batch, dict): + t = batch.get("input_ids", next(iter(batch.values()))) + elif isinstance(batch, torch.Tensor): + t = batch + elif isinstance(batch, (list, tuple)): + t = batch[0] if batch else None + else: + return "" + + if t is None or not isinstance(t, torch.Tensor): + return "" + return hashlib.sha256(t.cpu().contiguous().numpy().tobytes()).hexdigest() + + def _verify_hashes(self) -> None: + """Compare input hashes between original and quantized phases.""" + n = min(len(self.input_hashes), len(self.quant_input_hashes)) + mismatches = sum(1 for i in range(n) if self.input_hashes[i] != self.quant_input_hashes[i]) + if mismatches: + print( + f"[ActivationMSELogger] WARNING: {mismatches}/{n} batches have " + f"different input hashes between original and quantized phases. " + f"The same data may not have been used for both phases!" + ) + else: + print(f"[ActivationMSELogger] Input hash verification passed ({n}/{n} match)") From 8bf4e290d6961fa94bf5c5ab188521831daec158 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Thu, 19 Feb 2026 23:47:48 +0000 Subject: [PATCH 16/33] input amax sync added + tested gptq super sft checkpoint Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 31 ++++++++ tests/gpu/torch/quantization/test_gptq.py | 87 ++++++++++++++++++++++ 2 files changed, 118 insertions(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 103a6dc45..5bb9739a3 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -135,6 +135,8 @@ def max_calibrate( for name, module in model.named_modules(): if hasattr(module, "layer_sync_moe_local_experts_amax"): module.layer_sync_moe_local_experts_amax() + elif hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() if not distributed_sync: return @@ -1832,6 +1834,35 @@ def _set_input_quantizers_quant_mode(layer: nn.Module): module.disable_calib() +def _set_kv_quantizers_calib_mode(layer: nn.Module): + for name, module in layer.named_modules(): + if ( + isinstance(module, TensorQuantizer) + and ("k_bmm_quantizer" in name or "v_bmm_quantizer" in name) + and not module._disabled + and not module._dynamic + and module._calibrator is not None + ): + module._calibrator.reset() + module.disable_quant() + module.enable_calib() + + +def _set_kv_quantizers_quant_mode(layer: nn.Module): + for name, module in layer.named_modules(): + if ( + isinstance(module, TensorQuantizer) + and ("k_bmm_quantizer" in name or "v_bmm_quantizer" in name) + and not module._disabled + and not module._dynamic + and module._calibrator is not None + ): + if module._calibrator.compute_amax() is not None: + module.load_calib_amax() + module.enable_quant() + module.disable_calib() + + @contextlib.contextmanager def _disable_input_quantizers(layer: nn.Module): """Temporarily disable all enabled input quantizers in a layer.""" diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index 0c60bcd00..c47b48b1e 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -20,7 +20,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer import modelopt.torch.quantization as mtq +from modelopt.torch.export.unified_export_hf import _export_quantized_weight from modelopt.torch.quantization.model_calib import blockwise_weight_update, update_hessian +from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader RAND_SEED = 42 @@ -156,6 +158,91 @@ def test_gptq_updates(block_size, dim, model_weight, expect_weight_change): assert torch.allclose(model.weight.data, q_dq_weight), "Weight should be equal" +def test_gptq_export_roundtrip(): + """Test that GPTQ export + dequantize produces weights matching in-memory QDQ.""" + torch.manual_seed(RAND_SEED) + dim = 128 + block_size = 4 + + # Step 1: Create a simple linear model and quantize to install NVFP4 quantizers + model = torch.nn.Linear(dim, dim).to("cuda") + model.name = "linear" + original_weight = model.weight.data.clone() + input_tensor = torch.randn(2, 16, dim).to("cuda") + quant_cfg = mtq.NVFP4_DEFAULT_CFG + + mtq.quantize(model, quant_cfg, forward_loop=lambda m: m(input_tensor)) + + # Restore original weight before GPTQ + model.weight.data = original_weight.clone() + + # Step 2: Perform GPTQ — compute Hessian and update weights + hessian = torch.zeros(dim, dim, dtype=torch.float32) + n_samples = 0 + hessian, n_samples = update_hessian(input_tensor, hessian, n_samples) + hessian = hessian.to("cuda") + + blockwise_weight_update(model, hessian, block_size, percdamp=0.1) + + # Save the QDQ reference from the quantizer applied to GPTQ'd weights + gptq_weight_shape = model.weight.data.shape + gptq_weight_dtype = model.weight.data.dtype + qdq_ref = model.weight.data.clone() + + # Step 3: Export — converts weight to packed NVFP4 and registers scale buffers + _export_quantized_weight(model, torch.bfloat16) + + # Verify export produced the expected buffers + assert hasattr(model, "weight_scale"), "Export should register weight_scale buffer" + assert hasattr(model, "weight_scale_2"), "Export should register weight_scale_2 buffer" + + # Step 4: Dequantize the exported packed weight and compare with QDQ reference + packed_weight = model.weight.data + weight_scale = model.weight_scale + weight_scale_2 = model.weight_scale_2 + + nvfp4_qtensor = NVFP4QTensor(gptq_weight_shape, gptq_weight_dtype, packed_weight) + deq_weight = nvfp4_qtensor.dequantize( + dtype=torch.bfloat16, + scale=weight_scale, + double_scale=weight_scale_2, + block_sizes={-1: 16}, + ) + + assert deq_weight.shape == qdq_ref.shape, ( + f"Shape mismatch: dequantized {deq_weight.shape} vs QDQ ref {qdq_ref.shape}" + ) + diff = (deq_weight - qdq_ref.to(torch.bfloat16)).abs() + max_diff = diff.max().item() + max_diff_idx = diff.argmax().item() + max_diff_row = max_diff_idx // deq_weight.shape[1] + max_diff_col = max_diff_idx % deq_weight.shape[1] + num_mismatched = (diff > 1e-3).sum().item() + total_elements = diff.numel() + + print("\n--- Diff Stats ---") + print(f" Max diff: {max_diff}") + print(f" Mean diff: {diff.mean().item()}") + print(f" Median diff: {diff.median().item()}") + print(f" Std diff: {diff.std().item()}") + print( + f" Mismatched (>1e-3): {num_mismatched}/{total_elements} " + f"({100 * num_mismatched / total_elements:.2f}%)" + ) + print( + f" Max diff at [{max_diff_row}, {max_diff_col}]: " + f"deq={deq_weight[max_diff_row, max_diff_col].item()}, " + f"qdq_ref={qdq_ref[max_diff_row, max_diff_col].item()}" + ) + + assert torch.allclose(deq_weight, qdq_ref.to(torch.bfloat16), atol=1e-2), ( + f"Dequantized weight does not match QDQ reference. " + f"Max diff: {max_diff} at [{max_diff_row}, {max_diff_col}] " + f"(deq={deq_weight[max_diff_row, max_diff_col].item()}, " + f"qdq_ref={qdq_ref[max_diff_row, max_diff_col].item()})" + ) + + @pytest.mark.parametrize( "quant_cfg", [mtq.NVFP4_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG] ) From 01640cf57aeb7c80c1ca56479a59bf092efbfd17 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 23 Feb 2026 21:24:34 +0000 Subject: [PATCH 17/33] checkpoints generated on 0223 Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 1 - modelopt/torch/quantization/config.py | 24 ++++++--- modelopt/torch/quantization/model_calib.py | 57 ++++++++++++++++------ 3 files changed, 58 insertions(+), 24 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 6d2e62a58..4c0710381 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -868,7 +868,6 @@ def _compute_perplexity(model, data, batch_size: int = 1): ppl = _compute_perplexity(full_model, eval_data) print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(full_model, args.pyt_ckpt_path) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 1e56c1164..6375d81b0 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -156,12 +156,23 @@ "*mlp.gate.*": {"enable": False}, # Skip the MOE router "*mlp.shared_expert_gate.*": {"enable": False}, # Skip the MOE router "*linear_attn.conv1d*": {"enable": False}, - "*mixer.conv1d*": {"enable": False}, # Skip mamba conv1d + "*mixer.conv1d*": {"enable": False}, "*output_layer*": {"enable": False}, "output.*": {"enable": False}, "default": {"enable": False}, } +super_disabled_quantizer_cfg = { + "*fc1_latent_proj*": {"enable": False}, # Skip Latent MOE + "*fc2_latent_proj*": {"enable": False}, # Skip Latent MOE + "*q_proj*": {"enable": False}, # Skip QKV Linear + "*k_proj*": {"enable": False}, # Skip QKV Linear + "*v_proj*": {"enable": False}, # Skip QKV Linear + "*o_proj*": {"enable": False}, # Skip Output Linear + "*mtp*": {"enable": False}, # Skip MTP layers +} + + _mamba_moe_disabled_quantizer_cfg = { "*fc1_latent_proj*": {"enable": False}, # Skip Latent MOE "*fc2_latent_proj*": {"enable": False}, # Skip Latent MOE @@ -186,7 +197,7 @@ "enable": True, }, **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, + **super_disabled_quantizer_cfg, "*mixer.in_proj*": {"enable": False}, # Skip mamba linear "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, @@ -208,7 +219,7 @@ "enable": True, }, **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, + **super_disabled_quantizer_cfg, "*mixer.in_proj*": {"enable": False}, # Skip mamba linear "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, @@ -341,7 +352,7 @@ "enable": False, }, **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, + # **_mamba_moe_disabled_quantizer_cfg, "*mixer.in_proj*": {"enable": False}, # Skip mamba linear "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, @@ -366,9 +377,6 @@ "enable": True, }, **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, "algorithm": { "method": "gptq", @@ -605,7 +613,7 @@ "*weight_quantizer": _nvfp4_quantizer, "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, + **super_disabled_quantizer_cfg, "*mixer.in_proj*": {"enable": False}, # Skip mamba linear "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 5bb9739a3..7971ce59f 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -138,6 +138,26 @@ def max_calibrate( elif hasattr(module, "sync_moe_local_experts_amax"): module.sync_moe_local_experts_amax() + for name, module in list(model.named_modules()): + if isinstance(module, TensorQuantizer) and not module._disabled: + if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): + # Get the initial amax from max calibration + initial_amax = module._amax.clone().detach() + + is_nvfp4_static = ( + module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes is not None + and module._block_sizes.get("scale_bits") == (4, 3) + ) + + if is_nvfp4_static: + # Compute and set global_amax + global_amax = reduce_amax(initial_amax, axis=None) + + # Convert to NVFP4StaticQuantizer in-place + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + if not distributed_sync: return @@ -342,6 +362,7 @@ def mse_calibrate( if fp8_scale_sweep and is_nvfp4_static: # Replace calibrator with NVFP4MSECalibrator + print("mse_calibrate: Replacing calibrator with NVFP4MSECalibrator") module._calibrator = NVFP4MSECalibrator( amax=initial_amax, axis=module._calibrator._axis, @@ -628,6 +649,7 @@ def quant_func(x, amax, quantizer=weight_quantizer): error_func = helper.get_error_func() if fp8_scale_sweep and is_nvfp4_static: + print("local_hessian_calibrate: Replacing calibrator with NVFP4MSECalibrator") weight_quantizer._calibrator = NVFP4MSECalibrator( amax=initial_amax, axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, @@ -2073,21 +2095,26 @@ def gptq( "n_samples": 0, } - # Phase 2: Register hooks to collect Hessians during forward passes - def hessian_hook(module, input, output): - """Hook to intercept activations and update hessian matrix.""" - if hasattr(module, "input_quantizer") and module.input_quantizer.is_enabled: - inp = module.input_quantizer(input[0]) - else: - inp = input[0] - state = hessian_state[module.name] - hessian, n_samples = update_hessian(inp, state["hessian"], state["n_samples"]) - hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} + # Phase 2: Patch forwards to collect Hessians (similar to local_hessian_calibrate) + def _make_hessian_forward(module_name): + def hessian_forward(self, input, *args, **kwargs): + inp = input.to_local() if hasattr(input, "to_local") else input + state = hessian_state[module_name] + hessian, n_samples = update_hessian(inp, state["hessian"], state["n_samples"]) + hessian_state[module_name] = {"hessian": hessian, "n_samples": n_samples} + + self.weight_quantizer.disable() + out = self._forward_no_gptq_hessian(input, *args, **kwargs) + self.weight_quantizer.enable() + return out + + return hessian_forward - handles = [] + patched_modules = [] for name, module in layer.named_modules(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - handles.append(module.register_forward_hook(hessian_hook)) + bind_forward_method(module, _make_hessian_forward(name), "_forward_no_gptq_hessian") + patched_modules.append(module) # Run forward passes with the provided inputs to collect Hessians hessian_start = time.time() @@ -2097,9 +2124,9 @@ def hessian_hook(module, input, output): for args, kwargs_input in inputs: layer(*args, **kwargs_input) - # Remove hooks after collecting Hessians - for handle in handles: - handle.remove() + # Unpatch forwards + for module in patched_modules: + unpatch_forward_method(module, "_forward_no_gptq_hessian") torch.cuda.synchronize() if torch.cuda.is_available() else None hessian_time = time.time() - hessian_start From c1c7c96e419f72ddafa24c9bde525dd1f929affe Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:20:37 +0000 Subject: [PATCH 18/33] tested perplexity Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 134 -------------- modelopt/torch/quantization/mode.py | 2 + modelopt/torch/quantization/model_calib.py | 205 +-------------------- 3 files changed, 5 insertions(+), 336 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 6375d81b0..1e02468da 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -340,125 +340,6 @@ }, } -NVFP4_STATIC_WO_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - # **_mamba_moe_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear - }, - "algorithm": { - "method": "gptq", - "use_sequential": True, - }, -} - -NVFP4_STATIC_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq", - "use_sequential": True, - }, -} - -NVFP4_STATIC_WO_GPTQ_LITE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} - -NVFP4_STATIC_WO_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "max", - "use_sequential": False, - }, -} - -NVFP4_STATIC_WO_GPTQ_LITE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} - -NVFP4_DYNAMIC_WO_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} INT4_AWQ_CFG = { "quant_cfg": { @@ -1285,21 +1166,6 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): ), ) - checkpoint_every_n_layers: int | None = ModeloptField( - default=None, - title="Save intermediate checkpoint every N layers during sequential calibration.", - ) - - checkpoint_dir: str | None = ModeloptField( - default=None, - title="Directory for saving/loading intermediate GPTQ checkpoints.", - ) - - resume_from_layer: int = ModeloptField( - default=0, - title="Layer index to resume sequential calibration from (0 = start from beginning).", - ) - class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 88e93bb77..efc66ffa9 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -255,6 +255,8 @@ def wrapped_calib_func( else: # Direct calibration (existing behavior) func(model, forward_loop=forward_loop, **kwargs) + else: + raise ValueError(f"No calibration function provided for method: {method}") # Lets get the latest metadata for the quantizer states metadata = {} diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 7971ce59f..100c74923 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1825,196 +1825,11 @@ def hessian_hook(module, input, output): print_rank_0("GPTQ-lite quantization completed successfully") -def _set_input_quantizers_calib_mode(layer: nn.Module): - """Set all input quantizers of a layer to calibration mode.""" - for name, module in layer.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and "input_quantizer" in name - and not module._disabled - and not module._dynamic - and module._calibrator is not None - ): - module._calibrator.reset() - module.disable_quant() - module.enable_calib() - - -def _set_input_quantizers_quant_mode(layer: nn.Module): - """Load fresh amaxes and restore all input quantizers of a layer to quant mode.""" - for name, module in layer.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and "input_quantizer" in name - and not module._disabled - and not module._dynamic - and module._calibrator is not None - ): - if module._calibrator.compute_amax() is not None: - module.load_calib_amax() - module.enable_quant() - module.disable_calib() - - -def _set_kv_quantizers_calib_mode(layer: nn.Module): - for name, module in layer.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and ("k_bmm_quantizer" in name or "v_bmm_quantizer" in name) - and not module._disabled - and not module._dynamic - and module._calibrator is not None - ): - module._calibrator.reset() - module.disable_quant() - module.enable_calib() - - -def _set_kv_quantizers_quant_mode(layer: nn.Module): - for name, module in layer.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and ("k_bmm_quantizer" in name or "v_bmm_quantizer" in name) - and not module._disabled - and not module._dynamic - and module._calibrator is not None - ): - if module._calibrator.compute_amax() is not None: - module.load_calib_amax() - module.enable_quant() - module.disable_calib() - - -@contextlib.contextmanager -def _disable_input_quantizers(layer: nn.Module): - """Temporarily disable all enabled input quantizers in a layer.""" - enabled_quantizers = [] - for name, module in layer.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and "input_quantizer" in name - and not module._disabled - ): - module.disable() - enabled_quantizers.append(module) - try: - yield - finally: - for module in enabled_quantizers: - module.enable() - - -def save_fake_checkpoint(model: nn.Module, output_dir: str) -> None: - """Save fake quant checkpoint using save_pretrained() (HuggingFace format). - - Args: - model: The quantized model to save. - output_dir: Directory to write the checkpoint into. - """ - from modelopt.torch.opt.conversion import ModeloptStateManager, modelopt_state - from modelopt.torch.quantization.conversion import quantizer_state as get_quantizer_state - - os.makedirs(output_dir, exist_ok=True) - - # Remove accelerate hooks before saving to avoid pickling errors in modelopt_state. - # Accelerate hooks contain local functions (closures like 'add_hook_to_module..new_forward') - # that can't be pickled. Even after removing hooks from modules, they may still be captured - # in closures within quantizer_state metadata when modelopt_state() calls update_last_state_before_save(). - try: - from accelerate.hooks import remove_hook_from_module - - remove_hook_from_module(model, recurse=True) - except ImportError: - pass - - # Save model weights first (without modelopt_state to avoid pickling error) - model.save_pretrained(output_dir, save_modelopt_state=False) - - # Manually save modelopt_state after removing hooks and rebuilding quantizer_state. - # We need to rebuild quantizer_state because hooks may have been captured in closures - # when quantizer_state() was called during update_last_state_before_save() inside modelopt_state(). - if ModeloptStateManager.is_converted(model): - modelopt_state_path = os.path.join(output_dir, "modelopt_state.pth") - state = modelopt_state(model) - - # Rebuild quantizer_state in metadata to remove any hook references captured in closures - if "modelopt_state_dict" in state and isinstance(state["modelopt_state_dict"], list): - cleaned_state_dict = [] - for entry in state["modelopt_state_dict"]: - if isinstance(entry, tuple) and len(entry) >= 2: - mode_str, state_dict_entry = entry[0], entry[1] - if isinstance(state_dict_entry, dict) and "metadata" in state_dict_entry: - # Rebuild quantizer_state after hooks are removed - cleaned_entry = state_dict_entry.copy() - cleaned_metadata = cleaned_entry["metadata"].copy() - cleaned_metadata["quantizer_state"] = get_quantizer_state(model) - cleaned_entry["metadata"] = cleaned_metadata - cleaned_state_dict.append((mode_str, cleaned_entry)) - else: - cleaned_state_dict.append(entry) - else: - cleaned_state_dict.append(entry) - state["modelopt_state_dict"] = cleaned_state_dict - - torch.save(state, modelopt_state_path) - print_rank_0(f"Saved ModelOpt state to {modelopt_state_path}") - - -def _save_gptq_checkpoint( - model: nn.Module, checkpoint_dir: str, last_layer_idx: int, total_layers: int -) -> None: - """Save intermediate GPTQ checkpoint with metadata for resume support. - - Saves accelerate hooks before calling save_fake_checkpoint (which removes them), - then re-attaches them so the model remains functional for subsequent layers. - """ - print_rank_0( - f"Saving GPTQ checkpoint after layer {last_layer_idx}/{total_layers - 1} to {checkpoint_dir}" - ) - - # Save accelerate hooks before save_fake_checkpoint removes them. - # We need to re-attach them after saving so the model keeps working. - saved_hooks = {} - for name, module in model.named_modules(): - if hasattr(module, "_hf_hook"): - saved_hooks[name] = module._hf_hook - - try: - save_fake_checkpoint(model, checkpoint_dir) - finally: - # Re-attach accelerate hooks so the model keeps working for remaining layers. - if saved_hooks: - try: - from accelerate.hooks import add_hook_to_module - - name_to_module = dict(model.named_modules()) - for name, hook in saved_hooks.items(): - if name in name_to_module: - add_hook_to_module(name_to_module[name], hook) - print_rank_0(f"Re-attached {len(saved_hooks)} accelerate hooks") - except ImportError: - pass - - # Save checkpoint metadata for resume support. - meta = { - "last_completed_layer": last_layer_idx, - "total_layers": total_layers, - "timestamp": datetime.datetime.now().isoformat(), - } - meta_path = os.path.join(checkpoint_dir, "gptq_checkpoint_meta.json") - with open(meta_path, "w") as f: - json.dump(meta, f, indent=2) - print_rank_0(f"GPTQ checkpoint saved (layer {last_layer_idx}/{total_layers - 1})") - - @torch.no_grad() def sequential_calibrate( model: nn.Module, forward_loop: ForwardLoop, calib_func: Callable, - checkpoint_every_n_layers: int | None = None, - checkpoint_dir: str | None = None, - resume_from_layer: int = 0, **calib_kwargs, ): """Sequential calibration - a sequential layer-by-layer calibration algorithm. @@ -2064,14 +1879,14 @@ def _layer_forward_loop(m, _inputs=layer_inputs): def gptq( layer: nn.Module, inputs: list[tuple[tuple, dict]], + forward_loop: ForwardLoop, percdamp: float = 0.01, block_size: int = 128, **kwargs, ): """GPTQ quantization - a GPTQ variant.""" - import time - - total_start = time.time() + # Set weight amax and activation amax'es for the current layer using max_calibrate + max_calibrate(layer, forward_loop=forward_loop) # Dictionary to store hessian matrices for all linear layers in this decoder hessian_state = {} @@ -2117,7 +1932,6 @@ def hessian_forward(self, input, *args, **kwargs): patched_modules.append(module) # Run forward passes with the provided inputs to collect Hessians - hessian_start = time.time() print_rank_0( f"Computing Hessians for {len(tensor_mapping)} linear layers using {len(inputs)} batches..." ) @@ -2128,11 +1942,8 @@ def hessian_forward(self, input, *args, **kwargs): for module in patched_modules: unpatch_forward_method(module, "_forward_no_gptq_hessian") - torch.cuda.synchronize() if torch.cuda.is_available() else None - hessian_time = time.time() - hessian_start # Phase 3: Update weights using computed Hessians (same as gptq_lite) - weight_update_start = time.time() print_rank_0("Updating weights using GPTQ algorithm...") for name, module in layer.named_modules(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: @@ -2144,13 +1955,3 @@ def hessian_forward(self, input, *args, **kwargs): # Free memory del hessian_state[module.name] torch.cuda.empty_cache() - - torch.cuda.synchronize() if torch.cuda.is_available() else None - weight_update_time = time.time() - weight_update_start - - total_time = time.time() - total_start - print_rank_0( - f"GPTQ timing - Hessian: {hessian_time:.2f}s, " - f"Weight update: {weight_update_time:.2f}s, " - f"Total: {total_time:.2f}s" - ) From bcf6fe3f56e15028ca86e340822dc7cb05622414 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 9 Feb 2026 16:46:47 +0000 Subject: [PATCH 19/33] tested, revert later --- examples/llm_ptq/hf_ptq.py | 76 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 4c0710381..f7f4d94a1 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -723,6 +723,82 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) + if True: + # Disable quantizers + # mtq.fold_weight(full_model) + # print("Folded weights") + print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") + mtq.disable_quantizer(full_model, "*") + if True: + # mtq.fold_weight(full_model) + import os + + import torch.nn.functional as F + from datasets import load_dataset + from tqdm import trange + from transformers import AutoTokenizer + + # Set cache directory to work directory to avoid disk space issues + cache_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), ".hf_cache" + ) + os.makedirs(cache_dir, exist_ok=True) + os.environ["HF_DATASETS_CACHE"] = cache_dir + print(f"Using HuggingFace datasets cache: {cache_dir}") + + def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): + test_dataset_raw = load_dataset( + "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir + ) + test_dataset_tok = tokenizer( + "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" + ).input_ids + num_test_sequences = test_dataset_tok.numel() // sequence_length + test_loader = [ + test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] + for i in range(num_test_sequences) + ] + return test_loader + + @torch.no_grad() + def _compute_perplexity(model, data, batch_size: int = 1): + num_samples = len(data) + device = next(model.parameters()).device + # Running estimate of negative log-likelihood + nll_running = 0 + # Number of tokens processed to far + tokens_processed = 0 + # Loop through each batch + for i in trange( + 0, num_samples, batch_size, desc="Computing perplexity", leave=False + ): + j = min(i + batch_size, num_samples) + inputs = torch.cat(data[i:j]).to(device) + # Forward pass through the model + lm_logits = model(inputs).logits + # Shift logits and labels for next token prediction + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = inputs[:, 1:] + # Compute loss + loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.size(-1)), + shift_labels.reshape(-1), + ) + # Calculate negative log likelihood + a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) + b = tokens_processed / (tokens_processed + shift_labels.numel()) + nll_running = a * loss + b * nll_running + # Update number of processed tokens + tokens_processed += shift_labels.numel() + # Compute perplexity + ppl = nll_running.exp().item() + return ppl + + eval_data = _get_wikitext2(tokenizer, 2048) + ppl = _compute_perplexity(full_model, eval_data) + print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") + + breakpoint() if True: import os From 558041ca0953027e6a805cb8c93e73f504288b56 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 10 Feb 2026 04:41:46 +0000 Subject: [PATCH 20/33] tested Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 220 -------------------------- modelopt/torch/quantization/config.py | 94 +++++++++++ 2 files changed, 94 insertions(+), 220 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index f7f4d94a1..2d7267768 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -723,226 +723,6 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) - if True: - # Disable quantizers - # mtq.fold_weight(full_model) - # print("Folded weights") - print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") - mtq.disable_quantizer(full_model, "*") - if True: - # mtq.fold_weight(full_model) - import os - - import torch.nn.functional as F - from datasets import load_dataset - from tqdm import trange - from transformers import AutoTokenizer - - # Set cache directory to work directory to avoid disk space issues - cache_dir = os.path.join( - os.path.dirname(os.path.abspath(__file__)), ".hf_cache" - ) - os.makedirs(cache_dir, exist_ok=True) - os.environ["HF_DATASETS_CACHE"] = cache_dir - print(f"Using HuggingFace datasets cache: {cache_dir}") - - def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): - test_dataset_raw = load_dataset( - "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir - ) - test_dataset_tok = tokenizer( - "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" - ).input_ids - num_test_sequences = test_dataset_tok.numel() // sequence_length - test_loader = [ - test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] - for i in range(num_test_sequences) - ] - return test_loader - - @torch.no_grad() - def _compute_perplexity(model, data, batch_size: int = 1): - num_samples = len(data) - device = next(model.parameters()).device - # Running estimate of negative log-likelihood - nll_running = 0 - # Number of tokens processed to far - tokens_processed = 0 - # Loop through each batch - for i in trange( - 0, num_samples, batch_size, desc="Computing perplexity", leave=False - ): - j = min(i + batch_size, num_samples) - inputs = torch.cat(data[i:j]).to(device) - # Forward pass through the model - lm_logits = model(inputs).logits - # Shift logits and labels for next token prediction - shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = inputs[:, 1:] - # Compute loss - loss = F.cross_entropy( - shift_logits.reshape(-1, shift_logits.size(-1)), - shift_labels.reshape(-1), - ) - # Calculate negative log likelihood - a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) - b = tokens_processed / (tokens_processed + shift_labels.numel()) - nll_running = a * loss + b * nll_running - # Update number of processed tokens - tokens_processed += shift_labels.numel() - # Compute perplexity - ppl = nll_running.exp().item() - return ppl - - eval_data = _get_wikitext2(tokenizer, 2048) - ppl = _compute_perplexity(full_model, eval_data) - print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - - breakpoint() - - if True: - import os - - import torch.nn.functional as F - from datasets import load_dataset - from tqdm import trange - from transformers import AutoTokenizer - - # Set cache directory to work directory to avoid disk space issues - cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".hf_cache") - os.makedirs(cache_dir, exist_ok=True) - os.environ["HF_DATASETS_CACHE"] = cache_dir - print(f"Using HuggingFace datasets cache: {cache_dir}") - - def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): - test_dataset_raw = load_dataset( - "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir - ) - test_dataset_tok = tokenizer( - "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" - ).input_ids - num_test_sequences = test_dataset_tok.numel() // sequence_length - test_loader = [ - test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] - for i in range(num_test_sequences) - ] - return test_loader - - @torch.no_grad() - def _compute_perplexity(model, data, batch_size: int = 1): - num_samples = len(data) - device = next(model.parameters()).device - # Running estimate of negative log-likelihood - nll_running = 0 - # Number of tokens processed to far - tokens_processed = 0 - # Loop through each batch - for i in trange( - 0, num_samples, batch_size, desc="Computing perplexity", leave=False - ): - j = min(i + batch_size, num_samples) - inputs = torch.cat(data[i:j]).to(device) - # Forward pass through the model - lm_logits = model(inputs).logits - # Shift logits and labels for next token prediction - shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = inputs[:, 1:] - # Compute loss - loss = F.cross_entropy( - shift_logits.reshape(-1, shift_logits.size(-1)), - shift_labels.reshape(-1), - ) - # Calculate negative log likelihood - a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) - b = tokens_processed / (tokens_processed + shift_labels.numel()) - nll_running = a * loss + b * nll_running - # Update number of processed tokens - tokens_processed += shift_labels.numel() - # Compute perplexity - ppl = nll_running.exp().item() - return ppl - - eval_data = _get_wikitext2(tokenizer, 2048) - ppl = _compute_perplexity(full_model, eval_data) - print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - breakpoint() - - if args.export_qdq_weights: - # Disable quantizers - if "gptq" not in args.qformat: - mtq.fold_weight(full_model) - print("Folded weights") - - print(f"Saving model to {args.export_path}") - full_model.save_pretrained(args.export_path) - - if True: - import os - - import torch.nn.functional as F - from datasets import load_dataset - from tqdm import trange - from transformers import AutoTokenizer - - # Set cache directory to work directory to avoid disk space issues - cache_dir = os.path.join( - os.path.dirname(os.path.abspath(__file__)), ".hf_cache" - ) - os.makedirs(cache_dir, exist_ok=True) - os.environ["HF_DATASETS_CACHE"] = cache_dir - print(f"Using HuggingFace datasets cache: {cache_dir}") - - def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): - test_dataset_raw = load_dataset( - "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir - ) - test_dataset_tok = tokenizer( - "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" - ).input_ids - num_test_sequences = test_dataset_tok.numel() // sequence_length - test_loader = [ - test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] - for i in range(num_test_sequences) - ] - return test_loader - - @torch.no_grad() - def _compute_perplexity(model, data, batch_size: int = 1): - num_samples = len(data) - device = next(model.parameters()).device - # Running estimate of negative log-likelihood - nll_running = 0 - # Number of tokens processed to far - tokens_processed = 0 - # Loop through each batch - for i in trange( - 0, num_samples, batch_size, desc="Computing perplexity", leave=False - ): - j = min(i + batch_size, num_samples) - inputs = torch.cat(data[i:j]).to(device) - # Forward pass through the model - lm_logits = model(inputs).logits - # Shift logits and labels for next token prediction - shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = inputs[:, 1:] - # Compute loss - loss = F.cross_entropy( - shift_logits.reshape(-1, shift_logits.size(-1)), - shift_labels.reshape(-1), - ) - # Calculate negative log likelihood - a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) - b = tokens_processed / (tokens_processed + shift_labels.numel()) - nll_running = a * loss + b * nll_running - # Update number of processed tokens - tokens_processed += shift_labels.numel() - # Compute perplexity - ppl = nll_running.exp().item() - return ppl - - eval_data = _get_wikitext2(tokenizer, 2048) - ppl = _compute_perplexity(full_model, eval_data) - print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 1e02468da..f46276b4f 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -340,6 +340,100 @@ }, } +NVFP4_STATIC_WO_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + +NVFP4_STATIC_WO_GPTQ_LITE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} + +NVFP4_STATIC_WO_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "max", + "use_sequential": False, + }, +} + +NVFP4_STATIC_WO_GPTQ_LITE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} + +NVFP4_DYNAMIC_WO_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} INT4_AWQ_CFG = { "quant_cfg": { From a90703810523b4a991e933ae5376557b17318b1c Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 24 Feb 2026 00:01:34 +0000 Subject: [PATCH 21/33] initial cleanup Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 53 ------------------ modelopt/torch/export/quant_utils.py | 62 ++++++---------------- modelopt/torch/export/unified_export_hf.py | 11 ++-- modelopt/torch/quantization/__init__.py | 1 - modelopt/torch/quantization/model_calib.py | 4 -- 5 files changed, 20 insertions(+), 111 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 2d7267768..939ecee96 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -566,43 +566,6 @@ def mono_quantize( else: calibrate_loop = create_forward_loop(dataloader=calib_dataloader) - # Phase 1: Collect pre-quantization activations (batch_size=1 to save memory) - if getattr(args, "measure_activation_mse", False): - mse_max_samples = getattr(args, "activation_mse_max_samples", 16) - mse_save_dir = getattr(args, "activation_mse_save_dir", None) - mse_input_path = getattr(args, "activation_mse_input_path", None) - - # Materialize or load a frozen set of MSE inputs so that the exact - # same samples are used across runs and across codebases. - if mse_input_path and os.path.isfile(mse_input_path): - mse_data = mtq.ActivationMSELogger.load_data(mse_input_path) - else: - from torch.utils.data import DataLoader as _DataLoader - - mse_dataloader = _DataLoader(calib_dataloader.dataset, batch_size=1, shuffle=False) - if mse_input_path: - mse_data = mtq.ActivationMSELogger.materialize_data( - mse_dataloader, - mse_input_path, - max_samples=mse_max_samples, - ) - else: - # No path given -- materialize in memory only - mse_data = [] - for i, batch in enumerate(mse_dataloader): - if i >= mse_max_samples: - break - t = batch["input_ids"] if isinstance(batch, dict) else batch - mse_data.append(t.cpu()) - - mse_logger = mtq.ActivationMSELogger( - max_samples=mse_max_samples, - layer_filter=getattr(args, "activation_mse_layer_filter", None), - save_dir=mse_save_dir, - ) - print("\n--- Phase 1: Collecting pre-quantization activations ---") - mse_logger.collect(language_model, mse_data, phase="original") - if calibration_only: language_model = mtq.calibrate( language_model, quant_cfg["algorithm"], forward_loop=calibrate_loop @@ -610,16 +573,6 @@ def mono_quantize( else: language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop) - # Phase 2: Compute MSE against stored pre-quant activations - if getattr(args, "measure_activation_mse", False): - print("\n--- Phase 2: Computing per-layer activation MSE ---") - mse_logger.collect(language_model, mse_data, phase="quantized") - mse_logger.compute_mse() - print(mse_logger.summary()) - if mse_save_dir: - mse_logger.save() - del mse_logger, mse_data - # For VL models, update full_model to use the quantized language model if is_nemotron_vl_model: language_model_lineage = get_language_model_from_vl(full_model) @@ -1189,12 +1142,6 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) - parser.add_argument( - "--export_qdq_weights", - help=("Used for GPTQ weights as is without compressed weights for deployment."), - default=False, - action="store_true", - ) parser.add_argument( "--verbose", help="Print verbose output (e.g. quantization summary). Disable by --no-verbose.", diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index b762757cb..674d0596e 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -46,7 +46,7 @@ ) from modelopt.torch.utils import clear_cuda_cache -from ..quantization.nn import NVFP4StaticQuantizer, SequentialQuantizer, TensorQuantizer +from ..quantization.nn import SequentialQuantizer, TensorQuantizer from .model_config import ( KV_CACHE_FP8, KV_CACHE_INT8, @@ -353,17 +353,15 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> return get_scaling_factor(weight_quantizer[0]) quantization_format = get_quantization_format(module) - if quantization_format in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_NVFP4_FP8, ]: - # Calibrate weight quantizer if amax is not set (only needed for dynamic quantizers) - if not is_nvfp4_static: - module_name = f"{type(module).__name__}.{weight_name}" - _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) + # Calibrate weight quantizer if amax is not set + module_name = f"{type(module).__name__}.{weight_name}" + _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. @@ -373,10 +371,9 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer( weight_quantizer ) - # Unified method handles both static and dynamic quantizers - return NVFP4QTensor.get_weights_scaling_factor_from_quantizer( - weight_quantizer, + return NVFP4QTensor.get_weights_scaling_factor( weight, + weight_quantizer.block_sizes[-1], weight_scaling_factor_2.to(weight.device), )[0] @@ -410,13 +407,16 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") module_name = f"{type(module).__name__}.{weight_name}" _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) - if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: - # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. - # This is because the kernel dequantizes weight to fp8, which is in range 448. - return weight_quantizer._amax.float() / 448.0 - else: - # Unified method handles both static and dynamic quantizers - return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) + if quantization_format in [ + QUANTIZATION_NVFP4, + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, + ]: + return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) + elif quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: + # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. + # This is because the kernel dequantizes weight to fp8, which is in range 448. + return weight_quantizer._amax.float() / 448.0 # SequentialQuantizer is required if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled: @@ -799,7 +799,7 @@ def process_layer_quant_config(layer_config_dict): layer_config = {"quant_algo": "W8A16"} elif v == "int8_sq": layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"} - elif v in ["nvfp4", "nvfp4_static"]: + elif v == "nvfp4": layer_config = { "quant_algo": "NVFP4", "group_size": block_size_value, @@ -1397,18 +1397,6 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False for module in modules: module.weight_quantizer[-1].amax = weight_amax - # Handle NVFP4StaticQuantizer: unify global_amax for fused layers - elif isinstance(modules[0].weight_quantizer, NVFP4StaticQuantizer): - global_amax_list = [ - m.weight_quantizer.global_amax - for m in modules - if m.weight_quantizer.global_amax is not None - ] - if global_amax_list: - unified_global_amax = torch.max(torch.stack(global_amax_list)) - for module in modules: - module.weight_quantizer.global_amax = unified_global_amax - elif ( modules[0].weight_quantizer.is_enabled and modules[0].weight_quantizer.amax is not None @@ -1493,22 +1481,6 @@ def get_quant_config( if block_size == 0: block_size = get_weight_block_size(module) - # Static NVFP4 uses pre-computed per-block scales from MSE calibration - if quantization_format == QUANTIZATION_NVFP4: - weight_quantizer = getattr(module, "weight_quantizer", None) - if weight_quantizer is None: - # Try to get from first weight attribute - for wn in weight_names: - weight_quantizer = getattr( - module, quantizer_attr_names(wn).weight_quantizer, None - ) - if weight_quantizer is not None: - break - if weight_quantizer is not None: - is_static = isinstance(weight_quantizer, NVFP4StaticQuantizer) - if is_static: - quantization_format = "nvfp4_static" - # Construct per layer config dictionary layer_config_dict[name + ".quantization"] = quantization_format layer_config_dict[name + ".awq_block_size"] = block_size diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index e5810bc1f..4c87c3157 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -52,11 +52,7 @@ from torch.distributed.fsdp import FSDPModule from modelopt.torch.quantization import set_quantizer_by_cfg_context -from modelopt.torch.quantization.nn import ( - NVFP4StaticQuantizer, - SequentialQuantizer, - TensorQuantizer, -) +from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names @@ -548,7 +544,6 @@ def _export_quantized_weight( weight, _ = maybe_transpose_expert_weight_dimensions( weight, is_bmm_expert_weight=is_bmm_expert_weight ) - weight_scale = NVFP4QTensor.get_weights_scaling_factor( weight, block_size=block_size, @@ -556,7 +551,7 @@ def _export_quantized_weight( )[0] quantized_weight = to_quantized_weight( - weight.to(torch.bfloat16), + weight.to(dtype), weight_scale, quantization_format, weight_scale_2, @@ -573,7 +568,7 @@ def _export_quantized_weight( ) quantized_weight = to_quantized_weight( - weight.to(torch.bfloat16), + weight.to(dtype), weight_scale, quantization_format, weight_scale_2, diff --git a/modelopt/torch/quantization/__init__.py b/modelopt/torch/quantization/__init__.py index 757b844fb..87dbf30bb 100644 --- a/modelopt/torch/quantization/__init__.py +++ b/modelopt/torch/quantization/__init__.py @@ -19,7 +19,6 @@ from . import mode, plugins, utils # Add methods to mtq namespace -from .activation_mse import ActivationMSELogger, collect_activations, measure_activation_mse from .compress import * from .config import * from .conversion import * diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 100c74923..7b390ef0f 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -15,9 +15,6 @@ """Calibration utilities.""" -import contextlib -import datetime -import json import math import os import warnings @@ -1942,7 +1939,6 @@ def hessian_forward(self, input, *args, **kwargs): for module in patched_modules: unpatch_forward_method(module, "_forward_no_gptq_hessian") - # Phase 3: Update weights using computed Hessians (same as gptq_lite) print_rank_0("Updating weights using GPTQ algorithm...") for name, module in layer.named_modules(): From 1e5ce74c214f99cf3e5d545d8d73026fae480b03 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 24 Feb 2026 00:24:55 +0000 Subject: [PATCH 22/33] cleanup Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/activation_mse.py | 787 ------------------ modelopt/torch/quantization/config.py | 167 +--- 2 files changed, 1 insertion(+), 953 deletions(-) delete mode 100644 modelopt/torch/quantization/activation_mse.py diff --git a/modelopt/torch/quantization/activation_mse.py b/modelopt/torch/quantization/activation_mse.py deleted file mode 100644 index df90c84a3..000000000 --- a/modelopt/torch/quantization/activation_mse.py +++ /dev/null @@ -1,787 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Per-layer activation MSE measurement for quantization analysis. - -This module provides utilities to measure per-linear-layer MSE between a model's -activations before and after quantization. Inspired by FP-Quant's two-phase approach: - -- **Phase 1** (before quantization): ``collect_activations()`` runs the model on - calibration data and stores per-layer outputs in CPU RAM. -- **Phase 2** (after quantization): ``measure_activation_mse()`` runs the quantized - model on the same data and computes MSE on-the-fly against the stored Phase 1 - outputs. Only running scalar accumulators are kept -- no second set of tensors - is stored. - -Typical usage in hf_ptq.py:: - - # Phase 1: before quantization - orig_acts = mtq.collect_activations(model, mse_dataloader, max_samples=16) - - # Quantize - model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - - # Phase 2: after quantization -- computes MSE incrementally - mse = mtq.measure_activation_mse(model, mse_dataloader, orig_acts, max_samples=16) -""" - -import contextlib -import fnmatch -import hashlib -import os -from collections.abc import Iterable -from datetime import datetime - -import torch -import torch.nn as nn -import torch.nn.functional as F -from tqdm import tqdm - -from modelopt.torch.utils.network import get_decoder_layers - -__all__ = ["ActivationMSELogger", "collect_activations", "measure_activation_mse"] - - -def _tensor_from_output(out) -> torch.Tensor: - """Extract a single tensor from a layer's output (handles tuple returns).""" - if isinstance(out, torch.Tensor): - return out.detach() - return out[0].detach() - - -def _is_linear(module: nn.Module) -> bool: - """Check if a module is a linear layer (covers both nn.Linear and quantized linear).""" - return isinstance(module, nn.Linear) - - -def _matches_filter(name: str, layer_filter: str | None) -> bool: - """Check if a layer name matches the optional filter pattern (fnmatch-style).""" - if layer_filter is None: - return True - return fnmatch.fnmatch(name, layer_filter) - - -def _discover_target_layers( - model: nn.Module, - layer_filter: str | None = None, -) -> dict[str, nn.Module]: - """Discover linear layers within decoder blocks of the model. - - Uses get_decoder_layers() to find transformer blocks, then finds all linear - submodules within those blocks. Falls back to all linear layers in the model - if decoder blocks cannot be identified. - - Args: - model: The model to inspect. - layer_filter: Optional fnmatch pattern to select specific layers - (e.g., ``"*self_attn*"``). - - Returns: - Dict mapping full module path -> module reference. - """ - decoder_layers = get_decoder_layers(model) - - targets: dict[str, nn.Module] = {} - - if decoder_layers is not None: - # Build a reverse lookup: module id -> full name in model - module_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} - - for block in decoder_layers: - block_name = module_to_name.get(id(block), "") - for sub_name, sub_mod in block.named_modules(): - if _is_linear(sub_mod): - full_name = f"{block_name}.{sub_name}" if block_name else sub_name - if _matches_filter(full_name, layer_filter): - targets[full_name] = sub_mod - else: - # Fallback: scan all modules - for name, module in model.named_modules(): - if _is_linear(module): - if _matches_filter(name, layer_filter): - targets[name] = module - - return targets - - -def _run_batch(model: nn.Module, batch) -> None: - """Run a single batch through the model.""" - if isinstance(batch, dict): - model(**batch) - elif isinstance(batch, (list, tuple)): - model(*batch) - else: - model(batch) - - -@torch.no_grad() -def collect_activations( - model: nn.Module, - dataloader: Iterable, - max_samples: int | None = None, - layer_filter: str | None = None, -) -> dict[str, list[torch.Tensor]]: - """Collect per-linear-layer output activations into CPU memory (Phase 1). - - Registers forward hooks on linear layers within the model's decoder blocks, - runs calibration data through the model, and returns captured per-layer outputs. - - Args: - model: The model to collect activations from (typically pre-quantization). - dataloader: An iterable yielding batches (dicts with ``input_ids``, etc.). - Use batch_size=1 to minimize memory. - max_samples: Maximum number of batches to process. ``None`` means all. - layer_filter: Optional fnmatch pattern to restrict which layers are - collected (e.g., ``"*self_attn*"``). ``None`` means all linear layers - inside decoder blocks. - - Returns: - Dict mapping layer name to a list of output tensors (one per batch, on CPU). - """ - was_training = model.training - model.eval() - - # Discover target linear layers - targets = _discover_target_layers(model, layer_filter) - if not targets: - raise ValueError( - f"No linear layers found matching the given filter. layer_filter={layer_filter!r}" - ) - - print(f"Collecting activations for {len(targets)} layers...") - - # Storage: {layer_name: [tensor_per_batch, ...]} - saved: dict[str, list[torch.Tensor]] = {name: [] for name in targets} - captured: dict[str, torch.Tensor] = {} - - def _make_hook(key: str): - def hook(_module: nn.Module, _input, output) -> None: - captured[key] = _tensor_from_output(output).cpu() - - return hook - - # Register hooks - hooks = [] - for name, module in targets.items(): - hooks.append(module.register_forward_hook(_make_hook(name))) - - try: - n_batches = 0 - for batch in tqdm(dataloader, desc="Collecting activations", leave=False): - if max_samples is not None and n_batches >= max_samples: - break - - captured.clear() - _run_batch(model, batch) - - for name in targets: - if name in captured: - saved[name].append(captured[name]) - - n_batches += 1 - finally: - for h in hooks: - h.remove() - - model.train(was_training) - - print(f"Collected {n_batches} samples across {len(targets)} layers") - return saved - - -@torch.no_grad() -def measure_activation_mse( - model: nn.Module, - dataloader: Iterable, - orig_activations: dict[str, list[torch.Tensor]], - max_samples: int | None = None, - layer_filter: str | None = None, -) -> dict[str, float]: - """Compute per-layer MSE between stored and live activations (Phase 2). - - Runs the (quantized) model on calibration data and computes MSE on-the-fly - against the pre-quantization activations stored by :func:`collect_activations`. - - Only scalar accumulators (sum of squared errors and element count) are kept - per layer -- no second set of activation tensors is stored. - - The MSE for each layer is computed as:: - - MSE = sum_over_all_elements((orig - quant) ^ 2) / total_elements - - Args: - model: The quantized model to measure. - dataloader: Same dataloader used for :func:`collect_activations` - (must yield batches in the same order). - orig_activations: Output of :func:`collect_activations` -- dict mapping - layer name to a list of pre-quantization output tensors. - max_samples: Maximum number of batches to process (should match Phase 1). - layer_filter: Optional fnmatch pattern (should match Phase 1). - - Returns: - Dict mapping layer name to its MSE value. - """ - was_training = model.training - model.eval() - - # Discover target layers on the (now-quantized) model - targets = _discover_target_layers(model, layer_filter) - - # Only measure layers that exist in both the model and orig_activations - common_keys = sorted(set(targets.keys()) & set(orig_activations.keys())) - if not common_keys: - raise ValueError( - "No matching layers between the quantized model and stored activations. " - "Ensure the same layer_filter is used for both phases." - ) - - skipped = set(orig_activations.keys()) - set(targets.keys()) - if skipped: - print(f"Warning: {len(skipped)} layers in orig_activations not found in model (skipped)") - - print(f"Computing activation MSE for {len(common_keys)} layers...") - - # Scalar accumulators - sum_sq: dict[str, float] = dict.fromkeys(common_keys, 0.0) - count: dict[str, int] = dict.fromkeys(common_keys, 0) - - captured: dict[str, torch.Tensor] = {} - - def _make_hook(key: str): - def hook(_module: nn.Module, _input, output) -> None: - captured[key] = _tensor_from_output(output).cpu() - - return hook - - # Register hooks only on common layers - hooks = [targets[name].register_forward_hook(_make_hook(name)) for name in common_keys] - - try: - batch_idx = 0 - for batch in tqdm(dataloader, desc="Computing activation MSE", leave=False): - if max_samples is not None and batch_idx >= max_samples: - break - - captured.clear() - _run_batch(model, batch) - - for name in common_keys: - if name not in captured: - continue - if batch_idx >= len(orig_activations.get(name, [])): - continue - - o = orig_activations[name][batch_idx].float() - q = captured[name].float() - - if o.shape != q.shape: - print( - f"Warning: shape mismatch for {name} batch {batch_idx}: " - f"{o.shape} vs {q.shape}, skipping" - ) - continue - - sum_sq[name] += F.mse_loss(o, q, reduction="sum").item() - count[name] += o.numel() - - batch_idx += 1 - finally: - for h in hooks: - h.remove() - - model.train(was_training) - - mse = { - key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") for key in common_keys - } - - return mse - - -# --------------------------------------------------------------------------- -# Portable ActivationMSELogger class -# --------------------------------------------------------------------------- - - -def _portable_discover_target_layers( - model: nn.Module, - layer_filter: str | None = None, -) -> dict[str, nn.Module]: - """Discover linear layers in decoder blocks with a portable fallback chain. - - Strategy: - 1. Try modelopt's ``get_decoder_layers`` (available inside ModelOpt). - 2. Try common HuggingFace attribute paths (``model.model.layers``, etc.). - 3. Fall back to scanning **all** ``nn.Linear`` in ``model.named_modules()``. - - Within each set of decoder blocks the function collects every ``nn.Linear`` - sub-module and optionally filters by *layer_filter* (fnmatch pattern). - """ - decoder_layers = None - - # 1. Try modelopt helper (may not exist when file is copied elsewhere) - with contextlib.suppress(Exception): - decoder_layers = get_decoder_layers(model) - - # 2. Try common HF / other patterns - if decoder_layers is None: - for attr_chain in ( - ("model", "layers"), - ("decoder", "layers"), - ("transformer", "h"), - ("backbone", "layers"), - ): - obj = model - try: - for attr in attr_chain: - obj = getattr(obj, attr) - if isinstance(obj, nn.ModuleList): - decoder_layers = obj - break - except AttributeError: - continue - - targets: dict[str, nn.Module] = {} - - if decoder_layers is not None: - module_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} - for block in decoder_layers: - block_name = module_to_name.get(id(block), "") - for sub_name, sub_mod in block.named_modules(): - if isinstance(sub_mod, nn.Linear): - full_name = f"{block_name}.{sub_name}" if block_name else sub_name - if _matches_filter(full_name, layer_filter): - targets[full_name] = sub_mod - else: - # 3. Fallback: all linear layers - for name, module in model.named_modules(): - if isinstance(module, nn.Linear): - if _matches_filter(name, layer_filter): - targets[name] = module - - return targets - - -class ActivationMSELogger: - """Portable activation MSE logger for comparing original vs quantized models. - - Works with both: - - - ``List[Tensor]`` data (**FP-Quant** style): each element is ``[1, seq_len]`` - or ``[B, seq_len]``, consumed via ``model(tensor)``. - - ``DataLoader`` / ``Iterable`` yielding dicts (**ModelOpt** style): - ``{"input_ids": tensor, ...}``, consumed via ``model(**batch)``. - - Guarantees same samples are used for both phases via SHA-256 hashing of - input tensors. Supports saving / loading all activations to disk for - later cross-codebase comparison. - - Example (ModelOpt -- DataLoader with dict batches):: - - mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") - mse_logger.collect(model, dataloader, phase="original") - model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - mse_logger.collect(model, dataloader, phase="quantized") - results = mse_logger.compute_mse() - print(mse_logger.summary()) - mse_logger.save() - - Example (FP-Quant -- List[Tensor]):: - - mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") - mse_logger.collect(model_orig, calibration_data, phase="original") - mse_logger.collect(model_quant, calibration_data, phase="quantized") - results = mse_logger.compute_mse() - print(mse_logger.summary()) - mse_logger.save() - """ - - def __init__( - self, - max_samples: int = 16, - layer_filter: str | None = None, - save_dir: str | None = None, - ): - """Initialize the ActivationMSELogger. - - Args: - max_samples: Maximum number of calibration batches to process per phase. - layer_filter: Optional glob pattern to restrict which layers are tracked. - save_dir: Optional directory path for persisting activation data to disk. - """ - self.max_samples = max_samples - self.layer_filter = layer_filter - self.save_dir = save_dir - - # Per-phase state - self.original_activations: dict[str, list[torch.Tensor]] = {} - self.quantized_activations: dict[str, list[torch.Tensor]] = {} - self.input_hashes: list[str] = [] # hashes for "original" phase - self.quant_input_hashes: list[str] = [] # hashes for "quantized" phase - - # Computed after both phases - self.mse_results: dict[str, float] | None = None - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - @torch.no_grad() - def collect( - self, - model: nn.Module, - data: Iterable, - phase: str, - target_modules: dict[str, nn.Module] | None = None, - ) -> None: - """Collect per-linear-layer output activations for a given phase. - - Args: - model: The model to run (original or quantized). - data: An iterable of batches. Each batch can be: - - - ``torch.Tensor`` with shape ``[B, seq_len]`` (FP-Quant style). - - ``dict`` with at least an ``"input_ids"`` key (ModelOpt style). - - ``list`` / ``tuple`` of tensors. - phase: ``"original"`` or ``"quantized"``. - target_modules: Optional explicit mapping of ``{name: nn.Module}`` - to attach hooks to. If *None*, layers are auto-discovered - via decoder-block scanning. - """ - if phase not in ("original", "quantized"): - raise ValueError(f"phase must be 'original' or 'quantized', got {phase!r}") - - was_training = model.training - model.eval() - - # ----- layer discovery ----- - targets = ( - target_modules - if target_modules is not None - else (_portable_discover_target_layers(model, self.layer_filter)) - ) - if not targets: - raise ValueError( - "No linear layers found. Provide target_modules explicitly or " - f"check layer_filter={self.layer_filter!r}." - ) - - print( - f"[ActivationMSELogger] Phase '{phase}': hooking {len(targets)} layers, " - f"max_samples={self.max_samples}" - ) - - # ----- storage ----- - saved: dict[str, list[torch.Tensor]] = {name: [] for name in targets} - captured: dict[str, torch.Tensor] = {} - hashes: list[str] = [] - - def _make_hook(key: str): - def hook(_module: nn.Module, _input, output) -> None: - captured[key] = _tensor_from_output(output).cpu() - - return hook - - hooks = [] - for name, module in targets.items(): - hooks.append(module.register_forward_hook(_make_hook(name))) - - try: - n_batches = 0 - for batch in tqdm(data, desc=f"Collecting ({phase})", leave=False): - if self.max_samples is not None and n_batches >= self.max_samples: - break - - captured.clear() - self._run_batch(model, batch) - - for name in targets: - if name in captured: - saved[name].append(captured[name]) - - hashes.append(self._hash_batch(batch)) - n_batches += 1 - finally: - for h in hooks: - h.remove() - - model.train(was_training) - - # ----- store results on self ----- - if phase == "original": - self.original_activations = saved - self.input_hashes = hashes - else: - self.quantized_activations = saved - self.quant_input_hashes = hashes - # Verify sample consistency - if self.input_hashes: - self._verify_hashes() - - # Invalidate any previous MSE since we have new activations - self.mse_results = None - - print(f"[ActivationMSELogger] Collected {n_batches} batches for phase '{phase}'") - - def compute_mse(self) -> dict[str, float]: - """Compute per-layer MSE between original and quantized activations. - - Returns: - Dict mapping layer name to its MSE value. - - Raises: - ValueError: If either phase has not been collected yet. - """ - if not self.original_activations: - raise ValueError( - "No original activations collected. Call collect(..., phase='original') first." - ) - if not self.quantized_activations: - raise ValueError( - "No quantized activations collected. Call collect(..., phase='quantized') first." - ) - - common_keys = sorted( - set(self.original_activations.keys()) & set(self.quantized_activations.keys()) - ) - if not common_keys: - raise ValueError( - "No matching layer names between original and quantized activations. " - "Ensure the same model architecture / layer_filter is used for both phases." - ) - - orig_only = set(self.original_activations.keys()) - set(self.quantized_activations.keys()) - quant_only = set(self.quantized_activations.keys()) - set(self.original_activations.keys()) - if orig_only: - print( - f"[ActivationMSELogger] Warning: {len(orig_only)} layers only in original (skipped)" - ) - if quant_only: - print( - f"[ActivationMSELogger] Warning: {len(quant_only)} layers only in quantized (skipped)" - ) - - sum_sq: dict[str, float] = dict.fromkeys(common_keys, 0.0) - count: dict[str, int] = dict.fromkeys(common_keys, 0) - - for name in common_keys: - orig_list = self.original_activations[name] - quant_list = self.quantized_activations[name] - n = min(len(orig_list), len(quant_list)) - for i in range(n): - o = orig_list[i].float() - q = quant_list[i].float() - if o.shape != q.shape: - print( - f"[ActivationMSELogger] Warning: shape mismatch for {name} " - f"batch {i}: {o.shape} vs {q.shape}, skipping" - ) - continue - sum_sq[name] += F.mse_loss(o, q, reduction="sum").item() - count[name] += o.numel() - - self.mse_results = { - key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") - for key in common_keys - } - return self.mse_results - - def save(self, path: str | None = None) -> str: - """Save all state (activations, hashes, MSE) to disk via ``torch.save``. - - Args: - path: Explicit file path. If *None*, a timestamped file is created - inside ``self.save_dir`` (which must be set). - - Returns: - The path where the file was saved. - """ - if path is None: - if self.save_dir is None: - raise ValueError("Provide a path or set save_dir in the constructor.") - os.makedirs(self.save_dir, exist_ok=True) - ts = datetime.now().strftime("%Y%m%d_%H%M%S") - path = os.path.join(self.save_dir, f"activation_mse_{ts}.pt") - - payload = { - "max_samples": self.max_samples, - "layer_filter": self.layer_filter, - "input_hashes": self.input_hashes, - "quant_input_hashes": self.quant_input_hashes, - "original_activations": self.original_activations, - "quantized_activations": self.quantized_activations, - "mse": self.mse_results, - } - torch.save(payload, path) - print(f"[ActivationMSELogger] Saved to {path}") - return path - - @classmethod - def load(cls, path: str) -> "ActivationMSELogger": - """Load a previously saved ``ActivationMSELogger`` from disk. - - Args: - path: Path to the ``.pt`` file created by :meth:`save`. - - Returns: - A new ``ActivationMSELogger`` instance with restored state. - """ - payload = torch.load(path, map_location="cpu", weights_only=False) - logger = cls( - max_samples=payload.get("max_samples", 16), - layer_filter=payload.get("layer_filter"), - ) - logger.original_activations = payload.get("original_activations", {}) - logger.quantized_activations = payload.get("quantized_activations", {}) - logger.input_hashes = payload.get("input_hashes", []) - logger.quant_input_hashes = payload.get("quant_input_hashes", []) - logger.mse_results = payload.get("mse") - print(f"[ActivationMSELogger] Loaded from {path}") - return logger - - def summary(self) -> str: - """Return a formatted string summarising per-layer MSE results. - - Computes MSE first if not already done. - """ - if self.mse_results is None: - self.compute_mse() - assert self.mse_results is not None - - lines = ["Per-layer activation MSE (original vs quantized):"] - lines.extend( - f" {key}: {self.mse_results[key]:.6e}" for key in sorted(self.mse_results.keys()) - ) - return "\n".join(lines) - - # ------------------------------------------------------------------ - # Pre-materialized MSE data (cross-run / cross-codebase safety) - # ------------------------------------------------------------------ - - @staticmethod - def materialize_data( - data: Iterable, - path: str, - max_samples: int | None = None, - ) -> list[torch.Tensor]: - """Freeze the first *max_samples* batches from *data* into a ``.pt`` file. - - Each batch (``dict``, ``Tensor``, or ``list/tuple``) is normalised to a - single ``input_ids`` CPU tensor before saving. The resulting file is a - plain ``List[Tensor]`` that can be loaded in **any** codebase and passed - straight to :meth:`collect`. - - If *path* already exists it is **not** overwritten -- call - :meth:`load_data` instead. - - Args: - data: Iterable of batches (DataLoader, List[Tensor], etc.). - path: Destination ``.pt`` file path. - max_samples: How many batches to keep. ``None`` means all. - - Returns: - The materialised list of CPU tensors (same object that was saved). - """ - samples: list[torch.Tensor] = [] - for batch in data: - if max_samples is not None and len(samples) >= max_samples: - break - if isinstance(batch, dict): - t = batch.get("input_ids", next(iter(batch.values()))) - elif isinstance(batch, torch.Tensor): - t = batch - elif isinstance(batch, (list, tuple)): - t = batch[0] - else: - raise TypeError(f"Unsupported batch type: {type(batch)}") - samples.append(t.cpu()) - - os.makedirs(os.path.dirname(path) or ".", exist_ok=True) - torch.save(samples, path) - print(f"[ActivationMSELogger] Materialised {len(samples)} MSE input samples -> {path}") - return samples - - @staticmethod - def load_data(path: str) -> list[torch.Tensor]: - """Load a previously materialised MSE input set. - - Args: - path: Path to the ``.pt`` file created by :meth:`materialize_data`. - - Returns: - ``List[Tensor]`` of input batches (on CPU). - """ - samples = torch.load(path, map_location="cpu", weights_only=True) - print(f"[ActivationMSELogger] Loaded {len(samples)} MSE input samples from {path}") - return samples - - # ------------------------------------------------------------------ - # Static / private helpers - # ------------------------------------------------------------------ - - @staticmethod - def _run_batch(model: nn.Module, batch) -> None: - """Run a single batch through the model (handles Tensor, dict, list/tuple). - - Automatically moves inputs to the model's device so that CPU-stored - materialized data works transparently with a CUDA model. - """ - device = next(model.parameters()).device - if isinstance(batch, dict): - batch = { - k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() - } - model(**batch) - elif isinstance(batch, torch.Tensor): - model(batch.to(device)) - elif isinstance(batch, (list, tuple)): - batch = tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in batch) - model(*batch) - else: - raise TypeError(f"Unsupported batch type: {type(batch)}") - - @staticmethod - def _hash_batch(batch) -> str: - """Compute SHA-256 hash of the primary input tensor in *batch*. - - - ``dict`` -> hashes ``batch["input_ids"]`` (falls back to first value). - - ``Tensor`` -> hashes the tensor directly. - - ``list/tuple`` -> hashes the first element. - """ - if isinstance(batch, dict): - t = batch.get("input_ids", next(iter(batch.values()))) - elif isinstance(batch, torch.Tensor): - t = batch - elif isinstance(batch, (list, tuple)): - t = batch[0] if batch else None - else: - return "" - - if t is None or not isinstance(t, torch.Tensor): - return "" - return hashlib.sha256(t.cpu().contiguous().numpy().tobytes()).hexdigest() - - def _verify_hashes(self) -> None: - """Compare input hashes between original and quantized phases.""" - n = min(len(self.input_hashes), len(self.quant_input_hashes)) - mismatches = sum(1 for i in range(n) if self.input_hashes[i] != self.quant_input_hashes[i]) - if mismatches: - print( - f"[ActivationMSELogger] WARNING: {mismatches}/{n} batches have " - f"different input hashes between original and quantized phases. " - f"The same data may not have been used for both phases!" - ) - else: - print(f"[ActivationMSELogger] Input hash verification passed ({n}/{n} match)") diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index f46276b4f..5b41f35b6 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -156,22 +156,12 @@ "*mlp.gate.*": {"enable": False}, # Skip the MOE router "*mlp.shared_expert_gate.*": {"enable": False}, # Skip the MOE router "*linear_attn.conv1d*": {"enable": False}, - "*mixer.conv1d*": {"enable": False}, + "*mixer.conv1d*": {"enable": False}, # Skip mamba conv1d "*output_layer*": {"enable": False}, "output.*": {"enable": False}, "default": {"enable": False}, } -super_disabled_quantizer_cfg = { - "*fc1_latent_proj*": {"enable": False}, # Skip Latent MOE - "*fc2_latent_proj*": {"enable": False}, # Skip Latent MOE - "*q_proj*": {"enable": False}, # Skip QKV Linear - "*k_proj*": {"enable": False}, # Skip QKV Linear - "*v_proj*": {"enable": False}, # Skip QKV Linear - "*o_proj*": {"enable": False}, # Skip Output Linear - "*mtp*": {"enable": False}, # Skip MTP layers -} - _mamba_moe_disabled_quantizer_cfg = { "*fc1_latent_proj*": {"enable": False}, # Skip Latent MOE @@ -182,53 +172,6 @@ "*o_proj*": {"enable": False}, # Skip QKV Output Projection } -SUPER_NVFP4_CONSERVATIVE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - **_default_disabled_quantizer_cfg, - **super_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear - }, - "algorithm": "max", -} - -SUPER_NVFP4_CONSERVATIVE_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - **_default_disabled_quantizer_cfg, - **super_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear - }, - "algorithm": { - "method": "gptq", - "use_sequential": True, - }, -} - INT8_DEFAULT_CFG = { "quant_cfg": { @@ -328,113 +271,6 @@ "algorithm": "max", } -INT4_BLOCKWISE_WEIGHT_ONLY_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True}, - "*input_quantizer": {"enable": False}, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq", - "use_sequential": True, - }, -} - -NVFP4_STATIC_WO_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq", - "use_sequential": True, - }, -} - -NVFP4_STATIC_WO_GPTQ_LITE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} - -NVFP4_STATIC_WO_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "max", - "use_sequential": False, - }, -} - -NVFP4_STATIC_WO_GPTQ_LITE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} - -NVFP4_DYNAMIC_WO_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} - INT4_AWQ_CFG = { "quant_cfg": { "*weight_quantizer": { @@ -588,7 +424,6 @@ "*weight_quantizer": _nvfp4_quantizer, "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, - **super_disabled_quantizer_cfg, "*mixer.in_proj*": {"enable": False}, # Skip mamba linear "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, From 96f35c8d664932704c4a3466e8060a14f8f33da5 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 24 Feb 2026 00:27:54 +0000 Subject: [PATCH 23/33] removed stray config Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 5b41f35b6..07539fbfc 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -424,8 +424,6 @@ "*weight_quantizer": _nvfp4_quantizer, "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, "algorithm": "max", } From 72faf3eec0001cfbab1de6ad7faf5adbd4dde290 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 24 Feb 2026 00:31:18 +0000 Subject: [PATCH 24/33] removed stray prints Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 7b390ef0f..e45e7c3bd 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -359,7 +359,6 @@ def mse_calibrate( if fp8_scale_sweep and is_nvfp4_static: # Replace calibrator with NVFP4MSECalibrator - print("mse_calibrate: Replacing calibrator with NVFP4MSECalibrator") module._calibrator = NVFP4MSECalibrator( amax=initial_amax, axis=module._calibrator._axis, @@ -646,7 +645,6 @@ def quant_func(x, amax, quantizer=weight_quantizer): error_func = helper.get_error_func() if fp8_scale_sweep and is_nvfp4_static: - print("local_hessian_calibrate: Replacing calibrator with NVFP4MSECalibrator") weight_quantizer._calibrator = NVFP4MSECalibrator( amax=initial_amax, axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, From 890f28bf2670bd3551538cfae97f06ca5cb991a5 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 6 Mar 2026 01:04:14 +0000 Subject: [PATCH 25/33] fix rebase issues Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 55 ++++++++++++++----- modelopt/torch/export/unified_export_hf.py | 7 ++- modelopt/torch/quantization/config.py | 3 - modelopt/torch/quantization/mode.py | 2 - modelopt/torch/quantization/model_calib.py | 22 -------- .../nn/modules/tensor_quantizer.py | 13 +---- .../torch/quantization/triton/__init__.py | 4 -- 7 files changed, 49 insertions(+), 57 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 674d0596e..4ceb51cd2 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -46,7 +46,7 @@ ) from modelopt.torch.utils import clear_cuda_cache -from ..quantization.nn import SequentialQuantizer, TensorQuantizer +from ..quantization.nn import NVFP4StaticQuantizer, SequentialQuantizer, TensorQuantizer from .model_config import ( KV_CACHE_FP8, KV_CACHE_INT8, @@ -353,6 +353,7 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> return get_scaling_factor(weight_quantizer[0]) quantization_format = get_quantization_format(module) + if quantization_format in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, @@ -371,9 +372,10 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer( weight_quantizer ) - return NVFP4QTensor.get_weights_scaling_factor( + # Unified method handles both static and dynamic quantizers + return NVFP4QTensor.get_weights_scaling_factor_from_quantizer( + weight_quantizer, weight, - weight_quantizer.block_sizes[-1], weight_scaling_factor_2.to(weight.device), )[0] @@ -407,16 +409,13 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") module_name = f"{type(module).__name__}.{weight_name}" _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) - if quantization_format in [ - QUANTIZATION_NVFP4, - QUANTIZATION_NVFP4_AWQ, - QUANTIZATION_NVFP4_SVDQUANT, - ]: - return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) - elif quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: - # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. - # This is because the kernel dequantizes weight to fp8, which is in range 448. - return weight_quantizer._amax.float() / 448.0 + if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: + # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. + # This is because the kernel dequantizes weight to fp8, which is in range 448. + return weight_quantizer._amax.float() / 448.0 + else: + # Unified method handles both static and dynamic quantizers + return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) # SequentialQuantizer is required if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled: @@ -799,7 +798,7 @@ def process_layer_quant_config(layer_config_dict): layer_config = {"quant_algo": "W8A16"} elif v == "int8_sq": layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"} - elif v == "nvfp4": + elif v in ["nvfp4", "nvfp4_static"]: layer_config = { "quant_algo": "NVFP4", "group_size": block_size_value, @@ -1397,6 +1396,18 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False for module in modules: module.weight_quantizer[-1].amax = weight_amax + # Handle NVFP4StaticQuantizer: unify global_amax for fused layers + elif isinstance(modules[0].weight_quantizer, NVFP4StaticQuantizer): + global_amax_list = [ + m.weight_quantizer.global_amax + for m in modules + if m.weight_quantizer.global_amax is not None + ] + if global_amax_list: + unified_global_amax = torch.max(torch.stack(global_amax_list)) + for module in modules: + module.weight_quantizer.global_amax = unified_global_amax + elif ( modules[0].weight_quantizer.is_enabled and modules[0].weight_quantizer.amax is not None @@ -1481,6 +1492,22 @@ def get_quant_config( if block_size == 0: block_size = get_weight_block_size(module) + # Static NVFP4 uses pre-computed per-block scales from MSE calibration + if quantization_format == QUANTIZATION_NVFP4: + weight_quantizer = getattr(module, "weight_quantizer", None) + if weight_quantizer is None: + # Try to get from first weight attribute + for wn in weight_names: + weight_quantizer = getattr( + module, quantizer_attr_names(wn).weight_quantizer, None + ) + if weight_quantizer is not None: + break + if weight_quantizer is not None: + is_static = isinstance(weight_quantizer, NVFP4StaticQuantizer) + if is_static: + quantization_format = "nvfp4_static" + # Construct per layer config dictionary layer_config_dict[name + ".quantization"] = quantization_format layer_config_dict[name + ".awq_block_size"] = block_size diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 4c87c3157..78c8874a0 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -52,7 +52,11 @@ from torch.distributed.fsdp import FSDPModule from modelopt.torch.quantization import set_quantizer_by_cfg_context -from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer +from modelopt.torch.quantization.nn import ( + NVFP4StaticQuantizer, + SequentialQuantizer, + TensorQuantizer, +) from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names @@ -544,6 +548,7 @@ def _export_quantized_weight( weight, _ = maybe_transpose_expert_weight_dimensions( weight, is_bmm_expert_weight=is_bmm_expert_weight ) + weight_scale = NVFP4QTensor.get_weights_scaling_factor( weight, block_size=block_size, diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 07539fbfc..b034d89a0 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -437,9 +437,6 @@ }, "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, "algorithm": { "method": "mse", diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index efc66ffa9..88e93bb77 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -255,8 +255,6 @@ def wrapped_calib_func( else: # Direct calibration (existing behavior) func(model, forward_loop=forward_loop, **kwargs) - else: - raise ValueError(f"No calibration function provided for method: {method}") # Lets get the latest metadata for the quantizer states metadata = {} diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index e45e7c3bd..c3e1c993b 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -132,28 +132,6 @@ def max_calibrate( for name, module in model.named_modules(): if hasattr(module, "layer_sync_moe_local_experts_amax"): module.layer_sync_moe_local_experts_amax() - elif hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() - - for name, module in list(model.named_modules()): - if isinstance(module, TensorQuantizer) and not module._disabled: - if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - # Get the initial amax from max calibration - initial_amax = module._amax.clone().detach() - - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - - if is_nvfp4_static: - # Compute and set global_amax - global_amax = reduce_amax(initial_amax, axis=None) - - # Convert to NVFP4StaticQuantizer in-place - NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) if not distributed_sync: return diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 4317c5860..ec2c3cfc5 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -1331,19 +1331,10 @@ def global_amax(self, value): def _fake_quantize(self, inputs): """Fake quantization using two-level scaling with _amax and _global_amax.""" if self.amax is not None: - # Ensure amax/global_amax are on the same device as inputs. - # After from_pretrained with device_map, quantizer buffers may remain - # on CPU while model weights/activations are on GPU. - amax = self.amax - if amax.device != inputs.device: - amax = amax.to(inputs.device) - global_amax = self.global_amax - if global_amax is not None and global_amax.device != inputs.device: - global_amax = global_amax.to(inputs.device) return static_blockwise_fp4_fake_quant( inputs, - amax, - global_amax, # Can be None, will be computed internally + self.amax, + self.global_amax, # Can be None, will be computed internally True, # quantize_block_scales inputs.dtype, self._pass_through_bwd, diff --git a/modelopt/torch/quantization/triton/__init__.py b/modelopt/torch/quantization/triton/__init__.py index 6e8d4dba1..def70e591 100644 --- a/modelopt/torch/quantization/triton/__init__.py +++ b/modelopt/torch/quantization/triton/__init__.py @@ -34,10 +34,6 @@ from .fp4_kernel import * from .fp8_kernel import * - # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) - if torch.cuda.get_device_capability() >= (8, 9): - from .fp4_kernel_hopper import * - # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) if torch.cuda.get_device_capability() >= (8, 9): from .fp4_kernel_hopper import * From 40a665728161d1b22d8328b602f7993bb67c3411 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 6 Mar 2026 01:06:18 +0000 Subject: [PATCH 26/33] minor Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index b034d89a0..fa20356ff 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -162,7 +162,6 @@ "default": {"enable": False}, } - _mamba_moe_disabled_quantizer_cfg = { "*fc1_latent_proj*": {"enable": False}, # Skip Latent MOE "*fc2_latent_proj*": {"enable": False}, # Skip Latent MOE @@ -172,7 +171,6 @@ "*o_proj*": {"enable": False}, # Skip QKV Output Projection } - INT8_DEFAULT_CFG = { "quant_cfg": { "*weight_quantizer": {"num_bits": 8, "axis": 0}, @@ -271,6 +269,7 @@ "algorithm": "max", } + INT4_AWQ_CFG = { "quant_cfg": { "*weight_quantizer": { From cd9246a8f23488a10a781db29283a95c87639927 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 6 Mar 2026 22:12:00 +0000 Subject: [PATCH 27/33] tested e2e on qwen Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 34 ++++++++++ modelopt/torch/quantization/config.py | 20 ++++++ modelopt/torch/quantization/mode.py | 4 +- modelopt/torch/quantization/model_calib.py | 73 +++++++++++++++++++--- 4 files changed, 120 insertions(+), 11 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 939ecee96..9dbdc42bf 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -24,6 +24,7 @@ import numpy as np import torch from accelerate.hooks import remove_hook_from_module +from eval_perplexity import evaluate_perplexity from example_utils import ( build_quant_cfg, copy_custom_model_files, @@ -107,6 +108,7 @@ def _set_kv_cache_constant_amax(quant_cfg: dict) -> None: "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, "nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG, "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, + "nvfp4_gptq": mtq.NVFP4_GPTQ_CFG, "mxfp8": mtq.MXFP8_DEFAULT_CFG, "nvfp4_local_hessian": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, } @@ -922,6 +924,7 @@ def quantize_main( else: # mono quantization +<<<<<<< HEAD if args.recipe is not None: print(f"Use recipe {args.recipe} for quantization") recipe = load_recipe(args.recipe) @@ -929,6 +932,26 @@ def quantize_main( f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}" ) quant_cfg = recipe.ptq_cfg +======= + assert ( + args.qformat + in [ + "int8_wo", + "int4_awq", + "fp8", + "nvfp4", + "nvfp4_awq", + "nvfp4_mse", + "nvfp4_gptq", + "w4a8_awq", + "fp8_pb_wo", + "w4a8_mxfp4_fp8", + "nvfp4_mlp_only", + "mxfp8", + ] + or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES + ), f"Plain quantization format {args.qformat} not supported for HF export path" +>>>>>>> 6b8812d6 (tested e2e on qwen) else: assert len(args.qformat.split(",")) == 1, ( @@ -1000,6 +1023,11 @@ def quantize_main( is_nemotron_vl_model, first_text_speech_dataset, ) + + if args.eval_perplexity and tokenizer is not None: + print("Evaluating Wikitext-2 perplexity...") + evaluate_perplexity(language_model, tokenizer, seq_len=args.calib_seq) + export_quantized( args, full_model, @@ -1158,6 +1186,12 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) + parser.add_argument( + "--eval_perplexity", + help="Evaluate Wikitext-2 perplexity after quantization.", + default=False, + action="store_true", + ) parser.add_argument( "--low_memory_mode", help=( diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index fa20356ff..5611f3e42 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -459,6 +459,25 @@ }, } +NVFP4_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": {"method": "gptq", "use_sequential": True}, +} + MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = { "quant_cfg": { "*weight_quantizer": _nvfp4_quantizer, @@ -679,6 +698,7 @@ "NVFP4_AWQ_FULL_CFG", "NVFP4_AWQ_LITE_CFG", "NVFP4_DEFAULT_CFG", + "NVFP4_GPTQ_CFG", "NVFP4_FP8_MHA_CONFIG", "NVFP4_KV_CFG", "NVFP4_KV_ROTATE_CFG", diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 88e93bb77..df48c72c2 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -242,8 +242,8 @@ def wrapped_calib_func( if sequential: if forward_loop is None: raise ValueError("forward_loop is required for calibration but got None.") - assert method in ["max"], ( - f"Sequential calibration currently only supports max calibration, got {method}" + assert method in ["max", "gptq"], ( + f"Sequential calibration currently only supports max and gptq calibration, got {method}" ) # Wrap with sequential processing sequential_calibrate( diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index c3e1c993b..d27151dcb 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1848,19 +1848,62 @@ def _layer_forward_loop(m, _inputs=layer_inputs): print_rank_0("Sequential calibration completed") +def _promote_nvfp4_static_quantizers(model: nn.Module) -> int: + """Convert eligible TensorQuantizers to NVFP4StaticQuantizer in-place. + + After max calibration sets per-block amax values, NVFP4 static quantizers + need to be promoted so they use the two-level scaling path (global amax + + per-block amax) instead of the generic E4M3 path. + + Returns the number of quantizers converted. + """ + converted = 0 + for _name, module in list(model.named_modules()): + if isinstance(module, TensorQuantizer) and not module._disabled: + if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): + is_nvfp4_static = ( + module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes is not None + and module._block_sizes.get("scale_bits") == (4, 3) + ) + if is_nvfp4_static: + initial_amax = module._amax.clone().detach() + global_amax = reduce_amax(initial_amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + converted += 1 + return converted + + @torch.no_grad() def gptq( layer: nn.Module, - inputs: list[tuple[tuple, dict]], forward_loop: ForwardLoop, percdamp: float = 0.01, block_size: int = 128, **kwargs, ): - """GPTQ quantization - a GPTQ variant.""" - # Set weight amax and activation amax'es for the current layer using max_calibrate + """GPTQ quantization - a GPTQ variant. + + Args: + layer: A single decoder layer to quantize. + forward_loop: Callable that replays calibration inputs through the layer. + Provided by ``sequential_calibrate`` which captures per-layer activations. + percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). + block_size: Block size for GPTQ weight update. + """ + import time + + total_start = time.time() + + # Set weight amax and activation amax for the current layer using max_calibrate max_calibrate(layer, forward_loop=forward_loop) + # Promote NVFP4 static quantizers so they use the two-level scaling path + n_promoted = _promote_nvfp4_static_quantizers(layer) + if n_promoted: + print_rank_0(f"Promoted {n_promoted} quantizer(s) to NVFP4StaticQuantizer") + # Dictionary to store hessian matrices for all linear layers in this decoder hessian_state = {} @@ -1904,18 +1947,20 @@ def hessian_forward(self, input, *args, **kwargs): bind_forward_method(module, _make_hessian_forward(name), "_forward_no_gptq_hessian") patched_modules.append(module) - # Run forward passes with the provided inputs to collect Hessians - print_rank_0( - f"Computing Hessians for {len(tensor_mapping)} linear layers using {len(inputs)} batches..." - ) - for args, kwargs_input in inputs: - layer(*args, **kwargs_input) + # Run forward passes to collect Hessians + hessian_start = time.time() + print_rank_0(f"Computing Hessians for {len(tensor_mapping)} linear layers...") + forward_loop(layer) # Unpatch forwards for module in patched_modules: unpatch_forward_method(module, "_forward_no_gptq_hessian") + torch.cuda.synchronize() if torch.cuda.is_available() else None + hessian_time = time.time() - hessian_start + # Phase 3: Update weights using computed Hessians (same as gptq_lite) + weight_update_start = time.time() print_rank_0("Updating weights using GPTQ algorithm...") for name, module in layer.named_modules(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: @@ -1927,3 +1972,13 @@ def hessian_forward(self, input, *args, **kwargs): # Free memory del hessian_state[module.name] torch.cuda.empty_cache() + + torch.cuda.synchronize() if torch.cuda.is_available() else None + weight_update_time = time.time() - weight_update_start + + total_time = time.time() - total_start + print_rank_0( + f"GPTQ timing - Hessian: {hessian_time:.2f}s, " + f"Weight update: {weight_update_time:.2f}s, " + f"Total: {total_time:.2f}s" + ) From 49e54cc54b754a540a7b7877f8ddc608a4011e8d Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 6 Mar 2026 22:20:44 +0000 Subject: [PATCH 28/33] removed perplexity eval Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 33 +-------------------------------- 1 file changed, 1 insertion(+), 32 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 9dbdc42bf..56a12e28c 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -24,7 +24,6 @@ import numpy as np import torch from accelerate.hooks import remove_hook_from_module -from eval_perplexity import evaluate_perplexity from example_utils import ( build_quant_cfg, copy_custom_model_files, @@ -924,7 +923,6 @@ def quantize_main( else: # mono quantization -<<<<<<< HEAD if args.recipe is not None: print(f"Use recipe {args.recipe} for quantization") recipe = load_recipe(args.recipe) @@ -932,26 +930,6 @@ def quantize_main( f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}" ) quant_cfg = recipe.ptq_cfg -======= - assert ( - args.qformat - in [ - "int8_wo", - "int4_awq", - "fp8", - "nvfp4", - "nvfp4_awq", - "nvfp4_mse", - "nvfp4_gptq", - "w4a8_awq", - "fp8_pb_wo", - "w4a8_mxfp4_fp8", - "nvfp4_mlp_only", - "mxfp8", - ] - or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES - ), f"Plain quantization format {args.qformat} not supported for HF export path" ->>>>>>> 6b8812d6 (tested e2e on qwen) else: assert len(args.qformat.split(",")) == 1, ( @@ -1024,10 +1002,6 @@ def quantize_main( first_text_speech_dataset, ) - if args.eval_perplexity and tokenizer is not None: - print("Evaluating Wikitext-2 perplexity...") - evaluate_perplexity(language_model, tokenizer, seq_len=args.calib_seq) - export_quantized( args, full_model, @@ -1186,12 +1160,7 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) - parser.add_argument( - "--eval_perplexity", - help="Evaluate Wikitext-2 perplexity after quantization.", - default=False, - action="store_true", - ) + parser.add_argument( "--low_memory_mode", help=( From 2bc18cd495698d4f3e3fe17793eb3ff396ada901 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 6 Mar 2026 23:39:59 +0000 Subject: [PATCH 29/33] update Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 56a12e28c..bbedb43e6 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -24,6 +24,7 @@ import numpy as np import torch from accelerate.hooks import remove_hook_from_module +from eval_perplexity import evaluate_perplexity from example_utils import ( build_quant_cfg, copy_custom_model_files, @@ -678,6 +679,9 @@ def export_quantized( "They will be set at deployment time." ) + if getattr(args, "eval_perplexity", False) and tokenizer is not None: + evaluate_perplexity(full_model, tokenizer, seq_len=2048) + # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(full_model, args.pyt_ckpt_path) @@ -1219,6 +1223,12 @@ def parse_args() -> argparse.Namespace: "Does not impact non-MOE models." ), ) + parser.add_argument( + "--eval_perplexity", + action=argparse.BooleanOptionalAction, + default=False, + help="Evaluate Wikitext-2 perplexity after quantization (before export).", + ) args = parser.parse_args() if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): From c3119505c061ff1ec6bd0cde2610b18dc2f7294e Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 16 Mar 2026 18:26:47 +0000 Subject: [PATCH 30/33] revert later Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 123 +++++++++++++++++++++++- modelopt/torch/quantization/__init__.py | 8 +- modelopt/torch/quantization/config.py | 30 ++++++ 3 files changed, 158 insertions(+), 3 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index bbedb43e6..d81d4bed9 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -24,7 +24,6 @@ import numpy as np import torch from accelerate.hooks import remove_hook_from_module -from eval_perplexity import evaluate_perplexity from example_utils import ( build_quant_cfg, copy_custom_model_files, @@ -63,6 +62,11 @@ ) from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration +from modelopt.torch.quantization.metrics import ( + ActivationMSELogger, + compute_perplexity, + get_wikitext2, +) from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights from modelopt.torch.quantization.utils import is_quantized from modelopt.torch.utils.dataset_utils import ( @@ -99,6 +103,7 @@ def _set_kv_cache_constant_amax(quant_cfg: dict) -> None: "int4_awq": mtq.INT4_AWQ_CFG, "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG, "nvfp4": mtq.NVFP4_DEFAULT_CFG, + "nvfp4_wo": mtq.NVFP4_WEIGHT_ONLY_CFG, "nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG, "nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG, "fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, @@ -108,6 +113,7 @@ def _set_kv_cache_constant_amax(quant_cfg: dict) -> None: "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, "nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG, "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, + "nvfp4_wo_gptq": mtq.NVFP4_WEIGHT_ONLY_GPTQ_CFG, "nvfp4_gptq": mtq.NVFP4_GPTQ_CFG, "mxfp8": mtq.MXFP8_DEFAULT_CFG, "nvfp4_local_hessian": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, @@ -680,7 +686,10 @@ def export_quantized( ) if getattr(args, "eval_perplexity", False) and tokenizer is not None: - evaluate_perplexity(full_model, tokenizer, seq_len=2048) + seq_len = getattr(args, "eval_perplexity_seq_len", 2048) + eval_data = get_wikitext2(tokenizer, seq_len) + ppl = compute_perplexity(full_model, eval_data) + print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization @@ -913,6 +922,64 @@ def quantize_main( args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model ) + # Collect original (unquantized) activations before quantization modifies the model + mse_logger = None + if getattr(args, "measure_activation_mse", False): + n_mse = getattr(args, "activation_mse_max_samples", 16) + mse_save_dir = getattr(args, "activation_mse_save_dir", None) + mse_input_path = getattr(args, "activation_mse_input_path", None) + + # Resolve MSE input data: frozen file (raw text or tokenized) or live dataloader + mse_data = None + if mse_input_path is not None: + if mse_input_path.endswith(".json"): + if os.path.isfile(mse_input_path): + print(f"Loading MSE input data from existing .json file: {mse_input_path}") + texts = ActivationMSELogger.load_raw_text(mse_input_path) + mse_data = ActivationMSELogger.tokenize_raw_text( + texts, + tokenizer, + max_length=args.calib_seq, + ) + else: + assert tokenizer is not None, ( + "--activation_mse_input_path with .json requires a tokenizer to decode" + ) + print(f"Creating MSE input data .json file: {mse_input_path}") + texts = ActivationMSELogger.materialize_raw_text( + calib_dataloader, + mse_input_path, + tokenizer=tokenizer, + max_samples=n_mse, + ) + mse_data = ActivationMSELogger.tokenize_raw_text( + texts, + tokenizer, + max_length=args.calib_seq, + ) + elif mse_input_path.endswith(".pt"): + if os.path.isfile(mse_input_path): + print(f"Loading MSE input data from existing .pt file: {mse_input_path}") + mse_data = ActivationMSELogger.load_data(mse_input_path) + else: + print(f"Creating MSE input data .pt file: {mse_input_path}") + mse_data = ActivationMSELogger.materialize_data( + calib_dataloader, + mse_input_path, + max_samples=n_mse, + ) + else: + raise ValueError( + f"--activation_mse_input_path must end with .json or .pt, got: {mse_input_path}" + ) + + if mse_data is None: + mse_data = calib_dataloader + + mse_logger = ActivationMSELogger(max_samples=n_mse, save_dir=mse_save_dir) + print(f"Collecting original (unquantized) activations for MSE over {n_mse} samples...") + mse_logger.collect(language_model, mse_data, phase="original") + if args.auto_quantize_bits: assert len(args.qformat.split(",")) > 1, ( "Auto quantization needs multiple quantization format." @@ -1006,6 +1073,22 @@ def quantize_main( first_text_speech_dataset, ) + if mse_logger is not None: + import gc + + print("Collecting quantized activations for MSE...") + mse_logger.collect(language_model, mse_data, phase="quantized") + + mse_logger.compute_mse() + print(mse_logger.summary()) + + if getattr(args, "activation_mse_save_dir", None): + mse_logger.save() + + del mse_logger, mse_data + gc.collect() + torch.cuda.empty_cache() + export_quantized( args, full_model, @@ -1229,6 +1312,42 @@ def parse_args() -> argparse.Namespace: default=False, help="Evaluate Wikitext-2 perplexity after quantization (before export).", ) + parser.add_argument( + "--eval_perplexity_seq_len", + type=int, + default=2048, + help="Sequence length for perplexity evaluation (default: 2048).", + ) + parser.add_argument( + "--measure_activation_mse", + action=argparse.BooleanOptionalAction, + default=False, + help="Measure per-layer activation MSE (original vs quantized) after quantization.", + ) + parser.add_argument( + "--activation_mse_max_samples", + type=int, + default=16, + help="Max calibration samples for activation MSE (default: 16).", + ) + parser.add_argument( + "--activation_mse_save_dir", + type=str, + default=None, + help="Directory to save activation MSE results. If not set, results are only printed.", + ) + parser.add_argument( + "--activation_mse_input_path", + type=str, + default=None, + help=( + "Path to frozen MSE input data. Supports two formats:\n" + " .json — raw text (cross-model reuse): if file exists, loads and re-tokenizes " + "with the current model's tokenizer; if not, decodes calibration data to text and saves.\n" + " .pt — tokenized tensors (same-tokenizer reuse): if file exists, loads directly; " + "if not, materializes from calibration data and saves." + ), + ) args = parser.parse_args() if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): diff --git a/modelopt/torch/quantization/__init__.py b/modelopt/torch/quantization/__init__.py index 87dbf30bb..d471e5582 100644 --- a/modelopt/torch/quantization/__init__.py +++ b/modelopt/torch/quantization/__init__.py @@ -16,12 +16,18 @@ """Quantization package.""" # Initialize mode and plugins -from . import mode, plugins, utils +from . import metrics, mode, plugins, utils # Add methods to mtq namespace from .compress import * from .config import * from .conversion import * +from .metrics import ( + ActivationMSELogger, + compute_perplexity, + get_wikitext2, + measure_per_layer_activation_mse, +) from .model_quant import * from .nn.modules.quant_module import QuantModuleRegistry from .utils import update_quant_cfg_with_kv_cache_quant diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 5611f3e42..832079041 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -427,6 +427,20 @@ "algorithm": "max", } +NVFP4_WEIGHT_ONLY_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": {"enable": False}, + **_default_disabled_quantizer_cfg, + }, + "algorithm": "max", +} + NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG = { "quant_cfg": { "*weight_quantizer": { @@ -459,6 +473,20 @@ }, } +NVFP4_WEIGHT_ONLY_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": {"enable": False}, + **_default_disabled_quantizer_cfg, + }, + "algorithm": {"method": "gptq", "use_sequential": True}, +} + NVFP4_GPTQ_CFG = { "quant_cfg": { "*weight_quantizer": { @@ -699,6 +727,8 @@ "NVFP4_AWQ_LITE_CFG", "NVFP4_DEFAULT_CFG", "NVFP4_GPTQ_CFG", + "NVFP4_WEIGHT_ONLY_CFG", + "NVFP4_WEIGHT_ONLY_GPTQ_CFG", "NVFP4_FP8_MHA_CONFIG", "NVFP4_KV_CFG", "NVFP4_KV_ROTATE_CFG", From fec8f89c4eba5113dd4fae6c29a092114c45f542 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Thu, 19 Mar 2026 06:32:45 +0000 Subject: [PATCH 31/33] minor update Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/utils/network.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index b07ca570c..b54332375 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -46,7 +46,6 @@ def _convert_to_wrapped_module_name(name: str) -> str: "ModelLike", "compare_dict", "create_param_grad_clear_hook", - "get_decoder_layers", "get_model_attributes", "get_module_device", "get_same_padding", From 806e8ac65ea8658d711441a9354408d7abe51e79 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 18 Mar 2026 06:35:56 +0000 Subject: [PATCH 32/33] update Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 142 +++++++++++++++++++-- 1 file changed, 129 insertions(+), 13 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index d27151dcb..79c184b27 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1609,6 +1609,103 @@ def prepare_hessian_inverse(h, weight, percdamp): return h_inv +def _build_column_qdq(quantizer, weight_shape): + """Build a fast column-wise quantize-dequantize function for integer quantizers. + + Instead of calling the full TensorQuantizer on the entire weight matrix (which + quantizes all elements) and extracting one column, this returns a closure that + quantizes only a single column using the quantizer's pre-computed amax/scales. + + Since max_calibrate fixes the amax before GPTQ weight updates, quantizing a + single column with the same fixed scale gives bit-identical results to + quantizing the full matrix and extracting that column. + + Args: + quantizer: The weight TensorQuantizer (already calibrated). + weight_shape: Shape of the weight tensor (out_features, in_features). + + Returns: + Tuple of (column_qdq_fn, supported) where: + - column_qdq_fn(column, col_idx) -> qdq_column (if supported) + - supported: True if column-wise qdq is available, False to fall back. + """ + # Unsupported: NVFP4 (two-level FP4 scaling), FP quantization (num_bits is a tuple) + if isinstance(quantizer, NVFP4StaticQuantizer): + return None, False + if isinstance(quantizer._num_bits, tuple): + return None, False + + # Unsupported: pre_quant_scale (SmoothQuant) or rotation transforms mix columns + if getattr(quantizer, "pre_quant_scale", None) is not None: + return None, False + if getattr(quantizer, "rotate_is_enabled", False): + return None, False + + # Need calibrated amax + if not hasattr(quantizer, "_amax") or quantizer._amax is None: + return None, False + + num_bits = quantizer._num_bits + unsigned = getattr(quantizer, "_unsigned", False) + narrow_range = getattr(quantizer, "_narrow_range", False) + max_bound = (2 ** (num_bits - 1 + int(unsigned))) - 1 + min_bound = -max_bound + int(narrow_range) + + amax = quantizer._amax.float() + out_features, in_features = weight_shape + + # Determine quantization geometry from block_sizes + block_sizes = quantizer.block_sizes + group_size = None + if block_sizes is not None: + # Skip dynamic block quantization + if block_sizes.get("type", "static") == "dynamic": + return None, False + group_size = block_sizes.get(-1, None) or block_sizes.get(len(weight_shape) - 1, None) + + if group_size is not None and group_size > 0: + # Per-group block quantization along last dim. + # After _setup_for_blockquant, weight is reshaped to (-1, group_size) with axis=(0,). + # amax shape: (out_features * n_groups, 1) where n_groups = in_features // group_size. + if in_features % group_size != 0: + return None, False # Padding case — fall back + + n_groups = in_features // group_size + + try: + # Reshape amax to (out_features, n_groups) for O(1) group lookup + amax_2d = amax.reshape(out_features, n_groups) + except RuntimeError: + return None, False + + def _column_qdq_group( + col, col_idx, _a=amax_2d, _mx=max_bound, _mn=min_bound, _gs=group_size + ): + col_scale = _mx / _a[:, col_idx // _gs].clamp(min=1e-12) + return torch.clamp(torch.round(col * col_scale), _mn, _mx) / col_scale + + return _column_qdq_group, True + + # Per-channel (axis != None) or per-tensor (axis == None) + axis = quantizer.axis + if axis is not None: + # Per-channel: amax has shape (out_features, 1) or similar + col_scale = max_bound / amax.reshape(-1).clamp(min=1e-12) + + def _column_qdq_channel(col, col_idx, _s=col_scale, _mx=max_bound, _mn=min_bound): + return torch.clamp(torch.round(col * _s), _mn, _mx) / _s + + return _column_qdq_channel, True + + # Per-tensor: single scalar scale + scalar_scale = max_bound / amax.clamp(min=1e-12).item() + + def _column_qdq_tensor(col, col_idx, _s=scalar_scale, _mx=max_bound, _mn=min_bound): + return torch.clamp(torch.round(col * _s), _mn, _mx) / _s + + return _column_qdq_tensor, True + + def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): """Update module weights using GPTQ-style blockwise quantization. @@ -1625,22 +1722,41 @@ def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): # Preprocess Hessian: handle dead neurons and add damping h_inv = prepare_hessian_inverse(h, weight, percdamp) + # Try to build fast column-wise qdq (avoids quantizing the full matrix per column) + col_qdq_fn, col_qdq_supported = _build_column_qdq(module.weight_quantizer, weight.shape) + # Process weights in blocks for block_start in range(0, num_cols, block_size): block_end = min(block_start + block_size, num_cols) n_cols = block_end - block_start - wblk = weight.clone() - errs = torch.zeros_like(wblk[:, block_start:block_end]) h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end] - for i in range(n_cols): - w_ci = wblk[:, block_start + i] - d = h_inv_cho_blk[i, i] - qdq = module.weight_quantizer(wblk) - weight[:, block_start + i] = qdq[:, block_start + i] - err = (w_ci - qdq[:, block_start + i]) / d - wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) - errs[:, i] = err + if col_qdq_supported: + # Fast path: clone only the block columns, quantize only per-column + wblk = weight[:, block_start:block_end].clone() + errs = torch.zeros_like(wblk) + + for i in range(n_cols): + w_ci = wblk[:, i] + d = h_inv_cho_blk[i, i] + qdq_col = col_qdq_fn(w_ci, block_start + i) + weight[:, block_start + i] = qdq_col + err = (w_ci - qdq_col) / d + wblk[:, i:].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) + errs[:, i] = err + else: + # Fallback: original full-matrix quantization path + wblk = weight.clone() + errs = torch.zeros_like(wblk[:, block_start:block_end]) + + for i in range(n_cols): + w_ci = wblk[:, block_start + i] + d = h_inv_cho_blk[i, i] + qdq = module.weight_quantizer(wblk) + weight[:, block_start + i] = qdq[:, block_start + i] + err = (w_ci - qdq[:, block_start + i]) / d + wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) + errs[:, i] = err # Propagate errors to remaining weights weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1) @@ -1844,7 +1960,7 @@ def _layer_forward_loop(m, _inputs=layer_inputs): torch.cuda.empty_cache() finally: input_getter._unpatch_all_layers() - + print_rank_0("Sequential calibration completed") @@ -1969,9 +2085,9 @@ def hessian_forward(self, input, *args, **kwargs): blockwise_weight_update( module, hessian, block_size, percdamp, n_samples=state["n_samples"] ) - # Free memory del hessian_state[module.name] - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() torch.cuda.synchronize() if torch.cuda.is_available() else None weight_update_time = time.time() - weight_update_start From 3f2d7c05a4895d5890fe373198bdc3fe8d4d5da2 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 18 Mar 2026 23:30:36 +0000 Subject: [PATCH 33/33] gptq faster Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 108 +++++++--- .../quantization/triton/gptq_fused_kernel.py | 189 ++++++++++++++++++ tests/gpu/torch/quantization/test_gptq.py | 93 ++++++++- 3 files changed, 365 insertions(+), 25 deletions(-) create mode 100644 modelopt/torch/quantization/triton/gptq_fused_kernel.py diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 79c184b27..1e8a94b3d 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1545,7 +1545,7 @@ def _print_relative_mse_error( delta = q - w mse = (delta).mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6) suffix = f", n_hessian_samples: {n_samples}" if n_samples is not None else "" - print(f"[{module_name}] Relative MSE error: {mse.item():.2e}{suffix}") + print_rank_0(f"[{module_name}] Relative MSE error: {mse.item():.2e}{suffix}") def update_hessian(input, hessian, n_samples): @@ -1604,7 +1604,7 @@ def prepare_hessian_inverse(h, weight, percdamp): h = torch.cholesky_inverse(torch.linalg.cholesky(h)) h_inv = torch.linalg.cholesky(h, upper=True) except (RuntimeError, torch.linalg.LinAlgError): - print("Warning: Hessian is not positive definite, using identity matrix") + print_rank_0("Warning: Hessian is not positive definite, using identity matrix") h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) return h_inv @@ -1706,37 +1706,104 @@ def _column_qdq_tensor(col, col_idx, _s=scalar_scale, _mx=max_bound, _mn=min_bou return _column_qdq_tensor, True +def _can_use_fused_gptq(quantizer) -> bool: + """Check whether the fused Triton GPTQ kernel can be used for *quantizer*.""" + if not isinstance(quantizer, NVFP4StaticQuantizer): + return False + if not hasattr(quantizer, "_amax") or quantizer._amax is None: + return False + from modelopt.torch.quantization.triton import IS_AVAILABLE as _TRITON_OK + + return _TRITON_OK + + def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): """Update module weights using GPTQ-style blockwise quantization. + Dispatches to one of three internal paths depending on quantizer type: + + 1. **Fused Triton** — for :class:`NVFP4StaticQuantizer` when Triton is + available. Runs the entire column loop in a single GPU kernel per + block (~130x faster than the unfused path on Blackwell GPUs). + 2. **Column-QDQ** — for integer quantizers whose scale geometry allows + single-column fake-quant via :func:`_build_column_qdq`. + 3. **Full-matrix fallback** — calls the quantizer on the full weight matrix + each column (slowest, but always correct). + Args: - module: Neural network module with weight and weight_quantizer - H: Hessian matrix (d x d) - block_size: Size of blocks to process at once - percdamp: Damping percentage for Hessian diagonal - n_samples: Number of Hessian samples for logging (optional) + module: Neural network module with ``weight`` and ``weight_quantizer``. + h: Hessian matrix of shape ``(d, d)``. + block_size: Number of columns processed per block. + percdamp: Damping as a fraction of the mean Hessian diagonal. + n_samples: Number of Hessian samples (used only for logging). """ weight = module.weight.data.float().clone() - _, num_cols = weight.shape + num_rows, num_cols = weight.shape - # Preprocess Hessian: handle dead neurons and add damping h_inv = prepare_hessian_inverse(h, weight, percdamp) - # Try to build fast column-wise qdq (avoids quantizing the full matrix per column) - col_qdq_fn, col_qdq_supported = _build_column_qdq(module.weight_quantizer, weight.shape) + quantizer = module.weight_quantizer + if _can_use_fused_gptq(quantizer): + _blockwise_weight_update_fused(weight, h_inv, quantizer, num_rows, num_cols, block_size) + else: + col_qdq_fn, col_qdq_supported = _build_column_qdq(quantizer, weight.shape) + _blockwise_weight_update_unfused( + weight, h_inv, quantizer, num_cols, block_size, col_qdq_fn, col_qdq_supported + ) + + _print_relative_mse_error(weight, module.weight.float(), h, module.name, n_samples) + module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype) + + +def _blockwise_weight_update_fused(weight, h_inv, quantizer, num_rows, num_cols, block_size): + """Fused Triton path for NVFP4: one kernel launch per block.""" + from modelopt.torch.quantization.triton.gptq_fused_kernel import gptq_fused_block + + group_size = quantizer.block_sizes.get(-1, None) or quantizer.block_sizes.get(1, None) + num_groups = math.ceil(num_cols / group_size) + amax_grouped = quantizer._amax.float().reshape(num_rows, num_groups).contiguous() + global_amax = quantizer.global_amax.float() - # Process weights in blocks for block_start in range(0, num_cols, block_size): block_end = min(block_start + block_size, num_cols) - n_cols = block_end - block_start + n_cols_blk = block_end - block_start + + w_block = weight[:, block_start:block_end].clone().contiguous() + h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end].contiguous() + + qw_block, err_block = gptq_fused_block( + w_block, + amax_grouped, + global_amax, + h_inv_cho_blk, + group_size, + block_start, + n_cols_blk, + ) + + weight[:, block_start:block_end] = qw_block + if block_end < num_cols: + weight[:, block_end:].addmm_( + err_block[:, :n_cols_blk], + h_inv[block_start:block_end, block_end:], + alpha=-1, + ) + + +def _blockwise_weight_update_unfused( + weight, h_inv, quantizer, num_cols, block_size, col_qdq_fn, col_qdq_supported +): + """Column-QDQ or full-matrix fallback for non-NVFP4 quantizers.""" + for block_start in range(0, num_cols, block_size): + block_end = min(block_start + block_size, num_cols) + n_cols_blk = block_end - block_start h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end] if col_qdq_supported: - # Fast path: clone only the block columns, quantize only per-column wblk = weight[:, block_start:block_end].clone() errs = torch.zeros_like(wblk) - for i in range(n_cols): + for i in range(n_cols_blk): w_ci = wblk[:, i] d = h_inv_cho_blk[i, i] qdq_col = col_qdq_fn(w_ci, block_start + i) @@ -1745,27 +1812,20 @@ def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): wblk[:, i:].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) errs[:, i] = err else: - # Fallback: original full-matrix quantization path wblk = weight.clone() errs = torch.zeros_like(wblk[:, block_start:block_end]) - for i in range(n_cols): + for i in range(n_cols_blk): w_ci = wblk[:, block_start + i] d = h_inv_cho_blk[i, i] - qdq = module.weight_quantizer(wblk) + qdq = quantizer(wblk) weight[:, block_start + i] = qdq[:, block_start + i] err = (w_ci - qdq[:, block_start + i]) / d wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) errs[:, i] = err - # Propagate errors to remaining weights weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1) - # Print relative mse error - _print_relative_mse_error(weight, module.weight.float(), h, module.name, n_samples) - # Update module weights - module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype) - def gptq_lite( model: nn.Module, diff --git a/modelopt/torch/quantization/triton/gptq_fused_kernel.py b/modelopt/torch/quantization/triton/gptq_fused_kernel.py new file mode 100644 index 000000000..21d84713a --- /dev/null +++ b/modelopt/torch/quantization/triton/gptq_fused_kernel.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fused Triton kernel for the GPTQ blockwise weight-update inner loop. + +The standard GPTQ inner loop launches ~10-15 CUDA kernels per column +(amax lookup, FP4 quantization, error computation, rank-1 update). +For ``block_size=128`` that is ~1 500 kernel launches per block, each with +~5-10 us of launch overhead dominating actual compute. + +This module fuses the entire inner loop into a **single** Triton kernel per +block. Rows are independent and map to Triton programs; columns are processed +sequentially inside each program so the rank-1 error update is carried forward +without synchronisation. + +Supported quantisation format: **NVFP4 static block quantisation** (two-level +scaling with per-group amax and a global amax). +""" + +import torch +import triton +import triton.language as tl + +__all__ = ["gptq_fused_block"] + +# -- NVFP4 constants used by the kernel ------------------------------------ +# Maximum representable FP4-E2M1 value (1 + 1 + 0.5 = 6.0 when decoded via +# the standard E2M1 table: {0, 0.5, 1, 1.5, 2, 3, 4, 6}). +_FP4_MAX = 6.0 +# FP8-E4M3 has max representable value 448. +_FP8_E4M3_MAX = 448.0 + + +@triton.jit +def _gptq_fused_block_kernel( + w_ptr, # [num_rows, BLOCK_SIZE] working weight block (in-place) + qw_ptr, # [num_rows, BLOCK_SIZE] output: quantized weights + err_ptr, # [num_rows, BLOCK_SIZE] output: quantization errors + amax_ptr, # [num_rows, num_groups] per-group amax, row-major + global_amax_ptr, # scalar float32 on device + hinv_ptr, # [BLOCK_SIZE, BLOCK_SIZE] upper Cholesky of H^{-1} + num_rows, + num_groups, + group_size: tl.constexpr, + block_start, # column offset of this block in the full weight matrix + n_cols, # actual columns in this block (may be < BLOCK_SIZE) + BLOCK_SIZE: tl.constexpr, +): + """One program per row; sequentially quantizes columns, propagating errors.""" + row = tl.program_id(0) + if row >= num_rows: + return + + # Base pointers for this row + w_base = w_ptr + row * BLOCK_SIZE + qw_base = qw_ptr + row * BLOCK_SIZE + err_base = err_ptr + row * BLOCK_SIZE + amax_row_base = amax_ptr + row * num_groups + + # Pre-compute global FP8 scale factors (constant across columns) + global_amax = tl.load(global_amax_ptr).to(tl.float32) + global_scale = global_amax / 6.0 # _FP4_MAX + fp8_inv_scale = tl.where(global_scale > 0.0, 1.0 / (448.0 / global_scale), 0.0) + + j_range = tl.arange(0, BLOCK_SIZE) + + for i in range(BLOCK_SIZE): + wi = tl.load(w_base + i) + + # -- Compute NVFP4 two-level scale for this column's group ----------- + col_idx = block_start + i + group_idx = col_idx // group_size + raw_amax = tl.load(amax_row_base + group_idx).to(tl.float32) + raw_scale = raw_amax / 6.0 # _FP4_MAX + + # FP8-quantize the block scale: scale * fp8_scale -> cast E4M3 -> back + fp8_scale = tl.where(global_scale > 0.0, 448.0 / global_scale, 1.0) + si = (raw_scale * fp8_scale).to(tl.float8e4nv).to(tl.float32) * fp8_inv_scale + + # Guard: replace zero / nan / inf scale with 1.0 + # NOTE: ``si != si`` is the standard NaN check in Triton (no math.isnan). + si_safe = tl.where( + (si == 0.0) | (si != si) | (tl.abs(si) == float("inf")), # noqa: PLR0124 + 1.0, + si, + ) + + # -- FP4-E2M1 fake quantization (nearest-round to 8 levels) ---------- + abs_scaled = tl.abs(wi) / si_safe + q_val = tl.where( + abs_scaled <= 0.25, + 0.0, + tl.where( + abs_scaled < 0.75, + 0.5, + tl.where( + abs_scaled <= 1.25, + 1.0, + tl.where( + abs_scaled < 1.75, + 1.5, + tl.where( + abs_scaled <= 2.5, + 2.0, + tl.where(abs_scaled < 3.5, 3.0, tl.where(abs_scaled <= 5.0, 4.0, 6.0)), + ), + ), + ), + ), + ) + + qi = q_val * si_safe * tl.where(wi >= 0.0, 1.0, -1.0) + tl.store(qw_base + i, qi) + + # -- GPTQ error and rank-1 update ------------------------------------ + di = tl.load(hinv_ptr + i * BLOCK_SIZE + i) + err_i = (wi - qi) / di + tl.store(err_base + i, err_i) + + j_mask = (j_range > i) & (j_range < n_cols) + hinv_row = tl.load(hinv_ptr + i * BLOCK_SIZE + j_range, mask=j_mask, other=0.0) + w_rem = tl.load(w_base + j_range, mask=j_mask, other=0.0) + w_rem = w_rem - err_i * hinv_row + tl.store(w_base + j_range, w_rem, mask=j_mask) + + +def gptq_fused_block( + w_block: torch.Tensor, + amax_grouped: torch.Tensor, + global_amax: torch.Tensor, + h_inv_cho_blk: torch.Tensor, + group_size: int, + block_start: int, + n_cols: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Run the GPTQ column loop for one block in a single Triton kernel launch. + + Args: + w_block: Working weight block of shape ``[num_rows, block_size]`` (will be cloned). + amax_grouped: Per-group amax of shape ``[num_rows, num_groups]``. + global_amax: Scalar tensor with the global amax. + h_inv_cho_blk: Upper Cholesky factor of H^{-1}, shape ``[block_size, block_size]``. + group_size: NVFP4 quantization group size (typically 16). + block_start: Column offset of this block in the full weight matrix. + n_cols: Actual number of columns in this block (``<= block_size``). + + Returns: + Tuple of ``(qw_block, err_block)`` each of shape ``[num_rows, block_size]``. + """ + num_rows, block_size = w_block.shape + num_groups = amax_grouped.shape[1] + + w_block = w_block.contiguous() + amax_grouped = amax_grouped.contiguous() + h_inv_cho_blk = h_inv_cho_blk.contiguous() + + qw_block = torch.empty_like(w_block) + err_block = torch.empty_like(w_block) + + grid = (num_rows,) + with torch.cuda.device(w_block.device): + _gptq_fused_block_kernel[grid]( + w_block, + qw_block, + err_block, + amax_grouped, + global_amax, + h_inv_cho_blk, + num_rows, + num_groups, + group_size, + block_start, + n_cols, + BLOCK_SIZE=block_size, + ) + + return qw_block, err_block diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index c47b48b1e..23bdf6cbf 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -21,7 +21,14 @@ import modelopt.torch.quantization as mtq from modelopt.torch.export.unified_export_hf import _export_quantized_weight -from modelopt.torch.quantization.model_calib import blockwise_weight_update, update_hessian +from modelopt.torch.quantization.model_calib import ( + _blockwise_weight_update_fused, + _blockwise_weight_update_unfused, + blockwise_weight_update, + prepare_hessian_inverse, + update_hessian, +) +from modelopt.torch.quantization.nn import NVFP4StaticQuantizer from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader @@ -295,3 +302,87 @@ def test_gptq_e2e_flow(quant_cfg): print( f"Generated ids after quantization: {tokenizer.decode(generated_ids_after_ptq[0], skip_special_tokens=True)}" ) + + +@pytest.mark.parametrize("dim", [256, 512]) +def test_fused_vs_unfused_nvfp4(dim): + """Verify that the fused Triton GPTQ kernel produces equivalent results to the unfused path. + + The fused kernel computes NVFP4 quantisation inline using Triton intrinsics, + which can differ slightly from the PyTorch-level quantiser path (different FP + rounding order). On real models (dim >= 4096) the relative MSE difference is + typically < 0.1%; at the smaller dims used here the tolerance is set to 20%. + """ + from modelopt.torch.quantization.model_calib import _promote_nvfp4_static_quantizers + + torch.manual_seed(RAND_SEED) + block_size = min(128, dim) + + # NVFP4_WEIGHT_ONLY_GPTQ_CFG uses *static* blocks, which get promoted to + # NVFP4StaticQuantizer — the prerequisite for the fused Triton path. + quant_cfg = copy.deepcopy(mtq.NVFP4_WEIGHT_ONLY_GPTQ_CFG) + quant_cfg["algorithm"] = "max" # calibrate only, don't run GPTQ + + model = torch.nn.Linear(dim, dim, bias=False).to("cuda") + model.name = "test_fused" + original_weight = model.weight.data.clone() + inp = torch.randn(4, 32, dim, device="cuda") + + mtq.quantize(model, quant_cfg, forward_loop=lambda m: m(inp)) + + # Promote to NVFP4StaticQuantizer (normally done by gptq / sequential_calibrate) + n_promoted = _promote_nvfp4_static_quantizers(model) + assert n_promoted > 0, "Expected at least one quantizer to be promoted" + + quantizer = model.weight_quantizer + assert isinstance(quantizer, NVFP4StaticQuantizer), ( + f"Expected NVFP4StaticQuantizer, got {type(quantizer).__name__}" + ) + + # Restore original weight and compute Hessian + model.weight.data = original_weight.clone() + hessian = torch.zeros(dim, dim, dtype=torch.float32) + n_samples = 0 + hessian, n_samples = update_hessian(inp, hessian, n_samples) + hessian = hessian.to("cuda") + + # --- Run fused path --- + weight_fused = original_weight.float().clone() + num_rows, num_cols = weight_fused.shape + h_inv = prepare_hessian_inverse(hessian, weight_fused, percdamp=0.01) + _blockwise_weight_update_fused(weight_fused, h_inv, quantizer, num_rows, num_cols, block_size) + + # --- Run unfused path --- + weight_unfused = original_weight.float().clone() + h_inv_unfused = prepare_hessian_inverse(hessian, weight_unfused, percdamp=0.01) + _blockwise_weight_update_unfused( + weight_unfused, h_inv_unfused, quantizer, num_cols, block_size, None, False + ) + + # Both paths must produce non-trivial updates + assert not torch.equal(weight_fused, original_weight.float()), ( + "Fused path did not update weights" + ) + assert not torch.equal(weight_unfused, original_weight.float()), ( + "Unfused path did not update weights" + ) + + # Compare Hessian-weighted relative MSE + def _relative_mse(q, w, h): + delta = q - w + return (delta.mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6)).item() + + orig_f = original_weight.float() + mse_fused = _relative_mse(weight_fused, orig_f, hessian) + mse_unfused = _relative_mse(weight_unfused, orig_f, hessian) + + assert mse_fused > 0, "Fused MSE should be positive" + assert mse_unfused > 0, "Unfused MSE should be positive" + + # At small test dimensions, inline Triton FP4 rounding can diverge up to ~15% + # from the PyTorch path. On production-scale layers this drops below 0.1%. + relative_mse_diff = abs(mse_fused - mse_unfused) / max(mse_fused, mse_unfused) + assert relative_mse_diff < 0.20, ( + f"Fused ({mse_fused:.6e}) and unfused ({mse_unfused:.6e}) MSE differ by " + f"{relative_mse_diff:.2%}, expected < 20%" + )