From a850fa069a9beb90d9b950aae264671642a7d3ab Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Wed, 24 Jun 2026 19:35:31 -0700 Subject: [PATCH 1/2] [ExecuTorch][WebGPU] SDPA: skip QK contraction for fully-masked causal tiles MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/20492 **Skip the QK contraction for fully-masked causal tiles** — at S=128 prefill ~48% of the (query, key) tiles are entirely above the diagonal and contribute nothing; this elides their dot products (prefill-only; bit-identical output). **Problem**: For causal prefill, ~half the (query S-tile, key context-tile) pairs are entirely above the diagonal, yet the kernel still computes their full `d4` dot product before masking the result to `NEG_INF`. **Solution**: Skip the contraction for fully-masked tiles; the existing per-element mask still writes the sentinel: - **Before**: every `(s0, c0)` tile runs the full `d4` dot-product loop, then `store_qk` masks above-diagonal elements to `NEG_INF`. - **After**: a fully-masked tile (`c0 > s0 + TM-1 + input_pos`) breaks the `d4` loop immediately (`acc` stays 0); `store_qk` masks every element to `NEG_INF` exactly as before. **Implementation**: - Add `skip_tile = c0 > s0 + (TM - 1) + params.input_pos`, folded into the `d4` loop break condition. - Store loop unchanged — runs unconditionally, so no scratch entry is left stale. - Mirrors Vulkan `sdpa_compute_attn_weights_tiled.glsl` (`tile_in_mask_region`). **Constraints**: - No KV-cache, host, dispatch, or uniform change (all tiles still launch; the skip is in-shader). - Prefill-only: decode `S=1` never triggers it (`c0 <= input_pos < input_pos + TM - 1`). - `NEG_INF` stays the WGSL-safe `-1.0e30` (WGSL forbids a literal `-inf`); does not copy Vulkan's `-1.0/0.0`. Co-authored with Claude Code. ghstack-source-id: 396792509 @exported-using-ghexport Differential Revision: [D109517773](https://our.internmc.facebook.com/intern/diff/D109517773/) --- .../webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl | 4 +++- .../runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl index e694a6b223c..9a5cd614f8b 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl @@ -85,9 +85,11 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[2] = vec4(0.0, 0.0, 0.0, 0.0); acc[3] = vec4(0.0, 0.0, 0.0, 0.0); + // Skip fully-masked causal tiles; mirrors Vulkan attn_weights_tiled.glsl. + let skip_tile = c0 > s0 + (TM - 1u) + params.input_pos; var d4: u32 = 0u; loop { - if (d4 >= params.D) { + if (d4 >= params.D || skip_tile) { break; } var q: array, TM>; diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h index 3d027b417f0..144af9b5956 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_attn_weights.wgsl - DO NOT EDIT. -// wgsl-sha256: e9bdc2272bba2716655e96be0597571701b4ec1496f78a85660e29d108f655f9 +// wgsl-sha256: d177264689e6c50e1794a0599808f3cfe6f30ba99c5084d3c8324da4b9f89d10 inline constexpr const char* kSdpaComputeAttnWeightsWGSL = R"( @group(0) @binding(0) var t_attn_weights: array; @group(0) @binding(1) var t_q: array; @@ -102,9 +102,11 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[2] = vec4(0.0, 0.0, 0.0, 0.0); acc[3] = vec4(0.0, 0.0, 0.0, 0.0); + // Skip fully-masked causal tiles; mirrors Vulkan attn_weights_tiled.glsl. + let skip_tile = c0 > s0 + (TM - 1u) + params.input_pos; var d4: u32 = 0u; loop { - if (d4 >= params.D) { + if (d4 >= params.D || skip_tile) { break; } var q: array, TM>; From 71de6e720c1ab765e8ddb90afd2e2e08f58f9d2e Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Wed, 24 Jun 2026 19:35:32 -0700 Subject: [PATCH 2/2] [ExecuTorch][WebGPU] SDPA: branchless aligned/tail loads in the QK/AV kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/20493 **Branchless aligned/tail loads + vec4 storage bindings** — drop the always-true per-lane bounds checks in the tiled QK/AV hot loops, split the AV context contraction into a branch-free aligned body plus a checked tail, and declare the head-dim-indexed SDPA storage buffers as `array>` so the loads/stores are forced-vectorized (addresses review feedback to mirror Vulkan's vec4 bindings). **Problem**: The tiled QK/AV vec4 loaders run 4 per-lane `if` bounds checks on every load, every contraction iteration (8 loads/iter). But `head_dim` is always a multiple of 4, so the D-axis checks never fire, and the AV context axis only needs a bounds check on the last ragged chunk. Separately the storage buffers were declared `array`, so the 4-lane loads/stores were not guaranteed to compile to aligned 128-bit vector accesses. **Solution**: Remove the dead checks, split the ragged axis, and vectorize the bindings: - **Before**: `load_q_vec4`/`load_k_vec4` (and AV `load_a_vec4`/`load_v_d4`) do 4 per-lane bounds `if`s per call; the AV `c4` loop runs checked loads for every chunk; `t_q`/`t_k_cache`/`t_v_cache`/`t_out` are `array` accessed element-by-element. - **After**: QK loads are a plain unchecked `vec4` (D%4==0, host-guarded); AV runs a branch-free aligned body over `c4 in [0, context_len - context_len%4)` then a 0-or-1 checked tail; the head-dim-indexed buffers `t_q`/`t_k_cache`/`t_v_cache`/`t_out` are `array>` indexed `[base/4u]`, and AV writes a single aligned `store_out_vec4`. **Implementation**: - QK: `load_q_vec4`/`load_k_vec4` drop the per-lane D checks and return `t_q[base/4u]` / `t_k_cache[base/4u]`. - AV: branch-free `load_a_vec4_nc`/`load_v_d4_nc` for the aligned body; checked `load_a_vec4`/`load_v_d4` for the tail; V reads `t_v_cache[base/4u]`; output is one aligned `store_out_vec4`. - Bindings: `t_q`, `t_k_cache` (QK) and `t_v_cache`, `t_out` (AV) are `array>`. `t_attn_weights` and the softmax buffer stay `array` — they are `context_len`-indexed (row stride not 4-aligned) and written per-element under the causal mask, so a `vec4` binding there would need a padded scratch row. - Host: add a `D % 4 == 0` guard in `Sdpa.cpp` — WGSL has no `SDPA_PAD_D` pad-load, so fail loud rather than read past the row; this guard also makes every `[base/4u]` index 4-aligned and every buffer a 16-byte multiple. - Test: add a `reject_d6` (head_dim=6) config + an `expect_reject` harness branch asserting the guard rejects a non-aligned head_dim at load. - Mirrors Vulkan `sdpa_compute_out_tiled.glsl` (aligned/tail split) and Vulkan's `array` SDPA bindings. **Constraints**: - Requires `head_dim % 4 == 0` (true for every Llama config, D=64); enforced by a loud host throw, not a silent narrowing. - Bit-identical output: the aligned body processes the same chunks in the same accumulation order as the scalar loop, the tail's out-of-range lanes contribute 0, and the `vec4` bindings read/write the same bytes as the scalar version. - No KV-cache layout, dispatch, or uniform change. Co-authored with Claude Code. ghstack-source-id: 396792517 @exported-using-ghexport Differential Revision: [D109521069](https://our.internmc.facebook.com/intern/diff/D109521069/) --- backends/webgpu/runtime/ops/sdpa/Sdpa.cpp | 5 ++ .../ops/sdpa/sdpa_compute_attn_weights.wgsl | 27 +++----- .../ops/sdpa/sdpa_compute_attn_weights_wgsl.h | 29 +++------ .../runtime/ops/sdpa/sdpa_compute_out.wgsl | 62 ++++++++++++------ .../runtime/ops/sdpa/sdpa_compute_out_wgsl.h | 64 +++++++++++++------ backends/webgpu/test/ops/sdpa/test_sdpa.py | 2 + backends/webgpu/test/test_webgpu_native.cpp | 21 ++++++ 7 files changed, 132 insertions(+), 78 deletions(-) diff --git a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp index aaf93c7fda0..b1aa689a09d 100644 --- a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp +++ b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp @@ -339,6 +339,11 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { if (k.dims[kn - 1] != D || v.dims[v.dims.size() - 1] != D) { throw std::runtime_error("WebGPU sdpa: k/v head_dim must match q"); } + // QK/AV read D as vec4 (no SDPA_PAD_D); head_dim must be a multiple of 4. + if (D % 4 != 0) { + throw std::runtime_error( + "WebGPU sdpa: head_dim (D) must be a multiple of 4"); + } if (v.dims[v.dims.size() - 2] != Hkv) { throw std::runtime_error("WebGPU sdpa: v num_heads must match k"); } diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl index 9a5cd614f8b..fd15767603d 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl @@ -1,6 +1,6 @@ @group(0) @binding(0) var t_attn_weights: array; -@group(0) @binding(1) var t_q: array; -@group(0) @binding(2) var t_k_cache: array; +@group(0) @binding(1) var t_q: array>; +@group(0) @binding(2) var t_k_cache: array>; struct Params { S: u32, @@ -22,30 +22,21 @@ override wg_size: u32 = 64; const TM: u32 = 4u; const TN: u32 = 4u; +// D is a multiple of 4 (host-guarded), so a d4 chunk is fully in-bounds — no per-lane check. fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (s >= params.S) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } - let base = s * params.Hq * params.D + h * params.D; - if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; } - if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; } - if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; } - if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; } - return r; + let base = s * params.Hq * params.D + h * params.D + d4; + return t_q[base / 4u]; } fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (c >= params.context_len) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } - let base = c * params.Hkv * params.D + kvh * params.D; - if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; } - if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; } - if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; } - if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; } - return r; + let base = c * params.Hkv * params.D + kvh * params.D + d4; + return t_k_cache[base / 4u]; } fn store_qk(s: u32, c: u32, h: u32, raw: f32) { diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h index 144af9b5956..ae250959e0e 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h @@ -13,11 +13,11 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_attn_weights.wgsl - DO NOT EDIT. -// wgsl-sha256: d177264689e6c50e1794a0599808f3cfe6f30ba99c5084d3c8324da4b9f89d10 +// wgsl-sha256: 4eef09b234fd926cdc0daf18d03e39cf4fd57dfa4bc67724b4878b7dc68d1254 inline constexpr const char* kSdpaComputeAttnWeightsWGSL = R"( @group(0) @binding(0) var t_attn_weights: array; -@group(0) @binding(1) var t_q: array; -@group(0) @binding(2) var t_k_cache: array; +@group(0) @binding(1) var t_q: array>; +@group(0) @binding(2) var t_k_cache: array>; struct Params { S: u32, @@ -39,30 +39,21 @@ override wg_size: u32 = 64; const TM: u32 = 4u; const TN: u32 = 4u; +// D is a multiple of 4 (host-guarded), so a d4 chunk is fully in-bounds — no per-lane check. fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (s >= params.S) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } - let base = s * params.Hq * params.D + h * params.D; - if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; } - if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; } - if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; } - if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; } - return r; + let base = s * params.Hq * params.D + h * params.D + d4; + return t_q[base / 4u]; } fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (c >= params.context_len) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } - let base = c * params.Hkv * params.D + kvh * params.D; - if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; } - if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; } - if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; } - if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; } - return r; + let base = c * params.Hkv * params.D + kvh * params.D + d4; + return t_k_cache[base / 4u]; } fn store_qk(s: u32, c: u32, h: u32, raw: f32) { diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl index a5c79dfd4e3..56242b0ddde 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl @@ -1,6 +1,6 @@ -@group(0) @binding(0) var t_out: array; +@group(0) @binding(0) var t_out: array>; @group(0) @binding(1) var t_attn_weights_softmax: array; -@group(0) @binding(2) var t_v_cache: array; +@group(0) @binding(2) var t_v_cache: array>; struct Params { S: u32, @@ -19,6 +19,7 @@ override wg_size: u32 = 64; const TM: u32 = 4u; const TN: u32 = 4u; +// Checked loaders mask context lanes past context_len (D%4==0, host-guarded). fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { var r = vec4(0.0, 0.0, 0.0, 0.0); if (s >= params.S) { @@ -33,24 +34,33 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { } fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (c >= params.context_len) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } let base = c * params.Hkv * params.D + kvh * params.D + d0; - if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; } - if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; } - if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; } - if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; } - return r; + return t_v_cache[base / 4u]; +} + +// Branch-free loaders for the aligned body: caller guarantees c4..c4+3 < context_len. +fn load_a_vec4_nc(s: u32, h: u32, c4: u32) -> vec4 { + if (s >= params.S) { + return vec4(0.0, 0.0, 0.0, 0.0); + } + let base = h * params.S * params.context_len + s * params.context_len + c4; + return vec4(t_attn_weights_softmax[base], t_attn_weights_softmax[base + 1u], t_attn_weights_softmax[base + 2u], t_attn_weights_softmax[base + 3u]); +} + +fn load_v_d4_nc(c: u32, kvh: u32, d0: u32) -> vec4 { + let base = c * params.Hkv * params.D + kvh * params.D + d0; + return t_v_cache[base / 4u]; } -fn store_out(s: u32, d: u32, h: u32, val: f32) { - if (s >= params.S || d >= params.D) { +fn store_out_vec4(s: u32, d0: u32, h: u32, val: vec4) { + if (s >= params.S) { return; } - let idx = s * params.Hq * params.D + h * params.D + d; - t_out[idx] = val; + let idx = s * params.Hq * params.D + h * params.D + d0; + t_out[idx / 4u] = val; } @compute @workgroup_size(wg_size, 1, 1) @@ -77,11 +87,28 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[2] = vec4(0.0, 0.0, 0.0, 0.0); acc[3] = vec4(0.0, 0.0, 0.0, 0.0); + // Branch-free aligned body + checked tail; mirrors Vulkan out_tiled.glsl. + let ctx_aligned = params.context_len - (params.context_len & 3u); var c4: u32 = 0u; loop { - if (c4 >= params.context_len) { + if (c4 >= ctx_aligned) { break; } + let a0 = load_a_vec4_nc(s0 + 0u, h, c4); + let a1 = load_a_vec4_nc(s0 + 1u, h, c4); + let a2 = load_a_vec4_nc(s0 + 2u, h, c4); + let a3 = load_a_vec4_nc(s0 + 3u, h, c4); + let v0 = load_v_d4_nc(c4 + 0u, kvh, d0); + let v1 = load_v_d4_nc(c4 + 1u, kvh, d0); + let v2 = load_v_d4_nc(c4 + 2u, kvh, d0); + let v3 = load_v_d4_nc(c4 + 3u, kvh, d0); + acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3; + acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; + acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; + acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; + c4 = c4 + 4u; + } + if (c4 < params.context_len) { let a0 = load_a_vec4(s0 + 0u, h, c4); let a1 = load_a_vec4(s0 + 1u, h, c4); let a2 = load_a_vec4(s0 + 2u, h, c4); @@ -94,7 +121,6 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; - c4 = c4 + 4u; } var m: u32 = 0u; @@ -102,11 +128,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { if (m >= TM) { break; } - let ov = acc[m]; - store_out(s0 + m, d0 + 0u, h, ov.x); - store_out(s0 + m, d0 + 1u, h, ov.y); - store_out(s0 + m, d0 + 2u, h, ov.z); - store_out(s0 + m, d0 + 3u, h, ov.w); + store_out_vec4(s0 + m, d0, h, acc[m]); m = m + 1u; } } diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h index 2fef0e2d8c9..6bec079ac2b 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h @@ -13,11 +13,11 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_out.wgsl - DO NOT EDIT. -// wgsl-sha256: 545f624567b08eba407954034df821010e49124fa6f8fd6b05c64ca4354ee4cc +// wgsl-sha256: 2ffa0eb520b1054e43a10fd13e6b287bd35777f1cfc29bd39e9d668772528191 inline constexpr const char* kSdpaComputeOutWGSL = R"( -@group(0) @binding(0) var t_out: array; +@group(0) @binding(0) var t_out: array>; @group(0) @binding(1) var t_attn_weights_softmax: array; -@group(0) @binding(2) var t_v_cache: array; +@group(0) @binding(2) var t_v_cache: array>; struct Params { S: u32, @@ -36,6 +36,7 @@ override wg_size: u32 = 64; const TM: u32 = 4u; const TN: u32 = 4u; +// Checked loaders mask context lanes past context_len (D%4==0, host-guarded). fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { var r = vec4(0.0, 0.0, 0.0, 0.0); if (s >= params.S) { @@ -50,24 +51,33 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { } fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (c >= params.context_len) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } let base = c * params.Hkv * params.D + kvh * params.D + d0; - if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; } - if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; } - if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; } - if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; } - return r; + return t_v_cache[base / 4u]; +} + +// Branch-free loaders for the aligned body: caller guarantees c4..c4+3 < context_len. +fn load_a_vec4_nc(s: u32, h: u32, c4: u32) -> vec4 { + if (s >= params.S) { + return vec4(0.0, 0.0, 0.0, 0.0); + } + let base = h * params.S * params.context_len + s * params.context_len + c4; + return vec4(t_attn_weights_softmax[base], t_attn_weights_softmax[base + 1u], t_attn_weights_softmax[base + 2u], t_attn_weights_softmax[base + 3u]); +} + +fn load_v_d4_nc(c: u32, kvh: u32, d0: u32) -> vec4 { + let base = c * params.Hkv * params.D + kvh * params.D + d0; + return t_v_cache[base / 4u]; } -fn store_out(s: u32, d: u32, h: u32, val: f32) { - if (s >= params.S || d >= params.D) { +fn store_out_vec4(s: u32, d0: u32, h: u32, val: vec4) { + if (s >= params.S) { return; } - let idx = s * params.Hq * params.D + h * params.D + d; - t_out[idx] = val; + let idx = s * params.Hq * params.D + h * params.D + d0; + t_out[idx / 4u] = val; } @compute @workgroup_size(wg_size, 1, 1) @@ -94,11 +104,28 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[2] = vec4(0.0, 0.0, 0.0, 0.0); acc[3] = vec4(0.0, 0.0, 0.0, 0.0); + // Branch-free aligned body + checked tail; mirrors Vulkan out_tiled.glsl. + let ctx_aligned = params.context_len - (params.context_len & 3u); var c4: u32 = 0u; loop { - if (c4 >= params.context_len) { + if (c4 >= ctx_aligned) { break; } + let a0 = load_a_vec4_nc(s0 + 0u, h, c4); + let a1 = load_a_vec4_nc(s0 + 1u, h, c4); + let a2 = load_a_vec4_nc(s0 + 2u, h, c4); + let a3 = load_a_vec4_nc(s0 + 3u, h, c4); + let v0 = load_v_d4_nc(c4 + 0u, kvh, d0); + let v1 = load_v_d4_nc(c4 + 1u, kvh, d0); + let v2 = load_v_d4_nc(c4 + 2u, kvh, d0); + let v3 = load_v_d4_nc(c4 + 3u, kvh, d0); + acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3; + acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; + acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; + acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; + c4 = c4 + 4u; + } + if (c4 < params.context_len) { let a0 = load_a_vec4(s0 + 0u, h, c4); let a1 = load_a_vec4(s0 + 1u, h, c4); let a2 = load_a_vec4(s0 + 2u, h, c4); @@ -111,7 +138,6 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; - c4 = c4 + 4u; } var m: u32 = 0u; @@ -119,11 +145,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { if (m >= TM) { break; } - let ov = acc[m]; - store_out(s0 + m, d0 + 0u, h, ov.x); - store_out(s0 + m, d0 + 1u, h, ov.y); - store_out(s0 + m, d0 + 2u, h, ov.z); - store_out(s0 + m, d0 + 3u, h, ov.w); + store_out_vec4(s0 + m, d0, h, acc[m]); m = m + 1u; } } diff --git a/backends/webgpu/test/ops/sdpa/test_sdpa.py b/backends/webgpu/test/ops/sdpa/test_sdpa.py index b674feae635..1f7a8242591 100644 --- a/backends/webgpu/test/ops/sdpa/test_sdpa.py +++ b/backends/webgpu/test/ops/sdpa/test_sdpa.py @@ -59,6 +59,8 @@ class SdpaConfig: # Llama 3.2 1B shape: realistic prefill (S=128 at pos 0) + decode (S=1 at pos 127). SdpaConfig("llama1b_prefill", 32, 8, 64, 128, 512, 0), SdpaConfig("llama1b_decode", 32, 8, 64, 1, 512, 127), + # D=6 is not a multiple of 4: the WebGPU head_dim%4 guard must reject it at load. + SdpaConfig("reject_d6", 4, 4, 6, 4, 16, 0), ] diff --git a/backends/webgpu/test/test_webgpu_native.cpp b/backends/webgpu/test/test_webgpu_native.cpp index 8d987578aa1..6b57254ca33 100644 --- a/backends/webgpu/test/test_webgpu_native.cpp +++ b/backends/webgpu/test/test_webgpu_native.cpp @@ -719,6 +719,7 @@ struct SdpaConfig { int input_pos; // prior tokens already in the cache (decode) float denom; // ramp divisor (mirrors Python); small -> large logits bool required = false; // CI (SDPA dir set): absent .pte = FAIL, not skip + bool expect_reject = false; // load MUST fail (e.g. D%4 guard), no golden }; static const SdpaConfig kSdpaConfigs[] = { @@ -738,6 +739,17 @@ static const SdpaConfig kSdpaConfigs[] = { // pos 127). {"llama1b_prefill", 32, 8, 64, 128, 512, 0, 16.0f}, {"llama1b_decode", 32, 8, 64, 1, 512, 127, 16.0f}, + // D=6 is not a multiple of 4: the head_dim%4 guard must reject it at load. + {"reject_d6", + 4, + 4, + 6, + 4, + 16, + 0, + 16.0f, + /*required=*/false, + /*expect_reject=*/true}, }; // Ramp denominator; mirror of test_sdpa.py::_RAMP_DENOM (keep in sync). @@ -791,6 +803,15 @@ static bool test_sdpa_config( Module module(model_path); auto err = module.load_forward(); + if (cfg.expect_reject) { + // D not a multiple of 4 must be rejected at load by the head_dim guard. + if (err != Error::Ok) { + printf("PASS: %s rejected at load (error %d)\n", cfg.name, (int)err); + return true; + } + printf("FAIL: %s loaded OK; head_dim%%4 guard did not fire\n", cfg.name); + return false; + } if (err != Error::Ok) { printf("FAIL: could not load forward method (error %d)\n", (int)err); return false;