From 59a7627726a1c88f685b54d1b18996c032c24460 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Sun, 8 Mar 2026 07:48:08 -0400 Subject: [PATCH] Metal backend: validate SDPA mask dtype matches Q/K/V dtype MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Metal SDPA kernel reads the attention mask buffer as `device T*` (same type as Q/K/V). If the mask is created in a different dtype (e.g., float32 mask with bf16 Q/K/V), the kernel silently reads wrong values — interpreting float32 bytes as bf16, which destroys masking. Add a runtime dtype check in op_sdpa.mm that throws a clear error when mask dtype doesn't match Q/K/V dtype. Add test models (SDPAStridedFloatMask, SDPAStridedFloatMaskLarge, SDPAStridedFloatMaskDtypeMismatch) to test_modules.py that reproduce the streaming encoder attention pattern with strided K/V and float additive masks. --- backends/apple/metal/runtime/ops/op_sdpa.mm | 16 +- backends/apple/metal/tests/test_modules.py | 187 +++++++++++++++++++- 2 files changed, 201 insertions(+), 2 deletions(-) diff --git a/backends/apple/metal/runtime/ops/op_sdpa.mm b/backends/apple/metal/runtime/ops/op_sdpa.mm index fdaabcf6b0b..915779a1803 100644 --- a/backends/apple/metal/runtime/ops/op_sdpa.mm +++ b/backends/apple/metal/runtime/ops/op_sdpa.mm @@ -40,7 +40,7 @@ constant uint3& qkv_head_strides [[buffer(6)]], constant uint3& qkv_seq_strides [[buffer(7)]], constant float& scale [[buffer(8)]], - const device T* mask [[buffer(9)]], // Changed from bool* to T* for floating point masks + const device T* mask [[buffer(9)]], // Must match Q/K/V dtype constant uint3& mask_strides [[buffer(10)]], constant bool& has_mask [[buffer(11)]], constant uint3& qkv_batch_strides [[buffer(12)]], // NEW: batch strides for Q, K, V @@ -537,6 +537,20 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( uint mask_q_seq_stride = 0; if (has_mask_val) { auto* mask_tensor = reinterpret_cast(*attn_mask); + + // Mask dtype must match Q/K/V dtype — the Metal kernel reads mask + // as device T* where T is the Q/K/V element type. + auto mask_dtype = static_cast(mask_tensor->scalar_type()); + if (mask_dtype != dtype) { + ET_LOG(Error, + "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: " + "mask dtype (%d) must match Q/K/V dtype (%d)", + mask_dtype, dtype); + throw std::runtime_error( + "SDPA mask dtype must match Q/K/V dtype. " + "Cast the mask to the model dtype before passing to SDPA."); + } + int nd = mask_tensor->dim(); mask_kv_seq_stride = (nd >= 1 && mask_tensor->sizes()[nd - 1] > 1) ? static_cast(mask_tensor->strides()[nd - 1]) : 0; mask_q_seq_stride = (nd >= 2 && mask_tensor->sizes()[nd - 2] > 1) ? static_cast(mask_tensor->strides()[nd - 2]) : 0; diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index ec8c6078a85..820288de9b2 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -499,7 +499,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: value = torch.as_strided(self.value, size=self.v_size, stride=self.v_stride) attn_mask = None if self.attn_mask_size: - attn_mask = torch.zeros(self.attn_mask_size) + attn_mask = torch.zeros(self.attn_mask_size, dtype=x.dtype) sdpa_output = torch.nn.functional.scaled_dot_product_attention( query, @@ -622,6 +622,152 @@ def __init__(self): } +# ------------------------------------------------------------------------- +class SDPAStridedFloatMask(nn.Module): + """SDPA with strided (transposed) K/V and float additive mask. + + Reproduces the streaming encoder attention pattern: + - Q: [B, H, seq_len, D] with seq_len=4 (one chunk) + - K/V: [B, H, buf_size, D] transposed from [B, buf_size, H, D] ring buffer + - mask: [seq_len, buf_size] float additive (0.0=attend, -inf=masked) + + Only the first seq_len positions are valid; the rest are -inf masked. + This tests the Metal SDPA kernel's handling of bf16 Q/K/V with float32 masks. + """ + + def __init__(self, n_heads=32, seq_len=4, buf_size=128, head_dim=64): + super().__init__() + self.n_heads = n_heads + self.seq_len = seq_len + self.buf_size = buf_size + self.head_dim = head_dim + + # K/V stored as [B, buf_size, H, D] (ring buffer layout), viewed as [B, H, buf_size, D] + self.k_buf = nn.Parameter(torch.randn(1, buf_size, n_heads, head_dim)) + self.v_buf = nn.Parameter(torch.randn(1, buf_size, n_heads, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Q from input: [B, H, seq_len, D] + q = x + + # K/V: transpose from [B, S, H, D] to [B, H, S, D] (simulates ring buffer view) + k = self.k_buf.transpose(1, 2) + v = self.v_buf.transpose(1, 2) + + # Float additive mask: only first seq_len positions are valid + mask = torch.full( + (self.seq_len, self.buf_size), float("-inf"), device=x.device, dtype=x.dtype + ) + for i in range(self.seq_len): + # Causal: query i attends to positions 0..i + mask[i, : i + 1] = 0.0 + + output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + return output + x + + +MODULE_REGISTRY["sdpa_strided_float_mask"] = { + "model_class": SDPAStridedFloatMask, + "input_shapes": [(1, 32, 4, 64)], + "description": "Streaming encoder SDPA: strided K/V (transposed ring buffer) with float additive mask", + "atol_float32": 1e-4, + "atol_bfloat16": 5e-2, +} + + +# ------------------------------------------------------------------------- +class SDPAStridedFloatMaskLarge(nn.Module): + """Same as SDPAStridedFloatMask but with buf_size=1500 (production ring buffer size). + + Tests that the bf16 precision issue scales with buffer size. + """ + + def __init__(self, n_heads=32, seq_len=4, buf_size=1500, head_dim=64): + super().__init__() + self.n_heads = n_heads + self.seq_len = seq_len + self.buf_size = buf_size + self.head_dim = head_dim + + self.k_buf = nn.Parameter(torch.randn(1, buf_size, n_heads, head_dim)) + self.v_buf = nn.Parameter(torch.randn(1, buf_size, n_heads, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + q = x + k = self.k_buf.transpose(1, 2) + v = self.v_buf.transpose(1, 2) + + mask = torch.full( + (self.seq_len, self.buf_size), float("-inf"), device=x.device, dtype=x.dtype + ) + for i in range(self.seq_len): + mask[i, : i + 1] = 0.0 + + output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + return output + x + + +MODULE_REGISTRY["sdpa_strided_float_mask_large"] = { + "model_class": SDPAStridedFloatMaskLarge, + "input_shapes": [(1, 32, 4, 64)], + "description": "Streaming encoder SDPA with production-size ring buffer (1500)", + "atol_float32": 1e-4, + "atol_bfloat16": 5e-2, +} + + +# ------------------------------------------------------------------------- +class SDPAStridedFloatMaskDtypeMismatch(nn.Module): + """SDPA with intentionally mismatched mask dtype (float32 mask with bf16 Q/K/V). + + This should trigger the C++ dtype validation error in op_sdpa.mm when + the mask dtype doesn't match Q/K/V dtype. The kernel reads mask as + device T*, so a float32 mask with bf16 Q/K/V would be misinterpreted. + """ + + def __init__(self, n_heads=32, seq_len=4, buf_size=128, head_dim=64): + super().__init__() + self.n_heads = n_heads + self.seq_len = seq_len + self.buf_size = buf_size + self.head_dim = head_dim + + self.k_buf = nn.Parameter(torch.randn(1, buf_size, n_heads, head_dim)) + self.v_buf = nn.Parameter(torch.randn(1, buf_size, n_heads, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + q = x + k = self.k_buf.transpose(1, 2) + v = self.v_buf.transpose(1, 2) + + # Intentionally create mask in float32 regardless of x.dtype + mask = torch.full( + (self.seq_len, self.buf_size), + float("-inf"), + device=x.device, + dtype=torch.float32, + ) + for i in range(self.seq_len): + mask[i, : i + 1] = 0.0 + + output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + return output + x + + +MODULE_REGISTRY["sdpa_strided_float_mask_dtype_mismatch"] = { + "model_class": SDPAStridedFloatMaskDtypeMismatch, + "input_shapes": [(1, 32, 4, 64)], + "description": "SDPA with mismatched mask dtype — tested via test_sdpa_mask_dtype_mismatch_rejected", + "skip": "Tested separately — intentional dtype mismatch for validation", +} + + # ============================================================================= # Helper Functions # ============================================================================= @@ -1137,6 +1283,45 @@ def run_test_in_directory(test_dir: Path) -> None: with tempfile.TemporaryDirectory() as tmpdir: run_test_in_directory(Path(tmpdir)) + @unittest.skipIf(SKIP_RUNTIME_TESTS, SKIP_RUNTIME_REASON) + def test_sdpa_mask_dtype_mismatch_rejected(self): + """Verify that op_sdpa.mm rejects mask dtype != Q/K/V dtype at runtime. + + Exports a bf16 model with a float32 mask (intentional mismatch). + The C++ kernel should raise an error instead of silently computing + wrong results. + """ + model, example_inputs = get_model_and_inputs( + "sdpa_strided_float_mask_dtype_mismatch", dtype=torch.bfloat16 + ) + + with tempfile.TemporaryDirectory() as tmpdir: + model_output_dir = Path(tmpdir) / "dtype_mismatch" + model_output_dir.mkdir() + + pte_path, _ = export_model_to_pte( + model, + example_inputs, + model_output_dir, + "sdpa_strided_float_mask_dtype_mismatch", + ) + self.assertTrue(pte_path.exists()) + + output_base_path = model_output_dir / "output" + success, error_msg = run_executor_runner(pte_path, output_base_path) + + self.assertFalse( + success, + "Expected executor_runner to fail with mask dtype mismatch, but it succeeded", + ) + # The C++ validation throws inside the SDPA kernel, but the error + # surfaces through the AOTI wrapper as a generic "api call failed". + error_lower = error_msg.lower() if error_msg else "" + self.assertTrue( + "mask dtype" in error_lower or "api call failed" in error_lower, + f"Expected 'mask dtype' or 'api call failed' in error, got: {error_msg[:300] if error_msg else 'None'}", + ) + # ============================================================================= # Dynamically generate test methods for each module and dtype in MODULE_REGISTRY