-
Notifications
You must be signed in to change notification settings - Fork 870
Voxtral Realtime: unify SDPA classes and dtype-aware attention masks #17997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -426,14 +426,20 @@ def forward( | |
|
|
||
|
|
||
| def _build_attn_mask( | ||
| input_pos: torch.Tensor, max_seq_len: int, device: torch.device | ||
| input_pos: torch.Tensor, | ||
| max_seq_len: int, | ||
| device: torch.device, | ||
| dtype: torch.dtype = torch.float32, | ||
| ) -> torch.Tensor: | ||
| """Build float additive attention mask without bool intermediates. | ||
| """Build additive attention mask matching the model dtype. | ||
|
|
||
| Metal AOTI doesn't support bool tensor allocation on MPS, so we use | ||
| integer arithmetic: clamp(curr_pos - k_pos + 1, 0, 1) gives 1 for | ||
| valid positions (k <= curr_pos) and 0 for invalid, then convert to | ||
| additive mask (0.0 = attend, -1e9 = don't attend). | ||
|
|
||
| The mask dtype must match Q/K/V dtype — the Metal SDPA kernel reads | ||
| the mask buffer with the same element type as Q/K/V. | ||
| """ | ||
| seqlen = input_pos.shape[0] | ||
| k_pos = torch.arange(max_seq_len, device=device) | ||
|
|
@@ -444,7 +450,7 @@ def _build_attn_mask( | |
| # Decode: [1, max_seq_len] | ||
| diff = (input_pos[0] - k_pos + 1).unsqueeze(0) | ||
| valid = torch.clamp(diff, min=0, max=1) | ||
| return (valid.float() - 1.0) * 1e9 | ||
| return (valid.to(dtype) - 1.0) * 1e9 | ||
|
|
||
|
|
||
| def _build_causal_mask_bool( | ||
|
|
@@ -461,18 +467,23 @@ def _build_causal_mask_bool( | |
|
|
||
|
|
||
| class MetalSDPA(nn.Module): | ||
| """Standard SDPA calling the MPS op directly for native GQA support. | ||
| """Scaled dot-product attention using the native MPS kernel. | ||
|
|
||
| The Metal SDPA kernel handles GQA natively via gqa_factor = n_heads / n_kv_heads, | ||
| avoiding the 4x memory bandwidth overhead of repeat_interleave. | ||
| Supports GQA (n_heads != n_kv_heads) and bf16 without requiring | ||
| repeat_interleave or manual fp32 upcast. Expects Q in [B, S, H, D] | ||
| layout; K/V in [B, H, S, D] by default (set transpose_kv=True if | ||
| K/V arrive in [B, S, H, D]). | ||
| """ | ||
|
|
||
| def __init__(self, n_heads: int, n_kv_heads: int, head_dim: int): | ||
| def __init__( | ||
| self, n_heads: int, n_kv_heads: int, head_dim: int, transpose_kv: bool = False | ||
| ): | ||
| super().__init__() | ||
| self.n_heads = n_heads | ||
| self.n_kv_heads = n_kv_heads | ||
| self.head_dim = head_dim | ||
| self.dim = n_heads * head_dim | ||
| self.transpose_kv = transpose_kv | ||
|
|
||
| def forward( | ||
| self, | ||
|
|
@@ -484,47 +495,40 @@ def forward( | |
| seqlen: int, | ||
| attn_mask: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Args: | ||
| input_pos: (seq_len,) position indices. | ||
| q: (B, seq_len, n_heads, head_dim) in [B, S, H, D] layout. | ||
| k, v: (B, n_kv_heads, max_seq_len, head_dim) in [B, H, S, D] layout from StaticKVCache. | ||
| bsz, seqlen: batch size and query sequence length. | ||
| attn_mask: precomputed float additive mask, or None to compute here. | ||
| Returns: | ||
| output: (B, seq_len, n_heads * head_dim). | ||
| """ | ||
| q = q.transpose(1, 2) # [B, n_heads, seq_len, head_dim] | ||
| q = q.transpose(1, 2) | ||
| if self.transpose_kv: | ||
| k = k.transpose(1, 2) | ||
| v = v.transpose(1, 2) | ||
|
|
||
| if attn_mask is None: | ||
| attn_mask = _build_attn_mask(input_pos, k.shape[2], q.device) | ||
| attn_mask = _build_attn_mask(input_pos, k.shape[2], q.device, q.dtype) | ||
|
|
||
| # Call the MPS SDPA op directly — bypasses CompositeImplicitAutograd | ||
| # decomposition which would insert repeat_interleave for GQA. | ||
| # The Metal kernel handles GQA natively via gqa_factor = n_heads / n_kv_heads. | ||
| y, _ = torch.ops.aten._scaled_dot_product_attention_math_for_mps( | ||
| q, k, v, attn_mask, 0.0, False, None | ||
| ) # [B, n_heads, seq_len, head_dim] | ||
| ) | ||
|
|
||
| y = y.transpose(1, 2).contiguous() # [B, seq_len, n_heads, head_dim] | ||
| y = y.transpose(1, 2).contiguous() | ||
| return y.view(bsz, seqlen, self.dim) | ||
|
|
||
|
|
||
| class CudaSDPA(nn.Module): | ||
| """Standard SDPA with GQA support for CUDA/AOTI backend. | ||
| class StandardSDPA(nn.Module): | ||
| """Scaled dot-product attention using F.scaled_dot_product_attention. | ||
|
|
||
| Uses F.scaled_dot_product_attention with repeat_interleave for GQA expansion. | ||
| KV cache uses [B, H, S, D] layout from StaticKVCache. Requires boolean | ||
| attention masks (Triton SDPA kernel only accepts torch.bool). | ||
| Supports GQA via repeat_interleave when n_heads != n_kv_heads. | ||
| Expects Q in [B, S, H, D]; K/V in [B, H, S, D] by default | ||
| (set transpose_kv=True if K/V arrive in [B, S, H, D]). | ||
| """ | ||
|
|
||
| def __init__(self, n_heads: int, n_kv_heads: int, head_dim: int): | ||
| def __init__( | ||
| self, n_heads: int, n_kv_heads: int, head_dim: int, transpose_kv: bool = False | ||
| ): | ||
| super().__init__() | ||
| self.n_heads = n_heads | ||
| self.n_kv_heads = n_kv_heads | ||
| self.n_rep = n_heads // n_kv_heads | ||
| self.head_dim = head_dim | ||
| self.dim = n_heads * head_dim | ||
| self.transpose_kv = transpose_kv | ||
|
|
||
| def forward( | ||
| self, | ||
|
|
@@ -536,19 +540,11 @@ def forward( | |
| seqlen: int, | ||
| attn_mask: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Args: | ||
| input_pos: (seq_len,) position indices. | ||
| q: (B, seq_len, n_heads, head_dim) in [B, S, H, D] layout. | ||
| k, v: (B, n_kv_heads, max_seq_len, head_dim) in [B, H, S, D] layout from StaticKVCache. | ||
| bsz, seqlen: batch size and query sequence length. | ||
| attn_mask: precomputed boolean mask (True=attend), or None to compute here. | ||
| Returns: | ||
| output: (B, seq_len, n_heads * head_dim). | ||
| """ | ||
| q = q.transpose(1, 2) # [B, n_heads, seq_len, head_dim] | ||
| q = q.transpose(1, 2) | ||
| if self.transpose_kv: | ||
| k = k.transpose(1, 2) | ||
| v = v.transpose(1, 2) | ||
|
|
||
| # Expand KV for GQA | ||
| if self.n_rep > 1: | ||
| k = k.repeat_interleave(self.n_rep, dim=1) | ||
| v = v.repeat_interleave(self.n_rep, dim=1) | ||
|
|
@@ -558,62 +554,9 @@ def forward( | |
|
|
||
| y = F.scaled_dot_product_attention( | ||
| q, k, v, attn_mask=attn_mask, is_causal=False | ||
| ) # [B, n_heads, seq_len, head_dim] | ||
|
|
||
| y = y.transpose(1, 2).contiguous() # [B, seq_len, n_heads, head_dim] | ||
| return y.view(bsz, seqlen, self.dim) | ||
|
|
||
|
|
||
| class StandardEncoderSDPA(nn.Module): | ||
| """Standard SDPA for encoder using F.scaled_dot_product_attention. | ||
|
|
||
| Compatible with AOTI/Metal/CUDA backend. Works with EncoderRingKVCache that uses | ||
| [B, S, H, D] layout and sliding window masks. | ||
| """ | ||
|
|
||
| def __init__(self, n_heads: int, head_dim: int): | ||
| super().__init__() | ||
| self.n_heads = n_heads | ||
| self.head_dim = head_dim | ||
| self.dim = n_heads * head_dim | ||
|
|
||
| def forward( | ||
| self, | ||
| input_pos: torch.Tensor, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| bsz: int, | ||
| seqlen: int, | ||
| mask: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Args: | ||
| input_pos: (seq_len,) position indices. | ||
| q: (B, seq_len, n_heads, head_dim) in [B, S, H, D] layout. | ||
| k, v: (B, buf_size, n_heads, head_dim) in [B, S, H, D] layout from EncoderRingKVCache. | ||
| bsz, seqlen: batch size and query sequence length. | ||
| mask: (seq_len, buf_size) attention mask. Float additive (0.0=attend, -inf=masked) | ||
| for Metal, or boolean (True=attend) for CUDA. | ||
| Returns: | ||
| output: (B, seq_len, n_heads * head_dim). | ||
| """ | ||
| # Convert from [B, S, H, D] to [B, H, S, D] for F.scaled_dot_product_attention | ||
| q = q.transpose(1, 2) # [B, n_heads, seq_len, head_dim] | ||
| k = k.transpose(1, 2) # [B, n_heads, buf_size, head_dim] | ||
| v = v.transpose(1, 2) # [B, n_heads, buf_size, head_dim] | ||
| ) | ||
|
|
||
| # Apply SDPA with sliding window mask | ||
| y = F.scaled_dot_product_attention( | ||
| q, | ||
| k, | ||
| v, | ||
| attn_mask=mask, | ||
| is_causal=False, # We handle masking explicitly via mask parameter | ||
| ) # [B, n_heads, seq_len, head_dim] | ||
|
|
||
| # Convert back to [B, S, H, D] and flatten | ||
| y = y.transpose(1, 2).contiguous() # [B, seq_len, n_heads, head_dim] | ||
| y = y.transpose(1, 2).contiguous() | ||
| return y.view(bsz, seqlen, self.dim) | ||
|
|
||
|
|
||
|
|
@@ -642,7 +585,7 @@ def __init__(self, config: VoxtralRealtimeConfig, max_seq_len: int): | |
| self.sdpa = MetalSDPA(self.n_heads, self.n_kv_heads, self.head_dim) | ||
| elif self.backend == "cuda": | ||
| self.kv_cache = StaticKVCache(max_seq_len, self.n_kv_heads, self.head_dim) | ||
| self.sdpa = CudaSDPA(self.n_heads, self.n_kv_heads, self.head_dim) | ||
| self.sdpa = StandardSDPA(self.n_heads, self.n_kv_heads, self.head_dim) | ||
| else: | ||
| self.kv_cache = KVCache(max_seq_len, self.n_kv_heads, self.head_dim) | ||
| self.sdpa = SDPA(self.n_heads, self.head_dim) | ||
|
|
@@ -753,7 +696,9 @@ def forward( | |
| attn_mask: torch.Tensor | None = None | ||
| if self.config.backend == "metal": | ||
| max_seq_len = self.freqs_cos.shape[0] | ||
| attn_mask = _build_attn_mask(input_pos, max_seq_len, input_embeds.device) | ||
| attn_mask = _build_attn_mask( | ||
| input_pos, max_seq_len, input_embeds.device, input_embeds.dtype | ||
| ) | ||
| elif self.config.backend == "cuda": | ||
| max_seq_len = self.freqs_cos.shape[0] | ||
| attn_mask = _build_causal_mask_bool( | ||
|
|
@@ -925,15 +870,21 @@ def update( | |
| return self.k_cache, self.v_cache | ||
|
|
||
| def create_causal_mask( | ||
| self, start_pos: torch.Tensor, seq_len: int, bool_mask: bool = False | ||
| self, | ||
| start_pos: torch.Tensor, | ||
| seq_len: int, | ||
| bool_mask: bool = False, | ||
| dtype: torch.dtype = torch.float32, | ||
| ) -> torch.Tensor: | ||
| """Create sliding window attention mask for ring buffer. | ||
|
|
||
| Args: | ||
| start_pos: Tensor containing the starting position (scalar tensor) | ||
| seq_len: Number of query positions | ||
| bool_mask: If True, return boolean mask (True=attend). If False, | ||
| return float additive mask (0.0=attend, -inf=masked). | ||
| return additive mask (0.0=attend, -inf=masked) in the given dtype. | ||
| dtype: Mask dtype for additive masks. Must match Q/K/V dtype for | ||
| the Metal SDPA kernel. | ||
| """ | ||
| total_written = start_pos + seq_len | ||
| j = torch.arange(self.buf_size, dtype=torch.long, device=start_pos.device) | ||
|
|
@@ -949,7 +900,11 @@ def create_causal_mask( | |
| return valid.unsqueeze(0).unsqueeze( | ||
| 0 | ||
| ) # [1, 1, seq_len, buf_size] for Triton | ||
| return torch.where(valid, 0.0, float("-inf")) | ||
| return torch.where( | ||
| valid, | ||
| torch.zeros(1, dtype=dtype, device=start_pos.device), | ||
| torch.tensor(float("-inf"), dtype=dtype, device=start_pos.device), | ||
| ) | ||
|
Comment on lines
+903
to
+907
|
||
|
|
||
|
|
||
| class StreamingAudioEncoderExport(nn.Module): | ||
|
|
@@ -1000,8 +955,20 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750): | |
| ) | ||
|
|
||
| # Choose SDPA based on backend | ||
| if config.backend in ("metal", "cuda"): | ||
| self.sdpa = StandardEncoderSDPA(config.enc_n_heads, config.enc_head_dim) | ||
| if config.backend == "metal": | ||
| self.sdpa = MetalSDPA( | ||
| config.enc_n_heads, | ||
| config.enc_n_heads, | ||
| config.enc_head_dim, | ||
| transpose_kv=True, | ||
| ) | ||
| elif config.backend == "cuda": | ||
| self.sdpa = StandardSDPA( | ||
| config.enc_n_heads, | ||
| config.enc_n_heads, | ||
| config.enc_head_dim, | ||
| transpose_kv=True, | ||
| ) | ||
| else: | ||
| self.sdpa = SDPA(config.enc_n_heads, config.enc_head_dim) | ||
|
|
||
|
|
@@ -1078,7 +1045,7 @@ def forward( | |
| T = x.size(1) | ||
| # Pass start position as tensor (not .item()) to avoid unbacked symbols in AOTI | ||
| mask = self.kv_caches[0].create_causal_mask( | ||
| enc_input_pos[0], T, bool_mask=self.bool_mask | ||
| enc_input_pos[0], T, bool_mask=self.bool_mask, dtype=x.dtype | ||
| ) | ||
|
|
||
| for i, layer in enumerate(self.layers): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring says the additive mask uses "-1e9 = don't attend", but the implementation now casts to an arbitrary dtype. For float16 in particular, multiplying by 1e9 overflows to +/-inf, so the mask semantics become 0 and -inf (and the doc becomes inaccurate). Consider either updating the doc to reflect dtype-dependent behavior, or computing the masked value from torch.finfo(dtype) (e.g., a representable large negative) to avoid overflow/magic constants while still matching Q/K/V dtype for Metal.