diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl index e694a6b223c..9a5cd614f8b 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl @@ -85,9 +85,11 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[2] = vec4(0.0, 0.0, 0.0, 0.0); acc[3] = vec4(0.0, 0.0, 0.0, 0.0); + // Skip fully-masked causal tiles; mirrors Vulkan attn_weights_tiled.glsl. + let skip_tile = c0 > s0 + (TM - 1u) + params.input_pos; var d4: u32 = 0u; loop { - if (d4 >= params.D) { + if (d4 >= params.D || skip_tile) { break; } var q: array, TM>; diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h index 3d027b417f0..144af9b5956 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_attn_weights.wgsl - DO NOT EDIT. -// wgsl-sha256: e9bdc2272bba2716655e96be0597571701b4ec1496f78a85660e29d108f655f9 +// wgsl-sha256: d177264689e6c50e1794a0599808f3cfe6f30ba99c5084d3c8324da4b9f89d10 inline constexpr const char* kSdpaComputeAttnWeightsWGSL = R"( @group(0) @binding(0) var t_attn_weights: array; @group(0) @binding(1) var t_q: array; @@ -102,9 +102,11 @@ fn main(@builtin(global_invocation_id) gid: vec3) { acc[2] = vec4(0.0, 0.0, 0.0, 0.0); acc[3] = vec4(0.0, 0.0, 0.0, 0.0); + // Skip fully-masked causal tiles; mirrors Vulkan attn_weights_tiled.glsl. + let skip_tile = c0 > s0 + (TM - 1u) + params.input_pos; var d4: u32 = 0u; loop { - if (d4 >= params.D) { + if (d4 >= params.D || skip_tile) { break; } var q: array, TM>;