Metal backend: validate SDPA mask dtype matches Q/K/V dtype#17996
Metal backend: validate SDPA mask dtype matches Q/K/V dtype#17996mergennachin wants to merge 1 commit intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17996
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Unrelated FailureAs of commit 59a7627 with merge base 122fdef ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
This PR hardens the Apple Metal backend’s SDPA implementation by validating that the attention mask tensor dtype matches the Q/K/V dtype at runtime, preventing silent corruption caused by the kernel reading the mask buffer as device T*.
Changes:
- Add a runtime dtype check in
op_sdpa.mmthat rejectsattn_maskwhenmask.scalar_type() != Q/K/V.scalar_type(). - Add new Metal SDPA test modules that mimic a streaming encoder pattern with transposed (strided) K/V and additive masks.
- Add a runtime test that verifies dtype mismatches are rejected (bf16 Q/K/V with float32 mask).
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| backends/apple/metal/runtime/ops/op_sdpa.mm | Adds runtime validation to reject mask dtype mismatches before dispatching the Metal SDPA kernel. |
| backends/apple/metal/tests/test_modules.py | Adds SDPA strided/mask reproduction modules and a targeted runtime test for dtype mismatch rejection. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| """SDPA with strided (transposed) K/V and float additive mask. | ||
|
|
||
| Reproduces the streaming encoder attention pattern: | ||
| - Q: [B, H, seq_len, D] with seq_len=4 (one chunk) | ||
| - K/V: [B, H, buf_size, D] transposed from [B, buf_size, H, D] ring buffer | ||
| - mask: [seq_len, buf_size] float additive (0.0=attend, -inf=masked) | ||
|
|
||
| Only the first seq_len positions are valid; the rest are -inf masked. | ||
| This tests the Metal SDPA kernel's handling of bf16 Q/K/V with float32 masks. |
There was a problem hiding this comment.
The docstring says this model tests bf16 Q/K/V with float32 masks, but the mask is created with dtype=x.dtype (so it will be bf16 when x is bf16). Either update the docstring to match the actual behavior (mask dtype matches Q/K/V) or change mask construction/registry skips to reflect the intended float32-mask scenario.
| """SDPA with strided (transposed) K/V and float additive mask. | |
| Reproduces the streaming encoder attention pattern: | |
| - Q: [B, H, seq_len, D] with seq_len=4 (one chunk) | |
| - K/V: [B, H, buf_size, D] transposed from [B, buf_size, H, D] ring buffer | |
| - mask: [seq_len, buf_size] float additive (0.0=attend, -inf=masked) | |
| Only the first seq_len positions are valid; the rest are -inf masked. | |
| This tests the Metal SDPA kernel's handling of bf16 Q/K/V with float32 masks. | |
| """SDPA with strided (transposed) K/V and additive mask. | |
| Reproduces the streaming encoder attention pattern: | |
| - Q: [B, H, seq_len, D] with seq_len=4 (one chunk) | |
| - K/V: [B, H, buf_size, D] transposed from [B, buf_size, H, D] ring buffer | |
| - mask: [seq_len, buf_size] additive (0.0=attend, -inf=masked), with dtype matching Q/K/V | |
| Only the first seq_len positions are valid; the rest are -inf masked. | |
| This tests the Metal SDPA kernel's handling of bf16 Q/K/V with masks matching the Q/K/V dtype. |
| class SDPAStridedFloatMaskLarge(nn.Module): | ||
| """Same as SDPAStridedFloatMask but with buf_size=1500 (production ring buffer size). | ||
|
|
||
| Tests that the bf16 precision issue scales with buffer size. |
There was a problem hiding this comment.
This docstring mentions a 'bf16 precision issue' scaling with buffer size, but the underlying bug described in the PR is mask dtype mismatch (float32 vs bf16) causing byte reinterpretation. Consider updating the wording here to describe what this test is actually exercising (e.g., larger KV length/strides) to avoid confusion.
| Tests that the bf16 precision issue scales with buffer size. | |
| Exercises SDPA with a production-size ring buffer (larger K/V length and strides) | |
| and the associated float additive mask behavior. |
| // Mask dtype must match Q/K/V dtype — the Metal kernel reads mask | ||
| // as device T* where T is the Q/K/V element type. | ||
| auto mask_dtype = static_cast<int32_t>(mask_tensor->scalar_type()); | ||
| if (mask_dtype != dtype) { | ||
| ET_LOG(Error, | ||
| "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: " | ||
| "mask dtype (%d) must match Q/K/V dtype (%d)", | ||
| mask_dtype, dtype); | ||
| throw std::runtime_error( | ||
| "SDPA mask dtype must match Q/K/V dtype. " | ||
| "Cast the mask to the model dtype before passing to SDPA."); | ||
| } |
There was a problem hiding this comment.
With this new runtime validation, any existing SDPA test/module that creates an attn_mask without matching the model dtype (e.g., default float32 masks created via torch.zeros()) will start failing for bfloat16 runs. Update those call sites to construct/cast the mask to the Q/K/V dtype (or explicitly skip bfloat16) so the Metal test suite remains green.
| ET_LOG(Error, | ||
| "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: " | ||
| "mask dtype (%d) must match Q/K/V dtype (%d)", | ||
| mask_dtype, dtype); | ||
| throw std::runtime_error( | ||
| "SDPA mask dtype must match Q/K/V dtype. " | ||
| "Cast the mask to the model dtype before passing to SDPA."); | ||
| } |
There was a problem hiding this comment.
This logs the dtype mismatch and then throws; the surrounding try/catch already logs e.what(), so this can produce duplicate error logs. Consider consolidating into a single error message (either include the dtype values in the thrown message, or rely on the catch-side log) to reduce noise.
The Metal SDPA kernel reads the attention mask buffer as `device T*` (same type as Q/K/V). If the mask is created in a different dtype (e.g., float32 mask with bf16 Q/K/V), the kernel silently reads wrong values — interpreting float32 bytes as bf16, which destroys masking. Add a runtime dtype check in op_sdpa.mm that throws a clear error when mask dtype doesn't match Q/K/V dtype. Add test models (SDPAStridedFloatMask, SDPAStridedFloatMaskLarge, SDPAStridedFloatMaskDtypeMismatch) to test_modules.py that reproduce the streaming encoder attention pattern with strided K/V and float additive masks.
0a373ff to
59a7627
Compare
The Metal SDPA kernel reads the attention mask buffer as
device T*(same type as Q/K/V). If the mask is created in a different dtype
(e.g., float32 mask with bf16 Q/K/V), the kernel silently reads wrong
values — interpreting float32 bytes as bf16, which destroys masking.
Add a runtime dtype check in op_sdpa.mm that throws a clear error when
mask dtype doesn't match Q/K/V dtype.
Add test models (SDPAStridedFloatMask, SDPAStridedFloatMaskLarge,
SDPAStridedFloatMaskDtypeMismatch) to test_modules.py that reproduce
the streaming encoder attention pattern with strided K/V and float
additive masks.