Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions examples/models/voxtral_realtime/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ VoxtralRealtimeModel
attention: LMAttention
wq/wk/wv/wo: Linear (no bias)
kv_cache: KVCache (XNNPACK) or StaticKVCache (Metal/CUDA)
sdpa: SDPA (XNNPACK) or MetalSDPA (Metal) or CudaSDPA (CUDA)
sdpa: SDPA (XNNPACK) or MetalSDPA (Metal) or StandardSDPA (CUDA)
ffn_norm: RMSNorm
ada_rms_norm_t_cond: Sequential(Linear, GELU, Linear)
feed_forward: LMMLP (w1/w2/w3)
Expand All @@ -116,7 +116,7 @@ StreamingAudioEncoderExport
enc_norm: RMSNorm (shared from encoder.norm)
adapter: AudioLanguageAdapter (shared from model.adapter)
kv_caches: 32x EncoderRingKVCache (XNNPACK) or StandardEncoderRingKVCache (Metal/CUDA)
sdpa: SDPA (XNNPACK) or StandardEncoderSDPA (Metal/CUDA)
sdpa: SDPA (XNNPACK) or MetalSDPA (Metal, transpose_kv=True) or StandardSDPA (CUDA, transpose_kv=True)
inv_freq: RoPE inverse frequencies (owned, on-the-fly computation)
```

Expand Down Expand Up @@ -164,10 +164,12 @@ Handles GQA expansion internally and upcasts to float32.

**Metal:** `MetalSDPA` uses `torch.ops.aten._scaled_dot_product_attention_math_for_mps`
which handles GQA natively via `gqa_factor`, avoiding the memory bandwidth
overhead of `repeat_interleave`. Uses explicit additive attention masks.
AOTInductor has compatibility issues with the `custom_sdpa` custom op.
overhead of `repeat_interleave`. Uses explicit additive attention masks
that must match the Q/K/V dtype (the kernel reads masks as `device T*`).
Used for both decoder (GQA, `transpose_kv=False`) and streaming encoder
(no GQA, `transpose_kv=True`).

**CUDA:** `CudaSDPA` uses `F.scaled_dot_product_attention` with
**CUDA:** `StandardSDPA` uses `F.scaled_dot_product_attention` with
`repeat_interleave` for GQA expansion (32 query heads / 8 KV heads = 4x).
Uses boolean attention masks (`True`=attend, `False`=masked) as required
by the Triton SDPA kernel. The CUDA backend's Triton SDPA replacement
Expand Down Expand Up @@ -225,9 +227,13 @@ mel_chunk (1, 128, 8) + enc_input_pos (4,)
**XNNPACK/Portable:** Uses `EncoderRingKVCache` (`update_cache_with_indices`
custom op) and `SDPA` (`custom_sdpa`).

**Metal/CUDA:** Uses `StandardEncoderRingKVCache` (`index_copy_`-based ring
buffer) and `StandardEncoderSDPA` (`F.scaled_dot_product_attention` with
explicit sliding window masks) — AOTI-compatible patterns avoiding custom ops.
**Metal:** Uses `StandardEncoderRingKVCache` (`index_copy_`-based ring
buffer) and `MetalSDPA` (native MPS SDPA kernel with `transpose_kv=True`).
Masks are created in the model dtype to match the kernel's `device T*` expectation.

**CUDA:** Uses `StandardEncoderRingKVCache` and `StandardSDPA`
(`F.scaled_dot_product_attention` with `transpose_kv=True` and explicit
sliding window masks).

### Streaming decode loop

Expand Down
177 changes: 72 additions & 105 deletions examples/models/voxtral_realtime/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,14 +426,20 @@ def forward(


def _build_attn_mask(
input_pos: torch.Tensor, max_seq_len: int, device: torch.device
input_pos: torch.Tensor,
max_seq_len: int,
device: torch.device,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Build float additive attention mask without bool intermediates.
"""Build additive attention mask matching the model dtype.

Metal AOTI doesn't support bool tensor allocation on MPS, so we use
integer arithmetic: clamp(curr_pos - k_pos + 1, 0, 1) gives 1 for
valid positions (k <= curr_pos) and 0 for invalid, then convert to
additive mask (0.0 = attend, -1e9 = don't attend).

Comment on lines 436 to +440
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says the additive mask uses "-1e9 = don't attend", but the implementation now casts to an arbitrary dtype. For float16 in particular, multiplying by 1e9 overflows to +/-inf, so the mask semantics become 0 and -inf (and the doc becomes inaccurate). Consider either updating the doc to reflect dtype-dependent behavior, or computing the masked value from torch.finfo(dtype) (e.g., a representable large negative) to avoid overflow/magic constants while still matching Q/K/V dtype for Metal.

Copilot uses AI. Check for mistakes.
The mask dtype must match Q/K/V dtype — the Metal SDPA kernel reads
the mask buffer with the same element type as Q/K/V.
"""
seqlen = input_pos.shape[0]
k_pos = torch.arange(max_seq_len, device=device)
Expand All @@ -444,7 +450,7 @@ def _build_attn_mask(
# Decode: [1, max_seq_len]
diff = (input_pos[0] - k_pos + 1).unsqueeze(0)
valid = torch.clamp(diff, min=0, max=1)
return (valid.float() - 1.0) * 1e9
return (valid.to(dtype) - 1.0) * 1e9


def _build_causal_mask_bool(
Expand All @@ -461,18 +467,23 @@ def _build_causal_mask_bool(


class MetalSDPA(nn.Module):
"""Standard SDPA calling the MPS op directly for native GQA support.
"""Scaled dot-product attention using the native MPS kernel.

The Metal SDPA kernel handles GQA natively via gqa_factor = n_heads / n_kv_heads,
avoiding the 4x memory bandwidth overhead of repeat_interleave.
Supports GQA (n_heads != n_kv_heads) and bf16 without requiring
repeat_interleave or manual fp32 upcast. Expects Q in [B, S, H, D]
layout; K/V in [B, H, S, D] by default (set transpose_kv=True if
K/V arrive in [B, S, H, D]).
"""

def __init__(self, n_heads: int, n_kv_heads: int, head_dim: int):
def __init__(
self, n_heads: int, n_kv_heads: int, head_dim: int, transpose_kv: bool = False
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.dim = n_heads * head_dim
self.transpose_kv = transpose_kv

def forward(
self,
Expand All @@ -484,47 +495,40 @@ def forward(
seqlen: int,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Args:
input_pos: (seq_len,) position indices.
q: (B, seq_len, n_heads, head_dim) in [B, S, H, D] layout.
k, v: (B, n_kv_heads, max_seq_len, head_dim) in [B, H, S, D] layout from StaticKVCache.
bsz, seqlen: batch size and query sequence length.
attn_mask: precomputed float additive mask, or None to compute here.
Returns:
output: (B, seq_len, n_heads * head_dim).
"""
q = q.transpose(1, 2) # [B, n_heads, seq_len, head_dim]
q = q.transpose(1, 2)
if self.transpose_kv:
k = k.transpose(1, 2)
v = v.transpose(1, 2)

if attn_mask is None:
attn_mask = _build_attn_mask(input_pos, k.shape[2], q.device)
attn_mask = _build_attn_mask(input_pos, k.shape[2], q.device, q.dtype)

# Call the MPS SDPA op directly — bypasses CompositeImplicitAutograd
# decomposition which would insert repeat_interleave for GQA.
# The Metal kernel handles GQA natively via gqa_factor = n_heads / n_kv_heads.
y, _ = torch.ops.aten._scaled_dot_product_attention_math_for_mps(
q, k, v, attn_mask, 0.0, False, None
) # [B, n_heads, seq_len, head_dim]
)

y = y.transpose(1, 2).contiguous() # [B, seq_len, n_heads, head_dim]
y = y.transpose(1, 2).contiguous()
return y.view(bsz, seqlen, self.dim)


class CudaSDPA(nn.Module):
"""Standard SDPA with GQA support for CUDA/AOTI backend.
class StandardSDPA(nn.Module):
"""Scaled dot-product attention using F.scaled_dot_product_attention.

Uses F.scaled_dot_product_attention with repeat_interleave for GQA expansion.
KV cache uses [B, H, S, D] layout from StaticKVCache. Requires boolean
attention masks (Triton SDPA kernel only accepts torch.bool).
Supports GQA via repeat_interleave when n_heads != n_kv_heads.
Expects Q in [B, S, H, D]; K/V in [B, H, S, D] by default
(set transpose_kv=True if K/V arrive in [B, S, H, D]).
"""

def __init__(self, n_heads: int, n_kv_heads: int, head_dim: int):
def __init__(
self, n_heads: int, n_kv_heads: int, head_dim: int, transpose_kv: bool = False
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_rep = n_heads // n_kv_heads
self.head_dim = head_dim
self.dim = n_heads * head_dim
self.transpose_kv = transpose_kv

def forward(
self,
Expand All @@ -536,19 +540,11 @@ def forward(
seqlen: int,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Args:
input_pos: (seq_len,) position indices.
q: (B, seq_len, n_heads, head_dim) in [B, S, H, D] layout.
k, v: (B, n_kv_heads, max_seq_len, head_dim) in [B, H, S, D] layout from StaticKVCache.
bsz, seqlen: batch size and query sequence length.
attn_mask: precomputed boolean mask (True=attend), or None to compute here.
Returns:
output: (B, seq_len, n_heads * head_dim).
"""
q = q.transpose(1, 2) # [B, n_heads, seq_len, head_dim]
q = q.transpose(1, 2)
if self.transpose_kv:
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Expand KV for GQA
if self.n_rep > 1:
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
Expand All @@ -558,62 +554,9 @@ def forward(

y = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=False
) # [B, n_heads, seq_len, head_dim]

y = y.transpose(1, 2).contiguous() # [B, seq_len, n_heads, head_dim]
return y.view(bsz, seqlen, self.dim)


class StandardEncoderSDPA(nn.Module):
"""Standard SDPA for encoder using F.scaled_dot_product_attention.

Compatible with AOTI/Metal/CUDA backend. Works with EncoderRingKVCache that uses
[B, S, H, D] layout and sliding window masks.
"""

def __init__(self, n_heads: int, head_dim: int):
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
self.dim = n_heads * head_dim

def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bsz: int,
seqlen: int,
mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Args:
input_pos: (seq_len,) position indices.
q: (B, seq_len, n_heads, head_dim) in [B, S, H, D] layout.
k, v: (B, buf_size, n_heads, head_dim) in [B, S, H, D] layout from EncoderRingKVCache.
bsz, seqlen: batch size and query sequence length.
mask: (seq_len, buf_size) attention mask. Float additive (0.0=attend, -inf=masked)
for Metal, or boolean (True=attend) for CUDA.
Returns:
output: (B, seq_len, n_heads * head_dim).
"""
# Convert from [B, S, H, D] to [B, H, S, D] for F.scaled_dot_product_attention
q = q.transpose(1, 2) # [B, n_heads, seq_len, head_dim]
k = k.transpose(1, 2) # [B, n_heads, buf_size, head_dim]
v = v.transpose(1, 2) # [B, n_heads, buf_size, head_dim]
)

# Apply SDPA with sliding window mask
y = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
is_causal=False, # We handle masking explicitly via mask parameter
) # [B, n_heads, seq_len, head_dim]

# Convert back to [B, S, H, D] and flatten
y = y.transpose(1, 2).contiguous() # [B, seq_len, n_heads, head_dim]
y = y.transpose(1, 2).contiguous()
return y.view(bsz, seqlen, self.dim)


Expand Down Expand Up @@ -642,7 +585,7 @@ def __init__(self, config: VoxtralRealtimeConfig, max_seq_len: int):
self.sdpa = MetalSDPA(self.n_heads, self.n_kv_heads, self.head_dim)
elif self.backend == "cuda":
self.kv_cache = StaticKVCache(max_seq_len, self.n_kv_heads, self.head_dim)
self.sdpa = CudaSDPA(self.n_heads, self.n_kv_heads, self.head_dim)
self.sdpa = StandardSDPA(self.n_heads, self.n_kv_heads, self.head_dim)
else:
self.kv_cache = KVCache(max_seq_len, self.n_kv_heads, self.head_dim)
self.sdpa = SDPA(self.n_heads, self.head_dim)
Expand Down Expand Up @@ -753,7 +696,9 @@ def forward(
attn_mask: torch.Tensor | None = None
if self.config.backend == "metal":
max_seq_len = self.freqs_cos.shape[0]
attn_mask = _build_attn_mask(input_pos, max_seq_len, input_embeds.device)
attn_mask = _build_attn_mask(
input_pos, max_seq_len, input_embeds.device, input_embeds.dtype
)
elif self.config.backend == "cuda":
max_seq_len = self.freqs_cos.shape[0]
attn_mask = _build_causal_mask_bool(
Expand Down Expand Up @@ -925,15 +870,21 @@ def update(
return self.k_cache, self.v_cache

def create_causal_mask(
self, start_pos: torch.Tensor, seq_len: int, bool_mask: bool = False
self,
start_pos: torch.Tensor,
seq_len: int,
bool_mask: bool = False,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Create sliding window attention mask for ring buffer.

Args:
start_pos: Tensor containing the starting position (scalar tensor)
seq_len: Number of query positions
bool_mask: If True, return boolean mask (True=attend). If False,
return float additive mask (0.0=attend, -inf=masked).
return additive mask (0.0=attend, -inf=masked) in the given dtype.
dtype: Mask dtype for additive masks. Must match Q/K/V dtype for
the Metal SDPA kernel.
"""
total_written = start_pos + seq_len
j = torch.arange(self.buf_size, dtype=torch.long, device=start_pos.device)
Expand All @@ -949,7 +900,11 @@ def create_causal_mask(
return valid.unsqueeze(0).unsqueeze(
0
) # [1, 1, seq_len, buf_size] for Triton
return torch.where(valid, 0.0, float("-inf"))
return torch.where(
valid,
torch.zeros(1, dtype=dtype, device=start_pos.device),
torch.tensor(float("-inf"), dtype=dtype, device=start_pos.device),
)
Comment on lines +903 to +907
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_causal_mask now materializes new tensors (torch.zeros(1, ...) and torch.tensor(-inf, ...)) every call and uses a (1,) shape that relies on broadcasting. In streaming this mask is computed frequently, so these per-call allocations can add overhead. Prefer 0-dim scalars (shape ()) and/or reuse cached scalar constants (or buffers) to reduce allocations while keeping the requested dtype/device.

Copilot uses AI. Check for mistakes.


class StreamingAudioEncoderExport(nn.Module):
Expand Down Expand Up @@ -1000,8 +955,20 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750):
)

# Choose SDPA based on backend
if config.backend in ("metal", "cuda"):
self.sdpa = StandardEncoderSDPA(config.enc_n_heads, config.enc_head_dim)
if config.backend == "metal":
self.sdpa = MetalSDPA(
config.enc_n_heads,
config.enc_n_heads,
config.enc_head_dim,
transpose_kv=True,
)
elif config.backend == "cuda":
self.sdpa = StandardSDPA(
config.enc_n_heads,
config.enc_n_heads,
config.enc_head_dim,
transpose_kv=True,
)
else:
self.sdpa = SDPA(config.enc_n_heads, config.enc_head_dim)

Expand Down Expand Up @@ -1078,7 +1045,7 @@ def forward(
T = x.size(1)
# Pass start position as tensor (not .item()) to avoid unbacked symbols in AOTI
mask = self.kv_caches[0].create_causal_mask(
enc_input_pos[0], T, bool_mask=self.bool_mask
enc_input_pos[0], T, bool_mask=self.bool_mask, dtype=x.dtype
)

for i, layer in enumerate(self.layers):
Expand Down
Loading