diff --git a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp index dd48f6f5902..b1aa689a09d 100644 --- a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp +++ b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp @@ -26,6 +26,10 @@ namespace executorch::backends::webgpu { namespace { +// Register-tile dims; MUST match TM/TN in the reg WGSL kernels. +constexpr int64_t kSdpaTileM = 4; +constexpr int64_t kSdpaTileN = 4; + // Uniform param structs (all 16-byte aligned, matching the WGSL Params). struct UpdateCacheParams { uint32_t numel; @@ -335,6 +339,11 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { if (k.dims[kn - 1] != D || v.dims[v.dims.size() - 1] != D) { throw std::runtime_error("WebGPU sdpa: k/v head_dim must match q"); } + // QK/AV read D as vec4 (no SDPA_PAD_D); head_dim must be a multiple of 4. + if (D % 4 != 0) { + throw std::runtime_error( + "WebGPU sdpa: head_dim (D) must be a multiple of 4"); + } if (v.dims[v.dims.size() - 2] != Hkv) { throw std::runtime_error("WebGPU sdpa: v num_heads must match k"); } @@ -464,14 +473,16 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { dynamic_pos, "update_cache(V)"); - // --- Dispatch 3: QK -> attn_weights. One thread per (h,s,c) element. + // --- Dispatch 3: QK -> attn_weights. One thread per TM x TN tile. { if (aw_floats > UINT32_MAX) { throw std::runtime_error( "WebGPU sdpa: Hq*S*context_len exceeds uint32 max"); } + const int64_t qk_tiles = Hq * utils::div_up(S, kSdpaTileM) * + utils::div_up(context_len, kSdpaTileN); const uint32_t wgc = utils::compute_1d_workgroup_count( - device, static_cast(aw_floats), qk_wg, "QK"); + device, static_cast(qk_tiles), qk_wg, "QK"); AttnWeightsParams p = make_attn_weights_params( S, Hq, Hkv, D, context_len, input_pos, g, scale); WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p)); @@ -515,12 +526,12 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { softmax_buf = ubuf; } - // --- Dispatch 5: AV -> out. One thread per (s,h,d) output element. + // --- Dispatch 5: AV -> out. One thread per TM x TN tile. { - const uint64_t out_floats = static_cast(S) * - static_cast(Hq) * static_cast(D); + const int64_t av_tiles = + Hq * utils::div_up(S, kSdpaTileM) * utils::div_up(D, kSdpaTileN); const uint32_t wgc = utils::compute_1d_workgroup_count( - device, static_cast(out_floats), av_wg, "AV"); + device, static_cast(av_tiles), av_wg, "AV"); ComputeOutParams p = make_compute_out_params(S, Hq, Hkv, D, context_len, g); WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p)); BufferBinding bindings[3] = { @@ -591,9 +602,11 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { AttnWeightsParams qp = make_attn_weights_params(S, Hq, Hkv, D, ctx, pos, g, scale); wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp)); + const int64_t qk_tiles = Hq * utils::div_up(S, kSdpaTileM) * + utils::div_up(ctx, kSdpaTileN); const uint32_t qk_wgc = utils::compute_1d_workgroup_count( gr.device(), - static_cast(aw_floats), + static_cast(qk_tiles), qk_wg, "QK(resize)"); gr.dispatch_at(qk_idx).workgroup_count_x = qk_wgc; diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl index b9905a59376..fd15767603d 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl @@ -1,6 +1,6 @@ @group(0) @binding(0) var t_attn_weights: array; -@group(0) @binding(1) var t_q: array; -@group(0) @binding(2) var t_k_cache: array; +@group(0) @binding(1) var t_q: array>; +@group(0) @binding(2) var t_k_cache: array>; struct Params { S: u32, @@ -19,37 +19,98 @@ const NEG_INF: f32 = -1.0e30; override wg_size: u32 = 64; +const TM: u32 = 4u; +const TN: u32 = 4u; + +// D is a multiple of 4 (host-guarded), so a d4 chunk is fully in-bounds — no per-lane check. +fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4 { + if (s >= params.S) { + return vec4(0.0, 0.0, 0.0, 0.0); + } + let base = s * params.Hq * params.D + h * params.D + d4; + return t_q[base / 4u]; +} + +fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4 { + if (c >= params.context_len) { + return vec4(0.0, 0.0, 0.0, 0.0); + } + let base = c * params.Hkv * params.D + kvh * params.D + d4; + return t_k_cache[base / 4u]; +} + +fn store_qk(s: u32, c: u32, h: u32, raw: f32) { + if (s >= params.S || c >= params.context_len) { + return; + } + var val = raw * params.scale; + // Causal mask: position c may not attend beyond s + input_pos. + if (c > s + params.input_pos) { + val = NEG_INF; + } + let idx = h * params.S * params.context_len + s * params.context_len + c; + t_attn_weights[idx] = val; +} + @compute @workgroup_size(wg_size, 1, 1) fn main(@builtin(global_invocation_id) gid: vec3) { - let total = params.Hq * params.S * params.context_len; - let idx = gid.x; - if (idx >= total) { + let nrt = (params.S + TM - 1u) / TM; + let nct = (params.context_len + TN - 1u) / TN; + let tiles = nrt * nct; + let total = tiles * params.Hq; + if (gid.x >= total) { return; } - let c = idx % params.context_len; - let s = (idx / params.context_len) % params.S; - let h = idx / (params.context_len * params.S); + let h = gid.x / tiles; + let rem = gid.x % tiles; + let row_tile = rem / nct; + let col_tile = rem % nct; let kvh = h / params.g; + let s0 = row_tile * TM; + let c0 = col_tile * TN; - let q_base = s * params.Hq * params.D + h * params.D; - let k_base = c * params.Hkv * params.D + kvh * params.D; + var acc: array, 4>; + acc[0] = vec4(0.0, 0.0, 0.0, 0.0); + acc[1] = vec4(0.0, 0.0, 0.0, 0.0); + acc[2] = vec4(0.0, 0.0, 0.0, 0.0); + acc[3] = vec4(0.0, 0.0, 0.0, 0.0); - var acc: f32 = 0.0; - var d: u32 = 0u; + // Skip fully-masked causal tiles; mirrors Vulkan attn_weights_tiled.glsl. + let skip_tile = c0 > s0 + (TM - 1u) + params.input_pos; + var d4: u32 = 0u; loop { - if (d >= params.D) { + if (d4 >= params.D || skip_tile) { break; } - acc = acc + t_q[q_base + d] * t_k_cache[k_base + d]; - d = d + 1u; + var q: array, TM>; + var k: array, TN>; + for (var i: u32 = 0u; i < TM; i = i + 1u) { + q[i] = load_q_vec4(s0 + i, h, d4); + } + for (var j: u32 = 0u; j < TN; j = j + 1u) { + k[j] = load_k_vec4(c0 + j, kvh, d4); + } + for (var i: u32 = 0u; i < TM; i = i + 1u) { + acc[i] += vec4( + dot(q[i], k[0]), + dot(q[i], k[1]), + dot(q[i], k[2]), + dot(q[i], k[3])); + } + d4 = d4 + 4u; } - acc = acc * params.scale; - // Causal mask: position c may not attend beyond s + input_pos. - if (c > s + params.input_pos) { - acc = NEG_INF; + var m: u32 = 0u; + loop { + if (m >= TM) { + break; + } + let av = acc[m]; + store_qk(s0 + m, c0 + 0u, h, av.x); + store_qk(s0 + m, c0 + 1u, h, av.y); + store_qk(s0 + m, c0 + 2u, h, av.z); + store_qk(s0 + m, c0 + 3u, h, av.w); + m = m + 1u; } - - t_attn_weights[idx] = acc; } diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h index 3f3f3d6b085..ae250959e0e 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h @@ -13,11 +13,11 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_attn_weights.wgsl - DO NOT EDIT. -// wgsl-sha256: 7410869c1c35f09777851bf49b835dc8fecaff3f327aa64a9c900ac0cc3445e1 +// wgsl-sha256: 4eef09b234fd926cdc0daf18d03e39cf4fd57dfa4bc67724b4878b7dc68d1254 inline constexpr const char* kSdpaComputeAttnWeightsWGSL = R"( @group(0) @binding(0) var t_attn_weights: array; -@group(0) @binding(1) var t_q: array; -@group(0) @binding(2) var t_k_cache: array; +@group(0) @binding(1) var t_q: array>; +@group(0) @binding(2) var t_k_cache: array>; struct Params { S: u32, @@ -36,39 +36,100 @@ const NEG_INF: f32 = -1.0e30; override wg_size: u32 = 64; +const TM: u32 = 4u; +const TN: u32 = 4u; + +// D is a multiple of 4 (host-guarded), so a d4 chunk is fully in-bounds — no per-lane check. +fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4 { + if (s >= params.S) { + return vec4(0.0, 0.0, 0.0, 0.0); + } + let base = s * params.Hq * params.D + h * params.D + d4; + return t_q[base / 4u]; +} + +fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4 { + if (c >= params.context_len) { + return vec4(0.0, 0.0, 0.0, 0.0); + } + let base = c * params.Hkv * params.D + kvh * params.D + d4; + return t_k_cache[base / 4u]; +} + +fn store_qk(s: u32, c: u32, h: u32, raw: f32) { + if (s >= params.S || c >= params.context_len) { + return; + } + var val = raw * params.scale; + // Causal mask: position c may not attend beyond s + input_pos. + if (c > s + params.input_pos) { + val = NEG_INF; + } + let idx = h * params.S * params.context_len + s * params.context_len + c; + t_attn_weights[idx] = val; +} + @compute @workgroup_size(wg_size, 1, 1) fn main(@builtin(global_invocation_id) gid: vec3) { - let total = params.Hq * params.S * params.context_len; - let idx = gid.x; - if (idx >= total) { + let nrt = (params.S + TM - 1u) / TM; + let nct = (params.context_len + TN - 1u) / TN; + let tiles = nrt * nct; + let total = tiles * params.Hq; + if (gid.x >= total) { return; } - let c = idx % params.context_len; - let s = (idx / params.context_len) % params.S; - let h = idx / (params.context_len * params.S); + let h = gid.x / tiles; + let rem = gid.x % tiles; + let row_tile = rem / nct; + let col_tile = rem % nct; let kvh = h / params.g; + let s0 = row_tile * TM; + let c0 = col_tile * TN; - let q_base = s * params.Hq * params.D + h * params.D; - let k_base = c * params.Hkv * params.D + kvh * params.D; + var acc: array, 4>; + acc[0] = vec4(0.0, 0.0, 0.0, 0.0); + acc[1] = vec4(0.0, 0.0, 0.0, 0.0); + acc[2] = vec4(0.0, 0.0, 0.0, 0.0); + acc[3] = vec4(0.0, 0.0, 0.0, 0.0); - var acc: f32 = 0.0; - var d: u32 = 0u; + // Skip fully-masked causal tiles; mirrors Vulkan attn_weights_tiled.glsl. + let skip_tile = c0 > s0 + (TM - 1u) + params.input_pos; + var d4: u32 = 0u; loop { - if (d >= params.D) { + if (d4 >= params.D || skip_tile) { break; } - acc = acc + t_q[q_base + d] * t_k_cache[k_base + d]; - d = d + 1u; + var q: array, TM>; + var k: array, TN>; + for (var i: u32 = 0u; i < TM; i = i + 1u) { + q[i] = load_q_vec4(s0 + i, h, d4); + } + for (var j: u32 = 0u; j < TN; j = j + 1u) { + k[j] = load_k_vec4(c0 + j, kvh, d4); + } + for (var i: u32 = 0u; i < TM; i = i + 1u) { + acc[i] += vec4( + dot(q[i], k[0]), + dot(q[i], k[1]), + dot(q[i], k[2]), + dot(q[i], k[3])); + } + d4 = d4 + 4u; } - acc = acc * params.scale; - // Causal mask: position c may not attend beyond s + input_pos. - if (c > s + params.input_pos) { - acc = NEG_INF; + var m: u32 = 0u; + loop { + if (m >= TM) { + break; + } + let av = acc[m]; + store_qk(s0 + m, c0 + 0u, h, av.x); + store_qk(s0 + m, c0 + 1u, h, av.y); + store_qk(s0 + m, c0 + 2u, h, av.z); + store_qk(s0 + m, c0 + 3u, h, av.w); + m = m + 1u; } - - t_attn_weights[idx] = acc; } )"; diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl index 97642670f60..56242b0ddde 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl @@ -1,6 +1,6 @@ -@group(0) @binding(0) var t_out: array; +@group(0) @binding(0) var t_out: array>; @group(0) @binding(1) var t_attn_weights_softmax: array; -@group(0) @binding(2) var t_v_cache: array; +@group(0) @binding(2) var t_v_cache: array>; struct Params { S: u32, @@ -16,31 +16,119 @@ struct Params { override wg_size: u32 = 64; +const TM: u32 = 4u; +const TN: u32 = 4u; + +// Checked loaders mask context lanes past context_len (D%4==0, host-guarded). +fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (s >= params.S) { + return r; + } + let base = h * params.S * params.context_len + s * params.context_len; + if (c4 + 0u < params.context_len) { r.x = t_attn_weights_softmax[base + c4 + 0u]; } + if (c4 + 1u < params.context_len) { r.y = t_attn_weights_softmax[base + c4 + 1u]; } + if (c4 + 2u < params.context_len) { r.z = t_attn_weights_softmax[base + c4 + 2u]; } + if (c4 + 3u < params.context_len) { r.w = t_attn_weights_softmax[base + c4 + 3u]; } + return r; +} + +fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4 { + if (c >= params.context_len) { + return vec4(0.0, 0.0, 0.0, 0.0); + } + let base = c * params.Hkv * params.D + kvh * params.D + d0; + return t_v_cache[base / 4u]; +} + +// Branch-free loaders for the aligned body: caller guarantees c4..c4+3 < context_len. +fn load_a_vec4_nc(s: u32, h: u32, c4: u32) -> vec4 { + if (s >= params.S) { + return vec4(0.0, 0.0, 0.0, 0.0); + } + let base = h * params.S * params.context_len + s * params.context_len + c4; + return vec4(t_attn_weights_softmax[base], t_attn_weights_softmax[base + 1u], t_attn_weights_softmax[base + 2u], t_attn_weights_softmax[base + 3u]); +} + +fn load_v_d4_nc(c: u32, kvh: u32, d0: u32) -> vec4 { + let base = c * params.Hkv * params.D + kvh * params.D + d0; + return t_v_cache[base / 4u]; +} + +fn store_out_vec4(s: u32, d0: u32, h: u32, val: vec4) { + if (s >= params.S) { + return; + } + let idx = s * params.Hq * params.D + h * params.D + d0; + t_out[idx / 4u] = val; +} + @compute @workgroup_size(wg_size, 1, 1) fn main(@builtin(global_invocation_id) gid: vec3) { - let total = params.S * params.Hq * params.D; - let idx = gid.x; - if (idx >= total) { + let nrt = (params.S + TM - 1u) / TM; + let nct = (params.D + TN - 1u) / TN; + let tiles = nrt * nct; + let total = tiles * params.Hq; + if (gid.x >= total) { return; } - let d = idx % params.D; - let h = (idx / params.D) % params.Hq; - let s = idx / (params.D * params.Hq); + let h = gid.x / tiles; + let rem = gid.x % tiles; + let row_tile = rem / nct; + let col_tile = rem % nct; let kvh = h / params.g; + let s0 = row_tile * TM; + let d0 = col_tile * TN; - let aw_base = h * params.S * params.context_len + s * params.context_len; + var acc: array, 4>; + acc[0] = vec4(0.0, 0.0, 0.0, 0.0); + acc[1] = vec4(0.0, 0.0, 0.0, 0.0); + acc[2] = vec4(0.0, 0.0, 0.0, 0.0); + acc[3] = vec4(0.0, 0.0, 0.0, 0.0); - var acc: f32 = 0.0; - var c: u32 = 0u; + // Branch-free aligned body + checked tail; mirrors Vulkan out_tiled.glsl. + let ctx_aligned = params.context_len - (params.context_len & 3u); + var c4: u32 = 0u; loop { - if (c >= params.context_len) { + if (c4 >= ctx_aligned) { break; } - let v_off = c * params.Hkv * params.D + kvh * params.D + d; - acc = acc + t_attn_weights_softmax[aw_base + c] * t_v_cache[v_off]; - c = c + 1u; + let a0 = load_a_vec4_nc(s0 + 0u, h, c4); + let a1 = load_a_vec4_nc(s0 + 1u, h, c4); + let a2 = load_a_vec4_nc(s0 + 2u, h, c4); + let a3 = load_a_vec4_nc(s0 + 3u, h, c4); + let v0 = load_v_d4_nc(c4 + 0u, kvh, d0); + let v1 = load_v_d4_nc(c4 + 1u, kvh, d0); + let v2 = load_v_d4_nc(c4 + 2u, kvh, d0); + let v3 = load_v_d4_nc(c4 + 3u, kvh, d0); + acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3; + acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; + acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; + acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; + c4 = c4 + 4u; + } + if (c4 < params.context_len) { + let a0 = load_a_vec4(s0 + 0u, h, c4); + let a1 = load_a_vec4(s0 + 1u, h, c4); + let a2 = load_a_vec4(s0 + 2u, h, c4); + let a3 = load_a_vec4(s0 + 3u, h, c4); + let v0 = load_v_d4(c4 + 0u, kvh, d0); + let v1 = load_v_d4(c4 + 1u, kvh, d0); + let v2 = load_v_d4(c4 + 2u, kvh, d0); + let v3 = load_v_d4(c4 + 3u, kvh, d0); + acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3; + acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; + acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; + acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; } - t_out[idx] = acc; + var m: u32 = 0u; + loop { + if (m >= TM) { + break; + } + store_out_vec4(s0 + m, d0, h, acc[m]); + m = m + 1u; + } } diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h index ce25df06876..6bec079ac2b 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h @@ -13,11 +13,11 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_out.wgsl - DO NOT EDIT. -// wgsl-sha256: 67b9c64fbffdcb72264dda42e24b59e414719411c64c504f84f2ba57b5dcfc0f +// wgsl-sha256: 2ffa0eb520b1054e43a10fd13e6b287bd35777f1cfc29bd39e9d668772528191 inline constexpr const char* kSdpaComputeOutWGSL = R"( -@group(0) @binding(0) var t_out: array; +@group(0) @binding(0) var t_out: array>; @group(0) @binding(1) var t_attn_weights_softmax: array; -@group(0) @binding(2) var t_v_cache: array; +@group(0) @binding(2) var t_v_cache: array>; struct Params { S: u32, @@ -33,33 +33,121 @@ struct Params { override wg_size: u32 = 64; +const TM: u32 = 4u; +const TN: u32 = 4u; + +// Checked loaders mask context lanes past context_len (D%4==0, host-guarded). +fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4 { + var r = vec4(0.0, 0.0, 0.0, 0.0); + if (s >= params.S) { + return r; + } + let base = h * params.S * params.context_len + s * params.context_len; + if (c4 + 0u < params.context_len) { r.x = t_attn_weights_softmax[base + c4 + 0u]; } + if (c4 + 1u < params.context_len) { r.y = t_attn_weights_softmax[base + c4 + 1u]; } + if (c4 + 2u < params.context_len) { r.z = t_attn_weights_softmax[base + c4 + 2u]; } + if (c4 + 3u < params.context_len) { r.w = t_attn_weights_softmax[base + c4 + 3u]; } + return r; +} + +fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4 { + if (c >= params.context_len) { + return vec4(0.0, 0.0, 0.0, 0.0); + } + let base = c * params.Hkv * params.D + kvh * params.D + d0; + return t_v_cache[base / 4u]; +} + +// Branch-free loaders for the aligned body: caller guarantees c4..c4+3 < context_len. +fn load_a_vec4_nc(s: u32, h: u32, c4: u32) -> vec4 { + if (s >= params.S) { + return vec4(0.0, 0.0, 0.0, 0.0); + } + let base = h * params.S * params.context_len + s * params.context_len + c4; + return vec4(t_attn_weights_softmax[base], t_attn_weights_softmax[base + 1u], t_attn_weights_softmax[base + 2u], t_attn_weights_softmax[base + 3u]); +} + +fn load_v_d4_nc(c: u32, kvh: u32, d0: u32) -> vec4 { + let base = c * params.Hkv * params.D + kvh * params.D + d0; + return t_v_cache[base / 4u]; +} + +fn store_out_vec4(s: u32, d0: u32, h: u32, val: vec4) { + if (s >= params.S) { + return; + } + let idx = s * params.Hq * params.D + h * params.D + d0; + t_out[idx / 4u] = val; +} + @compute @workgroup_size(wg_size, 1, 1) fn main(@builtin(global_invocation_id) gid: vec3) { - let total = params.S * params.Hq * params.D; - let idx = gid.x; - if (idx >= total) { + let nrt = (params.S + TM - 1u) / TM; + let nct = (params.D + TN - 1u) / TN; + let tiles = nrt * nct; + let total = tiles * params.Hq; + if (gid.x >= total) { return; } - let d = idx % params.D; - let h = (idx / params.D) % params.Hq; - let s = idx / (params.D * params.Hq); + let h = gid.x / tiles; + let rem = gid.x % tiles; + let row_tile = rem / nct; + let col_tile = rem % nct; let kvh = h / params.g; + let s0 = row_tile * TM; + let d0 = col_tile * TN; - let aw_base = h * params.S * params.context_len + s * params.context_len; + var acc: array, 4>; + acc[0] = vec4(0.0, 0.0, 0.0, 0.0); + acc[1] = vec4(0.0, 0.0, 0.0, 0.0); + acc[2] = vec4(0.0, 0.0, 0.0, 0.0); + acc[3] = vec4(0.0, 0.0, 0.0, 0.0); - var acc: f32 = 0.0; - var c: u32 = 0u; + // Branch-free aligned body + checked tail; mirrors Vulkan out_tiled.glsl. + let ctx_aligned = params.context_len - (params.context_len & 3u); + var c4: u32 = 0u; loop { - if (c >= params.context_len) { + if (c4 >= ctx_aligned) { break; } - let v_off = c * params.Hkv * params.D + kvh * params.D + d; - acc = acc + t_attn_weights_softmax[aw_base + c] * t_v_cache[v_off]; - c = c + 1u; + let a0 = load_a_vec4_nc(s0 + 0u, h, c4); + let a1 = load_a_vec4_nc(s0 + 1u, h, c4); + let a2 = load_a_vec4_nc(s0 + 2u, h, c4); + let a3 = load_a_vec4_nc(s0 + 3u, h, c4); + let v0 = load_v_d4_nc(c4 + 0u, kvh, d0); + let v1 = load_v_d4_nc(c4 + 1u, kvh, d0); + let v2 = load_v_d4_nc(c4 + 2u, kvh, d0); + let v3 = load_v_d4_nc(c4 + 3u, kvh, d0); + acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3; + acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; + acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; + acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; + c4 = c4 + 4u; + } + if (c4 < params.context_len) { + let a0 = load_a_vec4(s0 + 0u, h, c4); + let a1 = load_a_vec4(s0 + 1u, h, c4); + let a2 = load_a_vec4(s0 + 2u, h, c4); + let a3 = load_a_vec4(s0 + 3u, h, c4); + let v0 = load_v_d4(c4 + 0u, kvh, d0); + let v1 = load_v_d4(c4 + 1u, kvh, d0); + let v2 = load_v_d4(c4 + 2u, kvh, d0); + let v3 = load_v_d4(c4 + 3u, kvh, d0); + acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3; + acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3; + acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3; + acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3; } - t_out[idx] = acc; + var m: u32 = 0u; + loop { + if (m >= TM) { + break; + } + store_out_vec4(s0 + m, d0, h, acc[m]); + m = m + 1u; + } } )"; diff --git a/backends/webgpu/test/ops/sdpa/test_sdpa.py b/backends/webgpu/test/ops/sdpa/test_sdpa.py index b674feae635..1f7a8242591 100644 --- a/backends/webgpu/test/ops/sdpa/test_sdpa.py +++ b/backends/webgpu/test/ops/sdpa/test_sdpa.py @@ -59,6 +59,8 @@ class SdpaConfig: # Llama 3.2 1B shape: realistic prefill (S=128 at pos 0) + decode (S=1 at pos 127). SdpaConfig("llama1b_prefill", 32, 8, 64, 128, 512, 0), SdpaConfig("llama1b_decode", 32, 8, 64, 1, 512, 127), + # D=6 is not a multiple of 4: the WebGPU head_dim%4 guard must reject it at load. + SdpaConfig("reject_d6", 4, 4, 6, 4, 16, 0), ] diff --git a/backends/webgpu/test/test_webgpu_native.cpp b/backends/webgpu/test/test_webgpu_native.cpp index 8d987578aa1..6b57254ca33 100644 --- a/backends/webgpu/test/test_webgpu_native.cpp +++ b/backends/webgpu/test/test_webgpu_native.cpp @@ -719,6 +719,7 @@ struct SdpaConfig { int input_pos; // prior tokens already in the cache (decode) float denom; // ramp divisor (mirrors Python); small -> large logits bool required = false; // CI (SDPA dir set): absent .pte = FAIL, not skip + bool expect_reject = false; // load MUST fail (e.g. D%4 guard), no golden }; static const SdpaConfig kSdpaConfigs[] = { @@ -738,6 +739,17 @@ static const SdpaConfig kSdpaConfigs[] = { // pos 127). {"llama1b_prefill", 32, 8, 64, 128, 512, 0, 16.0f}, {"llama1b_decode", 32, 8, 64, 1, 512, 127, 16.0f}, + // D=6 is not a multiple of 4: the head_dim%4 guard must reject it at load. + {"reject_d6", + 4, + 4, + 6, + 4, + 16, + 0, + 16.0f, + /*required=*/false, + /*expect_reject=*/true}, }; // Ramp denominator; mirror of test_sdpa.py::_RAMP_DENOM (keep in sync). @@ -791,6 +803,15 @@ static bool test_sdpa_config( Module module(model_path); auto err = module.load_forward(); + if (cfg.expect_reject) { + // D not a multiple of 4 must be rejected at load by the head_dim guard. + if (err != Error::Ok) { + printf("PASS: %s rejected at load (error %d)\n", cfg.name, (int)err); + return true; + } + printf("FAIL: %s loaded OK; head_dim%%4 guard did not fire\n", cfg.name); + return false; + } if (err != Error::Ok) { printf("FAIL: could not load forward method (error %d)\n", (int)err); return false;