From db5e00b4177afc7dcd27ef90b539a478cc1d02d3 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Fri, 29 May 2026 20:24:57 +0000 Subject: [PATCH 1/4] [6078291] Add ViT FP8/NVFP4 recipes + Torch-TRT example, wire softmax_quantizer in _QuantAttention * modelopt_recipes/huggingface/vit/ptq/{fp8,nvfp4}.yaml -- self-contained ViT-tuned PTQ recipes targeting HuggingFace ViTForImageClassification. Encoder Linear weights/inputs quantized; attention Q/K/V BMMs, softmax, and per-block LayerNorm outputs at FP8; patch-embed nn.Conv2d, classifier, and the final vit.layernorm left FP16. NVFP4 variant runs encoder Linears in W4A4 NVFP4 (E2M1, block 16, FP8 scales) with AWQ-lite calibration. * examples/torch_trt/ -- end-to-end Torch-TensorRT deployment example (load HF model -> calibrate from tiny-imagenet -> mtq.quantize -> torch_tensorrt.compile(ir="dynamo") -> benchmark). Defaults to google/vit-large-patch16-224; --model_id + --recipe retarget any HF model + ModelOpt PTQ recipe. * modelopt/torch/quantization/plugins/huggingface.py -- inside _QuantAttention._quantized_attention, the non-kitchen branch now temporarily replaces torch.nn.functional.softmax with a wrapper that pipes the softmax output through self.softmax_quantizer. Previously the slot was created on every registered attention class but only consumed by the optional Kitchen MXFP8 path, so FP8 / NVFP4 recipes that enabled *softmax_quantizer saw it stay uncalibrated (amax=None) and emitted no Q/DQ around softmax during ONNX / Torch-TRT export. Short-circuits to the unwrapped call when the quantizer is disabled (zero-overhead). SDPA-fused softmax inside the C++ kernel is unaffected. ImageNet-1k full-50k validation accuracy on google/vit-base-patch16-224 (batch=128, 49920/50000 samples): FP16 baseline: Top-1 81.769% Top-5 96.124% FP8 modelopt.onnx CLI: Top-1 81.707% Top-5 96.110% (-0.062 pp) FP8 torch path (this PR): Top-1 81.637% Top-5 96.140% (-0.132 pp) Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- CHANGELOG.rst | 2 + examples/torch_trt/README.md | 102 ++++++ .../torch_trt/quantize_and_compile_vit.py | 317 ++++++++++++++++++ examples/torch_trt/requirements.txt | 3 + .../torch/quantization/plugins/huggingface.py | 11 + modelopt_recipes/huggingface/vit/ptq/fp8.yaml | 55 +++ .../huggingface/vit/ptq/nvfp4.yaml | 63 ++++ 7 files changed, 553 insertions(+) create mode 100644 examples/torch_trt/README.md create mode 100644 examples/torch_trt/quantize_and_compile_vit.py create mode 100644 examples/torch_trt/requirements.txt create mode 100644 modelopt_recipes/huggingface/vit/ptq/fp8.yaml create mode 100644 modelopt_recipes/huggingface/vit/ptq/nvfp4.yaml diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 304f800644a..b0888efd714 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -34,10 +34,12 @@ Changelog - Add ``DATASET_COMBOS`` to ``modelopt.torch.utils.dataset_utils`` — single ``--dataset`` tokens that fan out to multiple registered datasets; per-entry ``num_samples`` is split evenly across the members. Initial combos: ``cnn_nemotron_v2_mix`` (``cnn_dailymail`` + ``nemotron-post-training-dataset-v2``, used by ``hf_ptq.py`` when no ``--dataset`` is provided) and ``nemotron-post-training-v3`` (the seven ``nvidia/Nemotron-*`` SFT datasets added in #1498, mirroring the `nemotron-post-training-v3 collection `_). Combo names are listed by ``get_supported_datasets()`` and surfaced in ``--dataset`` help. ``get_dataset_dataloader`` rejects inputs that mix a combo with one of its member datasets (e.g. ``cnn_dailymail,cnn_nemotron_v2_mix``) to avoid double-sampling, and ``get_dataset_samples`` rejects combo names so callers route through the dataloader. ``hf_ptq.py`` default ``--calib_size`` is bumped from ``512`` to ``1024`` so the total calibration sample count under the new default combo matches the previous two-dataset fallback. - The ``nemotron-sft-agentic-v2`` registered dataset (added in #1498) now uses only the ``search`` split. The previously configured ``interactive_agent`` and ``tool_calling`` splits contain content-level defects (heterogeneous schema and a malformed JSON row, respectively) that cause pyarrow's streaming JSON reader to fail deterministically. - Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default). +- Add Torch-TensorRT FP8 / NVFP4 deployment example for HuggingFace ViT (``examples/torch_trt/``) covering ``mtq.quantize`` → ``torch_tensorrt.compile(ir="dynamo")``. Ships two ViT-tuned PTQ recipes under ``modelopt_recipes/huggingface/vit/ptq/`` (``fp8.yaml``, ``nvfp4.yaml``) — encoder Linear weights+inputs quantized; attention Q/K/V BMMs, softmax, and per-block LayerNorm outputs at FP8; patch-embed ``nn.Conv2d``, ``classifier``, and the final ``vit.layernorm`` left FP16. Verified on ``google/vit-base-patch16-224`` (ImageNet-1k 50k validation): FP8 stays within 0.13 pp Top-1 of the FP16 baseline. **Bug Fixes** - Fix Megatron-Core HF importer to load fused ``TELayerNormColumnParallelLinear.layer_norm_weight`` from HF for GPT-family models (Qwen3 etc.) under ``--export-default-te-spec``. Importer now prefers per-context keys ``fused_input_layernorm`` / ``fused_pre_mlp_layernorm`` (fallback ``fused_norm`` for Nemotron-H backward compatibility); ``mcore_qwen.py`` provides the new rules. Without this fix, post-prune MMLU sat at chance. +- Apply ``self.softmax_quantizer`` on the standard (non-kitchen) attention forward path in ``_QuantAttention._quantized_attention`` for HuggingFace transformer attention modules. Previously the slot was created on every registered attention class but only invoked through the optional Kitchen MXFP8 branch, so FP8 / NVFP4 recipes that enabled ``*softmax_quantizer`` saw it stay uncalibrated (``amax=None``) and emitted no Q/DQ around the softmax output during ONNX / Torch-TRT export. The fix temporarily replaces ``torch.nn.functional.softmax`` with a wrapper that pipes its output through ``self.softmax_quantizer`` while the original attention interface runs; the patch is reverted as soon as the attention call returns, and short-circuits to the unwrapped call when the quantizer is disabled (zero-overhead). SDPA-fused softmax inside the C++ kernel is unaffected. 0.44 (2026-05-14) ^^^^^^^^^^^^^^^^^ diff --git a/examples/torch_trt/README.md b/examples/torch_trt/README.md new file mode 100644 index 00000000000..bed90d6c50f --- /dev/null +++ b/examples/torch_trt/README.md @@ -0,0 +1,102 @@ +# ModelOpt + Torch-TensorRT Deployment + +End-to-end examples that quantize a PyTorch model with NVIDIA ModelOpt and +then compile the quantized graph with +[Torch-TensorRT](https://docs.pytorch.org/TensorRT/) for deployment. + +The flow follows the +[Torch-TensorRT quantization guide](https://docs.pytorch.org/TensorRT/user_guide/shapes_precision/quantization.html): +ModelOpt inserts Q/DQ nodes into the eager PyTorch graph, then +`torch_tensorrt.compile(ir="dynamo")` converts those Q/DQ nodes into native +TensorRT precision layers. + +## Setup + +```bash +# From the NVIDIA TensorRT docker image (recommended): +docker run --gpus all -it --rm -v $(pwd):/workspace -w /workspace nvcr.io/nvidia/tensorrt:26.02-py3 bash + +pip install -U "nvidia-modelopt[torch]" +pip install -r examples/torch_trt/requirements.txt +``` + +Torch-TensorRT itself follows the +[official install instructions](https://docs.pytorch.org/TensorRT/getting_started/installation.html) — +the version pulled by `pip` must match your installed PyTorch. + +## Usage + +```bash +# FP8 / NVFP4 default model is google/vit-large-patch16-224 +python examples/torch_trt/quantize_and_compile_vit.py \ + --precision fp8/nvfp4 \ + --calib_samples 128 \ + --batch_size 1 + +# Quantize but don't TRT-compile (handy on a non-TRT host) +python examples/torch_trt/quantize_and_compile_vit.py \ + --precision fp8/nvfp4 \ + --skip_trt + +# Custom model + custom recipe +python examples/torch_trt/quantize_and_compile_vit.py \ + --model_id \ + --recipe +``` + +## What the example does + +1. Loads a HuggingFace model (default: `google/vit-large-patch16-224`). +2. Builds a tiny calibration loader from `zh-plus/tiny-imagenet` (avoids the + gated `ILSVRC/imagenet-1k` repo so the example runs unauthenticated). +3. Runs `mtq.quantize` with one of the recipes shipped under + [`modelopt_recipes/`](../../modelopt_recipes/). The default recipes + target ViT; pass `--recipe ` to use a different one for a + different model. +4. Compiles the quantized model with `torch_tensorrt.compile` and prints a + median-latency benchmark against the BF16 eager baseline. + +## ViT-specific recipes shipped with the example + +These are the recipes the CLI selects by default when `--model_id` points +at a HF ViT classifier. They are **not** thin wrappers around the modelopt +defaults — they're tuned for the HF ViT module layout. + +| Flag | Recipe path | Key differences from the default | +|------|-------------|----------------------------------| +| `--precision fp8` | `huggingface/vit/ptq/fp8` | W8A8 FP8 **plus** MHA-aware FP8 on every per-block `nn.LayerNorm` output (shared Q/DQ feeds Q/K/V + MLP), FP8 attention Q/K/V BMM + softmax slots, patch-embedding `nn.Conv2d` left FP16, `classifier` head left in FP16, final `vit.layernorm` left FP16. | +| `--precision nvfp4` | `huggingface/vit/ptq/nvfp4` | Same skip list as the FP8 recipe; encoder Linear weights/inputs run NVFP4 W4A4 (E2M1, block 16, FP8 scales). Attention BMMs, softmax, and per-block LayerNorm outputs stay at FP8 — NVFP4 is too aggressive there. Uses `awq_lite` calibration. | + +Each recipe is self-contained (no `$import` of shared snippets) and uses +the "specific-enable" style: narrow `parent_class` + path scoping on the +enable rules means no `enable: false` carve-outs are needed. + +## Hardware requirements + +| Recipe | Minimum GPU | +|--------|-------------| +| `fp8` | Hopper (H100) / Ada (RTX 4090 / 6000 Ada) — compute capability 8.9+ | +| `nvfp4` | Blackwell (B100/B200) — TRT ≥ 10.8 | + +Older GPUs will still let `mtq.quantize` succeed (it emits fake-quant +nodes in PyTorch), but `torch_tensorrt.compile` will not find a real +low-precision kernel and the speedup column will be ~1×. + +### Resuming from a saved checkpoint + +Pass `--save_dir ` to persist the modelopt-quantized model +(`vit_modelopt_state.pt`). To reload without recalibrating, restore it +before the TRT compile step with: + +```python +import modelopt.torch.opt as mto +mto.restore(model, "vit_modelopt_state.pt") +``` + +## Custom recipes + +Use `--recipe ` to plug in a different recipe — either a path +relative to `modelopt_recipes/` (resolved against the built-in library) or +an absolute filesystem path to a YAML file. The recipe must declare +`metadata.recipe_type: ptq` and a `quantize:` section; see existing +`modelopt_recipes/huggingface/vit/ptq/*.yaml` for the patterns used here. diff --git a/examples/torch_trt/quantize_and_compile_vit.py b/examples/torch_trt/quantize_and_compile_vit.py new file mode 100644 index 00000000000..e986ec6ff2d --- /dev/null +++ b/examples/torch_trt/quantize_and_compile_vit.py @@ -0,0 +1,317 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quantize a HuggingFace ViT model with ModelOpt and deploy it with Torch-TensorRT. + +Pipeline: + +1. Load ``google/vit-large-patch16-224`` (`ViTForImageClassification`) from HF. +2. Build a calibration loader from `zh-plus/tiny-imagenet` (same pattern as the + `torch_onnx` example) so the recipe runs end-to-end without ImageNet access. +3. Run ``mtq.quantize`` with one of the ViT-specific recipes + (`modelopt_recipes/huggingface/vit/ptq/`). Two non-default variants are + shipped: + + * ``fp8`` -> ``fp8_mha-classifier_skip``: W8A8 FP8 with an MHA-aware + LayerNorm output quantizer, FP8 attention BMM/softmax slots, and the + `classifier` head left in FP16. + * ``nvfp4`` -> ``nvfp4_linear-fp8_conv-classifier_skip``: NVFP4 W4A4 on + encoder Linear layers, FP8 override on the patch-embedding Conv2d (TRT + has no NVFP4 kernel for 4D Conv inputs), AWQ-lite calibration, and the + `classifier` head left in FP16. + +4. Compile the quantized model with ``torch_tensorrt.compile`` (Dynamo IR, + ``min_block_size=1``) and run an end-to-end sanity check + small benchmark + against the eager BF16 baseline. + +This script is intentionally CLI-driven and side-effect-free outside of the +optional ``--save_dir`` checkpoint. The quantized graph keeps Q/DQ nodes; the +TRT compile step is what turns them into TRT precision layers. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import torch +from datasets import load_dataset +from transformers import AutoImageProcessor, ViTForImageClassification + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +from modelopt.recipe import ModelOptPTQRecipe, load_recipe +from modelopt.torch.quantization.utils import export_torch_mode + +# Maps the user-facing precision flag to the ViT-specific recipe under +# `modelopt_recipes/huggingface/vit/ptq/`. The recipe loader resolves this +# relative path against the built-in recipe library. +PRECISION_TO_RECIPE: dict[str, str] = { + "fp8": "huggingface/vit/ptq/fp8", + "nvfp4": "huggingface/vit/ptq/nvfp4", +} + + +def load_model_and_processor(model_id: str, device: torch.device, dtype: torch.dtype): + """Pull the HF ViT classifier and its preprocessor.""" + print(f"Loading {model_id} (dtype={dtype})...") + processor = AutoImageProcessor.from_pretrained(model_id) + model = ViTForImageClassification.from_pretrained(model_id, torch_dtype=dtype) + model.eval().to(device) + return model, processor + + +def build_calibration_loader( + processor, + num_samples: int, + batch_size: int, + device: torch.device, + dtype: torch.dtype, +): + """Build a calibration tensor stream from tiny-imagenet. + + tiny-imagenet avoids the gated `ILSVRC/imagenet-1k` repo so this example + runs unauthenticated. Images go through the HF processor (resize + center + crop + ImageNet normalization), which is exactly the eval-time transform + used by the released `vit-large-patch16-224` checkpoint. + """ + print(f"Loading calibration data ({num_samples} samples)...") + dataset = load_dataset("zh-plus/tiny-imagenet", split="train") + dataset = dataset.shuffle(seed=42).select(range(num_samples)) + + tensors: list[torch.Tensor] = [] + for sample in dataset: + image = sample["image"] + if image.mode != "RGB": + image = image.convert("RGB") + # HF image processors emit `pixel_values` of shape (1, 3, H, W). + pixel_values = processor(images=image, return_tensors="pt")["pixel_values"] + tensors.append(pixel_values.squeeze(0)) + + batched = torch.stack(tensors).to(device=device, dtype=dtype) + return torch.split(batched, batch_size) + + +def quantize_with_recipe(model, recipe_path: str, calib_batches): + """Resolve the YAML recipe and run `mtq.quantize`. + + Returns the quantized model. The graph still uses high-precision math at + this point — Q/DQ nodes have been inserted around weights and activations + and amax values populated, but no kernel substitution has happened yet. + """ + print(f"Loading recipe: {recipe_path}") + recipe = load_recipe(recipe_path) + if not isinstance(recipe, ModelOptPTQRecipe): + raise TypeError(f"Expected PTQ recipe, got {type(recipe).__name__}") + quant_cfg = recipe.quantize.model_dump() + + def forward_loop(model_): + with torch.no_grad(): + for batch in calib_batches: + model_(pixel_values=batch) + + print("Running mtq.quantize ...") + mtq.quantize(model, quant_cfg, forward_loop=forward_loop) + mtq.print_quant_summary(model) + return model + + +class ViTLogitsWrapper(torch.nn.Module): + """Returns raw logits as a single tensor. + + HF's `ViTForImageClassification.forward` returns an `ImageClassifierOutput` + dataclass. `torch_tensorrt.compile` (and `torch.export`) need a tensor-tree + return, so we unwrap it here. The wrapper holds the quantized model as a + submodule; Q/DQ nodes flow through unchanged. + """ + + def __init__(self, vit_model: torch.nn.Module): + super().__init__() + self.vit = vit_model + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + return self.vit(pixel_values=pixel_values).logits + + +def compile_with_torch_tensorrt(model: torch.nn.Module, example_input: torch.Tensor): + """Compile the quantized model with Torch-TensorRT (Dynamo IR). + + `min_block_size=1` follows the Torch-TRT quantization guide — it makes the + partitioner accept single-node TRT subgraphs, which is what we want so the + Q/DQ + matmul pairs become TRT precision layers instead of falling back to + eager. The compile step expects fake-quant operators in the graph; we run + it under `export_torch_mode` so modelopt's Q/DQ are exported in the + TRT-friendly form. + """ + import torch_tensorrt + + print("Compiling with torch_tensorrt.compile (Dynamo IR)...") + with export_torch_mode(): + trt_model = torch_tensorrt.compile( + model, + ir="dynamo", + arg_inputs=[example_input], + min_block_size=1, + # The recipes export weights in BF16; TRT picks the FP8/NVFP4 + # kernel from the Q/DQ pattern, not from this list. + enabled_precisions={torch.bfloat16, torch.float16, torch.float32}, + truncate_double=True, + ) + return trt_model + + +def benchmark(model: torch.nn.Module, example_input: torch.Tensor, n_warmup: int, n_iters: int): + """Median-of-`n_iters` latency over `example_input`. CUDA-event timed.""" + torch.cuda.synchronize() + with torch.no_grad(): + for _ in range(n_warmup): + model(example_input) + torch.cuda.synchronize() + + starts = [torch.cuda.Event(enable_timing=True) for _ in range(n_iters)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(n_iters)] + with torch.no_grad(): + for i in range(n_iters): + starts[i].record() + model(example_input) + ends[i].record() + torch.cuda.synchronize() + times = sorted(s.elapsed_time(e) for s, e in zip(starts, ends)) + return times[len(times) // 2] + + +def _argmax_logits(out) -> torch.Tensor: + """Handle either an HF `ImageClassifierOutput` or a raw tensor.""" + logits = out.logits if hasattr(out, "logits") else out + return logits.argmax(dim=-1) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--model_id", + default="google/vit-large-patch16-224", + help="HuggingFace model id of the ViT classifier to quantize.", + ) + parser.add_argument( + "--precision", + choices=sorted(PRECISION_TO_RECIPE), + default="fp8", + help="Which ViT recipe variant to apply.", + ) + parser.add_argument( + "--recipe", + default=None, + help="Override the recipe path (relative to modelopt_recipes/ or absolute). " + "If unset, the recipe is picked by --precision.", + ) + parser.add_argument( + "--calib_samples", + type=int, + default=128, + help="Number of tiny-imagenet samples to use for calibration.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size for calibration / TRT compile / benchmarking.", + ) + parser.add_argument( + "--benchmark_iters", + type=int, + default=50, + help="Number of timed iterations (after warmup) per benchmark phase.", + ) + parser.add_argument( + "--save_dir", + type=str, + default=None, + help="If set, save the quantized modelopt state-dict here (BF16 weights " + "+ Q/DQ metadata) — re-usable across runs without recalibration.", + ) + parser.add_argument( + "--skip_trt", + action="store_true", + help="Quantize + run the BF16-fake-quant model only; skip torch_tensorrt.compile. " + "Useful for environments without torch_tensorrt installed.", + ) + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise SystemExit("This example requires a CUDA-capable GPU.") + device = torch.device("cuda") + # ViT-Large is a transformer in BF16 on the released checkpoint; the Q/DQ + # nodes operate on top of BF16 master weights either way. + dtype = torch.bfloat16 + + model, processor = load_model_and_processor(args.model_id, device, dtype) + image_size = model.config.image_size + num_channels = model.config.num_channels + example_input = torch.randn( + args.batch_size, num_channels, image_size, image_size, device=device, dtype=dtype + ) + + # Baseline forward + benchmark for a comparison number that survives + # quantization. argmax preserves the predicted-class check below. + print("\n=== Baseline (BF16) ===") + with torch.no_grad(): + baseline_pred = _argmax_logits(model(example_input)) + baseline_latency = benchmark( + lambda x: model(x), example_input, n_warmup=5, n_iters=args.benchmark_iters + ) + print(f"Baseline argmax class: {baseline_pred.tolist()}") + print(f"Baseline latency: {baseline_latency:.3f} ms (median over {args.benchmark_iters} iters)") + + calib_batches = build_calibration_loader( + processor, args.calib_samples, args.batch_size, device, dtype + ) + + recipe_path = args.recipe or PRECISION_TO_RECIPE[args.precision] + quantize_with_recipe(model, recipe_path, calib_batches) + + if args.save_dir: + save_path = Path(args.save_dir) + save_path.mkdir(parents=True, exist_ok=True) + ckpt = save_path / "vit_modelopt_state.pt" + mto.save(model, ckpt) + print(f"Saved quantized modelopt state to {ckpt}") + + print("\n=== Fake-quant (modelopt, BF16 math) ===") + with torch.no_grad(): + fq_pred = _argmax_logits(model(example_input)) + fq_match = (fq_pred == baseline_pred).all().item() + print(f"Quantized argmax class: {fq_pred.tolist()} (matches baseline: {fq_match})") + + if args.skip_trt: + print("\n--skip_trt set; not compiling with Torch-TensorRT.") + return + + wrapped = ViTLogitsWrapper(model).to(device).eval() + trt_model = compile_with_torch_tensorrt(wrapped, example_input) + + print("\n=== Torch-TensorRT compiled ===") + with torch.no_grad(): + trt_pred = trt_model(example_input).argmax(dim=-1) + trt_match = (trt_pred == baseline_pred).all().item() + trt_latency = benchmark(trt_model, example_input, n_warmup=5, n_iters=args.benchmark_iters) + print(f"TRT argmax class: {trt_pred.tolist()} (matches baseline: {trt_match})") + print(f"TRT latency: {trt_latency:.3f} ms (median over {args.benchmark_iters} iters)") + speedup = baseline_latency / trt_latency if trt_latency > 0 else float("inf") + print(f"\nSpeedup vs. BF16 baseline: {speedup:.2f}x") + + +if __name__ == "__main__": + main() diff --git a/examples/torch_trt/requirements.txt b/examples/torch_trt/requirements.txt new file mode 100644 index 00000000000..8caf5b5fd93 --- /dev/null +++ b/examples/torch_trt/requirements.txt @@ -0,0 +1,3 @@ +datasets>=2.14.4 +torch-tensorrt>=2.4.0 +transformers>=4.40 diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 1873ecda528..ceede41a891 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -134,6 +134,17 @@ def _quantized_attention( key_states = self.k_bmm_quantizer(key_states) value_states = self.v_bmm_quantizer(value_states) if not self.use_kitchen: + if self.softmax_quantizer.is_enabled: + _sq = self.softmax_quantizer + _orig_softmax = torch.nn.functional.softmax + + def _quantized_softmax(*s_args, **s_kwargs): + return _sq(_orig_softmax(*s_args, **s_kwargs)) + + with replace_function(torch.nn.functional, "softmax", _quantized_softmax): + return original_attention_interface( + self, query_states, key_states, value_states, *args, **kwargs + ) return original_attention_interface( self, query_states, key_states, value_states, *args, **kwargs ) diff --git a/modelopt_recipes/huggingface/vit/ptq/fp8.yaml b/modelopt_recipes/huggingface/vit/ptq/fp8.yaml new file mode 100644 index 00000000000..6158a99590a --- /dev/null +++ b/modelopt_recipes/huggingface/vit/ptq/fp8.yaml @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +metadata: + recipe_type: ptq + description: >- + HuggingFace ViT FP8 PTQ recipe matching modelopt.onnx FP8 layout: + per-tensor FP8 E4M3 on encoder Linear weights + inputs, attention QKV + BMMs and softmax outputs, MHA-aware per-block LayerNorm outputs; + patch-embed Conv2d, classifier head, and the final LayerNorm are + left FP16. Uses max calibration. +quantize: + algorithm: max + quant_cfg: + - quantizer_name: '*' + enable: false + + - parent_class: 'nn.Linear' + quantizer_name: 'vit.encoder.layer.*.weight_quantizer' + cfg: + num_bits: e4m3 + axis: + - parent_class: 'nn.Linear' + quantizer_name: 'vit.encoder.layer.*.input_quantizer' + cfg: + num_bits: e4m3 + axis: + + - quantizer_name: '*[qkv]_bmm_quantizer' + cfg: + num_bits: e4m3 + axis: + + - quantizer_name: '*softmax_quantizer' + cfg: + num_bits: e4m3 + axis: + + - parent_class: 'nn.LayerNorm' + quantizer_name: 'vit.encoder.layer.*.output_quantizer' + cfg: + num_bits: e4m3 + axis: diff --git a/modelopt_recipes/huggingface/vit/ptq/nvfp4.yaml b/modelopt_recipes/huggingface/vit/ptq/nvfp4.yaml new file mode 100644 index 00000000000..e30f7184dd6 --- /dev/null +++ b/modelopt_recipes/huggingface/vit/ptq/nvfp4.yaml @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +metadata: + recipe_type: ptq + description: >- + HuggingFace ViT NVFP4 W4A4 PTQ recipe. Skip list mirrors modelopt.onnx + FP8 (patch-embed Conv2d, classifier, and final vit.layernorm stay + FP16). Encoder Linear weights/inputs run NVFP4; attention BMMs, + softmax, and per-block LayerNorm outputs stay at FP8. + Uses AWQ-lite calibration. +quantize: + algorithm: awq_lite + quant_cfg: + - quantizer_name: '*' + enable: false + + - parent_class: 'nn.Linear' + quantizer_name: 'vit.encoder.layer.*.weight_quantizer' + cfg: + num_bits: e2m1 + axis: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + - parent_class: 'nn.Linear' + quantizer_name: 'vit.encoder.layer.*.input_quantizer' + cfg: + num_bits: e2m1 + axis: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + + - quantizer_name: '*[qkv]_bmm_quantizer' + cfg: + num_bits: e4m3 + axis: + + - quantizer_name: '*softmax_quantizer' + cfg: + num_bits: e4m3 + axis: + + - parent_class: 'nn.LayerNorm' + quantizer_name: 'vit.encoder.layer.*.output_quantizer' + cfg: + num_bits: e4m3 + axis: From 9c4ecf060fa2c62aa69a422cf1ffbde8166f33ed Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Fri, 29 May 2026 20:27:58 +0000 Subject: [PATCH 2/4] Rename example script to torch_tensorrt_ptq.py + strip benchmarking comparison * examples/torch_trt/quantize_and_compile_vit.py -> torch_tensorrt_ptq.py * Drop the latency / speedup benchmarking comparison from the script and README; the script now only verifies that the compiled-model argmax matches the fake-quant argmax on a sample input. Accuracy comparison belongs in a separate harness, not in a "quantize + compile" example. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/torch_trt/README.md | 13 +- ...d_compile_vit.py => torch_tensorrt_ptq.py} | 112 ++++-------------- 2 files changed, 29 insertions(+), 96 deletions(-) rename examples/torch_trt/{quantize_and_compile_vit.py => torch_tensorrt_ptq.py} (62%) diff --git a/examples/torch_trt/README.md b/examples/torch_trt/README.md index bed90d6c50f..d826f7fc3d6 100644 --- a/examples/torch_trt/README.md +++ b/examples/torch_trt/README.md @@ -28,18 +28,18 @@ the version pulled by `pip` must match your installed PyTorch. ```bash # FP8 / NVFP4 default model is google/vit-large-patch16-224 -python examples/torch_trt/quantize_and_compile_vit.py \ +python examples/torch_trt/torch_tensorrt_ptq.py \ --precision fp8/nvfp4 \ --calib_samples 128 \ --batch_size 1 # Quantize but don't TRT-compile (handy on a non-TRT host) -python examples/torch_trt/quantize_and_compile_vit.py \ +python examples/torch_trt/torch_tensorrt_ptq.py \ --precision fp8/nvfp4 \ --skip_trt # Custom model + custom recipe -python examples/torch_trt/quantize_and_compile_vit.py \ +python examples/torch_trt/torch_tensorrt_ptq.py \ --model_id \ --recipe ``` @@ -53,8 +53,9 @@ python examples/torch_trt/quantize_and_compile_vit.py \ [`modelopt_recipes/`](../../modelopt_recipes/). The default recipes target ViT; pass `--recipe ` to use a different one for a different model. -4. Compiles the quantized model with `torch_tensorrt.compile` and prints a - median-latency benchmark against the BF16 eager baseline. +4. Compiles the quantized model with `torch_tensorrt.compile` and verifies + that the compiled-model argmax matches the fake-quant argmax on a sample + input. ## ViT-specific recipes shipped with the example @@ -80,7 +81,7 @@ enable rules means no `enable: false` carve-outs are needed. Older GPUs will still let `mtq.quantize` succeed (it emits fake-quant nodes in PyTorch), but `torch_tensorrt.compile` will not find a real -low-precision kernel and the speedup column will be ~1×. +low-precision kernel. ### Resuming from a saved checkpoint diff --git a/examples/torch_trt/quantize_and_compile_vit.py b/examples/torch_trt/torch_tensorrt_ptq.py similarity index 62% rename from examples/torch_trt/quantize_and_compile_vit.py rename to examples/torch_trt/torch_tensorrt_ptq.py index e986ec6ff2d..2e65e59a269 100644 --- a/examples/torch_trt/quantize_and_compile_vit.py +++ b/examples/torch_trt/torch_tensorrt_ptq.py @@ -13,32 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Quantize a HuggingFace ViT model with ModelOpt and deploy it with Torch-TensorRT. +"""Quantize a HuggingFace ViT model with ModelOpt and compile with Torch-TensorRT. Pipeline: 1. Load ``google/vit-large-patch16-224`` (`ViTForImageClassification`) from HF. -2. Build a calibration loader from `zh-plus/tiny-imagenet` (same pattern as the - `torch_onnx` example) so the recipe runs end-to-end without ImageNet access. -3. Run ``mtq.quantize`` with one of the ViT-specific recipes - (`modelopt_recipes/huggingface/vit/ptq/`). Two non-default variants are - shipped: - - * ``fp8`` -> ``fp8_mha-classifier_skip``: W8A8 FP8 with an MHA-aware - LayerNorm output quantizer, FP8 attention BMM/softmax slots, and the - `classifier` head left in FP16. - * ``nvfp4`` -> ``nvfp4_linear-fp8_conv-classifier_skip``: NVFP4 W4A4 on - encoder Linear layers, FP8 override on the patch-embedding Conv2d (TRT - has no NVFP4 kernel for 4D Conv inputs), AWQ-lite calibration, and the - `classifier` head left in FP16. - -4. Compile the quantized model with ``torch_tensorrt.compile`` (Dynamo IR, - ``min_block_size=1``) and run an end-to-end sanity check + small benchmark - against the eager BF16 baseline. - -This script is intentionally CLI-driven and side-effect-free outside of the -optional ``--save_dir`` checkpoint. The quantized graph keeps Q/DQ nodes; the -TRT compile step is what turns them into TRT precision layers. +2. Build a calibration loader from `zh-plus/tiny-imagenet` so the recipe runs + end-to-end without ImageNet access. +3. Run ``mtq.quantize`` with one of the ViT-specific recipes under + `modelopt_recipes/huggingface/vit/ptq/` (FP8 or NVFP4). +4. Compile the quantized model with ``torch_tensorrt.compile(ir="dynamo", + min_block_size=1)`` and verify the compiled-model argmax matches the + fake-quant argmax on a sample input. + +The quantized graph keeps Q/DQ nodes; the TRT compile step is what turns +them into TRT precision layers. """ from __future__ import annotations @@ -80,13 +69,7 @@ def build_calibration_loader( device: torch.device, dtype: torch.dtype, ): - """Build a calibration tensor stream from tiny-imagenet. - - tiny-imagenet avoids the gated `ILSVRC/imagenet-1k` repo so this example - runs unauthenticated. Images go through the HF processor (resize + center - crop + ImageNet normalization), which is exactly the eval-time transform - used by the released `vit-large-patch16-224` checkpoint. - """ + """Build a calibration tensor stream from tiny-imagenet.""" print(f"Loading calibration data ({num_samples} samples)...") dataset = load_dataset("zh-plus/tiny-imagenet", split="train") dataset = dataset.shuffle(seed=42).select(range(num_samples)) @@ -96,7 +79,6 @@ def build_calibration_loader( image = sample["image"] if image.mode != "RGB": image = image.convert("RGB") - # HF image processors emit `pixel_values` of shape (1, 3, H, W). pixel_values = processor(images=image, return_tensors="pt")["pixel_values"] tensors.append(pixel_values.squeeze(0)) @@ -105,12 +87,7 @@ def build_calibration_loader( def quantize_with_recipe(model, recipe_path: str, calib_batches): - """Resolve the YAML recipe and run `mtq.quantize`. - - Returns the quantized model. The graph still uses high-precision math at - this point — Q/DQ nodes have been inserted around weights and activations - and amax values populated, but no kernel substitution has happened yet. - """ + """Resolve the YAML recipe and run `mtq.quantize`.""" print(f"Loading recipe: {recipe_path}") recipe = load_recipe(recipe_path) if not isinstance(recipe, ModelOptPTQRecipe): @@ -133,8 +110,7 @@ class ViTLogitsWrapper(torch.nn.Module): HF's `ViTForImageClassification.forward` returns an `ImageClassifierOutput` dataclass. `torch_tensorrt.compile` (and `torch.export`) need a tensor-tree - return, so we unwrap it here. The wrapper holds the quantized model as a - submodule; Q/DQ nodes flow through unchanged. + return, so we unwrap it here. """ def __init__(self, vit_model: torch.nn.Module): @@ -146,14 +122,11 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: def compile_with_torch_tensorrt(model: torch.nn.Module, example_input: torch.Tensor): - """Compile the quantized model with Torch-TensorRT (Dynamo IR). - - `min_block_size=1` follows the Torch-TRT quantization guide — it makes the - partitioner accept single-node TRT subgraphs, which is what we want so the - Q/DQ + matmul pairs become TRT precision layers instead of falling back to - eager. The compile step expects fake-quant operators in the graph; we run - it under `export_torch_mode` so modelopt's Q/DQ are exported in the - TRT-friendly form. + """Compile the quantized model with Torch-TensorRT (Dynamo IR, strongly-typed). + + `min_block_size=1` follows the Torch-TRT quantization guide so single-node + Q/DQ + matmul subgraphs become TRT precision layers. `export_torch_mode` + makes modelopt emit Q/DQ in the TRT-friendly form during `torch.export`. """ import torch_tensorrt @@ -164,34 +137,11 @@ def compile_with_torch_tensorrt(model: torch.nn.Module, example_input: torch.Ten ir="dynamo", arg_inputs=[example_input], min_block_size=1, - # The recipes export weights in BF16; TRT picks the FP8/NVFP4 - # kernel from the Q/DQ pattern, not from this list. - enabled_precisions={torch.bfloat16, torch.float16, torch.float32}, truncate_double=True, ) return trt_model -def benchmark(model: torch.nn.Module, example_input: torch.Tensor, n_warmup: int, n_iters: int): - """Median-of-`n_iters` latency over `example_input`. CUDA-event timed.""" - torch.cuda.synchronize() - with torch.no_grad(): - for _ in range(n_warmup): - model(example_input) - torch.cuda.synchronize() - - starts = [torch.cuda.Event(enable_timing=True) for _ in range(n_iters)] - ends = [torch.cuda.Event(enable_timing=True) for _ in range(n_iters)] - with torch.no_grad(): - for i in range(n_iters): - starts[i].record() - model(example_input) - ends[i].record() - torch.cuda.synchronize() - times = sorted(s.elapsed_time(e) for s, e in zip(starts, ends)) - return times[len(times) // 2] - - def _argmax_logits(out) -> torch.Tensor: """Handle either an HF `ImageClassifierOutput` or a raw tensor.""" logits = out.logits if hasattr(out, "logits") else out @@ -227,13 +177,7 @@ def main(): "--batch_size", type=int, default=1, - help="Batch size for calibration / TRT compile / benchmarking.", - ) - parser.add_argument( - "--benchmark_iters", - type=int, - default=50, - help="Number of timed iterations (after warmup) per benchmark phase.", + help="Batch size for calibration / TRT compile.", ) parser.add_argument( "--save_dir", @@ -245,7 +189,7 @@ def main(): parser.add_argument( "--skip_trt", action="store_true", - help="Quantize + run the BF16-fake-quant model only; skip torch_tensorrt.compile. " + help="Quantize + run the fake-quant model only; skip torch_tensorrt.compile. " "Useful for environments without torch_tensorrt installed.", ) args = parser.parse_args() @@ -253,8 +197,6 @@ def main(): if not torch.cuda.is_available(): raise SystemExit("This example requires a CUDA-capable GPU.") device = torch.device("cuda") - # ViT-Large is a transformer in BF16 on the released checkpoint; the Q/DQ - # nodes operate on top of BF16 master weights either way. dtype = torch.bfloat16 model, processor = load_model_and_processor(args.model_id, device, dtype) @@ -264,16 +206,10 @@ def main(): args.batch_size, num_channels, image_size, image_size, device=device, dtype=dtype ) - # Baseline forward + benchmark for a comparison number that survives - # quantization. argmax preserves the predicted-class check below. print("\n=== Baseline (BF16) ===") with torch.no_grad(): baseline_pred = _argmax_logits(model(example_input)) - baseline_latency = benchmark( - lambda x: model(x), example_input, n_warmup=5, n_iters=args.benchmark_iters - ) print(f"Baseline argmax class: {baseline_pred.tolist()}") - print(f"Baseline latency: {baseline_latency:.3f} ms (median over {args.benchmark_iters} iters)") calib_batches = build_calibration_loader( processor, args.calib_samples, args.batch_size, device, dtype @@ -289,7 +225,7 @@ def main(): mto.save(model, ckpt) print(f"Saved quantized modelopt state to {ckpt}") - print("\n=== Fake-quant (modelopt, BF16 math) ===") + print("\n=== Fake-quant (modelopt) ===") with torch.no_grad(): fq_pred = _argmax_logits(model(example_input)) fq_match = (fq_pred == baseline_pred).all().item() @@ -306,11 +242,7 @@ def main(): with torch.no_grad(): trt_pred = trt_model(example_input).argmax(dim=-1) trt_match = (trt_pred == baseline_pred).all().item() - trt_latency = benchmark(trt_model, example_input, n_warmup=5, n_iters=args.benchmark_iters) print(f"TRT argmax class: {trt_pred.tolist()} (matches baseline: {trt_match})") - print(f"TRT latency: {trt_latency:.3f} ms (median over {args.benchmark_iters} iters)") - speedup = baseline_latency / trt_latency if trt_latency > 0 else float("inf") - print(f"\nSpeedup vs. BF16 baseline: {speedup:.2f}x") if __name__ == "__main__": From cef592e8727ce55c68dd20e9f1c1123a12abd0a0 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Fri, 29 May 2026 20:43:17 +0000 Subject: [PATCH 3/4] Add e2e integration test for torch_trt example * tests/examples/torch_trt/test_torch_tensorrt_ptq.py -- mirrors the tests/examples/torch_onnx/test_torch_quant_to_onnx.py pattern: invokes the example via run_example_command, parametrizes over (fp8, nvfp4), uses a 1-layer ViT config (--no_pretrained + --model_kwargs) so the test completes in ~30 s per parametrized case. Two variants: - test_torch_tensorrt_ptq[precision] -- full e2e through torch_tensorrt.compile (importorskip on torch_tensorrt). - test_torch_tensorrt_ptq_skip_trt[precision] -- quantize-only smoke test, useful on hosts without torch_tensorrt installed. * examples/torch_trt/torch_tensorrt_ptq.py: - Add --no_pretrained + --model_kwargs flags (mirroring torch_onnx) so the same script doubles as the test entry point. - Force aten.cat.default into PyTorch fallback inside compile_with_torch_tensorrt -- torch_tensorrt 2.10's cat converter chokes on the HF ViT cls-token + patch-embedding concat (BF16: "Got unsupported ScalarType BFloat16"; FP16: rank-(-1) TRT tensor that crashes the downstream `embeddings + position_embeddings` add). The cat is a tiny [1,1,H] + [1,N,H] op that runs once per forward, so PyTorch fallback costs essentially nothing. Verified locally: pytest tests/examples/torch_trt/test_torch_tensorrt_ptq.py -> 4 passed in 103 s on RTX 6000 Ada. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/torch_trt/torch_tensorrt_ptq.py | 61 +++++++++++++-- .../torch_trt/test_torch_tensorrt_ptq.py | 78 +++++++++++++++++++ 2 files changed, 133 insertions(+), 6 deletions(-) create mode 100644 tests/examples/torch_trt/test_torch_tensorrt_ptq.py diff --git a/examples/torch_trt/torch_tensorrt_ptq.py b/examples/torch_trt/torch_tensorrt_ptq.py index 2e65e59a269..da584928865 100644 --- a/examples/torch_trt/torch_tensorrt_ptq.py +++ b/examples/torch_trt/torch_tensorrt_ptq.py @@ -33,11 +33,12 @@ from __future__ import annotations import argparse +import json from pathlib import Path import torch from datasets import load_dataset -from transformers import AutoImageProcessor, ViTForImageClassification +from transformers import AutoImageProcessor, ViTConfig, ViTForImageClassification import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq @@ -53,11 +54,30 @@ } -def load_model_and_processor(model_id: str, device: torch.device, dtype: torch.dtype): - """Pull the HF ViT classifier and its preprocessor.""" - print(f"Loading {model_id} (dtype={dtype})...") +def load_model_and_processor( + model_id: str, + device: torch.device, + dtype: torch.dtype, + pretrained: bool = True, + config_overrides: dict | None = None, +): + """Pull the HF ViT classifier and its preprocessor. + + With ``pretrained=False`` the model is built from a config with random + weights (test path); ``config_overrides`` lets the caller shrink it + (e.g. ``{"num_hidden_layers": 1, "hidden_size": 64, ...}``). The + preprocessor is always loaded from ``model_id`` since it only carries + a small JSON config. + """ + print(f"Loading {model_id} (dtype={dtype}, pretrained={pretrained})...") processor = AutoImageProcessor.from_pretrained(model_id) - model = ViTForImageClassification.from_pretrained(model_id, torch_dtype=dtype) + if pretrained: + model = ViTForImageClassification.from_pretrained(model_id, torch_dtype=dtype) + else: + config = ViTConfig.from_pretrained(model_id) + for k, v in (config_overrides or {}).items(): + setattr(config, k, v) + model = ViTForImageClassification(config).to(dtype) model.eval().to(device) return model, processor @@ -131,6 +151,13 @@ def compile_with_torch_tensorrt(model: torch.nn.Module, example_input: torch.Ten import torch_tensorrt print("Compiling with torch_tensorrt.compile (Dynamo IR)...") + # `aten.cat.default` is force-executed in PyTorch because torch_tensorrt + # 2.10's cat converter chokes on the cls-token + patch-embedding concat + # in HF ViT (BFloat16 path: `TypeError: Got unsupported ScalarType + # BFloat16`; FP16 path: rank-(-1) TRT tensor that trips the downstream + # `embeddings + position_embeddings` add). The cat is a tiny [1,1,H] + # + [1,N,H] concat that runs once per forward, so falling back to + # PyTorch costs essentially nothing. with export_torch_mode(): trt_model = torch_tensorrt.compile( model, @@ -138,6 +165,7 @@ def compile_with_torch_tensorrt(model: torch.nn.Module, example_input: torch.Ten arg_inputs=[example_input], min_block_size=1, truncate_double=True, + torch_executed_ops={torch.ops.aten.cat.default}, ) return trt_model @@ -192,6 +220,20 @@ def main(): help="Quantize + run the fake-quant model only; skip torch_tensorrt.compile. " "Useful for environments without torch_tensorrt installed.", ) + parser.add_argument( + "--no_pretrained", + action="store_true", + help="Build the model from config with random weights instead of " + "downloading pretrained weights. Useful for fast e2e tests.", + ) + parser.add_argument( + "--model_kwargs", + type=str, + default=None, + help="JSON string of ViTConfig overrides applied when --no_pretrained " + 'is set (e.g. \'{"num_hidden_layers": 1, "hidden_size": 64, ' + '"intermediate_size": 128, "num_attention_heads": 2}\').', + ) args = parser.parse_args() if not torch.cuda.is_available(): @@ -199,7 +241,14 @@ def main(): device = torch.device("cuda") dtype = torch.bfloat16 - model, processor = load_model_and_processor(args.model_id, device, dtype) + config_overrides = json.loads(args.model_kwargs) if args.model_kwargs else None + model, processor = load_model_and_processor( + args.model_id, + device, + dtype, + pretrained=not args.no_pretrained, + config_overrides=config_overrides, + ) image_size = model.config.image_size num_channels = model.config.num_channels example_input = torch.randn( diff --git a/tests/examples/torch_trt/test_torch_tensorrt_ptq.py b/tests/examples/torch_trt/test_torch_tensorrt_ptq.py new file mode 100644 index 00000000000..3b540cd00e8 --- /dev/null +++ b/tests/examples/torch_trt/test_torch_tensorrt_ptq.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import pytest +from _test_utils.examples.run_command import extend_cmd_parts, run_example_command + +# Recipe variants the example ships. Mirrors the parametrization style of +# ``tests/examples/torch_onnx/test_torch_quant_to_onnx.py``. +_PRECISIONS = ["fp8", "nvfp4"] + +# Tiny ViT config (~1 encoder block) so the test stays under a few seconds +# of GPU time while exercising every code path the recipe touches: encoder +# Linear weight/input quantizers, attention BMM + softmax quantizers, +# per-block LayerNorm output quantizer, and the patch-embed Conv / final +# vit.layernorm / classifier skip rules. +_TINY_VIT_KWARGS = { + "num_hidden_layers": 1, + "hidden_size": 64, + "intermediate_size": 128, + "num_attention_heads": 2, +} + + +@pytest.mark.parametrize("precision", _PRECISIONS) +def test_torch_tensorrt_ptq(precision): + """End-to-end: load tiny ViT -> mtq.quantize via recipe -> torch_tensorrt.compile. + + Runs against the smallest viable ``ViTForImageClassification`` config so + the test stays fast; ``--no_pretrained`` skips the multi-GB pretrained + download. The example's CLI exits non-zero if any step (calibration, + quantization, TRT compile) fails or if the compiled-model argmax doesn't + match the fake-quant argmax on the sample input. + """ + pytest.importorskip("torch_tensorrt") + + cmd_parts = extend_cmd_parts( + ["python", "torch_tensorrt_ptq.py"], + model_id="google/vit-base-patch16-224", + precision=precision, + calib_samples="4", + batch_size="1", + model_kwargs=json.dumps(_TINY_VIT_KWARGS), + ) + cmd_parts.append("--no_pretrained") + run_example_command(cmd_parts, "torch_trt") + + +@pytest.mark.parametrize("precision", _PRECISIONS) +def test_torch_tensorrt_ptq_skip_trt(precision): + """Quantize-only smoke test (no torch_tensorrt.compile). + + Useful on hosts without ``torch_tensorrt`` installed and as a faster + sanity check that just exercises the recipe + ``mtq.quantize`` path. + """ + cmd_parts = extend_cmd_parts( + ["python", "torch_tensorrt_ptq.py"], + model_id="google/vit-base-patch16-224", + precision=precision, + calib_samples="4", + batch_size="1", + model_kwargs=json.dumps(_TINY_VIT_KWARGS), + ) + cmd_parts.extend(["--no_pretrained", "--skip_trt"]) + run_example_command(cmd_parts, "torch_trt") From 7db9fad204619282246f6bf801ab0811efbd7b05 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Fri, 29 May 2026 20:52:31 +0000 Subject: [PATCH 4/4] Remove quantize-only smoke variant from torch_trt e2e test Drops `test_torch_tensorrt_ptq_skip_trt` -- the full `test_torch_tensorrt_ptq` variant already exercises the same mtq.quantize path and goes further (torch_tensorrt.compile). The skip-variant added duplicate CI runtime without unique coverage. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- .../torch_trt/test_torch_tensorrt_ptq.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/tests/examples/torch_trt/test_torch_tensorrt_ptq.py b/tests/examples/torch_trt/test_torch_tensorrt_ptq.py index 3b540cd00e8..f0350d525ad 100644 --- a/tests/examples/torch_trt/test_torch_tensorrt_ptq.py +++ b/tests/examples/torch_trt/test_torch_tensorrt_ptq.py @@ -57,22 +57,3 @@ def test_torch_tensorrt_ptq(precision): ) cmd_parts.append("--no_pretrained") run_example_command(cmd_parts, "torch_trt") - - -@pytest.mark.parametrize("precision", _PRECISIONS) -def test_torch_tensorrt_ptq_skip_trt(precision): - """Quantize-only smoke test (no torch_tensorrt.compile). - - Useful on hosts without ``torch_tensorrt`` installed and as a faster - sanity check that just exercises the recipe + ``mtq.quantize`` path. - """ - cmd_parts = extend_cmd_parts( - ["python", "torch_tensorrt_ptq.py"], - model_id="google/vit-base-patch16-224", - precision=precision, - calib_samples="4", - batch_size="1", - model_kwargs=json.dumps(_TINY_VIT_KWARGS), - ) - cmd_parts.extend(["--no_pretrained", "--skip_trt"]) - run_example_command(cmd_parts, "torch_trt")