From 0b83d06ff8c42d642f8061c09f0a17d54714f79e Mon Sep 17 00:00:00 2001 From: Hrishith Thadicherla Date: Thu, 12 Mar 2026 00:33:47 +0530 Subject: [PATCH 1/4] Added torch ptq followed by onnx export followd by GQA surgery example in windows and changed torch and onnx export related files which were broken Signed-off-by: Hrishith Thadicherla --- .../windows/torch_onnx/llm_export/README.md | 42 + .../torch_onnx/llm_export/llm_export.py | 837 ++++++++++++++++++ .../torch_onnx/llm_export/requirements.txt | 6 + modelopt/onnx/export/int4_exporter.py | 26 +- modelopt/onnx/graph_surgery/__init__.py | 86 +- .../onnx/llm_export_utils/export_utils.py | 58 +- .../llm_export_utils/quantization_utils.py | 7 +- modelopt/torch/quantization/export_onnx.py | 43 +- .../nn/modules/tensor_quantizer.py | 1 + modelopt/torch/quantization/tensor_quant.py | 13 +- 10 files changed, 1052 insertions(+), 67 deletions(-) create mode 100644 examples/windows/torch_onnx/llm_export/README.md create mode 100644 examples/windows/torch_onnx/llm_export/llm_export.py create mode 100644 examples/windows/torch_onnx/llm_export/requirements.txt diff --git a/examples/windows/torch_onnx/llm_export/README.md b/examples/windows/torch_onnx/llm_export/README.md new file mode 100644 index 0000000000..439197b6c5 --- /dev/null +++ b/examples/windows/torch_onnx/llm_export/README.md @@ -0,0 +1,42 @@ +# LLM Export (Windows) + +Export LLMs from PyTorch to ONNX with quantization and GQA surgery. + +## Supported Precisions + +- `nvfp4` — NVIDIA FP4 quantization +- `int4_awq` — INT4 AWQ quantization +- `int8_sq` — INT8 SmoothQuant + +## Usage + +### NVFP4 + +```bash +python llm_export.py --hf_model_path "meta-llama/Llama-3.2-3B-Instruct" --dtype nvfp4 --output_dir ./llama3.2-3b-nvfp4 +``` + +### INT4 AWQ + +```bash +python llm_export.py --hf_model_path "meta-llama/Llama-3.2-3B-Instruct" --dtype int4_awq --output_dir ./llama3.2-3b-int4 +``` + +### INT8 SmoothQuant + +```bash +python llm_export.py --hf_model_path "Qwen/Qwen2.5-3B-Instruct" --dtype int8_sq --output_dir ./qwen-3b-int8 +``` + +## Options + +| Argument | Description | +|---|---| +| `--hf_model_path` | HuggingFace model name or local path | +| `--dtype` | Quantization precision (`fp16`, `fp8`, `int4_awq`, `int8_sq`, `nvfp4`) | +| `--output_dir` | Directory to save the exported ONNX model | +| `--calib_size` | Calibration dataset size (default: 512) | +| `--save_original` | Save the pre-surgery ONNX for debugging | +| `--trust_remote_code` | Trust remote code when loading from HuggingFace | +| `--onnx_path` | Skip export, run surgery on an existing ONNX | +| `--config_path` | Path to config.json if not alongside the model | diff --git a/examples/windows/torch_onnx/llm_export/llm_export.py b/examples/windows/torch_onnx/llm_export/llm_export.py new file mode 100644 index 0000000000..739125e2ef --- /dev/null +++ b/examples/windows/torch_onnx/llm_export/llm_export.py @@ -0,0 +1,837 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Windows-optimized LLM export script for torch→ONNX pathway. + +This script extends the base torch_onnx/llm_export.py with Windows/NVFP4-specific +post-processing: + - Overrides _trt_high_precision_dtype to "Half" on all TensorQuantizers after + quantization (so FP4 scale tensors are FP16 instead of FP32) + - NVFP4 surgeon: converts TRT-domain DQ nodes to native ONNX, upgrades opset to 23, + fixes Transpose output dtypes for projection weight paths + - Sets ir_version = 10 for compatibility +""" + +import argparse +import json +import os +import re +import shutil +import tempfile +import time +from contextlib import contextmanager + +import onnx +import onnx_graphsurgeon as gs +import torch +from packaging.version import Version +from transformers import AutoConfig, AutoTokenizer + +import modelopt +from modelopt.onnx.export import INT4QuantExporter, NVFP4QuantExporter +from modelopt.onnx.graph_surgery import replace_attention_with_gqa +from modelopt.onnx.llm_export_utils.export_utils import ( + ModelLoader, + WrapperModelForCausalLM, + llm_to_onnx, +) +from modelopt.onnx.llm_export_utils.quantization_utils import quantize +from modelopt.onnx.llm_export_utils.surgeon_utils import fold_fp8_qdq_to_dq +from modelopt.torch.export import export_hf_checkpoint +from modelopt.torch.quantization.nn import TensorQuantizer +from modelopt.torch.quantization.utils import is_quantized_linear + + +def compress_int8_weights(onnx_model): + """Compress INT8 QDQ weights: fold QuantizeLinear+DequantizeLinear into DequantizeLinear with INT8 initializers. + + Finds patterns: initializer(FP16) -> QuantizeLinear -> DequantizeLinear -> consumer + Replaces with: initializer(INT8) -> DequantizeLinear -> consumer + """ + import numpy as np + from onnx import numpy_helper + + graph = onnx_model.graph + init_map = {i.name: i for i in graph.initializer} + nodes_to_remove = [] + + producer = {} + for node in graph.node: + for out in node.output: + producer[out] = node + + for dq_node in graph.node: + if dq_node.op_type != "DequantizeLinear": + continue + + q_input = dq_node.input[0] + q_node = producer.get(q_input) + if q_node is None or q_node.op_type != "QuantizeLinear": + continue + + weight_name = q_node.input[0] + if weight_name not in init_map: + continue + + weight_arr = numpy_helper.to_array(init_map[weight_name]) + scale_name = q_node.input[1] + + if scale_name in init_map: + scale_arr = numpy_helper.to_array(init_map[scale_name]) + else: + scale_prod = producer.get(scale_name) + if scale_prod and scale_prod.op_type == "Constant": + for attr in scale_prod.attribute: + if attr.name == "value": + scale_arr = numpy_helper.to_array(attr.t) + else: + continue + + axis = None + for attr in q_node.attribute: + if attr.name == "axis": + axis = attr.i + + if axis is not None and scale_arr.ndim == 1: + shape = [1] * weight_arr.ndim + shape[axis] = -1 + scale_broad = scale_arr.reshape(shape) + else: + scale_broad = scale_arr + + quantized = np.clip(np.round(weight_arr / scale_broad), -128, 127).astype(np.int8) + + int8_name = weight_name + "_int8" + int8_tensor = numpy_helper.from_array(quantized, int8_name) + graph.initializer.append(int8_tensor) + + dq_node.input[0] = int8_name + dq_node.input[1] = q_node.input[1] + if len(dq_node.input) > 2 and len(q_node.input) > 2: + dq_node.input[2] = q_node.input[2] + elif len(q_node.input) > 2: + dq_node.input.append(q_node.input[2]) + + dq_has_axis = any(a.name == "axis" for a in dq_node.attribute) + if axis is not None and not dq_has_axis: + axis_attr = dq_node.attribute.add() + axis_attr.name = "axis" + axis_attr.i = axis + + nodes_to_remove.append(q_node.name) + + if nodes_to_remove: + new_nodes = [n for n in graph.node if n.name not in nodes_to_remove] + del graph.node[:] + graph.node.extend(new_nodes) + + used_inputs = set() + for n in graph.node: + for inp in n.input: + used_inputs.add(inp) + new_inits = [i for i in graph.initializer if i.name in used_inputs or i.name.endswith("_int8")] + del graph.initializer[:] + graph.initializer.extend(new_inits) + + print(f" Compressed {len(nodes_to_remove)} weight Q+DQ pairs to DQ with INT8 weights") + + from onnx import numpy_helper as _nh + + vi_map = {vi.name: vi for vi in graph.value_info} + init_map = {i.name: i for i in graph.initializer} + + cast_pattern = re.compile( + r"/model/layers\.\d+/(self_attn/o_proj|mlp/down_proj)/input_quantizer/Cast$" + ) + mul_pattern = re.compile( + r"/model/layers\.\d+/(self_attn/o_proj|mlp/down_proj)/input_quantizer/Mul$" + ) + + casts_to_remove = [] + for node in graph.node: + if node.op_type == "Mul" and mul_pattern.search(node.name): + init_name = node.input[1] + if init_name in init_map: + init = init_map[init_name] + if init.data_type != onnx.TensorProto.FLOAT16: + arr = _nh.to_array(init).astype("float16") + init.CopyFrom(_nh.from_array(arr, init_name)) + for out in node.output: + if out in vi_map: + vi_map[out].type.tensor_type.elem_type = onnx.TensorProto.FLOAT16 + + if node.op_type == "Cast" and cast_pattern.search(node.name): + cast_input = node.input[0] + cast_output = node.output[0] + for other in graph.node: + for i, inp in enumerate(other.input): + if inp == cast_output: + other.input[i] = cast_input + casts_to_remove.append(node.name) + + if casts_to_remove: + new_nodes = [n for n in graph.node if n.name not in casts_to_remove] + del graph.node[:] + graph.node.extend(new_nodes) + print(f" Removed {len(casts_to_remove)} input_quantizer Cast nodes") + + qkv_qdq_pattern = re.compile( + r"/model/layers\.\d+/(self_attn/(q_proj|k_proj|v_proj)|mlp/(gate_proj|up_proj|down_proj))/input_quantizer/" + ) + qkv_q_nodes = {} + qkv_dq_nodes = {} + + for node in graph.node: + if qkv_qdq_pattern.search(node.name): + if node.op_type == "QuantizeLinear": + qkv_q_nodes[node.output[0]] = node + elif node.op_type == "DequantizeLinear": + qkv_dq_nodes[node.output[0]] = (node, node.input[0]) + + qkv_nodes_to_remove = set() + for dq_out, (dq_node, q_out) in qkv_dq_nodes.items(): + if q_out in qkv_q_nodes: + q_node = qkv_q_nodes[q_out] + mul_output = q_node.input[0] + for other in graph.node: + for i, inp in enumerate(other.input): + if inp == dq_out: + other.input[i] = mul_output + qkv_nodes_to_remove.add(q_node.name) + qkv_nodes_to_remove.add(dq_node.name) + + if qkv_nodes_to_remove: + new_nodes = [n for n in graph.node if n.name not in qkv_nodes_to_remove] + del graph.node[:] + graph.node.extend(new_nodes) + print( + f" Removed {len(qkv_nodes_to_remove)} Q/DQ nodes from " + f"q/k/v_proj + gate/up/down_proj activations" + ) + + return onnx_model + + +def llm_arguments(): + """Parse the arguments for the llm export script.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--hf_model_path", + type=str, + help="The folder of HF PyTorch model ckpt or HuggingFace model name/path (e.g., 'Qwen/Qwen3-0.6B')", + required=False, + ) + parser.add_argument( + "--dtype", + type=str, + default="fp16", + choices=["fp16", "fp8", "int4_awq", "int8_sq", "nvfp4"], + help="The precision of onnx export", + ) + + parser.add_argument( + "--lm_head", + type=str, + default="fp16", + choices=["fp16"], + help="The precision of lm_head. Currently only fp16 is tested and supported", + ) + parser.add_argument( + "--output_dir", + type=str, + help="The directory to store the generated ONNX model", + required=True, + ) + + parser.add_argument( + "--onnx_path", + type=str, + help="Pass this option when you have existing onnx to surgeon", + required=False, + ) + parser.add_argument( + "--save_original", + action="store_true", + default=False, + help="Save the original ONNX from torch.onnx.export without any modification", + ) + parser.add_argument( + "--dataset_dir", type=str, help="The path of dataset for quantization", required=False + ) + parser.add_argument( + "--config_path", + type=str, + help="The path of config.json, in case it is not with the PyTorch or ONNX file", + default=None, + ) + parser.add_argument( + "--calib_size", type=int, help="The size of calibration dataset", default=512 + ) + parser.add_argument( + "--trust_remote_code", + action="store_true", + default=False, + help="Trust remote code when loading model from HuggingFace Hub", + ) + return parser + + +def get_config_path(args): + """Get config.json file path from the arguments. + + The default priority is: config_path > hf_model_path/config.json > onnx_path/../config.json + """ + if args.config_path and os.path.exists(args.config_path): + return args.config_path + if args.hf_model_path: + if os.path.isdir(args.hf_model_path): + torch_config = os.path.join(args.hf_model_path, "config.json") + if os.path.exists(torch_config): + return torch_config + else: + try: + config = AutoConfig.from_pretrained( + args.hf_model_path, trust_remote_code=args.trust_remote_code + ) + temp_config_path = os.path.join( + tempfile.gettempdir(), f"config_{args.hf_model_path.replace('/', '_')}.json" + ) + with open(temp_config_path, "w") as f: + json.dump(config.to_dict(), f, indent=2) + return temp_config_path + except Exception as e: + print(f"Warning: Could not download config for {args.hf_model_path}: {e}") + + if args.onnx_path: + onnx_config = os.path.join(os.path.dirname(args.onnx_path), "config.json") + if os.path.exists(onnx_config): + return onnx_config + print("Warning: cannot find config.json. Please pass in --config_path.") + return None + + +def _override_trt_high_precision_dtype(model, dtype_str="Half"): + """Override _trt_high_precision_dtype on all TensorQuantizers in the model. + + For the Windows NVFP4 pathway, we set this to "Half" so that all Q/DQ scale + tensors are exported as FP16 instead of FP32, avoiding mixed-precision casts + in the ONNX graph. + + Args: + model: The quantized PyTorch model. + dtype_str: The target dtype string ("Half", "Float", "BFloat16"). + """ + count = 0 + for name, module in model.named_modules(): + if isinstance(module, TensorQuantizer): + module._trt_high_precision_dtype = dtype_str + count += 1 + print(f"Overrode _trt_high_precision_dtype to '{dtype_str}' on {count} TensorQuantizers.") + + +def export_raw_llm( + model, + output_dir, + dtype, + config_path, + hf_model_path, + lm_head_precision="fp16", + dataset_dir="", + wrapper_cls=WrapperModelForCausalLM, + extra_inputs={}, + extra_dyn_axes={}, + calib_size=512, + trust_remote_code=False, +): + """Export raw llm model to ONNX and perform quantization. + + Args: + model: torch.nn.module + output_dir: str + dtype: str + config_path: str + hf_model_path: str, Used for loading tokenizer for quantization + dataset_dir: str, Used for quantization + wrapper_cls: class, Used for wrapping the model + extra_inputs: dict, Used for extra inputs + extra_dyn_axes: dict, Used for extra dynamic axes + calib_size: int, Used for quantization calibration size + trust_remote_code: bool, Trust remote code when loading tokenizer + """ + os.makedirs(output_dir, exist_ok=True) + + if dtype == "fp16": + print("Loading fp16 ONNX model...") + + llm_to_onnx( + wrapper_cls(model), output_dir, extra_inputs=extra_inputs, extra_dyn_axes=extra_dyn_axes + ) + shutil.copy(config_path, os.path.join(output_dir, "config.json")) + + if dtype in ["fp8", "int4_awq", "int8_sq", "nvfp4"]: + tokenizer = AutoTokenizer.from_pretrained( + hf_model_path, trust_remote_code=trust_remote_code + ) + if os.path.isdir(hf_model_path): + modelopt_state = os.path.join(hf_model_path, "modelopt_state.pth") + model_needs_quantization = not os.path.exists(modelopt_state) + else: + model_needs_quantization = True + + if model_needs_quantization: + model = quantize( + model, tokenizer, dtype, lm_head_precision, dataset_dir, calib_size=calib_size + ) + + _override_trt_high_precision_dtype(model, "Half") + + if dtype == "nvfp4": + for module in model.modules(): + assert not isinstance(module, torch.nn.Linear) or is_quantized_linear(module) + if isinstance(module, torch.nn.Linear): + module.input_quantizer._trt_high_precision_dtype = "Half" + module.input_quantizer._onnx_quantizer_type = "dynamic" + module.weight_quantizer._onnx_quantizer_type = "static" + + if dtype in {"fp8", "int4_awq", "int8_sq", "nvfp4"}: + print(f"Exporting {dtype} ONNX model from quantized PyTorch model...") + llm_to_onnx( + wrapper_cls( + model, + ), + output_dir, + extra_inputs=extra_inputs, + extra_dyn_axes=extra_dyn_axes, + ) + shutil.copy(config_path, os.path.join(output_dir, "config.json")) + + quantized_model_dir = f"{output_dir}_{dtype}_quantized" + os.makedirs(quantized_model_dir, exist_ok=True) + with torch.inference_mode(): + export_hf_checkpoint(model, dtype=torch.float16, export_dir=quantized_model_dir) + + return model.state_dict() + + +def surgeon_llm( + raw_onnx_path, + output_dir, + dtype, + config_path, + hf_model_path=None, + lm_head_precision="fp16", + trust_remote_code=False, +): + """Surgeon raw llm onnx to fit TRT. + + For example, insert quantization q/dq nodes. + Includes Windows-specific NVFP4 post-processing: + - Convert DQ nodes from TRT domain to native ONNX + - Upgrade opset to 23 (minimum for FP4 DequantizeLinear) + - Remove trt opset import + - Fix Transpose output dtype for projection weight paths + - GQA surgery: replace attention with GroupQueryAttention + + Args: + raw_onnx_path: str + output_dir: str + dtype: str + config_path: str + hf_model_path: str, HuggingFace model ID for GQA surgery (RoPE caches, config) + lm_head_precision: str + trust_remote_code: bool, Trust remote code when loading HF config + """ + + t0 = time.time() + onnx.shape_inference.infer_shapes_path(raw_onnx_path) + graph = gs.import_onnx(onnx.load(raw_onnx_path)) + t1 = time.time() + print(f"Importing ONNX graph takes {t1 - t0}s.") + graph.fold_constants().cleanup().toposort() + + if dtype == "fp8" or lm_head_precision == "fp8": + graph = fold_fp8_qdq_to_dq(graph) + + os.makedirs(output_dir, exist_ok=True) + t2 = time.time() + + onnx_model = gs.export_onnx(graph) + + @contextmanager + def time_operation(operation_name): + start_time = time.time() + yield + end_time = time.time() + print(f"{operation_name} takes {end_time - start_time}s.") + + if dtype == "nvfp4": + with time_operation("quantizing weights to nvfp4"): + onnx_model = NVFP4QuantExporter.process_model(onnx_model) + + for node in onnx_model.graph.node: + if node.op_type == "DequantizeLinear" and node.domain != "": + node.domain = "" + + existing_vi = {vi.name: vi for vi in onnx_model.graph.value_info} + fp4_count = 0 + for node in onnx_model.graph.node: + if node.op_type == "TRT_FP4DynamicQuantize": + for output_name in node.output: + if output_name in existing_vi: + existing_vi[ + output_name + ].type.tensor_type.elem_type = onnx.TensorProto.FLOAT4E2M1 + else: + vi = onnx.helper.make_empty_tensor_value_info(output_name) + vi.type.tensor_type.elem_type = onnx.TensorProto.FLOAT4E2M1 + onnx_model.graph.value_info.append(vi) + fp4_count += 1 + print(f" Set {fp4_count} TRT_FP4DynamicQuantize output(s) -> float4e2m1") + + has_trt_nodes = any(node.domain == "trt" for node in onnx_model.graph.node) + + new_opsets = [] + for opset in onnx_model.opset_import: + if opset.domain == "trt": + if has_trt_nodes: + new_opsets.append(opset) + continue + if opset.domain == "": + opset.version = max(opset.version, 23) + new_opsets.append(opset) + if has_trt_nodes and not any(op.domain == "trt" for op in new_opsets): + trt_opset = onnx.OperatorSetIdProto() + trt_opset.domain = "trt" + trt_opset.version = 1 + new_opsets.append(trt_opset) + print(" Added missing opset_import: domain='trt', version=1") + del onnx_model.opset_import[:] + onnx_model.opset_import.extend(new_opsets) + + vi_map = {vi.name: vi for vi in onnx_model.graph.value_info} + transpose_pattern = re.compile( + r"layers\.\d+/" + r"(mlp/(down_proj|up_proj|gate_proj)|self_attn/(q_proj|k_proj|v_proj|o_proj))" + r"/Transpose" + ) + for node in onnx_model.graph.node: + if node.op_type == "Transpose" and transpose_pattern.search(node.name): + for out in node.output: + if out in vi_map: + vi_map[out].type.tensor_type.elem_type = onnx.TensorProto.FLOAT + + elif dtype == "int4_awq": + with time_operation("quantizing weights to int4"): + onnx_model = INT4QuantExporter.process_model(onnx_model) + + elif dtype == "int8_sq": + with time_operation("compressing INT8 weights (Q+DQ -> DQ with INT8 weights)"): + onnx_model = compress_int8_weights(onnx_model) + + if dtype in ("int4_awq", "int8_sq"): + for node in onnx_model.graph.node: + if node.op_type == "DequantizeLinear" and node.domain != "": + node.domain = "" + new_opsets = [] + min_opset = 23 if dtype == "int4_awq" else 19 + for opset in onnx_model.opset_import: + if opset.domain == "trt": + continue + if opset.domain == "": + opset.version = max(opset.version, min_opset) + new_opsets.append(opset) + del onnx_model.opset_import[:] + onnx_model.opset_import.extend(new_opsets) + print(f" Converted DQ nodes to native ONNX domain, opset set to {min_opset}") + + from onnx import numpy_helper as _nh + + vi_map = {vi.name: vi for vi in onnx_model.graph.value_info} + init_map = {i.name: i for i in onnx_model.graph.initializer} + mul_pattern = re.compile(r"/input_quantizer/Mul$") + mul_fixed = 0 + for node in onnx_model.graph.node: + if node.op_type == "Mul" and mul_pattern.search(node.name): + scale_name = node.input[1] + if scale_name in init_map: + init = init_map[scale_name] + if init.data_type != onnx.TensorProto.FLOAT16: + arr = _nh.to_array(init).astype("float16") + init.CopyFrom(_nh.from_array(arr, scale_name)) + for out in node.output: + if out in vi_map: + vi_map[out].type.tensor_type.elem_type = onnx.TensorProto.FLOAT16 + mul_fixed += 1 + if mul_fixed: + print(f" Fixed {mul_fixed} input_quantizer/Mul nodes to FP16 output") + + layer_cast_pattern = re.compile(r"^/model/layers\.\d+/Cast(_\d+)?$") + casts_removed = 0 + for node in list(onnx_model.graph.node): + if node.op_type == "Cast" and layer_cast_pattern.match(node.name): + cast_input = node.input[0] + cast_output = node.output[0] + for other in onnx_model.graph.node: + for i, inp in enumerate(other.input): + if inp == cast_output: + other.input[i] = cast_input + for out in onnx_model.graph.output: + if out.name == cast_output: + out.name = cast_input + onnx_model.graph.node.remove(node) + casts_removed += 1 + if casts_removed: + print(f" Removed {casts_removed} /model/layers.*/Cast nodes") + + vi_map = {vi.name: vi for vi in onnx_model.graph.value_info} + add_pattern = re.compile(r"^/model/layers\.\d+/Add$") + oproj_pattern = re.compile(r"^/model/layers\.\d+/self_attn/o_proj/MatMul$") + dtype_fixed = 0 + for node in onnx_model.graph.node: + if (node.op_type == "Add" and add_pattern.match(node.name)) or ( + node.op_type == "MatMul" and oproj_pattern.match(node.name) + ): + for tensor_name in list(node.input) + list(node.output): + if tensor_name in vi_map: + vi_map[tensor_name].type.tensor_type.elem_type = onnx.TensorProto.FLOAT16 + dtype_fixed += 1 + if dtype_fixed: + print(f" Fixed {dtype_fixed} value_info entries to FP16 for Add/o_proj nodes") + + last_add1_pattern = re.compile(r"^/model/layers\.(\d+)/Add_1$") + last_add1_node = None + last_layer_idx = -1 + for node in onnx_model.graph.node: + if node.op_type == "Add": + m = last_add1_pattern.match(node.name) + if m and int(m.group(1)) > last_layer_idx: + last_layer_idx = int(m.group(1)) + last_add1_node = node + + if last_add1_node is not None: + add1_output = last_add1_node.output[0] + if add1_output in vi_map: + vi_map[add1_output].type.tensor_type.elem_type = onnx.TensorProto.FLOAT16 + + new_add1_output = add1_output + "_fp16" + cast_output = add1_output + + last_add1_node.output[0] = new_add1_output + + cast_node = onnx.helper.make_node( + "Cast", + inputs=[new_add1_output], + outputs=[cast_output], + name=f"/model/layers.{last_layer_idx}/Add_1/Cast_to_fp32", + to=onnx.TensorProto.FLOAT, + ) + onnx_model.graph.node.append(cast_node) + + fp16_vi = onnx.helper.make_empty_tensor_value_info(new_add1_output) + fp16_vi.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16 + onnx_model.graph.value_info.append(fp16_vi) + + if cast_output in vi_map: + vi_map[cast_output].type.tensor_type.elem_type = onnx.TensorProto.FLOAT + + print( + f" Fixed layers.{last_layer_idx}/Add_1: output→FP16, " + f"inserted Cast→FP32 for norm/lm_head" + ) + + # Fix logits output shape for all dtypes + vocab_size = None + if config_path and os.path.exists(config_path): + with open(config_path) as f: + model_config = json.load(f) + vocab_size = model_config.get("vocab_size") + + for output in onnx_model.graph.output: + if output.name == "logits": + shape = output.type.tensor_type.shape + if shape and len(shape.dim) == 3: + old_dims = [d.dim_param if d.dim_param else str(d.dim_value) for d in shape.dim] + shape.dim[0].ClearField("dim_value") + shape.dim[0].dim_param = "batch_size" + shape.dim[1].ClearField("dim_value") + shape.dim[1].dim_param = "sequence_length" + if vocab_size is not None: + shape.dim[2].ClearField("dim_param") + shape.dim[2].dim_value = vocab_size + print( + f" Fixed logits shape: {old_dims} -> " + f"[batch_size, sequence_length, {vocab_size or 'unchanged'}]" + ) + break + + print( + f"Saving ONNX files in {output_dir}. All existing ONNX in the folder will be overwritten." + ) + for filename in os.listdir(output_dir): + file_path = os.path.join(output_dir, filename) + try: + if ( + os.path.isfile(file_path) or os.path.islink(file_path) + ) and ".json" not in file_path: + os.unlink(file_path) + + except Exception as e: + print(f"Failed to delete {file_path}. Reason: {e}") + + onnx_model.ir_version = 10 + + pre_gqa_onnx = os.path.join(output_dir, "_pre_gqa_model.onnx") + pre_gqa_data = "_pre_gqa_model.data" + onnx.save_model( + onnx_model, + pre_gqa_onnx, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=pre_gqa_data, + convert_attribute=True, + ) + + if os.path.exists(config_path): + if os.path.isfile(config_path) and config_path.endswith("config.json"): + shutil.copy(config_path, os.path.join(output_dir, "config.json")) + elif os.path.isdir(config_path): + shutil.copy( + os.path.join(config_path, "config.json"), os.path.join(output_dir, "config.json") + ) + else: + print(f"Warning: Unexpected config_path format: {config_path}") + + t3 = time.time() + print(f"Surgeon LLM completed in {t3 - t2}s.") + + final_onnx = os.path.join(output_dir, "model.onnx") + if hf_model_path: + print("\n" + "=" * 60) + print("Running GQA surgery: replacing attention with GroupQueryAttention...") + print("=" * 60) + t_gqa_start = time.time() + + replace_attention_with_gqa( + model_path=pre_gqa_onnx, + output_path=final_onnx, + hf_model_id=hf_model_path, + max_seq_len=4096, + io_dtype="float16", + use_external_data=True, + external_data_name="model.onnx_data", + ir_version=10, + trust_remote_code=trust_remote_code, + ) + + t_gqa_end = time.time() + print(f"GQA surgery completed in {t_gqa_end - t_gqa_start:.1f}s.") + else: + print("Warning: hf_model_path not provided, skipping GQA surgery.") + os.rename(pre_gqa_onnx, final_onnx) + os.rename( + os.path.join(output_dir, pre_gqa_data), + os.path.join(output_dir, "model.onnx_data"), + ) + + for temp_file in [pre_gqa_onnx, os.path.join(output_dir, pre_gqa_data)]: + if os.path.exists(temp_file): + os.unlink(temp_file) + print(f" Removed intermediate: {os.path.basename(temp_file)}") + + +def check_dtype_support(args): + """Check whether the dtype is supported by DriveOS LLM SDK. + + Returns False if it is not supported because of: + 1. Modelopt < 0.23.0 does not support nvfp4 + """ + + def get_modelopt_version(): + try: + return Version(modelopt.__version__) + except Exception as e: + print(f"Modelopt version cannot be parsed. Reason: {e!s}") + + if (args.dtype == "nvfp4") and get_modelopt_version() < Version("0.23.0"): + print( + "nvfp4 is not supported by installed modelopt version. Please upgrade to 0.23.0 or above for nvfp4 export." + ) + return False + + return True + + +def main(args): + """Main function to export the LLM model to ONNX.""" + assert args.hf_model_path or args.onnx_path, ( + "You need to provide either --hf_model_path or --onnx_path to process the export script." + ) + start_time = time.time() + + if not check_dtype_support(args): + return + + if args.onnx_path: + raw_onnx_path = args.onnx_path + + model_loader = ModelLoader(args.hf_model_path, args.config_path) + + if args.hf_model_path: + model = model_loader.load_model(trust_remote_code=args.trust_remote_code) + onnx_dir = args.output_dir + "_raw" if args.save_original else args.output_dir + raw_onnx_path = f"{onnx_dir}/model.onnx" + extra_inputs, extra_dyn_axes = {}, {} + export_raw_llm( + model=model, + output_dir=onnx_dir, + dtype=args.dtype, + config_path=args.config_path, + hf_model_path=args.hf_model_path, + lm_head_precision=args.lm_head, + dataset_dir=args.dataset_dir, + wrapper_cls=WrapperModelForCausalLM, + extra_inputs=extra_inputs, + extra_dyn_axes=extra_dyn_axes, + calib_size=args.calib_size, + trust_remote_code=args.trust_remote_code, + ) + + surgeon_llm( + raw_onnx_path=raw_onnx_path, + output_dir=args.output_dir, + dtype=args.dtype, + config_path=args.config_path, + hf_model_path=args.hf_model_path, + lm_head_precision=args.lm_head, + trust_remote_code=args.trust_remote_code, + ) + + end_time = time.time() + print( + f"LLM ONNX saved to {args.output_dir} with {args.dtype} precision in {end_time - start_time}s." + ) + + if args.dtype == "int8_sq": + print( + "\nNOTE: INT8 SmoothQuant models currently work only with the CUDA EP. " + "There are some decoding issues with fp16 max precision when running with NvTensorRtRtx." + ) + + +if __name__ == "__main__": + parser = llm_arguments() + args = parser.parse_args() + args.config_path = get_config_path(args) + main(args) diff --git a/examples/windows/torch_onnx/llm_export/requirements.txt b/examples/windows/torch_onnx/llm_export/requirements.txt new file mode 100644 index 0000000000..d8ea5d6f2a --- /dev/null +++ b/examples/windows/torch_onnx/llm_export/requirements.txt @@ -0,0 +1,6 @@ +--extra-index-url https://download.pytorch.org/whl/cu128 +datasets>=2.14.4 +timm +torch==2.8 +torchvision==0.23.0 +transformers diff --git a/modelopt/onnx/export/int4_exporter.py b/modelopt/onnx/export/int4_exporter.py index 0da217ae76..8fab80f958 100644 --- a/modelopt/onnx/export/int4_exporter.py +++ b/modelopt/onnx/export/int4_exporter.py @@ -107,6 +107,19 @@ def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: else: matmul_node = next_node + # Handle optional Cast between Transpose and MatMul + # (inserted when trt_high_precision_dtype is set to "Half") + if matmul_node.op_type == "Cast": + cast_after_transpose = matmul_node + nodes_to_remove.append(cast_after_transpose.name) + cast_child_nodes = [ + n for n in graph.node if cast_after_transpose.output[0] in n.input + ] + assert len(cast_child_nodes) == 1, ( + f"Expected exactly one child after Cast for {node.name}" + ) + matmul_node = cast_child_nodes[0] + assert matmul_node.op_type in ["MatMul", "Gemm"], ( f"Expected MatMul or Gemm node for {node.name}" ) @@ -249,16 +262,15 @@ def is_fp32_cast(node: onnx.NodeProto) -> bool: node.output.extend(cast_node.output) nodes_to_remove.append(cast_node.name) - # Remove unnecessay Cast after Pre-quant scale + # Remove unnecessary Cast after Pre-quant scale (if present) for node in graph.node: if is_pre_quant_scale_node(node): pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input] - assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}" - cast_node = pqs_child_nodes[0] - assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}" - node.output.clear() - node.output.extend(cast_node.output) - nodes_to_remove.append(cast_node.name) + if len(pqs_child_nodes) == 1 and pqs_child_nodes[0].op_type == "Cast": + cast_node = pqs_child_nodes[0] + node.output.clear() + node.output.extend(cast_node.output) + nodes_to_remove.append(cast_node.name) # Remove unnecessary casts new_nodes = [node for node in graph.node if node.name not in nodes_to_remove] diff --git a/modelopt/onnx/graph_surgery/__init__.py b/modelopt/onnx/graph_surgery/__init__.py index 06ac87c0b3..a92b2d5c19 100644 --- a/modelopt/onnx/graph_surgery/__init__.py +++ b/modelopt/onnx/graph_surgery/__init__.py @@ -24,44 +24,54 @@ - Transposing DequantizeLinear weights for column-major storage optimization - Graph cleanup and optimization -Example usage: - >>> from modelopt.onnx.graph_surgery import ( - ... replace_attention_with_gqa, - ... convert_fp16_to_bf16, - ... transpose_dequantize_linear_weights, - ... add_cross_kv_to_encoder, - ... ) - >>> # Replace attention with GQA for LLMs (FP16 model) - >>> replace_attention_with_gqa( - ... model_path="model_fp16.onnx", - ... output_path="model_gqa.onnx", - ... hf_model_id="meta-llama/Llama-2-7b-hf", - ... io_dtype="float16", - ... ) - >>> # Replace attention with GQA and convert to BF16 in one step - >>> replace_attention_with_gqa( - ... model_path="model_fp16.onnx", - ... output_path="model_gqa_bf16.onnx", - ... hf_model_id="meta-llama/Llama-2-7b-hf", - ... io_dtype="bfloat16", # Automatically converts FP16 to BF16 - ... ) - >>> # Add cross-attention KV cache outputs to encoder (GenAI compatible) - >>> add_cross_kv_to_encoder( - ... encoder_path="encoder_model.onnx", - ... output_path="encoder_with_kv.onnx", - ... hf_model_id="openai/whisper-large-v3-turbo", - ... ) - >>> # Standalone FP16 to BF16 conversion - >>> convert_fp16_to_bf16( - ... input_path="model_fp16.onnx", - ... output_path="model_bf16.onnx", - ... ) - >>> - >>> # Transpose DequantizeLinear weights for column-major storage - >>> transpose_dequantize_linear_weights( - ... model_path="model_quantized.onnx", - ... output_path="model_quantized_transposed.onnx", - ... ) +CLI Usage:: + + python -m modelopt.onnx.graph_surgery [options] + +Available commands: + +Replace attention with GQA (for FP16/BF16 LLMs):: + + python -m modelopt.onnx.graph_surgery replace-gqa \ + --input model.onnx \ + --output model_gqa.onnx \ + --model-id meta-llama/Llama-2-7b-hf + +Replace attention with GQA (for INT4/AWQ quantized LLMs):: + + python -m modelopt.onnx.graph_surgery replace-gqa \ + --input model.onnx \ + --output model_gqa.onnx \ + --model-id meta-llama/Llama-3.1-8B + +Add cross-attention KV cache to encoder:: + + python -m modelopt.onnx.graph_surgery add-cross-kv \ + --input encoder_model.onnx \ + --output encoder_with_kv.onnx \ + --model-id openai/whisper-large-v3-turbo + +Convert FP16 to BF16:: + + python -m modelopt.onnx.graph_surgery convert-bf16 \ + --input model_fp16.onnx \ + --output model_bf16.onnx + +Transpose DequantizeLinear weights (column-major optimization):: + + python -m modelopt.onnx.graph_surgery transpose-dq \ + --input model_quantized.onnx \ + --output model_quantized_transposed.onnx + +Analyze attention pattern:: + + python -m modelopt.onnx.graph_surgery analyze \ + --input model.onnx \ + --layer 0 + +For full options on any command, run:: + + python -m modelopt.onnx.graph_surgery --help """ from .dq_transpose import transpose_dequantize_linear_weights diff --git a/modelopt/onnx/llm_export_utils/export_utils.py b/modelopt/onnx/llm_export_utils/export_utils.py index 4009b119e7..c5b81e3122 100644 --- a/modelopt/onnx/llm_export_utils/export_utils.py +++ b/modelopt/onnx/llm_export_utils/export_utils.py @@ -76,19 +76,61 @@ def __init__(self, model): self.lm_head = model.lm_head self.config = model.config + # Patch DynamicLayer.lazy_initialization so it does NOT create empty + # tensors (which torch.jit.trace bakes as constants). Instead, set + # keys/values to None; the update() cat path handles the rest. + from transformers.cache_utils import DynamicLayer + + def _patched_update(self_layer, key_states, value_states, cache_kwargs=None): + if not self_layer.is_initialized: + self_layer.dtype = key_states.dtype + self_layer.device = key_states.device + self_layer.is_initialized = True + self_layer.keys = key_states + self_layer.values = value_states + return self_layer.keys, self_layer.values + self_layer.keys = torch.cat([self_layer.keys, key_states], dim=-2) + self_layer.values = torch.cat([self_layer.values, value_states], dim=-2) + return self_layer.keys, self_layer.values + + DynamicLayer.update = _patched_update + + # Monkey-patch create_causal_mask to return None during export. + # This avoids baking mask shapes as constants during JIT tracing. + # SDPA uses is_causal=True internally so the explicit mask is unnecessary. + import importlib + + import transformers.masking_utils + + setattr(transformers.masking_utils, "create_causal_mask", lambda *args, **kwargs: None) + model_type = getattr(self.config, "model_type", "llama") + try: + mod = importlib.import_module(f"transformers.models.{model_type}.modeling_{model_type}") + setattr(mod, "create_causal_mask", lambda *args, **kwargs: None) + except (ImportError, ModuleNotFoundError): + pass + + # Force use_gqa_in_sdpa to return False so SDPA does manual repeat_kv + # instead of using enable_gqa=True (which torch.onnx.export doesn't support). + # With attention_mask=None and enable_gqa=False, SDPA uses is_causal=True. + import transformers.integrations.sdpa_attention as sdpa_mod + + sdpa_mod.use_gqa_in_sdpa = lambda *args, **kwargs: False + def forward(self, input_ids: torch.Tensor | None, past_key_values: tuple): """Forward pass.""" - # Convert tuple cache to DynamicCache for models that require it (e.g., Qwen3) - cache = DynamicCache(config=self.config) - cache.key_cache = [kv[0] for kv in past_key_values] - cache.value_cache = [kv[1] for kv in past_key_values] - past_key_values = cache + cache = DynamicCache(ddp_cache_data=past_key_values, config=self.config) - outputs = self.model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True) + outputs = self.model(input_ids=input_ids, past_key_values=cache, use_cache=True) hidden_states = outputs[0] - past_key_values = outputs.past_key_values.to_legacy_cache() + + if hasattr(outputs.past_key_values, "to_legacy_cache"): + past_key_values_out = outputs.past_key_values.to_legacy_cache() + else: + past_key_values_out = outputs.past_key_values + logits = self.lm_head(hidden_states) - return logits, past_key_values + return logits, past_key_values_out def llm_to_onnx(model, output_dir, extra_inputs={}, extra_dyn_axes={}): diff --git a/modelopt/onnx/llm_export_utils/quantization_utils.py b/modelopt/onnx/llm_export_utils/quantization_utils.py index 61f551b634..07ca697316 100644 --- a/modelopt/onnx/llm_export_utils/quantization_utils.py +++ b/modelopt/onnx/llm_export_utils/quantization_utils.py @@ -65,6 +65,9 @@ def get_quant_config(precision, lm_head_precision="fp16"): elif precision == "int4_awq": quant_cfg = mtq.INT4_AWQ_CFG + elif precision == "int8_sq": + quant_cfg = mtq.INT8_SMOOTHQUANT_CFG + else: raise ValueError(f"Unsupported precision: {precision}") @@ -96,9 +99,11 @@ def quantize( assert precision in [ "fp8", "int4_awq", + "int8_sq", "nvfp4", ], ( - f"Only fp8(W8A8), int4_awq(W4A16), nvfp4(W4A4) is supported. You passed an unsupported precision: {precision}." + "Only fp8(W8A8), int4_awq(W4A16), int8_sq(W8A8 SmoothQuant), nvfp4(W4A4) is supported." + f" You passed an unsupported precision: {precision}." ) assert lm_head_precision in ["fp16"], ( diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index ddd638bc28..7fc0b0cc39 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -137,8 +137,14 @@ def export_int8( unsigned: bool, narrow_range: bool, trt_high_precision_dtype: str | None, + onnx_quantizer_type: str | None = None, ): - """Export quantized model to INT8 ONNX.""" + """Export quantized model to INT8 ONNX. + + When onnx_quantizer_type is "static" (weight quantizer), emits DQ-only + (DequantizeLinear without QuantizeLinear), matching INT4 behavior. + When None or "dynamic" (activation quantizer), emits Q+DQ as before. + """ assert num_bits == 8, "Number of bits must be 8 for INT8 ONNX export." output_shape = sym_help._get_tensor_sizes(inputs) maxbound = (1 << (num_bits - 1 + int(unsigned))) - 1 @@ -169,19 +175,32 @@ def export_int8( scale.masked_fill_(scale == 0, 1.0) scale = g.op("Constant", value_t=scale) - assert trt_high_precision_dtype in (input_type, "Float", "BFloat16"), ( - "TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float." + assert trt_high_precision_dtype in (input_type, "Float", "Half", "BFloat16"), ( + "TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float/Half." ) - # custom ops, so cast the input if needed. + # Cast the input if needed. if trt_high_precision_dtype != input_type: inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[trt_high_precision_dtype]) - quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis) - out = g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis).setType( - inputs.type().with_dtype(torch_dtype_map[trt_high_precision_dtype]).with_sizes(output_shape) - ) - # custom ops, so cast the output if needed. + if onnx_quantizer_type == "static": + # DQ-only for weight quantizers (same pattern as INT4). + # Weight stays FP16 in the ONNX; post-processing packs to INT8. + out = g.op("DequantizeLinear", inputs, scale, zero_point, axis_i=axis).setType( + inputs.type() + .with_dtype(torch_dtype_map[trt_high_precision_dtype]) + .with_sizes(output_shape) + ) + else: + # Q+DQ for activation quantizers (standard fake-quant pattern). + quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis) + out = g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis).setType( + inputs.type() + .with_dtype(torch_dtype_map[trt_high_precision_dtype]) + .with_sizes(output_shape) + ) + + # Cast the output back if needed. if trt_high_precision_dtype != input_type: inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[input_type]) @@ -527,9 +546,10 @@ def _fp4_dynamic_quantize( if trt_high_precision_dtype != input_type: inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[trt_high_precision_dtype]) + scale_dtype = trt_high_precision_dtype if trt_high_precision_dtype else "Float" scale = g.op( "Constant", - value_t=torch.tensor(scale).to(torch_dtype_map["Float"]), + value_t=torch.tensor(scale).to(torch_dtype_map[scale_dtype]), ) # This is a TensorRT local function, it dynamically quantizes the input tensor to FP4. xf4, sx_f8 = g.op( @@ -553,9 +573,10 @@ def _fp4_dequantize( ): """Helper Function for Dequantization.""" if isinstance(scale, float): + scale_dtype = trt_high_precision_dtype if trt_high_precision_dtype else "Float" scale = g.op( "Constant", - value_t=torch.tensor(scale, dtype=torch_dtype_map["Float"]), + value_t=torch.tensor(scale, dtype=torch_dtype_map[scale_dtype]), ) return g.op("DequantizeLinear", inputs, scale) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 7db479f76f..b1c2604cc7 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -809,6 +809,7 @@ def _fake_quantize(self, inputs): self._pass_through_bwd, self.block_sizes.get(-1) if self.block_sizes else None, self.axis[0] if isinstance(self.axis, tuple) else self.axis, + getattr(self, "_onnx_quantizer_type", None), ) return outputs diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 16b9d32997..dcdb0efbff 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -319,7 +319,7 @@ class FakeTensorQuantFunction(Function): """Fake version of TensorQuantFunction use CUDA extension.""" @staticmethod - @symbolic_helper.parse_args("v", "t", "t", "i", "b", "b", "s", "b", "i", "i") + @symbolic_helper.parse_args("v", "t", "t", "i", "b", "b", "s", "b", "i", "i", "s") def symbolic( g, inputs, @@ -332,6 +332,7 @@ def symbolic( pass_through_bwd=False, block_size=None, axis=None, + onnx_quantizer_type=None, ): """ONNX symbolic function.""" from .export_onnx import export_int4, export_int8 @@ -342,7 +343,14 @@ def symbolic( ) return export_int8( - g, inputs, amax, num_bits, unsigned, narrow_range, trt_high_precision_dtype + g, + inputs, + amax, + num_bits, + unsigned, + narrow_range, + trt_high_precision_dtype, + onnx_quantizer_type=onnx_quantizer_type, ) @staticmethod @@ -358,6 +366,7 @@ def forward( pass_through_bwd=False, block_size=None, axis=None, + onnx_quantizer_type=None, ): """Forward method.""" if bias is not None: From 036942c67697c18e66441b550ae6b4b34ecb775a Mon Sep 17 00:00:00 2001 From: Hrishith Thadicherla Date: Thu, 12 Mar 2026 14:57:41 +0530 Subject: [PATCH 2/4] Fixed no of args passed to the backward function Signed-off-by: Hrishith Thadicherla --- modelopt/onnx/quantization/autotune/benchmark.py | 3 ++- modelopt/torch/quantization/tensor_quant.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/benchmark.py b/modelopt/onnx/quantization/autotune/benchmark.py index 6278eb43e2..bbbe83661d 100644 --- a/modelopt/onnx/quantization/autotune/benchmark.py +++ b/modelopt/onnx/quantization/autotune/benchmark.py @@ -384,7 +384,8 @@ def _load_plugin_libraries(self): try: if hasattr(os, "RTLD_LAZY") and hasattr(os, "RTLD_GLOBAL"): plugin_handle = ctypes.CDLL( - str(plugin_path), mode=os.RTLD_LAZY | os.RTLD_GLOBAL + str(plugin_path), + mode=os.RTLD_LAZY | os.RTLD_GLOBAL, # type: ignore[attr-defined] ) else: # Fallback for platforms without RTLD flags (e.g., Windows) diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index dcdb0efbff..4f7252e943 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -404,7 +404,7 @@ def legacy_quant_func(): @staticmethod def backward(ctx, grad_outputs): """Implements straight through estimation with clipping.""" - return _fake_quant_backward_function(ctx, grad_outputs, num_args=10) + return _fake_quant_backward_function(ctx, grad_outputs, num_args=11) class ScaledE4M3Function(Function): From f84cdb14c7da58dca00da3687f0f60f332ff81ca Mon Sep 17 00:00:00 2001 From: Hrishith Thadicherla Date: Thu, 12 Mar 2026 15:06:03 +0530 Subject: [PATCH 3/4] reverted benchmark.py to its original state Signed-off-by: Hrishith Thadicherla --- modelopt/onnx/quantization/autotune/benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/onnx/quantization/autotune/benchmark.py b/modelopt/onnx/quantization/autotune/benchmark.py index bbbe83661d..765cb1d357 100644 --- a/modelopt/onnx/quantization/autotune/benchmark.py +++ b/modelopt/onnx/quantization/autotune/benchmark.py @@ -385,7 +385,7 @@ def _load_plugin_libraries(self): if hasattr(os, "RTLD_LAZY") and hasattr(os, "RTLD_GLOBAL"): plugin_handle = ctypes.CDLL( str(plugin_path), - mode=os.RTLD_LAZY | os.RTLD_GLOBAL, # type: ignore[attr-defined] + mode=os.RTLD_LAZY | os.RTLD_GLOBAL, ) else: # Fallback for platforms without RTLD flags (e.g., Windows) From 23ee38e01ff609af288eafdb872c9fb318062836 Mon Sep 17 00:00:00 2001 From: Hrishith Thadicherla Date: Thu, 12 Mar 2026 17:16:52 +0530 Subject: [PATCH 4/4] fixed export utils to pick up head dim from config by default and only use hidden_size // num_attention_heads when head_dim is not specified Signed-off-by: Hrishith Thadicherla --- modelopt/onnx/llm_export_utils/export_utils.py | 2 +- modelopt/onnx/quantization/autotune/benchmark.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/modelopt/onnx/llm_export_utils/export_utils.py b/modelopt/onnx/llm_export_utils/export_utils.py index c5b81e3122..a1d94e529b 100644 --- a/modelopt/onnx/llm_export_utils/export_utils.py +++ b/modelopt/onnx/llm_export_utils/export_utils.py @@ -148,7 +148,7 @@ def llm_to_onnx(model, output_dir, extra_inputs={}, extra_dyn_axes={}): num_attention_heads = config.num_attention_heads num_key_value_heads = config.num_key_value_heads hidden_size = config.hidden_size - hidden_size_per_layer = hidden_size // num_attention_heads + hidden_size_per_layer = getattr(config, "head_dim", hidden_size // num_attention_heads) dummy_bs = 1 dummy_len = 10 diff --git a/modelopt/onnx/quantization/autotune/benchmark.py b/modelopt/onnx/quantization/autotune/benchmark.py index 765cb1d357..6278eb43e2 100644 --- a/modelopt/onnx/quantization/autotune/benchmark.py +++ b/modelopt/onnx/quantization/autotune/benchmark.py @@ -384,8 +384,7 @@ def _load_plugin_libraries(self): try: if hasattr(os, "RTLD_LAZY") and hasattr(os, "RTLD_GLOBAL"): plugin_handle = ctypes.CDLL( - str(plugin_path), - mode=os.RTLD_LAZY | os.RTLD_GLOBAL, + str(plugin_path), mode=os.RTLD_LAZY | os.RTLD_GLOBAL ) else: # Fallback for platforms without RTLD flags (e.g., Windows)