diff --git a/backends/cuda/triton/kernels/sdpa.py b/backends/cuda/triton/kernels/sdpa.py index fb665e538bf..86807745f48 100644 --- a/backends/cuda/triton/kernels/sdpa.py +++ b/backends/cuda/triton/kernels/sdpa.py @@ -45,6 +45,15 @@ def _is_power_of_2(n: int) -> bool: return n > 0 and (n & (n - 1)) == 0 +# KV length at/above which decode (L_q == 1) uses the split-K flash-decoding +# kernel instead of the standard kernel. Mirrors the threshold the CUDA +# replacement pass uses to pick triton.sdpa_decode_splitk. +_SPLITK_LKV_THRESHOLD = 256 + +# FlashDecoding++ unified-max constant used by the split-K decode path. +_DEFAULT_SPLITK_PHI = 5.0 + + def _next_power_of_2(x: int) -> int: """Get the next power of 2 >= x, clamped to [16, 256]. @@ -160,6 +169,7 @@ def _sdpa_fwd_kernel_non_pow2( v_ptr, o_ptr, mask_ptr, + kv_len_ptr, B, H_grid, LQ, @@ -191,6 +201,7 @@ def _sdpa_fwd_kernel_non_pow2( BLOCK_D: tl.constexpr, HAS_MASK: tl.constexpr, IS_CAUSAL: tl.constexpr, + HAS_KV_LEN: tl.constexpr, NUM_GROUPS: tl.constexpr, PACK_GQA: tl.constexpr, ): @@ -254,9 +265,15 @@ def _sdpa_fwd_kernel_non_pow2( NEG_INF: tl.constexpr = float("-inf") - for start_n in tl.range(0, LK, BLOCK_N, num_stages=2): + # Bound the KV loop to valid (filled) positions; see pow2 body for details. + if HAS_KV_LEN: + kv_len = tl.load(kv_len_ptr) + else: + kv_len = LK + + for start_n in tl.range(0, kv_len, BLOCK_N, num_stages=2): offs_n = start_n + tl.arange(0, BLOCK_N) - kv_col_mask = offs_n < LK + kv_col_mask = offs_n < kv_len k_ptrs = k_base + (offs_n[:, None] * stride_kl + offs_d[None, :] * stride_kd) k = tl.load(k_ptrs, mask=kv_col_mask[:, None] & d_mask[None, :], other=0.0) @@ -332,6 +349,7 @@ def _sdpa_fwd_kernel_body( V_ptr, O_ptr, Mask_ptr, + KV_LEN_ptr, B, H_grid, Lq, @@ -358,6 +376,7 @@ def _sdpa_fwd_kernel_body( sm_scale: tl.float32, HAS_MASK: tl.constexpr, IS_CAUSAL: tl.constexpr, + HAS_KV_LEN: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, HEAD_DIM: tl.constexpr, @@ -422,6 +441,18 @@ def _sdpa_fwd_kernel_body( offs_n_init = tl.arange(0, BLOCK_N) + # Bound the KV loop to the number of valid (filled) positions instead of the + # full pre-allocated buffer Lk. For decode this is input_pos+1; for a prefill + # chunk it is chunk_end. This makes full-attention (global) layers O(context) + # rather than O(max_seq_len) — the empty tail of the cache is never touched. + # kv_len is read from a GPU scalar so the bound updates across CUDA-graph + # replays (decode is graph-captured). When not provided (HAS_KV_LEN False) it + # falls back to Lk, preserving the original behavior exactly. + if HAS_KV_LEN: + kv_len = tl.load(KV_LEN_ptr) + else: + kv_len = Lk + # Window-aware early-exit. A KV block that is fully masked (sliding-window # or causal) contributes nothing to the online softmax — every entry is # -inf, so p=0 and m_i/l_i/acc are left unchanged. We detect such blocks up @@ -434,7 +465,7 @@ def _sdpa_fwd_kernel_body( if IS_CAUSAL: max_seq_pos = tl.max(seq_pos) - for start_n in tl.range(0, Lk, BLOCK_N): + for start_n in tl.range(0, kv_len, BLOCK_N): offs_n = start_n + offs_n_init # Decide whether any row in this tile actually attends to this KV block. @@ -444,7 +475,7 @@ def _sdpa_fwd_kernel_body( + (seq_pos[:, None] * stride_mq) + (offs_n[None, :] * stride_mk) ) - mn_mask = row_valid[:, None] & (offs_n[None, :] < Lk) + mn_mask = row_valid[:, None] & (offs_n[None, :] < kv_len) mask_block = tl.load(mask_ptrs, mask=mn_mask, other=False) block_active = tl.sum(mask_block.to(tl.int32)) > 0 elif IS_CAUSAL: @@ -461,7 +492,7 @@ def _sdpa_fwd_kernel_body( + (offs_n[:, None] * stride_kn) + (offs_d[None, :] * stride_kd) ) - k_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) + k_mask = (offs_n[:, None] < kv_len) & (offs_d[None, :] < HEAD_DIM) k = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.bfloat16) qk = (tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale).to(tl.float32) @@ -493,7 +524,7 @@ def _sdpa_fwd_kernel_body( + (offs_n[:, None] * stride_vn) + (offs_d[None, :] * stride_vd) ) - v_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) + v_mask = (offs_n[:, None] < kv_len) & (offs_d[None, :] < HEAD_DIM) v = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.bfloat16) p_bf16 = p_f32.to(tl.bfloat16) @@ -523,111 +554,68 @@ def _sdpa_fwd_kernel_body( tl.store(o_ptrs, acc.to(tl.bfloat16), mask=o_mask) -@triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=3), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2), - ], - key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"], -) -@triton.jit -def _sdpa_fwd_kernel_m64( - Q_ptr, - K_ptr, - V_ptr, - O_ptr, - Mask_ptr, - B, - H_grid, - Lq, - Lk, - stride_qb, - stride_qh, - stride_qm, - stride_qd, - stride_kb, - stride_kh, - stride_kn, - stride_kd, - stride_vb, - stride_vh, - stride_vn, - stride_vd, - stride_ob, - stride_oh, - stride_om, - stride_od, - stride_mb, - stride_mq, - stride_mk, - sm_scale: tl.float32, - HAS_MASK: tl.constexpr, - IS_CAUSAL: tl.constexpr, - HEAD_DIM: tl.constexpr, - NUM_GROUPS: tl.constexpr, - PACK_GQA: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - _sdpa_fwd_kernel_body( - Q_ptr, - K_ptr, - V_ptr, - O_ptr, - Mask_ptr, - B, - H_grid, - Lq, - Lk, - stride_qb, - stride_qh, - stride_qm, - stride_qd, - stride_kb, - stride_kh, - stride_kn, - stride_kd, - stride_vb, - stride_vh, - stride_vn, - stride_vd, - stride_ob, - stride_oh, - stride_om, - stride_od, - stride_mb, - stride_mq, - stride_mk, - sm_scale, - HAS_MASK=HAS_MASK, - IS_CAUSAL=IS_CAUSAL, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - HEAD_DIM=HEAD_DIM, - NUM_GROUPS=NUM_GROUPS, - PACK_GQA=PACK_GQA, - ) +# Prefill / standard-path tile configs. ONE autotuned kernel spanning BLOCK_M in +# {16..128}; `_sdpa_prefill_prune` drops configs whose fp32 accumulator +# acc[BLOCK_M, HEAD_DIM] would spill registers for the runtime HEAD_DIM, so the +# kernel is high-occupancy AND HEAD_DIM-agnostic (64/80/96/128/256/512). This +# replaces the old fixed BLOCK_M=64 (m64) / BLOCK_M=32 (m32) wrappers + Python +# CTA-count selector: at HEAD_DIM=512 the m64 path spilled acc[64,512] fp32 +# (128 KB/CTA -> ~280 reg spills -> ~30 TFLOP/s); the autotuner now picks a +# non-spilling, well-pipelined tile per HEAD_DIM (e.g. BLOCK_M=32 at 512). +_SDPA_PREFILL_CONFIGS = [ + triton.Config({"BLOCK_M": 16, "BLOCK_N": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 64}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=3), +] + + +def _sdpa_prefill_prune(configs, nargs, **kwargs): + """Drop configs whose fp32 acc[BLOCK_M, HEAD_DIM] would spill registers. + + Keeps ``BLOCK_M * HEAD_DIM <= 4096 * num_warps`` (the measured A100 no-spill + boundary: HEAD_DIM=512 -> BLOCK_M<=32 at 4 warps / <=64 at 8 warps; + HEAD_DIM=128 -> BLOCK_M<=128 at 4 warps). This guarantees a high-occupancy + pick for any HEAD_DIM and a non-empty result (the BLOCK_M=16 configs satisfy + the budget for every HEAD_DIM<=1024). SMEM-OOR tiles (large + BLOCK_N*HEAD_DIM*num_stages) are pruned by the autotuner at benchmark time. + """ + head_dim = kwargs.get("HEAD_DIM") + if head_dim is None and nargs is not None: + head_dim = nargs.get("HEAD_DIM") + if head_dim is None: + return configs + kept = [ + c + for c in configs + if c.kwargs["BLOCK_M"] * head_dim <= 4096 * c.num_warps + ] + if not kept: + kept = [min(configs, key=lambda c: c.kwargs["BLOCK_M"] / c.num_warps)] + return kept @triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2), - ], + configs=_SDPA_PREFILL_CONFIGS, key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"], + prune_configs_by={"early_config_prune": _sdpa_prefill_prune}, ) @triton.jit -def _sdpa_fwd_kernel_m32( +def _sdpa_fwd_kernel( Q_ptr, K_ptr, V_ptr, O_ptr, Mask_ptr, + KV_LEN_ptr, B, H_grid, Lq, @@ -654,6 +642,7 @@ def _sdpa_fwd_kernel_m32( sm_scale: tl.float32, HAS_MASK: tl.constexpr, IS_CAUSAL: tl.constexpr, + HAS_KV_LEN: tl.constexpr, HEAD_DIM: tl.constexpr, NUM_GROUPS: tl.constexpr, PACK_GQA: tl.constexpr, @@ -666,6 +655,7 @@ def _sdpa_fwd_kernel_m32( V_ptr, O_ptr, Mask_ptr, + KV_LEN_ptr, B, H_grid, Lq, @@ -692,6 +682,7 @@ def _sdpa_fwd_kernel_m32( sm_scale, HAS_MASK=HAS_MASK, IS_CAUSAL=IS_CAUSAL, + HAS_KV_LEN=HAS_KV_LEN, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM, @@ -785,6 +776,8 @@ def _launch_pow2_kernel( is_causal: bool, num_groups: int, pack_gqa: bool, + kv_len_ptr: Optional[torch.Tensor] = None, + HAS_KV_LEN: bool = False, ) -> None: """Launch power-of-2 optimized SDPA kernel.""" stride_qb, stride_qh, stride_qm, stride_qd = query.stride() @@ -802,18 +795,18 @@ def _launch_pow2_kernel( def grid(meta): return (triton.cdiv(Lq_packed, meta["BLOCK_M"]), B * H_grid) - total_ctas_m64 = ((Lq_packed + 63) // 64) * (B * H_grid) - threshold = 4 * 84 - kernel = ( - _sdpa_fwd_kernel_m32 if total_ctas_m64 < threshold else _sdpa_fwd_kernel_m64 - ) - - wrap_triton(kernel)[grid]( + # Single autotuned kernel: the config set spans BLOCK_M in {16..128} and + # `_sdpa_prefill_prune` keeps only non-spilling tiles for this HEAD_DIM, so + # the autotuner picks a high-occupancy tile (small BLOCK_M for large HEAD_DIM, + # larger BLOCK_M / more CTAs for small problems) — subsuming the old + # CTA-count m32/m64 selector. + wrap_triton(_sdpa_fwd_kernel)[grid]( query, key, value, out, Mask_ptr if HAS_MASK else 0, + kv_len_ptr if HAS_KV_LEN else 0, B, H_grid, L_q, @@ -840,6 +833,7 @@ def grid(meta): sm_scale, HAS_MASK=HAS_MASK, IS_CAUSAL=is_causal, + HAS_KV_LEN=HAS_KV_LEN, HEAD_DIM=D, NUM_GROUPS=num_groups, PACK_GQA=pack_gqa, @@ -863,6 +857,8 @@ def _launch_non_pow2_kernel( is_causal: bool, num_groups: int, pack_gqa: bool, + kv_len_ptr: Optional[torch.Tensor] = None, + HAS_KV_LEN: bool = False, ) -> None: """Launch non-power-of-2 SDPA kernel with dynamic HEAD_DIM masking.""" stride_qb, stride_qh, stride_qm, stride_qd = query.stride() @@ -902,6 +898,7 @@ def grid_non_pow2(meta): value, out, mask_ptr, + kv_len_ptr if HAS_KV_LEN else 0, B, H_grid, L_q, @@ -933,6 +930,7 @@ def grid_non_pow2(meta): BLOCK_D=BLOCK_D, HAS_MASK=HAS_MASK, IS_CAUSAL=is_causal, + HAS_KV_LEN=HAS_KV_LEN, NUM_GROUPS=num_groups, PACK_GQA=pack_gqa, num_warps=num_warps, @@ -950,6 +948,7 @@ def sdpa( is_causal: bool = False, scale: float = 0.0, enable_gqa: bool = False, + kv_len: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Triton fused Scaled Dot-Product Attention with GQA pack optimization. @@ -967,6 +966,15 @@ def sdpa( is_causal: apply causal masking scale: attention scale (default: 1/sqrt(D)) enable_gqa: allow H_q != H_kv (GQA/MQA) + kv_len: Optional GPU int scalar = number of valid (filled) KV positions. + When provided, the inner KV loop is bounded to ``kv_len`` instead of + the full pre-allocated ``L_kv``, making attention O(context) instead + of O(max_seq_len). It is read on-device (no host sync) so the bound + updates correctly under CUDA-graph replay (decode). For decode pass + ``input_pos + 1``; for a prefill chunk pass ``chunk_end``. When None + the loop runs over the full ``L_kv`` (original behavior). Supplying + it for an L_q==1 decode with a large buffer also routes through the + split-K flash-decoding kernel for occupancy. Returns: Output tensor [B, H_q, L_q, D], dtype torch.bfloat16 """ @@ -984,6 +992,59 @@ def sdpa( "For decode (L_q < L_kv), use an explicit bool mask instead." ) + out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype) + sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale + HAS_MASK, Mask_ptr, stride_mb, stride_mq, stride_mk = _prepare_mask_params( + attn_mask, B, L_q, L_kv + ) + + # Optional length bound: device int32 scalar, clamped to the buffer size for + # OOB safety. Reshaped to [1] so the kernel can ``tl.load`` element 0. No + # ``.item()`` — keeps it CUDA-graph-safe (value updates on replay). + HAS_KV_LEN = kv_len is not None + if HAS_KV_LEN: + kv_len_t = torch.clamp( + kv_len.reshape(1).to(torch.int32), max=int(L_kv) + ).contiguous() + else: + kv_len_t = None + + # Split-K decode dispatch: L_q == 1 with a kv_len bound and a large KV + # buffer. Flash-decoding partitions the KV sequence across many CTAs for + # better occupancy (L_q=1 launches too few CTAs otherwise). The split is + # static (from buffer size L_kv, not the runtime kv_len value) so it is + # export/AOTI-traceable; the kernel still bounds each split's loop by kv_len + # on-device (CUDA-graph safe). Only taken when kv_len is supplied, so callers + # that don't pass kv_len keep the exact original (standard-kernel) dispatch. + if ( + HAS_KV_LEN + and L_q == 1 + and _is_power_of_2(D) + and L_kv >= _SPLITK_LKV_THRESHOLD + ): + _launch_decode_splitk( + query, + key, + value, + out, + B, + H_q, + H_kv, + L_kv, + D, + sm_scale, + HAS_MASK, + Mask_ptr, + stride_mb, + stride_mq, + stride_mk, + num_groups, + _DEFAULT_SPLITK_PHI, + kv_len_t, + HAS_KV_LEN, + ) + return out + # Decide whether to pack GQA based on tile utilization heuristic. # Use the actual BLOCK_M that the launched kernel will use: # - non-pow2 path always uses BLOCK_M=32 @@ -995,12 +1056,6 @@ def sdpa( block_m = 32 if total_ctas_m64 < 4 * 84 else 64 pack_gqa = _should_pack_gqa(L_q, num_groups, block_m) - out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype) - sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale - HAS_MASK, Mask_ptr, stride_mb, stride_mq, stride_mk = _prepare_mask_params( - attn_mask, B, L_q, L_kv - ) - if _is_power_of_2(D): _launch_pow2_kernel( query, @@ -1022,6 +1077,8 @@ def sdpa( is_causal, num_groups, pack_gqa, + kv_len_t, + HAS_KV_LEN, ) else: _launch_non_pow2_kernel( @@ -1041,6 +1098,8 @@ def sdpa( is_causal, num_groups, pack_gqa, + kv_len_t, + HAS_KV_LEN, ) return out @@ -1058,6 +1117,7 @@ def _sdpa_abstract( is_causal: bool = False, scale: float = 0.0, enable_gqa: bool = False, + kv_len: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Abstract/fake implementation for torch.export. @@ -1104,6 +1164,7 @@ def _sdpa_decode_splitk_kernel( O_partial_ptr, L_partial_ptr, Mask_ptr, + KV_LEN_ptr, B, H_kv, Lk, @@ -1133,6 +1194,7 @@ def _sdpa_decode_splitk_kernel( phi: tl.float32, chunk_size, HAS_MASK: tl.constexpr, + HAS_KV_LEN: tl.constexpr, BLOCK_N: tl.constexpr, HEAD_DIM: tl.constexpr, NUM_GROUPS: tl.constexpr, @@ -1144,7 +1206,15 @@ def _sdpa_decode_splitk_kernel( h_kv = pid_bh % H_kv start_n = split_id * chunk_size - end_n = tl.minimum(start_n + chunk_size, Lk) + # Bound the decode KV sweep to the valid (filled) positions. Splits whose + # chunk starts past kv_len do no work (end_n <= start_n) and store the zero + # partials they were initialized with, so the reduce is unaffected. kv_len is + # read on-device (CUDA-graph safe); falls back to Lk when not provided. + if HAS_KV_LEN: + kv_len = tl.load(KV_LEN_ptr) + else: + kv_len = Lk + end_n = tl.minimum(start_n + chunk_size, kv_len) offs_d = tl.arange(0, HEAD_DIM) offs_g = tl.arange(0, BLOCK_G) @@ -1293,6 +1363,8 @@ def _launch_decode_splitk( stride_mk: int, num_groups: int, phi: float, + kv_len_ptr: Optional[torch.Tensor] = None, + HAS_KV_LEN: bool = False, ) -> None: num_splits = min(max(triton.cdiv(L_kv, 256), 1), 128) chunk_size = triton.cdiv(L_kv, num_splits) @@ -1319,6 +1391,7 @@ def _launch_decode_splitk( O_partial, L_partial, Mask_ptr if HAS_MASK else 0, + kv_len_ptr if HAS_KV_LEN else 0, B, H_kv, L_kv, @@ -1348,6 +1421,7 @@ def _launch_decode_splitk( phi, chunk_size, HAS_MASK=HAS_MASK, + HAS_KV_LEN=HAS_KV_LEN, HEAD_DIM=D, NUM_GROUPS=num_groups, BLOCK_G=_next_power_of_2_unclamped(num_groups), @@ -1387,6 +1461,7 @@ def sdpa_decode_splitk( scale: float = 0.0, enable_gqa: bool = False, phi: float = 5.0, + kv_len: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Split-K flash-decoding SDPA for L_q=1 (decode step). @@ -1396,6 +1471,10 @@ def sdpa_decode_splitk( Signature mirrors sdpa() for drop-in use with torch.cond dispatch. enable_gqa is accepted but ignored — GQA is handled natively via H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1. + + kv_len: optional GPU int scalar bounding the KV sweep to the valid + (filled) positions (O(context) instead of O(max_seq_len)). Read + on-device, CUDA-graph safe. When None, sweeps the full L_kv. """ _validate_sdpa_inputs(query, key, value, dropout_p, enable_gqa) @@ -1431,6 +1510,14 @@ def sdpa_decode_splitk( attn_mask, B, L_q, L_kv ) + HAS_KV_LEN = kv_len is not None + if HAS_KV_LEN: + kv_len_t = torch.clamp( + kv_len.reshape(1).to(torch.int32), max=int(L_kv) + ).contiguous() + else: + kv_len_t = None + _launch_decode_splitk( query, key, @@ -1449,6 +1536,8 @@ def sdpa_decode_splitk( stride_mk, num_groups, phi, + kv_len_t, + HAS_KV_LEN, ) return out @@ -1464,6 +1553,7 @@ def _sdpa_decode_splitk_abstract( scale: float = 0.0, enable_gqa: bool = False, phi: float = 5.0, + kv_len: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype" B, H_q, L_q, D = query.shape diff --git a/examples/models/gemma4_31b/cuda_source_transformations.py b/examples/models/gemma4_31b/cuda_source_transformations.py index 6609178e084..8e3358e71a2 100644 --- a/examples/models/gemma4_31b/cuda_source_transformations.py +++ b/examples/models/gemma4_31b/cuda_source_transformations.py @@ -28,6 +28,11 @@ # Importing this module registers ``torch.ops.triton.tq4_sdpa``. import executorch.backends.cuda.triton.kernels.tq4_sdpa # noqa: F401 +# Importing this module registers ``torch.ops.triton.sdpa`` / +# ``torch.ops.triton.sdpa_decode_splitk`` (the length-aware bf16 attention ops +# used by the non-TurboQuant full-attention path below). +import executorch.backends.cuda.triton.kernels.sdpa # noqa: F401 + import torch import torch.nn as nn import torch.nn.functional as F @@ -111,6 +116,79 @@ def _turboquant_attention_forward( return self.o_proj(y) +def _lenaware_attention_forward( + self, + x: torch.Tensor, + input_pos: torch.Tensor, + attn_mask: torch.Tensor, +) -> torch.Tensor: + """Drop-in ``Gemma4Attention.forward`` for full-attention layers on the + non-TurboQuant CUDA path that bounds SDPA to the valid context length. + + Identical to the default forward (plain bf16 KV cache) except the final + ``F.scaled_dot_product_attention`` is replaced with + ``torch.ops.triton.sdpa(..., kv_len=...)``. Passing ``kv_len`` bounds the + attention KV loop to the actual filled context instead of the full + pre-allocated buffer (``max_seq_len`` for global layers), making decode + O(context) instead of O(max_seq_len) — and routes L_q==1 decode through the + length-aware split-K flash-decoding kernel. Sliding-window layers are not + patched (they already use a bounded ring buffer). + """ + B, T, _ = x.shape + + q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim) + raw_k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim) + if self.k_eq_v: + raw_v = raw_k + else: + raw_v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim) + + q = self.q_norm(q) + k = self.k_norm(raw_k) + v = self.v_norm(raw_v) + + # (B, H, T, D) for SDPA / KV cache. + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # RoPE: same code path as default forward. + freqs = torch.outer(input_pos.float(), self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos = torch.cos(emb) + sin = torch.sin(emb) + q, k = apply_rotary_emb(q, k, cos, sin) + + # Update cache and read back the full (pre-allocated) K/V buffers. + k, v = self.kv_cache.update(input_pos, k, v) + + # Number of valid (filled) KV positions = input_pos[0] + T. Passing this to + # sdpa bounds its KV loop to the actual context instead of the full + # pre-allocated buffer (max_seq_len for global layers), making attention + # O(context) instead of O(max_seq_len). Kept as a GPU scalar (no ``.item()``) + # so the bound is captured correctly by the decode CUDA graph. Decode: T=1 -> + # input_pos+1; prefill chunk: T -> chunk_end. + kv_len = input_pos[0] + input_pos.shape[0] + + # ``scale=self.scaling`` (= 1.0 for Gemma 4) — Gemma's QK-norm has absorbed + # the 1/sqrt(d) factor into trained weights. ``enable_gqa=True`` lets the + # kernel handle the head ratio without materializing expanded K/V. + y = torch.ops.triton.sdpa( + q, + k, + v, + attn_mask, + 0.0, # dropout_p + False, # is_causal: attn_mask already encodes causal masking + self.scaling, + True, # enable_gqa + kv_len, + ) + + y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) + return self.o_proj(y) + + def _fused_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: """Drop-in ``Gemma4MLP.forward`` over a fused gate|up projection. @@ -233,6 +311,22 @@ def cuda_source_transformations( _fuse_gate_up_proj(model) if not use_turboquant: + # Non-TurboQuant path: keep the bf16 KV cache but bound full-attention + # SDPA to the valid context length via a runtime kv_len scalar (routes + # through torch.ops.triton.sdpa, which dispatches L_q==1 decode to the + # length-aware split-K flash-decoding kernel). Sliding-window layers + # already use a bounded ring buffer, so they are left untouched. + n_bounded = 0 + for layer in model.layers: + attn = layer.self_attn + if attn.is_sliding: + continue + attn.forward = types.MethodType(_lenaware_attention_forward, attn) + n_bounded += 1 + print( + f"[gemma4_31b cuda] length-aware SDPA: bounded {n_bounded} " + f"full-attention layers to runtime kv_len (O(context) attention)" + ) return config = model.config