Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions fast_llm/data/document/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ class BlockModelInput(ModelInput):
lengths: list[int] = None
cumulative_lengths_q: torch.Tensor | None = None
cumulative_lengths_k: torch.Tensor | None = None
max_length_q: torch.Tensor | None = None
max_length_k: torch.Tensor | None = None
max_length_q: int | None = None
max_length_k: int | None = None
min_length_q: int | None = None
min_length_k: int | None = None
document_index_q: torch.Tensor | None = None
document_index_k: torch.Tensor | None = None
position_index: torch.Tensor | None = None
Expand All @@ -44,6 +46,8 @@ def to_kwargs(self) -> dict[str, typing.Any]:
AttentionKwargs.cu_seqlens_k: self.cumulative_lengths_k,
AttentionKwargs.max_seqlen_q: self.max_length_q,
AttentionKwargs.max_seqlen_k: self.max_length_k,
AttentionKwargs.min_seqlen_q: self.min_length_q,
AttentionKwargs.min_seqlen_k: self.min_length_k,
AttentionKwargs.document_index_q: self.document_index_q,
AttentionKwargs.document_index_k: self.document_index_k,
LanguageModelKwargs.position_ids: self.position_index,
Expand Down Expand Up @@ -101,6 +105,8 @@ def preprocess(self, model_input: BlockModelInput, config: LengthPreprocessingCo
model_input.cumulative_lengths_q, model_input.cumulative_lengths_k = self.cumulative_lengths
if config.return_max_sequence_lengths or config.return_document_index:
model_input.max_length_q, model_input.max_length_k = self.max_lengths
if config.return_min_sequence_lengths:
model_input.min_length_q, model_input.min_length_k = self.min_lengths
if config.return_document_index:
model_input.document_index_q, model_input.document_index_k = self.document_index
if config.return_position_index:
Expand All @@ -118,13 +124,20 @@ def cumulative_lengths(self) -> tuple[torch.Tensor, torch.Tensor]:
return cumulative_lengths_q, cumulative_lengths_k

@functools.cached_property
def max_lengths(self) -> tuple[torch.Tensor, torch.Tensor]:
def _first_length_k(self) -> int:
# First doc's K-side length includes the past KV prefix; remaining docs match q-side.
return self.sequence_k_past + self.lengths[0] - self.first_document_begin

@functools.cached_property
def max_lengths(self) -> tuple[int, int]:
max_length_q = max(self.lengths)
max_length_k = max(max_length_q, self.sequence_k_past + self.lengths[0] - self.first_document_begin)
return (
torch.full((1,), max_length_q, dtype=torch.int32, device=self.device),
torch.full((1,), max_length_k, dtype=torch.int32, device=self.device),
)
return max_length_q, max(max_length_q, self._first_length_k)

@functools.cached_property
def min_lengths(self) -> tuple[int, int]:
min_length_q = min(self.lengths)
min_length_k = min(self._first_length_k, *self.lengths[1:]) if len(self.lengths) > 1 else self._first_length_k
return min_length_q, min_length_k

@functools.cached_property
def document_index(self) -> tuple[torch.Tensor, torch.Tensor]:
Expand Down
1 change: 1 addition & 0 deletions fast_llm/data/document/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class LengthPreprocessingConfig(BatchPreprocessingConfig):
distributed: DistributedConfig = Field()
return_cumulative_sequence_lengths: bool = Field(default=False)
return_max_sequence_lengths: bool = Field(default=False)
return_min_sequence_lengths: bool = Field(default=False)
return_document_index: bool = Field(default=False)
return_position_index: bool = Field(default=False)

Expand Down
112 changes: 102 additions & 10 deletions fast_llm/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
from fast_llm.functional.utils import wrap_forward_backward
from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs
from fast_llm.layers.attention.config import (
_FLASH_MAX_HEAD_SIZE,
AttentionConfig,
AttentionImplementation,
AttentionKwargs,
)
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.decoder.block import BlockWithBias
from fast_llm.tensor import TensorMeta
Expand Down Expand Up @@ -79,12 +84,22 @@ def __init__(
peft=peft,
return_bias=return_bias,
)
# Two-stage resolution so callers can set `implementation=sdpa` to get the auto-picked SDPA flavor.
self._implementation = self._config.implementation
if self._implementation == AttentionImplementation.auto:
if _flash_available and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16):
if (
_flash_available
and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16)
and self._config.head_size <= _FLASH_MAX_HEAD_SIZE
):
self._implementation = AttentionImplementation.flash
else:
self._implementation = AttentionImplementation.backup
self._implementation = AttentionImplementation.sdpa
if self._implementation == AttentionImplementation.sdpa:
if self._distributed_config.use_cuda and self._config.window_size is None:
self._implementation = AttentionImplementation.sdpa_nested
else:
self._implementation = AttentionImplementation.sdpa_dense

self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor)
self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim(
Expand Down Expand Up @@ -258,6 +273,69 @@ def _attn_flash(
softmax_scale=self._softmax_scale,
)

def _attn_sdpa(
self,
query: torch.Tensor, # total_q, heads, head_size
key: torch.Tensor, # total_k, head_groups, head_size (pre-GQA-expansion)
value: torch.Tensor, # total_k, head_groups, head_size (pre-GQA-expansion)
kwargs: dict[str, typing.Any],
) -> torch.Tensor: # total_q, heads, head_size
# SDPA's fused kernels require Q/K/V to share heads, so we expand K/V across query heads.
# `enable_gqa=True` would handle this internally, but it forces a fallback to the MATH
# backend (which materializes the O(S^2) attention matrix) and is unsupported by the
# nested-jagged path entirely; manual `repeat_interleave` keeps EFFICIENT in play.
if self._local_heads_per_group > 1:
key = key.repeat_interleave(self._local_heads_per_group, dim=1)
value = value.repeat_interleave(self._local_heads_per_group, dim=1)

sdpa_args: dict[str, typing.Any] = {
"dropout_p": self._config.dropout if self.training else 0.0,
"scale": self._softmax_scale,
}
if self._implementation == AttentionImplementation.sdpa_nested:
# Wrap each document as its own batch element via nested-jagged so cross-doc masking
# is structural and EFFICIENT skips materializing the attention mask. The dispatch
# otherwise reads `max_seqlen`/`min_seqlen` to host on every call; passing them in
# explicitly keeps the path sync-free.
# `nested_tensor_from_jagged` requires int64 offsets.
cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64)
cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64)
query = torch.nested.nested_tensor_from_jagged(
query,
cu_seqlens_q,
min_seqlen=kwargs[AttentionKwargs.min_seqlen_q],
max_seqlen=kwargs[AttentionKwargs.max_seqlen_q],
)
key = torch.nested.nested_tensor_from_jagged(
key,
cu_seqlens_k,
min_seqlen=kwargs[AttentionKwargs.min_seqlen_k],
max_seqlen=kwargs[AttentionKwargs.max_seqlen_k],
)
value = torch.nested.nested_tensor_from_jagged(
value,
cu_seqlens_k,
min_seqlen=kwargs[AttentionKwargs.min_seqlen_k],
max_seqlen=kwargs[AttentionKwargs.max_seqlen_k],
)
sdpa_args["is_causal"] = self._config.causal
else:
# Dense + backup's preprocessed causal+document mask. Required on CPU (MATH rejects
# nested + is_causal) and on CUDA with sliding window (the nested path can't express
# it). Backup builds the mask as (1, sq, 1, sk); SDPA wants (B, H, sq, sk).
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
sdpa_args["attn_mask"] = kwargs[AttentionKwargs.attention_mask].transpose(1, 2)

output = torch.nn.functional.scaled_dot_product_attention(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
**sdpa_args,
).transpose(1, 2)
return output.values() if output.is_nested else output.squeeze(0)

