Voxtral Realtime: unify SDPA classes and dtype-aware attention masks#17997
Voxtral Realtime: unify SDPA classes and dtype-aware attention masks#17997mergennachin wants to merge 1 commit intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17997
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Awaiting Approval, 2 New FailuresAs of commit e62e797 with merge base 122fdef ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Unify CudaSDPA and StandardEncoderSDPA into a single StandardSDPA class with transpose_kv parameter, mirroring the MetalSDPA unification. This gives a symmetric design: MetalSDPA and StandardSDPA share the same interface (n_heads, n_kv_heads, head_dim, transpose_kv). Make _build_attn_mask and create_causal_mask dtype-aware — masks are now created in the model dtype instead of always float32. This is required because the Metal SDPA kernel reads the mask buffer as device T* (same type as Q/K/V). A float32 mask with bf16 Q/K/V would be misinterpreted.
adb6ec3 to
e62e797
Compare
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
This PR unifies the Voxtral Realtime SDPA implementations to align CUDA and encoder behavior with the existing Metal design, and updates attention mask creation to be dtype-aware to satisfy Metal SDPA kernel requirements.
Changes:
- Replaced
CudaSDPAandStandardEncoderSDPAwith a unifiedStandardSDPAthat supports both decoder and streaming-encoder layouts viatranspose_kv. - Made
_build_attn_maskandStandardEncoderRingKVCache.create_causal_maskdtype-aware so additive masks are created in the model’s dtype (required for Metal SDPA). - Updated model documentation to reflect the new SDPA classes and mask dtype constraints.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| examples/models/voxtral_realtime/model.py | Unifies SDPA modules (adds transpose_kv) and makes additive mask generation dtype-aware for Metal compatibility. |
| examples/models/voxtral_realtime/model.md | Updates architecture/docs to reflect StandardSDPA and dtype-matching additive masks for Metal. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| Metal AOTI doesn't support bool tensor allocation on MPS, so we use | ||
| integer arithmetic: clamp(curr_pos - k_pos + 1, 0, 1) gives 1 for | ||
| valid positions (k <= curr_pos) and 0 for invalid, then convert to | ||
| additive mask (0.0 = attend, -1e9 = don't attend). | ||
|
|
There was a problem hiding this comment.
The docstring says the additive mask uses "-1e9 = don't attend", but the implementation now casts to an arbitrary dtype. For float16 in particular, multiplying by 1e9 overflows to +/-inf, so the mask semantics become 0 and -inf (and the doc becomes inaccurate). Consider either updating the doc to reflect dtype-dependent behavior, or computing the masked value from torch.finfo(dtype) (e.g., a representable large negative) to avoid overflow/magic constants while still matching Q/K/V dtype for Metal.
| return torch.where( | ||
| valid, | ||
| torch.zeros(1, dtype=dtype, device=start_pos.device), | ||
| torch.tensor(float("-inf"), dtype=dtype, device=start_pos.device), | ||
| ) |
There was a problem hiding this comment.
create_causal_mask now materializes new tensors (torch.zeros(1, ...) and torch.tensor(-inf, ...)) every call and uses a (1,) shape that relies on broadcasting. In streaming this mask is computed frequently, so these per-call allocations can add overhead. Prefer 0-dim scalars (shape ()) and/or reuse cached scalar constants (or buffers) to reduce allocations while keeping the requested dtype/device.
Unify CudaSDPA and StandardEncoderSDPA into a single StandardSDPA class
with transpose_kv parameter, mirroring the MetalSDPA unification. This
gives a symmetric design: MetalSDPA and StandardSDPA share the same
interface (n_heads, n_kv_heads, head_dim, transpose_kv).
Make _build_attn_mask and create_causal_mask dtype-aware — masks are now
created in the model dtype instead of always float32. This is required
because the Metal SDPA kernel reads the mask buffer as device T* (same
type as Q/K/V). A float32 mask with bf16 Q/K/V would be misinterpreted.