From 4c2ed37a325641a8cfa2e5d1da3dde0afaccc320 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Tue, 23 Jun 2026 13:16:44 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- backends/webgpu/runtime/WebGPUUtils.h | 12 ++ .../ops/quantized_linear/QuantizedLinear.cpp | 49 +++++--- .../quantized_linear/q4gsw_linear_coop4.wgsl | 87 ++++++++++++++ .../q4gsw_linear_coop4_wgsl.h | 111 ++++++++++++++++++ 4 files changed, 244 insertions(+), 15 deletions(-) create mode 100644 backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4.wgsl create mode 100644 backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h diff --git a/backends/webgpu/runtime/WebGPUUtils.h b/backends/webgpu/runtime/WebGPUUtils.h index 39eb3caa28b..ddffc6029db 100644 --- a/backends/webgpu/runtime/WebGPUUtils.h +++ b/backends/webgpu/runtime/WebGPUUtils.h @@ -70,4 +70,16 @@ make_uniform(WGPUDevice device, const void* data, size_t size) { return buf; } +// Clamp a 1D workgroup count to the device limit, for grid-stride kernels that +// loop over any excess work (vs compute_1d_workgroup_count, which throws). +inline uint32_t clamp_workgroup_count(WGPUDevice device, uint32_t desired) { + WGPULimits limits = {}; + uint32_t max_count = + wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success && + limits.maxComputeWorkgroupsPerDimension > 0 + ? limits.maxComputeWorkgroupsPerDimension + : 65535u; // WebGPU spec-default floor + return std::min(desired, max_count); +} + } // namespace executorch::backends::webgpu::utils diff --git a/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp b/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp index 99af2db52fa..d33e89e96e2 100644 --- a/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp +++ b/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -93,18 +94,6 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { "WebGPU linear_q4gsw: N*K_packed must be a multiple of 4 (u32-packed)"); } - // Register-tiled GEMM: one thread per TM x TN tile; validate before alloc. - const uint32_t wg_size = - utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX); - const int64_t total_tiles = - q4gsw_ceil_div(M, kQ4gswTileM) * q4gsw_ceil_div(N, kQ4gswTileN); - if (total_tiles > static_cast(UINT32_MAX)) { - throw std::runtime_error( - "WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit"); - } - const uint32_t workgroup_count = utils::compute_1d_workgroup_count( - device, static_cast(total_tiles), wg_size, "linear_q4gsw"); - // fp32-only byte-size guards (no runtime dtype); fp16 scales -> bail. const uint64_t scales_numel = static_cast(num_groups) * static_cast(padded_N); @@ -132,6 +121,35 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { "WebGPU linear_q4gsw: scales dims too small for K/N"); } + // M==1 decode -> coop4 GEMV (needs K%8==0 && gs%8==0); else tiled GEMM. + const uint32_t wg_size = + utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX); + const bool use_gemv = (M == 1u && K % 8u == 0u && gs % 8u == 0u); + const char* shader_src = use_gemv ? kQ4gswLinearCoop4WGSL : kQ4gswLinearWGSL; + uint32_t workgroup_count; + if (use_gemv) { + // coop4: fixed 64 lanes, 1 workgroup per output, grid-strided over M*N. + const uint64_t outputs = + static_cast(M) * static_cast(N); + if (outputs == 0u || outputs > UINT32_MAX) { + throw std::runtime_error("WebGPU linear_q4gsw: M*N out of range"); + } + workgroup_count = + utils::clamp_workgroup_count(device, static_cast(outputs)); + if (workgroup_count == 0u) { + throw std::runtime_error("WebGPU linear_q4gsw: zero GEMV dispatch"); + } + } else { + const int64_t total_tiles = + q4gsw_ceil_div(M, kQ4gswTileM) * q4gsw_ceil_div(N, kQ4gswTileN); + if (total_tiles > static_cast(UINT32_MAX)) { + throw std::runtime_error( + "WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit"); + } + workgroup_count = utils::compute_1d_workgroup_count( + device, static_cast(total_tiles), wg_size, "linear_q4gsw"); + } + // Optional bias: real buffer if present, else a dummy for the fixed layout. uint32_t has_bias = 0; WGPUBuffer bias_buffer = nullptr; @@ -172,7 +190,7 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { WGPUShaderSourceWGSL wgsl_desc = {}; wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; - wgsl_desc.code = {kQ4gswLinearWGSL, WGPU_STRLEN}; + wgsl_desc.code = {shader_src, WGPU_STRLEN}; WGPUShaderModuleDescriptor shader_desc = {}; shader_desc.nextInChain = &wgsl_desc.chain; WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc); @@ -210,8 +228,9 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { pipeline_desc.layout = pipeline_layout; pipeline_desc.compute.module = shader; pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN}; - pipeline_desc.compute.constantCount = 1; - pipeline_desc.compute.constants = &wg_size_constant; + // coop4 GEMV uses fixed @workgroup_size(64); only the GEMM has an override. + pipeline_desc.compute.constantCount = use_gemv ? 0u : 1u; + pipeline_desc.compute.constants = use_gemv ? nullptr : &wg_size_constant; WGPUComputePipeline pipeline = wgpuDeviceCreateComputePipeline(device, &pipeline_desc); diff --git a/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4.wgsl b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4.wgsl new file mode 100644 index 00000000000..04e3d5eb3a6 --- /dev/null +++ b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4.wgsl @@ -0,0 +1,87 @@ +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_input: array; +@group(0) @binding(2) var t_weight: array; +@group(0) @binding(3) var t_scales: array; +@group(0) @binding(4) var t_bias: array; + +struct Params { + M: u32, + N: u32, + K: u32, + K_packed: u32, + group_size: u32, + padded_N: u32, + has_bias: u32, + _pad: u32, +} +@group(0) @binding(5) var params: Params; + +// Cooperative-over-K GEMV with u32-batched coalesced weight loads (64 lanes). +const WG: u32 = 64u; +var partial: array; + +@compute @workgroup_size(64, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(num_workgroups) ngrp: vec3, + @builtin(local_invocation_id) lid: vec3) { + let total = params.M * params.N; + let stride = ngrp.x; + let num_words = params.K >> 3u; // K / 8 words per row + let row_words = params.K_packed >> 2u; // u32s per weight row (= K/8) + var idx = wid.x; + loop { + if (idx >= total) { + break; + } + let m = idx / params.N; + let n = idx % params.N; + let in_base = m * params.K; + let wbase = n * row_words; + + var acc: f32 = 0.0; + var w: u32 = lid.x; + loop { + if (w >= num_words) { + break; + } + let word = t_weight[wbase + w]; + let k0 = w << 3u; // first K of this word + let scale = t_scales[(k0 / params.group_size) * params.padded_N + n]; + let ib = in_base + k0; + // 4 bytes, low+high nibble each -> 8 consecutive K. + for (var bi: u32 = 0u; bi < 4u; bi = bi + 1u) { + let byte = (word >> (bi * 8u)) & 0xFFu; + let lo = f32(i32(byte & 0x0Fu) - 8); + let hi = f32(i32((byte >> 4u) & 0x0Fu) - 8); + let kk = bi << 1u; + acc = acc + t_input[ib + kk] * lo * scale; + acc = acc + t_input[ib + kk + 1u] * hi * scale; + } + w = w + WG; + } + + partial[lid.x] = acc; + workgroupBarrier(); + var s: u32 = WG >> 1u; + loop { + if (s == 0u) { + break; + } + if (lid.x < s) { + partial[lid.x] = partial[lid.x] + partial[lid.x + s]; + } + workgroupBarrier(); + s = s >> 1u; + } + if (lid.x == 0u) { + var o = partial[0]; + if (params.has_bias != 0u) { + o = o + t_bias[n]; + } + t_out[idx] = o; + } + workgroupBarrier(); + idx = idx + stride; + } +} diff --git a/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h new file mode 100644 index 00000000000..507f772090f --- /dev/null +++ b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h @@ -0,0 +1,111 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch::backends::webgpu { + +// @generated from q4gsw_linear_coop4.wgsl - DO NOT EDIT. +// wgsl-sha256: 6e296f0583118d1ff0df914dd3ac078e7f4e526d99be7d233531a47fddb93f89 +inline constexpr const char* kQ4gswLinearCoop4WGSL = R"( +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_input: array; +@group(0) @binding(2) var t_weight: array; +@group(0) @binding(3) var t_scales: array; +@group(0) @binding(4) var t_bias: array; + +struct Params { + M: u32, + N: u32, + K: u32, + K_packed: u32, + group_size: u32, + padded_N: u32, + has_bias: u32, + _pad: u32, +} +@group(0) @binding(5) var params: Params; + +// Cooperative-over-K GEMV with u32-batched coalesced weight loads (64 lanes). +const WG: u32 = 64u; +var partial: array; + +@compute @workgroup_size(64, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(num_workgroups) ngrp: vec3, + @builtin(local_invocation_id) lid: vec3) { + let total = params.M * params.N; + let stride = ngrp.x; + let num_words = params.K >> 3u; // K / 8 words per row + let row_words = params.K_packed >> 2u; // u32s per weight row (= K/8) + var idx = wid.x; + loop { + if (idx >= total) { + break; + } + let m = idx / params.N; + let n = idx % params.N; + let in_base = m * params.K; + let wbase = n * row_words; + + var acc: f32 = 0.0; + var w: u32 = lid.x; + loop { + if (w >= num_words) { + break; + } + let word = t_weight[wbase + w]; + let k0 = w << 3u; // first K of this word + let scale = t_scales[(k0 / params.group_size) * params.padded_N + n]; + let ib = in_base + k0; + // 4 bytes, low+high nibble each -> 8 consecutive K. + for (var bi: u32 = 0u; bi < 4u; bi = bi + 1u) { + let byte = (word >> (bi * 8u)) & 0xFFu; + let lo = f32(i32(byte & 0x0Fu) - 8); + let hi = f32(i32((byte >> 4u) & 0x0Fu) - 8); + let kk = bi << 1u; + acc = acc + t_input[ib + kk] * lo * scale; + acc = acc + t_input[ib + kk + 1u] * hi * scale; + } + w = w + WG; + } + + partial[lid.x] = acc; + workgroupBarrier(); + var s: u32 = WG >> 1u; + loop { + if (s == 0u) { + break; + } + if (lid.x < s) { + partial[lid.x] = partial[lid.x] + partial[lid.x + s]; + } + workgroupBarrier(); + s = s >> 1u; + } + if (lid.x == 0u) { + var o = partial[0]; + if (params.has_bias != 0u) { + o = o + t_bias[n]; + } + t_out[idx] = o; + } + workgroupBarrier(); + idx = idx + stride; + } +} +)"; + +inline constexpr uint32_t kQ4gswLinearCoop4WorkgroupSizeX = 64; +inline constexpr uint32_t kQ4gswLinearCoop4WorkgroupSizeY = 1; +inline constexpr uint32_t kQ4gswLinearCoop4WorkgroupSizeZ = 1; + +} // namespace executorch::backends::webgpu From 6da8fa9d99ebf4fde39b0e7301617e7bc985e431 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Wed, 24 Jun 2026 11:22:05 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- .../runtime/ops/quantized_linear/q4gsw_linear_coop4.wgsl | 4 ++-- .../runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4.wgsl b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4.wgsl index 04e3d5eb3a6..af6f661279d 100644 --- a/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4.wgsl +++ b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4.wgsl @@ -18,9 +18,9 @@ struct Params { // Cooperative-over-K GEMV with u32-batched coalesced weight loads (64 lanes). const WG: u32 = 64u; -var partial: array; +var partial: array; -@compute @workgroup_size(64, 1, 1) +@compute @workgroup_size(WG, 1, 1) fn main( @builtin(workgroup_id) wid: vec3, @builtin(num_workgroups) ngrp: vec3, diff --git a/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h index 507f772090f..7bacedfcb17 100644 --- a/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h +++ b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from q4gsw_linear_coop4.wgsl - DO NOT EDIT. -// wgsl-sha256: 6e296f0583118d1ff0df914dd3ac078e7f4e526d99be7d233531a47fddb93f89 +// wgsl-sha256: 3031886e68c375e617dfb263da39c492c6de4d8c1fb4073d70b18823a3e6a4fe inline constexpr const char* kQ4gswLinearCoop4WGSL = R"( @group(0) @binding(0) var t_out: array; @group(0) @binding(1) var t_input: array; @@ -35,9 +35,9 @@ struct Params { // Cooperative-over-K GEMV with u32-batched coalesced weight loads (64 lanes). const WG: u32 = 64u; -var partial: array; +var partial: array; -@compute @workgroup_size(64, 1, 1) +@compute @workgroup_size(WG, 1, 1) fn main( @builtin(workgroup_id) wid: vec3, @builtin(num_workgroups) ngrp: vec3,