Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion backends/apple/metal/runtime/ops/op_sdpa.mm
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
constant uint3& qkv_head_strides [[buffer(6)]],
constant uint3& qkv_seq_strides [[buffer(7)]],
constant float& scale [[buffer(8)]],
const device T* mask [[buffer(9)]], // Changed from bool* to T* for floating point masks
const device T* mask [[buffer(9)]], // Must match Q/K/V dtype
constant uint3& mask_strides [[buffer(10)]],
constant bool& has_mask [[buffer(11)]],
constant uint3& qkv_batch_strides [[buffer(12)]], // NEW: batch strides for Q, K, V
Expand Down Expand Up @@ -537,6 +537,20 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
uint mask_q_seq_stride = 0;
if (has_mask_val) {
auto* mask_tensor = reinterpret_cast<Tensor*>(*attn_mask);

// Mask dtype must match Q/K/V dtype — the Metal kernel reads mask
// as device T* where T is the Q/K/V element type.
auto mask_dtype = static_cast<int32_t>(mask_tensor->scalar_type());
if (mask_dtype != dtype) {
ET_LOG(Error,
"aoti_torch_mps__scaled_dot_product_attention_math_for_mps: "
"mask dtype (%d) must match Q/K/V dtype (%d)",
mask_dtype, dtype);
throw std::runtime_error(
"SDPA mask dtype must match Q/K/V dtype. "
"Cast the mask to the model dtype before passing to SDPA.");
}
Comment on lines +541 to +552
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this new runtime validation, any existing SDPA test/module that creates an attn_mask without matching the model dtype (e.g., default float32 masks created via torch.zeros()) will start failing for bfloat16 runs. Update those call sites to construct/cast the mask to the Q/K/V dtype (or explicitly skip bfloat16) so the Metal test suite remains green.

Copilot uses AI. Check for mistakes.
Comment on lines +545 to +552
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logs the dtype mismatch and then throws; the surrounding try/catch already logs e.what(), so this can produce duplicate error logs. Consider consolidating into a single error message (either include the dtype values in the thrown message, or rely on the catch-side log) to reduce noise.

Copilot uses AI. Check for mistakes.

int nd = mask_tensor->dim();
mask_kv_seq_stride = (nd >= 1 && mask_tensor->sizes()[nd - 1] > 1) ? static_cast<uint>(mask_tensor->strides()[nd - 1]) : 0;
mask_q_seq_stride = (nd >= 2 && mask_tensor->sizes()[nd - 2] > 1) ? static_cast<uint>(mask_tensor->strides()[nd - 2]) : 0;
Expand Down
187 changes: 186 additions & 1 deletion backends/apple/metal/tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
value = torch.as_strided(self.value, size=self.v_size, stride=self.v_stride)
attn_mask = None
if self.attn_mask_size:
attn_mask = torch.zeros(self.attn_mask_size)
attn_mask = torch.zeros(self.attn_mask_size, dtype=x.dtype)

sdpa_output = torch.nn.functional.scaled_dot_product_attention(
query,
Expand Down Expand Up @@ -622,6 +622,152 @@ def __init__(self):
}


# -------------------------------------------------------------------------
class SDPAStridedFloatMask(nn.Module):
"""SDPA with strided (transposed) K/V and float additive mask.

Reproduces the streaming encoder attention pattern:
- Q: [B, H, seq_len, D] with seq_len=4 (one chunk)
- K/V: [B, H, buf_size, D] transposed from [B, buf_size, H, D] ring buffer
- mask: [seq_len, buf_size] float additive (0.0=attend, -inf=masked)

Only the first seq_len positions are valid; the rest are -inf masked.
This tests the Metal SDPA kernel's handling of bf16 Q/K/V with float32 masks.
Comment on lines +627 to +635
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says this model tests bf16 Q/K/V with float32 masks, but the mask is created with dtype=x.dtype (so it will be bf16 when x is bf16). Either update the docstring to match the actual behavior (mask dtype matches Q/K/V) or change mask construction/registry skips to reflect the intended float32-mask scenario.

Suggested change
"""SDPA with strided (transposed) K/V and float additive mask.
Reproduces the streaming encoder attention pattern:
- Q: [B, H, seq_len, D] with seq_len=4 (one chunk)
- K/V: [B, H, buf_size, D] transposed from [B, buf_size, H, D] ring buffer
- mask: [seq_len, buf_size] float additive (0.0=attend, -inf=masked)
Only the first seq_len positions are valid; the rest are -inf masked.
This tests the Metal SDPA kernel's handling of bf16 Q/K/V with float32 masks.
"""SDPA with strided (transposed) K/V and additive mask.
Reproduces the streaming encoder attention pattern:
- Q: [B, H, seq_len, D] with seq_len=4 (one chunk)
- K/V: [B, H, buf_size, D] transposed from [B, buf_size, H, D] ring buffer
- mask: [seq_len, buf_size] additive (0.0=attend, -inf=masked), with dtype matching Q/K/V
Only the first seq_len positions are valid; the rest are -inf masked.
This tests the Metal SDPA kernel's handling of bf16 Q/K/V with masks matching the Q/K/V dtype.

Copilot uses AI. Check for mistakes.
"""

def __init__(self, n_heads=32, seq_len=4, buf_size=128, head_dim=64):
super().__init__()
self.n_heads = n_heads
self.seq_len = seq_len
self.buf_size = buf_size
self.head_dim = head_dim

# K/V stored as [B, buf_size, H, D] (ring buffer layout), viewed as [B, H, buf_size, D]
self.k_buf = nn.Parameter(torch.randn(1, buf_size, n_heads, head_dim))
self.v_buf = nn.Parameter(torch.randn(1, buf_size, n_heads, head_dim))

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Q from input: [B, H, seq_len, D]
q = x

# K/V: transpose from [B, S, H, D] to [B, H, S, D] (simulates ring buffer view)
k = self.k_buf.transpose(1, 2)
v = self.v_buf.transpose(1, 2)

# Float additive mask: only first seq_len positions are valid
mask = torch.full(
(self.seq_len, self.buf_size), float("-inf"), device=x.device, dtype=x.dtype
)
for i in range(self.seq_len):
# Causal: query i attends to positions 0..i
mask[i, : i + 1] = 0.0

output = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
)
return output + x


