From c24647ac163107e6ccf276cbe300d70969d75c22 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 7 May 2026 19:12:39 -0400 Subject: [PATCH 01/12] Add SDPA attention for head_size > 256 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Flash-attn errors out at head_size > 256, so head_size=512 models cannot train without materializing the full O(S²) attention matrix via the backup path. Add `AttentionImplementation.sdpa` using `torch.nested` to bridge the packed-varlen layout to SDPA's batched signature, pinning the EFFICIENT backend. K/V are manually repeat_interleaved to match Q heads because the fused kernels reject broadcasted GQA inputs. Auto-fallback: flash when bf16/fp16 + head_size <= 256 + flash is available; backup for windowed attention (the sdpa path does not support sliding window); sdpa otherwise. Tests: SDPA equivalence check parallel to flash, gated on CUDA + bf16; two head_size=320 cases exercising the SDPA-only regime; refactored parametrization from `_build_test_cases` plus single-use variant lists into a few inline for-loops at module level. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/attention/attention.py | 65 +++++++- fast_llm/layers/attention/config.py | 1 + tests/layers/test_attention.py | 220 ++++++++++++++----------- 3 files changed, 183 insertions(+), 103 deletions(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 12f85bf28..b9a0cc944 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -81,10 +81,19 @@ def __init__( ) self._implementation = self._config.implementation if self._implementation == AttentionImplementation.auto: - if _flash_available and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16): + if ( + _flash_available + and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16) + and self._config.head_size <= 256 + ): self._implementation = AttentionImplementation.flash - else: + elif self._config.window_size is not None: + # SDPA path doesn't support sliding window; backup is the only fallback that does. self._implementation = AttentionImplementation.backup + else: + self._implementation = AttentionImplementation.sdpa + if self._implementation == AttentionImplementation.sdpa: + assert self._config.window_size is None, "SDPA implementation does not support sliding window." self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim( @@ -258,6 +267,38 @@ def _attn_flash( softmax_scale=self._softmax_scale, ) + def _attn_sdpa( + self, + query: torch.Tensor, # total_q, heads, head_size + key: torch.Tensor, # total_k, head_groups, head_size + value: torch.Tensor, # total_k, head_groups, head_size + kwargs: dict[str, typing.Any], + ) -> torch.Tensor: # total_q, heads, head_size + # SDPA's EFFICIENT backend (the only one that supports head_size > 256) requires + # Q/K/V to have the same num_heads, so we materialize K/V across query heads. + # Wrap as nested-jagged to give SDPA the per-document mask via batch elements, + # avoiding the pack→pad→gather dance. + if self._local_heads_per_group > 1: + key = key.repeat_interleave(self._local_heads_per_group, dim=1) + value = value.repeat_interleave(self._local_heads_per_group, dim=1) + cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64) + cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64) + query_nested = torch.nested.nested_tensor_from_jagged(query, cu_seqlens_q) + key_nested = torch.nested.nested_tensor_from_jagged(key, cu_seqlens_k) + value_nested = torch.nested.nested_tensor_from_jagged(value, cu_seqlens_k) + + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): + output_nested = torch.nn.functional.scaled_dot_product_attention( + query_nested.transpose(1, 2), + key_nested.transpose(1, 2), + value_nested.transpose(1, 2), + is_causal=self._config.causal, + dropout_p=self._config.dropout if self.training else 0.0, + scale=self._softmax_scale, + ).transpose(1, 2) + + return output_nested.values() + def _apply_norm_with_grad_capture( self, norm: torch.nn.Module, x: torch.Tensor ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: @@ -420,6 +461,8 @@ def _forward( with set_generator(self._distributed.tp_generator): if self._implementation == AttentionImplementation.flash: input_ = self._attn_flash(query, key, value, kwargs) + elif self._implementation == AttentionImplementation.sdpa: + input_ = self._attn_sdpa(query, key, value, kwargs) elif self._implementation == AttentionImplementation.backup: # TODO: Avoid the flattens. input_ = self._attn_backup(query, key, value, kwargs) @@ -472,7 +515,10 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c attention_compute = sequence_q * sequence_k * attn_compute_base - if (not config.hardware) or self._implementation in AttentionImplementation.flash: + if (not config.hardware) or self._implementation in ( + AttentionImplementation.flash, + AttentionImplementation.sdpa, + ): # Remove non-causal part. (TODO: Support non-causal) # TODO: Compute is overestimated without cross-document attention. attention_compute -= (sequence_q * (sequence_q - 1) * attn_compute_base) // 2 @@ -498,15 +544,18 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c ) def get_preprocessing_config(self) -> dict[str, typing.Any]: - return ( - { + if self._implementation == AttentionImplementation.flash: + return { "return_cumulative_sequence_lengths": True, "return_max_sequence_lengths": True, "causal": self._config.causal, } - if self._implementation == AttentionImplementation.flash - else {"return_document_index": True, "causal": self._config.causal} - ) + elif self._implementation == AttentionImplementation.sdpa: + return {"return_cumulative_sequence_lengths": True, "causal": self._config.causal} + elif self._implementation == AttentionImplementation.backup: + return {"return_document_index": True, "causal": self._config.causal} + else: + raise NotImplementedError(self._implementation) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._rotary.preprocess(kwargs) diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index cc5d80e88..69aa4f484 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -38,6 +38,7 @@ class AttentionKwargs(MixerKwargs): class AttentionImplementation(enum.StrEnum): auto = "auto" flash = "flash" + sdpa = "sdpa" backup = "backup" diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index d572816b2..39a4d5d58 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -174,78 +174,101 @@ def expected_output( return torch.nn.functional.linear(attn_out.flatten(1), attention.dense.weight.detach()) -_base_attention_cases = [ +_LENGTHS_FULL = [[15], [6, 9], [4, 1, 10], [20, 32, 10, 11, 9, 18]] +_LENGTHS_SHORT = [[15], [4, 1, 10]] +_LENGTHS_SINGLE = [[15]] + +_attention_test_cases: list[tuple[AttentionTestConfig, list[int]]] = [] + +# Mask, group, and window base cases — no norms, swept over all length sets. +for name, kwargs in ( ("causal", {"causal": True}), ("noncausal", {"causal": False}), ("window", {"causal": True, "window_size": 4}), ("mqa", {"causal": True, "kv_heads": 1}), ("mha", {"causal": True, "kv_heads": _HEADS}), -] - -_attention_rotary_cases = [ - # Rotary: packing equivalence is skipped for multi-document inputs (packed rotary uses global - # positions; per-sequence reference uses per-doc positions). All three checks run for single-doc inputs. - ("causal_rotary", {"causal": True, "rotary": True}), -] - -_attention_norm_variants = [ - ("no_norm", {}), - ("query_norm", {"query_norm": True}), - ("key_norm", {"key_norm": True}), - ("value_norm", {"value_norm": True}), - ("both_norms", {"query_norm": True, "key_norm": True}), - ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), -] - -_attention_shared_key_value_cases = [ - ("shared_key_value", {"shared_key_value": True}), - ("shared_key_value_rotary", {"shared_key_value": True, "rotary": True}), - # Gemma 4's full-attention layer combines shared_key_value with ProportionalRotary. +): + for lengths in _LENGTHS_FULL: + _attention_test_cases.append((AttentionTestConfig(name=f"{name}_no_norm", **kwargs), lengths)) + +# Per-head norm variants on causal and shared key/value bases. Rotary bases use single-doc +# inputs because the packed and per-sequence rotary references diverge across boundaries. +for base_name, base_kwargs, variants, length_set in ( + ( + "causal", + {"causal": True}, + ( + ("query_norm", {"query_norm": True}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("both_norms", {"query_norm": True, "key_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SHORT, + ), + ( + "causal_rotary", + {"causal": True, "rotary": True}, + ( + ("no_norm", {}), + ("query_norm", {"query_norm": True}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("both_norms", {"query_norm": True, "key_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SINGLE, + ), + ( + "shared_key_value", + {"shared_key_value": True}, + ( + ("no_norm", {}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SHORT, + ), + ( + "shared_key_value_rotary", + {"shared_key_value": True, "rotary": True}, + ( + ("no_norm", {}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SINGLE, + ), ( "shared_key_value_proportional_rotary", {"shared_key_value": True, "rotary": True, "rotary_partial_rotary_factor": 0.5}, + ( + ("no_norm", {}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SINGLE, ), -] - -_attention_shared_key_value_norm_variants = [ - ("no_norm", {}), - ("key_norm", {"key_norm": True}), - ("value_norm", {"value_norm": True}), - ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), -] - -# Norms apply per-head and don't interact with mask/group structure, so we test all norm -# variants on a single base (causal) instead of crossing every norm with every base. -# Lengths matter for packing/flash equivalence checks, so we sweep all lengths on the -# base × no_norm cases (where seq layout interacts most with attention math). -_LENGTHS_FULL = [[15], [6, 9], [4, 1, 10], [20, 32, 10, 11, 9, 18]] -_LENGTHS_SHORT = [[15], [4, 1, 10]] -_LENGTHS_SINGLE = [[15]] - - -def _build_test_cases() -> list[tuple[AttentionTestConfig, list[int]]]: - cases: list[tuple[AttentionTestConfig, list[int]]] = [] - for base_name, base_kwargs in _base_attention_cases: - config = AttentionTestConfig(name=f"{base_name}_no_norm", **base_kwargs) - cases.extend((config, lengths) for lengths in _LENGTHS_FULL) - for variant_name, variant_kwargs in _attention_norm_variants: - if variant_name == "no_norm": - continue - config = AttentionTestConfig(name=f"causal_{variant_name}", causal=True, **variant_kwargs) - cases.extend((config, lengths) for lengths in _LENGTHS_SHORT) - for base_name, base_kwargs in _attention_rotary_cases: - for variant_name, variant_kwargs in _attention_norm_variants: - config = AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs) - cases.extend((config, lengths) for lengths in _LENGTHS_SINGLE) - for base_name, base_kwargs in _attention_shared_key_value_cases: - lengths_set = _LENGTHS_SINGLE if base_kwargs.get("rotary") else _LENGTHS_SHORT - for variant_name, variant_kwargs in _attention_shared_key_value_norm_variants: - config = AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs) - cases.extend((config, lengths) for lengths in lengths_set) - return cases - +): + for variant_name, variant_kwargs in variants: + for lengths in length_set: + _attention_test_cases.append( + ( + AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs), + lengths, + ) + ) -_attention_test_cases = _build_test_cases() +# head_size > 256 — exercises the SDPA-only regime (flash caps at 256). +for name, kwargs in ( + ("large_head_causal", {"causal": True, "head_size": 320}), + ("large_head_mqa", {"causal": True, "head_size": 320, "kv_heads": 1}), +): + for lengths in _LENGTHS_SHORT: + _attention_test_cases.append((AttentionTestConfig(name=name, **kwargs), lengths)) def _run_per_seq_reference( @@ -357,50 +380,57 @@ def _test_attention(config: AttentionTestConfig, lengths: list[int]) -> None: Assert.rms_close_relative(param.grad_buffer, grad_ref, 1e-5, 1e-7, msg=name) stage.reset_gradients() - # Flash equivalence check: packed flash output must match per-sequence bfloat16 backup reference. - if _flash_available: - distributed_config_bf16 = DistributedConfig(compute_dtype=DataType.bfloat16, use_cuda=True) - distributed_bf16 = Distributed(distributed_config_bf16) + # Flash and SDPA equivalence checks: each implementation's packed bfloat16 output must + # match a per-sequence bfloat16 backup reference. + if not torch.cuda.is_available(): + return - attention_backup_bf16: Attention = config.get_attention_config("backup").get_layer( - distributed_config_bf16, hidden_dim, lr_scale=None, peft=None, return_bias=False - ) - stage_backup_bf16 = get_stage([attention_backup_bf16], distributed_bf16) - for param_bf16, param_f32 in zip(attention_backup_bf16.parameters(), attention.parameters(), strict=True): - param_bf16.data.copy_(param_f32.data) - - hidden_states_bf16 = hidden_states.detach().to(torch.bfloat16) - out_ref_bf16 = _run_per_seq_reference( - attention_backup_bf16, - stage_backup_bf16, - distributed_config_bf16, - hidden_states_bf16, - lengths, - device, - with_backward=False, - ) + distributed_config_bf16 = DistributedConfig(compute_dtype=DataType.bfloat16, use_cuda=True) + distributed_bf16 = Distributed(distributed_config_bf16) + + attention_backup_bf16: Attention = config.get_attention_config("backup").get_layer( + distributed_config_bf16, hidden_dim, lr_scale=None, peft=None, return_bias=False + ) + stage_backup_bf16 = get_stage([attention_backup_bf16], distributed_bf16) + for param_bf16, param_f32 in zip(attention_backup_bf16.parameters(), attention.parameters(), strict=True): + param_bf16.data.copy_(param_f32.data) + + hidden_states_bf16 = hidden_states.detach().to(torch.bfloat16) + out_ref_bf16 = _run_per_seq_reference( + attention_backup_bf16, + stage_backup_bf16, + distributed_config_bf16, + hidden_states_bf16, + lengths, + device, + with_backward=False, + ) - attention_flash: Attention = config.get_attention_config("flash").get_layer( + def _check_packed(implementation: str) -> None: + attention_impl: Attention = config.get_attention_config(implementation).get_layer( distributed_config_bf16, hidden_dim, lr_scale=None, peft=None, return_bias=False ) - stage_flash = get_stage([attention_flash], distributed_bf16) - for param_flash, param_f32 in zip(attention_flash.parameters(), attention.parameters(), strict=True): - param_flash.data.copy_(param_f32.data) - - (model_input_flash,) = LanguageModelBatch( + stage_impl = get_stage([attention_impl], distributed_bf16) + for param_impl, param_f32 in zip(attention_impl.parameters(), attention.parameters(), strict=True): + param_impl.data.copy_(param_f32.data) + (model_input,) = LanguageModelBatch( tokens=torch.empty(num_tokens, dtype=torch.int64, device=device), lengths=lengths ).get_model_inputs( LanguageModelBatchPreprocessingConfig( distributed=distributed_config_bf16, predicted_tokens=0, - **attention_flash.get_preprocessing_config(), + **attention_impl.get_preprocessing_config(), ) ) - kwargs_flash = model_input_flash.to_kwargs() - attention_flash.preprocess(kwargs_flash) - out_flash, _ = stage_flash.forward(hidden_states_bf16, kwargs_flash) - - Assert.rms_close_relative(out_flash, out_ref_bf16, 5e-3, 1e-7) + kwargs_impl = model_input.to_kwargs() + attention_impl.preprocess(kwargs_impl) + out_impl, _ = stage_impl.forward(hidden_states_bf16, kwargs_impl) + Assert.rms_close_relative(out_impl, out_ref_bf16, 5e-3, 1e-7) + + if _flash_available and config.head_size <= 256: + _check_packed("flash") + if config.window_size is None: + _check_packed("sdpa") @pytest.mark.slow From 23412b70f41bdb15bb605a033052b1d38ecf19dd Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 May 2026 14:05:43 -0400 Subject: [PATCH 02/12] Route auto-fallback to backup on CPU The SDPA path uses `nested_tensor_from_jagged + is_causal=True` which has no viable backend on CPU (math rejects nested + is_causal; the fused EFFICIENT/Flash backends are CUDA-only). Auto previously routed CPU runs through SDPA and they would crash; route them to backup. Also widens the SDPA branch to fp32 explicitly: the EFFICIENT backend engages on CUDA across bf16/fp16/fp32, and benchmarking confirms it beats backup on memory at every length and matches it on time at seq_len >= 4096 (backup grows quadratically; SDPA stays near constant). Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/attention/attention.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index b9a0cc944..e3e461cf5 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -87,11 +87,13 @@ def __init__( and self._config.head_size <= 256 ): self._implementation = AttentionImplementation.flash - elif self._config.window_size is not None: - # SDPA path doesn't support sliding window; backup is the only fallback that does. - self._implementation = AttentionImplementation.backup - else: + elif self._distributed_config.use_cuda and self._config.window_size is None: + # SDPA's EFFICIENT backend handles every dtype on CUDA; on CPU the + # nested + is_causal path has no viable backend, and SDPA does not + # support sliding window so windowed runs need backup either way. self._implementation = AttentionImplementation.sdpa + else: + self._implementation = AttentionImplementation.backup if self._implementation == AttentionImplementation.sdpa: assert self._config.window_size is None, "SDPA implementation does not support sliding window." From bd17da3787be9e6db593623fe8695f913bc44451 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 May 2026 14:41:07 -0400 Subject: [PATCH 03/12] Use backup mask in SDPA fallback paths MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous attempt routed CPU and windowed configurations to backup because the nested + is_causal=True form has no viable backend on CPU and cannot express sliding window. SDPA actually works fine in those cases when given an explicit attn_mask: backup's preprocessing already builds the combined causal+document mask (and threads sliding window into it), so the SDPA path can reuse it as-is. CUDA without a window keeps the nested + is_causal path so EFFICIENT runs without materializing the mask. CUDA with a window and CPU runs both fall through to dense + attn_mask, which lets MATH engage on CPU and reuses the windowed mask on CUDA. Auto-fallback simplifies to flash-or-sdpa: SDPA now covers every case backup used to (CPU, windowed without flash, head_size > 256). Verified on H100 bf16 head_size=512 that the dense + attn_mask form also engages EFFICIENT (peak 323 MiB vs 319 MiB for is_causal — the 4 MiB delta is the mask itself). Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/attention/attention.py | 78 ++++++++++++++++---------- tests/layers/test_attention.py | 3 +- 2 files changed, 49 insertions(+), 32 deletions(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index e3e461cf5..605074222 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -87,15 +87,8 @@ def __init__( and self._config.head_size <= 256 ): self._implementation = AttentionImplementation.flash - elif self._distributed_config.use_cuda and self._config.window_size is None: - # SDPA's EFFICIENT backend handles every dtype on CUDA; on CPU the - # nested + is_causal path has no viable backend, and SDPA does not - # support sliding window so windowed runs need backup either way. - self._implementation = AttentionImplementation.sdpa else: - self._implementation = AttentionImplementation.backup - if self._implementation == AttentionImplementation.sdpa: - assert self._config.window_size is None, "SDPA implementation does not support sliding window." + self._implementation = AttentionImplementation.sdpa self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim( @@ -276,30 +269,48 @@ def _attn_sdpa( value: torch.Tensor, # total_k, head_groups, head_size kwargs: dict[str, typing.Any], ) -> torch.Tensor: # total_q, heads, head_size - # SDPA's EFFICIENT backend (the only one that supports head_size > 256) requires - # Q/K/V to have the same num_heads, so we materialize K/V across query heads. - # Wrap as nested-jagged to give SDPA the per-document mask via batch elements, - # avoiding the pack→pad→gather dance. + # SDPA's fused kernels require Q/K/V to share heads, so we expand K/V across query heads. if self._local_heads_per_group > 1: key = key.repeat_interleave(self._local_heads_per_group, dim=1) value = value.repeat_interleave(self._local_heads_per_group, dim=1) - cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64) - cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64) - query_nested = torch.nested.nested_tensor_from_jagged(query, cu_seqlens_q) - key_nested = torch.nested.nested_tensor_from_jagged(key, cu_seqlens_k) - value_nested = torch.nested.nested_tensor_from_jagged(value, cu_seqlens_k) - - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): - output_nested = torch.nn.functional.scaled_dot_product_attention( - query_nested.transpose(1, 2), - key_nested.transpose(1, 2), - value_nested.transpose(1, 2), - is_causal=self._config.causal, + + if query.is_cuda and self._config.window_size is None: + # Most-efficient path: nested-jagged + is_causal lets EFFICIENT skip materializing + # the attention mask. Document boundaries are encoded by the per-doc batch elements. + cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64) + cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64) + query_nested = torch.nested.nested_tensor_from_jagged(query, cu_seqlens_q) + key_nested = torch.nested.nested_tensor_from_jagged(key, cu_seqlens_k) + value_nested = torch.nested.nested_tensor_from_jagged(value, cu_seqlens_k) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): + output_nested = torch.nn.functional.scaled_dot_product_attention( + query_nested.transpose(1, 2), + key_nested.transpose(1, 2), + value_nested.transpose(1, 2), + is_causal=self._config.causal, + dropout_p=self._config.dropout if self.training else 0.0, + scale=self._softmax_scale, + ).transpose(1, 2) + return output_nested.values() + + # CPU MATH rejects nested + is_causal, and the nested path can't express sliding window. + # Both fall back on the same dense + attn_mask form, reusing backup's preprocessed mask. + # Backup builds it as (1, sq, 1, sk) for its head-grouped layout; SDPA wants (B, H, sq, sk). + attention_mask = kwargs[AttentionKwargs.attention_mask] + if attention_mask is not None: + attention_mask = attention_mask.transpose(1, 2) + return ( + torch.nn.functional.scaled_dot_product_attention( + query.unsqueeze(0).transpose(1, 2), + key.unsqueeze(0).transpose(1, 2), + value.unsqueeze(0).transpose(1, 2), + attn_mask=attention_mask, dropout_p=self._config.dropout if self.training else 0.0, scale=self._softmax_scale, - ).transpose(1, 2) - - return output_nested.values() + ) + .transpose(1, 2) + .squeeze(0) + ) def _apply_norm_with_grad_capture( self, norm: torch.nn.Module, x: torch.Tensor @@ -552,16 +563,23 @@ def get_preprocessing_config(self) -> dict[str, typing.Any]: "return_max_sequence_lengths": True, "causal": self._config.causal, } - elif self._implementation == AttentionImplementation.sdpa: + elif ( + self._implementation == AttentionImplementation.sdpa + and self._distributed_config.use_cuda + and self._config.window_size is None + ): return {"return_cumulative_sequence_lengths": True, "causal": self._config.causal} - elif self._implementation == AttentionImplementation.backup: + elif self._implementation in (AttentionImplementation.sdpa, AttentionImplementation.backup): return {"return_document_index": True, "causal": self._config.causal} else: raise NotImplementedError(self._implementation) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._rotary.preprocess(kwargs) - if self._implementation == AttentionImplementation.backup: + if self._implementation == AttentionImplementation.backup or ( + self._implementation == AttentionImplementation.sdpa + and (not self._distributed_config.use_cuda or self._config.window_size is not None) + ): self._preprocess_for_backup_attention(kwargs) def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> None: diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index 39a4d5d58..ef96dbf96 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -429,8 +429,7 @@ def _check_packed(implementation: str) -> None: if _flash_available and config.head_size <= 256: _check_packed("flash") - if config.window_size is None: - _check_packed("sdpa") + _check_packed("sdpa") @pytest.mark.slow From 4c99e884e5cc1edf76c0447788220d5a8ad8eee1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 May 2026 15:54:25 -0400 Subject: [PATCH 04/12] Unify the two SDPA paths around a single F.scaled_dot_product_attention call MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The CUDA-no-window and dense-mask paths shared the K/V expansion, the SDPA call signature (dropout + scale), and the (B, H, S, D) layout requirement. Lift those out: rebind query/key/value to either nested-jagged or unsqueeze(0)'d 4D tensors in the per-path setup, build an `sdpa_args` dict that adds `is_causal=...` for nested or `attn_mask=...` for dense, then make a single SDPA call that works for both. The unwrap branches on `output.is_nested`. Also drops the explicit EFFICIENT_ATTENTION pin from the nested path — nested + is_causal=True has no other viable backend (MATH and Flash both reject it), so the auto pick lands on EFFICIENT or the call errors out either way. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/attention/attention.py | 63 ++++++++++++-------------- 1 file changed, 29 insertions(+), 34 deletions(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 605074222..2024cec45 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -274,43 +274,38 @@ def _attn_sdpa( key = key.repeat_interleave(self._local_heads_per_group, dim=1) value = value.repeat_interleave(self._local_heads_per_group, dim=1) + sdpa_args: dict[str, typing.Any] = { + "dropout_p": self._config.dropout if self.training else 0.0, + "scale": self._softmax_scale, + } if query.is_cuda and self._config.window_size is None: - # Most-efficient path: nested-jagged + is_causal lets EFFICIENT skip materializing - # the attention mask. Document boundaries are encoded by the per-doc batch elements. + # Wrap each document as its own batch element via nested-jagged so cross-doc masking + # is structural and EFFICIENT skips materializing the attention mask. cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64) cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64) - query_nested = torch.nested.nested_tensor_from_jagged(query, cu_seqlens_q) - key_nested = torch.nested.nested_tensor_from_jagged(key, cu_seqlens_k) - value_nested = torch.nested.nested_tensor_from_jagged(value, cu_seqlens_k) - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): - output_nested = torch.nn.functional.scaled_dot_product_attention( - query_nested.transpose(1, 2), - key_nested.transpose(1, 2), - value_nested.transpose(1, 2), - is_causal=self._config.causal, - dropout_p=self._config.dropout if self.training else 0.0, - scale=self._softmax_scale, - ).transpose(1, 2) - return output_nested.values() - - # CPU MATH rejects nested + is_causal, and the nested path can't express sliding window. - # Both fall back on the same dense + attn_mask form, reusing backup's preprocessed mask. - # Backup builds it as (1, sq, 1, sk) for its head-grouped layout; SDPA wants (B, H, sq, sk). - attention_mask = kwargs[AttentionKwargs.attention_mask] - if attention_mask is not None: - attention_mask = attention_mask.transpose(1, 2) - return ( - torch.nn.functional.scaled_dot_product_attention( - query.unsqueeze(0).transpose(1, 2), - key.unsqueeze(0).transpose(1, 2), - value.unsqueeze(0).transpose(1, 2), - attn_mask=attention_mask, - dropout_p=self._config.dropout if self.training else 0.0, - scale=self._softmax_scale, - ) - .transpose(1, 2) - .squeeze(0) - ) + query = torch.nested.nested_tensor_from_jagged(query, cu_seqlens_q) + key = torch.nested.nested_tensor_from_jagged(key, cu_seqlens_k) + value = torch.nested.nested_tensor_from_jagged(value, cu_seqlens_k) + sdpa_args["is_causal"] = self._config.causal + else: + # Dense + backup's preprocessed causal+document mask. Required on CPU (MATH rejects + # nested + is_causal) and on CUDA with sliding window (the nested path can't express + # it). Backup builds the mask as (1, sq, 1, sk); SDPA wants (B, H, sq, sk). + attention_mask = kwargs[AttentionKwargs.attention_mask] + if attention_mask is not None: + attention_mask = attention_mask.transpose(1, 2) + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + sdpa_args["attn_mask"] = attention_mask + + output = torch.nn.functional.scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + **sdpa_args, + ).transpose(1, 2) + return output.values() if output.is_nested else output.squeeze(0) def _apply_norm_with_grad_capture( self, norm: torch.nn.Module, x: torch.Tensor From ffa8b7ec34b9ea9d817fe4cb99e9894b6369497d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 May 2026 16:31:09 -0400 Subject: [PATCH 05/12] Note SDPA nested dispatch sync cost in the attention comment The nested path floors per-call wall around 6 ms because SDPA's nested dispatch pulls `max_seqlen` / `min_seqlen` to host (5 cudaMemcpyAsync DtoH + cudaStreamSynchronize per call). Sync count is fixed regardless of num_docs, so the path stays much faster than dense+mask in varlen training; the comment just makes the cost discoverable. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/attention/attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 2024cec45..57b826ab2 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -280,7 +280,9 @@ def _attn_sdpa( } if query.is_cuda and self._config.window_size is None: # Wrap each document as its own batch element via nested-jagged so cross-doc masking - # is structural and EFFICIENT skips materializing the attention mask. + # is structural and EFFICIENT skips materializing the attention mask. SDPA's nested + # dispatch reads `max_seqlen`/`min_seqlen` to host (5 cudaMemcpyAsync DtoH per call), + # which floors per-call wall at ~6 ms; still much faster than dense+mask in varlen. cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64) cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64) query = torch.nested.nested_tensor_from_jagged(query, cu_seqlens_q) From f6958b5aa8a5fb5a77c847e89f9d71fa1a339148 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 May 2026 16:57:23 -0400 Subject: [PATCH 06/12] Pre-compute min/max seq lengths so SDPA's nested path doesn't sync MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PyTorch's nested SDPA dispatch reads `max_seqlen` and `min_seqlen` to host on every call (5 cudaMemcpyAsync DtoH + cudaStreamSynchronize per call) when they aren't supplied. Both are trivially derivable from the Python `lengths` list at preprocessing time, so we compute them as plain ints, thread them through `BlockModelInput` / kwargs, and pass them to `nested_tensor_from_jagged`. While doing this, drop the `torch.full((1,), ..., device=...)` wrap on `max_lengths` — the value was always a Python int, and flash accepts an int directly (verified). The auto-device-move on the `Document` base class only moves Tensor fields, so plain ints pass through to_kwargs untouched. Sync events per call (Llama-7B-shape, 4 docs × 4096): before: 5 cudaStreamSynchronize + 5 Memcpy DtoH after: 0 Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/data/document/block.py | 25 +++++++++++++------ fast_llm/data/document/config.py | 1 + fast_llm/layers/attention/attention.py | 34 ++++++++++++++++++++------ fast_llm/layers/attention/config.py | 2 ++ 4 files changed, 48 insertions(+), 14 deletions(-) diff --git a/fast_llm/data/document/block.py b/fast_llm/data/document/block.py index 530be42ea..b9f2e30f2 100644 --- a/fast_llm/data/document/block.py +++ b/fast_llm/data/document/block.py @@ -24,8 +24,10 @@ class BlockModelInput(ModelInput): lengths: list[int] = None cumulative_lengths_q: torch.Tensor | None = None cumulative_lengths_k: torch.Tensor | None = None - max_length_q: torch.Tensor | None = None - max_length_k: torch.Tensor | None = None + max_length_q: int | None = None + max_length_k: int | None = None + min_length_q: int | None = None + min_length_k: int | None = None document_index_q: torch.Tensor | None = None document_index_k: torch.Tensor | None = None position_index: torch.Tensor | None = None @@ -44,6 +46,8 @@ def to_kwargs(self) -> dict[str, typing.Any]: AttentionKwargs.cu_seqlens_k: self.cumulative_lengths_k, AttentionKwargs.max_seqlen_q: self.max_length_q, AttentionKwargs.max_seqlen_k: self.max_length_k, + AttentionKwargs.min_seqlen_q: self.min_length_q, + AttentionKwargs.min_seqlen_k: self.min_length_k, AttentionKwargs.document_index_q: self.document_index_q, AttentionKwargs.document_index_k: self.document_index_k, LanguageModelKwargs.position_ids: self.position_index, @@ -101,6 +105,8 @@ def preprocess(self, model_input: BlockModelInput, config: LengthPreprocessingCo model_input.cumulative_lengths_q, model_input.cumulative_lengths_k = self.cumulative_lengths if config.return_max_sequence_lengths or config.return_document_index: model_input.max_length_q, model_input.max_length_k = self.max_lengths + if config.return_min_sequence_lengths: + model_input.min_length_q, model_input.min_length_k = self.min_lengths if config.return_document_index: model_input.document_index_q, model_input.document_index_k = self.document_index if config.return_position_index: @@ -118,13 +124,18 @@ def cumulative_lengths(self) -> tuple[torch.Tensor, torch.Tensor]: return cumulative_lengths_q, cumulative_lengths_k @functools.cached_property - def max_lengths(self) -> tuple[torch.Tensor, torch.Tensor]: + def max_lengths(self) -> tuple[int, int]: max_length_q = max(self.lengths) max_length_k = max(max_length_q, self.sequence_k_past + self.lengths[0] - self.first_document_begin) - return ( - torch.full((1,), max_length_q, dtype=torch.int32, device=self.device), - torch.full((1,), max_length_k, dtype=torch.int32, device=self.device), - ) + return max_length_q, max_length_k + + @functools.cached_property + def min_lengths(self) -> tuple[int, int]: + min_length_q = min(self.lengths) + # First doc's K-side length includes the past KV prefix; remaining docs match q-side. + first_length_k = self.sequence_k_past + self.lengths[0] - self.first_document_begin + min_length_k = min(first_length_k, *self.lengths[1:]) if len(self.lengths) > 1 else first_length_k + return min_length_q, min_length_k @functools.cached_property def document_index(self) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/fast_llm/data/document/config.py b/fast_llm/data/document/config.py index 352311b51..a90bcdebc 100644 --- a/fast_llm/data/document/config.py +++ b/fast_llm/data/document/config.py @@ -25,6 +25,7 @@ class LengthPreprocessingConfig(BatchPreprocessingConfig): distributed: DistributedConfig = Field() return_cumulative_sequence_lengths: bool = Field(default=False) return_max_sequence_lengths: bool = Field(default=False) + return_min_sequence_lengths: bool = Field(default=False) return_document_index: bool = Field(default=False) return_position_index: bool = Field(default=False) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 57b826ab2..9ff1f7846 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -280,14 +280,29 @@ def _attn_sdpa( } if query.is_cuda and self._config.window_size is None: # Wrap each document as its own batch element via nested-jagged so cross-doc masking - # is structural and EFFICIENT skips materializing the attention mask. SDPA's nested - # dispatch reads `max_seqlen`/`min_seqlen` to host (5 cudaMemcpyAsync DtoH per call), - # which floors per-call wall at ~6 ms; still much faster than dense+mask in varlen. + # is structural and EFFICIENT skips materializing the attention mask. The dispatch + # otherwise reads `max_seqlen`/`min_seqlen` to host on every call; passing them in + # explicitly keeps the path sync-free. cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64) cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64) - query = torch.nested.nested_tensor_from_jagged(query, cu_seqlens_q) - key = torch.nested.nested_tensor_from_jagged(key, cu_seqlens_k) - value = torch.nested.nested_tensor_from_jagged(value, cu_seqlens_k) + query = torch.nested.nested_tensor_from_jagged( + query, + cu_seqlens_q, + min_seqlen=kwargs[AttentionKwargs.min_seqlen_q], + max_seqlen=kwargs[AttentionKwargs.max_seqlen_q], + ) + key = torch.nested.nested_tensor_from_jagged( + key, + cu_seqlens_k, + min_seqlen=kwargs[AttentionKwargs.min_seqlen_k], + max_seqlen=kwargs[AttentionKwargs.max_seqlen_k], + ) + value = torch.nested.nested_tensor_from_jagged( + value, + cu_seqlens_k, + min_seqlen=kwargs[AttentionKwargs.min_seqlen_k], + max_seqlen=kwargs[AttentionKwargs.max_seqlen_k], + ) sdpa_args["is_causal"] = self._config.causal else: # Dense + backup's preprocessed causal+document mask. Required on CPU (MATH rejects @@ -565,7 +580,12 @@ def get_preprocessing_config(self) -> dict[str, typing.Any]: and self._distributed_config.use_cuda and self._config.window_size is None ): - return {"return_cumulative_sequence_lengths": True, "causal": self._config.causal} + return { + "return_cumulative_sequence_lengths": True, + "return_max_sequence_lengths": True, + "return_min_sequence_lengths": True, + "causal": self._config.causal, + } elif self._implementation in (AttentionImplementation.sdpa, AttentionImplementation.backup): return {"return_document_index": True, "causal": self._config.causal} else: diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 69aa4f484..f69e2129d 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -21,6 +21,8 @@ class MixerKwargs(BlockKwargs): cu_seqlens_k = "cu_seqlens_k" max_seqlen_q = "max_seqlen_q" max_seqlen_k = "max_seqlen_k" + min_seqlen_q = "min_seqlen_q" + min_seqlen_k = "min_seqlen_k" document_index_q = "document_index_q" document_index_k = "document_index_k" position_ids = "position_ids" From f389c22132943619f6aee036d555fb9bbaaedecd Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 11 May 2026 15:12:00 -0400 Subject: [PATCH 07/12] Split SDPA into sdpa_nested and sdpa_dense implementations The sdpa_nested vs sdpa_dense choice was previously a branch inside _attn_sdpa duplicated across preprocessing config and preprocess(), with the runtime check using `query.is_cuda` and the preprocessing checks using `self._distributed_config.use_cuda`. Promoting it to two enum values resolves the duplication and the two-sources-of-truth. Also add a CPU-side fp32 sdpa_dense equivalence check (the dense path is now the auto-fallback on CPU and was previously unexercised in CI), and factor first_length_k out of max_lengths/min_lengths. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/data/document/block.py | 12 +++--- fast_llm/layers/attention/attention.py | 24 +++++------ fast_llm/layers/attention/config.py | 3 +- tests/layers/test_attention.py | 60 +++++++++++++++----------- 4 files changed, 55 insertions(+), 44 deletions(-) diff --git a/fast_llm/data/document/block.py b/fast_llm/data/document/block.py index b9f2e30f2..a02f92bdf 100644 --- a/fast_llm/data/document/block.py +++ b/fast_llm/data/document/block.py @@ -123,18 +123,20 @@ def cumulative_lengths(self) -> tuple[torch.Tensor, torch.Tensor]: cumulative_lengths_k[0] = self.first_document_begin return cumulative_lengths_q, cumulative_lengths_k + @functools.cached_property + def _first_length_k(self) -> int: + # First doc's K-side length includes the past KV prefix; remaining docs match q-side. + return self.sequence_k_past + self.lengths[0] - self.first_document_begin + @functools.cached_property def max_lengths(self) -> tuple[int, int]: max_length_q = max(self.lengths) - max_length_k = max(max_length_q, self.sequence_k_past + self.lengths[0] - self.first_document_begin) - return max_length_q, max_length_k + return max_length_q, max(max_length_q, self._first_length_k) @functools.cached_property def min_lengths(self) -> tuple[int, int]: min_length_q = min(self.lengths) - # First doc's K-side length includes the past KV prefix; remaining docs match q-side. - first_length_k = self.sequence_k_past + self.lengths[0] - self.first_document_begin - min_length_k = min(first_length_k, *self.lengths[1:]) if len(self.lengths) > 1 else first_length_k + min_length_k = min(self._first_length_k, *self.lengths[1:]) if len(self.lengths) > 1 else self._first_length_k return min_length_q, min_length_k @functools.cached_property diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 9ff1f7846..bac62d373 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -87,8 +87,10 @@ def __init__( and self._config.head_size <= 256 ): self._implementation = AttentionImplementation.flash + elif self._distributed_config.use_cuda and self._config.window_size is None: + self._implementation = AttentionImplementation.sdpa_nested else: - self._implementation = AttentionImplementation.sdpa + self._implementation = AttentionImplementation.sdpa_dense self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim( @@ -278,7 +280,7 @@ def _attn_sdpa( "dropout_p": self._config.dropout if self.training else 0.0, "scale": self._softmax_scale, } - if query.is_cuda and self._config.window_size is None: + if self._implementation == AttentionImplementation.sdpa_nested: # Wrap each document as its own batch element via nested-jagged so cross-doc masking # is structural and EFFICIENT skips materializing the attention mask. The dispatch # otherwise reads `max_seqlen`/`min_seqlen` to host on every call; passing them in @@ -486,7 +488,7 @@ def _forward( with set_generator(self._distributed.tp_generator): if self._implementation == AttentionImplementation.flash: input_ = self._attn_flash(query, key, value, kwargs) - elif self._implementation == AttentionImplementation.sdpa: + elif self._implementation in (AttentionImplementation.sdpa_nested, AttentionImplementation.sdpa_dense): input_ = self._attn_sdpa(query, key, value, kwargs) elif self._implementation == AttentionImplementation.backup: # TODO: Avoid the flattens. @@ -542,7 +544,8 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c if (not config.hardware) or self._implementation in ( AttentionImplementation.flash, - AttentionImplementation.sdpa, + AttentionImplementation.sdpa_nested, + AttentionImplementation.sdpa_dense, ): # Remove non-causal part. (TODO: Support non-causal) # TODO: Compute is overestimated without cross-document attention. @@ -575,28 +578,21 @@ def get_preprocessing_config(self) -> dict[str, typing.Any]: "return_max_sequence_lengths": True, "causal": self._config.causal, } - elif ( - self._implementation == AttentionImplementation.sdpa - and self._distributed_config.use_cuda - and self._config.window_size is None - ): + elif self._implementation == AttentionImplementation.sdpa_nested: return { "return_cumulative_sequence_lengths": True, "return_max_sequence_lengths": True, "return_min_sequence_lengths": True, "causal": self._config.causal, } - elif self._implementation in (AttentionImplementation.sdpa, AttentionImplementation.backup): + elif self._implementation in (AttentionImplementation.sdpa_dense, AttentionImplementation.backup): return {"return_document_index": True, "causal": self._config.causal} else: raise NotImplementedError(self._implementation) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._rotary.preprocess(kwargs) - if self._implementation == AttentionImplementation.backup or ( - self._implementation == AttentionImplementation.sdpa - and (not self._distributed_config.use_cuda or self._config.window_size is not None) - ): + if self._implementation in (AttentionImplementation.backup, AttentionImplementation.sdpa_dense): self._preprocess_for_backup_attention(kwargs) def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> None: diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index f69e2129d..009fe6561 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -40,7 +40,8 @@ class AttentionKwargs(MixerKwargs): class AttentionImplementation(enum.StrEnum): auto = "auto" flash = "flash" - sdpa = "sdpa" + sdpa_nested = "sdpa_nested" + sdpa_dense = "sdpa_dense" backup = "backup" diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index ef96dbf96..943a74350 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -380,7 +380,40 @@ def _test_attention(config: AttentionTestConfig, lengths: list[int]) -> None: Assert.rms_close_relative(param.grad_buffer, grad_ref, 1e-5, 1e-7, msg=name) stage.reset_gradients() - # Flash and SDPA equivalence checks: each implementation's packed bfloat16 output must + def _check_packed( + implementation: str, + distributed_config_check: DistributedConfig, + distributed_check: Distributed, + hidden_states_check: torch.Tensor, + out_ref_check: torch.Tensor, + rtol: float, + ) -> None: + attention_impl: Attention = config.get_attention_config(implementation).get_layer( + distributed_config_check, hidden_dim, lr_scale=None, peft=None, return_bias=False + ) + stage_impl = get_stage([attention_impl], distributed_check) + for param_impl, param_f32 in zip(attention_impl.parameters(), attention.parameters(), strict=True): + param_impl.data.copy_(param_f32.data) + (model_input,) = LanguageModelBatch( + tokens=torch.empty(num_tokens, dtype=torch.int64, device=device), lengths=lengths + ).get_model_inputs( + LanguageModelBatchPreprocessingConfig( + distributed=distributed_config_check, + predicted_tokens=0, + **attention_impl.get_preprocessing_config(), + ) + ) + kwargs_impl = model_input.to_kwargs() + attention_impl.preprocess(kwargs_impl) + out_impl, _ = stage_impl.forward(hidden_states_check, kwargs_impl) + Assert.rms_close_relative(out_impl, out_ref_check, rtol, 1e-7) + + # SDPA-dense equivalence check: sdpa_dense reuses backup's mask, so its packed fp32 output + # must match the per-sequence backup reference. This is the only SDPA branch that runs on CPU + # (sdpa_nested needs CUDA), so the check is unconditional rather than CUDA-gated. + _check_packed("sdpa_dense", distributed_config, distributed, hidden_states.detach(), out_ref, 1e-5) + + # Flash and SDPA-nested equivalence checks: each implementation's packed bfloat16 output must # match a per-sequence bfloat16 backup reference. if not torch.cuda.is_available(): return @@ -406,30 +439,9 @@ def _test_attention(config: AttentionTestConfig, lengths: list[int]) -> None: with_backward=False, ) - def _check_packed(implementation: str) -> None: - attention_impl: Attention = config.get_attention_config(implementation).get_layer( - distributed_config_bf16, hidden_dim, lr_scale=None, peft=None, return_bias=False - ) - stage_impl = get_stage([attention_impl], distributed_bf16) - for param_impl, param_f32 in zip(attention_impl.parameters(), attention.parameters(), strict=True): - param_impl.data.copy_(param_f32.data) - (model_input,) = LanguageModelBatch( - tokens=torch.empty(num_tokens, dtype=torch.int64, device=device), lengths=lengths - ).get_model_inputs( - LanguageModelBatchPreprocessingConfig( - distributed=distributed_config_bf16, - predicted_tokens=0, - **attention_impl.get_preprocessing_config(), - ) - ) - kwargs_impl = model_input.to_kwargs() - attention_impl.preprocess(kwargs_impl) - out_impl, _ = stage_impl.forward(hidden_states_bf16, kwargs_impl) - Assert.rms_close_relative(out_impl, out_ref_bf16, 5e-3, 1e-7) - if _flash_available and config.head_size <= 256: - _check_packed("flash") - _check_packed("sdpa") + _check_packed("flash", distributed_config_bf16, distributed_bf16, hidden_states_bf16, out_ref_bf16, 5e-3) + _check_packed("sdpa_nested", distributed_config_bf16, distributed_bf16, hidden_states_bf16, out_ref_bf16, 5e-3) @pytest.mark.slow From aa3b113586f1d4ea0c2e94b4b7d6948b4ca9fff8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 11 May 2026 18:03:18 -0400 Subject: [PATCH 08/12] Address fifth-round review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add `sdpa` enum value that auto-resolves to `sdpa_nested` or `sdpa_dense` at init based on `(use_cuda, window_size)`. Auto-fallback now resolves to `flash` or `sdpa`; the second stage then picks the concrete variant. - Update the `implementation` field's desc to describe the new cascade. - Validate `flash` + `head_size > 256` in `_validate` so explicit configs fail at construction with a clear message instead of inside flash-attn. - Drop the dead `if attention_mask is not None` guard in `_attn_sdpa`'s dense branch — `_preprocess_for_backup_attention` always sets a non-None mask, so the conditional could never short-circuit. - Comment near `repeat_interleave` that `enable_gqa=True` is intentionally avoided (forces MATH fallback, incompatible with nested-jagged path). - Extend `_check_packed` to run backward and compare parameter gradients, closing the autograd-coverage gap through `nested_tensor_from_jagged` and GQA `repeat_interleave`. Bump the bf16 per-seq reference to `with_backward=True` and capture `grads_ref_bf16` for the bf16 checks. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/attention/attention.py | 13 +++++---- fast_llm/layers/attention/config.py | 9 +++++- tests/layers/test_attention.py | 39 ++++++++++++++++---------- 3 files changed, 40 insertions(+), 21 deletions(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index bac62d373..740e7342f 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -87,7 +87,10 @@ def __init__( and self._config.head_size <= 256 ): self._implementation = AttentionImplementation.flash - elif self._distributed_config.use_cuda and self._config.window_size is None: + else: + self._implementation = AttentionImplementation.sdpa + if self._implementation == AttentionImplementation.sdpa: + if self._distributed_config.use_cuda and self._config.window_size is None: self._implementation = AttentionImplementation.sdpa_nested else: self._implementation = AttentionImplementation.sdpa_dense @@ -272,6 +275,9 @@ def _attn_sdpa( kwargs: dict[str, typing.Any], ) -> torch.Tensor: # total_q, heads, head_size # SDPA's fused kernels require Q/K/V to share heads, so we expand K/V across query heads. + # `enable_gqa=True` would handle this internally, but it forces a fallback to the MATH + # backend (which materializes the O(S^2) attention matrix) and is unsupported by the + # nested-jagged path entirely; manual `repeat_interleave` keeps EFFICIENT in play. if self._local_heads_per_group > 1: key = key.repeat_interleave(self._local_heads_per_group, dim=1) value = value.repeat_interleave(self._local_heads_per_group, dim=1) @@ -310,13 +316,10 @@ def _attn_sdpa( # Dense + backup's preprocessed causal+document mask. Required on CPU (MATH rejects # nested + is_causal) and on CUDA with sliding window (the nested path can't express # it). Backup builds the mask as (1, sq, 1, sk); SDPA wants (B, H, sq, sk). - attention_mask = kwargs[AttentionKwargs.attention_mask] - if attention_mask is not None: - attention_mask = attention_mask.transpose(1, 2) query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) - sdpa_args["attn_mask"] = attention_mask + sdpa_args["attn_mask"] = kwargs[AttentionKwargs.attention_mask].transpose(1, 2) output = torch.nn.functional.scaled_dot_product_attention( query.transpose(1, 2), diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 009fe6561..85d133f5c 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -40,6 +40,7 @@ class AttentionKwargs(MixerKwargs): class AttentionImplementation(enum.StrEnum): auto = "auto" flash = "flash" + sdpa = "sdpa" sdpa_nested = "sdpa_nested" sdpa_dense = "sdpa_dense" backup = "backup" @@ -124,7 +125,10 @@ class AttentionConfig(MixerConfig): ) implementation: AttentionImplementation = Field( default=AttentionImplementation.auto, - desc="The implementation to use for the attention layer. Default: `flash` if supported, otherwise `backup`.", + desc="The implementation to use for the attention layer." + " `auto` picks `flash` when available (bf16/fp16, head_size <= 256, flash-attn installed), otherwise `sdpa`." + " `sdpa` further resolves to `sdpa_nested` on CUDA without sliding window, and to `sdpa_dense` otherwise." + " `sdpa_nested` and `sdpa_dense` are explicit overrides; `backup` is a slow pure-PyTorch fallback.", hint=FieldHint.feature, ) query_norm: NormalizationConfig | None = Field( @@ -156,6 +160,9 @@ def _validate(self) -> None: if not self.causal: assert self.window_size is None, "Non-causal windowed attention is not supported." + if self.implementation == AttentionImplementation.flash: + Assert.leq(self.head_size, 256) + @property def layer_class(self) -> "type[Attention]": from fast_llm.layers.attention.attention import Attention diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index 943a74350..70b838aed 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -386,6 +386,7 @@ def _check_packed( distributed_check: Distributed, hidden_states_check: torch.Tensor, out_ref_check: torch.Tensor, + grads_ref_check: list[torch.Tensor], rtol: float, ) -> None: attention_impl: Attention = config.get_attention_config(implementation).get_layer( @@ -405,16 +406,19 @@ def _check_packed( ) kwargs_impl = model_input.to_kwargs() attention_impl.preprocess(kwargs_impl) - out_impl, _ = stage_impl.forward(hidden_states_check, kwargs_impl) + out_impl, context = stage_impl.forward(hidden_states_check, kwargs_impl) + stage_impl.backward(torch.ones_like(out_impl), context) Assert.rms_close_relative(out_impl, out_ref_check, rtol, 1e-7) + for param_impl, grad_ref in zip(attention_impl.parameters(), grads_ref_check, strict=True): + Assert.rms_close_relative(param_impl.grad_buffer, grad_ref, rtol, 1e-7, msg=implementation) # SDPA-dense equivalence check: sdpa_dense reuses backup's mask, so its packed fp32 output - # must match the per-sequence backup reference. This is the only SDPA branch that runs on CPU - # (sdpa_nested needs CUDA), so the check is unconditional rather than CUDA-gated. - _check_packed("sdpa_dense", distributed_config, distributed, hidden_states.detach(), out_ref, 1e-5) + # and parameter gradients must match the per-sequence backup reference. This is the only SDPA + # branch that runs on CPU (sdpa_nested needs CUDA), so the check is unconditional. + _check_packed("sdpa_dense", distributed_config, distributed, hidden_states, out_ref, grads_ref, 1e-5) - # Flash and SDPA-nested equivalence checks: each implementation's packed bfloat16 output must - # match a per-sequence bfloat16 backup reference. + # Flash and SDPA-nested equivalence checks: each implementation's packed bfloat16 output and + # parameter gradients must match a per-sequence bfloat16 backup reference. if not torch.cuda.is_available(): return @@ -430,18 +434,23 @@ def _check_packed( hidden_states_bf16 = hidden_states.detach().to(torch.bfloat16) out_ref_bf16 = _run_per_seq_reference( - attention_backup_bf16, - stage_backup_bf16, - distributed_config_bf16, - hidden_states_bf16, - lengths, - device, - with_backward=False, + attention_backup_bf16, stage_backup_bf16, distributed_config_bf16, hidden_states_bf16, lengths, device ) + grads_ref_bf16 = [param.grad_buffer.clone() for param in attention_backup_bf16.parameters()] if _flash_available and config.head_size <= 256: - _check_packed("flash", distributed_config_bf16, distributed_bf16, hidden_states_bf16, out_ref_bf16, 5e-3) - _check_packed("sdpa_nested", distributed_config_bf16, distributed_bf16, hidden_states_bf16, out_ref_bf16, 5e-3) + _check_packed( + "flash", distributed_config_bf16, distributed_bf16, hidden_states_bf16, out_ref_bf16, grads_ref_bf16, 5e-3 + ) + _check_packed( + "sdpa_nested", + distributed_config_bf16, + distributed_bf16, + hidden_states_bf16, + out_ref_bf16, + grads_ref_bf16, + 5e-3, + ) @pytest.mark.slow From 311d979b9b9f9c48f3a6614015038ccaa374e6b5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 11 May 2026 18:49:43 -0400 Subject: [PATCH 09/12] Validate sdpa_nested window incompatibility; widen bf16 grad rtol MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Reject `sdpa_nested` + non-None `window_size` in `_validate`. The nested + is_causal path can't express sliding window; previously a config that explicitly chose `sdpa_nested` with a window would silently produce full-causal output instead. - Skip the bf16 `sdpa_nested` test for windowed configs (the validation above would otherwise reject the test config). - Add a bf16 `sdpa_dense` equivalence check. Same kernel as backup but packed-vs-per-seq reductions diverge at bf16 by ~7e-3 — establishes that the bf16 reference itself isn't reproducible to forward tolerance, so flash/sdpa_nested divergence is reduction-order noise, not a bug. - Split `_check_packed`'s rtol into `out_rtol` and `grad_rtol`; bf16 grads diverge above the forward bound (up to ~1.3 % for flash). New grad tolerance is 1.5e-2 at bf16, 1e-5 at fp32. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/attention/config.py | 5 ++++ tests/layers/test_attention.py | 46 +++++++++++++++++++++-------- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 85d133f5c..23a54c2ed 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -163,6 +163,11 @@ def _validate(self) -> None: if self.implementation == AttentionImplementation.flash: Assert.leq(self.head_size, 256) + if self.implementation == AttentionImplementation.sdpa_nested: + assert ( + self.window_size is None + ), "`sdpa_nested` does not support sliding window; use `sdpa` or `sdpa_dense`." + @property def layer_class(self) -> "type[Attention]": from fast_llm.layers.attention.attention import Attention diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index 70b838aed..c2af3d323 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -387,7 +387,8 @@ def _check_packed( hidden_states_check: torch.Tensor, out_ref_check: torch.Tensor, grads_ref_check: list[torch.Tensor], - rtol: float, + out_rtol: float, + grad_rtol: float, ) -> None: attention_impl: Attention = config.get_attention_config(implementation).get_layer( distributed_config_check, hidden_dim, lr_scale=None, peft=None, return_bias=False @@ -408,17 +409,19 @@ def _check_packed( attention_impl.preprocess(kwargs_impl) out_impl, context = stage_impl.forward(hidden_states_check, kwargs_impl) stage_impl.backward(torch.ones_like(out_impl), context) - Assert.rms_close_relative(out_impl, out_ref_check, rtol, 1e-7) + Assert.rms_close_relative(out_impl, out_ref_check, out_rtol, 1e-7) for param_impl, grad_ref in zip(attention_impl.parameters(), grads_ref_check, strict=True): - Assert.rms_close_relative(param_impl.grad_buffer, grad_ref, rtol, 1e-7, msg=implementation) + Assert.rms_close_relative(param_impl.grad_buffer, grad_ref, grad_rtol, 1e-7, msg=implementation) # SDPA-dense equivalence check: sdpa_dense reuses backup's mask, so its packed fp32 output # and parameter gradients must match the per-sequence backup reference. This is the only SDPA # branch that runs on CPU (sdpa_nested needs CUDA), so the check is unconditional. - _check_packed("sdpa_dense", distributed_config, distributed, hidden_states, out_ref, grads_ref, 1e-5) + _check_packed("sdpa_dense", distributed_config, distributed, hidden_states, out_ref, grads_ref, 1e-5, 1e-5) - # Flash and SDPA-nested equivalence checks: each implementation's packed bfloat16 output and - # parameter gradients must match a per-sequence bfloat16 backup reference. + # Flash and SDPA equivalence checks: each implementation's packed bfloat16 output and parameter + # gradients must match a per-sequence bfloat16 backup reference. Backward grad tolerance is + # looser than forward — bf16 reduction-order noise compounds through the backward pass, and + # the same-kernel sdpa_dense path itself diverges from per-seq backup by ~7e-3 at bf16. if not torch.cuda.is_available(): return @@ -432,25 +435,44 @@ def _check_packed( for param_bf16, param_f32 in zip(attention_backup_bf16.parameters(), attention.parameters(), strict=True): param_bf16.data.copy_(param_f32.data) - hidden_states_bf16 = hidden_states.detach().to(torch.bfloat16) + hidden_states_bf16 = hidden_states.detach().to(torch.bfloat16).requires_grad_() out_ref_bf16 = _run_per_seq_reference( attention_backup_bf16, stage_backup_bf16, distributed_config_bf16, hidden_states_bf16, lengths, device ) grads_ref_bf16 = [param.grad_buffer.clone() for param in attention_backup_bf16.parameters()] - if _flash_available and config.head_size <= 256: - _check_packed( - "flash", distributed_config_bf16, distributed_bf16, hidden_states_bf16, out_ref_bf16, grads_ref_bf16, 5e-3 - ) _check_packed( - "sdpa_nested", + "sdpa_dense", distributed_config_bf16, distributed_bf16, hidden_states_bf16, out_ref_bf16, grads_ref_bf16, 5e-3, + 1.5e-2, ) + if _flash_available and config.head_size <= 256: + _check_packed( + "flash", + distributed_config_bf16, + distributed_bf16, + hidden_states_bf16, + out_ref_bf16, + grads_ref_bf16, + 5e-3, + 1.5e-2, + ) + if config.window_size is None: + _check_packed( + "sdpa_nested", + distributed_config_bf16, + distributed_bf16, + hidden_states_bf16, + out_ref_bf16, + grads_ref_bf16, + 5e-3, + 1.5e-2, + ) @pytest.mark.slow From fc88839258ac717fd31c3a039038d30afc357d0d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 11 May 2026 19:04:53 -0400 Subject: [PATCH 10/12] Address sixth-round review polish - Extract `_FLASH_MAX_HEAD_SIZE = 256` in config.py; reference it from both the auto-resolution in `Attention.__init__` and the `_validate` guard. - Annotate the two-stage implementation resolution to explain why `auto` and `sdpa` share the same cascade structure. - Comment the int64 cast in `_attn_sdpa` so the cast isn't mistaken for unnecessary work (`nested_tensor_from_jagged` requires int64 offsets). - Qualify the key/value shape comments on `_attn_sdpa` as pre-GQA-expansion to match the rebinding after `repeat_interleave`. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/attention/attention.py | 15 +++++++++++---- fast_llm/layers/attention/config.py | 5 ++++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 740e7342f..dba863e94 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -10,7 +10,12 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.utils import wrap_forward_backward -from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs +from fast_llm.layers.attention.config import ( + _FLASH_MAX_HEAD_SIZE, + AttentionConfig, + AttentionImplementation, + AttentionKwargs, +) from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.tensor import TensorMeta @@ -79,12 +84,13 @@ def __init__( peft=peft, return_bias=return_bias, ) + # Two-stage resolution so callers can set `implementation=sdpa` to get the auto-picked SDPA flavor. self._implementation = self._config.implementation if self._implementation == AttentionImplementation.auto: if ( _flash_available and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16) - and self._config.head_size <= 256 + and self._config.head_size <= _FLASH_MAX_HEAD_SIZE ): self._implementation = AttentionImplementation.flash else: @@ -270,8 +276,8 @@ def _attn_flash( def _attn_sdpa( self, query: torch.Tensor, # total_q, heads, head_size - key: torch.Tensor, # total_k, head_groups, head_size - value: torch.Tensor, # total_k, head_groups, head_size + key: torch.Tensor, # total_k, head_groups, head_size (pre-GQA-expansion) + value: torch.Tensor, # total_k, head_groups, head_size (pre-GQA-expansion) kwargs: dict[str, typing.Any], ) -> torch.Tensor: # total_q, heads, head_size # SDPA's fused kernels require Q/K/V to share heads, so we expand K/V across query heads. @@ -291,6 +297,7 @@ def _attn_sdpa( # is structural and EFFICIENT skips materializing the attention mask. The dispatch # otherwise reads `max_seqlen`/`min_seqlen` to host on every call; passing them in # explicitly keeps the path sync-free. + # `nested_tensor_from_jagged` requires int64 offsets. cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64) cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64) query = torch.nested.nested_tensor_from_jagged( diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 23a54c2ed..95c67e7a3 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -37,6 +37,9 @@ class AttentionKwargs(MixerKwargs): past_key_values = "past_key_values" +_FLASH_MAX_HEAD_SIZE = 256 + + class AttentionImplementation(enum.StrEnum): auto = "auto" flash = "flash" @@ -161,7 +164,7 @@ def _validate(self) -> None: assert self.window_size is None, "Non-causal windowed attention is not supported." if self.implementation == AttentionImplementation.flash: - Assert.leq(self.head_size, 256) + Assert.leq(self.head_size, _FLASH_MAX_HEAD_SIZE) if self.implementation == AttentionImplementation.sdpa_nested: assert ( From 9ae839bc719f5912207aa658d556928c32ab1a31 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 11 May 2026 19:09:36 -0400 Subject: [PATCH 11/12] Address seventh-round review polish - Drop the now-dead `with_backward` parameter on `_run_per_seq_reference`; both callers default to `True` after the backward-parity refactor. - Rename `large_head_causal` / `large_head_mqa` to `*_no_norm` so they match the suffix convention the cross-product loop establishes. - Fix the bf16-grad comment that called sdpa_dense the "same-kernel" path; sdpa_dense routes through SDPA's MATH/EFFICIENT backend, backup uses its own kernel. Reworded to make the divergence-across-kernels point clear. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/layers/test_attention.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index c2af3d323..456228386 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -264,8 +264,8 @@ def expected_output( # head_size > 256 — exercises the SDPA-only regime (flash caps at 256). for name, kwargs in ( - ("large_head_causal", {"causal": True, "head_size": 320}), - ("large_head_mqa", {"causal": True, "head_size": 320, "kv_heads": 1}), + ("large_head_causal_no_norm", {"causal": True, "head_size": 320}), + ("large_head_mqa_no_norm", {"causal": True, "head_size": 320, "kv_heads": 1}), ): for lengths in _LENGTHS_SHORT: _attention_test_cases.append((AttentionTestConfig(name=name, **kwargs), lengths)) @@ -278,7 +278,6 @@ def _run_per_seq_reference( hidden_states: torch.Tensor, lengths: list[int], device: torch.device, - with_backward: bool = True, ) -> torch.Tensor: out_refs = [] for length, hidden_slice in zip(lengths, torch.split(hidden_states, lengths, dim=0), strict=True): @@ -294,8 +293,7 @@ def _run_per_seq_reference( kwargs = model_input.to_kwargs() attention.preprocess(kwargs) out, context = stage.forward(hidden_slice, kwargs) - if with_backward: - stage.backward(torch.ones_like(out), context) + stage.backward(torch.ones_like(out), context) out_refs.append(out.detach()) return torch.cat(out_refs, dim=0) @@ -420,8 +418,8 @@ def _check_packed( # Flash and SDPA equivalence checks: each implementation's packed bfloat16 output and parameter # gradients must match a per-sequence bfloat16 backup reference. Backward grad tolerance is - # looser than forward — bf16 reduction-order noise compounds through the backward pass, and - # the same-kernel sdpa_dense path itself diverges from per-seq backup by ~7e-3 at bf16. + # looser than forward — bf16 reduction-order noise compounds through the backward pass, with + # even the sdpa_dense path diverging from the per-seq backup reference by ~7e-3 at bf16. if not torch.cuda.is_available(): return From 77df6029d23e775a3173d3390649151396d01520 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 11 May 2026 19:37:30 -0400 Subject: [PATCH 12/12] Lift _check_packed out of _test_attention; trim block comments _check_packed only closed over data that can be passed as arguments, so there's no reason for it to be a nested helper. Promoting it to module scope and consolidating the three bf16 call sites into a single for-loop also drops the verbose block comments preceding each path; the one remaining comment justifies the looser bf16 grad rtol. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/layers/test_attention.py | 129 ++++++++++++++++----------------- 1 file changed, 64 insertions(+), 65 deletions(-) diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index 456228386..4940631b9 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -271,6 +271,44 @@ def expected_output( _attention_test_cases.append((AttentionTestConfig(name=name, **kwargs), lengths)) +def _check_packed( + implementation: str, + config: AttentionTestConfig, + hidden_dim: TensorDim, + lengths: list[int], + attention_f32: Attention, + distributed_config: DistributedConfig, + distributed: Distributed, + hidden_states: torch.Tensor, + out_ref: torch.Tensor, + grads_ref: list[torch.Tensor], + out_rtol: float, + grad_rtol: float, +) -> None: + attention_impl: Attention = config.get_attention_config(implementation).get_layer( + distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False + ) + stage_impl = get_stage([attention_impl], distributed) + for param_impl, param_f32 in zip(attention_impl.parameters(), attention_f32.parameters(), strict=True): + param_impl.data.copy_(param_f32.data) + (model_input,) = LanguageModelBatch( + tokens=torch.empty(sum(lengths), dtype=torch.int64, device=hidden_states.device), lengths=lengths + ).get_model_inputs( + LanguageModelBatchPreprocessingConfig( + distributed=distributed_config, + predicted_tokens=0, + **attention_impl.get_preprocessing_config(), + ) + ) + kwargs_impl = model_input.to_kwargs() + attention_impl.preprocess(kwargs_impl) + out_impl, context = stage_impl.forward(hidden_states, kwargs_impl) + stage_impl.backward(torch.ones_like(out_impl), context) + Assert.rms_close_relative(out_impl, out_ref, out_rtol, 1e-7) + for param_impl, grad_ref in zip(attention_impl.parameters(), grads_ref, strict=True): + Assert.rms_close_relative(param_impl.grad_buffer, grad_ref, grad_rtol, 1e-7, msg=implementation) + + def _run_per_seq_reference( attention: Attention, stage, @@ -378,48 +416,21 @@ def _test_attention(config: AttentionTestConfig, lengths: list[int]) -> None: Assert.rms_close_relative(param.grad_buffer, grad_ref, 1e-5, 1e-7, msg=name) stage.reset_gradients() - def _check_packed( - implementation: str, - distributed_config_check: DistributedConfig, - distributed_check: Distributed, - hidden_states_check: torch.Tensor, - out_ref_check: torch.Tensor, - grads_ref_check: list[torch.Tensor], - out_rtol: float, - grad_rtol: float, - ) -> None: - attention_impl: Attention = config.get_attention_config(implementation).get_layer( - distributed_config_check, hidden_dim, lr_scale=None, peft=None, return_bias=False - ) - stage_impl = get_stage([attention_impl], distributed_check) - for param_impl, param_f32 in zip(attention_impl.parameters(), attention.parameters(), strict=True): - param_impl.data.copy_(param_f32.data) - (model_input,) = LanguageModelBatch( - tokens=torch.empty(num_tokens, dtype=torch.int64, device=device), lengths=lengths - ).get_model_inputs( - LanguageModelBatchPreprocessingConfig( - distributed=distributed_config_check, - predicted_tokens=0, - **attention_impl.get_preprocessing_config(), - ) - ) - kwargs_impl = model_input.to_kwargs() - attention_impl.preprocess(kwargs_impl) - out_impl, context = stage_impl.forward(hidden_states_check, kwargs_impl) - stage_impl.backward(torch.ones_like(out_impl), context) - Assert.rms_close_relative(out_impl, out_ref_check, out_rtol, 1e-7) - for param_impl, grad_ref in zip(attention_impl.parameters(), grads_ref_check, strict=True): - Assert.rms_close_relative(param_impl.grad_buffer, grad_ref, grad_rtol, 1e-7, msg=implementation) - - # SDPA-dense equivalence check: sdpa_dense reuses backup's mask, so its packed fp32 output - # and parameter gradients must match the per-sequence backup reference. This is the only SDPA - # branch that runs on CPU (sdpa_nested needs CUDA), so the check is unconditional. - _check_packed("sdpa_dense", distributed_config, distributed, hidden_states, out_ref, grads_ref, 1e-5, 1e-5) - - # Flash and SDPA equivalence checks: each implementation's packed bfloat16 output and parameter - # gradients must match a per-sequence bfloat16 backup reference. Backward grad tolerance is - # looser than forward — bf16 reduction-order noise compounds through the backward pass, with - # even the sdpa_dense path diverging from the per-seq backup reference by ~7e-3 at bf16. + _check_packed( + "sdpa_dense", + config, + hidden_dim, + lengths, + attention, + distributed_config, + distributed, + hidden_states, + out_ref, + grads_ref, + 1e-5, + 1e-5, + ) + if not torch.cuda.is_available(): return @@ -439,30 +450,18 @@ def _check_packed( ) grads_ref_bf16 = [param.grad_buffer.clone() for param in attention_backup_bf16.parameters()] - _check_packed( - "sdpa_dense", - distributed_config_bf16, - distributed_bf16, - hidden_states_bf16, - out_ref_bf16, - grads_ref_bf16, - 5e-3, - 1.5e-2, - ) - if _flash_available and config.head_size <= 256: - _check_packed( - "flash", - distributed_config_bf16, - distributed_bf16, - hidden_states_bf16, - out_ref_bf16, - grads_ref_bf16, - 5e-3, - 1.5e-2, - ) - if config.window_size is None: + # bf16 grad rtol is looser than forward: reduction-order noise compounds through backward. + for implementation in ("sdpa_dense", "flash", "sdpa_nested"): + if implementation == "flash" and (not _flash_available or config.head_size > 256): + continue + if implementation == "sdpa_nested" and config.window_size is not None: + continue _check_packed( - "sdpa_nested", + implementation, + config, + hidden_dim, + lengths, + attention, distributed_config_bf16, distributed_bf16, hidden_states_bf16,