-
Notifications
You must be signed in to change notification settings - Fork 870
Metal backend: validate SDPA mask dtype matches Q/K/V dtype #17996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,7 +40,7 @@ | |
| constant uint3& qkv_head_strides [[buffer(6)]], | ||
| constant uint3& qkv_seq_strides [[buffer(7)]], | ||
| constant float& scale [[buffer(8)]], | ||
| const device T* mask [[buffer(9)]], // Changed from bool* to T* for floating point masks | ||
| const device T* mask [[buffer(9)]], // Must match Q/K/V dtype | ||
| constant uint3& mask_strides [[buffer(10)]], | ||
| constant bool& has_mask [[buffer(11)]], | ||
| constant uint3& qkv_batch_strides [[buffer(12)]], // NEW: batch strides for Q, K, V | ||
|
|
@@ -537,6 +537,20 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( | |
| uint mask_q_seq_stride = 0; | ||
| if (has_mask_val) { | ||
| auto* mask_tensor = reinterpret_cast<Tensor*>(*attn_mask); | ||
|
|
||
| // Mask dtype must match Q/K/V dtype — the Metal kernel reads mask | ||
| // as device T* where T is the Q/K/V element type. | ||
| auto mask_dtype = static_cast<int32_t>(mask_tensor->scalar_type()); | ||
| if (mask_dtype != dtype) { | ||
| ET_LOG(Error, | ||
| "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: " | ||
| "mask dtype (%d) must match Q/K/V dtype (%d)", | ||
| mask_dtype, dtype); | ||
| throw std::runtime_error( | ||
| "SDPA mask dtype must match Q/K/V dtype. " | ||
| "Cast the mask to the model dtype before passing to SDPA."); | ||
| } | ||
|
Comment on lines
+545
to
+552
|
||
|
|
||
| int nd = mask_tensor->dim(); | ||
| mask_kv_seq_stride = (nd >= 1 && mask_tensor->sizes()[nd - 1] > 1) ? static_cast<uint>(mask_tensor->strides()[nd - 1]) : 0; | ||
| mask_q_seq_stride = (nd >= 2 && mask_tensor->sizes()[nd - 2] > 1) ? static_cast<uint>(mask_tensor->strides()[nd - 2]) : 0; | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -499,7 +499,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||||||||||||||||||||||||||||||||||
| value = torch.as_strided(self.value, size=self.v_size, stride=self.v_stride) | ||||||||||||||||||||||||||||||||||||||
| attn_mask = None | ||||||||||||||||||||||||||||||||||||||
| if self.attn_mask_size: | ||||||||||||||||||||||||||||||||||||||
| attn_mask = torch.zeros(self.attn_mask_size) | ||||||||||||||||||||||||||||||||||||||
| attn_mask = torch.zeros(self.attn_mask_size, dtype=x.dtype) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| sdpa_output = torch.nn.functional.scaled_dot_product_attention( | ||||||||||||||||||||||||||||||||||||||
| query, | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -622,6 +622,152 @@ def __init__(self): | |||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # ------------------------------------------------------------------------- | ||||||||||||||||||||||||||||||||||||||
| class SDPAStridedFloatMask(nn.Module): | ||||||||||||||||||||||||||||||||||||||
| """SDPA with strided (transposed) K/V and float additive mask. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Reproduces the streaming encoder attention pattern: | ||||||||||||||||||||||||||||||||||||||
| - Q: [B, H, seq_len, D] with seq_len=4 (one chunk) | ||||||||||||||||||||||||||||||||||||||
| - K/V: [B, H, buf_size, D] transposed from [B, buf_size, H, D] ring buffer | ||||||||||||||||||||||||||||||||||||||
| - mask: [seq_len, buf_size] float additive (0.0=attend, -inf=masked) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Only the first seq_len positions are valid; the rest are -inf masked. | ||||||||||||||||||||||||||||||||||||||
| This tests the Metal SDPA kernel's handling of bf16 Q/K/V with float32 masks. | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+627
to
+635
|
||||||||||||||||||||||||||||||||||||||
| """SDPA with strided (transposed) K/V and float additive mask. | |
| Reproduces the streaming encoder attention pattern: | |
| - Q: [B, H, seq_len, D] with seq_len=4 (one chunk) | |
| - K/V: [B, H, buf_size, D] transposed from [B, buf_size, H, D] ring buffer | |
| - mask: [seq_len, buf_size] float additive (0.0=attend, -inf=masked) | |
| Only the first seq_len positions are valid; the rest are -inf masked. | |
| This tests the Metal SDPA kernel's handling of bf16 Q/K/V with float32 masks. | |
| """SDPA with strided (transposed) K/V and additive mask. | |
| Reproduces the streaming encoder attention pattern: | |
| - Q: [B, H, seq_len, D] with seq_len=4 (one chunk) | |
| - K/V: [B, H, buf_size, D] transposed from [B, buf_size, H, D] ring buffer | |
| - mask: [seq_len, buf_size] additive (0.0=attend, -inf=masked), with dtype matching Q/K/V | |
| Only the first seq_len positions are valid; the rest are -inf masked. | |
| This tests the Metal SDPA kernel's handling of bf16 Q/K/V with masks matching the Q/K/V dtype. |
Copilot
AI
Mar 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This docstring mentions a 'bf16 precision issue' scaling with buffer size, but the underlying bug described in the PR is mask dtype mismatch (float32 vs bf16) causing byte reinterpretation. Consider updating the wording here to describe what this test is actually exercising (e.g., larger KV length/strides) to avoid confusion.
| Tests that the bf16 precision issue scales with buffer size. | |
| Exercises SDPA with a production-size ring buffer (larger K/V length and strides) | |
| and the associated float additive mask behavior. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this new runtime validation, any existing SDPA test/module that creates an attn_mask without matching the model dtype (e.g., default float32 masks created via torch.zeros()) will start failing for bfloat16 runs. Update those call sites to construct/cast the mask to the Q/K/V dtype (or explicitly skip bfloat16) so the Metal test suite remains green.