From 8e682944365170b135b511745bae6c49680c2111 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Tue, 23 Jun 2026 13:21:50 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- .../runtime/ops/sdpa/sdpa_compute_out.wgsl | 31 +++++++++-------- .../runtime/ops/sdpa/sdpa_compute_out_wgsl.h | 33 +++++++++---------- 2 files changed, 31 insertions(+), 33 deletions(-) diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl index 3ac2339376e..a5c79dfd4e3 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl @@ -32,17 +32,16 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { return r; } -fn load_v_vec4(d: u32, kvh: u32, c4: u32) -> vec4 { +fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4 { var r = vec4(0.0, 0.0, 0.0, 0.0); - if (d >= params.D) { + if (c >= params.context_len) { return r; } - let stride = params.Hkv * params.D; - let off = kvh * params.D + d; - if (c4 + 0u < params.context_len) { r.x = t_v_cache[(c4 + 0u) * stride + off]; } - if (c4 + 1u < params.context_len) { r.y = t_v_cache[(c4 + 1u) * stride + off]; } - if (c4 + 2u < params.context_len) { r.z = t_v_cache[(c4 + 2u) * stride + off]; } - if (c4 + 3u < params.context_len) { r.w = t_v_cache[(c4 + 3u) * stride + off]; } + let base = c * params.Hkv * params.D + kvh * params.D + d0; + if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; } + if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; } + if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; } + if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; } return r; } @@ -87,14 +86,14 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let a1 = load_a_vec4(s0 + 1u, h, c4); let a2 = load_a_vec4(s0 + 2u, h, c4); let a3 = load_a_vec4(s0 + 3u, h, c4); - let v0 = load_v_vec4(d0 + 0u, kvh, c4); - let v1 = load_v_vec4(d0 + 1u, kvh, c4); - let v2 = load_v_vec4(d0 + 2u, kvh, c4); - let v3 = load_v_vec4(d0 + 3u, kvh, c4); - acc[0] += vec4(dot(a0, v0), dot(a0, v1), dot(a0, v2), dot(a0, v3)); - acc[1] += vec4(dot(a1, v0), dot(a1, v1), dot(a1, v2), dot(a1, v3)); - acc[2] += vec4(dot(a2, v0), dot(a2, v1), dot(a2, v2), dot(a2, v3)); - acc[3] += vec4(dot(a3, v0), dot(a3, v1), dot(a3, v2), dot(a3, v3)); + let v0 = load_v_d4(c4 + 0u, kvh, d0); + let v1 = load_v_d4(c4 + 1u, kvh, d0); + let v2 = load_v_d4(c4 + 2u, kvh, d0); + let v3 = load_v_d4(c4 + 3u, kvh, d0); + acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3; + acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; + acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; + acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; c4 = c4 + 4u; } diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h index cf1d742d7e5..2fef0e2d8c9 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_out.wgsl - DO NOT EDIT. -// wgsl-sha256: 4ffc13bad0bf56b87a57f75307f29e851dd2bd6bf0dba094488df5d262e910e3 +// wgsl-sha256: 545f624567b08eba407954034df821010e49124fa6f8fd6b05c64ca4354ee4cc inline constexpr const char* kSdpaComputeOutWGSL = R"( @group(0) @binding(0) var t_out: array; @group(0) @binding(1) var t_attn_weights_softmax: array; @@ -49,17 +49,16 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { return r; } -fn load_v_vec4(d: u32, kvh: u32, c4: u32) -> vec4 { +fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4 { var r = vec4(0.0, 0.0, 0.0, 0.0); - if (d >= params.D) { + if (c >= params.context_len) { return r; } - let stride = params.Hkv * params.D; - let off = kvh * params.D + d; - if (c4 + 0u < params.context_len) { r.x = t_v_cache[(c4 + 0u) * stride + off]; } - if (c4 + 1u < params.context_len) { r.y = t_v_cache[(c4 + 1u) * stride + off]; } - if (c4 + 2u < params.context_len) { r.z = t_v_cache[(c4 + 2u) * stride + off]; } - if (c4 + 3u < params.context_len) { r.w = t_v_cache[(c4 + 3u) * stride + off]; } + let base = c * params.Hkv * params.D + kvh * params.D + d0; + if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; } + if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; } + if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; } + if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; } return r; } @@ -104,14 +103,14 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let a1 = load_a_vec4(s0 + 1u, h, c4); let a2 = load_a_vec4(s0 + 2u, h, c4); let a3 = load_a_vec4(s0 + 3u, h, c4); - let v0 = load_v_vec4(d0 + 0u, kvh, c4); - let v1 = load_v_vec4(d0 + 1u, kvh, c4); - let v2 = load_v_vec4(d0 + 2u, kvh, c4); - let v3 = load_v_vec4(d0 + 3u, kvh, c4); - acc[0] += vec4(dot(a0, v0), dot(a0, v1), dot(a0, v2), dot(a0, v3)); - acc[1] += vec4(dot(a1, v0), dot(a1, v1), dot(a1, v2), dot(a1, v3)); - acc[2] += vec4(dot(a2, v0), dot(a2, v1), dot(a2, v2), dot(a2, v3)); - acc[3] += vec4(dot(a3, v0), dot(a3, v1), dot(a3, v2), dot(a3, v3)); + let v0 = load_v_d4(c4 + 0u, kvh, d0); + let v1 = load_v_d4(c4 + 1u, kvh, d0); + let v2 = load_v_d4(c4 + 2u, kvh, d0); + let v3 = load_v_d4(c4 + 3u, kvh, d0); + acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3; + acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; + acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; + acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; c4 = c4 + 4u; }