diff --git a/backends/webgpu/runtime/WebGPUUtils.h b/backends/webgpu/runtime/WebGPUUtils.h index 39eb3caa28b..c5c779ffd5e 100644 --- a/backends/webgpu/runtime/WebGPUUtils.h +++ b/backends/webgpu/runtime/WebGPUUtils.h @@ -18,6 +18,12 @@ namespace executorch::backends::webgpu::utils { +// Ceiling division for non-negative integers (mirrors Vulkan's utils::div_up). +template +inline T div_up(T a, T b) { + return (a + b - 1) / b; +} + // Clamp workgroup size to device limit (SwiftShader caps at 128). inline uint32_t clamp_workgroup_size(WGPUDevice device, uint32_t desired) { WGPULimits limits = {}; @@ -34,7 +40,7 @@ inline uint32_t compute_1d_workgroup_count( uint32_t num_threads, uint32_t workgroup_size, const char* op_name) { - uint32_t count = (num_threads + workgroup_size - 1) / workgroup_size; + uint32_t count = div_up(num_threads, workgroup_size); WGPULimits limits = {}; uint32_t max_count = wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success && @@ -70,4 +76,16 @@ make_uniform(WGPUDevice device, const void* data, size_t size) { return buf; } +// Clamp a 1D workgroup count to the device limit, for grid-stride kernels that +// loop over any excess work (vs compute_1d_workgroup_count, which throws). +inline uint32_t clamp_workgroup_count(WGPUDevice device, uint32_t desired) { + WGPULimits limits = {}; + uint32_t max_count = + wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success && + limits.maxComputeWorkgroupsPerDimension > 0 + ? limits.maxComputeWorkgroupsPerDimension + : 65535u; // WebGPU spec-default floor + return std::min(desired, max_count); +} + } // namespace executorch::backends::webgpu::utils diff --git a/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp b/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp index 2597aea10d4..6da05ff2010 100644 --- a/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp +++ b/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -34,6 +35,10 @@ struct Q4gswParams { }; static_assert(sizeof(Q4gswParams) == 32, "Q4gswParams must be 32 bytes"); +// Register-tile dims; MUST match TM/TN in q4gsw_linear.wgsl. +constexpr int64_t kQ4gswTileM = 4; +constexpr int64_t kQ4gswTileN = 4; + // et_vk.linear_q4gsw args: [in, weight, scales, group_size, bias, out]. void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { const int in_id = args.at(0); @@ -85,10 +90,6 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { "WebGPU linear_q4gsw: N*K_packed must be a multiple of 4 (u32-packed)"); } - // One workgroup per output row (M); validate dispatch before any alloc. - const uint32_t workgroup_count = - utils::compute_1d_workgroup_count(device, M, 1, "linear_q4gsw"); - // fp32-only byte-size guards (no runtime dtype); fp16 scales -> bail. const uint64_t scales_numel = static_cast(num_groups) * static_cast(padded_N); @@ -116,6 +117,35 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { "WebGPU linear_q4gsw: scales dims too small for K/N"); } + // M==1 decode -> coop4 GEMV (needs K%8==0 && gs%8==0); else tiled GEMM. + const uint32_t wg_size = + utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX); + const bool use_gemv = (M == 1u && K % 8u == 0u && gs % 8u == 0u); + const char* shader_src = use_gemv ? kQ4gswLinearCoop4WGSL : kQ4gswLinearWGSL; + uint32_t workgroup_count; + if (use_gemv) { + // coop4: fixed 64 lanes, 1 workgroup per output, grid-strided over M*N. + const uint64_t outputs = + static_cast(M) * static_cast(N); + if (outputs == 0u || outputs > UINT32_MAX) { + throw std::runtime_error("WebGPU linear_q4gsw: M*N out of range"); + } + workgroup_count = + utils::clamp_workgroup_count(device, static_cast(outputs)); + if (workgroup_count == 0u) { + throw std::runtime_error("WebGPU linear_q4gsw: zero GEMV dispatch"); + } + } else { + const int64_t total_tiles = utils::div_up(M, kQ4gswTileM) * + utils::div_up(N, kQ4gswTileN); + if (total_tiles > static_cast(UINT32_MAX)) { + throw std::runtime_error( + "WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit"); + } + workgroup_count = utils::compute_1d_workgroup_count( + device, static_cast(total_tiles), wg_size, "linear_q4gsw"); + } + // Optional bias: real buffer if present, else a dummy for the fixed layout. uint32_t has_bias = 0; WGPUBuffer bias_buffer = nullptr; @@ -156,7 +186,7 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { WGPUShaderSourceWGSL wgsl_desc = {}; wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; - wgsl_desc.code = {kQ4gswLinearWGSL, WGPU_STRLEN}; + wgsl_desc.code = {shader_src, WGPU_STRLEN}; WGPUShaderModuleDescriptor shader_desc = {}; shader_desc.nextInChain = &wgsl_desc.chain; WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc); @@ -186,8 +216,6 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { WGPUPipelineLayout pipeline_layout = wgpuDeviceCreatePipelineLayout(device, &pl_desc); - const uint32_t wg_size = - utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX); WGPUConstantEntry wg_size_constant = {}; wg_size_constant.key = {"wg_size", WGPU_STRLEN}; wg_size_constant.value = static_cast(wg_size); @@ -196,8 +224,9 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { pipeline_desc.layout = pipeline_layout; pipeline_desc.compute.module = shader; pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN}; - pipeline_desc.compute.constantCount = 1; - pipeline_desc.compute.constants = &wg_size_constant; + // coop4 GEMV uses fixed @workgroup_size(64); only the GEMM has an override. + pipeline_desc.compute.constantCount = use_gemv ? 0u : 1u; + pipeline_desc.compute.constants = use_gemv ? nullptr : &wg_size_constant; WGPUComputePipeline pipeline = wgpuDeviceCreateComputePipeline(device, &pipeline_desc); diff --git a/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear.wgsl b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear.wgsl index d0d6e155987..8cea61d331c 100644 --- a/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear.wgsl +++ b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear.wgsl @@ -18,47 +18,74 @@ struct Params { override wg_size: u32 = 64u; -// One workgroup per row m, threads stride N; loop logical K only (in-bounds). +// Register-tiled GEMM: dequant weight once per (n,k), reused across TM rows. +const TM: u32 = 4u; +const TN: u32 = 4u; +const TILE_ELEMS: u32 = TM * TN; // accumulator size; keeps acc in sync with TM/TN + @compute @workgroup_size(wg_size, 1, 1) -fn main( - @builtin(workgroup_id) wid: vec3, - @builtin(local_invocation_id) lid: vec3) { - let m = wid.x; - if (m >= params.M) { +fn main(@builtin(global_invocation_id) gid: vec3) { + let nrt = (params.M + TM - 1u) / TM; + let nct = (params.N + TN - 1u) / TN; + let tiles = nrt * nct; + // M==0 or N==0 -> tiles==0 -> every thread returns here, so the M-1u/N-1u + // clamps below never underflow (the host also rejects M==0/N==0). + if (gid.x >= tiles) { return; } - let in_base = m * params.K; + let row_tile = gid.x / nct; + let col_tile = gid.x % nct; + let m0 = row_tile * TM; + let n0 = col_tile * TN; + + var acc: array; + for (var i: u32 = 0u; i < TILE_ELEMS; i = i + 1u) { + acc[i] = 0.0; + } - var n: u32 = lid.x; + var k: u32 = 0u; loop { - if (n >= params.N) { + if (k >= params.K) { break; } - var acc: f32 = 0.0; - var k: u32 = 0u; - loop { - if (k >= params.K) { - break; - } - // Packed weight byte for (n, k): row stride K_packed bytes, byte k/2. - let byte_idx = n * params.K_packed + (k >> 1u); + // Load the TM input values for column k once; reused across all TN columns. + var in_reg: array; + for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) { + let m_eff = min(m0 + ml, params.M - 1u); + in_reg[ml] = t_input[m_eff * params.K + k]; + } + for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) { + // Clamp to last valid column; overhang result is never stored. + let n_eff = min(n0 + nl, params.N - 1u); + let byte_idx = n_eff * params.K_packed + (k >> 1u); let word = t_weight[byte_idx >> 2u]; let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu; var nib: u32; if ((k & 1u) == 0u) { - nib = b & 0x0Fu; // even k -> low nibble + nib = b & 0x0Fu; // even k -> low nibble } else { nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble } let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7] - let scale = t_scales[(k / params.group_size) * params.padded_N + n]; - acc = acc + t_input[in_base + k] * q * scale; - k = k + 1u; + let dq = q * t_scales[(k / params.group_size) * params.padded_N + n_eff]; + for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) { + acc[ml * TN + nl] = acc[ml * TN + nl] + in_reg[ml] * dq; + } } - if (params.has_bias != 0u) { - acc = acc + t_bias[n]; + k = k + 1u; + } + + for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) { + let m = m0 + ml; + for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) { + let n = n0 + nl; + if (m < params.M && n < params.N) { + var v = acc[ml * TN + nl]; + if (params.has_bias != 0u) { + v = v + t_bias[n]; + } + t_out[m * params.N + n] = v; + } } - t_out[m * params.N + n] = acc; - n = n + wg_size; } } diff --git a/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4.wgsl b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4.wgsl new file mode 100644 index 00000000000..af6f661279d --- /dev/null +++ b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4.wgsl @@ -0,0 +1,87 @@ +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_input: array; +@group(0) @binding(2) var t_weight: array; +@group(0) @binding(3) var t_scales: array; +@group(0) @binding(4) var t_bias: array; + +struct Params { + M: u32, + N: u32, + K: u32, + K_packed: u32, + group_size: u32, + padded_N: u32, + has_bias: u32, + _pad: u32, +} +@group(0) @binding(5) var params: Params; + +// Cooperative-over-K GEMV with u32-batched coalesced weight loads (64 lanes). +const WG: u32 = 64u; +var partial: array; + +@compute @workgroup_size(WG, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(num_workgroups) ngrp: vec3, + @builtin(local_invocation_id) lid: vec3) { + let total = params.M * params.N; + let stride = ngrp.x; + let num_words = params.K >> 3u; // K / 8 words per row + let row_words = params.K_packed >> 2u; // u32s per weight row (= K/8) + var idx = wid.x; + loop { + if (idx >= total) { + break; + } + let m = idx / params.N; + let n = idx % params.N; + let in_base = m * params.K; + let wbase = n * row_words; + + var acc: f32 = 0.0; + var w: u32 = lid.x; + loop { + if (w >= num_words) { + break; + } + let word = t_weight[wbase + w]; + let k0 = w << 3u; // first K of this word + let scale = t_scales[(k0 / params.group_size) * params.padded_N + n]; + let ib = in_base + k0; + // 4 bytes, low+high nibble each -> 8 consecutive K. + for (var bi: u32 = 0u; bi < 4u; bi = bi + 1u) { + let byte = (word >> (bi * 8u)) & 0xFFu; + let lo = f32(i32(byte & 0x0Fu) - 8); + let hi = f32(i32((byte >> 4u) & 0x0Fu) - 8); + let kk = bi << 1u; + acc = acc + t_input[ib + kk] * lo * scale; + acc = acc + t_input[ib + kk + 1u] * hi * scale; + } + w = w + WG; + } + + partial[lid.x] = acc; + workgroupBarrier(); + var s: u32 = WG >> 1u; + loop { + if (s == 0u) { + break; + } + if (lid.x < s) { + partial[lid.x] = partial[lid.x] + partial[lid.x + s]; + } + workgroupBarrier(); + s = s >> 1u; + } + if (lid.x == 0u) { + var o = partial[0]; + if (params.has_bias != 0u) { + o = o + t_bias[n]; + } + t_out[idx] = o; + } + workgroupBarrier(); + idx = idx + stride; + } +} diff --git a/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h new file mode 100644 index 00000000000..7bacedfcb17 --- /dev/null +++ b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h @@ -0,0 +1,111 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch::backends::webgpu { + +// @generated from q4gsw_linear_coop4.wgsl - DO NOT EDIT. +// wgsl-sha256: 3031886e68c375e617dfb263da39c492c6de4d8c1fb4073d70b18823a3e6a4fe +inline constexpr const char* kQ4gswLinearCoop4WGSL = R"( +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_input: array; +@group(0) @binding(2) var t_weight: array; +@group(0) @binding(3) var t_scales: array; +@group(0) @binding(4) var t_bias: array; + +struct Params { + M: u32, + N: u32, + K: u32, + K_packed: u32, + group_size: u32, + padded_N: u32, + has_bias: u32, + _pad: u32, +} +@group(0) @binding(5) var params: Params; + +// Cooperative-over-K GEMV with u32-batched coalesced weight loads (64 lanes). +const WG: u32 = 64u; +var partial: array; + +@compute @workgroup_size(WG, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(num_workgroups) ngrp: vec3, + @builtin(local_invocation_id) lid: vec3) { + let total = params.M * params.N; + let stride = ngrp.x; + let num_words = params.K >> 3u; // K / 8 words per row + let row_words = params.K_packed >> 2u; // u32s per weight row (= K/8) + var idx = wid.x; + loop { + if (idx >= total) { + break; + } + let m = idx / params.N; + let n = idx % params.N; + let in_base = m * params.K; + let wbase = n * row_words; + + var acc: f32 = 0.0; + var w: u32 = lid.x; + loop { + if (w >= num_words) { + break; + } + let word = t_weight[wbase + w]; + let k0 = w << 3u; // first K of this word + let scale = t_scales[(k0 / params.group_size) * params.padded_N + n]; + let ib = in_base + k0; + // 4 bytes, low+high nibble each -> 8 consecutive K. + for (var bi: u32 = 0u; bi < 4u; bi = bi + 1u) { + let byte = (word >> (bi * 8u)) & 0xFFu; + let lo = f32(i32(byte & 0x0Fu) - 8); + let hi = f32(i32((byte >> 4u) & 0x0Fu) - 8); + let kk = bi << 1u; + acc = acc + t_input[ib + kk] * lo * scale; + acc = acc + t_input[ib + kk + 1u] * hi * scale; + } + w = w + WG; + } + + partial[lid.x] = acc; + workgroupBarrier(); + var s: u32 = WG >> 1u; + loop { + if (s == 0u) { + break; + } + if (lid.x < s) { + partial[lid.x] = partial[lid.x] + partial[lid.x + s]; + } + workgroupBarrier(); + s = s >> 1u; + } + if (lid.x == 0u) { + var o = partial[0]; + if (params.has_bias != 0u) { + o = o + t_bias[n]; + } + t_out[idx] = o; + } + workgroupBarrier(); + idx = idx + stride; + } +} +)"; + +inline constexpr uint32_t kQ4gswLinearCoop4WorkgroupSizeX = 64; +inline constexpr uint32_t kQ4gswLinearCoop4WorkgroupSizeY = 1; +inline constexpr uint32_t kQ4gswLinearCoop4WorkgroupSizeZ = 1; + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h index d176a01d27f..69494bbc947 100644 --- a/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h +++ b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from q4gsw_linear.wgsl - DO NOT EDIT. -// wgsl-sha256: 966cec5d4102eb7c8f6504d2a335a1bd2f235424933fe83b4d0f8f274d894f39 +// wgsl-sha256: dc6a55014ae4543bd80e5e22c3fb52896aca96e0589f700803327d8121ada489 inline constexpr const char* kQ4gswLinearWGSL = R"( @group(0) @binding(0) var t_out: array; @group(0) @binding(1) var t_input: array; @@ -35,48 +35,75 @@ struct Params { override wg_size: u32 = 64u; -// One workgroup per row m, threads stride N; loop logical K only (in-bounds). +// Register-tiled GEMM: dequant weight once per (n,k), reused across TM rows. +const TM: u32 = 4u; +const TN: u32 = 4u; +const TILE_ELEMS: u32 = TM * TN; // accumulator size; keeps acc in sync with TM/TN + @compute @workgroup_size(wg_size, 1, 1) -fn main( - @builtin(workgroup_id) wid: vec3, - @builtin(local_invocation_id) lid: vec3) { - let m = wid.x; - if (m >= params.M) { +fn main(@builtin(global_invocation_id) gid: vec3) { + let nrt = (params.M + TM - 1u) / TM; + let nct = (params.N + TN - 1u) / TN; + let tiles = nrt * nct; + // M==0 or N==0 -> tiles==0 -> every thread returns here, so the M-1u/N-1u + // clamps below never underflow (the host also rejects M==0/N==0). + if (gid.x >= tiles) { return; } - let in_base = m * params.K; + let row_tile = gid.x / nct; + let col_tile = gid.x % nct; + let m0 = row_tile * TM; + let n0 = col_tile * TN; + + var acc: array; + for (var i: u32 = 0u; i < TILE_ELEMS; i = i + 1u) { + acc[i] = 0.0; + } - var n: u32 = lid.x; + var k: u32 = 0u; loop { - if (n >= params.N) { + if (k >= params.K) { break; } - var acc: f32 = 0.0; - var k: u32 = 0u; - loop { - if (k >= params.K) { - break; - } - // Packed weight byte for (n, k): row stride K_packed bytes, byte k/2. - let byte_idx = n * params.K_packed + (k >> 1u); + // Load the TM input values for column k once; reused across all TN columns. + var in_reg: array; + for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) { + let m_eff = min(m0 + ml, params.M - 1u); + in_reg[ml] = t_input[m_eff * params.K + k]; + } + for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) { + // Clamp to last valid column; overhang result is never stored. + let n_eff = min(n0 + nl, params.N - 1u); + let byte_idx = n_eff * params.K_packed + (k >> 1u); let word = t_weight[byte_idx >> 2u]; let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu; var nib: u32; if ((k & 1u) == 0u) { - nib = b & 0x0Fu; // even k -> low nibble + nib = b & 0x0Fu; // even k -> low nibble } else { nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble } let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7] - let scale = t_scales[(k / params.group_size) * params.padded_N + n]; - acc = acc + t_input[in_base + k] * q * scale; - k = k + 1u; + let dq = q * t_scales[(k / params.group_size) * params.padded_N + n_eff]; + for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) { + acc[ml * TN + nl] = acc[ml * TN + nl] + in_reg[ml] * dq; + } } - if (params.has_bias != 0u) { - acc = acc + t_bias[n]; + k = k + 1u; + } + + for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) { + let m = m0 + ml; + for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) { + let n = n0 + nl; + if (m < params.M && n < params.N) { + var v = acc[ml * TN + nl]; + if (params.has_bias != 0u) { + v = v + t_bias[n]; + } + t_out[m * params.N + n] = v; + } } - t_out[m * params.N + n] = acc; - n = n + wg_size; } } )"; diff --git a/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp b/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp index 7de83330810..e73c6e23a88 100644 --- a/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp +++ b/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -92,10 +93,17 @@ void rms_norm_impl(WebGPUGraph& graph, const std::vector& args) { graph.add_uniform_buffer_bytes(sizeof(RmsNormParams)); + // Select the vec4 kernel when the row width is a multiple of 4 (every Llama + // hidden size qualifies); fall back to the scalar kernel otherwise. The two + // kernels are equivalent up to floating-point reassociation (the vec4 + // reduction reorders the sum, so not bit-identical) and share the same bind + // group + dispatch. + const bool use_vec4 = (row_width % 4u == 0u); + // Create shader module from built-in WGSL source WGPUShaderSourceWGSL wgsl_desc = {}; wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; - wgsl_desc.code = {kRmsNormWGSL, WGPU_STRLEN}; + wgsl_desc.code = {use_vec4 ? kRmsNormVec4WGSL : kRmsNormWGSL, WGPU_STRLEN}; WGPUShaderModuleDescriptor shader_desc = {}; shader_desc.nextInChain = &wgsl_desc.chain; @@ -176,6 +184,9 @@ void rms_norm_impl(WebGPUGraph& graph, const std::vector& args) { static_assert( kRmsNormWorkgroupSizeX == 64, "must match @workgroup_size and WG_SIZE in rms_norm.wgsl"); + static_assert( + kRmsNormVec4WorkgroupSizeX == 64, + "must match @workgroup_size and WG_SIZE in rms_norm_vec4.wgsl"); graph.add_dispatch({pipeline, bind_group, num_rows}); // Release intermediate objects (pipeline and bind_group are kept by dispatch) diff --git a/backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4.wgsl b/backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4.wgsl new file mode 100644 index 00000000000..c2f731e5f60 --- /dev/null +++ b/backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4.wgsl @@ -0,0 +1,74 @@ +@group(0) @binding(0) var t_out: array>; +@group(0) @binding(1) var t_in: array>; +@group(0) @binding(2) var t_weight: array>; + +struct Params { + num_rows: u32, + row_width: u32, + epsilon: f32, + _pad: u32, +} +@group(0) @binding(3) var params: Params; + +const WG_SIZE: u32 = 64u; + +var shared_sum: array; + +fn reduce_shared(worker_id: u32) { + workgroupBarrier(); + var stride: u32 = WG_SIZE / 2u; + loop { + if (stride == 0u) { + break; + } + if (worker_id < stride) { + shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride]; + } + workgroupBarrier(); + stride = stride >> 1u; + } +} + +// vec4 variant of rms_norm: each lane strides by WG_SIZE over rw4 = row_width/4 +// texels and accumulates dot(v, v). row_width is the ELEMENT count, so mean_sq +// divides by it (not rw4). The host selects this only when row_width % 4 == 0. +@compute @workgroup_size(64, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + let row_idx = wid.x; + let worker_id = lid.x; + + if (row_idx >= params.num_rows) { + return; + } + + let rw4 = params.row_width / 4u; + let base4 = row_idx * rw4; + + var local_sq_sum: f32 = 0.0; + var x4: u32 = worker_id; + loop { + if (x4 >= rw4) { + break; + } + let v = t_in[base4 + x4]; + local_sq_sum = local_sq_sum + dot(v, v); + x4 = x4 + WG_SIZE; + } + + shared_sum[worker_id] = local_sq_sum; + reduce_shared(worker_id); + + let mean_sq = shared_sum[0] / f32(params.row_width); + let rstd = inverseSqrt(mean_sq + params.epsilon); + + x4 = worker_id; + loop { + if (x4 >= rw4) { + break; + } + t_out[base4 + x4] = t_in[base4 + x4] * rstd * t_weight[x4]; + x4 = x4 + WG_SIZE; + } +} diff --git a/backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4_wgsl.h b/backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4_wgsl.h new file mode 100644 index 00000000000..633bf3adfc0 --- /dev/null +++ b/backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4_wgsl.h @@ -0,0 +1,98 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch::backends::webgpu { + +// @generated from rms_norm_vec4.wgsl - DO NOT EDIT. +// wgsl-sha256: 4c0ba56708bf125a7ec6ea3c51d1288e05ac00a8e2cfa10e38e9a208e230b8df +inline constexpr const char* kRmsNormVec4WGSL = R"( +@group(0) @binding(0) var t_out: array>; +@group(0) @binding(1) var t_in: array>; +@group(0) @binding(2) var t_weight: array>; + +struct Params { + num_rows: u32, + row_width: u32, + epsilon: f32, + _pad: u32, +} +@group(0) @binding(3) var params: Params; + +const WG_SIZE: u32 = 64u; + +var shared_sum: array; + +fn reduce_shared(worker_id: u32) { + workgroupBarrier(); + var stride: u32 = WG_SIZE / 2u; + loop { + if (stride == 0u) { + break; + } + if (worker_id < stride) { + shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride]; + } + workgroupBarrier(); + stride = stride >> 1u; + } +} + +// vec4 variant of rms_norm: each lane strides by WG_SIZE over rw4 = row_width/4 +// texels and accumulates dot(v, v). row_width is the ELEMENT count, so mean_sq +// divides by it (not rw4). The host selects this only when row_width % 4 == 0. +@compute @workgroup_size(64, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + let row_idx = wid.x; + let worker_id = lid.x; + + if (row_idx >= params.num_rows) { + return; + } + + let rw4 = params.row_width / 4u; + let base4 = row_idx * rw4; + + var local_sq_sum: f32 = 0.0; + var x4: u32 = worker_id; + loop { + if (x4 >= rw4) { + break; + } + let v = t_in[base4 + x4]; + local_sq_sum = local_sq_sum + dot(v, v); + x4 = x4 + WG_SIZE; + } + + shared_sum[worker_id] = local_sq_sum; + reduce_shared(worker_id); + + let mean_sq = shared_sum[0] / f32(params.row_width); + let rstd = inverseSqrt(mean_sq + params.epsilon); + + x4 = worker_id; + loop { + if (x4 >= rw4) { + break; + } + t_out[base4 + x4] = t_in[base4 + x4] * rstd * t_weight[x4]; + x4 = x4 + WG_SIZE; + } +} +)"; + +inline constexpr uint32_t kRmsNormVec4WorkgroupSizeX = 64; +inline constexpr uint32_t kRmsNormVec4WorkgroupSizeY = 1; +inline constexpr uint32_t kRmsNormVec4WorkgroupSizeZ = 1; + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/test/op_tests/cases.py b/backends/webgpu/test/op_tests/cases.py index 428c94d3066..febdbd507a8 100644 --- a/backends/webgpu/test/op_tests/cases.py +++ b/backends/webgpu/test/op_tests/cases.py @@ -37,7 +37,7 @@ RmsNormModule, ) -# rms_norm coverage is exactly the 14 cases the native test covered. +# rms_norm coverage is exactly the 15 cases the native test covered. RMS_NORM_CASES = _CASES diff --git a/backends/webgpu/test/op_tests/test_schema.py b/backends/webgpu/test/op_tests/test_schema.py index 9e62c9558cf..bcc03a40fd9 100644 --- a/backends/webgpu/test/op_tests/test_schema.py +++ b/backends/webgpu/test/op_tests/test_schema.py @@ -41,7 +41,7 @@ def test_add_rms_norm_registered(): assert {"add", "rms_norm"} <= set(op_test_registry) assert len(op_test_registry["add"].cases) >= 3 # regular/self/scalar/chained - # Exact parity, no hardcoded literal (real _CASES == 14; import so it can't drift): + # Exact parity, no hardcoded literal (real _CASES == 15; import so it can't drift): assert len(op_test_registry["rms_norm"].cases) == len(cases.RMS_NORM_CASES) # weight is a construction param, NOT a forward input: rms0 = op_test_registry["rms_norm"].cases[0] diff --git a/backends/webgpu/test/ops/rms_norm/test_rms_norm.py b/backends/webgpu/test/ops/rms_norm/test_rms_norm.py index d4f88de672a..57679d6d097 100644 --- a/backends/webgpu/test/ops/rms_norm/test_rms_norm.py +++ b/backends/webgpu/test/ops/rms_norm/test_rms_norm.py @@ -140,6 +140,7 @@ def _weight_zeros_neg(hidden: int) -> torch.Tensor: {"name": "distinct_rows", "shape": (1, 1, 5, 256), "input_fn": _distinct_rows}, {"name": "single_row", "shape": (1, 1, 1, 896)}, {"name": "mixed_sign", "shape": (1, 1, 4, 128), "input_fn": _mixed_sign}, + {"name": "llama_hidden_2048", "shape": (1, 1, 1, 2048)}, {"name": "large_4096", "shape": (1, 1, 1, 4096)}, {"name": "large_8192", "shape": (1, 1, 1, 8192)}, {