diff --git a/backends/cuda/tests/test_sdpa_midm.py b/backends/cuda/tests/test_sdpa_midm.py new file mode 100644 index 00000000000..931e4d387fc --- /dev/null +++ b/backends/cuda/tests/test_sdpa_midm.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Correctness (vs F.sdpa) + isolated speedup for the mid-M flash SDPA kernel. + +CUDA + Triton only. Validates the length-bounded mid-M kernel against the exact +attention the gemma4 full-attention layers compute (causal, enable_gqa, scale=1) +and shows it beats a full-buffer F.sdpa when the valid length << max_seq_len. +""" + +import unittest + +import torch + +from executorch.backends.cuda.triton.kernels.sdpa_midm import ( + midm_sdpa, + sdpa_midm, + sdpa_midm_reference, +) + + +def _require_cuda(tc): + if not torch.cuda.is_available(): + tc.skipTest("CUDA required") + + +def _rand(B, Hkv, H, M, D, S, anchor, device="cuda", dtype=torch.bfloat16): + q = torch.randn(B, H, M, D, device=device, dtype=dtype) + k = torch.randn(B, Hkv, S, D, device=device, dtype=dtype) + v = torch.randn(B, Hkv, S, D, device=device, dtype=dtype) + input_pos = torch.arange(anchor, anchor + M, device=device, dtype=torch.long) + return q, k, v, input_pos + + +def _rel_err(a, b): + return ( + (a.float() - b.float()).abs().mean() / b.float().abs().mean().clamp_min(1e-6) + ).item() + + +class TestMidMSDPA(unittest.TestCase): + def setUp(self): + _require_cuda(self) + torch.manual_seed(0) + + def _check(self, B, Hkv, H, M, D, S, anchor, tol=0.02): + q, k, v, pos = _rand(B, Hkv, H, M, D, S, anchor) + got = sdpa_midm(q, k, v, pos, scale=1.0) + ref = sdpa_midm_reference(q, k, v, pos, scale=1.0) + self.assertEqual(got.shape, (B, H, M, D)) + err = _rel_err(got, ref) + self.assertLess(err, tol, f"rel_err={err} for M={M} D={D} anchor={anchor}") + + # gemma4 global-attention shape: H=32, HKV=4 (GQA 8), D=512. + def test_global_layer_verify_window(self): + for M in (2, 4, 5, 8): + for anchor in (0, 17, 200, 1000): + self._check(1, 4, 32, M, 512, 4096, anchor) + + def test_other_gqa_and_headdim(self): + # smaller config (head_dim 256, GQA 4) to exercise generality + for M in (2, 5, 8): + self._check(1, 2, 8, M, 256, 2048, 300) + + def test_anchor_zero_single_diagonal(self): + # anchor 0: row j attends keys [0, j] only + self._check(1, 4, 32, 4, 512, 1024, 0) + + def test_matches_full_buffer_fsdpa(self): + # The bounded kernel must equal F.sdpa over the FULL buffer with the + # model's causal additive mask (the rest masked to -inf). + import torch.nn.functional as F + + q, k, v, pos = _rand(1, 4, 32, 5, 512, 8192, 500) + key_idx = torch.arange(8192, device="cuda") + keep = key_idx[None, :] <= pos[:, None] + am = torch.where(keep, 0.0, float("-inf")).to(q.dtype) + full = F.scaled_dot_product_attention( + q, k, v, attn_mask=am, is_causal=False, enable_gqa=True, scale=1.0 + ) + got = sdpa_midm(q, k, v, pos, scale=1.0) + self.assertLess(_rel_err(got, full), 0.02) + + def test_splitk_large_context(self): + # Many active splits: 64K buffer, anchors across the range. Exercises the + # cross-split online-softmax reduce at the lengths that motivated split-K. + for anchor in (2048, 30000, 60000): + for M in (2, 5, 8): + self._check(1, 4, 32, M, 512, 65536, anchor) + + def test_splitk_masked_and_boundary_splits(self): + # anchor small vs a large buffer: late key-range splits are fully causal- + # masked for the early rows (null partials), and a row's cutoff lands mid + # chunk. Reduce must discard -inf/0 partials cleanly. + for anchor in (1, 31, 33, 500): + self._check(1, 2, 8, 5, 256, 65536, anchor) + + def test_dispatch_falls_back(self): + # M=1 and M>MIDM_MAX_M must take the F.sdpa path (not the mid-M kernel). + import torch.nn.functional as F + + for M in (1, 16): + q, k, v, pos = _rand(1, 4, 32, M, 512, 1024, 100) + am = torch.zeros(M, 1024, device="cuda", dtype=q.dtype) + key_idx = torch.arange(1024, device="cuda") + am = torch.where(key_idx[None, :] <= pos[:, None], 0.0, float("-inf")).to( + q.dtype + ) + out = midm_sdpa(q, k, v, pos, am, scale=1.0, enable=True) + ref = F.scaled_dot_product_attention( + q, k, v, attn_mask=am, is_causal=False, enable_gqa=True, scale=1.0 + ) + self.assertLess(_rel_err(out, ref), 0.02) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/backends/cuda/triton/kernels/__init__.py b/backends/cuda/triton/kernels/__init__.py index 4db10fbf82d..a2281c5c06b 100644 --- a/backends/cuda/triton/kernels/__init__.py +++ b/backends/cuda/triton/kernels/__init__.py @@ -17,6 +17,7 @@ int4_matvec, ) from executorch.backends.cuda.triton.kernels.sdpa import sdpa, sdpa_decode_splitk +from executorch.backends.cuda.triton.kernels.sdpa_midm import sdpa_midm from executorch.backends.cuda.triton.kernels.topk import topk __all__ = [ @@ -29,6 +30,7 @@ "moe_align_block_size", "sdpa", "sdpa_decode_splitk", + "sdpa_midm", "topk", ] diff --git a/backends/cuda/triton/kernels/sdpa_midm.py b/backends/cuda/triton/kernels/sdpa_midm.py new file mode 100644 index 00000000000..66de8fbf731 --- /dev/null +++ b/backends/cuda/triton/kernels/sdpa_midm.py @@ -0,0 +1,391 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Mid-M flash SDPA: a length-bounded, split-K attention kernel for a few query rows. + +A companion to the kernels in ``sdpa.py``. The shared ``_sdpa_fwd_kernel_m64/m32`` +path scans the whole K/V buffer (mask width = max_seq_len), so a speculative +verify forward (M = chain+1, a few rows) over a 64K-context export pays a 64K-wide +attention pass even when only a few hundred positions are valid. + +This kernel specializes that mid-M regime (analogous to the small-M weight- +stationary INT4 GEMM): it keeps all M query rows resident, streams K/V once, and +**bounds the key range to the actual valid length**. Crucially it also **splits +the key range across CTAs** (flash-decoding / split-K): a single (B, H) grid puts +only one CTA per head looping the whole KV serially, so at long context the verify +attention is occupancy-starved and grows linearly with length. Split-K partitions +[0, valid_len) into NUM_SPLITS chunks computed in parallel, then a reduce kernel +combines the per-split online-softmax partials -- the same trick that keeps the +M=1 decode path (``_sdpa_decode_splitk``) flat out to 64K. Since attention here is +bandwidth-bound on the K/V read, the M query rows ride along for free and the +verify attention approaches the decode floor regardless of context. + +Causal masking is computed per row from ``input_pos`` (each of the M rows has its +own cutoff) -- no materialized max_seq_len mask. scale / GQA / bf16-in-fp32- +accumulate match ``F.scaled_dot_product_attention(..., is_causal=False, +enable_gqa=True)`` (Gemma 4 uses scale=1; QK-norm absorbs 1/sqrt(d)). Sliding- +window layers use a ring cache and stay on F.sdpa; this targets the flat full- +attention layers. + +Lives in the CUDA backend (imported only during CUDA lowering / by the model's +mid-M dispatch), so triton is imported unconditionally -- same as ``sdpa.py``. +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from torch.library import triton_op, wrap_triton + +# Verify windows up to this M route to the mid-M kernel; above it the prefill +# path is appropriate (enough rows to amortize a tiled kernel). +MIDM_MAX_M = 8 + +# Number of key-range partitions for split-K. The verify method exports a static +# M / B / H / D, so the partial buffers and grid are static-shaped; only the +# per-split chunk size (derived from the dynamic valid_len) is a runtime scalar. +# 32 splits x (B*H) heads gives ~1K CTAs at the gemma4 global shape -- ample +# occupancy on an A100 while keeping the fp32 partials small. +NUM_SPLITS = 32 + + +@triton.jit +def _sdpa_midm_splitk_kernel( + Q, + K, + V, + POS, + Opart, + Lpart, + Mpart, + sqb, + sqh, + sqm, + sqd, + skb, + skh, + skn, + skd, + svb, + svh, + svn, + svd, + sops, + sopb, + soph, + sopm, + sopd, + slps, + slpb, + slph, + slpm, + smps, + smpb, + smph, + smpm, + valid_len, + chunk_size, + scale, + H: tl.constexpr, + HKV: tl.constexpr, + M: tl.constexpr, + D: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + split_id = tl.program_id(0) + pid_b = tl.program_id(1) + pid_h = tl.program_id(2) + kv_h = pid_h // (H // HKV) + + # This CTA owns keys [start_n, end_n) of the valid range. Splits whose range + # falls entirely past valid_len run an empty loop and emit a null partial + # (m=-inf, l=0, acc=0), which the reduce discards. + start_n = split_id * chunk_size + end_n = tl.minimum(start_n + chunk_size, valid_len) + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, D) + offs_n = tl.arange(0, BLOCK_N) + + q = tl.load( + Q + pid_b * sqb + pid_h * sqh + offs_m[:, None] * sqm + offs_d[None, :] * sqd, + mask=offs_m[:, None] < M, + other=0.0, + ) + qpos = tl.load(POS + offs_m, mask=offs_m < M, other=0).to(tl.int32) + + m_i = tl.full([BLOCK_M], -float("inf"), tl.float32) + l_i = tl.zeros([BLOCK_M], tl.float32) + acc = tl.zeros([BLOCK_M, D], tl.float32) + + kbase = K + pid_b * skb + kv_h * skh + vbase = V + pid_b * svb + kv_h * svh + for sn in tl.range(start_n, end_n, BLOCK_N): + n = sn + offs_n + nmask = n < end_n + k = tl.load( + kbase + n[:, None] * skn + offs_d[None, :] * skd, + mask=nmask[:, None], + other=0.0, + ) + qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale + causal = (n[None, :] <= qpos[:, None]) & nmask[None, :] + # Keep fp32: a bare -inf python literal promotes the loop-carried softmax + # state to fp64, which AOTI's Triton compile rejects. + qk = tl.where(causal, qk, float("-inf")).to(tl.float32) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + # A split whose whole tile is causal-masked for a row leaves m_ij=-inf; + # guard exp(-inf - -inf)=NaN (mirrors sdpa.py). Such null tiles then yield + # a (m=-inf, l=0, acc=0) partial that the reduce discards. + safe_qk = tl.where( + m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf") + ) + p = tl.exp(safe_qk) + safe_alpha = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0) + alpha = tl.exp(safe_alpha) + l_i = l_i * alpha + tl.sum(p, 1) + v = tl.load( + vbase + n[:, None] * svn + offs_d[None, :] * svd, + mask=nmask[:, None], + other=0.0, + ) + acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v) + m_i = m_ij + + # Store the un-normalized partial (acc, running max, running denom). The + # reduce kernel rescales across splits and divides by the global denom. + pbase = split_id * sops + pid_b * sopb + pid_h * soph + tl.store( + Opart + pbase + offs_m[:, None] * sopm + offs_d[None, :] * sopd, + acc, + mask=offs_m[:, None] < M, + ) + tl.store( + Lpart + split_id * slps + pid_b * slpb + pid_h * slph + offs_m * slpm, + l_i, + mask=offs_m < M, + ) + tl.store( + Mpart + split_id * smps + pid_b * smpb + pid_h * smph + offs_m * smpm, + m_i, + mask=offs_m < M, + ) + + +@triton.jit +def _sdpa_midm_reduce_kernel( + Opart, + Lpart, + Mpart, + OUT, + sops, + sopb, + soph, + sopm, + sopd, + slps, + slpb, + slph, + slpm, + smps, + smpb, + smph, + smpm, + sob, + soh, + som, + sod, + NUM_SPLITS: tl.constexpr, + M: tl.constexpr, + D: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, D) + + m_g = tl.full([BLOCK_M], -float("inf"), tl.float32) + l_g = tl.zeros([BLOCK_M], tl.float32) + acc = tl.zeros([BLOCK_M, D], tl.float32) + + for s in range(0, NUM_SPLITS): + m_s = tl.load( + Mpart + s * smps + pid_b * smpb + pid_h * smph + offs_m * smpm, + mask=offs_m < M, + other=-float("inf"), + ) + l_s = tl.load( + Lpart + s * slps + pid_b * slpb + pid_h * slph + offs_m * slpm, + mask=offs_m < M, + other=0.0, + ) + o_s = tl.load( + Opart + + s * sops + + pid_b * sopb + + pid_h * soph + + offs_m[:, None] * sopm + + offs_d[None, :] * sopd, + mask=offs_m[:, None] < M, + other=0.0, + ) + m_new = tl.maximum(m_g, m_s) + # Guard the all-empty case (m_new = -inf): -inf - -inf is NaN; where + # selects the safe value and discards it (mirrors sdpa.py). + finite = m_new > -float("inf") + alpha_g = tl.where(finite, tl.exp(m_g - m_new), 1.0) + alpha_s = tl.where(finite, tl.exp(m_s - m_new), 0.0) + l_g = l_g * alpha_g + l_s * alpha_s + acc = acc * alpha_g[:, None] + o_s * alpha_s[:, None] + m_g = m_new + + inv = tl.where(l_g > 0, 1.0 / l_g, 0.0) + acc = acc * inv[:, None] + tl.store( + OUT + pid_b * sob + pid_h * soh + offs_m[:, None] * som + offs_d[None, :] * sod, + acc.to(OUT.dtype.element_ty), + mask=offs_m[:, None] < M, + ) + + +@triton_op("triton::sdpa_midm", mutates_args={}) +def _sdpa_midm_op( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + input_pos: torch.Tensor, + valid_len: int, + scale: float, +) -> torch.Tensor: + """Length-bounded, split-K mid-M flash SDPA (triton_op so AOTI codegens it). + + ``valid_len`` (max valid position + 1) bounds the key range; it is split into + NUM_SPLITS chunks of ``chunk_size`` keys computed in parallel, then reduced. + M / B / H / D are static for the exported verify method, so only chunk_size is + a runtime (backed-SymInt) scalar -- the grid and partial buffers are static. + """ + B, H, M, D = q.shape + HKV = k.shape[1] + out = torch.empty_like(q) + BLOCK_M = max(16, triton.next_power_of_2(M)) + # gemma4 global layers use D=512; a wide key tile + pipelining overflow SMEM + # there, so shrink both. Small D can afford more. + BLOCK_N, num_stages = (32, 1) if D >= 512 else (64, 2) + chunk_size = (valid_len + NUM_SPLITS - 1) // NUM_SPLITS + + Opart = torch.empty((NUM_SPLITS, B, H, M, D), device=q.device, dtype=torch.float32) + Lpart = torch.empty((NUM_SPLITS, B, H, M), device=q.device, dtype=torch.float32) + Mpart = torch.empty((NUM_SPLITS, B, H, M), device=q.device, dtype=torch.float32) + + wrap_triton(_sdpa_midm_splitk_kernel)[(NUM_SPLITS, B, H)]( + q, + k, + v, + input_pos, + Opart, + Lpart, + Mpart, + *q.stride(), + *k.stride(), + *v.stride(), + *Opart.stride(), + *Lpart.stride(), + *Mpart.stride(), + valid_len, + chunk_size, + scale, + H=H, + HKV=HKV, + M=M, + D=D, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_stages=num_stages, + num_warps=4, + ) + wrap_triton(_sdpa_midm_reduce_kernel)[(B, H)]( + Opart, + Lpart, + Mpart, + out, + *Opart.stride(), + *Lpart.stride(), + *Mpart.stride(), + *out.stride(), + NUM_SPLITS=NUM_SPLITS, + M=M, + D=D, + BLOCK_M=BLOCK_M, + num_warps=4, + ) + return out + + +@_sdpa_midm_op.register_fake +def _sdpa_midm_abstract(q, k, v, input_pos, valid_len, scale): + return torch.empty_like(q) + + +def sdpa_midm(q, k, v, input_pos, scale=1.0, valid_len=None): + """Eager/convenience wrapper. ``valid_len`` defaults to max(input_pos)+1 + clamped to the buffer; callers in a traced graph should pass a single + precomputed ``valid_len`` to avoid per-layer SymInts.""" + if valid_len is None: + valid_len = min(int(input_pos[-1]) + 1, k.shape[2]) + return torch.ops.triton.sdpa_midm(q, k, v, input_pos, valid_len, scale) + + +def sdpa_midm_reference( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + input_pos: torch.Tensor, + scale: float = 1.0, +) -> torch.Tensor: + """Reference: F.sdpa over the valid prefix with a causal additive mask. + + Mirrors the gemma4 full-attention call (is_causal=False, enable_gqa=True, + explicit additive mask) sliced to the valid length, so it equals what the + model computes over the full buffer (the rest is masked to -inf anyway). + """ + valid_len = int(input_pos.max().item()) + 1 + key_idx = torch.arange(valid_len, device=q.device) + keep = key_idx[None, :] <= input_pos[:, None] + attn_mask = torch.where(keep, 0.0, float("-inf")).to(q.dtype) + return F.scaled_dot_product_attention( + q, + k[:, :, :valid_len], + v[:, :, :valid_len], + attn_mask=attn_mask, + is_causal=False, + enable_gqa=True, + scale=scale, + ) + + +def midm_sdpa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + input_pos: torch.Tensor, + attn_mask: torch.Tensor, + scale: float = 1.0, + enable: bool = False, + valid_len=None, +) -> torch.Tensor: + """Dispatch: the mid-M op for a small query window when enabled; otherwise + the standard F.sdpa the model already uses (which the replacement pass swaps + for triton::sdpa). M is static per exported method, so the branch resolves at + trace time. ``valid_len`` is the shared per-forward key bound.""" + M = q.shape[2] + if enable and 2 <= M <= MIDM_MAX_M: + return sdpa_midm(q, k, v, input_pos, scale, valid_len=valid_len) + return F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, is_causal=False, enable_gqa=True, scale=scale + ) diff --git a/examples/models/eagle3/export.py b/examples/models/eagle3/export.py index 9e96e8fc59a..892749e256b 100644 --- a/examples/models/eagle3/export.py +++ b/examples/models/eagle3/export.py @@ -83,8 +83,11 @@ def __init__(self, spec: Eagle3Speculator): super().__init__() self.spec = spec - def forward(self, tokens, input_pos): - return self.spec.target_verify(tokens, input_pos) + def forward(self, tokens, input_pos, kv_window): + # kv_window length = number of valid KV positions; its dynamic dim is a + # backed SymInt that bounds the mid-M SDPA key loop (ignored if mid-M is + # off). Only its shape matters, not its contents. + return self.spec.target_verify(tokens, input_pos, kv_window) class _DraftDecode(nn.Module): @@ -119,6 +122,15 @@ def _export_cuda( inductor_config.coordinate_descent_tuning = False inductor_config.aot_inductor.compile_wrapper_opt_level = "O0" + import time + + _t = [time.time()] + + def _lap(msg: str) -> None: + now = time.time() + print(f"[export +{now - _t[0]:6.1f}s] {msg}", flush=True) + _t[0] = now + # Register Int4Tensor dispatch -> executorch_cuda::int4_plain_mm for the # target. main() sets MATVEC_MAX_M (and restores it) around this call. import executorch.backends.cuda.quantize_op_dispatch.int4_dispatch as int4_dispatch @@ -151,17 +163,24 @@ def _export_cuda( dynamic_shapes=({1: prefill_dim}, {0: prefill_dim}), strict=True, ) + _lap("export prefill") print(f"Exporting target_verify (T = {verify_len})...") + # The mid-M SDPA key bound is the dynamic length of kv_window: valid KV + # positions = anchor_pos + chain + 1, in [verify_len, max_seq_len]. + kv_dim = Dim("kv_len", min=verify_len, max=target_config.max_seq_len) with torch.no_grad(): verify_ep = export( _TargetVerify(spec), ( torch.zeros((1, verify_len), dtype=torch.long), torch.arange(verify_len, dtype=torch.long), + torch.zeros((8 * verify_len,), dtype=torch.int32), ), + dynamic_shapes=({}, {}, {0: kv_dim}), strict=True, ) + _lap("export target_verify") # draft_decode: T>1 seeds the draft KV (prompt / newly confirmed tokens), T=1 # steps the chain. The feature is hidden-size for both (fused target feature @@ -182,6 +201,7 @@ def _export_cuda( dynamic_shapes=({1: draft_dim}, {1: draft_dim}, {0: draft_dim}), strict=True, ) + _lap("export draft_decode") del spec gc.collect() @@ -227,6 +247,7 @@ def _partitioner(name: str): ) del prefill_ep, verify_ep, draft_ep gc.collect() + _lap("to_edge_transform_and_lower (AOTI compile)") et_program = et_prog.to_executorch( config=ExecutorchBackendConfig( @@ -240,6 +261,7 @@ def _partitioner(name: str): ) del et_prog gc.collect() + _lap("to_executorch") os.makedirs(output_dir, exist_ok=True) pte_path = os.path.join(output_dir, "model.pte") @@ -250,6 +272,7 @@ def _partitioner(name: str): if et_program._tensor_data: et_program.write_tensor_data_to_file(output_dir) print(f" Saved tensor data (.ptd) to {output_dir}/") + _lap("write .pte + .ptd") print("Done.") @@ -278,6 +301,12 @@ def main() -> None: p.add_argument( "--chain", type=int, default=4, help="Draft chain length K (verify K+1)." ) + p.add_argument( + "--no-midm-sdpa", + action="store_true", + help="Disable the length-bounded mid-M SDPA kernel for target_verify " + "(it accelerates full-attention layers at long context).", + ) args = p.parse_args() spec_t = TARGETS[args.target_model] @@ -287,6 +316,13 @@ def main() -> None: print(f"Loading {args.target_model} target from {args.target}...") target = spec_t.load(args.target, args.max_seq_len) + # Route the target's full-attention layers' verify SDPA (M=chain+1) through + # the length-bounded mid-M Triton kernel. Only affects target_verify (prefill + # M is out of range, decode isn't exported); huge win at long context. + if not args.no_midm_sdpa and hasattr(target, "set_midm_sdpa"): + target.set_midm_sdpa(True) + print("Enabled mid-M SDPA for target_verify.") + print(f"Loading draft head from {args.draft}...") draft, _ = Eagle3Draft.from_checkpoint( args.draft, device="cpu", dtype=torch.bfloat16, max_seq_len=args.max_seq_len diff --git a/examples/models/eagle3/main.cpp b/examples/models/eagle3/main.cpp index 62471e51db1..5bee9360a3e 100644 --- a/examples/models/eagle3/main.cpp +++ b/examples/models/eagle3/main.cpp @@ -11,8 +11,13 @@ // Loads the speculator .pte (examples/models/eagle3/export.py) exposing three // methods that share the target / draft KV caches: // prefill(tokens[1,T], pos[T]) -> (next_token[1,1], feat[1,T,H]) -// target_verify(tokens[1,C], pos[C]) -> (greedy_ids[1,C], feat[1,C,H]) -// draft_decode(tokens[1,T], feat[1,T,H], pos[T]) -> (target_ids[1,T], g[1,T,H]) +// target_verify(tokens[1,C], pos[C], kv_window[V]) -> (greedy_ids[1,C], +// feat[1,C,H]) -- kv_window's dynamic length V (= valid KV positions = +// anchor_pos+C) bounds the mid-M SDPA key loop (ignored if mid-M is off). +// Its growing per-round shape means target_verify can't be a CUDA graph when +// mid-M is on, so pass --cuda_graph=false there. +// draft_decode(tokens[1,T], feat[1,T,H], pos[T]) -> (target_ids[1,T], +// g[1,T,H]) // where feat is the fused (hidden-size) draft feature and H is the draft hidden // size. Verification is greedy (argmax), so emitted tokens equal greedy target // decoding (lossless) by construction. @@ -30,23 +35,24 @@ // negligible next to the INT4 31B target forward, and it keeps device-tensor // lifetimes simple. // -// Run (after exporting model.pte + aoti_cuda_blob.ptd via export.py, sourcing the -// CUDA env, and building the eagle3-cuda preset): +// Run (after exporting model.pte + aoti_cuda_blob.ptd via export.py, sourcing +// the CUDA env, and building the eagle3-cuda preset): // eagle3_speculator_runner --model_path