Skip to content

Metal backend: validate SDPA mask dtype matches Q/K/V dtype#17996

Draft
mergennachin wants to merge 1 commit intomainfrom
metal-sdpa-mask-dtype
Draft

Metal backend: validate SDPA mask dtype matches Q/K/V dtype#17996
mergennachin wants to merge 1 commit intomainfrom
metal-sdpa-mask-dtype

Conversation

@mergennachin
Copy link
Contributor

The Metal SDPA kernel reads the attention mask buffer as device T*
(same type as Q/K/V). If the mask is created in a different dtype
(e.g., float32 mask with bf16 Q/K/V), the kernel silently reads wrong
values — interpreting float32 bytes as bf16, which destroys masking.

Add a runtime dtype check in op_sdpa.mm that throws a clear error when
mask dtype doesn't match Q/K/V dtype.

Add test models (SDPAStridedFloatMask, SDPAStridedFloatMaskLarge,
SDPAStridedFloatMaskDtypeMismatch) to test_modules.py that reproduce
the streaming encoder attention pattern with strided K/V and float
additive masks.

Copilot AI review requested due to automatic review settings March 8, 2026 11:54
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 8, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17996

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Unrelated Failure

As of commit 59a7627 with merge base 122fdef (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 8, 2026
@github-actions
Copy link

github-actions bot commented Mar 8, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR hardens the Apple Metal backend’s SDPA implementation by validating that the attention mask tensor dtype matches the Q/K/V dtype at runtime, preventing silent corruption caused by the kernel reading the mask buffer as device T*.

Changes:

  • Add a runtime dtype check in op_sdpa.mm that rejects attn_mask when mask.scalar_type() != Q/K/V.scalar_type().
  • Add new Metal SDPA test modules that mimic a streaming encoder pattern with transposed (strided) K/V and additive masks.
  • Add a runtime test that verifies dtype mismatches are rejected (bf16 Q/K/V with float32 mask).

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
backends/apple/metal/runtime/ops/op_sdpa.mm Adds runtime validation to reject mask dtype mismatches before dispatching the Metal SDPA kernel.
backends/apple/metal/tests/test_modules.py Adds SDPA strided/mask reproduction modules and a targeted runtime test for dtype mismatch rejection.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +627 to +635
"""SDPA with strided (transposed) K/V and float additive mask.

Reproduces the streaming encoder attention pattern:
- Q: [B, H, seq_len, D] with seq_len=4 (one chunk)
- K/V: [B, H, buf_size, D] transposed from [B, buf_size, H, D] ring buffer
- mask: [seq_len, buf_size] float additive (0.0=attend, -inf=masked)

Only the first seq_len positions are valid; the rest are -inf masked.
This tests the Metal SDPA kernel's handling of bf16 Q/K/V with float32 masks.
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says this model tests bf16 Q/K/V with float32 masks, but the mask is created with dtype=x.dtype (so it will be bf16 when x is bf16). Either update the docstring to match the actual behavior (mask dtype matches Q/K/V) or change mask construction/registry skips to reflect the intended float32-mask scenario.

Suggested change
"""SDPA with strided (transposed) K/V and float additive mask.
Reproduces the streaming encoder attention pattern:
- Q: [B, H, seq_len, D] with seq_len=4 (one chunk)
- K/V: [B, H, buf_size, D] transposed from [B, buf_size, H, D] ring buffer
- mask: [seq_len, buf_size] float additive (0.0=attend, -inf=masked)
Only the first seq_len positions are valid; the rest are -inf masked.
This tests the Metal SDPA kernel's handling of bf16 Q/K/V with float32 masks.
"""SDPA with strided (transposed) K/V and additive mask.
Reproduces the streaming encoder attention pattern:
- Q: [B, H, seq_len, D] with seq_len=4 (one chunk)
- K/V: [B, H, buf_size, D] transposed from [B, buf_size, H, D] ring buffer
- mask: [seq_len, buf_size] additive (0.0=attend, -inf=masked), with dtype matching Q/K/V
Only the first seq_len positions are valid; the rest are -inf masked.
This tests the Metal SDPA kernel's handling of bf16 Q/K/V with masks matching the Q/K/V dtype.

Copilot uses AI. Check for mistakes.
class SDPAStridedFloatMaskLarge(nn.Module):
"""Same as SDPAStridedFloatMask but with buf_size=1500 (production ring buffer size).

Tests that the bf16 precision issue scales with buffer size.
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring mentions a 'bf16 precision issue' scaling with buffer size, but the underlying bug described in the PR is mask dtype mismatch (float32 vs bf16) causing byte reinterpretation. Consider updating the wording here to describe what this test is actually exercising (e.g., larger KV length/strides) to avoid confusion.

Suggested change
Tests that the bf16 precision issue scales with buffer size.
Exercises SDPA with a production-size ring buffer (larger K/V length and strides)
and the associated float additive mask behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +541 to +552
// Mask dtype must match Q/K/V dtype — the Metal kernel reads mask
// as device T* where T is the Q/K/V element type.
auto mask_dtype = static_cast<int32_t>(mask_tensor->scalar_type());
if (mask_dtype != dtype) {
ET_LOG(Error,
"aoti_torch_mps__scaled_dot_product_attention_math_for_mps: "
"mask dtype (%d) must match Q/K/V dtype (%d)",
mask_dtype, dtype);
throw std::runtime_error(
"SDPA mask dtype must match Q/K/V dtype. "
"Cast the mask to the model dtype before passing to SDPA.");
}
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this new runtime validation, any existing SDPA test/module that creates an attn_mask without matching the model dtype (e.g., default float32 masks created via torch.zeros()) will start failing for bfloat16 runs. Update those call sites to construct/cast the mask to the Q/K/V dtype (or explicitly skip bfloat16) so the Metal test suite remains green.

Copilot uses AI. Check for mistakes.
Comment on lines +545 to +552
ET_LOG(Error,
"aoti_torch_mps__scaled_dot_product_attention_math_for_mps: "
"mask dtype (%d) must match Q/K/V dtype (%d)",
mask_dtype, dtype);
throw std::runtime_error(
"SDPA mask dtype must match Q/K/V dtype. "
"Cast the mask to the model dtype before passing to SDPA.");
}
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logs the dtype mismatch and then throws; the surrounding try/catch already logs e.what(), so this can produce duplicate error logs. Consider consolidating into a single error message (either include the dtype values in the thrown message, or rely on the catch-side log) to reduce noise.

Copilot uses AI. Check for mistakes.
@mergennachin mergennachin marked this pull request as draft March 8, 2026 12:13
The Metal SDPA kernel reads the attention mask buffer as `device T*`
(same type as Q/K/V). If the mask is created in a different dtype
(e.g., float32 mask with bf16 Q/K/V), the kernel silently reads wrong
values — interpreting float32 bytes as bf16, which destroys masking.

Add a runtime dtype check in op_sdpa.mm that throws a clear error when
mask dtype doesn't match Q/K/V dtype.

Add test models (SDPAStridedFloatMask, SDPAStridedFloatMaskLarge,
SDPAStridedFloatMaskDtypeMismatch) to test_modules.py that reproduce
the streaming encoder attention pattern with strided K/V and float
additive masks.
@mergennachin mergennachin force-pushed the metal-sdpa-mask-dtype branch from 0a373ff to 59a7627 Compare March 8, 2026 12:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants