From eb9073c27311f8dd498cc82ffeefb1a548a8c58b Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Wed, 24 Jun 2026 19:35:31 -0700 Subject: [PATCH] [ExecuTorch][WebGPU] Coalesce SDPA AV V-cache reads along contiguous head-dim MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/20459 **~19% faster SDPA attention-output (AV) stage** — 393→317 µs on llama3 prefill (Chrome Canary / M4 Pro). **Problem**: V-cache reads load 4 strided context rows × 1 head-dim lane, missing coalescing. **Solution**: Flip access pattern to read 4 contiguous head-dim lanes per context row: - **Before**: `load_v_vec4(d, kvh, c4)` → 4 strided rows, `dot()` along D - **After**: `load_v_d4(c, kvh, d0)` → 4 contiguous D-lanes (16-byte texel), scalar broadcast **Implementation**: - Reindex `load_v` helper to read contiguous head-dim - Replace `dot(A, V)` with `acc += A[c] * V_vec4(d0:d0+3)` - Mirrors Vulkan `load_v_cache_d4` coalescing pattern **Constraints**: - No KV-cache layout change (still `[C, Hkv, D]`) - Output numerically identical (FP-reassociated, max abs diff 1.43e-6 vs torch) ghstack-source-id: 396792504 @exported-using-ghexport Differential Revision: [D109339276](https://our.internmc.facebook.com/intern/diff/D109339276/) --- .../runtime/ops/sdpa/sdpa_compute_out.wgsl | 31 +++++++++-------- .../runtime/ops/sdpa/sdpa_compute_out_wgsl.h | 33 +++++++++---------- 2 files changed, 31 insertions(+), 33 deletions(-) diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl index 3ac2339376e..a5c79dfd4e3 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl @@ -32,17 +32,16 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { return r; } -fn load_v_vec4(d: u32, kvh: u32, c4: u32) -> vec4 { +fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4 { var r = vec4(0.0, 0.0, 0.0, 0.0); - if (d >= params.D) { + if (c >= params.context_len) { return r; } - let stride = params.Hkv * params.D; - let off = kvh * params.D + d; - if (c4 + 0u < params.context_len) { r.x = t_v_cache[(c4 + 0u) * stride + off]; } - if (c4 + 1u < params.context_len) { r.y = t_v_cache[(c4 + 1u) * stride + off]; } - if (c4 + 2u < params.context_len) { r.z = t_v_cache[(c4 + 2u) * stride + off]; } - if (c4 + 3u < params.context_len) { r.w = t_v_cache[(c4 + 3u) * stride + off]; } + let base = c * params.Hkv * params.D + kvh * params.D + d0; + if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; } + if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; } + if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; } + if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; } return r; } @@ -87,14 +86,14 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let a1 = load_a_vec4(s0 + 1u, h, c4); let a2 = load_a_vec4(s0 + 2u, h, c4); let a3 = load_a_vec4(s0 + 3u, h, c4); - let v0 = load_v_vec4(d0 + 0u, kvh, c4); - let v1 = load_v_vec4(d0 + 1u, kvh, c4); - let v2 = load_v_vec4(d0 + 2u, kvh, c4); - let v3 = load_v_vec4(d0 + 3u, kvh, c4); - acc[0] += vec4(dot(a0, v0), dot(a0, v1), dot(a0, v2), dot(a0, v3)); - acc[1] += vec4(dot(a1, v0), dot(a1, v1), dot(a1, v2), dot(a1, v3)); - acc[2] += vec4(dot(a2, v0), dot(a2, v1), dot(a2, v2), dot(a2, v3)); - acc[3] += vec4(dot(a3, v0), dot(a3, v1), dot(a3, v2), dot(a3, v3)); + let v0 = load_v_d4(c4 + 0u, kvh, d0); + let v1 = load_v_d4(c4 + 1u, kvh, d0); + let v2 = load_v_d4(c4 + 2u, kvh, d0); + let v3 = load_v_d4(c4 + 3u, kvh, d0); + acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3; + acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; + acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; + acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; c4 = c4 + 4u; } diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h index cf1d742d7e5..2fef0e2d8c9 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_out.wgsl - DO NOT EDIT. -// wgsl-sha256: 4ffc13bad0bf56b87a57f75307f29e851dd2bd6bf0dba094488df5d262e910e3 +// wgsl-sha256: 545f624567b08eba407954034df821010e49124fa6f8fd6b05c64ca4354ee4cc inline constexpr const char* kSdpaComputeOutWGSL = R"( @group(0) @binding(0) var t_out: array; @group(0) @binding(1) var t_attn_weights_softmax: array; @@ -49,17 +49,16 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { return r; } -fn load_v_vec4(d: u32, kvh: u32, c4: u32) -> vec4 { +fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4 { var r = vec4(0.0, 0.0, 0.0, 0.0); - if (d >= params.D) { + if (c >= params.context_len) { return r; } - let stride = params.Hkv * params.D; - let off = kvh * params.D + d; - if (c4 + 0u < params.context_len) { r.x = t_v_cache[(c4 + 0u) * stride + off]; } - if (c4 + 1u < params.context_len) { r.y = t_v_cache[(c4 + 1u) * stride + off]; } - if (c4 + 2u < params.context_len) { r.z = t_v_cache[(c4 + 2u) * stride + off]; } - if (c4 + 3u < params.context_len) { r.w = t_v_cache[(c4 + 3u) * stride + off]; } + let base = c * params.Hkv * params.D + kvh * params.D + d0; + if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; } + if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; } + if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; } + if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; } return r; } @@ -104,14 +103,14 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let a1 = load_a_vec4(s0 + 1u, h, c4); let a2 = load_a_vec4(s0 + 2u, h, c4); let a3 = load_a_vec4(s0 + 3u, h, c4); - let v0 = load_v_vec4(d0 + 0u, kvh, c4); - let v1 = load_v_vec4(d0 + 1u, kvh, c4); - let v2 = load_v_vec4(d0 + 2u, kvh, c4); - let v3 = load_v_vec4(d0 + 3u, kvh, c4); - acc[0] += vec4(dot(a0, v0), dot(a0, v1), dot(a0, v2), dot(a0, v3)); - acc[1] += vec4(dot(a1, v0), dot(a1, v1), dot(a1, v2), dot(a1, v3)); - acc[2] += vec4(dot(a2, v0), dot(a2, v1), dot(a2, v2), dot(a2, v3)); - acc[3] += vec4(dot(a3, v0), dot(a3, v1), dot(a3, v2), dot(a3, v3)); + let v0 = load_v_d4(c4 + 0u, kvh, d0); + let v1 = load_v_d4(c4 + 1u, kvh, d0); + let v2 = load_v_d4(c4 + 2u, kvh, d0); + let v3 = load_v_d4(c4 + 3u, kvh, d0); + acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3; + acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; + acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; + acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; c4 = c4 + 4u; }