diff --git a/experimental/nvfp4_global_scale_study/README.md b/experimental/nvfp4_global_scale_study/README.md new file mode 100644 index 00000000000..0bd66902587 --- /dev/null +++ b/experimental/nvfp4_global_scale_study/README.md @@ -0,0 +1,289 @@ +# NVFP4 Global-Scale (g_amax) Study + +A self-contained numerical study of how the **per-tensor global scale** in NVFP4 +affects quantization error, and how to **calibrate it** — especially for +activations, where calibration data may not cover the true inference dynamic +range. + +Everything here is reproduced by [`nvfp4_global_scale_study.py`](./nvfp4_global_scale_study.py), +which drives the **real** `modelopt.torch.quantization.qtensor.nvfp4_tensor.NVFP4QTensor` +quantize/dequantize path (not a re-implementation) and cross-checks it against +the closed-form math. + +```bash +python nvfp4_global_scale_study.py +``` + +## TL;DR + +- NVFP4 uses two-level scaling: a per-tensor `global_scale` (set by a global + amax `g_amax`) and a per-16-element FP8-E4M3 `block_scale` (set by each + block's amax `b_amax`). +- `g_amax` does **not** set element resolution — the per-block scale already + normalizes every block to the e2m1 range `[-6, 6]`. `g_amax` only decides + **where each block's FP8 scale lands in the FP8 range**, i.e. which blocks + fall out of the well-conditioned "normal FP8" zone. +- Therefore choosing `g_amax` is a **range-only, second-order** decision with a + closed-form feasible window. Pick `g_amax` anywhere in `[B_max, 28672·B_min]` + and you are essentially optimal. + +## 1. The math (verified) + +For an element `e` in a block with block-amax `b_amax`, in a tensor with global +amax `g_amax`: + +```text +global_scale = g_amax / (6 · 448) +block_scale = fp8_e4m3( b_amax / (6 · global_scale) ) # = fp8(b_amax · 448 / g_amax) + clamped to the FP8-E4M3 range [2^-9, 448] +deq(e) = snap_e2m1( e / (block_scale · global_scale) ) · (block_scale · global_scale) +``` + +The product collapses to a clean value when the block scale is neither clamped +nor (heavily) FP8-rounded: + +```text +block_scale · global_scale -> b_amax / 6 +``` + +so `scaled = e / (b_amax/6) = 6e/b_amax`, mapping the block's largest-magnitude +element onto the e2m1 max value `6`. + +`nvfp4_global_scale_study.py` **Part 1** asserts this against the live code over +15 scenarios (`ALL real==manual: True`). + +## 2. Error vs. g_amax — three regimes + +![error vs g_amax](./error_vs_gamax.png) + +Each panel locks `(e, b_amax)` and sweeps `g_amax` from `1e-2` to `1e6`. +Colored points = signed element error `deq - e` (left axis); gray points = +**relative** FP8 block-scale error `(fp8(x) - x)/x`, `x = b_amax·448/g_amax` +(right axis). Dotted vertical = the natural single-block choice `g_amax = b_amax`. + +| Regime | Condition (`ρ = g_amax / b_amax`) | Block-scale `x` | What happens | +|---|---|---|---| +| **1. Saturation** | `ρ < 1` | `x > 448` → upper clamp | block's large values get clipped — **large error** | +| **2. Well-conditioned** | `1 ≤ ρ ≤ 28672` | normal FP8 | error is **flat** = e2m1 grid floor + bounded ±6.25% FP8 wobble | +| **3a. Subnormal** | `28672 < ρ ≤ 229376` | subnormal FP8 | block-scale precision degrades, error grows | +| **3b. Underflow** | `ρ > 229376` | `x < 2^-9` → lower clamp | block scale floors, element re-zeroed — **catastrophic** | + +Key subtlety the plot makes concrete: the gray *absolute* FP8 error → 0 at large +`g_amax` only because the scale magnitude shrinks; the **relative** FP8 error +(what actually perturbs the result) stays bounded by the E4M3 mantissa step +(`2^-4 = 6.25%` worst case) and only hits exactly 0 at discrete `g_amax` where +`b_amax·448/g_amax` is an exact E4M3 value. Those exact points are where +`deq - e` touches its grid floor. + +### The "fixed zone" boundaries (regime 2) + +The well-conditioned band is where the block scale stays in **normal FP8** +(`2^-6 ≤ x ≤ 448`). Solving for `g_amax`: + +```text +B_max ≤ g_amax ≤ 28672 · B_min + ^ ^ + x = 448 edge x = 2^-6 edge (28672 = 448 · 64 = 7 · 2^12) +``` + +- The window **width is always 28672× (≈4.46 decades)** in `g_amax`, regardless + of `b_amax`. +- It simply **slides right proportionally to `b_amax`**. This is exactly why, in + the figure, larger-`b_amax` panels have their plateau shifted right. +- Below it, subnormal starts at `28672·B_min` and the lower clamp at + `229376·B_min` (`= 448·512 = 7·2^15`). + +#### Where 28672 comes from + +`28672` is not arbitrary — it is the **dynamic range of normal (non-subnormal) +FP8-E4M3 values**, `max_normal / min_normal`. Derivation: + +1. The stored block scale (pre-clamp) is `x = b_amax · 448 / g_amax`, and the + well-conditioned regime requires `x` in normal FP8: `2^-6 ≤ x ≤ 448`. +2. Substituting and solving for `g_amax`: + + ```text + x ≤ 448 -> g_amax ≥ b_amax + x ≥ 2^-6 -> g_amax ≤ b_amax · 448 · 2^6 + ``` + + so the window is `[b_amax, 448·2^6·b_amax]` and its width factor is + `448 · 64 = 28672`. Equivalently, since `x ∝ 1/g_amax`, the window width is + just the range of `x` that normal FP8 spans, `448 / 2^-6`. +3. The two FP8-E4M3FN landmarks come from its bit layout (1 sign, 4 exponent, + 3 mantissa, bias 7; value `(1 + m/8)·2^(e-7)` for normal `e ∈ [1,15]`): + - **min normal**: `e=1, m=0` → `2^(1-7) = 2^-6 = 0.015625` + - **max normal**: `e=15, m=6` (the "FN" variant reuses `e=15` for finite + values; only `S.1111.111` is NaN) → `(1+6/8)·2^8 = 1.75·256 = 448` + +So `28672 = 448 / 2^-6 = 1.75·2^8·2^6 = 1.75·2^14 = 7·2^12`. + +The sibling constant `229376` uses the smallest *subnormal* `2^-9` instead of +the smallest normal: `448 / 2^-9 = 448·512 = 7·2^15 = 229376` — the edge of +representability, below which the block scale floors and the block is zeroed. +In short: the two magic numbers are simply FP8-E4M3's normal and full dynamic +ranges, because the block scale is what is stored in FP8. + +## 3. The regimes, in one curve: FP8 block-scale error vs. b_amax/g_amax + +![block-scale error across regimes](./error_vs_ratio.png) + +The ideal block scale is `bscale = b_amax·448/g_amax = 448·t` with +`t = b_amax/g_amax = 1/ρ`, so the **relative** FP8 quantization error of the +block scale, `(fp8(bscale) − bscale)/bscale`, depends **only on `t`** — a single +curve that exposes every regime at once (y-axis is symlog): + +| Region (`t = b_amax/g_amax`) | `ρ = g_amax/b_amax` | Behaviour | +|---|---|---| +| `t > 1` | `ρ < 1` | **Saturation** — `bscale` clamps to `448`; rel err → `−1` as the true scale runs away above the clamp. The block's large values get clipped. | +| `1/28672 ≤ t ≤ 1` (shaded) | `1 ≤ ρ ≤ 28672` | **Normal FP8** — bounded relative error `≤ 6.25%` (the E4M3 mantissa step); touches 0 at exact E4M3 values. This is the well-conditioned zone. | +| `1/229376 ≤ t < 1/28672` | `28672 < ρ ≤ 229376` | **Subnormal** — the "fan" of widening FP8 steps; rel err grows. | +| `t < 1/229376` | `ρ > 229376` | **Lower clamp** — `bscale` floors at `2⁻⁹`; rel err → `+∞` (`+42×` at `t=1e-7`), block effectively zeroed. | + +This is the per-tensor view of the same boundaries derived in §2: the shaded +band is exactly the `[B_max, 28672·B_min]` window mapped onto the block-amax +ratio. It makes the asymmetry used in calibration (§4) visually obvious — +saturation drives the error hard toward `−1` (catastrophic), while the +subnormal side degrades gracefully until the lower clamp. + +### Aside: the e2m1 grid dead zone (independent of g_amax) + +Separately from the scale, the **4-bit e2m1 grid** imposes a hard floor on +element error. With an ideal scale (`g_amax = b_amax`), an element sits at +`scaled = 6·(e/b_amax)`; the smallest nonzero e2m1 level is `0.5` (rounding +boundary `0.25`), so + +```text +|e| < b_amax / 24 -> rounds to 0 (the element is annihilated) +``` + +Per-block dynamic range is only ~24×: any element more than 24× smaller than its +block's max is lost — independent of `g_amax`. This is the core reason small +blocks (16) and outlier handling matter, and why `g_amax` calibration cannot +rescue small-`e`/large-`b_amax` loss. + +## 4. Calibrating g_amax (activations) + +Because the **per-block scale is recomputed dynamically at inference**, an unseen +activation pattern can only hurt through one scalar per block: the block amax +`b_i` relative to the static `g_amax`. You don't need calibration to cover the +full activation distribution — only the **range of block amaxes**. + +For a fixed `g_amax`, a runtime block is safe (normal FP8) iff its amax falls in +a fixed 28672×-wide window: + +```text +g_amax / 28672 ≤ b ≤ g_amax +``` + +The two ways to fall out are **very asymmetric**: + +| Failure | Trigger | Severity | +|---|---|---| +| **Saturation** | `b > g_amax` (block bigger than expected) | catastrophic — clips the largest/outlier activations | +| **Subnormal/underflow** | `b < g_amax/28672` | graceful — small blocks, small absolute error | + +So **bias `g_amax` high**: spend most of the 28672× budget as upward headroom. + +### Recipe + +1. From calibration, collect the per-block (16-wide) amax distribution and take + robust statistics — not a single-batch max: + - `B_max` = high percentile / EMA running max (e.g. 99.99th). + - `B_min` = low percentile of blocks you care to represent (e.g. 1st); + near-zero blocks are fine — they go gracefully subnormal. +2. Feasible (no-saturation + no-subnormal) window: `[B_max, 28672·B_min]`. + Available **slack** (upward margin budget) `= 28672 / (B_max/B_min)`. +3. Choose inside the window, biased upward for outlier insurance: + - balanced / log-center: `g_amax = sqrt(B_max · 28672·B_min)` + - upward-biased (recommended): `g_amax ≈ B_max · slack^0.65` +4. If `B_max/B_min > 28672` (slack < 1): no single `g_amax` covers the range — + fix the **range** (SmoothQuant-style outlier migration, per-channel scaling, + or higher-precision fallback for outlier channels), not `g_amax`. +5. Optionally refine with a 1-D (MSE or Hessian-weighted) search **constrained to + the feasible window** so it can never pick a value that saturates the tail. + +### Worked example + +Calibration block amaxes span `B_min = 0.5`, `B_max = 1000` +(dynamic range 2000 ≈ 3.3 decades): + +- slack `= 28672 / 2000 ≈ 14.3×` +- feasible window `[1000, 14336]` +- **recommended `g_amax ≈ 5000`** (`≈ B_max · slack^0.65`): ~5× outlier + insurance before any block saturates, while the smallest 0.5-amax blocks stay + well inside normal FP8 (`ρ = 5000/0.5 = 10000 < 28672`). +- avoid `g_amax = 1000` (zero headroom — likely saturation at inference) and + avoid `g_amax = 14336` (zero subnormal cushion, no benefit). + +### Robustness when you can't predict B_max: anchor to B_min + +The recipe above anchors to `B_max`, but in practice **calibration rarely +bounds the inference `B_max`** — it is outlier-driven and heavy-tailed, so real +inputs routinely exceed it, causing saturation (the catastrophic failure). The +*stable* end is `B_min`: the bulk/floor activation scale is governed by the +preceding normalization layer and is far more consistent across data. + +So a more robust formulation **anchors the bottom edge of the normal window at +`B_min`** and extends the full width upward, handing the format its entire +dynamic range as outlier insurance — without predicting `B_max` at all: + +```text +g_amax = rho · B_min, rho up to 28672 (recommended ~16384) +``` + +- Larger `rho` → larger `g_amax` → more saturation headroom. Within the normal + window the element error is *flat*, so this costs nothing for blocks that stay + normal; it only trades **downward margin** (blocks below `B_min` tip into the + graceful subnormal regime). +- `rho ≈ 16384` (just under the 28672 cliff) keeps a ~1.75× cushion so a + moderate inference drift in `B_min` doesn't immediately subnormal those blocks, + while still giving ~16000× upward headroom. + +Guardrails: use a robust low-percentile `B_min` (not the literal min, which can +be ~0); keep `g_amax ≥ margin·B_max_calib` as a sanity floor; and if +`B_max/B_min > 28672` the range exceeds the format — fix it with outlier +mitigation, not `g_amax`. + +### Verifying it: robustness to unseen outliers + +![calibration strategy robustness](./calib_strategy.png) + +Part 4 of the script simulates a stable lognormal block-amax bulk plus outlier +blocks that **grow by a factor `k` at inference** (unseen during calibration), +and compares quantization MSE for `B_max`-anchored vs `B_min`-anchored `g_amax` +against an oracle that knows the true inference `B_max`: + +| inference outlier growth `k` | MSE `B_max`-anchored | MSE `B_min`-anchored | MSE oracle | +|---|---|---|---| +| 1.0 | 0.0092 | 0.0093 | 0.0090 | +| 3.2 | 0.139 | 0.0263 | 0.0259 | +| 10 | **7.87** | 0.189 | 0.192 | +| 32 | **156** | 1.96 | 1.86 | + +The `B_min`-anchored choice **tracks the oracle almost exactly** across the whole +range, while the `B_max`-anchored choice saturates as soon as outliers grow and +degrades by 40×+ — concrete confirmation that anchoring to the stable `B_min` +and using the format's full window is the robust default for activations. + +## Why this differs from INT8/FP8 per-tensor calibration + +For INT8/FP8 the per-tensor scale directly trades range vs. resolution, so its +choice is first-order. For NVFP4 the per-block scale already owns resolution; +`g_amax` is a range-only knob with a wide (~4.46-decade) safe window. NVFP4 is +consequently robust to the global-amax choice across a very wide range — the +remaining error is dominated by the irreducible e2m1 grid, not by `g_amax`. + +## Files + +| File | Description | +|---|---| +| `nvfp4_global_scale_study.py` | Reproduces all numbers and figures against the live `NVFP4QTensor` code path | +| `error_vs_gamax.png` | Signed error & relative FP8 block-scale error vs. `g_amax` (3×3 grid of `(e, b_amax)` cases) | +| `error_vs_ratio.png` | Relative FP8 block-scale error vs. `b_amax/g_amax` (all four regimes in one curve) | +| `calib_strategy.png` | Robustness to unseen outliers: `B_max`- vs `B_min`-anchored `g_amax` (MSE vs outlier growth) | + +## References + +- `modelopt/torch/quantization/qtensor/nvfp4_tensor.py` — the NVFP4 implementation this study exercises. +- FP8 E4M3FN: max normal `448`, min normal `2^-6`, min subnormal `2^-9`. diff --git a/experimental/nvfp4_global_scale_study/calib_strategy.png b/experimental/nvfp4_global_scale_study/calib_strategy.png new file mode 100644 index 00000000000..273e93fad1e Binary files /dev/null and b/experimental/nvfp4_global_scale_study/calib_strategy.png differ diff --git a/experimental/nvfp4_global_scale_study/error_vs_gamax.png b/experimental/nvfp4_global_scale_study/error_vs_gamax.png new file mode 100644 index 00000000000..8b798fe7e63 Binary files /dev/null and b/experimental/nvfp4_global_scale_study/error_vs_gamax.png differ diff --git a/experimental/nvfp4_global_scale_study/error_vs_ratio.png b/experimental/nvfp4_global_scale_study/error_vs_ratio.png new file mode 100644 index 00000000000..d25d4a2cbe8 Binary files /dev/null and b/experimental/nvfp4_global_scale_study/error_vs_ratio.png differ diff --git a/experimental/nvfp4_global_scale_study/nvfp4_global_scale_study.py b/experimental/nvfp4_global_scale_study/nvfp4_global_scale_study.py new file mode 100644 index 00000000000..f7efd7cc075 --- /dev/null +++ b/experimental/nvfp4_global_scale_study/nvfp4_global_scale_study.py @@ -0,0 +1,439 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Study NVFP4 quantize->dequantize error vs. the choice of global amax (g_amax). + +Reproduces the figures in this directory's README: + Part 1: prove the hand-derived math matches the REAL NVFP4QTensor code path. + Part 2: lock (e, b_amax), sweep g_amax, plot signed error dequant(quant(e)) - e. + Part 3: relative FP8 block-scale error vs. b_amax/g_amax (shows all regimes). + Part 4: activation calibration robustness — B_max- vs B_min-anchored g_amax. + +Run from anywhere: python nvfp4_global_scale_study.py +""" + +import os + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import torch + +from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor + +HERE = os.path.dirname(os.path.abspath(__file__)) +BLOCK = 16 +torch.set_printoptions(precision=8) + + +def build_block(e: float, b_amax: float) -> torch.Tensor: + """One block of 16 elements whose abs-max is exactly b_amax and that contains e. + + Element 0 = e, element 1 = +b_amax (forces the block amax), rest = 0. + Requires |e| <= b_amax. + """ + assert abs(e) <= b_amax + 1e-12, "need |e| <= b_amax" + blk = torch.zeros(1, BLOCK, dtype=torch.float32) + blk[0, 0] = e + blk[0, 1] = b_amax + return blk + + +def real_code_qdq(e: float, b_amax: float, g_amax: float): + """Run the actual NVFP4QTensor quantize + dequantize. Returns (deq_e, block_scale_fp8, prod).""" + blk = build_block(e, b_amax) + # Inject g_amax independently of the tensor's own amax via weights_scaling_factor_2. + wsf2 = torch.tensor(g_amax / (6.0 * 448.0), dtype=torch.float32) + + qtensor, wsf, wsf2_out = NVFP4QTensor.quantize( + blk, block_size=BLOCK, weights_scaling_factor_2=wsf2 + ) + deq = qtensor.dequantize( + dtype=torch.float32, + scale=wsf, # per-block scale (fp8 e4m3) + double_scale=wsf2_out, # global / per-tensor scale + block_sizes={-1: BLOCK}, + ) + deq_e = deq[0, 0].item() + block_scale_fp8 = wsf.float().flatten()[0].item() + prod = block_scale_fp8 * wsf2_out.item() # effective divisor used in (de)quant + return deq_e, block_scale_fp8, prod + + +def manual_qdq(e: float, b_amax: float, g_amax: float): + """Hand-derived math, reusing only the e2m1 snapping primitive (_cast_fp4).""" + global_scale = g_amax / (6.0 * 448.0) + # block_scale (high precision) = b_amax / (6 * global_scale) = b_amax*448/g_amax + block_scale_hp = b_amax / (6.0 * global_scale) + # stored as fp8 e4m3, clamped to [2**-9, 448] + block_scale_fp8 = ( + torch.tensor(block_scale_hp, dtype=torch.float32) + .clamp(min=2**-9, max=448.0) + .to(torch.float8_e4m3fn) + .float() + .item() + ) + prod = block_scale_fp8 * global_scale # effective divisor + scaled = torch.tensor([[e / prod]], dtype=torch.float32) + + # snap to e2m1 grid using the SAME primitive the library uses + codes = NVFP4QTensor._cast_fp4(scaled.clone()) + snapped = NVFP4QTensor.get_e2m1_values("cpu")[codes.long()].flatten()[0].item() + deq_e = snapped * prod + return deq_e, block_scale_fp8, prod + + +# ---------------------------------------------------------------------------- +# PART 1 — numeric proof on concrete scenarios +# ---------------------------------------------------------------------------- +print("=" * 100) +print("PART 1: real NVFP4 code vs. hand-derived math") +print("=" * 100) + +scenarios = [ + # (e, b_amax, g_amax) + (0.37, 0.50, 1.0), + (0.37, 0.50, 0.5), # g_amax == b_amax (block is the global max block) + (0.37, 0.50, 4.0), # large g_amax -> tiny block scale, fp8 precision loss + (0.37, 0.50, 0.05), # g_amax < b_amax -> block scale wants >448, gets clamped + (-2.9, 3.0, 6.0), + (1.234, 5.0, 5.0), + (0.018, 0.02, 12.0), # extreme: block scale underflow toward 2**-9 clamp + # ---- small e + large b_amax, sweeping the ratio r = e/b_amax (g_amax = b_amax) ---- + (3.0, 6.0, 6.0), # r = 1/2 -> scaled 3.0 (on-grid) + (1.5, 6.0, 6.0), # r = 1/4 -> scaled 1.5 (on-grid) + (0.75, 6.0, 6.0), # r = 1/8 -> scaled 0.75 (tie bound, round-to-even) + (0.30, 6.0, 6.0), # r = 1/20 -> scaled 0.30 -> snaps to 0.5, big rel err + (0.10, 6.0, 6.0), # r = 1/60 -> scaled 0.10 -> snaps to 0, total loss + (0.02, 6.0, 6.0), # r = 1/300 -> scaled 0.02 -> snaps to 0, total loss + (0.30, 30.0, 30.0), # tiny ratio, large b_amax (same r=1/100 behaviour) + (0.30, 30.0, 60.0), # same but g_amax 2x b_amax -> fp8 block-scale error adds +] + +hdr = ( + f"{'e':>8} {'b_amax':>7} {'g_amax':>7} | {'deq(real)':>12} {'deq(manual)':>12} | " + f"{'bscale_fp8':>11} {'prod':>12} | {'err(real)':>11} {'match?':>7}" +) +print(hdr) +print("-" * len(hdr)) +all_match = True +for e, b, g in scenarios: + dr, bs_r, pr = real_code_qdq(e, b, g) + dm, bs_m, pm = manual_qdq(e, b, g) + err = dr - e # signed error (deq - e) + match = abs(dr - dm) < 1e-6 and abs(pr - pm) <= 1e-6 * max(1.0, abs(pr)) + all_match &= match + print( + f"{e:>8.4f} {b:>7.3f} {g:>7.3f} | {dr:>12.6f} {dm:>12.6f} | " + f"{bs_r:>11.5f} {pr:>12.8f} | {err:>11.6f} {match!s:>7}" + ) + +print("-" * len(hdr)) +print(f"ALL real==manual: {all_match}") +print() +print( + "Note prod = block_scale_fp8 * global_scale; analytically prod -> b_amax/6 " + "when no fp8 clamp/rounding." +) +for e, b, g in scenarios: + _, _, pr = real_code_qdq(e, b, g) + print( + f" b_amax={b:>5.3f}: prod={pr:.8f} vs b_amax/6={b / 6:.8f} (ratio {pr / (b / 6):.5f})" + ) + +# ---------------------------------------------------------------------------- +# PART 2 — lock (e, b_amax), sweep g_amax, plot error +# ---------------------------------------------------------------------------- +print() +print("=" * 100) +print("PART 2: error vs g_amax (e and b_amax locked)") +print("=" * 100) + +cases = [ + (0.37, 0.50), + (0.123, 2.00), + (1.70, 2.00), + # small e + large b_amax (various ratios) + (0.30, 6.00), # r = 1/20 -> in the e2m1 dead zone for moderate g_amax + (0.02, 6.00), # r = 1/300 -> deep dead zone + (0.30, 30.00), # tiny ratio, large b_amax + # very large b_amax = 1000, e from 100 / 500 / 900 + (100.0, 1000.0), # r = 0.1 + (500.0, 1000.0), # r = 0.5 + (900.0, 1000.0), # r = 0.9 +] + +g_grid = torch.logspace(-2, 6, 800).tolist() # g_amax from 0.01 to 1e6 + +# One subplot per (e, b_amax) case. +ncols = 3 +nrows = (len(cases) + ncols - 1) // ncols +fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 3.6 * nrows), squeeze=False) +for idx, (e, b) in enumerate(cases): + ax = axes[idx // ncols][idx % ncols] + errs = [] + scale_rel_errs = [] # RELATIVE fp8 block-scale error: (fp8(x) - x) / x, x = b*448/g + for g in g_grid: + dr, bs, _ = real_code_qdq(e, b, g) + errs.append(dr - e) # signed error (deq - e) + ideal_block_scale = b * 448.0 / g + scale_rel_errs.append((bs - ideal_block_scale) / ideal_block_scale) + ax.plot(g_grid, errs, ".", color=f"C{idx}", markersize=3, label="deq - e") + # RELATIVE FP8 block-scale quant error on a secondary axis (this is what perturbs prod) + ax_r = ax.twinx() + ax_r.plot( + g_grid, + scale_rel_errs, + ".", + color="gray", + markersize=2, + alpha=0.7, + label="(fp8(bscale)-bscale)/bscale", + ) + ax_r.set_ylabel("rel. fp8 bscale err", color="gray", fontsize=8) + ax_r.tick_params(axis="y", labelcolor="gray", labelsize=7) + ax_r.axhline(0, color="gray", ls="-", lw=0.4, alpha=0.5) + ax.axvline(b, color="gray", ls=":", lw=0.8) # natural choice g_amax = b_amax + ax.axhline(0, color="black", lw=0.8) + ax.set_xscale("log") + ax.set_xlabel("g_amax [log]") + ax.set_ylabel("deq - e (signed err)") + ax.set_title(f"e={e}, b_amax={b}") + ax.grid(True, which="both", ls="--", alpha=0.3) + ax.set_zorder(ax_r.get_zorder() + 1) # keep deq-e curve on top + ax.patch.set_visible(False) + lines = ax.get_lines()[:1] + ax_r.get_lines()[:1] + ax.legend(lines, [ln.get_label() for ln in lines], fontsize=7, loc="best") +# Hide any unused panels. +for idx in range(len(cases), nrows * ncols): + axes[idx // ncols][idx % ncols].axis("off") +fig.suptitle("NVFP4 quantize→dequantize signed error vs. choice of global amax") +fig.tight_layout() +out = os.path.join(HERE, "error_vs_gamax.png") +fig.savefig(out, dpi=130) +print(f"saved plot -> {out}") + +# Also print a small table for one case so the trend is visible in text. +e, b = cases[0] +print(f"\nSample (e={e}, b_amax={b}):") +print(f"{'g_amax':>10} {'block_scale_fp8':>16} {'prod':>12} {'deq_e':>12} {'signed_err':>12}") +for g in [0.01, 0.05, 0.1, b, 0.7, 1.0, 2.0, 10.0, 50.0, 100.0]: + dr, bs, pr = real_code_qdq(e, b, g) + print(f"{g:>10.3f} {bs:>16.6f} {pr:>12.8f} {dr:>12.6f} {dr - e:>12.6f}") + +# ---------------------------------------------------------------------------- +# PART 3 — relative FP8 block-scale error vs the ratio b_amax / g_amax. +# block_scale (ideal) = b_amax*448/g_amax = 448 * (b_amax/g_amax), so the +# relative quant error depends ONLY on t = b_amax/g_amax -> one curve that +# cleanly shows every regime (saturation / normal / subnormal / underflow). +# ---------------------------------------------------------------------------- +print() +print("=" * 100) +print("PART 3: relative FP8 block-scale error (fp8(bscale)-bscale)/bscale vs b_amax/g_amax") +print("=" * 100) + +# FP8-E4M3FN landmarks and the resulting regime boundaries in t = b_amax/g_amax. +FP8_MAX = 448.0 # max normal +FP8_MIN_NORMAL = 2.0**-6 # min normal +FP8_MIN_SUBNORMAL = 2.0**-9 # min subnormal == lower clamp +T_SAT = FP8_MAX / FP8_MAX # = 1.0 : bscale hits 448 (upper clamp) for t > 1 +T_SUBNORMAL = FP8_MIN_NORMAL / FP8_MAX # = 1/28672 : normal -> subnormal +T_LOWER_CLAMP = FP8_MIN_SUBNORMAL / FP8_MAX # = 1/229376 : subnormal -> lower clamp + + +def block_scale_rel_err(t: float) -> float: + """(fp8(bscale) - bscale)/bscale for ideal bscale = 448*t, using the library clamp+cast.""" + bscale = FP8_MAX * t + fp8 = ( + torch.tensor(bscale, dtype=torch.float32) + .clamp(min=FP8_MIN_SUBNORMAL, max=FP8_MAX) + .to(torch.float8_e4m3fn) + .float() + .item() + ) + return (fp8 - bscale) / bscale + + +t_grid = torch.logspace(-7, 2, 1500).tolist() # b_amax/g_amax from 1e-7 to 1e2 +rel_errs = [block_scale_rel_err(t) for t in t_grid] + +fig2, ax2 = plt.subplots(figsize=(10, 5.5)) +ax2.plot(t_grid, rel_errs, ".", color="C0", markersize=2.5) +ax2.axhline(0, color="black", lw=0.8) + +# Regime boundaries + shaded normal zone. +for tb, lbl in [ + (T_SAT, "t=1 (upper clamp)"), + (T_SUBNORMAL, "t=1/28672"), + (T_LOWER_CLAMP, "t=1/229376"), +]: + ax2.axvline(tb, color="gray", ls="--", lw=0.9) + ax2.text( + tb, + 0.92, + lbl, + rotation=90, + va="top", + ha="right", + fontsize=7, + transform=ax2.get_xaxis_transform(), + color="gray", + ) +ax2.axvspan(T_SUBNORMAL, T_SAT, color="green", alpha=0.07) +# Regime labels. +for xpos, txt in [ + (3.0, "saturation\n(values clipped)"), + (3e-3, "normal FP8\n|rel err| <= 6.25%"), + (3e-6, "subnormal"), + (2e-7, "lower\nclamp"), +]: + ax2.text( + xpos, 0.5, txt, ha="center", va="center", fontsize=7.5, transform=ax2.get_xaxis_transform() + ) + +ax2.set_xscale("log") +ax2.set_yscale("symlog", linthresh=0.1) +ax2.set_xlabel("b_amax / g_amax [log scale] (= 1 / rho)") +ax2.set_ylabel("(fp8(bscale) - bscale) / bscale [symlog]") +ax2.set_title("NVFP4 relative FP8 block-scale quantization error across regimes") +ax2.grid(True, which="both", ls="--", alpha=0.3) +fig2.tight_layout() +out2 = os.path.join(HERE, "error_vs_ratio.png") +fig2.savefig(out2, dpi=130) +print(f"saved plot -> {out2}") + +print( + f"\n{'b_amax/g_amax':>13} {'bscale=448t':>12} {'fp8(bscale)':>12} {'rel_err':>10} {'regime':>12}" +) +for t in [1e2, 1e1, 1.0, 0.1, 1e-3, T_SUBNORMAL, 1e-5, T_LOWER_CLAMP, 1e-6, 1e-7]: + bscale = FP8_MAX * t + fp8 = ( + torch.tensor(bscale) + .clamp(min=FP8_MIN_SUBNORMAL, max=FP8_MAX) + .to(torch.float8_e4m3fn) + .float() + .item() + ) + rel = (fp8 - bscale) / bscale + if t > T_SAT: + regime = "saturation" + elif t >= T_SUBNORMAL: + regime = "normal" + elif t >= T_LOWER_CLAMP: + regime = "subnormal" + else: + regime = "lower-clamp" + print(f"{t:>13.3e} {bscale:>12.5g} {fp8:>12.5g} {rel:>10.4f} {regime:>12}") + +# ---------------------------------------------------------------------------- +# PART 4 — activation calibration: B_max-anchored vs B_min-anchored g_amax as +# unseen inference outliers grow. Demonstrates that anchoring g_amax to +# the stable B_min (g = rho * B_min) is robust to outliers that +# calibration never saw, whereas B_max-anchoring saturates immediately. +# ---------------------------------------------------------------------------- +print() +print("=" * 100) +print("PART 4: calibration strategy robustness (B_max-anchored vs B_min-anchored)") +print("=" * 100) + +NORMAL_WINDOW = FP8_MAX / FP8_MIN_NORMAL # 28672: width of the normal-FP8 g_amax window + + +def build_tensor_from_block_amaxes(block_amaxes: torch.Tensor, seed: int = 0) -> torch.Tensor: + """Build a (num_blocks, 16) tensor whose per-block abs-max equals block_amaxes.""" + g = torch.Generator().manual_seed(seed) + n = block_amaxes.numel() + x = torch.randn(n, BLOCK, generator=g) + x = x / x.abs().amax(dim=-1, keepdim=True) # normalize each block to amax 1 + return x * block_amaxes.view(-1, 1).float() + + +def quant_mse(tensor: torch.Tensor, g_amax: float) -> float: + """NVFP4 quant->dequant MSE over all elements for a given per-tensor g_amax.""" + wsf2 = torch.tensor(g_amax / (6.0 * 448.0), dtype=torch.float32) + qt, wsf, wsf2o = NVFP4QTensor.quantize(tensor, block_size=BLOCK, weights_scaling_factor_2=wsf2) + deq = qt.dequantize(dtype=torch.float32, scale=wsf, double_scale=wsf2o, block_sizes={-1: BLOCK}) + return float(((deq - tensor) ** 2).mean()) + + +# Calibration block-amax distribution: stable lognormal bulk + a few moderate outliers. +gen = torch.Generator().manual_seed(0) +n_blocks = 2000 +bulk = torch.exp(torch.randn(n_blocks, generator=gen) * 0.7) # lognormal, median ~1 +bulk[:: max(1, n_blocks // 20)] *= 8.0 # ~5% moderate outlier blocks +calib_amaxes = bulk + +B_min = float(torch.quantile(calib_amaxes, 0.01)) # robust floor (1st percentile) +B_max_calib = float(calib_amaxes.max()) +rho = 16384.0 # leaning high, just under the 28672 cliff + +g_bmin = rho * B_min # B_min-anchored (does not depend on inference outliers) +g_bmax = 1.5 * B_max_calib # B_max-anchored with a 1.5x margin + +print(f"calib: B_min(1%)={B_min:.4g} B_max={B_max_calib:.4g} range={B_max_calib / B_min:.1f}x") +print(f" B_min-anchored g_amax = {rho:.0f} * B_min = {g_bmin:.4g}") +print(f" B_max-anchored g_amax = 1.5 * B_max = {g_bmax:.4g}") +print(f" normal-window width = {NORMAL_WINDOW:.0f}x ; feasible iff range < that\n") + +# At inference, the outlier blocks grow by factor k (unseen during calibration). +k_grid = torch.logspace(0, 1.5, 40).tolist() # 1x .. ~31x +mse_bmin, mse_bmax, mse_oracle = [], [], [] +outlier_mask = torch.zeros(n_blocks, dtype=torch.bool) +outlier_mask[:: max(1, n_blocks // 20)] = True +for k in k_grid: + infer_amaxes = calib_amaxes.clone() + infer_amaxes[outlier_mask] *= k # outliers grow at inference + tensor = build_tensor_from_block_amaxes(infer_amaxes, seed=1) + mse_bmin.append(quant_mse(tensor, g_bmin)) + mse_bmax.append(quant_mse(tensor, g_bmax)) + mse_oracle.append(quant_mse(tensor, float(infer_amaxes.max()))) # knows inference max + +fig4, ax4 = plt.subplots(figsize=(9, 5.5)) +ax4.plot( + k_grid, mse_bmax, "o-", ms=3, color="C3", label=f"B_max-anchored (g=1.5·B_max={g_bmax:.2g})" +) +ax4.plot( + k_grid, mse_bmin, "s-", ms=3, color="C2", label=f"B_min-anchored (g=16384·B_min={g_bmin:.2g})" +) +ax4.plot(k_grid, mse_oracle, "--", color="gray", label="oracle (g=inference B_max)") +ax4.axvline(g_bmin / B_max_calib, color="C2", ls=":", lw=0.8) +ax4.text( + g_bmin / B_max_calib, + 0.97, + "B_min-anchor saturates here", + rotation=90, + va="top", + ha="right", + fontsize=7, + color="C2", + transform=ax4.get_xaxis_transform(), +) +ax4.set_xscale("log") +ax4.set_yscale("log") +ax4.set_xlabel("inference outlier growth factor k (B_max_infer / B_max_calib)") +ax4.set_ylabel("quantization MSE on inference data") +ax4.set_title("Robustness to unseen activation outliers: B_min- vs B_max-anchored g_amax") +ax4.legend(fontsize=8) +ax4.grid(True, which="both", ls="--", alpha=0.3) +fig4.tight_layout() +out4 = os.path.join(HERE, "calib_strategy.png") +fig4.savefig(out4, dpi=130) +print(f"saved plot -> {out4}") + +print(f"\n{'k':>6} {'MSE B_max-anch':>15} {'MSE B_min-anch':>15} {'MSE oracle':>12}") +for k, mx, mn, mo in zip(k_grid, mse_bmax, mse_bmin, mse_oracle): + if any(abs(k - kk) < 1e-9 for kk in [k_grid[0], k_grid[13], k_grid[26], k_grid[-1]]): + print(f"{k:>6.2f} {mx:>15.6g} {mn:>15.6g} {mo:>12.6g}")