Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions backends/webgpu/runtime/ops/sdpa/Sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ namespace executorch::backends::webgpu {

namespace {

// Register-tile dims; MUST match TM/TN in the reg WGSL kernels.
constexpr int64_t kSdpaTileM = 4;
constexpr int64_t kSdpaTileN = 4;

// Uniform param structs (all 16-byte aligned, matching the WGSL Params).
struct UpdateCacheParams {
uint32_t numel;
Expand Down Expand Up @@ -335,6 +339,11 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
if (k.dims[kn - 1] != D || v.dims[v.dims.size() - 1] != D) {
throw std::runtime_error("WebGPU sdpa: k/v head_dim must match q");
}
// QK/AV read D as vec4 (no SDPA_PAD_D); head_dim must be a multiple of 4.
if (D % 4 != 0) {
throw std::runtime_error(
"WebGPU sdpa: head_dim (D) must be a multiple of 4");
}
if (v.dims[v.dims.size() - 2] != Hkv) {
throw std::runtime_error("WebGPU sdpa: v num_heads must match k");
}
Expand Down Expand Up @@ -464,14 +473,16 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
dynamic_pos,
"update_cache(V)");

// --- Dispatch 3: QK -> attn_weights. One thread per (h,s,c) element.
// --- Dispatch 3: QK -> attn_weights. One thread per TM x TN tile.
{
if (aw_floats > UINT32_MAX) {
throw std::runtime_error(
"WebGPU sdpa: Hq*S*context_len exceeds uint32 max");
}
const int64_t qk_tiles = Hq * utils::div_up(S, kSdpaTileM) *
utils::div_up(context_len, kSdpaTileN);
const uint32_t wgc = utils::compute_1d_workgroup_count(
device, static_cast<uint32_t>(aw_floats), qk_wg, "QK");
device, static_cast<uint32_t>(qk_tiles), qk_wg, "QK");
AttnWeightsParams p = make_attn_weights_params(
S, Hq, Hkv, D, context_len, input_pos, g, scale);
WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p));
Expand Down Expand Up @@ -515,12 +526,12 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
softmax_buf = ubuf;
}

// --- Dispatch 5: AV -> out. One thread per (s,h,d) output element.
// --- Dispatch 5: AV -> out. One thread per TM x TN tile.
{
const uint64_t out_floats = static_cast<uint64_t>(S) *
static_cast<uint64_t>(Hq) * static_cast<uint64_t>(D);
const int64_t av_tiles =
Hq * utils::div_up(S, kSdpaTileM) * utils::div_up(D, kSdpaTileN);
const uint32_t wgc = utils::compute_1d_workgroup_count(
device, static_cast<uint32_t>(out_floats), av_wg, "AV");
device, static_cast<uint32_t>(av_tiles), av_wg, "AV");
ComputeOutParams p = make_compute_out_params(S, Hq, Hkv, D, context_len, g);
WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p));
BufferBinding bindings[3] = {
Expand Down Expand Up @@ -591,9 +602,11 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
AttnWeightsParams qp =
make_attn_weights_params(S, Hq, Hkv, D, ctx, pos, g, scale);
wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp));
const int64_t qk_tiles = Hq * utils::div_up(S, kSdpaTileM) *
utils::div_up(ctx, kSdpaTileN);
const uint32_t qk_wgc = utils::compute_1d_workgroup_count(
gr.device(),
static_cast<uint32_t>(aw_floats),
static_cast<uint32_t>(qk_tiles),
qk_wg,
"QK(resize)");
gr.dispatch_at(qk_idx).workgroup_count_x = qk_wgc;
Expand Down
103 changes: 82 additions & 21 deletions backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@group(0) @binding(0) var<storage, read_write> t_attn_weights: array<f32>;
@group(0) @binding(1) var<storage, read> t_q: array<f32>;
@group(0) @binding(2) var<storage, read> t_k_cache: array<f32>;
@group(0) @binding(1) var<storage, read> t_q: array<vec4<f32>>;
@group(0) @binding(2) var<storage, read> t_k_cache: array<vec4<f32>>;

