Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/webgpu/runtime/ops/sdpa/Sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,11 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
if (k.dims[kn - 1] != D || v.dims[v.dims.size() - 1] != D) {
throw std::runtime_error("WebGPU sdpa: k/v head_dim must match q");
}
// QK/AV read D as vec4 (no SDPA_PAD_D); head_dim must be a multiple of 4.
if (D % 4 != 0) {
throw std::runtime_error(
"WebGPU sdpa: head_dim (D) must be a multiple of 4");
}
if (v.dims[v.dims.size() - 2] != Hkv) {
throw std::runtime_error("WebGPU sdpa: v num_heads must match k");
}
Expand Down
31 changes: 12 additions & 19 deletions backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@group(0) @binding(0) var<storage, read_write> t_attn_weights: array<f32>;
@group(0) @binding(1) var<storage, read> t_q: array<f32>;
@group(0) @binding(2) var<storage, read> t_k_cache: array<f32>;
@group(0) @binding(1) var<storage, read> t_q: array<vec4<f32>>;
@group(0) @binding(2) var<storage, read> t_k_cache: array<vec4<f32>>;

struct Params {
S: u32,
Expand All @@ -22,30 +22,21 @@ override wg_size: u32 = 64;
const TM: u32 = 4u;
const TN: u32 = 4u;

// D is a multiple of 4 (host-guarded), so a d4 chunk is fully in-bounds — no per-lane check.
fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4<f32> {
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
if (s >= params.S) {
return r;
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
let base = s * params.Hq * params.D + h * params.D;
if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; }
if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; }
if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; }
if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; }
return r;
let base = s * params.Hq * params.D + h * params.D + d4;
return t_q[base / 4u];
}

fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4<f32> {
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
if (c >= params.context_len) {
return r;
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
let base = c * params.Hkv * params.D + kvh * params.D;
if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; }
if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; }
if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; }
if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; }
return r;
let base = c * params.Hkv * params.D + kvh * params.D + d4;
return t_k_cache[base / 4u];
}

fn store_qk(s: u32, c: u32, h: u32, raw: f32) {
Expand Down Expand Up @@ -85,9 +76,11 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
acc[2] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
acc[3] = vec4<f32>(0.0, 0.0, 0.0, 0.0);

// Skip fully-masked causal tiles; mirrors Vulkan attn_weights_tiled.glsl.
let skip_tile = c0 > s0 + (TM - 1u) + params.input_pos;
var d4: u32 = 0u;
loop {
if (d4 >= params.D) {
if (d4 >= params.D || skip_tile) {
break;
}
var q: array<vec4<f32>, TM>;
Expand Down
33 changes: 13 additions & 20 deletions backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
namespace executorch::backends::webgpu {

// @generated from sdpa_compute_attn_weights.wgsl - DO NOT EDIT.
// wgsl-sha256: e9bdc2272bba2716655e96be0597571701b4ec1496f78a85660e29d108f655f9
// wgsl-sha256: 4eef09b234fd926cdc0daf18d03e39cf4fd57dfa4bc67724b4878b7dc68d1254
inline constexpr const char* kSdpaComputeAttnWeightsWGSL = R"(
@group(0) @binding(0) var<storage, read_write> t_attn_weights: array<f32>;
@group(0) @binding(1) var<storage, read> t_q: array<f32>;
@group(0) @binding(2) var<storage, read> t_k_cache: array<f32>;
@group(0) @binding(1) var<storage, read> t_q: array<vec4<f32>>;
@group(0) @binding(2) var<storage, read> t_k_cache: array<vec4<f32>>;

struct Params {
S: u32,
Expand All @@ -39,30 +39,21 @@ override wg_size: u32 = 64;
const TM: u32 = 4u;
const TN: u32 = 4u;

// D is a multiple of 4 (host-guarded), so a d4 chunk is fully in-bounds — no per-lane check.
fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4<f32> {
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
if (s >= params.S) {
return r;
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
let base = s * params.Hq * params.D + h * params.D;
if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; }
if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; }
if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; }
if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; }
return r;
let base = s * params.Hq * params.D + h * params.D + d4;
return t_q[base / 4u];
}

fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4<f32> {
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
if (c >= params.context_len) {
return r;
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
let base = c * params.Hkv * params.D + kvh * params.D;
if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; }
if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; }
if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; }
if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; }
return r;
let base = c * params.Hkv * params.D + kvh * params.D + d4;
return t_k_cache[base / 4u];
}

