diff --git a/fast_llm/data/document/block.py b/fast_llm/data/document/block.py index 530be42ea..a02f92bdf 100644 --- a/fast_llm/data/document/block.py +++ b/fast_llm/data/document/block.py @@ -24,8 +24,10 @@ class BlockModelInput(ModelInput): lengths: list[int] = None cumulative_lengths_q: torch.Tensor | None = None cumulative_lengths_k: torch.Tensor | None = None - max_length_q: torch.Tensor | None = None - max_length_k: torch.Tensor | None = None + max_length_q: int | None = None + max_length_k: int | None = None + min_length_q: int | None = None + min_length_k: int | None = None document_index_q: torch.Tensor | None = None document_index_k: torch.Tensor | None = None position_index: torch.Tensor | None = None @@ -44,6 +46,8 @@ def to_kwargs(self) -> dict[str, typing.Any]: AttentionKwargs.cu_seqlens_k: self.cumulative_lengths_k, AttentionKwargs.max_seqlen_q: self.max_length_q, AttentionKwargs.max_seqlen_k: self.max_length_k, + AttentionKwargs.min_seqlen_q: self.min_length_q, + AttentionKwargs.min_seqlen_k: self.min_length_k, AttentionKwargs.document_index_q: self.document_index_q, AttentionKwargs.document_index_k: self.document_index_k, LanguageModelKwargs.position_ids: self.position_index, @@ -101,6 +105,8 @@ def preprocess(self, model_input: BlockModelInput, config: LengthPreprocessingCo model_input.cumulative_lengths_q, model_input.cumulative_lengths_k = self.cumulative_lengths if config.return_max_sequence_lengths or config.return_document_index: model_input.max_length_q, model_input.max_length_k = self.max_lengths + if config.return_min_sequence_lengths: + model_input.min_length_q, model_input.min_length_k = self.min_lengths if config.return_document_index: model_input.document_index_q, model_input.document_index_k = self.document_index if config.return_position_index: @@ -118,13 +124,20 @@ def cumulative_lengths(self) -> tuple[torch.Tensor, torch.Tensor]: return cumulative_lengths_q, cumulative_lengths_k @functools.cached_property - def max_lengths(self) -> tuple[torch.Tensor, torch.Tensor]: + def _first_length_k(self) -> int: + # First doc's K-side length includes the past KV prefix; remaining docs match q-side. + return self.sequence_k_past + self.lengths[0] - self.first_document_begin + + @functools.cached_property + def max_lengths(self) -> tuple[int, int]: max_length_q = max(self.lengths) - max_length_k = max(max_length_q, self.sequence_k_past + self.lengths[0] - self.first_document_begin) - return ( - torch.full((1,), max_length_q, dtype=torch.int32, device=self.device), - torch.full((1,), max_length_k, dtype=torch.int32, device=self.device), - ) + return max_length_q, max(max_length_q, self._first_length_k) + + @functools.cached_property + def min_lengths(self) -> tuple[int, int]: + min_length_q = min(self.lengths) + min_length_k = min(self._first_length_k, *self.lengths[1:]) if len(self.lengths) > 1 else self._first_length_k + return min_length_q, min_length_k @functools.cached_property def document_index(self) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/fast_llm/data/document/config.py b/fast_llm/data/document/config.py index 352311b51..a90bcdebc 100644 --- a/fast_llm/data/document/config.py +++ b/fast_llm/data/document/config.py @@ -25,6 +25,7 @@ class LengthPreprocessingConfig(BatchPreprocessingConfig): distributed: DistributedConfig = Field() return_cumulative_sequence_lengths: bool = Field(default=False) return_max_sequence_lengths: bool = Field(default=False) + return_min_sequence_lengths: bool = Field(default=False) return_document_index: bool = Field(default=False) return_position_index: bool = Field(default=False) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 12f85bf28..dba863e94 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -10,7 +10,12 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.utils import wrap_forward_backward -from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs +from fast_llm.layers.attention.config import ( + _FLASH_MAX_HEAD_SIZE, + AttentionConfig, + AttentionImplementation, + AttentionKwargs, +) from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.tensor import TensorMeta @@ -79,12 +84,22 @@ def __init__( peft=peft, return_bias=return_bias, ) + # Two-stage resolution so callers can set `implementation=sdpa` to get the auto-picked SDPA flavor. self._implementation = self._config.implementation if self._implementation == AttentionImplementation.auto: - if _flash_available and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16): + if ( + _flash_available + and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16) + and self._config.head_size <= _FLASH_MAX_HEAD_SIZE + ): self._implementation = AttentionImplementation.flash else: - self._implementation = AttentionImplementation.backup + self._implementation = AttentionImplementation.sdpa + if self._implementation == AttentionImplementation.sdpa: + if self._distributed_config.use_cuda and self._config.window_size is None: + self._implementation = AttentionImplementation.sdpa_nested + else: + self._implementation = AttentionImplementation.sdpa_dense self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim( @@ -258,6 +273,69 @@ def _attn_flash( softmax_scale=self._softmax_scale, ) + def _attn_sdpa( + self, + query: torch.Tensor, # total_q, heads, head_size + key: torch.Tensor, # total_k, head_groups, head_size (pre-GQA-expansion) + value: torch.Tensor, # total_k, head_groups, head_size (pre-GQA-expansion) + kwargs: dict[str, typing.Any], + ) -> torch.Tensor: # total_q, heads, head_size + # SDPA's fused kernels require Q/K/V to share heads, so we expand K/V across query heads. + # `enable_gqa=True` would handle this internally, but it forces a fallback to the MATH + # backend (which materializes the O(S^2) attention matrix) and is unsupported by the + # nested-jagged path entirely; manual `repeat_interleave` keeps EFFICIENT in play. + if self._local_heads_per_group > 1: + key = key.repeat_interleave(self._local_heads_per_group, dim=1) + value = value.repeat_interleave(self._local_heads_per_group, dim=1) + + sdpa_args: dict[str, typing.Any] = { + "dropout_p": self._config.dropout if self.training else 0.0, + "scale": self._softmax_scale, + } + if self._implementation == AttentionImplementation.sdpa_nested: + # Wrap each document as its own batch element via nested-jagged so cross-doc masking + # is structural and EFFICIENT skips materializing the attention mask. The dispatch + # otherwise reads `max_seqlen`/`min_seqlen` to host on every call; passing them in + # explicitly keeps the path sync-free. + # `nested_tensor_from_jagged` requires int64 offsets. + cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64) + cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64) + query = torch.nested.nested_tensor_from_jagged( + query, + cu_seqlens_q, + min_seqlen=kwargs[AttentionKwargs.min_seqlen_q], + max_seqlen=kwargs[AttentionKwargs.max_seqlen_q], + ) + key = torch.nested.nested_tensor_from_jagged( + key, + cu_seqlens_k, + min_seqlen=kwargs[AttentionKwargs.min_seqlen_k], + max_seqlen=kwargs[AttentionKwargs.max_seqlen_k], + ) + value = torch.nested.nested_tensor_from_jagged( + value, + cu_seqlens_k, + min_seqlen=kwargs[AttentionKwargs.min_seqlen_k], + max_seqlen=kwargs[AttentionKwargs.max_seqlen_k], + ) + sdpa_args["is_causal"] = self._config.causal + else: + # Dense + backup's preprocessed causal+document mask. Required on CPU (MATH rejects + # nested + is_causal) and on CUDA with sliding window (the nested path can't express + # it). Backup builds the mask as (1, sq, 1, sk); SDPA wants (B, H, sq, sk). + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + sdpa_args["attn_mask"] = kwargs[AttentionKwargs.attention_mask].transpose(1, 2) + + output = torch.nn.functional.scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + **sdpa_args, + ).transpose(1, 2) + return output.values() if output.is_nested else output.squeeze(0) + def _apply_norm_with_grad_capture( self, norm: torch.nn.Module, x: torch.Tensor ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: @@ -420,6 +498,8 @@ def _forward( with set_generator(self._distributed.tp_generator): if self._implementation == AttentionImplementation.flash: input_ = self._attn_flash(query, key, value, kwargs) + elif self._implementation in (AttentionImplementation.sdpa_nested, AttentionImplementation.sdpa_dense): + input_ = self._attn_sdpa(query, key, value, kwargs) elif self._implementation == AttentionImplementation.backup: # TODO: Avoid the flattens. input_ = self._attn_backup(query, key, value, kwargs) @@ -472,7 +552,11 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c attention_compute = sequence_q * sequence_k * attn_compute_base - if (not config.hardware) or self._implementation in AttentionImplementation.flash: + if (not config.hardware) or self._implementation in ( + AttentionImplementation.flash, + AttentionImplementation.sdpa_nested, + AttentionImplementation.sdpa_dense, + ): # Remove non-causal part. (TODO: Support non-causal) # TODO: Compute is overestimated without cross-document attention. attention_compute -= (sequence_q * (sequence_q - 1) * attn_compute_base) // 2 @@ -498,19 +582,27 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c ) def get_preprocessing_config(self) -> dict[str, typing.Any]: - return ( - { + if self._implementation == AttentionImplementation.flash: + return { "return_cumulative_sequence_lengths": True, "return_max_sequence_lengths": True, "causal": self._config.causal, } - if self._implementation == AttentionImplementation.flash - else {"return_document_index": True, "causal": self._config.causal} - ) + elif self._implementation == AttentionImplementation.sdpa_nested: + return { + "return_cumulative_sequence_lengths": True, + "return_max_sequence_lengths": True, + "return_min_sequence_lengths": True, + "causal": self._config.causal, + } + elif self._implementation in (AttentionImplementation.sdpa_dense, AttentionImplementation.backup): + return {"return_document_index": True, "causal": self._config.causal} + else: + raise NotImplementedError(self._implementation) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._rotary.preprocess(kwargs) - if self._implementation == AttentionImplementation.backup: + if self._implementation in (AttentionImplementation.backup, AttentionImplementation.sdpa_dense): self._preprocess_for_backup_attention(kwargs) def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> None: diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index cc5d80e88..95c67e7a3 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -21,6 +21,8 @@ class MixerKwargs(BlockKwargs): cu_seqlens_k = "cu_seqlens_k" max_seqlen_q = "max_seqlen_q" max_seqlen_k = "max_seqlen_k" + min_seqlen_q = "min_seqlen_q" + min_seqlen_k = "min_seqlen_k" document_index_q = "document_index_q" document_index_k = "document_index_k" position_ids = "position_ids" @@ -35,9 +37,15 @@ class AttentionKwargs(MixerKwargs): past_key_values = "past_key_values" +_FLASH_MAX_HEAD_SIZE = 256 + + class AttentionImplementation(enum.StrEnum): auto = "auto" flash = "flash" + sdpa = "sdpa" + sdpa_nested = "sdpa_nested" + sdpa_dense = "sdpa_dense" backup = "backup" @@ -120,7 +128,10 @@ class AttentionConfig(MixerConfig): ) implementation: AttentionImplementation = Field( default=AttentionImplementation.auto, - desc="The implementation to use for the attention layer. Default: `flash` if supported, otherwise `backup`.", + desc="The implementation to use for the attention layer." + " `auto` picks `flash` when available (bf16/fp16, head_size <= 256, flash-attn installed), otherwise `sdpa`." + " `sdpa` further resolves to `sdpa_nested` on CUDA without sliding window, and to `sdpa_dense` otherwise." + " `sdpa_nested` and `sdpa_dense` are explicit overrides; `backup` is a slow pure-PyTorch fallback.", hint=FieldHint.feature, ) query_norm: NormalizationConfig | None = Field( @@ -152,6 +163,14 @@ def _validate(self) -> None: if not self.causal: assert self.window_size is None, "Non-causal windowed attention is not supported." + if self.implementation == AttentionImplementation.flash: + Assert.leq(self.head_size, _FLASH_MAX_HEAD_SIZE) + + if self.implementation == AttentionImplementation.sdpa_nested: + assert ( + self.window_size is None + ), "`sdpa_nested` does not support sliding window; use `sdpa` or `sdpa_dense`." + @property def layer_class(self) -> "type[Attention]": from fast_llm.layers.attention.attention import Attention diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index d572816b2..4940631b9 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -174,78 +174,139 @@ def expected_output( return torch.nn.functional.linear(attn_out.flatten(1), attention.dense.weight.detach()) -_base_attention_cases = [ +_LENGTHS_FULL = [[15], [6, 9], [4, 1, 10], [20, 32, 10, 11, 9, 18]] +_LENGTHS_SHORT = [[15], [4, 1, 10]] +_LENGTHS_SINGLE = [[15]] + +_attention_test_cases: list[tuple[AttentionTestConfig, list[int]]] = [] + +# Mask, group, and window base cases — no norms, swept over all length sets. +for name, kwargs in ( ("causal", {"causal": True}), ("noncausal", {"causal": False}), ("window", {"causal": True, "window_size": 4}), ("mqa", {"causal": True, "kv_heads": 1}), ("mha", {"causal": True, "kv_heads": _HEADS}), -] - -_attention_rotary_cases = [ - # Rotary: packing equivalence is skipped for multi-document inputs (packed rotary uses global - # positions; per-sequence reference uses per-doc positions). All three checks run for single-doc inputs. - ("causal_rotary", {"causal": True, "rotary": True}), -] - -_attention_norm_variants = [ - ("no_norm", {}), - ("query_norm", {"query_norm": True}), - ("key_norm", {"key_norm": True}), - ("value_norm", {"value_norm": True}), - ("both_norms", {"query_norm": True, "key_norm": True}), - ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), -] - -_attention_shared_key_value_cases = [ - ("shared_key_value", {"shared_key_value": True}), - ("shared_key_value_rotary", {"shared_key_value": True, "rotary": True}), - # Gemma 4's full-attention layer combines shared_key_value with ProportionalRotary. +): + for lengths in _LENGTHS_FULL: + _attention_test_cases.append((AttentionTestConfig(name=f"{name}_no_norm", **kwargs), lengths)) + +# Per-head norm variants on causal and shared key/value bases. Rotary bases use single-doc +# inputs because the packed and per-sequence rotary references diverge across boundaries. +for base_name, base_kwargs, variants, length_set in ( + ( + "causal", + {"causal": True}, + ( + ("query_norm", {"query_norm": True}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("both_norms", {"query_norm": True, "key_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SHORT, + ), + ( + "causal_rotary", + {"causal": True, "rotary": True}, + ( + ("no_norm", {}), + ("query_norm", {"query_norm": True}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("both_norms", {"query_norm": True, "key_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SINGLE, + ), + ( + "shared_key_value", + {"shared_key_value": True}, + ( + ("no_norm", {}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SHORT, + ), + ( + "shared_key_value_rotary", + {"shared_key_value": True, "rotary": True}, + ( + ("no_norm", {}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SINGLE, + ), ( "shared_key_value_proportional_rotary", {"shared_key_value": True, "rotary": True, "rotary_partial_rotary_factor": 0.5}, + ( + ("no_norm", {}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SINGLE, ), -] - -_attention_shared_key_value_norm_variants = [ - ("no_norm", {}), - ("key_norm", {"key_norm": True}), - ("value_norm", {"value_norm": True}), - ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), -] - -# Norms apply per-head and don't interact with mask/group structure, so we test all norm -# variants on a single base (causal) instead of crossing every norm with every base. -# Lengths matter for packing/flash equivalence checks, so we sweep all lengths on the -# base × no_norm cases (where seq layout interacts most with attention math). -_LENGTHS_FULL = [[15], [6, 9], [4, 1, 10], [20, 32, 10, 11, 9, 18]] -_LENGTHS_SHORT = [[15], [4, 1, 10]] -_LENGTHS_SINGLE = [[15]] - +): + for variant_name, variant_kwargs in variants: + for lengths in length_set: + _attention_test_cases.append( + ( + AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs), + lengths, + ) + ) -def _build_test_cases() -> list[tuple[AttentionTestConfig, list[int]]]: - cases: list[tuple[AttentionTestConfig, list[int]]] = [] - for base_name, base_kwargs in _base_attention_cases: - config = AttentionTestConfig(name=f"{base_name}_no_norm", **base_kwargs) - cases.extend((config, lengths) for lengths in _LENGTHS_FULL) - for variant_name, variant_kwargs in _attention_norm_variants: - if variant_name == "no_norm": - continue - config = AttentionTestConfig(name=f"causal_{variant_name}", causal=True, **variant_kwargs) - cases.extend((config, lengths) for lengths in _LENGTHS_SHORT) - for base_name, base_kwargs in _attention_rotary_cases: - for variant_name, variant_kwargs in _attention_norm_variants: - config = AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs) - cases.extend((config, lengths) for lengths in _LENGTHS_SINGLE) - for base_name, base_kwargs in _attention_shared_key_value_cases: - lengths_set = _LENGTHS_SINGLE if base_kwargs.get("rotary") else _LENGTHS_SHORT - for variant_name, variant_kwargs in _attention_shared_key_value_norm_variants: - config = AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs) - cases.extend((config, lengths) for lengths in lengths_set) - return cases +# head_size > 256 — exercises the SDPA-only regime (flash caps at 256). +for name, kwargs in ( + ("large_head_causal_no_norm", {"causal": True, "head_size": 320}), + ("large_head_mqa_no_norm", {"causal": True, "head_size": 320, "kv_heads": 1}), +): + for lengths in _LENGTHS_SHORT: + _attention_test_cases.append((AttentionTestConfig(name=name, **kwargs), lengths)) -_attention_test_cases = _build_test_cases() +def _check_packed( + implementation: str, + config: AttentionTestConfig, + hidden_dim: TensorDim, + lengths: list[int], + attention_f32: Attention, + distributed_config: DistributedConfig, + distributed: Distributed, + hidden_states: torch.Tensor, + out_ref: torch.Tensor, + grads_ref: list[torch.Tensor], + out_rtol: float, + grad_rtol: float, +) -> None: + attention_impl: Attention = config.get_attention_config(implementation).get_layer( + distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False + ) + stage_impl = get_stage([attention_impl], distributed) + for param_impl, param_f32 in zip(attention_impl.parameters(), attention_f32.parameters(), strict=True): + param_impl.data.copy_(param_f32.data) + (model_input,) = LanguageModelBatch( + tokens=torch.empty(sum(lengths), dtype=torch.int64, device=hidden_states.device), lengths=lengths + ).get_model_inputs( + LanguageModelBatchPreprocessingConfig( + distributed=distributed_config, + predicted_tokens=0, + **attention_impl.get_preprocessing_config(), + ) + ) + kwargs_impl = model_input.to_kwargs() + attention_impl.preprocess(kwargs_impl) + out_impl, context = stage_impl.forward(hidden_states, kwargs_impl) + stage_impl.backward(torch.ones_like(out_impl), context) + Assert.rms_close_relative(out_impl, out_ref, out_rtol, 1e-7) + for param_impl, grad_ref in zip(attention_impl.parameters(), grads_ref, strict=True): + Assert.rms_close_relative(param_impl.grad_buffer, grad_ref, grad_rtol, 1e-7, msg=implementation) def _run_per_seq_reference( @@ -255,7 +316,6 @@ def _run_per_seq_reference( hidden_states: torch.Tensor, lengths: list[int], device: torch.device, - with_backward: bool = True, ) -> torch.Tensor: out_refs = [] for length, hidden_slice in zip(lengths, torch.split(hidden_states, lengths, dim=0), strict=True): @@ -271,8 +331,7 @@ def _run_per_seq_reference( kwargs = model_input.to_kwargs() attention.preprocess(kwargs) out, context = stage.forward(hidden_slice, kwargs) - if with_backward: - stage.backward(torch.ones_like(out), context) + stage.backward(torch.ones_like(out), context) out_refs.append(out.detach()) return torch.cat(out_refs, dim=0) @@ -357,50 +416,60 @@ def _test_attention(config: AttentionTestConfig, lengths: list[int]) -> None: Assert.rms_close_relative(param.grad_buffer, grad_ref, 1e-5, 1e-7, msg=name) stage.reset_gradients() - # Flash equivalence check: packed flash output must match per-sequence bfloat16 backup reference. - if _flash_available: - distributed_config_bf16 = DistributedConfig(compute_dtype=DataType.bfloat16, use_cuda=True) - distributed_bf16 = Distributed(distributed_config_bf16) + _check_packed( + "sdpa_dense", + config, + hidden_dim, + lengths, + attention, + distributed_config, + distributed, + hidden_states, + out_ref, + grads_ref, + 1e-5, + 1e-5, + ) - attention_backup_bf16: Attention = config.get_attention_config("backup").get_layer( - distributed_config_bf16, hidden_dim, lr_scale=None, peft=None, return_bias=False - ) - stage_backup_bf16 = get_stage([attention_backup_bf16], distributed_bf16) - for param_bf16, param_f32 in zip(attention_backup_bf16.parameters(), attention.parameters(), strict=True): - param_bf16.data.copy_(param_f32.data) - - hidden_states_bf16 = hidden_states.detach().to(torch.bfloat16) - out_ref_bf16 = _run_per_seq_reference( - attention_backup_bf16, - stage_backup_bf16, - distributed_config_bf16, - hidden_states_bf16, - lengths, - device, - with_backward=False, - ) + if not torch.cuda.is_available(): + return - attention_flash: Attention = config.get_attention_config("flash").get_layer( - distributed_config_bf16, hidden_dim, lr_scale=None, peft=None, return_bias=False - ) - stage_flash = get_stage([attention_flash], distributed_bf16) - for param_flash, param_f32 in zip(attention_flash.parameters(), attention.parameters(), strict=True): - param_flash.data.copy_(param_f32.data) + distributed_config_bf16 = DistributedConfig(compute_dtype=DataType.bfloat16, use_cuda=True) + distributed_bf16 = Distributed(distributed_config_bf16) - (model_input_flash,) = LanguageModelBatch( - tokens=torch.empty(num_tokens, dtype=torch.int64, device=device), lengths=lengths - ).get_model_inputs( - LanguageModelBatchPreprocessingConfig( - distributed=distributed_config_bf16, - predicted_tokens=0, - **attention_flash.get_preprocessing_config(), - ) - ) - kwargs_flash = model_input_flash.to_kwargs() - attention_flash.preprocess(kwargs_flash) - out_flash, _ = stage_flash.forward(hidden_states_bf16, kwargs_flash) + attention_backup_bf16: Attention = config.get_attention_config("backup").get_layer( + distributed_config_bf16, hidden_dim, lr_scale=None, peft=None, return_bias=False + ) + stage_backup_bf16 = get_stage([attention_backup_bf16], distributed_bf16) + for param_bf16, param_f32 in zip(attention_backup_bf16.parameters(), attention.parameters(), strict=True): + param_bf16.data.copy_(param_f32.data) + + hidden_states_bf16 = hidden_states.detach().to(torch.bfloat16).requires_grad_() + out_ref_bf16 = _run_per_seq_reference( + attention_backup_bf16, stage_backup_bf16, distributed_config_bf16, hidden_states_bf16, lengths, device + ) + grads_ref_bf16 = [param.grad_buffer.clone() for param in attention_backup_bf16.parameters()] - Assert.rms_close_relative(out_flash, out_ref_bf16, 5e-3, 1e-7) + # bf16 grad rtol is looser than forward: reduction-order noise compounds through backward. + for implementation in ("sdpa_dense", "flash", "sdpa_nested"): + if implementation == "flash" and (not _flash_available or config.head_size > 256): + continue + if implementation == "sdpa_nested" and config.window_size is not None: + continue + _check_packed( + implementation, + config, + hidden_dim, + lengths, + attention, + distributed_config_bf16, + distributed_bf16, + hidden_states_bf16, + out_ref_bf16, + grads_ref_bf16, + 5e-3, + 1.5e-2, + ) @pytest.mark.slow