[ExecuTorch][WebGPU] Register-tile the SDPA QK/AV kernels#20507
[ExecuTorch][WebGPU] Register-tile the SDPA QK/AV kernels#20507pytorchbot wants to merge 1 commit into
Conversation
Pull Request resolved: #20405 **+32% SDPA attention-compute (AV +40%)** — register-tile the QK and AV kernels (isolated GPU-timestamp A/B, decode S=1, Chrome Canary / M4 Pro). A kernel-time win, not a wall-clock `forward()` win — `forward()` stays bound by the submit/sync/readback floor (the separate fusion axis). **Problem**: The naive QK/AV kernels compute one output element per thread, so each thread re-loads Q/K/V and the dot products are scalar — poor register reuse, ALU/latency-bound. **Solution**: Each thread computes a 4×4 output tile with the dot products vec4-packed in registers: - **Before**: one thread per output element; scalar accumulate over D (QK) / context (AV). - **After**: one thread per `(head, S-tile, {ctx,D}-tile)`; 4×4 register tile, vec4 dot products. A floating-point accumulation reorder of the same products — no algorithm change. **Implementation**: - `sdpa_compute_attn_weights.wgsl` (QK): one thread per `(head, S-tile, ctx-tile)`, grid `Hq · ceil(S/4) · ceil(ctx/4)`; tile registers are `array<vec4<f32>, TM/TN>` loaded via `for` loops. - `sdpa_compute_out.wgsl` (AV): one thread per `(head, S-tile, D-tile)`, grid `Hq · ceil(S/4) · ceil(D/4)`. - `Sdpa.cpp`: dispatch math moves from an element count to a tile count (`kSdpaTileM/N=4`, shared `utils::div_up`), keeping the uint32 scratch-overflow guard. - Mirrors the Vulkan register-tiled SDPA kernels; the shared `utils::div_up` mirrors Vulkan's `utils::div_up`. **Constraints**: - softmax, `update_cache`, the bind-group layouts, and the scratch-buffer sizes (`Hq*S*ctx`) are unchanged. - Scope is tiling only — causal tile-skip, V-cache coalescing, and branchless aligned/tail loads are separate follow-ups; this diff intentionally omits the Vulkan causal tile-skip since it is correctness-neutral (the per-element mask in `store_qk` is identical). See DESIGN_DECISIONS.md. - Output matches the naive kernels within fp tolerance (accumulation reorder only). ghstack-source-id: 396792505 @exported-using-ghexport Differential Revision: [D109081409](https://our.internmc.facebook.com/intern/diff/D109081409/)
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20507
Note: Links to docs will display an error until the docs builds have been completed. ❗ 2 Active SEVsThere are 2 currently active SEVs. If your PR is affected, please view them below:
❌ 1 New Failure, 2 Unrelated FailuresAs of commit 37abadb with merge base e03f777 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
This PR was created by the merge bot to help merge the original PR into the main branch.
ghstack PR number: #20405 by @JulianCloudNTH
^ Please use this as the source of truth for the PR details, comments, and reviews
ghstack PR base: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/49/base
ghstack PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/49/head
Merge bot PR base: https://github.com/pytorch/executorch/tree/main
Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/49/orig
@diff-train-skip-merge