def _apply_norm_with_grad_capture(
self, norm: torch.nn.Module, x: torch.Tensor
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
Expand Down Expand Up @@ -420,6 +498,8 @@ def _forward(
with set_generator(self._distributed.tp_generator):
if self._implementation == AttentionImplementation.flash:
input_ = self._attn_flash(query, key, value, kwargs)
elif self._implementation in (AttentionImplementation.sdpa_nested, AttentionImplementation.sdpa_dense):
input_ = self._attn_sdpa(query, key, value, kwargs)
elif self._implementation == AttentionImplementation.backup:
# TODO: Avoid the flattens.
input_ = self._attn_backup(query, key, value, kwargs)
Expand Down Expand Up @@ -472,7 +552,11 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c

attention_compute = sequence_q * sequence_k * attn_compute_base

if (not config.hardware) or self._implementation in AttentionImplementation.flash:
if (not config.hardware) or self._implementation in (
AttentionImplementation.flash,
AttentionImplementation.sdpa_nested,
AttentionImplementation.sdpa_dense,
):
# Remove non-causal part. (TODO: Support non-causal)
# TODO: Compute is overestimated without cross-document attention.
attention_compute -= (sequence_q * (sequence_q - 1) * attn_compute_base) // 2
Expand All @@ -498,19 +582,27 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c
)

def get_preprocessing_config(self) -> dict[str, typing.Any]:
return (
{
if self._implementation == AttentionImplementation.flash:
return {
"return_cumulative_sequence_lengths": True,
"return_max_sequence_lengths": True,
"causal": self._config.causal,
}
if self._implementation == AttentionImplementation.flash
else {"return_document_index": True, "causal": self._config.causal}
)
elif self._implementation == AttentionImplementation.sdpa_nested:
return {
"return_cumulative_sequence_lengths": True,
"return_max_sequence_lengths": True,
"return_min_sequence_lengths": True,
"causal": self._config.causal,
}
elif self._implementation in (AttentionImplementation.sdpa_dense, AttentionImplementation.backup):
return {"return_document_index": True, "causal": self._config.causal}
else:
raise NotImplementedError(self._implementation)

def preprocess(self, kwargs: dict[str, typing.Any]) -> None:
self._rotary.preprocess(kwargs)
if self._implementation == AttentionImplementation.backup:
if self._implementation in (AttentionImplementation.backup, AttentionImplementation.sdpa_dense):
self._preprocess_for_backup_attention(kwargs)

def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> None:
Expand Down
21 changes: 20 additions & 1 deletion fast_llm/layers/attention/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class MixerKwargs(BlockKwargs):
cu_seqlens_k = "cu_seqlens_k"
max_seqlen_q = "max_seqlen_q"
max_seqlen_k = "max_seqlen_k"
min_seqlen_q = "min_seqlen_q"
min_seqlen_k = "min_seqlen_k"
document_index_q = "document_index_q"
document_index_k = "document_index_k"
position_ids = "position_ids"
Expand All @@ -35,9 +37,15 @@ class AttentionKwargs(MixerKwargs):
past_key_values = "past_key_values"


_FLASH_MAX_HEAD_SIZE = 256


class AttentionImplementation(enum.StrEnum):
auto = "auto"
flash = "flash"
sdpa = "sdpa"
sdpa_nested = "sdpa_nested"
sdpa_dense = "sdpa_dense"
backup = "backup"


Expand Down Expand Up @@ -120,7 +128,10 @@ class AttentionConfig(MixerConfig):
)
implementation: AttentionImplementation = Field(
default=AttentionImplementation.auto,
desc="The implementation to use for the attention layer. Default: `flash` if supported, otherwise `backup`.",
desc="The implementation to use for the attention layer."
" `auto` picks `flash` when available (bf16/fp16, head_size <= 256, flash-attn installed), otherwise `sdpa`."
" `sdpa` further resolves to `sdpa_nested` on CUDA without sliding window, and to `sdpa_dense` otherwise."
" `sdpa_nested` and `sdpa_dense` are explicit overrides; `backup` is a slow pure-PyTorch fallback.",
hint=FieldHint.feature,
)
query_norm: NormalizationConfig | None = Field(
Expand Down Expand Up @@ -152,6 +163,14 @@ def _validate(self) -> None:
if not self.causal:
assert self.window_size is None, "Non-causal windowed attention is not supported."

if self.implementation == AttentionImplementation.flash:
Assert.leq(self.head_size, _FLASH_MAX_HEAD_SIZE)

if self.implementation == AttentionImplementation.sdpa_nested:
assert (
self.window_size is None
), "`sdpa_nested` does not support sliding window; use `sdpa` or `sdpa_dense`."

@property
def layer_class(self) -> "type[Attention]":
from fast_llm.layers.attention.attention import Attention
Expand Down
Loading
Loading