diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 1c0c10c87..d66e50d56 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -170,6 +170,21 @@ class StochasticMixerConfig(MixerConfig): hint=FieldHint.feature, ) + predefined_layouts: list[list[str]] = Field( + default_factory=list, + desc="List of predefined layouts to oversample. Each layout is a list of mixer names, one per layer. " + "Mixer names must match keys in the mixers dict. All layouts must share a common length.", + hint=FieldHint.feature, + ) + + predefined_layout_probabilities: list[float] = Field( + default_factory=list, + desc="Per-layout sampling probability, parallel to predefined_layouts. " + "Each value must be in [0, 1]; the sum must be <= 1. The residual probability (1 - sum) " + "falls through to sampling_strategy.", + hint=FieldHint.feature, + ) + seed_shift: int = Field( default=_BIG_PRIMES[11], desc="Seed shift for mixer sampling reproducibility.", @@ -204,6 +219,26 @@ def _validate(self) -> None: normalized_values = normalize_probabilities(list(self.sampling_weights.values())) self.sampling_weights = dict(zip(self.sampling_weights.keys(), normalized_values)) + # Validate predefined layouts + if self.predefined_layouts: + Assert.eq(len(self.predefined_layout_probabilities), len(self.predefined_layouts)) + mixer_names = set(self.mixers.keys()) + common_length = len(self.predefined_layouts[0]) + for i, (layout, probability) in enumerate( + zip(self.predefined_layouts, self.predefined_layout_probabilities, strict=True) + ): + unknown = set(layout) - mixer_names + if unknown: + raise ValueError( + f"predefined_layouts[{i}] contains unknown mixer names: {unknown}. " + f"Valid names: {mixer_names}" + ) + Assert.eq(len(layout), common_length) + Assert.in_range_incl(probability, 0.0, 1.0) + Assert.leq(sum(self.predefined_layout_probabilities), 1.0) + elif self.predefined_layout_probabilities: + raise ValueError("predefined_layout_probabilities provided but predefined_layouts is empty") + @property def layer_class(self) -> "type[StochasticMixer]": from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index a3ea8b846..2db31c914 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -16,7 +16,7 @@ ) from fast_llm.logging import get_model_debug_level from fast_llm.tensor import TensorMeta -from fast_llm.utils import safe_merge_dicts +from fast_llm.utils import Assert, safe_merge_dicts logger = logging.getLogger(__name__) @@ -82,6 +82,15 @@ def __init__( else: raise NotImplementedError(f"Sampling strategy {self._config.sampling_strategy} not implemented") + # Multinomial weights over [*predefined_layouts, residual]; residual mass falls through to sampling_strategy. + if self._config.predefined_layouts: + probabilities = self._config.predefined_layout_probabilities + self._predefined_layout_probabilities = torch.tensor( + [*probabilities, 1.0 - sum(probabilities)], dtype=torch.float32, device="cpu" + ) + else: + self._predefined_layout_probabilities = None + logger.info( f"Initialized StochasticMixer with {len(self.mixers)} mixers: " f"{', '.join(f'{name}={type(mixer).__name__}' for name, mixer in self.mixers.items())} " @@ -111,7 +120,8 @@ def _sample_mixer_name(self, kwargs: dict[str, typing.Any]) -> str: if not self.training: return self._config.main_mixer_name - if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: + # Layout-based selection (full_layout strategy or predefined layout override) + if StochasticMixerKwargs.layout in kwargs: layout = kwargs[StochasticMixerKwargs.layout] counter = kwargs[StochasticMixerKwargs.layout_counter] idx = counter[0] @@ -190,6 +200,15 @@ def _sample_placement(self, counts: list[int], num_layers: int, generator: torch perm = torch.randperm(num_layers, generator=generator) return [layout[i] for i in perm.tolist()] + def _sample_predefined_layout(self, generator: torch.Generator) -> list[str] | None: + """Sample one of the predefined layouts, or return None when the residual is drawn.""" + if self._predefined_layout_probabilities is None: + return None + idx = torch.multinomial(self._predefined_layout_probabilities, num_samples=1, generator=generator).item() + if idx == len(self._config.predefined_layouts): + return None + return self._config.predefined_layouts[idx] + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: from fast_llm.engine.distributed.config import MAX_SEED from fast_llm.layers.block.config import BlockKwargs @@ -200,8 +219,15 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: generator.manual_seed(seed) kwargs[StochasticMixerKwargs.generator] = generator - if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: - num_layers = kwargs[BlockKwargs.num_blocks_in_sequence] + num_layers = kwargs[BlockKwargs.num_blocks_in_sequence] + if self._config.predefined_layouts: + Assert.eq(len(self._config.predefined_layouts[0]), num_layers) + predefined_layout = self._sample_predefined_layout(generator) + + if predefined_layout is not None: + kwargs[StochasticMixerKwargs.layout] = predefined_layout + kwargs[StochasticMixerKwargs.layout_counter] = [0] + elif self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: counts = self._sample_allocation(num_layers, generator) layout = self._sample_placement(counts, num_layers, generator) kwargs[StochasticMixerKwargs.layout] = layout @@ -221,11 +247,26 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c usages = [mixer.get_compute_usage(input_, kwargs, config) for mixer in self.mixers.values()] if self._sampling_probs is not None: - # Weight by sampling probability and return the expected value - expected_usage = sum(usage * prob.item() for usage, prob in zip(usages, self._sampling_probs)) + strategy_usage = sum(usage * prob.item() for usage, prob in zip(usages, self._sampling_probs)) else: # full_layout: uniform over compositions, so equal expected weight per mixer - expected_usage = sum(usages) / len(usages) + strategy_usage = sum(usages) / len(usages) + + if self._config.predefined_layouts: + usage_by_name = dict(zip(self.mixers.keys(), usages, strict=True)) + layout_length = len(self._config.predefined_layouts[0]) + predefined_usage = sum( + probability * sum(usage_by_name[name] for name in layout) / layout_length + for probability, layout in zip( + self._config.predefined_layout_probabilities, + self._config.predefined_layouts, + strict=True, + ) + ) + residual = 1.0 - sum(self._config.predefined_layout_probabilities) + expected_usage = predefined_usage + residual * strategy_usage + else: + expected_usage = strategy_usage return int(expected_usage)