[ExecuTorch][WebGPU] SDPA: branchless aligned/tail loads in the QK/AV kernels#20493
[ExecuTorch][WebGPU] SDPA: branchless aligned/tail loads in the QK/AV kernels#20493JulianCloudNTH wants to merge 3 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20493
Note: Links to docs will display an error until the docs builds have been completed. ❗ 2 Active SEVsThere are 2 currently active SEVs. If your PR is affected, please view them below:
❌ 1 New Failure, 11 Pending, 2 Unrelated FailuresAs of commit 192f09f with merge base e03f777 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
… kernels Pull Request resolved: #20493 **Branchless aligned/tail loads + vec4 storage bindings** — drop the always-true per-lane bounds checks in the tiled QK/AV hot loops, split the AV context contraction into a branch-free aligned body plus a checked tail, and declare the head-dim-indexed SDPA storage buffers as `array<vec4<f32>>` so the loads/stores are forced-vectorized (addresses review feedback to mirror Vulkan's vec4 bindings). **Problem**: The tiled QK/AV vec4 loaders run 4 per-lane `if` bounds checks on every load, every contraction iteration (8 loads/iter). But `head_dim` is always a multiple of 4, so the D-axis checks never fire, and the AV context axis only needs a bounds check on the last ragged chunk. Separately the storage buffers were declared `array<f32>`, so the 4-lane loads/stores were not guaranteed to compile to aligned 128-bit vector accesses. **Solution**: Remove the dead checks, split the ragged axis, and vectorize the bindings: - **Before**: `load_q_vec4`/`load_k_vec4` (and AV `load_a_vec4`/`load_v_d4`) do 4 per-lane bounds `if`s per call; the AV `c4` loop runs checked loads for every chunk; `t_q`/`t_k_cache`/`t_v_cache`/`t_out` are `array<f32>` accessed element-by-element. - **After**: QK loads are a plain unchecked `vec4` (D%4==0, host-guarded); AV runs a branch-free aligned body over `c4 in [0, context_len - context_len%4)` then a 0-or-1 checked tail; the head-dim-indexed buffers `t_q`/`t_k_cache`/`t_v_cache`/`t_out` are `array<vec4<f32>>` indexed `[base/4u]`, and AV writes a single aligned `store_out_vec4`. **Implementation**: - QK: `load_q_vec4`/`load_k_vec4` drop the per-lane D checks and return `t_q[base/4u]` / `t_k_cache[base/4u]`. - AV: branch-free `load_a_vec4_nc`/`load_v_d4_nc` for the aligned body; checked `load_a_vec4`/`load_v_d4` for the tail; V reads `t_v_cache[base/4u]`; output is one aligned `store_out_vec4`. - Bindings: `t_q`, `t_k_cache` (QK) and `t_v_cache`, `t_out` (AV) are `array<vec4<f32>>`. `t_attn_weights` and the softmax buffer stay `array<f32>` — they are `context_len`-indexed (row stride not 4-aligned) and written per-element under the causal mask, so a `vec4` binding there would need a padded scratch row. - Host: add a `D % 4 == 0` guard in `Sdpa.cpp` — WGSL has no `SDPA_PAD_D` pad-load, so fail loud rather than read past the row; this guard also makes every `[base/4u]` index 4-aligned and every buffer a 16-byte multiple. - Test: add a `reject_d6` (head_dim=6) config + an `expect_reject` harness branch asserting the guard rejects a non-aligned head_dim at load. - Mirrors Vulkan `sdpa_compute_out_tiled.glsl` (aligned/tail split) and Vulkan's `array<vec4>` SDPA bindings. **Constraints**: - Requires `head_dim % 4 == 0` (true for every Llama config, D=64); enforced by a loud host throw, not a silent narrowing. - Bit-identical output: the aligned body processes the same chunks in the same accumulation order as the scalar loop, the tail's out-of-range lanes contribute 0, and the `vec4` bindings read/write the same bytes as the scalar version. - No KV-cache layout, dispatch, or uniform change. Co-authored with Claude Code. ghstack-source-id: 396717582 @exported-using-ghexport Differential Revision: [D109521069](https://our.internmc.facebook.com/intern/diff/D109521069/)
SS-JIA
left a comment
There was a problem hiding this comment.
Review automatically exported from Phabricator review in Meta.
… kernels Pull Request resolved: #20493 **Branchless aligned/tail loads + vec4 storage bindings** — drop the always-true per-lane bounds checks in the tiled QK/AV hot loops, split the AV context contraction into a branch-free aligned body plus a checked tail, and declare the head-dim-indexed SDPA storage buffers as `array<vec4<f32>>` so the loads/stores are forced-vectorized (addresses review feedback to mirror Vulkan's vec4 bindings). **Problem**: The tiled QK/AV vec4 loaders run 4 per-lane `if` bounds checks on every load, every contraction iteration (8 loads/iter). But `head_dim` is always a multiple of 4, so the D-axis checks never fire, and the AV context axis only needs a bounds check on the last ragged chunk. Separately the storage buffers were declared `array<f32>`, so the 4-lane loads/stores were not guaranteed to compile to aligned 128-bit vector accesses. **Solution**: Remove the dead checks, split the ragged axis, and vectorize the bindings: - **Before**: `load_q_vec4`/`load_k_vec4` (and AV `load_a_vec4`/`load_v_d4`) do 4 per-lane bounds `if`s per call; the AV `c4` loop runs checked loads for every chunk; `t_q`/`t_k_cache`/`t_v_cache`/`t_out` are `array<f32>` accessed element-by-element. - **After**: QK loads are a plain unchecked `vec4` (D%4==0, host-guarded); AV runs a branch-free aligned body over `c4 in [0, context_len - context_len%4)` then a 0-or-1 checked tail; the head-dim-indexed buffers `t_q`/`t_k_cache`/`t_v_cache`/`t_out` are `array<vec4<f32>>` indexed `[base/4u]`, and AV writes a single aligned `store_out_vec4`. **Implementation**: - QK: `load_q_vec4`/`load_k_vec4` drop the per-lane D checks and return `t_q[base/4u]` / `t_k_cache[base/4u]`. - AV: branch-free `load_a_vec4_nc`/`load_v_d4_nc` for the aligned body; checked `load_a_vec4`/`load_v_d4` for the tail; V reads `t_v_cache[base/4u]`; output is one aligned `store_out_vec4`. - Bindings: `t_q`, `t_k_cache` (QK) and `t_v_cache`, `t_out` (AV) are `array<vec4<f32>>`. `t_attn_weights` and the softmax buffer stay `array<f32>` — they are `context_len`-indexed (row stride not 4-aligned) and written per-element under the causal mask, so a `vec4` binding there would need a padded scratch row. - Host: add a `D % 4 == 0` guard in `Sdpa.cpp` — WGSL has no `SDPA_PAD_D` pad-load, so fail loud rather than read past the row; this guard also makes every `[base/4u]` index 4-aligned and every buffer a 16-byte multiple. - Test: add a `reject_d6` (head_dim=6) config + an `expect_reject` harness branch asserting the guard rejects a non-aligned head_dim at load. - Mirrors Vulkan `sdpa_compute_out_tiled.glsl` (aligned/tail split) and Vulkan's `array<vec4>` SDPA bindings. **Constraints**: - Requires `head_dim % 4 == 0` (true for every Llama config, D=64); enforced by a loud host throw, not a silent narrowing. - Bit-identical output: the aligned body processes the same chunks in the same accumulation order as the scalar loop, the tail's out-of-range lanes contribute 0, and the `vec4` bindings read/write the same bytes as the scalar version. - No KV-cache layout, dispatch, or uniform change. Co-authored with Claude Code. ghstack-source-id: 396792517 @exported-using-ghexport Differential Revision: [D109521069](https://our.internmc.facebook.com/intern/diff/D109521069/)
Stack from ghstack (oldest at bottom):
Branchless aligned/tail loads + vec4 storage bindings — drop the always-true per-lane bounds checks in the tiled QK/AV hot loops, split the AV context contraction into a branch-free aligned body plus a checked tail, and declare the head-dim-indexed SDPA storage buffers as
array<vec4<f32>>so the loads/stores are forced-vectorized (addresses review feedback to mirror Vulkan's vec4 bindings).Problem: The tiled QK/AV vec4 loaders run 4 per-lane
ifbounds checks on every load, every contraction iteration (8 loads/iter). Buthead_dimis always a multiple of 4, so the D-axis checks never fire, and the AV context axis only needs a bounds check on the last ragged chunk. Separately the storage buffers were declaredarray<f32>, so the 4-lane loads/stores were not guaranteed to compile to aligned 128-bit vector accesses.Solution: Remove the dead checks, split the ragged axis, and vectorize the bindings:
load_q_vec4/load_k_vec4(and AVload_a_vec4/load_v_d4) do 4 per-lane boundsifs per call; the AVc4loop runs checked loads for every chunk;t_q/t_k_cache/t_v_cache/t_outarearray<f32>accessed element-by-element.vec4(D%4==0, host-guarded); AV runs a branch-free aligned body overc4 in [0, context_len - context_len%4)then a 0-or-1 checked tail; the head-dim-indexed bufferst_q/t_k_cache/t_v_cache/t_outarearray<vec4<f32>>indexed[base/4u], and AV writes a single alignedstore_out_vec4.Implementation:
load_q_vec4/load_k_vec4drop the per-lane D checks and returnt_q[base/4u]/t_k_cache[base/4u].load_a_vec4_nc/load_v_d4_ncfor the aligned body; checkedload_a_vec4/load_v_d4for the tail; V readst_v_cache[base/4u]; output is one alignedstore_out_vec4.t_q,t_k_cache(QK) andt_v_cache,t_out(AV) arearray<vec4<f32>>.t_attn_weightsand the softmax buffer stayarray<f32>— they arecontext_len-indexed (row stride not 4-aligned) and written per-element under the causal mask, so avec4binding there would need a padded scratch row.D % 4 == 0guard inSdpa.cpp— WGSL has noSDPA_PAD_Dpad-load, so fail loud rather than read past the row; this guard also makes every[base/4u]index 4-aligned and every buffer a 16-byte multiple.reject_d6(head_dim=6) config + anexpect_rejectharness branch asserting the guard rejects a non-aligned head_dim at load.sdpa_compute_out_tiled.glsl(aligned/tail split) and Vulkan'sarray<vec4>SDPA bindings.Constraints:
head_dim % 4 == 0(true for every Llama config, D=64); enforced by a loud host throw, not a silent narrowing.vec4bindings read/write the same bytes as the scalar version.Co-authored with Claude Code.
@exported-using-ghexport
Differential Revision: D109521069
Differential Revision: D109521069