Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions backends/webgpu/runtime/ops/sdpa/Sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ namespace executorch::backends::webgpu {

namespace {

// Register-tile dims; MUST match TM/TN in the reg WGSL kernels.
constexpr int64_t kSdpaTileM = 4;
constexpr int64_t kSdpaTileN = 4;

// Uniform param structs (all 16-byte aligned, matching the WGSL Params).
struct UpdateCacheParams {
uint32_t numel;
Expand Down Expand Up @@ -464,14 +468,16 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
dynamic_pos,
"update_cache(V)");

// --- Dispatch 3: QK -> attn_weights. One thread per (h,s,c) element.
// --- Dispatch 3: QK -> attn_weights. One thread per TM x TN tile.
{
if (aw_floats > UINT32_MAX) {
throw std::runtime_error(
"WebGPU sdpa: Hq*S*context_len exceeds uint32 max");
}
const int64_t qk_tiles = Hq * utils::div_up(S, kSdpaTileM) *
utils::div_up(context_len, kSdpaTileN);
const uint32_t wgc = utils::compute_1d_workgroup_count(
device, static_cast<uint32_t>(aw_floats), qk_wg, "QK");
device, static_cast<uint32_t>(qk_tiles), qk_wg, "QK");
AttnWeightsParams p = make_attn_weights_params(
S, Hq, Hkv, D, context_len, input_pos, g, scale);
WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p));
Expand Down Expand Up @@ -515,12 +521,12 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
softmax_buf = ubuf;
}

// --- Dispatch 5: AV -> out. One thread per (s,h,d) output element.
// --- Dispatch 5: AV -> out. One thread per TM x TN tile.
{
const uint64_t out_floats = static_cast<uint64_t>(S) *
static_cast<uint64_t>(Hq) * static_cast<uint64_t>(D);
const int64_t av_tiles =
Hq * utils::div_up(S, kSdpaTileM) * utils::div_up(D, kSdpaTileN);
const uint32_t wgc = utils::compute_1d_workgroup_count(
device, static_cast<uint32_t>(out_floats), av_wg, "AV");
device, static_cast<uint32_t>(av_tiles), av_wg, "AV");
ComputeOutParams p = make_compute_out_params(S, Hq, Hkv, D, context_len, g);
WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p));
BufferBinding bindings[3] = {
Expand Down Expand Up @@ -591,9 +597,11 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
AttnWeightsParams qp =
make_attn_weights_params(S, Hq, Hkv, D, ctx, pos, g, scale);
wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp));
const int64_t qk_tiles = Hq * utils::div_up(S, kSdpaTileM) *
utils::div_up(ctx, kSdpaTileN);
const uint32_t qk_wgc = utils::compute_1d_workgroup_count(
gr.device(),
static_cast<uint32_t>(aw_floats),
static_cast<uint32_t>(qk_tiles),
qk_wg,
"QK(resize)");
gr.dispatch_at(qk_idx).workgroup_count_x = qk_wgc;
Expand Down
106 changes: 87 additions & 19 deletions backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,105 @@ const NEG_INF: f32 = -1.0e30;

override wg_size: u32 = 64;

const TM: u32 = 4u;
const TN: u32 = 4u;

fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4<f32> {
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
if (s >= params.S) {
return r;
}
let base = s * params.Hq * params.D + h * params.D;
if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; }
if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; }
if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; }
if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; }
return r;
}

fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4<f32> {
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
if (c >= params.context_len) {
return r;
}
let base = c * params.Hkv * params.D + kvh * params.D;
if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; }
if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; }
if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; }
if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; }
return r;
}

