Skip to content

Voxtral Realtime: unify SDPA classes and dtype-aware attention masks#17997

Draft
mergennachin wants to merge 1 commit intomainfrom
voxtral-sdpa-unification
Draft

Voxtral Realtime: unify SDPA classes and dtype-aware attention masks#17997
mergennachin wants to merge 1 commit intomainfrom
voxtral-sdpa-unification

Conversation

@mergennachin
Copy link
Contributor

@mergennachin mergennachin commented Mar 8, 2026

Unify CudaSDPA and StandardEncoderSDPA into a single StandardSDPA class
with transpose_kv parameter, mirroring the MetalSDPA unification. This
gives a symmetric design: MetalSDPA and StandardSDPA share the same
interface (n_heads, n_kv_heads, head_dim, transpose_kv).

Make _build_attn_mask and create_causal_mask dtype-aware — masks are now
created in the model dtype instead of always float32. This is required
because the Metal SDPA kernel reads the mask buffer as device T* (same
type as Q/K/V). A float32 mask with bf16 Q/K/V would be misinterpreted.

Copilot AI review requested due to automatic review settings March 8, 2026 12:27
@mergennachin mergennachin requested a review from lucylq as a code owner March 8, 2026 12:27
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 8, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17997

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Awaiting Approval, 2 New Failures

As of commit e62e797 with merge base 122fdef (image):

AWAITING APPROVAL - The following workflow needs approval before CI can run:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 8, 2026
Unify CudaSDPA and StandardEncoderSDPA into a single StandardSDPA class
with transpose_kv parameter, mirroring the MetalSDPA unification. This
gives a symmetric design: MetalSDPA and StandardSDPA share the same
interface (n_heads, n_kv_heads, head_dim, transpose_kv).

Make _build_attn_mask and create_causal_mask dtype-aware — masks are now
created in the model dtype instead of always float32. This is required
because the Metal SDPA kernel reads the mask buffer as device T* (same
type as Q/K/V). A float32 mask with bf16 Q/K/V would be misinterpreted.
@github-actions
Copy link

github-actions bot commented Mar 8, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR unifies the Voxtral Realtime SDPA implementations to align CUDA and encoder behavior with the existing Metal design, and updates attention mask creation to be dtype-aware to satisfy Metal SDPA kernel requirements.

Changes:

  • Replaced CudaSDPA and StandardEncoderSDPA with a unified StandardSDPA that supports both decoder and streaming-encoder layouts via transpose_kv.
  • Made _build_attn_mask and StandardEncoderRingKVCache.create_causal_mask dtype-aware so additive masks are created in the model’s dtype (required for Metal SDPA).
  • Updated model documentation to reflect the new SDPA classes and mask dtype constraints.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
examples/models/voxtral_realtime/model.py Unifies SDPA modules (adds transpose_kv) and makes additive mask generation dtype-aware for Metal compatibility.
examples/models/voxtral_realtime/model.md Updates architecture/docs to reflect StandardSDPA and dtype-matching additive masks for Metal.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 436 to +440
Metal AOTI doesn't support bool tensor allocation on MPS, so we use
integer arithmetic: clamp(curr_pos - k_pos + 1, 0, 1) gives 1 for
valid positions (k <= curr_pos) and 0 for invalid, then convert to
additive mask (0.0 = attend, -1e9 = don't attend).

Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says the additive mask uses "-1e9 = don't attend", but the implementation now casts to an arbitrary dtype. For float16 in particular, multiplying by 1e9 overflows to +/-inf, so the mask semantics become 0 and -inf (and the doc becomes inaccurate). Consider either updating the doc to reflect dtype-dependent behavior, or computing the masked value from torch.finfo(dtype) (e.g., a representable large negative) to avoid overflow/magic constants while still matching Q/K/V dtype for Metal.

Copilot uses AI. Check for mistakes.
Comment on lines +903 to +907
return torch.where(
valid,
torch.zeros(1, dtype=dtype, device=start_pos.device),
torch.tensor(float("-inf"), dtype=dtype, device=start_pos.device),
)
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_causal_mask now materializes new tensors (torch.zeros(1, ...) and torch.tensor(-inf, ...)) every call and uses a (1,) shape that relies on broadcasting. In streaming this mask is computed frequently, so these per-call allocations can add overhead. Prefer 0-dim scalars (shape ()) and/or reuse cached scalar constants (or buffers) to reduce allocations while keeping the requested dtype/device.

Copilot uses AI. Check for mistakes.
@mergennachin mergennachin marked this pull request as draft March 9, 2026 15:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants