diff --git a/modelopt/onnx/export/fp8_exporter.py b/modelopt/onnx/export/fp8_exporter.py index 28e6b1da1..fe0a3c7a2 100644 --- a/modelopt/onnx/export/fp8_exporter.py +++ b/modelopt/onnx/export/fp8_exporter.py @@ -22,6 +22,8 @@ import torch from onnx_graphsurgeon.ir.tensor import LazyValues +from modelopt.onnx.utils import is_fp8_constant + from .base_exporter import ONNXQuantExporter @@ -61,37 +63,46 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: graph.cleanup().toposort().fold_constants().cleanup() for node in graph.nodes: - if node.op == "TRT_FP8QuantizeLinear": - # Should not remove input QDQ - if not isinstance(node.inputs[0], gs.Constant): - continue - - weights = node.inputs[0] - scale = node.inputs[1] - torch_weights = torch.from_numpy(weights.values) - torch_scale = torch.from_numpy(scale.values) - quantizer_name = scale.name.rsplit("/", 1)[0] - dq_op = node.outputs[0].outputs[0] - assert dq_op.op == "TRT_FP8DequantizeLinear", ( - f"QDQ does not occur in pairs. You reached {dq_op.op}" - ) - - # Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8. - numpy_weights = ( - (torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy() - ) - tensor = onnx.TensorProto() - tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN - tensor.dims.extend(numpy_weights.shape) - tensor.raw_data = numpy_weights.tobytes() - values = LazyValues(tensor) - onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values) - - node.outputs.clear() - # DQ Op is separated out - dq_op.inputs[0] = onnx_weights_fp8 - dq_op.op = "DequantizeLinear" - dq_op.outputs[0].dtype = dq_op.inputs[1].dtype + is_trt_fp8_q = node.op == "TRT_FP8QuantizeLinear" + is_std_fp8_q = ( + node.op == "QuantizeLinear" + and len(node.inputs) >= 3 + and isinstance(node.inputs[2], gs.Constant) + and is_fp8_constant(node.inputs[2]) + ) + if not (is_trt_fp8_q or is_std_fp8_q): + continue + + # Should not remove input QDQ + if not isinstance(node.inputs[0], gs.Constant): + continue + + weights = node.inputs[0] + scale = node.inputs[1] + torch_weights = torch.from_numpy(weights.values) + torch_scale = torch.from_numpy(scale.values) + quantizer_name = scale.name.rsplit("/", 1)[0] + dq_op = node.outputs[0].outputs[0] + assert dq_op.op in ("TRT_FP8DequantizeLinear", "DequantizeLinear"), ( + f"QDQ does not occur in pairs. You reached {dq_op.op}" + ) + + # Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8. + numpy_weights = ( + (torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy() + ) + tensor = onnx.TensorProto() + tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN + tensor.dims.extend(numpy_weights.shape) + tensor.raw_data = numpy_weights.tobytes() + values = LazyValues(tensor) + onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values) + + node.outputs.clear() + # DQ Op is separated out + dq_op.inputs[0] = onnx_weights_fp8 + dq_op.op = "DequantizeLinear" + dq_op.outputs[0].dtype = dq_op.inputs[1].dtype graph.cleanup().toposort() end_time = time.time() diff --git a/modelopt/onnx/llm_export_utils/surgeon_utils.py b/modelopt/onnx/llm_export_utils/surgeon_utils.py index 664ed0f32..f43998e75 100644 --- a/modelopt/onnx/llm_export_utils/surgeon_utils.py +++ b/modelopt/onnx/llm_export_utils/surgeon_utils.py @@ -23,6 +23,8 @@ import torch from onnx_graphsurgeon.ir.tensor import LazyValues +from modelopt.onnx.utils import is_fp8_constant + def clear_inputs(node: gs.Node | gs.Tensor): """Clear all inputs for a node or tensor in ONNX.""" @@ -81,37 +83,46 @@ def fold_fp8_qdq_to_dq(graph: gs.Graph): graph.cleanup().toposort().fold_constants().cleanup() for node in graph.nodes: - if node.op == "TRT_FP8QuantizeLinear": - # Should not remove input QDQ - if not isinstance(node.inputs[0], gs.Constant): - continue - - weights = node.inputs[0] - scale = node.inputs[1] - torch_weights = torch.from_numpy(weights.values) - torch_scale = torch.from_numpy(scale.values) - quantizer_name = scale.name.rsplit("/", 1)[0] - dq_op = node.outputs[0].outputs[0] - assert dq_op.op == "TRT_FP8DequantizeLinear", ( - f"QDQ does not occur in pairs. You reached {dq_op.op}" - ) - - # Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8. - numpy_weights = ( - (torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy() - ) - tensor = onnx.TensorProto() - tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN - tensor.dims.extend(numpy_weights.shape) - tensor.raw_data = numpy_weights.tobytes() - values = LazyValues(tensor) - onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values) - - node.outputs.clear() - # DQ Op is separated out - dq_op.inputs[0] = onnx_weights_fp8 - dq_op.op = "DequantizeLinear" - dq_op.outputs[0].dtype = dq_op.inputs[1].dtype + is_trt_fp8_q = node.op == "TRT_FP8QuantizeLinear" + is_std_fp8_q = ( + node.op == "QuantizeLinear" + and len(node.inputs) >= 3 + and isinstance(node.inputs[2], gs.Constant) + and is_fp8_constant(node.inputs[2]) + ) + if not (is_trt_fp8_q or is_std_fp8_q): + continue + + # Should not remove input QDQ + if not isinstance(node.inputs[0], gs.Constant): + continue + + weights = node.inputs[0] + scale = node.inputs[1] + torch_weights = torch.from_numpy(weights.values) + torch_scale = torch.from_numpy(scale.values) + quantizer_name = scale.name.rsplit("/", 1)[0] + dq_op = node.outputs[0].outputs[0] + assert dq_op.op in ("TRT_FP8DequantizeLinear", "DequantizeLinear"), ( + f"QDQ does not occur in pairs. You reached {dq_op.op}" + ) + + # Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8. + numpy_weights = ( + (torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy() + ) + tensor = onnx.TensorProto() + tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN + tensor.dims.extend(numpy_weights.shape) + tensor.raw_data = numpy_weights.tobytes() + values = LazyValues(tensor) + onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values) + + node.outputs.clear() + # DQ Op is separated out + dq_op.inputs[0] = onnx_weights_fp8 + dq_op.op = "DequantizeLinear" + dq_op.outputs[0].dtype = dq_op.inputs[1].dtype graph.cleanup().toposort() end_time = time.time() diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 4025ea065..95fddf76c 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -28,6 +28,7 @@ import onnx_graphsurgeon as gs from onnx.helper import get_attribute_value from onnx_graphsurgeon import Constant, Node, Variable +from onnx_graphsurgeon.ir.tensor import LazyValues from modelopt.onnx.logging_config import logger @@ -35,6 +36,19 @@ BASE_MIN_OPSET = 19 +def is_fp8_constant(const: Constant) -> bool: + """Return True if a gs.Constant holds a FLOAT8E4M3FN tensor. + + Uses getattr to guard against future changes to the LazyValues internal API. + """ + if not isinstance(const.values, LazyValues): + return False + tensor_proto = getattr(const.values, "_tensor", None) + if tensor_proto is None: + return False + return tensor_proto.data_type == onnx.TensorProto.FLOAT8E4M3FN + + def get_input_names_from_bytes(model_bytes: bytes, external_inputs_only: bool = True) -> list[str]: """This function returns the inputs names of the given onnx model in bytes. diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index ddd638bc2..bce7f61d2 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -126,7 +126,12 @@ } mha_valid_precisions = {"Half", "BFloat16"} -torch_dtype_map = {"Float": torch.float32, "Half": torch.float16, "BFloat16": torch.bfloat16} +torch_dtype_map = { + "Float": torch.float32, + "Half": torch.float16, + "BFloat16": torch.bfloat16, + "Float8": torch.float8_e4m3fn, +} def export_int8( @@ -221,8 +226,7 @@ def _fp8_quantize( """Helper Function for Quantization.""" output_shape = sym_help._get_tensor_sizes(inputs) - # TRT StronglyType only supports FP16 QDQs - # custom ops, so cast the input if needed. + # Cast the input to the high-precision dtype if needed. input_type = inputs.type().scalarType() assert trt_high_precision_dtype in (input_type, "Float"), ( "TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float." @@ -234,9 +238,12 @@ def _fp8_quantize( "Constant", value_t=torch.tensor(scale_inv).to(torch_dtype_map[trt_high_precision_dtype]), ) - q_op = g.op("trt::TRT_FP8QuantizeLinear", inputs, scale).setType( - inputs.type().with_dtype(torch.uint8).with_sizes(output_shape) - ) + # Use standard ONNX QuantizeLinear with FLOAT8E4M3FN zero_point (opset 19). + # The zero_point dtype determines the output dtype per the ONNX spec. + zero_point = g.op("Constant", value_t=torch.tensor(0.0)) + zero_point = g.op("Cast", zero_point, to_i=onnx_dtype_map["Float8"]) + q_op = g.op("QuantizeLinear", inputs, scale, zero_point, saturate_i=1) + q_op.setType(inputs.type().with_dtype(torch.float8_e4m3fn).with_sizes(output_shape)) return q_op @@ -249,21 +256,22 @@ def _fp8_dequantize( ): """Helper Function for Dequantization.""" output_shape = sym_help._get_tensor_sizes(inputs) - assert trt_high_precision_dtype in (otype, "Float"), ( - "TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float." - ) scale = g.op( "Constant", value_t=torch.tensor(scale_inv, dtype=torch_dtype_map[otype]), # type: ignore[index] ) - out = g.op("trt::TRT_FP8DequantizeLinear", inputs, scale).setType( + # Use standard ONNX DequantizeLinear with FLOAT8E4M3FN zero_point (opset 19). + # Per the ONNX spec, DequantizeLinear with FLOAT8E4M3FN input outputs float32. + zero_point = g.op("Constant", value_t=torch.tensor(0.0)) + zero_point = g.op("Cast", zero_point, to_i=onnx_dtype_map["Float8"]) + out = g.op("DequantizeLinear", inputs, scale, zero_point) + out.setType( inputs.type().with_dtype(torch_dtype_map[trt_high_precision_dtype]).with_sizes(output_shape) ) - # DQ outputs are currently constrained to FP32 due to a similar limitation in ORT - # custom ops, so cast the output if needed. - if trt_high_precision_dtype != otype: - out = g.op("Cast", out, to_i=onnx_dtype_map[otype]) # type: ignore[index] + # DequantizeLinear outputs float32 in opset 19; cast back to original type if needed. + if otype in torch_dtype_map and otype != "Float": + out = g.op("Cast", out, to_i=onnx_dtype_map[otype]) return out diff --git a/tests/unit/torch/quantization/test_fp8_onnx_shape.py b/tests/unit/torch/quantization/test_fp8_onnx_shape.py new file mode 100644 index 000000000..eac9e9b54 --- /dev/null +++ b/tests/unit/torch/quantization/test_fp8_onnx_shape.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests: FP8 ONNX export shape inference. + +Two complementary tests: + 1. Prove that TRT custom FP8 ops lose shape info (root-cause regression guard). + 2. Prove that standard ONNX QDQ ops preserve shape info after the fix. +""" + +import io + +import pytest + +onnx = pytest.importorskip("onnx") + +import torch +from _test_utils.torch.quantization.models import SimpleConv + +import modelopt.torch.quantization as mtq + +# --------------------------------------------------------------------------- +# Part 1 — root-cause: TRT custom ops have no ONNX shape inference function +# --------------------------------------------------------------------------- + + +def test_trt_fp8_ops_unsupported_by_onnx_inference(): + """ONNX shape inference raises or produces no shape for trt::TRT_FP8QuantizeLinear. + + This documents the root cause: TRT custom ops are not registered with ONNX + shape inference, so any graph containing them cannot have shapes propagated + through those nodes. ONNX either raises InferenceError (newer versions) or + silently leaves the output shape empty (older versions). + """ + from onnx import TensorProto, helper + + # Build a minimal ONNX graph: Input → trt::TRT_FP8QuantizeLinear → Output + input_vi = helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 4]) + scale_init = helper.make_tensor("scale", TensorProto.FLOAT, [1], [1.0]) + output_vi = helper.make_tensor_value_info("y", TensorProto.UINT8, None) + + node = helper.make_node( + "TRT_FP8QuantizeLinear", + inputs=["x", "scale"], + outputs=["y"], + domain="trt", + ) + + graph = helper.make_graph([node], "trt_fp8_q_test", [input_vi], [output_vi], [scale_init]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 19)]) + model.ir_version = 9 + + try: + inferred = onnx.shape_inference.infer_shapes(model, strict_mode=False) + except onnx.shape_inference.InferenceError: + # Newer ONNX rejects unknown domains outright — inference is impossible. + return + + # Older ONNX silently skips unknown ops, leaving the output shape empty. + output_shape = inferred.graph.output[0].type.tensor_type.shape + assert not output_shape.dim, ( + "Expected TRT_FP8QuantizeLinear output to have no shape (op unknown to ONNX), " + f"but got dims: {list(output_shape.dim)}" + ) + + +# --------------------------------------------------------------------------- +# Part 2 — fix: standard ONNX QDQ ops preserve shape after export +# --------------------------------------------------------------------------- + + +def test_fp8_onnx_export_shape_preserved(): + """FP8-quantized SimpleConv exported with opset 19 retains shape on all QDQ outputs.""" + model = SimpleConv().eval() + dummy_input = SimpleConv.get_input() + + def forward_loop(m): + m(dummy_input) + + model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop=forward_loop) + # Disable output quantizers to avoid export errors (they produce FP8 outputs that + # downstream non-quantized ops can't accept in the TorchScript exporter). + mtq.disable_quantizer(model, lambda name: "output_quantizer" in name) + + buf = io.BytesIO() + torch.onnx.export( + model, + dummy_input, + buf, + opset_version=19, + input_names=["input"], + output_names=["output"], + dynamo=False, + ) + buf.seek(0) + onnx_model = onnx.load_from_string(buf.read()) + + # No TRT custom FP8 ops should remain. + trt_fp8_ops = [ + n.op_type for n in onnx_model.graph.node if n.domain == "trt" and "FP8" in n.op_type + ] + assert not trt_fp8_ops, ( + f"Found TRT custom FP8 ops after export: {trt_fp8_ops}. " + "These have no ONNX shape inference and will cause shape loss." + ) + + # Run shape inference and collect QDQ output shapes. + inferred = onnx.shape_inference.infer_shapes(onnx_model, strict_mode=False) + shape_by_name: dict[str, list] = {} + for vi in (*inferred.graph.value_info, *inferred.graph.output): + shape_by_name[vi.name] = [ + d.dim_value if d.HasField("dim_value") else -1 for d in vi.type.tensor_type.shape.dim + ] + + missing = [] + for node in inferred.graph.node: + if node.op_type in ("QuantizeLinear", "DequantizeLinear"): + for out in node.output: + shape = shape_by_name.get(out) + if not shape: + missing.append(out) + + assert not missing, ( + f"The following QDQ outputs are missing shape info after shape inference: {missing}. " + "This indicates the FP8 export still uses ops without ONNX shape inference support." + )