fn store_qk(s: u32, c: u32, h: u32, raw: f32) {
if (s >= params.S || c >= params.context_len) {
return;
}
var val = raw * params.scale;
// Causal mask: position c may not attend beyond s + input_pos.
if (c > s + params.input_pos) {
val = NEG_INF;
}
let idx = h * params.S * params.context_len + s * params.context_len + c;
t_attn_weights[idx] = val;
}

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let total = params.Hq * params.S * params.context_len;
let idx = gid.x;
if (idx >= total) {
let nrt = (params.S + TM - 1u) / TM;
let nct = (params.context_len + TN - 1u) / TN;
let tiles = nrt * nct;
let total = tiles * params.Hq;
if (gid.x >= total) {
return;
}
let c = idx % params.context_len;
let s = (idx / params.context_len) % params.S;
let h = idx / (params.context_len * params.S);

let h = gid.x / tiles;
let rem = gid.x % tiles;
let row_tile = rem / nct;
let col_tile = rem % nct;
let kvh = h / params.g;
let s0 = row_tile * TM;
let c0 = col_tile * TN;

let q_base = s * params.Hq * params.D + h * params.D;
let k_base = c * params.Hkv * params.D + kvh * params.D;
var acc: array<vec4<f32>, 4>;
acc[0] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
acc[1] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
acc[2] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
acc[3] = vec4<f32>(0.0, 0.0, 0.0, 0.0);

var acc: f32 = 0.0;
var d: u32 = 0u;
var d4: u32 = 0u;
loop {
if (d >= params.D) {
if (d4 >= params.D) {
break;
}
acc = acc + t_q[q_base + d] * t_k_cache[k_base + d];
d = d + 1u;
var q: array<vec4<f32>, TM>;
var k: array<vec4<f32>, TN>;
for (var i: u32 = 0u; i < TM; i = i + 1u) {
q[i] = load_q_vec4(s0 + i, h, d4);
}
for (var j: u32 = 0u; j < TN; j = j + 1u) {
k[j] = load_k_vec4(c0 + j, kvh, d4);
}
for (var i: u32 = 0u; i < TM; i = i + 1u) {
acc[i] += vec4<f32>(
dot(q[i], k[0]),
dot(q[i], k[1]),
dot(q[i], k[2]),
dot(q[i], k[3]));
}
d4 = d4 + 4u;
}
acc = acc * params.scale;

// Causal mask: position c may not attend beyond s + input_pos.
if (c > s + params.input_pos) {
acc = NEG_INF;
var m: u32 = 0u;
loop {
if (m >= TM) {
break;
}
let av = acc[m];
store_qk(s0 + m, c0 + 0u, h, av.x);
store_qk(s0 + m, c0 + 1u, h, av.y);
store_qk(s0 + m, c0 + 2u, h, av.z);
store_qk(s0 + m, c0 + 3u, h, av.w);
m = m + 1u;
}

t_attn_weights[idx] = acc;
}
108 changes: 88 additions & 20 deletions backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace executorch::backends::webgpu {

// @generated from sdpa_compute_attn_weights.wgsl - DO NOT EDIT.
// wgsl-sha256: 7410869c1c35f09777851bf49b835dc8fecaff3f327aa64a9c900ac0cc3445e1
// wgsl-sha256: e9bdc2272bba2716655e96be0597571701b4ec1496f78a85660e29d108f655f9
inline constexpr const char* kSdpaComputeAttnWeightsWGSL = R"(
@group(0) @binding(0) var<storage, read_write> t_attn_weights: array<f32>;
@group(0) @binding(1) var<storage, read> t_q: array<f32>;
Expand All @@ -36,39 +36,107 @@ const NEG_INF: f32 = -1.0e30;

override wg_size: u32 = 64;

const TM: u32 = 4u;
const TN: u32 = 4u;

fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4<f32> {
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
if (s >= params.S) {
return r;
}
let base = s * params.Hq * params.D + h * params.D;
if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; }
if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; }
if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; }
if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; }
return r;
}

fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4<f32> {
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
if (c >= params.context_len) {
return r;
}
let base = c * params.Hkv * params.D + kvh * params.D;
if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; }
if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; }
if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; }
if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; }
return r;
}

fn store_qk(s: u32, c: u32, h: u32, raw: f32) {
if (s >= params.S || c >= params.context_len) {
return;
}
var val = raw * params.scale;
// Causal mask: position c may not attend beyond s + input_pos.
if (c > s + params.input_pos) {
val = NEG_INF;
}
let idx = h * params.S * params.context_len + s * params.context_len + c;
t_attn_weights[idx] = val;
}

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let total = params.Hq * params.S * params.context_len;
let idx = gid.x;
if (idx >= total) {
let nrt = (params.S + TM - 1u) / TM;
let nct = (params.context_len + TN - 1u) / TN;
let tiles = nrt * nct;
let total = tiles * params.Hq;
if (gid.x >= total) {
return;
}
let c = idx % params.context_len;
let s = (idx / params.context_len) % params.S;
let h = idx / (params.context_len * params.S);

let h = gid.x / tiles;
let rem = gid.x % tiles;
let row_tile = rem / nct;
let col_tile = rem % nct;
let kvh = h / params.g;
let s0 = row_tile * TM;
let c0 = col_tile * TN;

let q_base = s * params.Hq * params.D + h * params.D;
let k_base = c * params.Hkv * params.D + kvh * params.D;
var acc: array<vec4<f32>, 4>;
acc[0] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
acc[1] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
acc[2] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
acc[3] = vec4<f32>(0.0, 0.0, 0.0, 0.0);

var acc: f32 = 0.0;
var d: u32 = 0u;
var d4: u32 = 0u;
loop {
if (d >= params.D) {
if (d4 >= params.D) {
break;
}
acc = acc + t_q[q_base + d] * t_k_cache[k_base + d];
d = d + 1u;
var q: array<vec4<f32>, TM>;
var k: array<vec4<f32>, TN>;
for (var i: u32 = 0u; i < TM; i = i + 1u) {
q[i] = load_q_vec4(s0 + i, h, d4);
}
for (var j: u32 = 0u; j < TN; j = j + 1u) {
k[j] = load_k_vec4(c0 + j, kvh, d4);
}
for (var i: u32 = 0u; i < TM; i = i + 1u) {
acc[i] += vec4<f32>(
dot(q[i], k[0]),
dot(q[i], k[1]),
dot(q[i], k[2]),
dot(q[i], k[3]));
}
d4 = d4 + 4u;
}
acc = acc * params.scale;

// Causal mask: position c may not attend beyond s + input_pos.
if (c > s + params.input_pos) {
acc = NEG_INF;
var m: u32 = 0u;
loop {
if (m >= TM) {
break;
}
let av = acc[m];
store_qk(s0 + m, c0 + 0u, h, av.x);
store_qk(s0 + m, c0 + 1u, h, av.y);
store_qk(s0 + m, c0 + 2u, h, av.z);
store_qk(s0 + m, c0 + 3u, h, av.w);
m = m + 1u;
}

t_attn_weights[idx] = acc;
}
)";

Expand Down
Loading
Loading