Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fast_llm/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def _forward(
token_dims = (kwargs[AttentionKwargs.batch_dim], kwargs[AttentionKwargs.sequence_q_dim])
token_shape = tuple(dim.size for dim in token_dims)
query = query.unflatten(0, token_shape)
key_value = key_value.unflatten(0, token_shape)
key_value = key_value.unflatten(0, (token_shape[0], token_shape[1] * self._sequence_data_parallel_dim.size))

# TODO: Move the rest to function.

Expand Down Expand Up @@ -457,7 +457,7 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non
seq_ids = torch.stack(
[
torch.cat([torch.full((x,), i, device=device) for i, x in enumerate(sample_lens)])
for sample_lens in kwargs[AttentionKwargs.sequence_lengths]
for sample_lens in kwargs[AttentionKwargs.lengths]
]
)
document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None])[
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/layers/attention/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MixerKwargs(BlockKwargs):
cu_seqlens_k = "cu_seqlens_k"
max_seqlen_q = "max_seqlen_q"
max_seqlen_k = "max_seqlen_k"
seq_idx = "seq_idx"
document_index_q = "document_index_q"
position_ids = "position_ids"


Expand Down
6 changes: 2 additions & 4 deletions fast_llm/layers/attention/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ def preprocess_for_varlen(
Assert.eq(kwargs[MixerKwargs.sequence_k_dim].global_size, kwargs[MixerKwargs.sequence_q_dim].global_size)

sequence_lengths = [
sequence_length
for sequence_lengths in kwargs[MixerKwargs.sequence_lengths]
for sequence_length in sequence_lengths
sequence_length for sequence_lengths in kwargs[MixerKwargs.lengths] for sequence_length in sequence_lengths
]
if return_cu_seqlens:
cu_seqlens_q = torch.tensor([0] + sequence_lengths, dtype=torch.int32, device=device).cumsum(
Expand All @@ -43,7 +41,7 @@ def preprocess_for_varlen(
kwargs[MixerKwargs.max_seqlen_q] = max_seqlen_q
kwargs[MixerKwargs.max_seqlen_k] = max_seqlen_q
if return_seq_idx:
kwargs[MixerKwargs.seq_idx] = torch.cat(
kwargs[MixerKwargs.document_index_q] = torch.cat(
[
torch.full((sequence_length,), i, dtype=torch.int32, device=device)
for i, sequence_length in enumerate(sequence_lengths)
Expand Down
16 changes: 15 additions & 1 deletion fast_llm/layers/block/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class BlockKwargs:
hidden_token_dim = "hidden_token_dim"
# TODO: These are confusing
sequence_length = "sequence_length"
sequence_lengths = "sequence_lengths"
lengths = "sequence_lengths"
# TODO: Belongs elsewhere?
grad_output = "grad_output"
activation_distillation_targets = "activation_distillation_targets"
Expand Down Expand Up @@ -84,13 +84,15 @@ def get_layer(
*,
lr_scale: float | None,
peft: PeftConfig | None,
**kwargs,
) -> "BlockBase":
return self.layer_class(
self,
distributed_config,
hidden_dim=hidden_dim,
lr_scale=combine_lr_scales(lr_scale, self.lr_scale),
peft=peft,
**kwargs,
)

def get_reference_models(self) -> set[str]:
Expand All @@ -106,6 +108,10 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi
return FixedBlockSequenceConfig._from_dict(default, strict)
return super()._from_dict(default, strict=strict)

@property
def last_block_config(self) -> BlockConfig:
raise NotImplementedError()


@config_class(dynamic_type={BlockSequenceConfig: "fixed"})
class FixedBlockSequenceConfig(BlockSequenceConfig):
Expand All @@ -130,6 +136,10 @@ def layer_class(self) -> "type[FixedBlockSequence]":
def get_reference_models(self) -> set[str]:
return self.block.get_reference_models()

@property
def last_block_config(self) -> BlockConfig:
return self.block


@config_class(dynamic_type={BlockSequenceConfig: "pattern"})
class PatternBlockSequenceConfig(BlockSequenceConfig):
Expand Down Expand Up @@ -161,6 +171,10 @@ def _validate(self):

super()._validate()

@property
def last_block_config(self) -> BlockConfig:
return self.blocks[self.expanded_pattern[-1]]

@property
def layer_class(self) -> "type[PatternBlockSequence]":
from fast_llm.layers.block.sequence import PatternBlockSequence
Expand Down
16 changes: 14 additions & 2 deletions fast_llm/layers/block/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
hidden_dim: TensorDim,
lr_scale: float | None,
peft: PeftConfig | None,
return_last_layer_input: bool = False,
):
super().__init__(
config,
Expand All @@ -40,8 +41,13 @@ def __init__(
hidden_dim,
lr_scale=self._lr_scale,
peft=self._peft,
**(
{"return_input": True}
if return_last_layer_input and block_index == self._config.num_blocks - 1
else {}
),
)
for _ in range(self._config.num_blocks)
for block_index in range(self._config.num_blocks)
]
)

Expand Down Expand Up @@ -75,6 +81,7 @@ def __init__(
hidden_dim: TensorDim,
lr_scale: float | None,
peft: PeftConfig | None,
return_last_layer_input: bool = False,
):
super().__init__(
config,
Expand All @@ -90,8 +97,13 @@ def __init__(
hidden_dim,
lr_scale=self._lr_scale,
peft=self._peft,
**(
{"return_input": True}
if return_last_layer_input and block_index == self._config.num_blocks - 1
else {}
),
)
for name in self._config.expanded_pattern
for block_index, name in enumerate(self._config.expanded_pattern)
]
)

Expand Down
20 changes: 13 additions & 7 deletions fast_llm/layers/common/linear/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ def __init__(
else self._forward_torch
)

def _forward_torch(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
if kwargs:
raise NotImplementedError(
f"Arguments {tuple(kwargs)} not implemented for torch implementation of 1d convolution."
)
def _forward_torch(
self, input_: torch.Tensor, document_index: torch.Tensor | None = None, lengths: list[int] | None = None
) -> torch.Tensor:
if document_index is not None and lengths is None:
raise ValueError("Torch implementation of CausalConv1d requires lengths.")
if lengths is not None:
return torch.cat([self._forward_torch(x) for x in input_.split(lengths, dim=-1)], dim=-1)
return self._activation.activation_fn(
torch.nn.functional.conv1d(
input_,
Expand All @@ -49,13 +51,17 @@ def _forward_torch(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
)[..., : input_.size(1)]
)

def _forward_causal_conv1d(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
def _forward_causal_conv1d(
self, input_: torch.Tensor, document_index: torch.Tensor | None = None, lengths: list[int] | None = None
) -> torch.Tensor:
if lengths is not None and document_index is None:
raise ValueError("Compiled implementation of CausalConv1d requires document indices.")
return _causal_conv1d_fn(
input_,
self.weight.squeeze(1),
self.bias,
activation=(None if self._activation == ActivationType.identity else self._activation.value),
**kwargs,
seq_idx=document_index,
)

def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int:
Expand Down
131 changes: 30 additions & 101 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import abc
import typing

from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none
Expand All @@ -14,7 +13,7 @@

if typing.TYPE_CHECKING:
from fast_llm.layers.language_model.embedding import LanguageModelEmbedding
from fast_llm.layers.language_model.head import LanguageModelHead, LanguageModelHeadBase
from fast_llm.layers.language_model.head import LanguageModelHead
from fast_llm.layers.language_model.language_model import LanguageModel
from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction

Expand Down Expand Up @@ -95,41 +94,8 @@ def layer_class(self) -> "type[LanguageModelEmbedding]":
return LanguageModelEmbedding


@config_class(registry=True)
class LanguageModelHeadBaseConfig(BlockConfig):
@classmethod
def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self:
if cls is LanguageModelHeadBaseConfig and cls.get_subclass(default.get("type")) is None:
# Default subclass.
return LanguageModelHeadConfig._from_dict(default, strict)
return super()._from_dict(default, strict=strict)

def get_layer(
self,
distributed_config: DistributedConfig,
embeddings_config: LanguageModelEmbeddingsConfig,
*,
hidden_dim: TensorDim,
lr_scale: float | None,
peft: PeftConfig | None,
) -> "LanguageModelHeadBase":
return self.layer_class(
self,
distributed_config,
embeddings_config,
hidden_dim=hidden_dim,
lr_scale=combine_lr_scales(lr_scale, self.lr_scale),
peft=peft,
)

@property
@abc.abstractmethod
def max_prediction_distance(self) -> int:
pass


@config_class(dynamic_type={LanguageModelHeadBaseConfig: "language_model_head"})
class LanguageModelHeadConfig(LanguageModelHeadBaseConfig):
@config_class()
class LanguageModelHeadConfig(BlockConfig):
_abstract = False
normalization: NormalizationConfig = Field(
desc="Configuration for the final normalization layer.",
Expand Down Expand Up @@ -160,6 +126,18 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig):
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
prediction_heads: int = Field(
default=1,
desc="Prediction heads.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
prediction_loss_coefficient: list[float] | None = Field(
default=None,
desc="Loss coefficient for each prediction head.",
doc="If not provided, all heads are equally weighted.",
hint=FieldHint.feature,
)

def get_layer(
self,
Expand All @@ -169,85 +147,36 @@ def get_layer(
hidden_dim: TensorDim,
lr_scale: float | None,
peft: PeftConfig | None,
prediction_distance: int = 0,
prediction_heads: int = 1,
loss_coefficient: float = 1.0,
):
return self.layer_class(
block_config: DecoderBlockConfig | None = None,
) -> "tuple[LanguageModelHead, MultiTokenPrediction]":
from fast_llm.layers.language_model.head import LanguageModelHead
from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction

return LanguageModelHead(
self,
distributed_config,
embeddings_config,
hidden_dim=hidden_dim,
lr_scale=combine_lr_scales(lr_scale, self.lr_scale),
peft=peft,
), MultiTokenPrediction(
self,
distributed_config,
embeddings_config,
hidden_dim=hidden_dim,
lr_scale=combine_lr_scales(lr_scale, self.lr_scale),
peft=peft,
prediction_distance=prediction_distance,
prediction_heads=prediction_heads,
loss_coefficient=loss_coefficient,
block_config=block_config,
)

@property
def layer_class(self) -> "type[LanguageModelHead]":
from fast_llm.layers.language_model.head import LanguageModelHead

return LanguageModelHead

def _validate(self) -> None:
super()._validate()
assert LM_HEAD_LOSS_NAME not in self.losses

@property
def max_prediction_distance(self) -> int:
return 1

def get_reference_models(self) -> set[str]:
return {reference_model for loss in self.losses.values() for reference_model in loss.get_reference_models()}


@config_class(dynamic_type={LanguageModelHeadBaseConfig: "multi_token_prediction"})
class MultiTokenPredictionConfig(LanguageModelHeadBaseConfig):
_abstract = False
# Needs to be `DecoderBlockConfig` for the `return_input` interface.
# TODO: Make a generic wrapper for returning input instead?
block: DecoderBlockConfig = Field(
desc="Configuration for the decoder block before each head.",
hint=FieldHint.architecture,
)
# TODO: Generalize? (needs the extra initialization arguments)
head: LanguageModelHeadConfig = Field(
desc="Configuration for the multi-token-prediction heads.",
hint=FieldHint.architecture,
)
prediction_heads: int = Field(
default=1,
desc="Prediction heads.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
prediction_loss_coefficient: list[float] | None = Field(
default=None,
desc="Loss coefficient for each prediction head.",
doc="If not provided, all heads are equally weighted.",
hint=FieldHint.feature,
)

def _validate(self) -> None:
super()._validate()
if isinstance(self.prediction_loss_coefficient, list):
Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads)
for coeff in self.prediction_loss_coefficient:
Assert.geq(coeff, 0)

@property
def layer_class(self) -> "type[MultiTokenPrediction]":
from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction

return MultiTokenPrediction

@property
def max_prediction_distance(self) -> int:
return self.prediction_heads


@config_class()
class LanguageModelConfig(BlockConfig):
decoder: BlockSequenceConfig = Field(
Expand All @@ -258,7 +187,7 @@ class LanguageModelConfig(BlockConfig):
hint=FieldHint.architecture,
desc="Configuration for the language model embeddings.",
)
head: LanguageModelHeadBaseConfig = Field(
head: LanguageModelHeadConfig = Field(
hint=FieldHint.architecture, desc="Configuration for the language model head(s)."
)
tied_embedding_weight: bool = Field(
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/layers/language_model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _forward(
if self._vocab_parallel:
token_mask = (token_ids >= self._vocab_start_index) * (token_ids < self._vocab_end_index)
masked_input = (token_ids - self._vocab_start_index) * token_mask
embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * token_mask.unsqueeze(2) # noqa
embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * token_mask.unsqueeze(-1) # noqa
embeddings = reduce_forward(embeddings, group)
# TODO: Input masking of position embeddings inconsistant with non-vocab-parallel
if self.position_embeddings_weight is not None:
Expand Down
Loading