Skip to content

Add SDPA attention implementation#512

Open
jlamypoirier wants to merge 12 commits into
mainfrom
jlp_sdpa-attention
Open

Add SDPA attention implementation#512
jlamypoirier wants to merge 12 commits into
mainfrom
jlp_sdpa-attention

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

@jlamypoirier jlamypoirier commented May 7, 2026

Summary

Flash-attn caps at head_size = 256; head_size = 512 models (e.g. Gemma 4's full-attention layers) currently force the backup path, which materializes the full O(S²) attention matrix and OOMs above ~8K context on H100. Add AttentionImplementation.sdpa so those models can train.

The implementation has two CUDA-aware paths sharing the rest of the layer:

  • CUDA, no sliding window: torch.nested.nested_tensor_from_jagged(values, cu_seqlens, min_seqlen=..., max_seqlen=...) + is_causal=True under EFFICIENT. Each document becomes its own batch element, so cross-document attention is excluded by structure rather than by mask. Pre-computed min_seqlen/max_seqlen are passed in to keep the dispatch sync-free (otherwise it reads them to host on every call — 5 host barriers per layer).
  • CUDA + window and CPU: dense (1, H, total, D) + attn_mask, reusing backup's preprocessed causal+document mask. MATH cannot accept nested + is_causal=True, so the mask path is the only viable form on CPU; on CUDA-with-window the mask is needed because nested + is_causal cannot express sliding window. Per a cluster probe, EFFICIENT also engages on CUDA with explicit attn_mask — only ~4 MiB extra over is_causal for the mask itself.

K/V are manually repeat_interleaved across query heads in both paths because SDPA's fused kernels reject broadcasted GQA inputs.

Auto-fallback simplifies to flash for bf16/fp16 + head_size ≤ 256 + flash available, otherwise sdpa. SDPA now covers every previously-backup case (CPU, windowed without flash, head_size > 256); backup remains as an explicit implementation: backup option but the auto path no longer reaches it.

To pre-compute the seqlens for SDPA, the data preprocessor's max_lengths is changed from a 1-element device tensor to a Python int (flash accepts int natively, verified), min_lengths is added symmetrically, and a return_min_sequence_lengths flag is added to LengthPreprocessingConfig. Plain ints sail through Document.to_device_ since it only moves Tensor fields.

get_preprocessing_config branches by impl: flash needs cu_seqlens + max_seqlens; sdpa-CUDA-no-window needs cu_seqlens + max + min; sdpa-windowed / sdpa-CPU / backup all need document_index (mask is built in preprocess and shared).

Tests: SDPA equivalence check parallel to flash via a small _check_packed closure (CUDA bf16); two head_size=320 cases that exercise the SDPA-only regime; windowed cases now exercise SDPA too. Parametrization refactored from _build_test_cases + single-use variant lists into inline for-loops at module level.

Benchmark — H100 bf16, 20 iters after 10 warmup, fwd+bwd wall

Llama-7B-shape (32 heads MHA, head_size=128):

seq docs window backup sdpa-mask sdpa-nested
4K 1 none 18.6 ms / 8.2 GiB 3.5 / 0.42 GiB 7.7 / 0.39 GiB
8K 1 none 74 / 32.4 GiB 12.6 / 0.88 GiB 18.4 / 0.75 GiB
16K 1 none OOM 50 / 2.07 GiB 60 / 1.57 GiB
16K 4×4K none OOM 50 / 2.07 GiB 18.6 / 1.57 GiB
16K 1 4K OOM 50 / 2.07 GiB n/a

Gemma-4 full-attn (16/8 GQA, head_size=512):

seq docs window backup sdpa-mask sdpa-nested
4K 1 none 11 / 4.4 GiB 46 / 0.84 GiB 31 / 0.81 GiB
8K 1 none 42 / 16.7 GiB 161 / 1.67 GiB 93 / 1.55 GiB
16K 1 none OOM 615 / 3.61 GiB 331 / 3.11 GiB
16K 4×4K none OOM 616 / 3.61 GiB 88 / 3.11 GiB
16K 1 4K OOM 612 / 3.61 GiB n/a

Multi-document varlen — the typical training case — is where nested+is_causal pulls ahead of mask by 2.6×–7×: nested processes each doc as its own batch element (4×4K² attention work) while mask materializes the full 16K² matrix even though same-doc cross-attention is then masked out. Backup OOMs above ~8K at these widths.

Sync events per nested SDPA call (profiled): 0 with pre-computed seqlens; 5 without. Pure wall-clock impact in synthetic bench is ~0.1–1 ms/call, but in a real training loop those 5 host barriers per layer × 30 layers × 8 microbatches = 1200 syncs/step would have prevented host-GPU overlap; with them gone, the nested path's ~6 ms of Python wrapping overhead can hide behind GPU compute.

Test plan

  • Local pytest -v -n 4 tests/layers/test_attention.py (CPU): 56 passed
  • Cluster pytest -v -n 8 tests/layers/test_attention.py (CUDA): 56 passed; all SDPA equivalence checks run, including windowed cases

🤖 Generated with Claude Code

jlamypoirier and others added 3 commits May 7, 2026 19:12
Flash-attn errors out at head_size > 256, so head_size=512 models
cannot train without materializing the full O(S²) attention matrix
via the backup path.

Add `AttentionImplementation.sdpa` using `torch.nested` to bridge the
packed-varlen layout to SDPA's batched signature, pinning the EFFICIENT
backend. K/V are manually repeat_interleaved to match Q heads because
the fused kernels reject broadcasted GQA inputs.

Auto-fallback: flash when bf16/fp16 + head_size <= 256 + flash is
available; backup for windowed attention (the sdpa path does not
support sliding window); sdpa otherwise.

Tests: SDPA equivalence check parallel to flash, gated on CUDA + bf16;
two head_size=320 cases exercising the SDPA-only regime; refactored
parametrization from `_build_test_cases` plus single-use variant lists
into a few inline for-loops at module level.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The SDPA path uses `nested_tensor_from_jagged + is_causal=True` which
has no viable backend on CPU (math rejects nested + is_causal; the
fused EFFICIENT/Flash backends are CUDA-only). Auto previously routed
CPU runs through SDPA and they would crash; route them to backup.

Also widens the SDPA branch to fp32 explicitly: the EFFICIENT backend
engages on CUDA across bf16/fp16/fp32, and benchmarking confirms it
beats backup on memory at every length and matches it on time at
seq_len >= 4096 (backup grows quadratically; SDPA stays near constant).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous attempt routed CPU and windowed configurations to backup
because the nested + is_causal=True form has no viable backend on CPU
and cannot express sliding window. SDPA actually works fine in those
cases when given an explicit attn_mask: backup's preprocessing already
builds the combined causal+document mask (and threads sliding window
into it), so the SDPA path can reuse it as-is.

CUDA without a window keeps the nested + is_causal path so EFFICIENT
runs without materializing the mask. CUDA with a window and CPU runs
both fall through to dense + attn_mask, which lets MATH engage on CPU
and reuses the windowed mask on CUDA.

Auto-fallback simplifies to flash-or-sdpa: SDPA now covers every case
backup used to (CPU, windowed without flash, head_size > 256).

Verified on H100 bf16 head_size=512 that the dense + attn_mask form
also engages EFFICIENT (peak 323 MiB vs 319 MiB for is_causal — the
4 MiB delta is the mask itself).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jlamypoirier jlamypoirier changed the title Add SDPA attention for head_size > 256 Add SDPA attention implementation May 8, 2026
jlamypoirier and others added 9 commits May 8, 2026 15:54
…on call

The CUDA-no-window and dense-mask paths shared the K/V expansion, the SDPA
call signature (dropout + scale), and the (B, H, S, D) layout requirement.
Lift those out: rebind query/key/value to either nested-jagged or
unsqueeze(0)'d 4D tensors in the per-path setup, build an `sdpa_args` dict
that adds `is_causal=...` for nested or `attn_mask=...` for dense, then
make a single SDPA call that works for both. The unwrap branches on
`output.is_nested`.

Also drops the explicit EFFICIENT_ATTENTION pin from the nested path —
nested + is_causal=True has no other viable backend (MATH and Flash both
reject it), so the auto pick lands on EFFICIENT or the call errors out
either way.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The nested path floors per-call wall around 6 ms because SDPA's nested
dispatch pulls `max_seqlen` / `min_seqlen` to host (5 cudaMemcpyAsync DtoH
+ cudaStreamSynchronize per call). Sync count is fixed regardless of
num_docs, so the path stays much faster than dense+mask in varlen
training; the comment just makes the cost discoverable.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
PyTorch's nested SDPA dispatch reads `max_seqlen` and `min_seqlen` to
host on every call (5 cudaMemcpyAsync DtoH + cudaStreamSynchronize per
call) when they aren't supplied. Both are trivially derivable from the
Python `lengths` list at preprocessing time, so we compute them as
plain ints, thread them through `BlockModelInput` / kwargs, and pass
them to `nested_tensor_from_jagged`.

While doing this, drop the `torch.full((1,), ..., device=...)` wrap
on `max_lengths` — the value was always a Python int, and flash
accepts an int directly (verified). The auto-device-move on the
`Document` base class only moves Tensor fields, so plain ints pass
through to_kwargs untouched.

Sync events per call (Llama-7B-shape, 4 docs × 4096):
  before: 5 cudaStreamSynchronize + 5 Memcpy DtoH
  after:  0

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The sdpa_nested vs sdpa_dense choice was previously a branch inside
_attn_sdpa duplicated across preprocessing config and preprocess(),
with the runtime check using `query.is_cuda` and the preprocessing
checks using `self._distributed_config.use_cuda`. Promoting it to two
enum values resolves the duplication and the two-sources-of-truth.

Also add a CPU-side fp32 sdpa_dense equivalence check (the dense path
is now the auto-fallback on CPU and was previously unexercised in CI),
and factor first_length_k out of max_lengths/min_lengths.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Add `sdpa` enum value that auto-resolves to `sdpa_nested` or `sdpa_dense`
  at init based on `(use_cuda, window_size)`. Auto-fallback now resolves to
  `flash` or `sdpa`; the second stage then picks the concrete variant.
- Update the `implementation` field's desc to describe the new cascade.
- Validate `flash` + `head_size > 256` in `_validate` so explicit configs
  fail at construction with a clear message instead of inside flash-attn.
- Drop the dead `if attention_mask is not None` guard in `_attn_sdpa`'s
  dense branch — `_preprocess_for_backup_attention` always sets a non-None
  mask, so the conditional could never short-circuit.
- Comment near `repeat_interleave` that `enable_gqa=True` is intentionally
  avoided (forces MATH fallback, incompatible with nested-jagged path).
- Extend `_check_packed` to run backward and compare parameter gradients,
  closing the autograd-coverage gap through `nested_tensor_from_jagged`
  and GQA `repeat_interleave`. Bump the bf16 per-seq reference to
  `with_backward=True` and capture `grads_ref_bf16` for the bf16 checks.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Reject `sdpa_nested` + non-None `window_size` in `_validate`. The nested
  + is_causal path can't express sliding window; previously a config that
  explicitly chose `sdpa_nested` with a window would silently produce
  full-causal output instead.
- Skip the bf16 `sdpa_nested` test for windowed configs (the validation
  above would otherwise reject the test config).
- Add a bf16 `sdpa_dense` equivalence check. Same kernel as backup but
  packed-vs-per-seq reductions diverge at bf16 by ~7e-3 — establishes that
  the bf16 reference itself isn't reproducible to forward tolerance, so
  flash/sdpa_nested divergence is reduction-order noise, not a bug.
- Split `_check_packed`'s rtol into `out_rtol` and `grad_rtol`; bf16 grads
  diverge above the forward bound (up to ~1.3 % for flash). New grad
  tolerance is 1.5e-2 at bf16, 1e-5 at fp32.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Extract `_FLASH_MAX_HEAD_SIZE = 256` in config.py; reference it from both
  the auto-resolution in `Attention.__init__` and the `_validate` guard.
- Annotate the two-stage implementation resolution to explain why `auto`
  and `sdpa` share the same cascade structure.
- Comment the int64 cast in `_attn_sdpa` so the cast isn't mistaken for
  unnecessary work (`nested_tensor_from_jagged` requires int64 offsets).
- Qualify the key/value shape comments on `_attn_sdpa` as pre-GQA-expansion
  to match the rebinding after `repeat_interleave`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Drop the now-dead `with_backward` parameter on `_run_per_seq_reference`;
  both callers default to `True` after the backward-parity refactor.
- Rename `large_head_causal` / `large_head_mqa` to `*_no_norm` so they
  match the suffix convention the cross-product loop establishes.
- Fix the bf16-grad comment that called sdpa_dense the "same-kernel" path;
  sdpa_dense routes through SDPA's MATH/EFFICIENT backend, backup uses its
  own kernel. Reworded to make the divergence-across-kernels point clear.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
_check_packed only closed over data that can be passed as arguments, so
there's no reason for it to be a nested helper. Promoting it to module
scope and consolidating the three bf16 call sites into a single for-loop
also drops the verbose block comments preceding each path; the one
remaining comment justifies the looser bf16 grad rtol.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant