From e62e797124f4e623385bb65a565d08c2d314c20b Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Sun, 8 Mar 2026 07:48:33 -0400 Subject: [PATCH] Voxtral Realtime: unify SDPA classes and dtype-aware attention masks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Unify CudaSDPA and StandardEncoderSDPA into a single StandardSDPA class with transpose_kv parameter, mirroring the MetalSDPA unification. This gives a symmetric design: MetalSDPA and StandardSDPA share the same interface (n_heads, n_kv_heads, head_dim, transpose_kv). Make _build_attn_mask and create_causal_mask dtype-aware — masks are now created in the model dtype instead of always float32. This is required because the Metal SDPA kernel reads the mask buffer as device T* (same type as Q/K/V). A float32 mask with bf16 Q/K/V would be misinterpreted. --- examples/models/voxtral_realtime/model.md | 22 ++- examples/models/voxtral_realtime/model.py | 177 +++++++++------------- 2 files changed, 86 insertions(+), 113 deletions(-) diff --git a/examples/models/voxtral_realtime/model.md b/examples/models/voxtral_realtime/model.md index fe240b03d8c..5a58b8c4a1a 100644 --- a/examples/models/voxtral_realtime/model.md +++ b/examples/models/voxtral_realtime/model.md @@ -102,7 +102,7 @@ VoxtralRealtimeModel attention: LMAttention wq/wk/wv/wo: Linear (no bias) kv_cache: KVCache (XNNPACK) or StaticKVCache (Metal/CUDA) - sdpa: SDPA (XNNPACK) or MetalSDPA (Metal) or CudaSDPA (CUDA) + sdpa: SDPA (XNNPACK) or MetalSDPA (Metal) or StandardSDPA (CUDA) ffn_norm: RMSNorm ada_rms_norm_t_cond: Sequential(Linear, GELU, Linear) feed_forward: LMMLP (w1/w2/w3) @@ -116,7 +116,7 @@ StreamingAudioEncoderExport enc_norm: RMSNorm (shared from encoder.norm) adapter: AudioLanguageAdapter (shared from model.adapter) kv_caches: 32x EncoderRingKVCache (XNNPACK) or StandardEncoderRingKVCache (Metal/CUDA) - sdpa: SDPA (XNNPACK) or StandardEncoderSDPA (Metal/CUDA) + sdpa: SDPA (XNNPACK) or MetalSDPA (Metal, transpose_kv=True) or StandardSDPA (CUDA, transpose_kv=True) inv_freq: RoPE inverse frequencies (owned, on-the-fly computation) ``` @@ -164,10 +164,12 @@ Handles GQA expansion internally and upcasts to float32. **Metal:** `MetalSDPA` uses `torch.ops.aten._scaled_dot_product_attention_math_for_mps` which handles GQA natively via `gqa_factor`, avoiding the memory bandwidth -overhead of `repeat_interleave`. Uses explicit additive attention masks. -AOTInductor has compatibility issues with the `custom_sdpa` custom op. +overhead of `repeat_interleave`. Uses explicit additive attention masks +that must match the Q/K/V dtype (the kernel reads masks as `device T*`). +Used for both decoder (GQA, `transpose_kv=False`) and streaming encoder +(no GQA, `transpose_kv=True`). -**CUDA:** `CudaSDPA` uses `F.scaled_dot_product_attention` with +**CUDA:** `StandardSDPA` uses `F.scaled_dot_product_attention` with `repeat_interleave` for GQA expansion (32 query heads / 8 KV heads = 4x). Uses boolean attention masks (`True`=attend, `False`=masked) as required by the Triton SDPA kernel. The CUDA backend's Triton SDPA replacement @@ -225,9 +227,13 @@ mel_chunk (1, 128, 8) + enc_input_pos (4,) **XNNPACK/Portable:** Uses `EncoderRingKVCache` (`update_cache_with_indices` custom op) and `SDPA` (`custom_sdpa`). -**Metal/CUDA:** Uses `StandardEncoderRingKVCache` (`index_copy_`-based ring -buffer) and `StandardEncoderSDPA` (`F.scaled_dot_product_attention` with -explicit sliding window masks) — AOTI-compatible patterns avoiding custom ops. +**Metal:** Uses `StandardEncoderRingKVCache` (`index_copy_`-based ring +buffer) and `MetalSDPA` (native MPS SDPA kernel with `transpose_kv=True`). +Masks are created in the model dtype to match the kernel's `device T*` expectation. + +**CUDA:** Uses `StandardEncoderRingKVCache` and `StandardSDPA` +(`F.scaled_dot_product_attention` with `transpose_kv=True` and explicit +sliding window masks). ### Streaming decode loop diff --git a/examples/models/voxtral_realtime/model.py b/examples/models/voxtral_realtime/model.py index 26778413834..54cba5f109e 100644 --- a/examples/models/voxtral_realtime/model.py +++ b/examples/models/voxtral_realtime/model.py @@ -426,14 +426,20 @@ def forward( def _build_attn_mask( - input_pos: torch.Tensor, max_seq_len: int, device: torch.device + input_pos: torch.Tensor, + max_seq_len: int, + device: torch.device, + dtype: torch.dtype = torch.float32, ) -> torch.Tensor: - """Build float additive attention mask without bool intermediates. + """Build additive attention mask matching the model dtype. Metal AOTI doesn't support bool tensor allocation on MPS, so we use integer arithmetic: clamp(curr_pos - k_pos + 1, 0, 1) gives 1 for valid positions (k <= curr_pos) and 0 for invalid, then convert to additive mask (0.0 = attend, -1e9 = don't attend). + + The mask dtype must match Q/K/V dtype — the Metal SDPA kernel reads + the mask buffer with the same element type as Q/K/V. """ seqlen = input_pos.shape[0] k_pos = torch.arange(max_seq_len, device=device) @@ -444,7 +450,7 @@ def _build_attn_mask( # Decode: [1, max_seq_len] diff = (input_pos[0] - k_pos + 1).unsqueeze(0) valid = torch.clamp(diff, min=0, max=1) - return (valid.float() - 1.0) * 1e9 + return (valid.to(dtype) - 1.0) * 1e9 def _build_causal_mask_bool( @@ -461,18 +467,23 @@ def _build_causal_mask_bool( class MetalSDPA(nn.Module): - """Standard SDPA calling the MPS op directly for native GQA support. + """Scaled dot-product attention using the native MPS kernel. - The Metal SDPA kernel handles GQA natively via gqa_factor = n_heads / n_kv_heads, - avoiding the 4x memory bandwidth overhead of repeat_interleave. + Supports GQA (n_heads != n_kv_heads) and bf16 without requiring + repeat_interleave or manual fp32 upcast. Expects Q in [B, S, H, D] + layout; K/V in [B, H, S, D] by default (set transpose_kv=True if + K/V arrive in [B, S, H, D]). """ - def __init__(self, n_heads: int, n_kv_heads: int, head_dim: int): + def __init__( + self, n_heads: int, n_kv_heads: int, head_dim: int, transpose_kv: bool = False + ): super().__init__() self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.head_dim = head_dim self.dim = n_heads * head_dim + self.transpose_kv = transpose_kv def forward( self, @@ -484,47 +495,40 @@ def forward( seqlen: int, attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: - """ - Args: - input_pos: (seq_len,) position indices. - q: (B, seq_len, n_heads, head_dim) in [B, S, H, D] layout. - k, v: (B, n_kv_heads, max_seq_len, head_dim) in [B, H, S, D] layout from StaticKVCache. - bsz, seqlen: batch size and query sequence length. - attn_mask: precomputed float additive mask, or None to compute here. - Returns: - output: (B, seq_len, n_heads * head_dim). - """ - q = q.transpose(1, 2) # [B, n_heads, seq_len, head_dim] + q = q.transpose(1, 2) + if self.transpose_kv: + k = k.transpose(1, 2) + v = v.transpose(1, 2) if attn_mask is None: - attn_mask = _build_attn_mask(input_pos, k.shape[2], q.device) + attn_mask = _build_attn_mask(input_pos, k.shape[2], q.device, q.dtype) - # Call the MPS SDPA op directly — bypasses CompositeImplicitAutograd - # decomposition which would insert repeat_interleave for GQA. - # The Metal kernel handles GQA natively via gqa_factor = n_heads / n_kv_heads. y, _ = torch.ops.aten._scaled_dot_product_attention_math_for_mps( q, k, v, attn_mask, 0.0, False, None - ) # [B, n_heads, seq_len, head_dim] + ) - y = y.transpose(1, 2).contiguous() # [B, seq_len, n_heads, head_dim] + y = y.transpose(1, 2).contiguous() return y.view(bsz, seqlen, self.dim) -class CudaSDPA(nn.Module): - """Standard SDPA with GQA support for CUDA/AOTI backend. +class StandardSDPA(nn.Module): + """Scaled dot-product attention using F.scaled_dot_product_attention. - Uses F.scaled_dot_product_attention with repeat_interleave for GQA expansion. - KV cache uses [B, H, S, D] layout from StaticKVCache. Requires boolean - attention masks (Triton SDPA kernel only accepts torch.bool). + Supports GQA via repeat_interleave when n_heads != n_kv_heads. + Expects Q in [B, S, H, D]; K/V in [B, H, S, D] by default + (set transpose_kv=True if K/V arrive in [B, S, H, D]). """ - def __init__(self, n_heads: int, n_kv_heads: int, head_dim: int): + def __init__( + self, n_heads: int, n_kv_heads: int, head_dim: int, transpose_kv: bool = False + ): super().__init__() self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.n_rep = n_heads // n_kv_heads self.head_dim = head_dim self.dim = n_heads * head_dim + self.transpose_kv = transpose_kv def forward( self, @@ -536,19 +540,11 @@ def forward( seqlen: int, attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: - """ - Args: - input_pos: (seq_len,) position indices. - q: (B, seq_len, n_heads, head_dim) in [B, S, H, D] layout. - k, v: (B, n_kv_heads, max_seq_len, head_dim) in [B, H, S, D] layout from StaticKVCache. - bsz, seqlen: batch size and query sequence length. - attn_mask: precomputed boolean mask (True=attend), or None to compute here. - Returns: - output: (B, seq_len, n_heads * head_dim). - """ - q = q.transpose(1, 2) # [B, n_heads, seq_len, head_dim] + q = q.transpose(1, 2) + if self.transpose_kv: + k = k.transpose(1, 2) + v = v.transpose(1, 2) - # Expand KV for GQA if self.n_rep > 1: k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) @@ -558,62 +554,9 @@ def forward( y = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, is_causal=False - ) # [B, n_heads, seq_len, head_dim] - - y = y.transpose(1, 2).contiguous() # [B, seq_len, n_heads, head_dim] - return y.view(bsz, seqlen, self.dim) - - -class StandardEncoderSDPA(nn.Module): - """Standard SDPA for encoder using F.scaled_dot_product_attention. - - Compatible with AOTI/Metal/CUDA backend. Works with EncoderRingKVCache that uses - [B, S, H, D] layout and sliding window masks. - """ - - def __init__(self, n_heads: int, head_dim: int): - super().__init__() - self.n_heads = n_heads - self.head_dim = head_dim - self.dim = n_heads * head_dim - - def forward( - self, - input_pos: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - bsz: int, - seqlen: int, - mask: torch.Tensor | None = None, - ) -> torch.Tensor: - """ - Args: - input_pos: (seq_len,) position indices. - q: (B, seq_len, n_heads, head_dim) in [B, S, H, D] layout. - k, v: (B, buf_size, n_heads, head_dim) in [B, S, H, D] layout from EncoderRingKVCache. - bsz, seqlen: batch size and query sequence length. - mask: (seq_len, buf_size) attention mask. Float additive (0.0=attend, -inf=masked) - for Metal, or boolean (True=attend) for CUDA. - Returns: - output: (B, seq_len, n_heads * head_dim). - """ - # Convert from [B, S, H, D] to [B, H, S, D] for F.scaled_dot_product_attention - q = q.transpose(1, 2) # [B, n_heads, seq_len, head_dim] - k = k.transpose(1, 2) # [B, n_heads, buf_size, head_dim] - v = v.transpose(1, 2) # [B, n_heads, buf_size, head_dim] + ) - # Apply SDPA with sliding window mask - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=mask, - is_causal=False, # We handle masking explicitly via mask parameter - ) # [B, n_heads, seq_len, head_dim] - - # Convert back to [B, S, H, D] and flatten - y = y.transpose(1, 2).contiguous() # [B, seq_len, n_heads, head_dim] + y = y.transpose(1, 2).contiguous() return y.view(bsz, seqlen, self.dim) @@ -642,7 +585,7 @@ def __init__(self, config: VoxtralRealtimeConfig, max_seq_len: int): self.sdpa = MetalSDPA(self.n_heads, self.n_kv_heads, self.head_dim) elif self.backend == "cuda": self.kv_cache = StaticKVCache(max_seq_len, self.n_kv_heads, self.head_dim) - self.sdpa = CudaSDPA(self.n_heads, self.n_kv_heads, self.head_dim) + self.sdpa = StandardSDPA(self.n_heads, self.n_kv_heads, self.head_dim) else: self.kv_cache = KVCache(max_seq_len, self.n_kv_heads, self.head_dim) self.sdpa = SDPA(self.n_heads, self.head_dim) @@ -753,7 +696,9 @@ def forward( attn_mask: torch.Tensor | None = None if self.config.backend == "metal": max_seq_len = self.freqs_cos.shape[0] - attn_mask = _build_attn_mask(input_pos, max_seq_len, input_embeds.device) + attn_mask = _build_attn_mask( + input_pos, max_seq_len, input_embeds.device, input_embeds.dtype + ) elif self.config.backend == "cuda": max_seq_len = self.freqs_cos.shape[0] attn_mask = _build_causal_mask_bool( @@ -925,7 +870,11 @@ def update( return self.k_cache, self.v_cache def create_causal_mask( - self, start_pos: torch.Tensor, seq_len: int, bool_mask: bool = False + self, + start_pos: torch.Tensor, + seq_len: int, + bool_mask: bool = False, + dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """Create sliding window attention mask for ring buffer. @@ -933,7 +882,9 @@ def create_causal_mask( start_pos: Tensor containing the starting position (scalar tensor) seq_len: Number of query positions bool_mask: If True, return boolean mask (True=attend). If False, - return float additive mask (0.0=attend, -inf=masked). + return additive mask (0.0=attend, -inf=masked) in the given dtype. + dtype: Mask dtype for additive masks. Must match Q/K/V dtype for + the Metal SDPA kernel. """ total_written = start_pos + seq_len j = torch.arange(self.buf_size, dtype=torch.long, device=start_pos.device) @@ -949,7 +900,11 @@ def create_causal_mask( return valid.unsqueeze(0).unsqueeze( 0 ) # [1, 1, seq_len, buf_size] for Triton - return torch.where(valid, 0.0, float("-inf")) + return torch.where( + valid, + torch.zeros(1, dtype=dtype, device=start_pos.device), + torch.tensor(float("-inf"), dtype=dtype, device=start_pos.device), + ) class StreamingAudioEncoderExport(nn.Module): @@ -1000,8 +955,20 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750): ) # Choose SDPA based on backend - if config.backend in ("metal", "cuda"): - self.sdpa = StandardEncoderSDPA(config.enc_n_heads, config.enc_head_dim) + if config.backend == "metal": + self.sdpa = MetalSDPA( + config.enc_n_heads, + config.enc_n_heads, + config.enc_head_dim, + transpose_kv=True, + ) + elif config.backend == "cuda": + self.sdpa = StandardSDPA( + config.enc_n_heads, + config.enc_n_heads, + config.enc_head_dim, + transpose_kv=True, + ) else: self.sdpa = SDPA(config.enc_n_heads, config.enc_head_dim) @@ -1078,7 +1045,7 @@ def forward( T = x.size(1) # Pass start position as tensor (not .item()) to avoid unbacked symbols in AOTI mask = self.kv_caches[0].create_causal_mask( - enc_input_pos[0], T, bool_mask=self.bool_mask + enc_input_pos[0], T, bool_mask=self.bool_mask, dtype=x.dtype ) for i, layer in enumerate(self.layers):