Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions fast_llm/layers/decoder/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,21 @@ class StochasticMixerConfig(MixerConfig):
hint=FieldHint.feature,
)

predefined_layouts: list[list[str]] = Field(
default_factory=list,
desc="List of predefined layouts to oversample. Each layout is a list of mixer names, one per layer. "
"Mixer names must match keys in the mixers dict. All layouts must share a common length.",
hint=FieldHint.feature,
)

predefined_layout_probabilities: list[float] = Field(
default_factory=list,
desc="Per-layout sampling probability, parallel to predefined_layouts. "
"Each value must be in [0, 1]; the sum must be <= 1. The residual probability (1 - sum) "
"falls through to sampling_strategy.",
hint=FieldHint.feature,
)

seed_shift: int = Field(
default=_BIG_PRIMES[11],
desc="Seed shift for mixer sampling reproducibility.",
Expand Down Expand Up @@ -204,6 +219,26 @@ def _validate(self) -> None:
normalized_values = normalize_probabilities(list(self.sampling_weights.values()))
self.sampling_weights = dict(zip(self.sampling_weights.keys(), normalized_values))

# Validate predefined layouts
if self.predefined_layouts:
Assert.eq(len(self.predefined_layout_probabilities), len(self.predefined_layouts))
mixer_names = set(self.mixers.keys())
common_length = len(self.predefined_layouts[0])
for i, (layout, probability) in enumerate(
zip(self.predefined_layouts, self.predefined_layout_probabilities, strict=True)
):
unknown = set(layout) - mixer_names
if unknown:
raise ValueError(
f"predefined_layouts[{i}] contains unknown mixer names: {unknown}. "
f"Valid names: {mixer_names}"
)
Assert.eq(len(layout), common_length)
Assert.in_range_incl(probability, 0.0, 1.0)
Assert.leq(sum(self.predefined_layout_probabilities), 1.0)
elif self.predefined_layout_probabilities:
raise ValueError("predefined_layout_probabilities provided but predefined_layouts is empty")

@property
def layer_class(self) -> "type[StochasticMixer]":
from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer
Expand Down
55 changes: 48 additions & 7 deletions fast_llm/layers/decoder/stochastic_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from fast_llm.logging import get_model_debug_level
from fast_llm.tensor import TensorMeta
from fast_llm.utils import safe_merge_dicts
from fast_llm.utils import Assert, safe_merge_dicts

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -82,6 +82,15 @@ def __init__(
else:
raise NotImplementedError(f"Sampling strategy {self._config.sampling_strategy} not implemented")

# Multinomial weights over [*predefined_layouts, residual]; residual mass falls through to sampling_strategy.
if self._config.predefined_layouts:
probabilities = self._config.predefined_layout_probabilities
self._predefined_layout_probabilities = torch.tensor(
[*probabilities, 1.0 - sum(probabilities)], dtype=torch.float32, device="cpu"
)
else:
self._predefined_layout_probabilities = None

logger.info(
f"Initialized StochasticMixer with {len(self.mixers)} mixers: "
f"{', '.join(f'{name}={type(mixer).__name__}' for name, mixer in self.mixers.items())} "
Expand Down Expand Up @@ -111,7 +120,8 @@ def _sample_mixer_name(self, kwargs: dict[str, typing.Any]) -> str:
if not self.training:
return self._config.main_mixer_name

if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout:
# Layout-based selection (full_layout strategy or predefined layout override)
if StochasticMixerKwargs.layout in kwargs:
layout = kwargs[StochasticMixerKwargs.layout]
counter = kwargs[StochasticMixerKwargs.layout_counter]
idx = counter[0]
Expand Down Expand Up @@ -190,6 +200,15 @@ def _sample_placement(self, counts: list[int], num_layers: int, generator: torch
perm = torch.randperm(num_layers, generator=generator)
return [layout[i] for i in perm.tolist()]

def _sample_predefined_layout(self, generator: torch.Generator) -> list[str] | None:
"""Sample one of the predefined layouts, or return None when the residual is drawn."""
if self._predefined_layout_probabilities is None:
return None
idx = torch.multinomial(self._predefined_layout_probabilities, num_samples=1, generator=generator).item()
if idx == len(self._config.predefined_layouts):
return None
return self._config.predefined_layouts[idx]

def preprocess(self, kwargs: dict[str, typing.Any]) -> None:
from fast_llm.engine.distributed.config import MAX_SEED
from fast_llm.layers.block.config import BlockKwargs
Expand All @@ -200,8 +219,15 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None:
generator.manual_seed(seed)
kwargs[StochasticMixerKwargs.generator] = generator

if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout:
num_layers = kwargs[BlockKwargs.num_blocks_in_sequence]
num_layers = kwargs[BlockKwargs.num_blocks_in_sequence]
if self._config.predefined_layouts:
Assert.eq(len(self._config.predefined_layouts[0]), num_layers)
predefined_layout = self._sample_predefined_layout(generator)

if predefined_layout is not None:
kwargs[StochasticMixerKwargs.layout] = predefined_layout
kwargs[StochasticMixerKwargs.layout_counter] = [0]
elif self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout:
counts = self._sample_allocation(num_layers, generator)
layout = self._sample_placement(counts, num_layers, generator)
kwargs[StochasticMixerKwargs.layout] = layout
Expand All @@ -221,11 +247,26 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c
usages = [mixer.get_compute_usage(input_, kwargs, config) for mixer in self.mixers.values()]

if self._sampling_probs is not None:
# Weight by sampling probability and return the expected value
expected_usage = sum(usage * prob.item() for usage, prob in zip(usages, self._sampling_probs))
strategy_usage = sum(usage * prob.item() for usage, prob in zip(usages, self._sampling_probs))
else:
# full_layout: uniform over compositions, so equal expected weight per mixer
expected_usage = sum(usages) / len(usages)
strategy_usage = sum(usages) / len(usages)

if self._config.predefined_layouts:
usage_by_name = dict(zip(self.mixers.keys(), usages, strict=True))
layout_length = len(self._config.predefined_layouts[0])
predefined_usage = sum(
probability * sum(usage_by_name[name] for name in layout) / layout_length
for probability, layout in zip(
self._config.predefined_layout_probabilities,
self._config.predefined_layouts,
strict=True,
)
)
residual = 1.0 - sum(self._config.predefined_layout_probabilities)
expected_usage = predefined_usage + residual * strategy_usage
else:
expected_usage = strategy_usage

return int(expected_usage)

Expand Down
Loading