From 37abadbeeca0750f46d259195ec6641730001073 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Wed, 24 Jun 2026 19:35:30 -0700 Subject: [PATCH 1/4] [ExecuTorch][WebGPU] Register-tile the SDPA QK/AV kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/20405 **+32% SDPA attention-compute (AV +40%)** — register-tile the QK and AV kernels (isolated GPU-timestamp A/B, decode S=1, Chrome Canary / M4 Pro). A kernel-time win, not a wall-clock `forward()` win — `forward()` stays bound by the submit/sync/readback floor (the separate fusion axis). **Problem**: The naive QK/AV kernels compute one output element per thread, so each thread re-loads Q/K/V and the dot products are scalar — poor register reuse, ALU/latency-bound. **Solution**: Each thread computes a 4×4 output tile with the dot products vec4-packed in registers: - **Before**: one thread per output element; scalar accumulate over D (QK) / context (AV). - **After**: one thread per `(head, S-tile, {ctx,D}-tile)`; 4×4 register tile, vec4 dot products. A floating-point accumulation reorder of the same products — no algorithm change. **Implementation**: - `sdpa_compute_attn_weights.wgsl` (QK): one thread per `(head, S-tile, ctx-tile)`, grid `Hq · ceil(S/4) · ceil(ctx/4)`; tile registers are `array, TM/TN>` loaded via `for` loops. - `sdpa_compute_out.wgsl` (AV): one thread per `(head, S-tile, D-tile)`, grid `Hq · ceil(S/4) · ceil(D/4)`. - `Sdpa.cpp`: dispatch math moves from an element count to a tile count (`kSdpaTileM/N=4`, shared `utils::div_up`), keeping the uint32 scratch-overflow guard. - Mirrors the Vulkan register-tiled SDPA kernels; the shared `utils::div_up` mirrors Vulkan's `utils::div_up`. **Constraints**: - softmax, `update_cache`, the bind-group layouts, and the scratch-buffer sizes (`Hq*S*ctx`) are unchanged. - Scope is tiling only — causal tile-skip, V-cache coalescing, and branchless aligned/tail loads are separate follow-ups; this diff intentionally omits the Vulkan causal tile-skip since it is correctness-neutral (the per-element mask in `store_qk` is identical). See DESIGN_DECISIONS.md. - Output matches the naive kernels within fp tolerance (accumulation reorder only). ghstack-source-id: 396792505 @exported-using-ghexport Differential Revision: [D109081409](https://our.internmc.facebook.com/intern/diff/D109081409/) --- backends/webgpu/runtime/ops/sdpa/Sdpa.cpp | 22 ++-- .../ops/sdpa/sdpa_compute_attn_weights.wgsl | 106 ++++++++++++++--- .../ops/sdpa/sdpa_compute_attn_weights_wgsl.h | 108 ++++++++++++++---- .../runtime/ops/sdpa/sdpa_compute_out.wgsl | 95 ++++++++++++--- .../runtime/ops/sdpa/sdpa_compute_out_wgsl.h | 97 +++++++++++++--- 5 files changed, 353 insertions(+), 75 deletions(-) diff --git a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp index dd48f6f5902..aaf93c7fda0 100644 --- a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp +++ b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp @@ -26,6 +26,10 @@ namespace executorch::backends::webgpu { namespace { +// Register-tile dims; MUST match TM/TN in the reg WGSL kernels. +constexpr int64_t kSdpaTileM = 4; +constexpr int64_t kSdpaTileN = 4; + // Uniform param structs (all 16-byte aligned, matching the WGSL Params). struct UpdateCacheParams { uint32_t numel; @@ -464,14 +468,16 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { dynamic_pos, "update_cache(V)"); - // --- Dispatch 3: QK -> attn_weights. One thread per (h,s,c) element. + // --- Dispatch 3: QK -> attn_weights. One thread per TM x TN tile. { if (aw_floats > UINT32_MAX) { throw std::runtime_error( "WebGPU sdpa: Hq*S*context_len exceeds uint32 max"); } + const int64_t qk_tiles = Hq * utils::div_up(S, kSdpaTileM) * + utils::div_up(context_len, kSdpaTileN); const uint32_t wgc = utils::compute_1d_workgroup_count( - device, static_cast(aw_floats), qk_wg, "QK"); + device, static_cast(qk_tiles), qk_wg, "QK"); AttnWeightsParams p = make_attn_weights_params( S, Hq, Hkv, D, context_len, input_pos, g, scale); WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p)); @@ -515,12 +521,12 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { softmax_buf = ubuf; } - // --- Dispatch 5: AV -> out. One thread per (s,h,d) output element. + // --- Dispatch 5: AV -> out. One thread per TM x TN tile. { - const uint64_t out_floats = static_cast(S) * - static_cast(Hq) * static_cast(D); + const int64_t av_tiles = + Hq * utils::div_up(S, kSdpaTileM) * utils::div_up(D, kSdpaTileN); const uint32_t wgc = utils::compute_1d_workgroup_count( - device, static_cast(out_floats), av_wg, "AV"); + device, static_cast(av_tiles), av_wg, "AV"); ComputeOutParams p = make_compute_out_params(S, Hq, Hkv, D, context_len, g); WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p)); BufferBinding bindings[3] = { @@ -591,9 +597,11 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { AttnWeightsParams qp = make_attn_weights_params(S, Hq, Hkv, D, ctx, pos, g, scale); wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp)); + const int64_t qk_tiles = Hq * utils::div_up(S, kSdpaTileM) * + utils::div_up(ctx, kSdpaTileN); const uint32_t qk_wgc = utils::compute_1d_workgroup_count( gr.device(), - static_cast(aw_floats), + static_cast(qk_tiles), qk_wg, "QK(resize)"); gr.dispatch_at(qk_idx).workgroup_count_x = qk_wgc; diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl index b9905a59376..e694a6b223c 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl @@ -19,37 +19,105 @@ const NEG_INF: f32 = -1.0e30; override wg_size: u32 = 64; +const TM: u32 = 4u; +const TN: u32 = 4u; + +fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (s >= params.S) { + return r; + } + let base = s * params.Hq * params.D + h * params.D; + if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; } + if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; } + if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; } + if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; } + return r; +} + +fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (c >= params.context_len) { + return r; + } + let base = c * params.Hkv * params.D + kvh * params.D; + if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; } + if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; } + if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; } + if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; } + return r; +} + +fn store_qk(s: u32, c: u32, h: u32, raw: f32) { + if (s >= params.S || c >= params.context_len) { + return; + } + var val = raw * params.scale; + // Causal mask: position c may not attend beyond s + input_pos. + if (c > s + params.input_pos) { + val = NEG_INF; + } + let idx = h * params.S * params.context_len + s * params.context_len + c; + t_attn_weights[idx] = val; +} + @compute @workgroup_size(wg_size, 1, 1) fn main(@builtin(global_invocation_id) gid: vec3) { - let total = params.Hq * params.S * params.context_len; - let idx = gid.x; - if (idx >= total) { + let nrt = (params.S + TM - 1u) / TM; + let nct = (params.context_len + TN - 1u) / TN; + let tiles = nrt * nct; + let total = tiles * params.Hq; + if (gid.x >= total) { return; } - let c = idx % params.context_len; - let s = (idx / params.context_len) % params.S; - let h = idx / (params.context_len * params.S); + let h = gid.x / tiles; + let rem = gid.x % tiles; + let row_tile = rem / nct; + let col_tile = rem % nct; let kvh = h / params.g; + let s0 = row_tile * TM; + let c0 = col_tile * TN; - let q_base = s * params.Hq * params.D + h * params.D; - let k_base = c * params.Hkv * params.D + kvh * params.D; + var acc: array, 4>; + acc[0] = vec4(0.0, 0.0, 0.0, 0.0); + acc[1] = vec4(0.0, 0.0, 0.0, 0.0); + acc[2] = vec4(0.0, 0.0, 0.0, 0.0); + acc[3] = vec4(0.0, 0.0, 0.0, 0.0); - var acc: f32 = 0.0; - var d: u32 = 0u; + var d4: u32 = 0u; loop { - if (d >= params.D) { + if (d4 >= params.D) { break; } - acc = acc + t_q[q_base + d] * t_k_cache[k_base + d]; - d = d + 1u; + var q: array, TM>; + var k: array, TN>; + for (var i: u32 = 0u; i < TM; i = i + 1u) { + q[i] = load_q_vec4(s0 + i, h, d4); + } + for (var j: u32 = 0u; j < TN; j = j + 1u) { + k[j] = load_k_vec4(c0 + j, kvh, d4); + } + for (var i: u32 = 0u; i < TM; i = i + 1u) { + acc[i] += vec4( + dot(q[i], k[0]), + dot(q[i], k[1]), + dot(q[i], k[2]), + dot(q[i], k[3])); + } + d4 = d4 + 4u; } - acc = acc * params.scale; - // Causal mask: position c may not attend beyond s + input_pos. - if (c > s + params.input_pos) { - acc = NEG_INF; + var m: u32 = 0u; + loop { + if (m >= TM) { + break; + } + let av = acc[m]; + store_qk(s0 + m, c0 + 0u, h, av.x); + store_qk(s0 + m, c0 + 1u, h, av.y); + store_qk(s0 + m, c0 + 2u, h, av.z); + store_qk(s0 + m, c0 + 3u, h, av.w); + m = m + 1u; } - - t_attn_weights[idx] = acc; } diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h index 3f3f3d6b085..3d027b417f0 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_attn_weights.wgsl - DO NOT EDIT. -// wgsl-sha256: 7410869c1c35f09777851bf49b835dc8fecaff3f327aa64a9c900ac0cc3445e1 +// wgsl-sha256: e9bdc2272bba2716655e96be0597571701b4ec1496f78a85660e29d108f655f9 inline constexpr const char* kSdpaComputeAttnWeightsWGSL = R"( @group(0) @binding(0) var t_attn_weights: array; @group(0) @binding(1) var t_q: array; @@ -36,39 +36,107 @@ const NEG_INF: f32 = -1.0e30; override wg_size: u32 = 64; +const TM: u32 = 4u; +const TN: u32 = 4u; + +fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (s >= params.S) { + return r; + } + let base = s * params.Hq * params.D + h * params.D; + if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; } + if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; } + if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; } + if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; } + return r; +} + +fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (c >= params.context_len) { + return r; + } + let base = c * params.Hkv * params.D + kvh * params.D; + if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; } + if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; } + if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; } + if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; } + return r; +} + +fn store_qk(s: u32, c: u32, h: u32, raw: f32) { + if (s >= params.S || c >= params.context_len) { + return; + } + var val = raw * params.scale; + // Causal mask: position c may not attend beyond s + input_pos. + if (c > s + params.input_pos) { + val = NEG_INF; + } + let idx = h * params.S * params.context_len + s * params.context_len + c; + t_attn_weights[idx] = val; +} + @compute @workgroup_size(wg_size, 1, 1) fn main(@builtin(global_invocation_id) gid: vec3) { - let total = params.Hq * params.S * params.context_len; - let idx = gid.x; - if (idx >= total) { + let nrt = (params.S + TM - 1u) / TM; + let nct = (params.context_len + TN - 1u) / TN; + let tiles = nrt * nct; + let total = tiles * params.Hq; + if (gid.x >= total) { return; } - let c = idx % params.context_len; - let s = (idx / params.context_len) % params.S; - let h = idx / (params.context_len * params.S); + let h = gid.x / tiles; + let rem = gid.x % tiles; + let row_tile = rem / nct; + let col_tile = rem % nct; let kvh = h / params.g; + let s0 = row_tile * TM; + let c0 = col_tile * TN; - let q_base = s * params.Hq * params.D + h * params.D; - let k_base = c * params.Hkv * params.D + kvh * params.D; + var acc: array, 4>; + acc[0] = vec4(0.0, 0.0, 0.0, 0.0); + acc[1] = vec4(0.0, 0.0, 0.0, 0.0); + acc[2] = vec4(0.0, 0.0, 0.0, 0.0); + acc[3] = vec4(0.0, 0.0, 0.0, 0.0); - var acc: f32 = 0.0; - var d: u32 = 0u; + var d4: u32 = 0u; loop { - if (d >= params.D) { + if (d4 >= params.D) { break; } - acc = acc + t_q[q_base + d] * t_k_cache[k_base + d]; - d = d + 1u; + var q: array, TM>; + var k: array, TN>; + for (var i: u32 = 0u; i < TM; i = i + 1u) { + q[i] = load_q_vec4(s0 + i, h, d4); + } + for (var j: u32 = 0u; j < TN; j = j + 1u) { + k[j] = load_k_vec4(c0 + j, kvh, d4); + } + for (var i: u32 = 0u; i < TM; i = i + 1u) { + acc[i] += vec4( + dot(q[i], k[0]), + dot(q[i], k[1]), + dot(q[i], k[2]), + dot(q[i], k[3])); + } + d4 = d4 + 4u; } - acc = acc * params.scale; - // Causal mask: position c may not attend beyond s + input_pos. - if (c > s + params.input_pos) { - acc = NEG_INF; + var m: u32 = 0u; + loop { + if (m >= TM) { + break; + } + let av = acc[m]; + store_qk(s0 + m, c0 + 0u, h, av.x); + store_qk(s0 + m, c0 + 1u, h, av.y); + store_qk(s0 + m, c0 + 2u, h, av.z); + store_qk(s0 + m, c0 + 3u, h, av.w); + m = m + 1u; } - - t_attn_weights[idx] = acc; } )"; diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl index 97642670f60..3ac2339376e 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl @@ -16,31 +16,98 @@ struct Params { override wg_size: u32 = 64; +const TM: u32 = 4u; +const TN: u32 = 4u; + +fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (s >= params.S) { + return r; + } + let base = h * params.S * params.context_len + s * params.context_len; + if (c4 + 0u < params.context_len) { r.x = t_attn_weights_softmax[base + c4 + 0u]; } + if (c4 + 1u < params.context_len) { r.y = t_attn_weights_softmax[base + c4 + 1u]; } + if (c4 + 2u < params.context_len) { r.z = t_attn_weights_softmax[base + c4 + 2u]; } + if (c4 + 3u < params.context_len) { r.w = t_attn_weights_softmax[base + c4 + 3u]; } + return r; +} + +fn load_v_vec4(d: u32, kvh: u32, c4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (d >= params.D) { + return r; + } + let stride = params.Hkv * params.D; + let off = kvh * params.D + d; + if (c4 + 0u < params.context_len) { r.x = t_v_cache[(c4 + 0u) * stride + off]; } + if (c4 + 1u < params.context_len) { r.y = t_v_cache[(c4 + 1u) * stride + off]; } + if (c4 + 2u < params.context_len) { r.z = t_v_cache[(c4 + 2u) * stride + off]; } + if (c4 + 3u < params.context_len) { r.w = t_v_cache[(c4 + 3u) * stride + off]; } + return r; +} + +fn store_out(s: u32, d: u32, h: u32, val: f32) { + if (s >= params.S || d >= params.D) { + return; + } + let idx = s * params.Hq * params.D + h * params.D + d; + t_out[idx] = val; +} + @compute @workgroup_size(wg_size, 1, 1) fn main(@builtin(global_invocation_id) gid: vec3) { - let total = params.S * params.Hq * params.D; - let idx = gid.x; - if (idx >= total) { + let nrt = (params.S + TM - 1u) / TM; + let nct = (params.D + TN - 1u) / TN; + let tiles = nrt * nct; + let total = tiles * params.Hq; + if (gid.x >= total) { return; } - let d = idx % params.D; - let h = (idx / params.D) % params.Hq; - let s = idx / (params.D * params.Hq); + let h = gid.x / tiles; + let rem = gid.x % tiles; + let row_tile = rem / nct; + let col_tile = rem % nct; let kvh = h / params.g; + let s0 = row_tile * TM; + let d0 = col_tile * TN; - let aw_base = h * params.S * params.context_len + s * params.context_len; + var acc: array, 4>; + acc[0] = vec4(0.0, 0.0, 0.0, 0.0); + acc[1] = vec4(0.0, 0.0, 0.0, 0.0); + acc[2] = vec4(0.0, 0.0, 0.0, 0.0); + acc[3] = vec4(0.0, 0.0, 0.0, 0.0); - var acc: f32 = 0.0; - var c: u32 = 0u; + var c4: u32 = 0u; loop { - if (c >= params.context_len) { + if (c4 >= params.context_len) { break; } - let v_off = c * params.Hkv * params.D + kvh * params.D + d; - acc = acc + t_attn_weights_softmax[aw_base + c] * t_v_cache[v_off]; - c = c + 1u; + let a0 = load_a_vec4(s0 + 0u, h, c4); + let a1 = load_a_vec4(s0 + 1u, h, c4); + let a2 = load_a_vec4(s0 + 2u, h, c4); + let a3 = load_a_vec4(s0 + 3u, h, c4); + let v0 = load_v_vec4(d0 + 0u, kvh, c4); + let v1 = load_v_vec4(d0 + 1u, kvh, c4); + let v2 = load_v_vec4(d0 + 2u, kvh, c4); + let v3 = load_v_vec4(d0 + 3u, kvh, c4); + acc[0] += vec4(dot(a0, v0), dot(a0, v1), dot(a0, v2), dot(a0, v3)); + acc[1] += vec4(dot(a1, v0), dot(a1, v1), dot(a1, v2), dot(a1, v3)); + acc[2] += vec4(dot(a2, v0), dot(a2, v1), dot(a2, v2), dot(a2, v3)); + acc[3] += vec4(dot(a3, v0), dot(a3, v1), dot(a3, v2), dot(a3, v3)); + c4 = c4 + 4u; } - t_out[idx] = acc; + var m: u32 = 0u; + loop { + if (m >= TM) { + break; + } + let ov = acc[m]; + store_out(s0 + m, d0 + 0u, h, ov.x); + store_out(s0 + m, d0 + 1u, h, ov.y); + store_out(s0 + m, d0 + 2u, h, ov.z); + store_out(s0 + m, d0 + 3u, h, ov.w); + m = m + 1u; + } } diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h index ce25df06876..cf1d742d7e5 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_out.wgsl - DO NOT EDIT. -// wgsl-sha256: 67b9c64fbffdcb72264dda42e24b59e414719411c64c504f84f2ba57b5dcfc0f +// wgsl-sha256: 4ffc13bad0bf56b87a57f75307f29e851dd2bd6bf0dba094488df5d262e910e3 inline constexpr const char* kSdpaComputeOutWGSL = R"( @group(0) @binding(0) var t_out: array; @group(0) @binding(1) var t_attn_weights_softmax: array; @@ -33,33 +33,100 @@ struct Params { override wg_size: u32 = 64; +const TM: u32 = 4u; +const TN: u32 = 4u; + +fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (s >= params.S) { + return r; + } + let base = h * params.S * params.context_len + s * params.context_len; + if (c4 + 0u < params.context_len) { r.x = t_attn_weights_softmax[base + c4 + 0u]; } + if (c4 + 1u < params.context_len) { r.y = t_attn_weights_softmax[base + c4 + 1u]; } + if (c4 + 2u < params.context_len) { r.z = t_attn_weights_softmax[base + c4 + 2u]; } + if (c4 + 3u < params.context_len) { r.w = t_attn_weights_softmax[base + c4 + 3u]; } + return r; +} + +fn load_v_vec4(d: u32, kvh: u32, c4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (d >= params.D) { + return r; + } + let stride = params.Hkv * params.D; + let off = kvh * params.D + d; + if (c4 + 0u < params.context_len) { r.x = t_v_cache[(c4 + 0u) * stride + off]; } + if (c4 + 1u < params.context_len) { r.y = t_v_cache[(c4 + 1u) * stride + off]; } + if (c4 + 2u < params.context_len) { r.z = t_v_cache[(c4 + 2u) * stride + off]; } + if (c4 + 3u < params.context_len) { r.w = t_v_cache[(c4 + 3u) * stride + off]; } + return r; +} + +fn store_out(s: u32, d: u32, h: u32, val: f32) { + if (s >= params.S || d >= params.D) { + return; + } + let idx = s * params.Hq * params.D + h * params.D + d; + t_out[idx] = val; +} + @compute @workgroup_size(wg_size, 1, 1) fn main(@builtin(global_invocation_id) gid: vec3) { - let total = params.S * params.Hq * params.D; - let idx = gid.x; - if (idx >= total) { + let nrt = (params.S + TM - 1u) / TM; + let nct = (params.D + TN - 1u) / TN; + let tiles = nrt * nct; + let total = tiles * params.Hq; + if (gid.x >= total) { return; } - let d = idx % params.D; - let h = (idx / params.D) % params.Hq; - let s = idx / (params.D * params.Hq); + let h = gid.x / tiles; + let rem = gid.x % tiles; + let row_tile = rem / nct; + let col_tile = rem % nct; let kvh = h / params.g; + let s0 = row_tile * TM; + let d0 = col_tile * TN; - let aw_base = h * params.S * params.context_len + s * params.context_len; + var acc: array, 4>; + acc[0] = vec4(0.0, 0.0, 0.0, 0.0); + acc[1] = vec4(0.0, 0.0, 0.0, 0.0); + acc[2] = vec4(0.0, 0.0, 0.0, 0.0); + acc[3] = vec4(0.0, 0.0, 0.0, 0.0); - var acc: f32 = 0.0; - var c: u32 = 0u; + var c4: u32 = 0u; loop { - if (c >= params.context_len) { + if (c4 >= params.context_len) { break; } - let v_off = c * params.Hkv * params.D + kvh * params.D + d; - acc = acc + t_attn_weights_softmax[aw_base + c] * t_v_cache[v_off]; - c = c + 1u; + let a0 = load_a_vec4(s0 + 0u, h, c4); + let a1 = load_a_vec4(s0 + 1u, h, c4); + let a2 = load_a_vec4(s0 + 2u, h, c4); + let a3 = load_a_vec4(s0 + 3u, h, c4); + let v0 = load_v_vec4(d0 + 0u, kvh, c4); + let v1 = load_v_vec4(d0 + 1u, kvh, c4); + let v2 = load_v_vec4(d0 + 2u, kvh, c4); + let v3 = load_v_vec4(d0 + 3u, kvh, c4); + acc[0] += vec4(dot(a0, v0), dot(a0, v1), dot(a0, v2), dot(a0, v3)); + acc[1] += vec4(dot(a1, v0), dot(a1, v1), dot(a1, v2), dot(a1, v3)); + acc[2] += vec4(dot(a2, v0), dot(a2, v1), dot(a2, v2), dot(a2, v3)); + acc[3] += vec4(dot(a3, v0), dot(a3, v1), dot(a3, v2), dot(a3, v3)); + c4 = c4 + 4u; } - t_out[idx] = acc; + var m: u32 = 0u; + loop { + if (m >= TM) { + break; + } + let ov = acc[m]; + store_out(s0 + m, d0 + 0u, h, ov.x); + store_out(s0 + m, d0 + 1u, h, ov.y); + store_out(s0 + m, d0 + 2u, h, ov.z); + store_out(s0 + m, d0 + 3u, h, ov.w); + m = m + 1u; + } } )"; From a8e4091ebfd1d9949d2955770fe16625fecb5abb Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Wed, 24 Jun 2026 19:35:31 -0700 Subject: [PATCH 2/4] [ExecuTorch][WebGPU] Coalesce SDPA AV V-cache reads along contiguous head-dim MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/20459 **~19% faster SDPA attention-output (AV) stage** — 393→317 µs on llama3 prefill (Chrome Canary / M4 Pro). **Problem**: V-cache reads load 4 strided context rows × 1 head-dim lane, missing coalescing. **Solution**: Flip access pattern to read 4 contiguous head-dim lanes per context row: - **Before**: `load_v_vec4(d, kvh, c4)` → 4 strided rows, `dot()` along D - **After**: `load_v_d4(c, kvh, d0)` → 4 contiguous D-lanes (16-byte texel), scalar broadcast **Implementation**: - Reindex `load_v` helper to read contiguous head-dim - Replace `dot(A, V)` with `acc += A[c] * V_vec4(d0:d0+3)` - Mirrors Vulkan `load_v_cache_d4` coalescing pattern **Constraints**: - No KV-cache layout change (still `[C, Hkv, D]`) - Output numerically identical (FP-reassociated, max abs diff 1.43e-6 vs torch) ghstack-source-id: 396792504 @exported-using-ghexport Differential Revision: [D109339276](https://our.internmc.facebook.com/intern/diff/D109339276/) --- .../runtime/ops/sdpa/sdpa_compute_out.wgsl | 31 +++++++++-------- .../runtime/ops/sdpa/sdpa_compute_out_wgsl.h | 33 +++++++++---------- 2 files changed, 31 insertions(+), 33 deletions(-) diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl index 3ac2339376e..a5c79dfd4e3 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl @@ -32,17 +32,16 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { return r; } -fn load_v_vec4(d: u32, kvh: u32, c4: u32) -> vec4 { +fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4 { var r = vec4(0.0, 0.0, 0.0, 0.0); - if (d >= params.D) { + if (c >= params.context_len) { return r; } - let stride = params.Hkv * params.D; - let off = kvh * params.D + d; - if (c4 + 0u < params.context_len) { r.x = t_v_cache[(c4 + 0u) * stride + off]; } - if (c4 + 1u < params.context_len) { r.y = t_v_cache[(c4 + 1u) * stride + off]; } - if (c4 + 2u < params.context_len) { r.z = t_v_cache[(c4 + 2u) * stride + off]; } - if (c4 + 3u < params.context_len) { r.w = t_v_cache[(c4 + 3u) * stride + off]; } + let base = c * params.Hkv * params.D + kvh * params.D + d0; + if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; } + if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; } + if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; } + if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; } return r; } @@ -87,14 +86,14 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let a1 = load_a_vec4(s0 + 1u, h, c4); let a2 = load_a_vec4(s0 + 2u, h, c4); let a3 = load_a_vec4(s0 + 3u, h, c4); - let v0 = load_v_vec4(d0 + 0u, kvh, c4); - let v1 = load_v_vec4(d0 + 1u, kvh, c4); - let v2 = load_v_vec4(d0 + 2u, kvh, c4); - let v3 = load_v_vec4(d0 + 3u, kvh, c4); - acc[0] += vec4(dot(a0, v0), dot(a0, v1), dot(a0, v2), dot(a0, v3)); - acc[1] += vec4(dot(a1, v0), dot(a1, v1), dot(a1, v2), dot(a1, v3)); - acc[2] += vec4(dot(a2, v0), dot(a2, v1), dot(a2, v2), dot(a2, v3)); - acc[3] += vec4(dot(a3, v0), dot(a3, v1), dot(a3, v2), dot(a3, v3)); + let v0 = load_v_d4(c4 + 0u, kvh, d0); + let v1 = load_v_d4(c4 + 1u, kvh, d0); + let v2 = load_v_d4(c4 + 2u, kvh, d0); + let v3 = load_v_d4(c4 + 3u, kvh, d0); + acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3; + acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; + acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; + acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; c4 = c4 + 4u; } diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h index cf1d742d7e5..2fef0e2d8c9 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_out.wgsl - DO NOT EDIT. -// wgsl-sha256: 4ffc13bad0bf56b87a57f75307f29e851dd2bd6bf0dba094488df5d262e910e3 +// wgsl-sha256: 545f624567b08eba407954034df821010e49124fa6f8fd6b05c64ca4354ee4cc inline constexpr const char* kSdpaComputeOutWGSL = R"( @group(0) @binding(0) var t_out: array; @group(0) @binding(1) var t_attn_weights_softmax: array; @@ -49,17 +49,16 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { return r; } -fn load_v_vec4(d: u32, kvh: u32, c4: u32) -> vec4 { +fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4 { var r = vec4(0.0, 0.0, 0.0, 0.0); - if (d >= params.D) { + if (c >= params.context_len) { return r; } - let stride = params.Hkv * params.D; - let off = kvh * params.D + d; - if (c4 + 0u < params.context_len) { r.x = t_v_cache[(c4 + 0u) * stride + off]; } - if (c4 + 1u < params.context_len) { r.y = t_v_cache[(c4 + 1u) * stride + off]; } - if (c4 + 2u < params.context_len) { r.z = t_v_cache[(c4 + 2u) * stride + off]; } - if (c4 + 3u < params.context_len) { r.w = t_v_cache[(c4 + 3u) * stride + off]; } + let base = c * params.Hkv * params.D + kvh * params.D + d0; + if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; } + if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; } + if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; } + if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; } return r; } @@ -104,14 +103,14 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let a1 = load_a_vec4(s0 + 1u, h, c4); let a2 = load_a_vec4(s0 + 2u, h, c4); let a3 = load_a_vec4(s0 + 3u, h, c4); - let v0 = load_v_vec4(d0 + 0u, kvh, c4); - let v1 = load_v_vec4(d0 + 1u, kvh, c4); - let v2 = load_v_vec4(d0 + 2u, kvh, c4); - let v3 = load_v_vec4(d0 + 3u, kvh, c4); - acc[0] += vec4(dot(a0, v0), dot(a0, v1), dot(a0, v2), dot(a0, v3)); - acc[1] += vec4(dot(a1, v0), dot(a1, v1), dot(a1, v2), dot(a1, v3)); - acc[2] += vec4(dot(a2, v0), dot(a2, v1), dot(a2, v2), dot(a2, v3)); - acc[3] += vec4(dot(a3, v0), dot(a3, v1), dot(a3, v2), dot(a3, v3)); + let v0 = load_v_d4(c4 + 0u, kvh, d0); + let v1 = load_v_d4(c4 + 1u, kvh, d0); + let v2 = load_v_d4(c4 + 2u, kvh, d0); + let v3 = load_v_d4(c4 + 3u, kvh, d0); + acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3; + acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; + acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; + acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; c4 = c4 + 4u; } From 23b14c46d8e54a7859f382669bd692446a628257 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Wed, 24 Jun 2026 19:35:31 -0700 Subject: [PATCH 3/4] [ExecuTorch][WebGPU] SDPA: skip QK contraction for fully-masked causal tiles MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/20492 **Skip the QK contraction for fully-masked causal tiles** — at S=128 prefill ~48% of the (query, key) tiles are entirely above the diagonal and contribute nothing; this elides their dot products (prefill-only; bit-identical output). **Problem**: For causal prefill, ~half the (query S-tile, key context-tile) pairs are entirely above the diagonal, yet the kernel still computes their full `d4` dot product before masking the result to `NEG_INF`. **Solution**: Skip the contraction for fully-masked tiles; the existing per-element mask still writes the sentinel: - **Before**: every `(s0, c0)` tile runs the full `d4` dot-product loop, then `store_qk` masks above-diagonal elements to `NEG_INF`. - **After**: a fully-masked tile (`c0 > s0 + TM-1 + input_pos`) breaks the `d4` loop immediately (`acc` stays 0); `store_qk` masks every element to `NEG_INF` exactly as before. **Implementation**: - Add `skip_tile = c0 > s0 + (TM - 1) + params.input_pos`, folded into the `d4` loop break condition. - Store loop unchanged — runs unconditionally, so no scratch entry is left stale. - Mirrors Vulkan `sdpa_compute_attn_weights_tiled.glsl` (`tile_in_mask_region`). **Constraints**: - No KV-cache, host, dispatch, or uniform change (all tiles still launch; the skip is in-shader). - Prefill-only: decode `S=1` never triggers it (`c0 <= input_pos < input_pos + TM - 1`). - `NEG_INF` stays the WGSL-safe `-1.0e30` (WGSL forbids a literal `-inf`); does not copy Vulkan's `-1.0/0.0`. Co-authored with Claude Code. ghstack-source-id: 396792509 @exported-using-ghexport Differential Revision: [D109517773](https://our.internmc.facebook.com/intern/diff/D109517773/) --- .../webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl | 4 +++- .../runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl index e694a6b223c..9a5cd614f8b 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl @@ -85,9 +85,11 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[2] = vec4(0.0, 0.0, 0.0, 0.0); acc[3] = vec4(0.0, 0.0, 0.0, 0.0); + // Skip fully-masked causal tiles; mirrors Vulkan attn_weights_tiled.glsl. + let skip_tile = c0 > s0 + (TM - 1u) + params.input_pos; var d4: u32 = 0u; loop { - if (d4 >= params.D) { + if (d4 >= params.D || skip_tile) { break; } var q: array, TM>; diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h index 3d027b417f0..144af9b5956 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_attn_weights.wgsl - DO NOT EDIT. -// wgsl-sha256: e9bdc2272bba2716655e96be0597571701b4ec1496f78a85660e29d108f655f9 +// wgsl-sha256: d177264689e6c50e1794a0599808f3cfe6f30ba99c5084d3c8324da4b9f89d10 inline constexpr const char* kSdpaComputeAttnWeightsWGSL = R"( @group(0) @binding(0) var t_attn_weights: array; @group(0) @binding(1) var t_q: array; @@ -102,9 +102,11 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[2] = vec4(0.0, 0.0, 0.0, 0.0); acc[3] = vec4(0.0, 0.0, 0.0, 0.0); + // Skip fully-masked causal tiles; mirrors Vulkan attn_weights_tiled.glsl. + let skip_tile = c0 > s0 + (TM - 1u) + params.input_pos; var d4: u32 = 0u; loop { - if (d4 >= params.D) { + if (d4 >= params.D || skip_tile) { break; } var q: array, TM>; From 431acb05e103228d041e25952eb1389be200dca9 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Wed, 24 Jun 2026 19:35:32 -0700 Subject: [PATCH 4/4] [ExecuTorch][WebGPU] SDPA: branchless aligned/tail loads in the QK/AV kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/20493 **Branchless aligned/tail loads + vec4 storage bindings** — drop the always-true per-lane bounds checks in the tiled QK/AV hot loops, split the AV context contraction into a branch-free aligned body plus a checked tail, and declare the head-dim-indexed SDPA storage buffers as `array>` so the loads/stores are forced-vectorized (addresses review feedback to mirror Vulkan's vec4 bindings). **Problem**: The tiled QK/AV vec4 loaders run 4 per-lane `if` bounds checks on every load, every contraction iteration (8 loads/iter). But `head_dim` is always a multiple of 4, so the D-axis checks never fire, and the AV context axis only needs a bounds check on the last ragged chunk. Separately the storage buffers were declared `array`, so the 4-lane loads/stores were not guaranteed to compile to aligned 128-bit vector accesses. **Solution**: Remove the dead checks, split the ragged axis, and vectorize the bindings: - **Before**: `load_q_vec4`/`load_k_vec4` (and AV `load_a_vec4`/`load_v_d4`) do 4 per-lane bounds `if`s per call; the AV `c4` loop runs checked loads for every chunk; `t_q`/`t_k_cache`/`t_v_cache`/`t_out` are `array` accessed element-by-element. - **After**: QK loads are a plain unchecked `vec4` (D%4==0, host-guarded); AV runs a branch-free aligned body over `c4 in [0, context_len - context_len%4)` then a 0-or-1 checked tail; the head-dim-indexed buffers `t_q`/`t_k_cache`/`t_v_cache`/`t_out` are `array>` indexed `[base/4u]`, and AV writes a single aligned `store_out_vec4`. **Implementation**: - QK: `load_q_vec4`/`load_k_vec4` drop the per-lane D checks and return `t_q[base/4u]` / `t_k_cache[base/4u]`. - AV: branch-free `load_a_vec4_nc`/`load_v_d4_nc` for the aligned body; checked `load_a_vec4`/`load_v_d4` for the tail; V reads `t_v_cache[base/4u]`; output is one aligned `store_out_vec4`. - Bindings: `t_q`, `t_k_cache` (QK) and `t_v_cache`, `t_out` (AV) are `array>`. `t_attn_weights` and the softmax buffer stay `array` — they are `context_len`-indexed (row stride not 4-aligned) and written per-element under the causal mask, so a `vec4` binding there would need a padded scratch row. - Host: add a `D % 4 == 0` guard in `Sdpa.cpp` — WGSL has no `SDPA_PAD_D` pad-load, so fail loud rather than read past the row; this guard also makes every `[base/4u]` index 4-aligned and every buffer a 16-byte multiple. - Test: add a `reject_d6` (head_dim=6) config + an `expect_reject` harness branch asserting the guard rejects a non-aligned head_dim at load. - Mirrors Vulkan `sdpa_compute_out_tiled.glsl` (aligned/tail split) and Vulkan's `array` SDPA bindings. **Constraints**: - Requires `head_dim % 4 == 0` (true for every Llama config, D=64); enforced by a loud host throw, not a silent narrowing. - Bit-identical output: the aligned body processes the same chunks in the same accumulation order as the scalar loop, the tail's out-of-range lanes contribute 0, and the `vec4` bindings read/write the same bytes as the scalar version. - No KV-cache layout, dispatch, or uniform change. Co-authored with Claude Code. ghstack-source-id: 396792517 @exported-using-ghexport Differential Revision: [D109521069](https://our.internmc.facebook.com/intern/diff/D109521069/) --- backends/webgpu/runtime/ops/sdpa/Sdpa.cpp | 5 ++ .../ops/sdpa/sdpa_compute_attn_weights.wgsl | 27 +++----- .../ops/sdpa/sdpa_compute_attn_weights_wgsl.h | 29 +++------ .../runtime/ops/sdpa/sdpa_compute_out.wgsl | 62 ++++++++++++------ .../runtime/ops/sdpa/sdpa_compute_out_wgsl.h | 64 +++++++++++++------ backends/webgpu/test/ops/sdpa/test_sdpa.py | 2 + backends/webgpu/test/test_webgpu_native.cpp | 21 ++++++ 7 files changed, 132 insertions(+), 78 deletions(-) diff --git a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp index aaf93c7fda0..b1aa689a09d 100644 --- a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp +++ b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp @@ -339,6 +339,11 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { if (k.dims[kn - 1] != D || v.dims[v.dims.size() - 1] != D) { throw std::runtime_error("WebGPU sdpa: k/v head_dim must match q"); } + // QK/AV read D as vec4 (no SDPA_PAD_D); head_dim must be a multiple of 4. + if (D % 4 != 0) { + throw std::runtime_error( + "WebGPU sdpa: head_dim (D) must be a multiple of 4"); + } if (v.dims[v.dims.size() - 2] != Hkv) { throw std::runtime_error("WebGPU sdpa: v num_heads must match k"); } diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl index 9a5cd614f8b..fd15767603d 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl @@ -1,6 +1,6 @@ @group(0) @binding(0) var t_attn_weights: array; -@group(0) @binding(1) var t_q: array; -@group(0) @binding(2) var t_k_cache: array; +@group(0) @binding(1) var t_q: array>; +@group(0) @binding(2) var t_k_cache: array>; struct Params { S: u32, @@ -22,30 +22,21 @@ override wg_size: u32 = 64; const TM: u32 = 4u; const TN: u32 = 4u; +// D is a multiple of 4 (host-guarded), so a d4 chunk is fully in-bounds — no per-lane check. fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (s >= params.S) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } - let base = s * params.Hq * params.D + h * params.D; - if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; } - if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; } - if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; } - if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; } - return r; + let base = s * params.Hq * params.D + h * params.D + d4; + return t_q[base / 4u]; } fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (c >= params.context_len) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } - let base = c * params.Hkv * params.D + kvh * params.D; - if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; } - if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; } - if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; } - if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; } - return r; + let base = c * params.Hkv * params.D + kvh * params.D + d4; + return t_k_cache[base / 4u]; } fn store_qk(s: u32, c: u32, h: u32, raw: f32) { diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h index 144af9b5956..ae250959e0e 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h @@ -13,11 +13,11 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_attn_weights.wgsl - DO NOT EDIT. -// wgsl-sha256: d177264689e6c50e1794a0599808f3cfe6f30ba99c5084d3c8324da4b9f89d10 +// wgsl-sha256: 4eef09b234fd926cdc0daf18d03e39cf4fd57dfa4bc67724b4878b7dc68d1254 inline constexpr const char* kSdpaComputeAttnWeightsWGSL = R"( @group(0) @binding(0) var t_attn_weights: array; -@group(0) @binding(1) var t_q: array; -@group(0) @binding(2) var t_k_cache: array; +@group(0) @binding(1) var t_q: array>; +@group(0) @binding(2) var t_k_cache: array>; struct Params { S: u32, @@ -39,30 +39,21 @@ override wg_size: u32 = 64; const TM: u32 = 4u; const TN: u32 = 4u; +// D is a multiple of 4 (host-guarded), so a d4 chunk is fully in-bounds — no per-lane check. fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (s >= params.S) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } - let base = s * params.Hq * params.D + h * params.D; - if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; } - if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; } - if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; } - if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; } - return r; + let base = s * params.Hq * params.D + h * params.D + d4; + return t_q[base / 4u]; } fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (c >= params.context_len) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } - let base = c * params.Hkv * params.D + kvh * params.D; - if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; } - if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; } - if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; } - if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; } - return r; + let base = c * params.Hkv * params.D + kvh * params.D + d4; + return t_k_cache[base / 4u]; } fn store_qk(s: u32, c: u32, h: u32, raw: f32) { diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl index a5c79dfd4e3..56242b0ddde 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl @@ -1,6 +1,6 @@ -@group(0) @binding(0) var t_out: array; +@group(0) @binding(0) var t_out: array>; @group(0) @binding(1) var t_attn_weights_softmax: array; -@group(0) @binding(2) var t_v_cache: array; +@group(0) @binding(2) var t_v_cache: array>; struct Params { S: u32, @@ -19,6 +19,7 @@ override wg_size: u32 = 64; const TM: u32 = 4u; const TN: u32 = 4u; +// Checked loaders mask context lanes past context_len (D%4==0, host-guarded). fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { var r = vec4(0.0, 0.0, 0.0, 0.0); if (s >= params.S) { @@ -33,24 +34,33 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { } fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (c >= params.context_len) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } let base = c * params.Hkv * params.D + kvh * params.D + d0; - if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; } - if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; } - if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; } - if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; } - return r; + return t_v_cache[base / 4u]; +} + +// Branch-free loaders for the aligned body: caller guarantees c4..c4+3 < context_len. +fn load_a_vec4_nc(s: u32, h: u32, c4: u32) -> vec4 { + if (s >= params.S) { + return vec4(0.0, 0.0, 0.0, 0.0); + } + let base = h * params.S * params.context_len + s * params.context_len + c4; + return vec4(t_attn_weights_softmax[base], t_attn_weights_softmax[base + 1u], t_attn_weights_softmax[base + 2u], t_attn_weights_softmax[base + 3u]); +} + +fn load_v_d4_nc(c: u32, kvh: u32, d0: u32) -> vec4 { + let base = c * params.Hkv * params.D + kvh * params.D + d0; + return t_v_cache[base / 4u]; } -fn store_out(s: u32, d: u32, h: u32, val: f32) { - if (s >= params.S || d >= params.D) { +fn store_out_vec4(s: u32, d0: u32, h: u32, val: vec4) { + if (s >= params.S) { return; } - let idx = s * params.Hq * params.D + h * params.D + d; - t_out[idx] = val; + let idx = s * params.Hq * params.D + h * params.D + d0; + t_out[idx / 4u] = val; } @compute @workgroup_size(wg_size, 1, 1) @@ -77,11 +87,28 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[2] = vec4(0.0, 0.0, 0.0, 0.0); acc[3] = vec4(0.0, 0.0, 0.0, 0.0); + // Branch-free aligned body + checked tail; mirrors Vulkan out_tiled.glsl. + let ctx_aligned = params.context_len - (params.context_len & 3u); var c4: u32 = 0u; loop { - if (c4 >= params.context_len) { + if (c4 >= ctx_aligned) { break; } + let a0 = load_a_vec4_nc(s0 + 0u, h, c4); + let a1 = load_a_vec4_nc(s0 + 1u, h, c4); + let a2 = load_a_vec4_nc(s0 + 2u, h, c4); + let a3 = load_a_vec4_nc(s0 + 3u, h, c4); + let v0 = load_v_d4_nc(c4 + 0u, kvh, d0); + let v1 = load_v_d4_nc(c4 + 1u, kvh, d0); + let v2 = load_v_d4_nc(c4 + 2u, kvh, d0); + let v3 = load_v_d4_nc(c4 + 3u, kvh, d0); + acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3; + acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; + acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; + acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; + c4 = c4 + 4u; + } + if (c4 < params.context_len) { let a0 = load_a_vec4(s0 + 0u, h, c4); let a1 = load_a_vec4(s0 + 1u, h, c4); let a2 = load_a_vec4(s0 + 2u, h, c4); @@ -94,7 +121,6 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; - c4 = c4 + 4u; } var m: u32 = 0u; @@ -102,11 +128,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { if (m >= TM) { break; } - let ov = acc[m]; - store_out(s0 + m, d0 + 0u, h, ov.x); - store_out(s0 + m, d0 + 1u, h, ov.y); - store_out(s0 + m, d0 + 2u, h, ov.z); - store_out(s0 + m, d0 + 3u, h, ov.w); + store_out_vec4(s0 + m, d0, h, acc[m]); m = m + 1u; } } diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h index 2fef0e2d8c9..6bec079ac2b 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h @@ -13,11 +13,11 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_out.wgsl - DO NOT EDIT. -// wgsl-sha256: 545f624567b08eba407954034df821010e49124fa6f8fd6b05c64ca4354ee4cc +// wgsl-sha256: 2ffa0eb520b1054e43a10fd13e6b287bd35777f1cfc29bd39e9d668772528191 inline constexpr const char* kSdpaComputeOutWGSL = R"( -@group(0) @binding(0) var t_out: array; +@group(0) @binding(0) var t_out: array>; @group(0) @binding(1) var t_attn_weights_softmax: array; -@group(0) @binding(2) var t_v_cache: array; +@group(0) @binding(2) var t_v_cache: array>; struct Params { S: u32, @@ -36,6 +36,7 @@ override wg_size: u32 = 64; const TM: u32 = 4u; const TN: u32 = 4u; +// Checked loaders mask context lanes past context_len (D%4==0, host-guarded). fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { var r = vec4(0.0, 0.0, 0.0, 0.0); if (s >= params.S) { @@ -50,24 +51,33 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { } fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (c >= params.context_len) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } let base = c * params.Hkv * params.D + kvh * params.D + d0; - if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; } - if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; } - if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; } - if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; } - return r; + return t_v_cache[base / 4u]; +} + +// Branch-free loaders for the aligned body: caller guarantees c4..c4+3 < context_len. +fn load_a_vec4_nc(s: u32, h: u32, c4: u32) -> vec4 { + if (s >= params.S) { + return vec4(0.0, 0.0, 0.0, 0.0); + } + let base = h * params.S * params.context_len + s * params.context_len + c4; + return vec4(t_attn_weights_softmax[base], t_attn_weights_softmax[base + 1u], t_attn_weights_softmax[base + 2u], t_attn_weights_softmax[base + 3u]); +} + +fn load_v_d4_nc(c: u32, kvh: u32, d0: u32) -> vec4 { + let base = c * params.Hkv * params.D + kvh * params.D + d0; + return t_v_cache[base / 4u]; } -fn store_out(s: u32, d: u32, h: u32, val: f32) { - if (s >= params.S || d >= params.D) { +fn store_out_vec4(s: u32, d0: u32, h: u32, val: vec4) { + if (s >= params.S) { return; } - let idx = s * params.Hq * params.D + h * params.D + d; - t_out[idx] = val; + let idx = s * params.Hq * params.D + h * params.D + d0; + t_out[idx / 4u] = val; } @compute @workgroup_size(wg_size, 1, 1) @@ -94,11 +104,28 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[2] = vec4(0.0, 0.0, 0.0, 0.0); acc[3] = vec4(0.0, 0.0, 0.0, 0.0); + // Branch-free aligned body + checked tail; mirrors Vulkan out_tiled.glsl. + let ctx_aligned = params.context_len - (params.context_len & 3u); var c4: u32 = 0u; loop { - if (c4 >= params.context_len) { + if (c4 >= ctx_aligned) { break; } + let a0 = load_a_vec4_nc(s0 + 0u, h, c4); + let a1 = load_a_vec4_nc(s0 + 1u, h, c4); + let a2 = load_a_vec4_nc(s0 + 2u, h, c4); + let a3 = load_a_vec4_nc(s0 + 3u, h, c4); + let v0 = load_v_d4_nc(c4 + 0u, kvh, d0); + let v1 = load_v_d4_nc(c4 + 1u, kvh, d0); + let v2 = load_v_d4_nc(c4 + 2u, kvh, d0); + let v3 = load_v_d4_nc(c4 + 3u, kvh, d0); + acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3; + acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; + acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; + acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; + c4 = c4 + 4u; + } + if (c4 < params.context_len) { let a0 = load_a_vec4(s0 + 0u, h, c4); let a1 = load_a_vec4(s0 + 1u, h, c4); let a2 = load_a_vec4(s0 + 2u, h, c4); @@ -111,7 +138,6 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; - c4 = c4 + 4u; } var m: u32 = 0u; @@ -119,11 +145,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { if (m >= TM) { break; } - let ov = acc[m]; - store_out(s0 + m, d0 + 0u, h, ov.x); - store_out(s0 + m, d0 + 1u, h, ov.y); - store_out(s0 + m, d0 + 2u, h, ov.z); - store_out(s0 + m, d0 + 3u, h, ov.w); + store_out_vec4(s0 + m, d0, h, acc[m]); m = m + 1u; } } diff --git a/backends/webgpu/test/ops/sdpa/test_sdpa.py b/backends/webgpu/test/ops/sdpa/test_sdpa.py index b674feae635..1f7a8242591 100644 --- a/backends/webgpu/test/ops/sdpa/test_sdpa.py +++ b/backends/webgpu/test/ops/sdpa/test_sdpa.py @@ -59,6 +59,8 @@ class SdpaConfig: # Llama 3.2 1B shape: realistic prefill (S=128 at pos 0) + decode (S=1 at pos 127). SdpaConfig("llama1b_prefill", 32, 8, 64, 128, 512, 0), SdpaConfig("llama1b_decode", 32, 8, 64, 1, 512, 127), + # D=6 is not a multiple of 4: the WebGPU head_dim%4 guard must reject it at load. + SdpaConfig("reject_d6", 4, 4, 6, 4, 16, 0), ] diff --git a/backends/webgpu/test/test_webgpu_native.cpp b/backends/webgpu/test/test_webgpu_native.cpp index 8d987578aa1..6b57254ca33 100644 --- a/backends/webgpu/test/test_webgpu_native.cpp +++ b/backends/webgpu/test/test_webgpu_native.cpp @@ -719,6 +719,7 @@ struct SdpaConfig { int input_pos; // prior tokens already in the cache (decode) float denom; // ramp divisor (mirrors Python); small -> large logits bool required = false; // CI (SDPA dir set): absent .pte = FAIL, not skip + bool expect_reject = false; // load MUST fail (e.g. D%4 guard), no golden }; static const SdpaConfig kSdpaConfigs[] = { @@ -738,6 +739,17 @@ static const SdpaConfig kSdpaConfigs[] = { // pos 127). {"llama1b_prefill", 32, 8, 64, 128, 512, 0, 16.0f}, {"llama1b_decode", 32, 8, 64, 1, 512, 127, 16.0f}, + // D=6 is not a multiple of 4: the head_dim%4 guard must reject it at load. + {"reject_d6", + 4, + 4, + 6, + 4, + 16, + 0, + 16.0f, + /*required=*/false, + /*expect_reject=*/true}, }; // Ramp denominator; mirror of test_sdpa.py::_RAMP_DENOM (keep in sync). @@ -791,6 +803,15 @@ static bool test_sdpa_config( Module module(model_path); auto err = module.load_forward(); + if (cfg.expect_reject) { + // D not a multiple of 4 must be rejected at load by the head_dim guard. + if (err != Error::Ok) { + printf("PASS: %s rejected at load (error %d)\n", cfg.name, (int)err); + return true; + } + printf("FAIL: %s loaded OK; head_dim%%4 guard did not fire\n", cfg.name); + return false; + } if (err != Error::Ok) { printf("FAIL: could not load forward method (error %d)\n", (int)err); return false;