Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion backends/webgpu/runtime/WebGPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@

namespace executorch::backends::webgpu::utils {

// Ceiling division for non-negative integers (mirrors Vulkan's utils::div_up).
template <typename T>
inline T div_up(T a, T b) {
return (a + b - 1) / b;
}

// Clamp workgroup size to device limit (SwiftShader caps at 128).
inline uint32_t clamp_workgroup_size(WGPUDevice device, uint32_t desired) {
WGPULimits limits = {};
Expand All @@ -34,7 +40,7 @@ inline uint32_t compute_1d_workgroup_count(
uint32_t num_threads,
uint32_t workgroup_size,
const char* op_name) {
uint32_t count = (num_threads + workgroup_size - 1) / workgroup_size;
uint32_t count = div_up(num_threads, workgroup_size);
WGPULimits limits = {};
uint32_t max_count =
wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success &&
Expand Down Expand Up @@ -70,4 +76,16 @@ make_uniform(WGPUDevice device, const void* data, size_t size) {
return buf;
}

// Clamp a 1D workgroup count to the device limit, for grid-stride kernels that
// loop over any excess work (vs compute_1d_workgroup_count, which throws).
inline uint32_t clamp_workgroup_count(WGPUDevice device, uint32_t desired) {
WGPULimits limits = {};
uint32_t max_count =
wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success &&
limits.maxComputeWorkgroupsPerDimension > 0
? limits.maxComputeWorkgroupsPerDimension
: 65535u; // WebGPU spec-default floor
return std::min(desired, max_count);
}

} // namespace executorch::backends::webgpu::utils
47 changes: 38 additions & 9 deletions backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h>
#include <executorch/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h>

#include <webgpu/webgpu.h>
Expand All @@ -34,6 +35,10 @@ struct Q4gswParams {
};
static_assert(sizeof(Q4gswParams) == 32, "Q4gswParams must be 32 bytes");

// Register-tile dims; MUST match TM/TN in q4gsw_linear.wgsl.
constexpr int64_t kQ4gswTileM = 4;
constexpr int64_t kQ4gswTileN = 4;

// et_vk.linear_q4gsw args: [in, weight, scales, group_size, bias, out].
void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
const int in_id = args.at(0);
Expand Down Expand Up @@ -85,10 +90,6 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
"WebGPU linear_q4gsw: N*K_packed must be a multiple of 4 (u32-packed)");
}

// One workgroup per output row (M); validate dispatch before any alloc.
const uint32_t workgroup_count =
utils::compute_1d_workgroup_count(device, M, 1, "linear_q4gsw");

// fp32-only byte-size guards (no runtime dtype); fp16 scales -> bail.
const uint64_t scales_numel =
static_cast<uint64_t>(num_groups) * static_cast<uint64_t>(padded_N);
Expand Down Expand Up @@ -116,6 +117,35 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
"WebGPU linear_q4gsw: scales dims too small for K/N");
}

// M==1 decode -> coop4 GEMV (needs K%8==0 && gs%8==0); else tiled GEMM.
const uint32_t wg_size =
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
const bool use_gemv = (M == 1u && K % 8u == 0u && gs % 8u == 0u);
const char* shader_src = use_gemv ? kQ4gswLinearCoop4WGSL : kQ4gswLinearWGSL;
uint32_t workgroup_count;
if (use_gemv) {
// coop4: fixed 64 lanes, 1 workgroup per output, grid-strided over M*N.
const uint64_t outputs =
static_cast<uint64_t>(M) * static_cast<uint64_t>(N);
if (outputs == 0u || outputs > UINT32_MAX) {
throw std::runtime_error("WebGPU linear_q4gsw: M*N out of range");
}
workgroup_count =
utils::clamp_workgroup_count(device, static_cast<uint32_t>(outputs));
if (workgroup_count == 0u) {
throw std::runtime_error("WebGPU linear_q4gsw: zero GEMV dispatch");
}
} else {
const int64_t total_tiles = utils::div_up<int64_t>(M, kQ4gswTileM) *
utils::div_up<int64_t>(N, kQ4gswTileN);
if (total_tiles > static_cast<int64_t>(UINT32_MAX)) {
throw std::runtime_error(
"WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit");
}
workgroup_count = utils::compute_1d_workgroup_count(
device, static_cast<uint32_t>(total_tiles), wg_size, "linear_q4gsw");
}

// Optional bias: real buffer if present, else a dummy for the fixed layout.
uint32_t has_bias = 0;
WGPUBuffer bias_buffer = nullptr;
Expand Down Expand Up @@ -156,7 +186,7 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {

WGPUShaderSourceWGSL wgsl_desc = {};
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
wgsl_desc.code = {kQ4gswLinearWGSL, WGPU_STRLEN};
wgsl_desc.code = {shader_src, WGPU_STRLEN};
WGPUShaderModuleDescriptor shader_desc = {};
shader_desc.nextInChain = &wgsl_desc.chain;
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);
Expand Down Expand Up @@ -186,8 +216,6 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
WGPUPipelineLayout pipeline_layout =
wgpuDeviceCreatePipelineLayout(device, &pl_desc);

const uint32_t wg_size =
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
wg_size_constant.value = static_cast<double>(wg_size);
Expand All @@ -196,8 +224,9 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
pipeline_desc.layout = pipeline_layout;
pipeline_desc.compute.module = shader;
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
pipeline_desc.compute.constantCount = 1;
pipeline_desc.compute.constants = &wg_size_constant;
// coop4 GEMV uses fixed @workgroup_size(64); only the GEMM has an override.
pipeline_desc.compute.constantCount = use_gemv ? 0u : 1u;
pipeline_desc.compute.constants = use_gemv ? nullptr : &wg_size_constant;
WGPUComputePipeline pipeline =
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);

