diff --git a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp index aaf93c7fda0..b1aa689a09d 100644 --- a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp +++ b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp @@ -339,6 +339,11 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { if (k.dims[kn - 1] != D || v.dims[v.dims.size() - 1] != D) { throw std::runtime_error("WebGPU sdpa: k/v head_dim must match q"); } + // QK/AV read D as vec4 (no SDPA_PAD_D); head_dim must be a multiple of 4. + if (D % 4 != 0) { + throw std::runtime_error( + "WebGPU sdpa: head_dim (D) must be a multiple of 4"); + } if (v.dims[v.dims.size() - 2] != Hkv) { throw std::runtime_error("WebGPU sdpa: v num_heads must match k"); } diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl index e694a6b223c..fd15767603d 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl @@ -1,6 +1,6 @@ @group(0) @binding(0) var t_attn_weights: array; -@group(0) @binding(1) var t_q: array; -@group(0) @binding(2) var t_k_cache: array; +@group(0) @binding(1) var t_q: array>; +@group(0) @binding(2) var t_k_cache: array>; struct Params { S: u32, @@ -22,30 +22,21 @@ override wg_size: u32 = 64; const TM: u32 = 4u; const TN: u32 = 4u; +// D is a multiple of 4 (host-guarded), so a d4 chunk is fully in-bounds — no per-lane check. fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (s >= params.S) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } - let base = s * params.Hq * params.D + h * params.D; - if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; } - if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; } - if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; } - if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; } - return r; + let base = s * params.Hq * params.D + h * params.D + d4; + return t_q[base / 4u]; } fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (c >= params.context_len) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } - let base = c * params.Hkv * params.D + kvh * params.D; - if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; } - if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; } - if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; } - if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; } - return r; + let base = c * params.Hkv * params.D + kvh * params.D + d4; + return t_k_cache[base / 4u]; } fn store_qk(s: u32, c: u32, h: u32, raw: f32) { @@ -85,9 +76,11 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[2] = vec4(0.0, 0.0, 0.0, 0.0); acc[3] = vec4(0.0, 0.0, 0.0, 0.0); + // Skip fully-masked causal tiles; mirrors Vulkan attn_weights_tiled.glsl. + let skip_tile = c0 > s0 + (TM - 1u) + params.input_pos; var d4: u32 = 0u; loop { - if (d4 >= params.D) { + if (d4 >= params.D || skip_tile) { break; } var q: array, TM>; diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h index 3d027b417f0..ae250959e0e 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h @@ -13,11 +13,11 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_attn_weights.wgsl - DO NOT EDIT. -// wgsl-sha256: e9bdc2272bba2716655e96be0597571701b4ec1496f78a85660e29d108f655f9 +// wgsl-sha256: 4eef09b234fd926cdc0daf18d03e39cf4fd57dfa4bc67724b4878b7dc68d1254 inline constexpr const char* kSdpaComputeAttnWeightsWGSL = R"( @group(0) @binding(0) var t_attn_weights: array; -@group(0) @binding(1) var t_q: array; -@group(0) @binding(2) var t_k_cache: array; +@group(0) @binding(1) var t_q: array>; +@group(0) @binding(2) var t_k_cache: array>; struct Params { S: u32, @@ -39,30 +39,21 @@ override wg_size: u32 = 64; const TM: u32 = 4u; const TN: u32 = 4u; +// D is a multiple of 4 (host-guarded), so a d4 chunk is fully in-bounds — no per-lane check. fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (s >= params.S) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } - let base = s * params.Hq * params.D + h * params.D; - if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; } - if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; } - if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; } - if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; } - return r; + let base = s * params.Hq * params.D + h * params.D + d4; + return t_q[base / 4u]; } fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (c >= params.context_len) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } - let base = c * params.Hkv * params.D + kvh * params.D; - if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; } - if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; } - if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; } - if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; } - return r; + let base = c * params.Hkv * params.D + kvh * params.D + d4; + return t_k_cache[base / 4u]; } fn store_qk(s: u32, c: u32, h: u32, raw: f32) { @@ -102,9 +93,11 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[2] = vec4(0.0, 0.0, 0.0, 0.0); acc[3] = vec4(0.0, 0.0, 0.0, 0.0); + // Skip fully-masked causal tiles; mirrors Vulkan attn_weights_tiled.glsl. + let skip_tile = c0 > s0 + (TM - 1u) + params.input_pos; var d4: u32 = 0u; loop { - if (d4 >= params.D) { + if (d4 >= params.D || skip_tile) { break; } var q: array, TM>; diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl index a5c79dfd4e3..56242b0ddde 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl @@ -1,6 +1,6 @@ -@group(0) @binding(0) var t_out: array; +@group(0) @binding(0) var t_out: array>; @group(0) @binding(1) var t_attn_weights_softmax: array; -@group(0) @binding(2) var t_v_cache: array; +@group(0) @binding(2) var t_v_cache: array>; struct Params { S: u32, @@ -19,6 +19,7 @@ override wg_size: u32 = 64; const TM: u32 = 4u; const TN: u32 = 4u; +// Checked loaders mask context lanes past context_len (D%4==0, host-guarded). fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { var r = vec4(0.0, 0.0, 0.0, 0.0); if (s >= params.S) { @@ -33,24 +34,33 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { } fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (c >= params.context_len) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } let base = c * params.Hkv * params.D + kvh * params.D + d0; - if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; } - if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; } - if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; } - if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; } - return r; + return t_v_cache[base / 4u]; +} + +// Branch-free loaders for the aligned body: caller guarantees c4..c4+3 < context_len. +fn load_a_vec4_nc(s: u32, h: u32, c4: u32) -> vec4 { + if (s >= params.S) { + return vec4(0.0, 0.0, 0.0, 0.0); + } + let base = h * params.S * params.context_len + s * params.context_len + c4; + return vec4(t_attn_weights_softmax[base], t_attn_weights_softmax[base + 1u], t_attn_weights_softmax[base + 2u], t_attn_weights_softmax[base + 3u]); +} + +fn load_v_d4_nc(c: u32, kvh: u32, d0: u32) -> vec4 { + let base = c * params.Hkv * params.D + kvh * params.D + d0; + return t_v_cache[base / 4u]; } -fn store_out(s: u32, d: u32, h: u32, val: f32) { - if (s >= params.S || d >= params.D) { +fn store_out_vec4(s: u32, d0: u32, h: u32, val: vec4) { + if (s >= params.S) { return; } - let idx = s * params.Hq * params.D + h * params.D + d; - t_out[idx] = val; + let idx = s * params.Hq * params.D + h * params.D + d0; + t_out[idx / 4u] = val; } @compute @workgroup_size(wg_size, 1, 1) @@ -77,11 +87,28 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[2] = vec4(0.0, 0.0, 0.0, 0.0); acc[3] = vec4(0.0, 0.0, 0.0, 0.0); + // Branch-free aligned body + checked tail; mirrors Vulkan out_tiled.glsl. + let ctx_aligned = params.context_len - (params.context_len & 3u); var c4: u32 = 0u; loop { - if (c4 >= params.context_len) { + if (c4 >= ctx_aligned) { break; } + let a0 = load_a_vec4_nc(s0 + 0u, h, c4); + let a1 = load_a_vec4_nc(s0 + 1u, h, c4); + let a2 = load_a_vec4_nc(s0 + 2u, h, c4); + let a3 = load_a_vec4_nc(s0 + 3u, h, c4); + let v0 = load_v_d4_nc(c4 + 0u, kvh, d0); + let v1 = load_v_d4_nc(c4 + 1u, kvh, d0); + let v2 = load_v_d4_nc(c4 + 2u, kvh, d0); + let v3 = load_v_d4_nc(c4 + 3u, kvh, d0); + acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3; + acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; + acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; + acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; + c4 = c4 + 4u; + } + if (c4 < params.context_len) { let a0 = load_a_vec4(s0 + 0u, h, c4); let a1 = load_a_vec4(s0 + 1u, h, c4); let a2 = load_a_vec4(s0 + 2u, h, c4); @@ -94,7 +121,6 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; - c4 = c4 + 4u; } var m: u32 = 0u; @@ -102,11 +128,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { if (m >= TM) { break; } - let ov = acc[m]; - store_out(s0 + m, d0 + 0u, h, ov.x); - store_out(s0 + m, d0 + 1u, h, ov.y); - store_out(s0 + m, d0 + 2u, h, ov.z); - store_out(s0 + m, d0 + 3u, h, ov.w); + store_out_vec4(s0 + m, d0, h, acc[m]); m = m + 1u; } } diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h index 2fef0e2d8c9..6bec079ac2b 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h @@ -13,11 +13,11 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_out.wgsl - DO NOT EDIT. -// wgsl-sha256: 545f624567b08eba407954034df821010e49124fa6f8fd6b05c64ca4354ee4cc +// wgsl-sha256: 2ffa0eb520b1054e43a10fd13e6b287bd35777f1cfc29bd39e9d668772528191 inline constexpr const char* kSdpaComputeOutWGSL = R"( -@group(0) @binding(0) var t_out: array; +@group(0) @binding(0) var t_out: array>; @group(0) @binding(1) var t_attn_weights_softmax: array; -@group(0) @binding(2) var t_v_cache: array; +@group(0) @binding(2) var t_v_cache: array>; struct Params { S: u32, @@ -36,6 +36,7 @@ override wg_size: u32 = 64; const TM: u32 = 4u; const TN: u32 = 4u; +// Checked loaders mask context lanes past context_len (D%4==0, host-guarded). fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { var r = vec4(0.0, 0.0, 0.0, 0.0); if (s >= params.S) { @@ -50,24 +51,33 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { } fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4 { - var r = vec4(0.0, 0.0, 0.0, 0.0); if (c >= params.context_len) { - return r; + return vec4(0.0, 0.0, 0.0, 0.0); } let base = c * params.Hkv * params.D + kvh * params.D + d0; - if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; } - if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; } - if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; } - if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; } - return r; + return t_v_cache[base / 4u]; +} + +// Branch-free loaders for the aligned body: caller guarantees c4..c4+3 < context_len. +fn load_a_vec4_nc(s: u32, h: u32, c4: u32) -> vec4 { + if (s >= params.S) { + return vec4(0.0, 0.0, 0.0, 0.0); + } + let base = h * params.S * params.context_len + s * params.context_len + c4; + return vec4(t_attn_weights_softmax[base], t_attn_weights_softmax[base + 1u], t_attn_weights_softmax[base + 2u], t_attn_weights_softmax[base + 3u]); +} + +fn load_v_d4_nc(c: u32, kvh: u32, d0: u32) -> vec4 { + let base = c * params.Hkv * params.D + kvh * params.D + d0; + return t_v_cache[base / 4u]; } -fn store_out(s: u32, d: u32, h: u32, val: f32) { - if (s >= params.S || d >= params.D) { +fn store_out_vec4(s: u32, d0: u32, h: u32, val: vec4) { + if (s >= params.S) { return; } - let idx = s * params.Hq * params.D + h * params.D + d; - t_out[idx] = val; + let idx = s * params.Hq * params.D + h * params.D + d0; + t_out[idx / 4u] = val; } @compute @workgroup_size(wg_size, 1, 1) @@ -94,11 +104,28 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[2] = vec4(0.0, 0.0, 0.0, 0.0); acc[3] = vec4(0.0, 0.0, 0.0, 0.0); + // Branch-free aligned body + checked tail; mirrors Vulkan out_tiled.glsl. + let ctx_aligned = params.context_len - (params.context_len & 3u); var c4: u32 = 0u; loop { - if (c4 >= params.context_len) { + if (c4 >= ctx_aligned) { break; } + let a0 = load_a_vec4_nc(s0 + 0u, h, c4); + let a1 = load_a_vec4_nc(s0 + 1u, h, c4); + let a2 = load_a_vec4_nc(s0 + 2u, h, c4); + let a3 = load_a_vec4_nc(s0 + 3u, h, c4); + let v0 = load_v_d4_nc(c4 + 0u, kvh, d0); + let v1 = load_v_d4_nc(c4 + 1u, kvh, d0); + let v2 = load_v_d4_nc(c4 + 2u, kvh, d0); + let v3 = load_v_d4_nc(c4 + 3u, kvh, d0); + acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3; + acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; + acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; + acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; + c4 = c4 + 4u; + } + if (c4 < params.context_len) { let a0 = load_a_vec4(s0 + 0u, h, c4); let a1 = load_a_vec4(s0 + 1u, h, c4); let a2 = load_a_vec4(s0 + 2u, h, c4); @@ -111,7 +138,6 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; - c4 = c4 + 4u; } var m: u32 = 0u; @@ -119,11 +145,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { if (m >= TM) { break; } - let ov = acc[m]; - store_out(s0 + m, d0 + 0u, h, ov.x); - store_out(s0 + m, d0 + 1u, h, ov.y); - store_out(s0 + m, d0 + 2u, h, ov.z); - store_out(s0 + m, d0 + 3u, h, ov.w); + store_out_vec4(s0 + m, d0, h, acc[m]); m = m + 1u; } } diff --git a/backends/webgpu/test/ops/sdpa/test_sdpa.py b/backends/webgpu/test/ops/sdpa/test_sdpa.py index b674feae635..1f7a8242591 100644 --- a/backends/webgpu/test/ops/sdpa/test_sdpa.py +++ b/backends/webgpu/test/ops/sdpa/test_sdpa.py @@ -59,6 +59,8 @@ class SdpaConfig: # Llama 3.2 1B shape: realistic prefill (S=128 at pos 0) + decode (S=1 at pos 127). SdpaConfig("llama1b_prefill", 32, 8, 64, 128, 512, 0), SdpaConfig("llama1b_decode", 32, 8, 64, 1, 512, 127), + # D=6 is not a multiple of 4: the WebGPU head_dim%4 guard must reject it at load. + SdpaConfig("reject_d6", 4, 4, 6, 4, 16, 0), ] diff --git a/backends/webgpu/test/test_webgpu_native.cpp b/backends/webgpu/test/test_webgpu_native.cpp index 8d987578aa1..6b57254ca33 100644 --- a/backends/webgpu/test/test_webgpu_native.cpp +++ b/backends/webgpu/test/test_webgpu_native.cpp @@ -719,6 +719,7 @@ struct SdpaConfig { int input_pos; // prior tokens already in the cache (decode) float denom; // ramp divisor (mirrors Python); small -> large logits bool required = false; // CI (SDPA dir set): absent .pte = FAIL, not skip + bool expect_reject = false; // load MUST fail (e.g. D%4 guard), no golden }; static const SdpaConfig kSdpaConfigs[] = { @@ -738,6 +739,17 @@ static const SdpaConfig kSdpaConfigs[] = { // pos 127). {"llama1b_prefill", 32, 8, 64, 128, 512, 0, 16.0f}, {"llama1b_decode", 32, 8, 64, 1, 512, 127, 16.0f}, + // D=6 is not a multiple of 4: the head_dim%4 guard must reject it at load. + {"reject_d6", + 4, + 4, + 6, + 4, + 16, + 0, + 16.0f, + /*required=*/false, + /*expect_reject=*/true}, }; // Ramp denominator; mirror of test_sdpa.py::_RAMP_DENOM (keep in sync). @@ -791,6 +803,15 @@ static bool test_sdpa_config( Module module(model_path); auto err = module.load_forward(); + if (cfg.expect_reject) { + // D not a multiple of 4 must be rejected at load by the head_dim guard. + if (err != Error::Ok) { + printf("PASS: %s rejected at load (error %d)\n", cfg.name, (int)err); + return true; + } + printf("FAIL: %s loaded OK; head_dim%%4 guard did not fire\n", cfg.name); + return false; + } if (err != Error::Ok) { printf("FAIL: could not load forward method (error %d)\n", (int)err); return false;