MODULE_REGISTRY["sdpa_strided_float_mask"] = {
"model_class": SDPAStridedFloatMask,
"input_shapes": [(1, 32, 4, 64)],
"description": "Streaming encoder SDPA: strided K/V (transposed ring buffer) with float additive mask",
"atol_float32": 1e-4,
"atol_bfloat16": 5e-2,
}


# -------------------------------------------------------------------------
class SDPAStridedFloatMaskLarge(nn.Module):
"""Same as SDPAStridedFloatMask but with buf_size=1500 (production ring buffer size).

Tests that the bf16 precision issue scales with buffer size.
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring mentions a 'bf16 precision issue' scaling with buffer size, but the underlying bug described in the PR is mask dtype mismatch (float32 vs bf16) causing byte reinterpretation. Consider updating the wording here to describe what this test is actually exercising (e.g., larger KV length/strides) to avoid confusion.

Suggested change
Tests that the bf16 precision issue scales with buffer size.
Exercises SDPA with a production-size ring buffer (larger K/V length and strides)
and the associated float additive mask behavior.

Copilot uses AI. Check for mistakes.
"""

def __init__(self, n_heads=32, seq_len=4, buf_size=1500, head_dim=64):
super().__init__()
self.n_heads = n_heads
self.seq_len = seq_len
self.buf_size = buf_size
self.head_dim = head_dim

self.k_buf = nn.Parameter(torch.randn(1, buf_size, n_heads, head_dim))
self.v_buf = nn.Parameter(torch.randn(1, buf_size, n_heads, head_dim))

def forward(self, x: torch.Tensor) -> torch.Tensor:
q = x
k = self.k_buf.transpose(1, 2)
v = self.v_buf.transpose(1, 2)

mask = torch.full(
(self.seq_len, self.buf_size), float("-inf"), device=x.device, dtype=x.dtype
)
for i in range(self.seq_len):
mask[i, : i + 1] = 0.0

output = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
)
return output + x


MODULE_REGISTRY["sdpa_strided_float_mask_large"] = {
"model_class": SDPAStridedFloatMaskLarge,
"input_shapes": [(1, 32, 4, 64)],
"description": "Streaming encoder SDPA with production-size ring buffer (1500)",
"atol_float32": 1e-4,
"atol_bfloat16": 5e-2,
}


# -------------------------------------------------------------------------
class SDPAStridedFloatMaskDtypeMismatch(nn.Module):
"""SDPA with intentionally mismatched mask dtype (float32 mask with bf16 Q/K/V).

This should trigger the C++ dtype validation error in op_sdpa.mm when
the mask dtype doesn't match Q/K/V dtype. The kernel reads mask as
device T*, so a float32 mask with bf16 Q/K/V would be misinterpreted.
"""

def __init__(self, n_heads=32, seq_len=4, buf_size=128, head_dim=64):
super().__init__()
self.n_heads = n_heads
self.seq_len = seq_len
self.buf_size = buf_size
self.head_dim = head_dim

self.k_buf = nn.Parameter(torch.randn(1, buf_size, n_heads, head_dim))
self.v_buf = nn.Parameter(torch.randn(1, buf_size, n_heads, head_dim))

def forward(self, x: torch.Tensor) -> torch.Tensor:
q = x
k = self.k_buf.transpose(1, 2)
v = self.v_buf.transpose(1, 2)

# Intentionally create mask in float32 regardless of x.dtype
mask = torch.full(
(self.seq_len, self.buf_size),
float("-inf"),
device=x.device,
dtype=torch.float32,
)
for i in range(self.seq_len):
mask[i, : i + 1] = 0.0

output = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
)
return output + x


MODULE_REGISTRY["sdpa_strided_float_mask_dtype_mismatch"] = {
"model_class": SDPAStridedFloatMaskDtypeMismatch,
"input_shapes": [(1, 32, 4, 64)],
"description": "SDPA with mismatched mask dtype — tested via test_sdpa_mask_dtype_mismatch_rejected",
"skip": "Tested separately — intentional dtype mismatch for validation",
}


# =============================================================================
# Helper Functions
# =============================================================================
Expand Down Expand Up @@ -1137,6 +1283,45 @@ def run_test_in_directory(test_dir: Path) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
run_test_in_directory(Path(tmpdir))

@unittest.skipIf(SKIP_RUNTIME_TESTS, SKIP_RUNTIME_REASON)
def test_sdpa_mask_dtype_mismatch_rejected(self):
"""Verify that op_sdpa.mm rejects mask dtype != Q/K/V dtype at runtime.

Exports a bf16 model with a float32 mask (intentional mismatch).
The C++ kernel should raise an error instead of silently computing
wrong results.
"""
model, example_inputs = get_model_and_inputs(
"sdpa_strided_float_mask_dtype_mismatch", dtype=torch.bfloat16
)

with tempfile.TemporaryDirectory() as tmpdir:
model_output_dir = Path(tmpdir) / "dtype_mismatch"
model_output_dir.mkdir()

pte_path, _ = export_model_to_pte(
model,
example_inputs,
model_output_dir,
"sdpa_strided_float_mask_dtype_mismatch",
)
self.assertTrue(pte_path.exists())

output_base_path = model_output_dir / "output"
success, error_msg = run_executor_runner(pte_path, output_base_path)

self.assertFalse(
success,
"Expected executor_runner to fail with mask dtype mismatch, but it succeeded",
)
# The C++ validation throws inside the SDPA kernel, but the error
# surfaces through the AOTI wrapper as a generic "api call failed".
error_lower = error_msg.lower() if error_msg else ""
self.assertTrue(
"mask dtype" in error_lower or "api call failed" in error_lower,
f"Expected 'mask dtype' or 'api call failed' in error, got: {error_msg[:300] if error_msg else 'None'}",
)


# =============================================================================
# Dynamically generate test methods for each module and dtype in MODULE_REGISTRY
Expand Down
Loading