Expand Down
77 changes: 52 additions & 25 deletions backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,74 @@ struct Params {

override wg_size: u32 = 64u;

// One workgroup per row m, threads stride N; loop logical K only (in-bounds).
// Register-tiled GEMM: dequant weight once per (n,k), reused across TM rows.
const TM: u32 = 4u;
const TN: u32 = 4u;
const TILE_ELEMS: u32 = TM * TN; // accumulator size; keeps acc in sync with TM/TN

@compute @workgroup_size(wg_size, 1, 1)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let m = wid.x;
if (m >= params.M) {
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let nrt = (params.M + TM - 1u) / TM;
let nct = (params.N + TN - 1u) / TN;
let tiles = nrt * nct;
// M==0 or N==0 -> tiles==0 -> every thread returns here, so the M-1u/N-1u
// clamps below never underflow (the host also rejects M==0/N==0).
if (gid.x >= tiles) {
return;
}
let in_base = m * params.K;
let row_tile = gid.x / nct;
let col_tile = gid.x % nct;
let m0 = row_tile * TM;
let n0 = col_tile * TN;

var acc: array<f32, TILE_ELEMS>;
for (var i: u32 = 0u; i < TILE_ELEMS; i = i + 1u) {
acc[i] = 0.0;
}

var n: u32 = lid.x;
var k: u32 = 0u;
loop {
if (n >= params.N) {
if (k >= params.K) {
break;
}
var acc: f32 = 0.0;
var k: u32 = 0u;
loop {
if (k >= params.K) {
break;
}
// Packed weight byte for (n, k): row stride K_packed bytes, byte k/2.
let byte_idx = n * params.K_packed + (k >> 1u);
// Load the TM input values for column k once; reused across all TN columns.
var in_reg: array<f32, TM>;
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
let m_eff = min(m0 + ml, params.M - 1u);
in_reg[ml] = t_input[m_eff * params.K + k];
}
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
// Clamp to last valid column; overhang result is never stored.
let n_eff = min(n0 + nl, params.N - 1u);
let byte_idx = n_eff * params.K_packed + (k >> 1u);
let word = t_weight[byte_idx >> 2u];
let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu;
var nib: u32;
if ((k & 1u) == 0u) {
nib = b & 0x0Fu; // even k -> low nibble
nib = b & 0x0Fu; // even k -> low nibble
} else {
nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble
}
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
let scale = t_scales[(k / params.group_size) * params.padded_N + n];
acc = acc + t_input[in_base + k] * q * scale;
k = k + 1u;
let dq = q * t_scales[(k / params.group_size) * params.padded_N + n_eff];
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
acc[ml * TN + nl] = acc[ml * TN + nl] + in_reg[ml] * dq;
}
}
if (params.has_bias != 0u) {
acc = acc + t_bias[n];
k = k + 1u;
}

for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
let m = m0 + ml;
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
let n = n0 + nl;
if (m < params.M && n < params.N) {
var v = acc[ml * TN + nl];
if (params.has_bias != 0u) {
v = v + t_bias[n];
}
t_out[m * params.N + n] = v;
}
}
t_out[m * params.N + n] = acc;
n = n + wg_size;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
@group(0) @binding(1) var<storage, read> t_input: array<f32>;
@group(0) @binding(2) var<storage, read> t_weight: array<u32>;
@group(0) @binding(3) var<storage, read> t_scales: array<f32>;
@group(0) @binding(4) var<storage, read> t_bias: array<f32>;

struct Params {
M: u32,
N: u32,
K: u32,
K_packed: u32,
group_size: u32,
padded_N: u32,
has_bias: u32,
_pad: u32,
}
@group(0) @binding(5) var<uniform> params: Params;

// Cooperative-over-K GEMV with u32-batched coalesced weight loads (64 lanes).
const WG: u32 = 64u;
var<workgroup> partial: array<f32, WG>;

@compute @workgroup_size(WG, 1, 1)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(num_workgroups) ngrp: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let total = params.M * params.N;
let stride = ngrp.x;
let num_words = params.K >> 3u; // K / 8 words per row
let row_words = params.K_packed >> 2u; // u32s per weight row (= K/8)
var idx = wid.x;
loop {
if (idx >= total) {
break;
}
let m = idx / params.N;
let n = idx % params.N;
let in_base = m * params.K;
let wbase = n * row_words;

var acc: f32 = 0.0;
var w: u32 = lid.x;
loop {
if (w >= num_words) {
break;
}
let word = t_weight[wbase + w];
let k0 = w << 3u; // first K of this word
let scale = t_scales[(k0 / params.group_size) * params.padded_N + n];
let ib = in_base + k0;
// 4 bytes, low+high nibble each -> 8 consecutive K.
for (var bi: u32 = 0u; bi < 4u; bi = bi + 1u) {
let byte = (word >> (bi * 8u)) & 0xFFu;
let lo = f32(i32(byte & 0x0Fu) - 8);
let hi = f32(i32((byte >> 4u) & 0x0Fu) - 8);
let kk = bi << 1u;
acc = acc + t_input[ib + kk] * lo * scale;
acc = acc + t_input[ib + kk + 1u] * hi * scale;
}
w = w + WG;
}

partial[lid.x] = acc;
workgroupBarrier();
var s: u32 = WG >> 1u;
loop {
if (s == 0u) {
break;
}
if (lid.x < s) {
partial[lid.x] = partial[lid.x] + partial[lid.x + s];
}
workgroupBarrier();
s = s >> 1u;
}
if (lid.x == 0u) {
var o = partial[0];
if (params.has_bias != 0u) {
o = o + t_bias[n];
}
t_out[idx] = o;
}
workgroupBarrier();
idx = idx + stride;
}
}
Loading
Loading