struct Params {
S: u32,
Expand All @@ -19,37 +19,98 @@ const NEG_INF: f32 = -1.0e30;

override wg_size: u32 = 64;

const TM: u32 = 4u;
const TN: u32 = 4u;

// D is a multiple of 4 (host-guarded), so a d4 chunk is fully in-bounds — no per-lane check.
fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4<f32> {
if (s >= params.S) {
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
let base = s * params.Hq * params.D + h * params.D + d4;
return t_q[base / 4u];
}

fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4<f32> {
if (c >= params.context_len) {
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
let base = c * params.Hkv * params.D + kvh * params.D + d4;
return t_k_cache[base / 4u];
}

fn store_qk(s: u32, c: u32, h: u32, raw: f32) {
if (s >= params.S || c >= params.context_len) {
return;
}
var val = raw * params.scale;
// Causal mask: position c may not attend beyond s + input_pos.
if (c > s + params.input_pos) {
val = NEG_INF;
}
let idx = h * params.S * params.context_len + s * params.context_len + c;
t_attn_weights[idx] = val;
}

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let total = params.Hq * params.S * params.context_len;
let idx = gid.x;
if (idx >= total) {
let nrt = (params.S + TM - 1u) / TM;
let nct = (params.context_len + TN - 1u) / TN;
let tiles = nrt * nct;
let total = tiles * params.Hq;
if (gid.x >= total) {
return;
}
let c = idx % params.context_len;
let s = (idx / params.context_len) % params.S;
let h = idx / (params.context_len * params.S);

let h = gid.x / tiles;
let rem = gid.x % tiles;
let row_tile = rem / nct;
let col_tile = rem % nct;
let kvh = h / params.g;
let s0 = row_tile * TM;
let c0 = col_tile * TN;

let q_base = s * params.Hq * params.D + h * params.D;
let k_base = c * params.Hkv * params.D + kvh * params.D;
var acc: array<vec4<f32>, 4>;
acc[0] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
acc[1] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
acc[2] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
acc[3] = vec4<f32>(0.0, 0.0, 0.0, 0.0);

var acc: f32 = 0.0;
var d: u32 = 0u;
// Skip fully-masked causal tiles; mirrors Vulkan attn_weights_tiled.glsl.
let skip_tile = c0 > s0 + (TM - 1u) + params.input_pos;
var d4: u32 = 0u;
loop {
if (d >= params.D) {
if (d4 >= params.D || skip_tile) {
break;
}
acc = acc + t_q[q_base + d] * t_k_cache[k_base + d];
d = d + 1u;
var q: array<vec4<f32>, TM>;
var k: array<vec4<f32>, TN>;
for (var i: u32 = 0u; i < TM; i = i + 1u) {
q[i] = load_q_vec4(s0 + i, h, d4);
}
for (var j: u32 = 0u; j < TN; j = j + 1u) {
k[j] = load_k_vec4(c0 + j, kvh, d4);
}
for (var i: u32 = 0u; i < TM; i = i + 1u) {
acc[i] += vec4<f32>(
dot(q[i], k[0]),
dot(q[i], k[1]),
dot(q[i], k[2]),
dot(q[i], k[3]));
}
d4 = d4 + 4u;
}
acc = acc * params.scale;

// Causal mask: position c may not attend beyond s + input_pos.
if (c > s + params.input_pos) {
acc = NEG_INF;
var m: u32 = 0u;
loop {
if (m >= TM) {
break;
}
let av = acc[m];
store_qk(s0 + m, c0 + 0u, h, av.x);
store_qk(s0 + m, c0 + 1u, h, av.y);
store_qk(s0 + m, c0 + 2u, h, av.z);
store_qk(s0 + m, c0 + 3u, h, av.w);
m = m + 1u;
}

t_attn_weights[idx] = acc;
}
105 changes: 83 additions & 22 deletions backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
namespace executorch::backends::webgpu {

// @generated from sdpa_compute_attn_weights.wgsl - DO NOT EDIT.
// wgsl-sha256: 7410869c1c35f09777851bf49b835dc8fecaff3f327aa64a9c900ac0cc3445e1
// wgsl-sha256: 4eef09b234fd926cdc0daf18d03e39cf4fd57dfa4bc67724b4878b7dc68d1254
inline constexpr const char* kSdpaComputeAttnWeightsWGSL = R"(
@group(0) @binding(0) var<storage, read_write> t_attn_weights: array<f32>;
@group(0) @binding(1) var<storage, read> t_q: array<f32>;
@group(0) @binding(2) var<storage, read> t_k_cache: array<f32>;
@group(0) @binding(1) var<storage, read> t_q: array<vec4<f32>>;
@group(0) @binding(2) var<storage, read> t_k_cache: array<vec4<f32>>;

struct Params {
S: u32,
Expand All @@ -36,39 +36,100 @@ const NEG_INF: f32 = -1.0e30;

override wg_size: u32 = 64;

const TM: u32 = 4u;
const TN: u32 = 4u;

// D is a multiple of 4 (host-guarded), so a d4 chunk is fully in-bounds — no per-lane check.
fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4<f32> {
if (s >= params.S) {
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
let base = s * params.Hq * params.D + h * params.D + d4;
return t_q[base / 4u];
}

fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4<f32> {
if (c >= params.context_len) {
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
let base = c * params.Hkv * params.D + kvh * params.D + d4;
return t_k_cache[base / 4u];
}

fn store_qk(s: u32, c: u32, h: u32, raw: f32) {
if (s >= params.S || c >= params.context_len) {
return;
}
var val = raw * params.scale;
// Causal mask: position c may not attend beyond s + input_pos.
if (c > s + params.input_pos) {
val = NEG_INF;
}
let idx = h * params.S * params.context_len + s * params.context_len + c;
t_attn_weights[idx] = val;
}

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let total = params.Hq * params.S * params.context_len;
let idx = gid.x;
if (idx >= total) {
let nrt = (params.S + TM - 1u) / TM;
let nct = (params.context_len + TN - 1u) / TN;
let tiles = nrt * nct;
let total = tiles * params.Hq;
if (gid.x >= total) {
return;
}
let c = idx % params.context_len;
let s = (idx / params.context_len) % params.S;
let h = idx / (params.context_len * params.S);

let h = gid.x / tiles;
let rem = gid.x % tiles;
let row_tile = rem / nct;
let col_tile = rem % nct;
let kvh = h / params.g;
let s0 = row_tile * TM;
let c0 = col_tile * TN;

let q_base = s * params.Hq * params.D + h * params.D;
let k_base = c * params.Hkv * params.D + kvh * params.D;
var acc: array<vec4<f32>, 4>;
acc[0] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
acc[1] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
acc[2] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
acc[3] = vec4<f32>(0.0, 0.0, 0.0, 0.0);

var acc: f32 = 0.0;
var d: u32 = 0u;
// Skip fully-masked causal tiles; mirrors Vulkan attn_weights_tiled.glsl.
let skip_tile = c0 > s0 + (TM - 1u) + params.input_pos;
var d4: u32 = 0u;
loop {
if (d >= params.D) {
if (d4 >= params.D || skip_tile) {
break;
}
acc = acc + t_q[q_base + d] * t_k_cache[k_base + d];
d = d + 1u;
var q: array<vec4<f32>, TM>;
var k: array<vec4<f32>, TN>;
for (var i: u32 = 0u; i < TM; i = i + 1u) {
q[i] = load_q_vec4(s0 + i, h, d4);
}
for (var j: u32 = 0u; j < TN; j = j + 1u) {
k[j] = load_k_vec4(c0 + j, kvh, d4);
}
for (var i: u32 = 0u; i < TM; i = i + 1u) {
acc[i] += vec4<f32>(
dot(q[i], k[0]),
dot(q[i], k[1]),
dot(q[i], k[2]),
dot(q[i], k[3]));
}
d4 = d4 + 4u;
}
acc = acc * params.scale;

// Causal mask: position c may not attend beyond s + input_pos.
if (c > s + params.input_pos) {
acc = NEG_INF;
var m: u32 = 0u;
loop {
if (m >= TM) {
break;
}
let av = acc[m];
store_qk(s0 + m, c0 + 0u, h, av.x);
store_qk(s0 + m, c0 + 1u, h, av.y);
store_qk(s0 + m, c0 + 2u, h, av.z);
store_qk(s0 + m, c0 + 3u, h, av.w);
m = m + 1u;
}

t_attn_weights[idx] = acc;
}
)";

Expand Down
Loading
Loading