From 37fc965433418d10bc9908a45ba5db880e3ab672 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 17 Jun 2026 12:02:24 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- backends/cuda/tests/test_sdpa_midm.py | 120 +++++++ backends/cuda/triton/kernels/__init__.py | 2 + backends/cuda/triton/kernels/sdpa_midm.py | 391 ++++++++++++++++++++++ examples/models/eagle3/export.py | 40 ++- examples/models/eagle3/main.cpp | 205 ++++++++---- examples/models/eagle3/speculator.py | 7 +- examples/models/eagle3/target.py | 8 +- examples/models/gemma4_31b/model.py | 82 ++++- 8 files changed, 774 insertions(+), 81 deletions(-) create mode 100644 backends/cuda/tests/test_sdpa_midm.py create mode 100644 backends/cuda/triton/kernels/sdpa_midm.py diff --git a/backends/cuda/tests/test_sdpa_midm.py b/backends/cuda/tests/test_sdpa_midm.py new file mode 100644 index 00000000000..931e4d387fc --- /dev/null +++ b/backends/cuda/tests/test_sdpa_midm.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Correctness (vs F.sdpa) + isolated speedup for the mid-M flash SDPA kernel. + +CUDA + Triton only. Validates the length-bounded mid-M kernel against the exact +attention the gemma4 full-attention layers compute (causal, enable_gqa, scale=1) +and shows it beats a full-buffer F.sdpa when the valid length << max_seq_len. +""" + +import unittest + +import torch + +from executorch.backends.cuda.triton.kernels.sdpa_midm import ( + midm_sdpa, + sdpa_midm, + sdpa_midm_reference, +) + + +def _require_cuda(tc): + if not torch.cuda.is_available(): + tc.skipTest("CUDA required") + + +def _rand(B, Hkv, H, M, D, S, anchor, device="cuda", dtype=torch.bfloat16): + q = torch.randn(B, H, M, D, device=device, dtype=dtype) + k = torch.randn(B, Hkv, S, D, device=device, dtype=dtype) + v = torch.randn(B, Hkv, S, D, device=device, dtype=dtype) + input_pos = torch.arange(anchor, anchor + M, device=device, dtype=torch.long) + return q, k, v, input_pos + + +def _rel_err(a, b): + return ( + (a.float() - b.float()).abs().mean() / b.float().abs().mean().clamp_min(1e-6) + ).item() + + +class TestMidMSDPA(unittest.TestCase): + def setUp(self): + _require_cuda(self) + torch.manual_seed(0) + + def _check(self, B, Hkv, H, M, D, S, anchor, tol=0.02): + q, k, v, pos = _rand(B, Hkv, H, M, D, S, anchor) + got = sdpa_midm(q, k, v, pos, scale=1.0) + ref = sdpa_midm_reference(q, k, v, pos, scale=1.0) + self.assertEqual(got.shape, (B, H, M, D)) + err = _rel_err(got, ref) + self.assertLess(err, tol, f"rel_err={err} for M={M} D={D} anchor={anchor}") + + # gemma4 global-attention shape: H=32, HKV=4 (GQA 8), D=512. + def test_global_layer_verify_window(self): + for M in (2, 4, 5, 8): + for anchor in (0, 17, 200, 1000): + self._check(1, 4, 32, M, 512, 4096, anchor) + + def test_other_gqa_and_headdim(self): + # smaller config (head_dim 256, GQA 4) to exercise generality + for M in (2, 5, 8): + self._check(1, 2, 8, M, 256, 2048, 300) + + def test_anchor_zero_single_diagonal(self): + # anchor 0: row j attends keys [0, j] only + self._check(1, 4, 32, 4, 512, 1024, 0) + + def test_matches_full_buffer_fsdpa(self): + # The bounded kernel must equal F.sdpa over the FULL buffer with the + # model's causal additive mask (the rest masked to -inf). + import torch.nn.functional as F + + q, k, v, pos = _rand(1, 4, 32, 5, 512, 8192, 500) + key_idx = torch.arange(8192, device="cuda") + keep = key_idx[None, :] <= pos[:, None] + am = torch.where(keep, 0.0, float("-inf")).to(q.dtype) + full = F.scaled_dot_product_attention( + q, k, v, attn_mask=am, is_causal=False, enable_gqa=True, scale=1.0 + ) + got = sdpa_midm(q, k, v, pos, scale=1.0) + self.assertLess(_rel_err(got, full), 0.02) + + def test_splitk_large_context(self): + # Many active splits: 64K buffer, anchors across the range. Exercises the + # cross-split online-softmax reduce at the lengths that motivated split-K. + for anchor in (2048, 30000, 60000): + for M in (2, 5, 8): + self._check(1, 4, 32, M, 512, 65536, anchor) + + def test_splitk_masked_and_boundary_splits(self): + # anchor small vs a large buffer: late key-range splits are fully causal- + # masked for the early rows (null partials), and a row's cutoff lands mid + # chunk. Reduce must discard -inf/0 partials cleanly. + for anchor in (1, 31, 33, 500): + self._check(1, 2, 8, 5, 256, 65536, anchor) + + def test_dispatch_falls_back(self): + # M=1 and M>MIDM_MAX_M must take the F.sdpa path (not the mid-M kernel). + import torch.nn.functional as F + + for M in (1, 16): + q, k, v, pos = _rand(1, 4, 32, M, 512, 1024, 100) + am = torch.zeros(M, 1024, device="cuda", dtype=q.dtype) + key_idx = torch.arange(1024, device="cuda") + am = torch.where(key_idx[None, :] <= pos[:, None], 0.0, float("-inf")).to( + q.dtype + ) + out = midm_sdpa(q, k, v, pos, am, scale=1.0, enable=True) + ref = F.scaled_dot_product_attention( + q, k, v, attn_mask=am, is_causal=False, enable_gqa=True, scale=1.0 + ) + self.assertLess(_rel_err(out, ref), 0.02) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/backends/cuda/triton/kernels/__init__.py b/backends/cuda/triton/kernels/__init__.py index 4db10fbf82d..a2281c5c06b 100644 --- a/backends/cuda/triton/kernels/__init__.py +++ b/backends/cuda/triton/kernels/__init__.py @@ -17,6 +17,7 @@ int4_matvec, ) from executorch.backends.cuda.triton.kernels.sdpa import sdpa, sdpa_decode_splitk +from executorch.backends.cuda.triton.kernels.sdpa_midm import sdpa_midm from executorch.backends.cuda.triton.kernels.topk import topk __all__ = [ @@ -29,6 +30,7 @@ "moe_align_block_size", "sdpa", "sdpa_decode_splitk", + "sdpa_midm", "topk", ] diff --git a/backends/cuda/triton/kernels/sdpa_midm.py b/backends/cuda/triton/kernels/sdpa_midm.py new file mode 100644 index 00000000000..66de8fbf731 --- /dev/null +++ b/backends/cuda/triton/kernels/sdpa_midm.py @@ -0,0 +1,391 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Mid-M flash SDPA: a length-bounded, split-K attention kernel for a few query rows. + +A companion to the kernels in ``sdpa.py``. The shared ``_sdpa_fwd_kernel_m64/m32`` +path scans the whole K/V buffer (mask width = max_seq_len), so a speculative +verify forward (M = chain+1, a few rows) over a 64K-context export pays a 64K-wide +attention pass even when only a few hundred positions are valid. + +This kernel specializes that mid-M regime (analogous to the small-M weight- +stationary INT4 GEMM): it keeps all M query rows resident, streams K/V once, and +**bounds the key range to the actual valid length**. Crucially it also **splits +the key range across CTAs** (flash-decoding / split-K): a single (B, H) grid puts +only one CTA per head looping the whole KV serially, so at long context the verify +attention is occupancy-starved and grows linearly with length. Split-K partitions +[0, valid_len) into NUM_SPLITS chunks computed in parallel, then a reduce kernel +combines the per-split online-softmax partials -- the same trick that keeps the +M=1 decode path (``_sdpa_decode_splitk``) flat out to 64K. Since attention here is +bandwidth-bound on the K/V read, the M query rows ride along for free and the +verify attention approaches the decode floor regardless of context. + +Causal masking is computed per row from ``input_pos`` (each of the M rows has its +own cutoff) -- no materialized max_seq_len mask. scale / GQA / bf16-in-fp32- +accumulate match ``F.scaled_dot_product_attention(..., is_causal=False, +enable_gqa=True)`` (Gemma 4 uses scale=1; QK-norm absorbs 1/sqrt(d)). Sliding- +window layers use a ring cache and stay on F.sdpa; this targets the flat full- +attention layers. + +Lives in the CUDA backend (imported only during CUDA lowering / by the model's +mid-M dispatch), so triton is imported unconditionally -- same as ``sdpa.py``. +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from torch.library import triton_op, wrap_triton + +# Verify windows up to this M route to the mid-M kernel; above it the prefill +# path is appropriate (enough rows to amortize a tiled kernel). +MIDM_MAX_M = 8 + +# Number of key-range partitions for split-K. The verify method exports a static +# M / B / H / D, so the partial buffers and grid are static-shaped; only the +# per-split chunk size (derived from the dynamic valid_len) is a runtime scalar. +# 32 splits x (B*H) heads gives ~1K CTAs at the gemma4 global shape -- ample +# occupancy on an A100 while keeping the fp32 partials small. +NUM_SPLITS = 32 + + +@triton.jit +def _sdpa_midm_splitk_kernel( + Q, + K, + V, + POS, + Opart, + Lpart, + Mpart, + sqb, + sqh, + sqm, + sqd, + skb, + skh, + skn, + skd, + svb, + svh, + svn, + svd, + sops, + sopb, + soph, + sopm, + sopd, + slps, + slpb, + slph, + slpm, + smps, + smpb, + smph, + smpm, + valid_len, + chunk_size, + scale, + H: tl.constexpr, + HKV: tl.constexpr, + M: tl.constexpr, + D: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + split_id = tl.program_id(0) + pid_b = tl.program_id(1) + pid_h = tl.program_id(2) + kv_h = pid_h // (H // HKV) + + # This CTA owns keys [start_n, end_n) of the valid range. Splits whose range + # falls entirely past valid_len run an empty loop and emit a null partial + # (m=-inf, l=0, acc=0), which the reduce discards. + start_n = split_id * chunk_size + end_n = tl.minimum(start_n + chunk_size, valid_len) + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, D) + offs_n = tl.arange(0, BLOCK_N) + + q = tl.load( + Q + pid_b * sqb + pid_h * sqh + offs_m[:, None] * sqm + offs_d[None, :] * sqd, + mask=offs_m[:, None] < M, + other=0.0, + ) + qpos = tl.load(POS + offs_m, mask=offs_m < M, other=0).to(tl.int32) + + m_i = tl.full([BLOCK_M], -float("inf"), tl.float32) + l_i = tl.zeros([BLOCK_M], tl.float32) + acc = tl.zeros([BLOCK_M, D], tl.float32) + + kbase = K + pid_b * skb + kv_h * skh + vbase = V + pid_b * svb + kv_h * svh + for sn in tl.range(start_n, end_n, BLOCK_N): + n = sn + offs_n + nmask = n < end_n + k = tl.load( + kbase + n[:, None] * skn + offs_d[None, :] * skd, + mask=nmask[:, None], + other=0.0, + ) + qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale + causal = (n[None, :] <= qpos[:, None]) & nmask[None, :] + # Keep fp32: a bare -inf python literal promotes the loop-carried softmax + # state to fp64, which AOTI's Triton compile rejects. + qk = tl.where(causal, qk, float("-inf")).to(tl.float32) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + # A split whose whole tile is causal-masked for a row leaves m_ij=-inf; + # guard exp(-inf - -inf)=NaN (mirrors sdpa.py). Such null tiles then yield + # a (m=-inf, l=0, acc=0) partial that the reduce discards. + safe_qk = tl.where( + m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf") + ) + p = tl.exp(safe_qk) + safe_alpha = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0) + alpha = tl.exp(safe_alpha) + l_i = l_i * alpha + tl.sum(p, 1) + v = tl.load( + vbase + n[:, None] * svn + offs_d[None, :] * svd, + mask=nmask[:, None], + other=0.0, + ) + acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v) + m_i = m_ij + + # Store the un-normalized partial (acc, running max, running denom). The + # reduce kernel rescales across splits and divides by the global denom. + pbase = split_id * sops + pid_b * sopb + pid_h * soph + tl.store( + Opart + pbase + offs_m[:, None] * sopm + offs_d[None, :] * sopd, + acc, + mask=offs_m[:, None] < M, + ) + tl.store( + Lpart + split_id * slps + pid_b * slpb + pid_h * slph + offs_m * slpm, + l_i, + mask=offs_m < M, + ) + tl.store( + Mpart + split_id * smps + pid_b * smpb + pid_h * smph + offs_m * smpm, + m_i, + mask=offs_m < M, + ) + + +@triton.jit +def _sdpa_midm_reduce_kernel( + Opart, + Lpart, + Mpart, + OUT, + sops, + sopb, + soph, + sopm, + sopd, + slps, + slpb, + slph, + slpm, + smps, + smpb, + smph, + smpm, + sob, + soh, + som, + sod, + NUM_SPLITS: tl.constexpr, + M: tl.constexpr, + D: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, D) + + m_g = tl.full([BLOCK_M], -float("inf"), tl.float32) + l_g = tl.zeros([BLOCK_M], tl.float32) + acc = tl.zeros([BLOCK_M, D], tl.float32) + + for s in range(0, NUM_SPLITS): + m_s = tl.load( + Mpart + s * smps + pid_b * smpb + pid_h * smph + offs_m * smpm, + mask=offs_m < M, + other=-float("inf"), + ) + l_s = tl.load( + Lpart + s * slps + pid_b * slpb + pid_h * slph + offs_m * slpm, + mask=offs_m < M, + other=0.0, + ) + o_s = tl.load( + Opart + + s * sops + + pid_b * sopb + + pid_h * soph + + offs_m[:, None] * sopm + + offs_d[None, :] * sopd, + mask=offs_m[:, None] < M, + other=0.0, + ) + m_new = tl.maximum(m_g, m_s) + # Guard the all-empty case (m_new = -inf): -inf - -inf is NaN; where + # selects the safe value and discards it (mirrors sdpa.py). + finite = m_new > -float("inf") + alpha_g = tl.where(finite, tl.exp(m_g - m_new), 1.0) + alpha_s = tl.where(finite, tl.exp(m_s - m_new), 0.0) + l_g = l_g * alpha_g + l_s * alpha_s + acc = acc * alpha_g[:, None] + o_s * alpha_s[:, None] + m_g = m_new + + inv = tl.where(l_g > 0, 1.0 / l_g, 0.0) + acc = acc * inv[:, None] + tl.store( + OUT + pid_b * sob + pid_h * soh + offs_m[:, None] * som + offs_d[None, :] * sod, + acc.to(OUT.dtype.element_ty), + mask=offs_m[:, None] < M, + ) + + +@triton_op("triton::sdpa_midm", mutates_args={}) +def _sdpa_midm_op( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + input_pos: torch.Tensor, + valid_len: int, + scale: float, +) -> torch.Tensor: + """Length-bounded, split-K mid-M flash SDPA (triton_op so AOTI codegens it). + + ``valid_len`` (max valid position + 1) bounds the key range; it is split into + NUM_SPLITS chunks of ``chunk_size`` keys computed in parallel, then reduced. + M / B / H / D are static for the exported verify method, so only chunk_size is + a runtime (backed-SymInt) scalar -- the grid and partial buffers are static. + """ + B, H, M, D = q.shape + HKV = k.shape[1] + out = torch.empty_like(q) + BLOCK_M = max(16, triton.next_power_of_2(M)) + # gemma4 global layers use D=512; a wide key tile + pipelining overflow SMEM + # there, so shrink both. Small D can afford more. + BLOCK_N, num_stages = (32, 1) if D >= 512 else (64, 2) + chunk_size = (valid_len + NUM_SPLITS - 1) // NUM_SPLITS + + Opart = torch.empty((NUM_SPLITS, B, H, M, D), device=q.device, dtype=torch.float32) + Lpart = torch.empty((NUM_SPLITS, B, H, M), device=q.device, dtype=torch.float32) + Mpart = torch.empty((NUM_SPLITS, B, H, M), device=q.device, dtype=torch.float32) + + wrap_triton(_sdpa_midm_splitk_kernel)[(NUM_SPLITS, B, H)]( + q, + k, + v, + input_pos, + Opart, + Lpart, + Mpart, + *q.stride(), + *k.stride(), + *v.stride(), + *Opart.stride(), + *Lpart.stride(), + *Mpart.stride(), + valid_len, + chunk_size, + scale, + H=H, + HKV=HKV, + M=M, + D=D, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_stages=num_stages, + num_warps=4, + ) + wrap_triton(_sdpa_midm_reduce_kernel)[(B, H)]( + Opart, + Lpart, + Mpart, + out, + *Opart.stride(), + *Lpart.stride(), + *Mpart.stride(), + *out.stride(), + NUM_SPLITS=NUM_SPLITS, + M=M, + D=D, + BLOCK_M=BLOCK_M, + num_warps=4, + ) + return out + + +@_sdpa_midm_op.register_fake +def _sdpa_midm_abstract(q, k, v, input_pos, valid_len, scale): + return torch.empty_like(q) + + +def sdpa_midm(q, k, v, input_pos, scale=1.0, valid_len=None): + """Eager/convenience wrapper. ``valid_len`` defaults to max(input_pos)+1 + clamped to the buffer; callers in a traced graph should pass a single + precomputed ``valid_len`` to avoid per-layer SymInts.""" + if valid_len is None: + valid_len = min(int(input_pos[-1]) + 1, k.shape[2]) + return torch.ops.triton.sdpa_midm(q, k, v, input_pos, valid_len, scale) + + +def sdpa_midm_reference( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + input_pos: torch.Tensor, + scale: float = 1.0, +) -> torch.Tensor: + """Reference: F.sdpa over the valid prefix with a causal additive mask. + + Mirrors the gemma4 full-attention call (is_causal=False, enable_gqa=True, + explicit additive mask) sliced to the valid length, so it equals what the + model computes over the full buffer (the rest is masked to -inf anyway). + """ + valid_len = int(input_pos.max().item()) + 1 + key_idx = torch.arange(valid_len, device=q.device) + keep = key_idx[None, :] <= input_pos[:, None] + attn_mask = torch.where(keep, 0.0, float("-inf")).to(q.dtype) + return F.scaled_dot_product_attention( + q, + k[:, :, :valid_len], + v[:, :, :valid_len], + attn_mask=attn_mask, + is_causal=False, + enable_gqa=True, + scale=scale, + ) + + +def midm_sdpa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + input_pos: torch.Tensor, + attn_mask: torch.Tensor, + scale: float = 1.0, + enable: bool = False, + valid_len=None, +) -> torch.Tensor: + """Dispatch: the mid-M op for a small query window when enabled; otherwise + the standard F.sdpa the model already uses (which the replacement pass swaps + for triton::sdpa). M is static per exported method, so the branch resolves at + trace time. ``valid_len`` is the shared per-forward key bound.""" + M = q.shape[2] + if enable and 2 <= M <= MIDM_MAX_M: + return sdpa_midm(q, k, v, input_pos, scale, valid_len=valid_len) + return F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, is_causal=False, enable_gqa=True, scale=scale + ) diff --git a/examples/models/eagle3/export.py b/examples/models/eagle3/export.py index c3a4decf8f4..e171cc4b505 100644 --- a/examples/models/eagle3/export.py +++ b/examples/models/eagle3/export.py @@ -82,8 +82,11 @@ def __init__(self, spec: Eagle3Speculator): super().__init__() self.spec = spec - def forward(self, tokens, input_pos): - return self.spec.target_verify(tokens, input_pos) + def forward(self, tokens, input_pos, kv_window): + # kv_window length = number of valid KV positions; its dynamic dim is a + # backed SymInt that bounds the mid-M SDPA key loop (ignored if mid-M is + # off). Only its shape matters, not its contents. + return self.spec.target_verify(tokens, input_pos, kv_window) class _DraftDecode(nn.Module): @@ -118,6 +121,15 @@ def _export_cuda( inductor_config.coordinate_descent_tuning = False inductor_config.aot_inductor.compile_wrapper_opt_level = "O0" + import time + + _t = [time.time()] + + def _lap(msg: str) -> None: + now = time.time() + print(f"[export +{now - _t[0]:6.1f}s] {msg}", flush=True) + _t[0] = now + # Register Int4Tensor dispatch -> executorch_cuda::int4_plain_mm for the # target. main() sets MATVEC_MAX_M (and restores it) around this call. import executorch.backends.cuda.int4_dispatch as int4_dispatch @@ -150,17 +162,24 @@ def _export_cuda( dynamic_shapes=({1: prefill_dim}, {0: prefill_dim}), strict=True, ) + _lap("export prefill") print(f"Exporting target_verify (T = {verify_len})...") + # The mid-M SDPA key bound is the dynamic length of kv_window: valid KV + # positions = anchor_pos + chain + 1, in [verify_len, max_seq_len]. + kv_dim = Dim("kv_len", min=verify_len, max=target_config.max_seq_len) with torch.no_grad(): verify_ep = export( _TargetVerify(spec), ( torch.zeros((1, verify_len), dtype=torch.long), torch.arange(verify_len, dtype=torch.long), + torch.zeros((8 * verify_len,), dtype=torch.int32), ), + dynamic_shapes=({}, {}, {0: kv_dim}), strict=True, ) + _lap("export target_verify") # draft_decode: T>1 seeds the draft KV (prompt / newly confirmed tokens), T=1 # steps the chain. The feature is hidden-size for both (fused target feature @@ -181,6 +200,7 @@ def _export_cuda( dynamic_shapes=({1: draft_dim}, {1: draft_dim}, {0: draft_dim}), strict=True, ) + _lap("export draft_decode") del spec gc.collect() @@ -226,6 +246,7 @@ def _partitioner(name: str): ) del prefill_ep, verify_ep, draft_ep gc.collect() + _lap("to_edge_transform_and_lower (AOTI compile)") et_program = et_prog.to_executorch( config=ExecutorchBackendConfig( @@ -240,6 +261,7 @@ def _partitioner(name: str): ) del et_prog gc.collect() + _lap("to_executorch") os.makedirs(output_dir, exist_ok=True) pte_path = os.path.join(output_dir, "model.pte") @@ -250,6 +272,7 @@ def _partitioner(name: str): if et_program._tensor_data: et_program.write_tensor_data_to_file(output_dir) print(f" Saved tensor data (.ptd) to {output_dir}/") + _lap("write .pte + .ptd") print("Done.") @@ -278,6 +301,12 @@ def main() -> None: p.add_argument( "--chain", type=int, default=4, help="Draft chain length K (verify K+1)." ) + p.add_argument( + "--no-midm-sdpa", + action="store_true", + help="Disable the length-bounded mid-M SDPA kernel for target_verify " + "(it accelerates full-attention layers at long context).", + ) args = p.parse_args() spec_t = TARGETS[args.target_model] @@ -287,6 +316,13 @@ def main() -> None: print(f"Loading {args.target_model} target from {args.target}...") target = spec_t.load(args.target, args.max_seq_len) + # Route the target's full-attention layers' verify SDPA (M=chain+1) through + # the length-bounded mid-M Triton kernel. Only affects target_verify (prefill + # M is out of range, decode isn't exported); huge win at long context. + if not args.no_midm_sdpa and hasattr(target, "set_midm_sdpa"): + target.set_midm_sdpa(True) + print("Enabled mid-M SDPA for target_verify.") + print(f"Loading draft head from {args.draft}...") draft, _ = Eagle3Draft.from_checkpoint( args.draft, device="cpu", dtype=torch.bfloat16, max_seq_len=args.max_seq_len diff --git a/examples/models/eagle3/main.cpp b/examples/models/eagle3/main.cpp index 6a68e89eaaa..758a8bbf6b1 100644 --- a/examples/models/eagle3/main.cpp +++ b/examples/models/eagle3/main.cpp @@ -11,8 +11,13 @@ // Loads the speculator .pte (examples/models/eagle3/export.py) exposing three // methods that share the target / draft KV caches: // prefill(tokens[1,T], pos[T]) -> (next_token[1,1], feat[1,T,H]) -// target_verify(tokens[1,C], pos[C]) -> (greedy_ids[1,C], feat[1,C,H]) -// draft_decode(tokens[1,T], feat[1,T,H], pos[T]) -> (target_ids[1,T], g[1,T,H]) +// target_verify(tokens[1,C], pos[C], kv_window[V]) -> (greedy_ids[1,C], +// feat[1,C,H]) -- kv_window's dynamic length V (= valid KV positions = +// anchor_pos+C) bounds the mid-M SDPA key loop (ignored if mid-M is off). +// Its growing per-round shape means target_verify can't be a CUDA graph when +// mid-M is on, so pass --cuda_graph=false there. +// draft_decode(tokens[1,T], feat[1,T,H], pos[T]) -> (target_ids[1,T], +// g[1,T,H]) // where feat is the fused (hidden-size) draft feature and H is the draft hidden // size. Verification is greedy (argmax), so emitted tokens equal greedy target // decoding (lossless) by construction. @@ -30,23 +35,24 @@ // negligible next to the INT4 31B target forward, and it keeps device-tensor // lifetimes simple. // -// Run (after exporting model.pte + aoti_cuda_blob.ptd via export.py, sourcing the -// CUDA env, and building the eagle3-cuda preset): +// Run (after exporting model.pte + aoti_cuda_blob.ptd via export.py, sourcing +// the CUDA env, and building the eagle3-cuda preset): // eagle3_speculator_runner --model_path /model.pte \ // --data_path /aoti_cuda_blob.ptd --tokenizer_path \ // --prompt "..." --max_new_tokens 128 // The chat template and stop tokens default to Gemma 4 IT; override -// --chat_prefix/--chat_suffix/--stop_ids/--stop_token (and --bos_id -1) for other -// target/tokenizer pairs. Per-run timing counters (tau, verify/draft ms) print at -// the end. +// --chat_prefix/--chat_suffix/--stop_ids/--stop_token (and --bos_id -1) for +// other target/tokenizer pairs. Per-run timing counters (tau, verify/draft ms) +// print at the end. // // Scope: a single-sequence, greedy, fixed-shape demo runner -- not a generic -// EAGLE serving path. No batching, sampler stack (top-k/p/temperature), grammar/ -// tool constraints, streaming API, or integration with the standard ExecuTorch -// LLM runner. The host feature round-trip above is a first-implementation choice -// (the target forward dominates here); a device-resident handoff is future work. -// The target, draft, and tokenizer must be a matched, co-trained set -- a -// mismatch can pass export and silently degrade acceptance/output. +// EAGLE serving path. No batching, sampler stack (top-k/p/temperature), +// grammar/ tool constraints, streaming API, or integration with the standard +// ExecuTorch LLM runner. The host feature round-trip above is a +// first-implementation choice (the target forward dominates here); a +// device-resident handoff is future work. The target, draft, and tokenizer must +// be a matched, co-trained set -- a mismatch can pass export and silently +// degrade acceptance/output. #include @@ -96,9 +102,18 @@ DEFINE_bool(raw_prompt, false, "Skip the Gemma 4 IT chat template."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); DEFINE_int32(bos_id, 2, "BOS token id (-1 to skip; Gemma convention: 2)."); DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1)."); -DEFINE_bool(cuda_graph, true, "Capture target_verify as a CUDA graph (CUDA only)."); +DEFINE_bool( + cuda_graph, + false, + "Capture target_verify as a CUDA graph (CUDA only). Off by default: the " + "current export feeds target_verify a kv_window whose length changes every " + "round, so capture is unsafe (stale-shape replay). Only enable for an " + "export whose target_verify inputs all have stable shapes."); // Chat template + stop tokens default to Gemma 4 IT; override for other models. -DEFINE_string(chat_prefix, "<|turn>user\n", "Chat-template text before the prompt."); +DEFINE_string( + chat_prefix, + "<|turn>user\n", + "Chat-template text before the prompt."); DEFINE_string( chat_suffix, "\n<|turn>model\n<|channel>thought\n", @@ -130,7 +145,8 @@ std::vector to_host_bytes(const executorch::aten::Tensor& t) { cudaPointerAttributes attrs{}; if (cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess && attrs.type == cudaMemoryTypeDevice) { - cudaError_t err = cudaMemcpy(out.data(), ptr, out.size(), cudaMemcpyDeviceToHost); + cudaError_t err = + cudaMemcpy(out.data(), ptr, out.size(), cudaMemcpyDeviceToHost); if (err != cudaSuccess) { ET_LOG(Error, "D2H copy failed: %s", cudaGetErrorString(err)); exit(1); @@ -183,7 +199,10 @@ int main(int argc, char** argv) { auto tokenizer = std::make_unique(); if (tokenizer->load(FLAGS_tokenizer_path) != tokenizers::Error::Ok) { - ET_LOG(Error, "Failed to load tokenizer from %s", FLAGS_tokenizer_path.c_str()); + ET_LOG( + Error, + "Failed to load tokenizer from %s", + FLAGS_tokenizer_path.c_str()); return 1; } @@ -208,10 +227,12 @@ int main(int argc, char** argv) { executorch::runtime::set_option("CudaBackend", backend_options.view()); } if (FLAGS_cuda_graph) { - // target_verify is the one target forward per round and has a static shape - // (chain+1 tokens), so capture it as a CUDA graph to avoid paying the - // 60-layer per-kernel launch overhead every round (the dominant cost - // otherwise). Its input tensors must wrap stable host buffers (below). + // Opt-in only (default off): capturing target_verify avoids the 60-layer + // per-kernel launch overhead every round, but it is only sound when every + // target_verify input has a stable shape across rounds. This export does + // not satisfy that -- kv_window's length is the per-round valid-KV count + // (see the kvwin_buf NOTE below) -- so enabling capture here risks stale- + // shape replay. The flag is kept for a future fixed-shape verify export. executorch::runtime::BackendOptions<1> g; g.set_option("enable_cuda_graph_for_method", "target_verify"); executorch::runtime::set_option("CudaBackend", g.view()); @@ -282,24 +303,38 @@ int main(int argc, char** argv) { } const int64_t L = static_cast(prompt.size()); // The runner does not chunk: the whole prompt must fit one prefill, and its - // length must be within the exported prefill range [min_prefill, max_prefill]. + // length must be within the exported prefill range [min_prefill, + // max_prefill]. if (L > max_prefill) { - ET_LOG(Error, "Prompt (%" PRId64 " tokens) exceeds max_prefill %" PRId64 - "; this runner does not chunk prefill.", L, max_prefill); + ET_LOG( + Error, + "Prompt (%" PRId64 " tokens) exceeds max_prefill %" PRId64 + "; this runner does not chunk prefill.", + L, + max_prefill); return 1; } if (L < min_prefill) { - ET_LOG(Error, "Prompt (%" PRId64 " tokens) is below the exported prefill " - "minimum %" PRId64 "; use a longer prompt.", L, min_prefill); + ET_LOG( + Error, + "Prompt (%" PRId64 + " tokens) is below the exported prefill " + "minimum %" PRId64 "; use a longer prompt.", + L, + min_prefill); return 1; } // The prefill bonus token is always emittable (no KV write past the prompt). - // Each speculative round, however, writes a K-token verify window, so it needs - // anchor_pos + K <= max_seq_len - 1 (enforced in the loop below). Cap the total - // at the positions available; max_new >= 1 since L <= max_prefill < max_seq_len. + // Each speculative round, however, writes a K-token verify window, so it + // needs anchor_pos + K <= max_seq_len - 1 (enforced in the loop below). Cap + // the total at the positions available; max_new >= 1 since L <= max_prefill < + // max_seq_len. int64_t max_new = std::min(FLAGS_max_new_tokens, max_seq_len - L); - printf("Prompt tokens: %" PRId64 ", chain K=%" PRId64 ", max_new=%" PRId64 - "\n", L, K, max_new); + printf( + "Prompt tokens: %" PRId64 ", chain K=%" PRId64 ", max_new=%" PRId64 "\n", + L, + K, + max_new); auto S = [](int64_t v) { return static_cast(v); }; @@ -309,11 +344,15 @@ int main(int argc, char** argv) { auto long_tensor = [&](std::vector& buf) { return from_blob( - buf.data(), {1, S((int64_t)buf.size())}, executorch::aten::ScalarType::Long); + buf.data(), + {1, S((int64_t)buf.size())}, + executorch::aten::ScalarType::Long); }; auto pos_tensor = [&](std::vector& buf) { return from_blob( - buf.data(), {S((int64_t)buf.size())}, executorch::aten::ScalarType::Long); + buf.data(), + {S((int64_t)buf.size())}, + executorch::aten::ScalarType::Long); }; // draft_decode over (tokens, feat rows, positions); returns proposals + the @@ -333,7 +372,9 @@ int main(int argc, char** argv) { feat_buf.assign(feat_rows, feat_rows + feat_T * H); auto t_tok = long_tensor(tok_buf); auto t_feat = from_blob( - feat_buf.data(), {1, S(feat_T), S(H)}, executorch::aten::ScalarType::BFloat16); + feat_buf.data(), + {1, S(feat_T), S(H)}, + executorch::aten::ScalarType::BFloat16); auto t_pos = pos_tensor(pos_buf); auto r = module->execute( "draft_decode", {EValue(t_tok), EValue(t_feat), EValue(t_pos)}); @@ -345,8 +386,7 @@ int main(int argc, char** argv) { HostFeature g = read_feature(r->at(1).toTensor()); out_last_g.T = 1; out_last_g.H = g.H; - out_last_g.data.assign( - g.data.end() - g.H, g.data.end()); // last row of g + out_last_g.data.assign(g.data.end() - g.H, g.data.end()); // last row of g }; // Run a draft chain seeded by (seed_tokens, seed_feat) at seed positions; the @@ -358,16 +398,26 @@ int main(int argc, char** argv) { std::vector ids; HostFeature last_g; draft_decode( - seed_tokens, seed_feat.data.data(), seed_feat.T, seed_feat.H, - seed_start_pos, ids, last_g); + seed_tokens, + seed_feat.data.data(), + seed_feat.T, + seed_feat.H, + seed_start_pos, + ids, + last_g); proposals.push_back(ids.back()); int64_t last_pos = seed_start_pos + seed_feat.T - 1; for (int64_t k = 1; k < K; k++) { std::vector step_ids; HostFeature step_g; draft_decode( - {proposals.back()}, last_g.data.data(), 1, last_g.H, - last_pos + k, step_ids, step_g); + {proposals.back()}, + last_g.data.data(), + 1, + last_g.H, + last_pos + k, + step_ids, + step_g); proposals.push_back(step_ids[0]); last_g = step_g; } @@ -377,7 +427,8 @@ int main(int argc, char** argv) { stats.model_load_end_ms = llm::time_in_ms(); stats.inference_start_ms = stats.model_load_end_ms; - // --- Prefill: target over the prompt -> bonus token + per-position feature. --- + // --- Prefill: target over the prompt -> bonus token + per-position feature. + // --- tok_buf = prompt; pos_buf.resize(L); for (int64_t i = 0; i < L; i++) { @@ -389,7 +440,8 @@ int main(int argc, char** argv) { ET_LOG(Error, "prefill failed"); return 1; } - int64_t anchor = read_ids(pf->at(0).toTensor())[0]; // bonus token at position L + int64_t anchor = + read_ids(pf->at(0).toTensor())[0]; // bonus token at position L HostFeature feat_prompt = read_feature(pf->at(1).toTensor()); const int64_t H = feat_prompt.H; int64_t anchor_pos = L; @@ -401,13 +453,16 @@ int main(int argc, char** argv) { uint64_t prev = static_cast(prompt.back()); { auto s = tokenizer->decode(prev, static_cast(anchor)); - if (s.ok()) { printf("%s", s->c_str()); fflush(stdout); } + if (s.ok()) { + printf("%s", s->c_str()); + fflush(stdout); + } prev = static_cast(anchor); } // We only run the speculative loop if more than the (already emitted) prefill - // bonus is wanted, the bonus wasn't EOS, and there is room for a K-token verify - // window. Otherwise we are done -- no draft seeding needed. + // bonus is wanted, the bonus wasn't EOS, and there is room for a K-token + // verify window. Otherwise we are done -- no draft seeding needed. bool hit_eos = eos_ids.count(static_cast(anchor)) > 0; bool speculate = max_new > 1 && !hit_eos && anchor_pos + K <= max_seq_len - 1; std::vector proposals; @@ -427,6 +482,13 @@ int main(int argc, char** argv) { vtok_buf.data(), {1, S(K + 1)}, executorch::aten::ScalarType::Long); auto vpos_t = from_blob( vpos_buf.data(), {S(K + 1)}, executorch::aten::ScalarType::Long); + // kv_window: its dynamic length (= valid KV positions this round) is the + // mid-M SDPA key bound (ignored if the export has mid-M off). Contents are + // unused -- only the shape matters -- so one max-size buffer is reused and + // viewed at the per-round length. NOTE: this per-round shape change is why + // target_verify can't be captured as a CUDA graph for this export -- hence + // --cuda_graph defaults to false. + std::vector kvwin_buf(max_seq_len, 0); // --- Speculative rounds: one target forward (target_verify) per round. --- int64_t rounds = 0; @@ -442,8 +504,13 @@ int main(int argc, char** argv) { for (int64_t i = 0; i <= K; i++) { vpos_buf[i] = anchor_pos + i; } + // Valid KV positions after writing this round = [0, anchor_pos+K]. + int64_t valid_len = anchor_pos + K + 1; + auto kvwin_t = from_blob( + kvwin_buf.data(), {S(valid_len)}, executorch::aten::ScalarType::Int); int64_t t_v = llm::time_in_ms(); - auto vr = module->execute("target_verify", {EValue(vtok_t), EValue(vpos_t)}); + auto vr = module->execute( + "target_verify", {EValue(vtok_t), EValue(vpos_t), EValue(kvwin_t)}); if (vr.error() != Error::Ok) { ET_LOG(Error, "target_verify failed"); return 1; @@ -467,10 +534,14 @@ int main(int argc, char** argv) { std::vector newly(proposals.begin(), proposals.begin() + a); newly.push_back(corrected); for (int64_t t : newly) { - if ((int64_t)emitted.size() >= max_new) break; + if ((int64_t)emitted.size() >= max_new) + break; emitted.push_back(t); auto s = tokenizer->decode(prev, static_cast(t)); - if (s.ok()) { printf("%s", s->c_str()); fflush(stdout); } + if (s.ok()) { + printf("%s", s->c_str()); + fflush(stdout); + } prev = static_cast(t); if (eos_ids.count(static_cast(t)) > 0) { // Stop at the first accepted EOS; do not emit the rest of this batch. @@ -481,13 +552,15 @@ int main(int argc, char** argv) { break; } } - if (hit_eos || (int64_t)emitted.size() >= max_new) break; + if (hit_eos || (int64_t)emitted.size() >= max_new) + break; // Reseed the draft (shifted): slot anchor_pos+i holds (verify_feat[i], - // token_{anchor_pos+i+1}) where token = p_i (i