fn store_qk(s: u32, c: u32, h: u32, raw: f32) {
Expand Down Expand Up @@ -102,9 +93,11 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
acc[2] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
acc[3] = vec4<f32>(0.0, 0.0, 0.0, 0.0);

// Skip fully-masked causal tiles; mirrors Vulkan attn_weights_tiled.glsl.
let skip_tile = c0 > s0 + (TM - 1u) + params.input_pos;
var d4: u32 = 0u;
loop {
if (d4 >= params.D) {
if (d4 >= params.D || skip_tile) {
break;
}
var q: array<vec4<f32>, TM>;
Expand Down
62 changes: 42 additions & 20 deletions backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
@group(0) @binding(0) var<storage, read_write> t_out: array<vec4<f32>>;
@group(0) @binding(1) var<storage, read> t_attn_weights_softmax: array<f32>;
@group(0) @binding(2) var<storage, read> t_v_cache: array<f32>;
@group(0) @binding(2) var<storage, read> t_v_cache: array<vec4<f32>>;

struct Params {
S: u32,
Expand All @@ -19,6 +19,7 @@ override wg_size: u32 = 64;
const TM: u32 = 4u;
const TN: u32 = 4u;

// Checked loaders mask context lanes past context_len (D%4==0, host-guarded).
fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4<f32> {
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
if (s >= params.S) {
Expand All @@ -33,24 +34,33 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4<f32> {
}

fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4<f32> {
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
if (c >= params.context_len) {
return r;
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
let base = c * params.Hkv * params.D + kvh * params.D + d0;
if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; }
if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; }
if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; }
if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; }
return r;
return t_v_cache[base / 4u];
}

// Branch-free loaders for the aligned body: caller guarantees c4..c4+3 < context_len.
fn load_a_vec4_nc(s: u32, h: u32, c4: u32) -> vec4<f32> {
if (s >= params.S) {
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
let base = h * params.S * params.context_len + s * params.context_len + c4;
return vec4<f32>(t_attn_weights_softmax[base], t_attn_weights_softmax[base + 1u], t_attn_weights_softmax[base + 2u], t_attn_weights_softmax[base + 3u]);
}

fn load_v_d4_nc(c: u32, kvh: u32, d0: u32) -> vec4<f32> {
let base = c * params.Hkv * params.D + kvh * params.D + d0;
return t_v_cache[base / 4u];
}

fn store_out(s: u32, d: u32, h: u32, val: f32) {
if (s >= params.S || d >= params.D) {
fn store_out_vec4(s: u32, d0: u32, h: u32, val: vec4<f32>) {
if (s >= params.S) {
return;
}
let idx = s * params.Hq * params.D + h * params.D + d;
t_out[idx] = val;
let idx = s * params.Hq * params.D + h * params.D + d0;
t_out[idx / 4u] = val;
}

@compute @workgroup_size(wg_size, 1, 1)
Expand All @@ -77,11 +87,28 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
acc[2] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
acc[3] = vec4<f32>(0.0, 0.0, 0.0, 0.0);

// Branch-free aligned body + checked tail; mirrors Vulkan out_tiled.glsl.
let ctx_aligned = params.context_len - (params.context_len & 3u);
var c4: u32 = 0u;
loop {
if (c4 >= params.context_len) {
if (c4 >= ctx_aligned) {
break;
}
let a0 = load_a_vec4_nc(s0 + 0u, h, c4);
let a1 = load_a_vec4_nc(s0 + 1u, h, c4);
let a2 = load_a_vec4_nc(s0 + 2u, h, c4);
let a3 = load_a_vec4_nc(s0 + 3u, h, c4);
let v0 = load_v_d4_nc(c4 + 0u, kvh, d0);
let v1 = load_v_d4_nc(c4 + 1u, kvh, d0);
let v2 = load_v_d4_nc(c4 + 2u, kvh, d0);
let v3 = load_v_d4_nc(c4 + 3u, kvh, d0);
acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3;
acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3;
acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3;
acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3;
c4 = c4 + 4u;
}
if (c4 < params.context_len) {
let a0 = load_a_vec4(s0 + 0u, h, c4);
let a1 = load_a_vec4(s0 + 1u, h, c4);
let a2 = load_a_vec4(s0 + 2u, h, c4);
Expand All @@ -94,19 +121,14 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3;
acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3;
acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3;
c4 = c4 + 4u;
}

var m: u32 = 0u;
loop {
if (m >= TM) {
break;
}
let ov = acc[m];
store_out(s0 + m, d0 + 0u, h, ov.x);
store_out(s0 + m, d0 + 1u, h, ov.y);
store_out(s0 + m, d0 + 2u, h, ov.z);
store_out(s0 + m, d0 + 3u, h, ov.w);
store_out_vec4(s0 + m, d0, h, acc[m]);
m = m + 1u;
}
}
64 changes: 43 additions & 21 deletions backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
namespace executorch::backends::webgpu {

// @generated from sdpa_compute_out.wgsl - DO NOT EDIT.
// wgsl-sha256: 545f624567b08eba407954034df821010e49124fa6f8fd6b05c64ca4354ee4cc
// wgsl-sha256: 2ffa0eb520b1054e43a10fd13e6b287bd35777f1cfc29bd39e9d668772528191
inline constexpr const char* kSdpaComputeOutWGSL = R"(
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
@group(0) @binding(0) var<storage, read_write> t_out: array<vec4<f32>>;
@group(0) @binding(1) var<storage, read> t_attn_weights_softmax: array<f32>;
@group(0) @binding(2) var<storage, read> t_v_cache: array<f32>;
@group(0) @binding(2) var<storage, read> t_v_cache: array<vec4<f32>>;

struct Params {
S: u32,
Expand All @@ -36,6 +36,7 @@ override wg_size: u32 = 64;
const TM: u32 = 4u;
const TN: u32 = 4u;

// Checked loaders mask context lanes past context_len (D%4==0, host-guarded).
fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4<f32> {
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
if (s >= params.S) {
Expand All @@ -50,24 +51,33 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4<f32> {
}

fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4<f32> {
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
if (c >= params.context_len) {
return r;
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
let base = c * params.Hkv * params.D + kvh * params.D + d0;
if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; }
if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; }
if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; }
if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; }
return r;
return t_v_cache[base / 4u];
}

// Branch-free loaders for the aligned body: caller guarantees c4..c4+3 < context_len.
fn load_a_vec4_nc(s: u32, h: u32, c4: u32) -> vec4<f32> {
if (s >= params.S) {
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
let base = h * params.S * params.context_len + s * params.context_len + c4;
return vec4<f32>(t_attn_weights_softmax[base], t_attn_weights_softmax[base + 1u], t_attn_weights_softmax[base + 2u], t_attn_weights_softmax[base + 3u]);
}

fn load_v_d4_nc(c: u32, kvh: u32, d0: u32) -> vec4<f32> {
let base = c * params.Hkv * params.D + kvh * params.D + d0;
return t_v_cache[base / 4u];
}

fn store_out(s: u32, d: u32, h: u32, val: f32) {
if (s >= params.S || d >= params.D) {
fn store_out_vec4(s: u32, d0: u32, h: u32, val: vec4<f32>) {
if (s >= params.S) {
return;
}
let idx = s * params.Hq * params.D + h * params.D + d;
t_out[idx] = val;
let idx = s * params.Hq * params.D + h * params.D + d0;
t_out[idx / 4u] = val;
}

@compute @workgroup_size(wg_size, 1, 1)
Expand All @@ -94,11 +104,28 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
acc[2] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
acc[3] = vec4<f32>(0.0, 0.0, 0.0, 0.0);

// Branch-free aligned body + checked tail; mirrors Vulkan out_tiled.glsl.
let ctx_aligned = params.context_len - (params.context_len & 3u);
var c4: u32 = 0u;
loop {
if (c4 >= params.context_len) {
if (c4 >= ctx_aligned) {
break;
}
let a0 = load_a_vec4_nc(s0 + 0u, h, c4);
let a1 = load_a_vec4_nc(s0 + 1u, h, c4);
let a2 = load_a_vec4_nc(s0 + 2u, h, c4);
let a3 = load_a_vec4_nc(s0 + 3u, h, c4);
let v0 = load_v_d4_nc(c4 + 0u, kvh, d0);
let v1 = load_v_d4_nc(c4 + 1u, kvh, d0);
let v2 = load_v_d4_nc(c4 + 2u, kvh, d0);
let v3 = load_v_d4_nc(c4 + 3u, kvh, d0);
acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3;
acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3;
acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3;
acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3;
c4 = c4 + 4u;
}
if (c4 < params.context_len) {
let a0 = load_a_vec4(s0 + 0u, h, c4);
let a1 = load_a_vec4(s0 + 1u, h, c4);
let a2 = load_a_vec4(s0 + 2u, h, c4);
Expand All @@ -111,19 +138,14 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3;
acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3;
acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3;
c4 = c4 + 4u;
}

var m: u32 = 0u;
loop {
if (m >= TM) {
break;
}
let ov = acc[m];
store_out(s0 + m, d0 + 0u, h, ov.x);
store_out(s0 + m, d0 + 1u, h, ov.y);
store_out(s0 + m, d0 + 2u, h, ov.z);
store_out(s0 + m, d0 + 3u, h, ov.w);
store_out_vec4(s0 + m, d0, h, acc[m]);
m = m + 1u;
}
}
Expand Down
2 changes: 2 additions & 0 deletions backends/webgpu/test/ops/sdpa/test_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class SdpaConfig:
# Llama 3.2 1B shape: realistic prefill (S=128 at pos 0) + decode (S=1 at pos 127).
SdpaConfig("llama1b_prefill", 32, 8, 64, 128, 512, 0),
SdpaConfig("llama1b_decode", 32, 8, 64, 1, 512, 127),
# D=6 is not a multiple of 4: the WebGPU head_dim%4 guard must reject it at load.
SdpaConfig("reject_d6", 4, 4, 6, 4, 16, 0),
]


Expand Down
Loading
Loading