From 8b674b1699b91994ab935a021d4fec409e83ed47 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 6 Mar 2026 20:49:00 +0000 Subject: [PATCH 1/6] Sample full layouts instead of independent per layer mixer sampling --- fast_llm/layers/block/config.py | 1 + fast_llm/layers/block/sequence.py | 6 ++- fast_llm/layers/decoder/config.py | 14 +++++- fast_llm/layers/decoder/stochastic_mixer.py | 53 +++++++++++++++++++-- 4 files changed, 67 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index fd76d36cb..f6bd8b896 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -46,6 +46,7 @@ class BlockKwargs: hidden_states = "hidden_states" output_hidden_states = "output_hidden_states" activation_mask = "activation_mask" + num_blocks_in_sequence = "num_blocks_in_sequence" @config_class(registry=True) diff --git a/fast_llm/layers/block/sequence.py b/fast_llm/layers/block/sequence.py index 54a5b3471..a53b6a78d 100644 --- a/fast_llm/layers/block/sequence.py +++ b/fast_llm/layers/block/sequence.py @@ -9,7 +9,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.block import BlockBase -from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig +from fast_llm.layers.block.config import BlockKwargs, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.common.peft.config import PeftConfig @@ -56,6 +56,7 @@ def get_layers(self) -> list["Layer"]: return self._layers_with_namespace def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + kwargs[BlockKwargs.num_blocks_in_sequence] = self._config.num_blocks self._layers_with_namespace[0].preprocess(kwargs) def get_loss_definitions(self, count: int = 1) -> list[LossDef]: @@ -110,7 +111,8 @@ def get_layers(self) -> list[Layer]: return self._layers_with_namespace def preprocess(self, kwargs: dict[str, typing.Any]) -> None: - for _, index in self._config.preprocessing_layers.items(): + for name, index in self._config.preprocessing_layers.items(): + kwargs[BlockKwargs.num_blocks_in_sequence] = self._config.expanded_pattern.count(name) self._layers_with_namespace[index].preprocess(kwargs) def get_loss_definitions(self, count: int = 1) -> list[LossDef]: diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 2f5990ccb..4cab2d39b 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -21,6 +21,8 @@ class StochasticMixerKwargs(BlockKwargs): mixer_name = "stochastic_mixer_name" generator = "stochastic_mixer_generator" + layout = "stochastic_mixer_layout" + layout_counter = "stochastic_mixer_layout_counter" @config_class() @@ -91,6 +93,7 @@ class StochasticMixerSamplingStrategy(enum.StrEnum): uniform = "uniform" weighted = "weighted" + full_layout = "full_layout" @config_class(registry=True) @@ -124,7 +127,8 @@ class StochasticMixerConfig(MixerConfig): _abstract = False - mixers: dict[str, MixerConfig] = Field( + mixers: dict[str, MixerConfig] | None = Field( + default=None, desc="Dict of mixer options to sample from (must contain at least 1). " "Keys are mixer names used for debugging and namespacing.", hint=FieldHint.architecture, @@ -162,7 +166,9 @@ class StochasticMixerConfig(MixerConfig): def _validate(self) -> None: super()._validate() - # Validate mixers dict is not empty + # Validate mixers dict is provided and not empty + if self.mixers is None: + raise ValueError("mixers must be provided for StochasticMixerConfig") Assert.gt(len(self.mixers), 0) # Set main_mixer_name to first mixer if not specified @@ -174,6 +180,10 @@ def _validate(self) -> None: if self.main_mixer_name not in self.mixers: raise ValueError(f"main_mixer_name '{self.main_mixer_name}' not found in mixers") + # Validate full_layout incompatibilities + if self.sampling_strategy == StochasticMixerSamplingStrategy.full_layout and self.sampling_weights is not None: + raise ValueError("sampling_weights is not compatible with full_layout sampling strategy") + # Validate and normalize sampling weights if self.sampling_weights is not None: Assert.eq(set(self.sampling_weights.keys()), set(self.mixers.keys())) diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index 984f34b80..8a05a2556 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -66,7 +66,9 @@ def __init__( } ) - if self._config.sampling_strategy == StochasticMixerSamplingStrategy.uniform: + if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: + self._sampling_probs = None + elif self._config.sampling_strategy == StochasticMixerSamplingStrategy.uniform: self._sampling_probs = torch.ones(len(self.mixers), device="cpu") / len(self.mixers) elif self._config.sampling_strategy == StochasticMixerSamplingStrategy.weighted: if self._config.sampling_weights is None: @@ -108,6 +110,13 @@ def _sample_mixer_name(self, kwargs: dict[str, typing.Any]) -> str: if not self.training: return self._config.main_mixer_name + if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: + layout = kwargs[StochasticMixerKwargs.layout] + counter = kwargs[StochasticMixerKwargs.layout_counter] + idx = counter[0] + counter[0] += 1 + return layout[idx] + generator = kwargs[StochasticMixerKwargs.generator] mixer_idx = torch.multinomial(self._sampling_probs, num_samples=1, generator=generator).item() return list(self.mixers.keys())[mixer_idx] @@ -150,6 +159,33 @@ def _forward( return self.mixers[mixer_name]._forward(input_, kwargs, losses, metrics) + def _sample_allocation(self, num_layers: int, generator: torch.Generator) -> list[int]: + """ + Sample a composition of num_layers into num_mixers bins uniformly. + + Uses stars-and-bars: picks (M-1) bar positions from {0, ..., N+M-2}, + giving each mixer a count. All integer partitions are equally likely. + """ + M = len(self.mixers) + N = num_layers + if M == 1: + return [N] + bars = torch.randperm(N + M - 1, generator=generator)[: M - 1].sort().values + padded = torch.cat([torch.tensor([-1]), bars, torch.tensor([N + M - 1])]) + counts = (padded[1:] - padded[:-1] - 1).tolist() + return counts + + def _sample_placement(self, counts: list[int], num_layers: int, generator: torch.Generator) -> list[str]: + """ + Given per-mixer counts, create a shuffled layout. + """ + mixer_names = list(self.mixers.keys()) + layout = [] + for name, count in zip(mixer_names, counts): + layout.extend([name] * count) + perm = torch.randperm(num_layers, generator=generator) + return [layout[i] for i in perm.tolist()] + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: from fast_llm.engine.distributed.config import MAX_SEED from fast_llm.layers.block.config import BlockKwargs @@ -160,6 +196,13 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: generator.manual_seed(seed) kwargs[StochasticMixerKwargs.generator] = generator + if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: + num_layers = kwargs[BlockKwargs.num_blocks_in_sequence] + counts = self._sample_allocation(num_layers, generator) + layout = self._sample_placement(counts, num_layers, generator) + kwargs[StochasticMixerKwargs.layout] = layout + kwargs[StochasticMixerKwargs.layout_counter] = [0] + for mixer in self.mixers.values(): mixer.preprocess(kwargs) @@ -173,8 +216,12 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c """ usages = [mixer.get_compute_usage(input_, kwargs, config) for mixer in self.mixers.values()] - # Weight by sampling probability and return the expected value - expected_usage = sum(usage * prob.item() for usage, prob in zip(usages, self._sampling_probs)) + if self._sampling_probs is not None: + # Weight by sampling probability and return the expected value + expected_usage = sum(usage * prob.item() for usage, prob in zip(usages, self._sampling_probs)) + else: + # full_layout: uniform over compositions, so equal expected weight per mixer + expected_usage = sum(usages) / len(usages) return int(expected_usage) From bdd37c7733a52cb723baab880dbbcdfcc13efbcb Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 9 Mar 2026 19:16:45 +0000 Subject: [PATCH 2/6] predefined layout set --- fast_llm/layers/decoder/config.py | 32 +++++++++++++++++++++ fast_llm/layers/decoder/stochastic_mixer.py | 29 +++++++++++++++++-- 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 4cab2d39b..390a8b05f 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -157,6 +157,21 @@ class StochasticMixerConfig(MixerConfig): hint=FieldHint.feature, ) + predefined_layouts: list[list[str]] | None = Field( + default=None, + desc="List of predefined layouts to oversample. Each layout is a list of mixer names, one per layer. " + "Mixer names must match keys in the mixers dict.", + hint=FieldHint.feature, + ) + + predefined_layout_probability: float = Field( + default=0.0, + desc="Probability of sampling from predefined_layouts instead of using the sampling_strategy. " + "Must be in [0, 1]. Only used when predefined_layouts is provided.", + hint=FieldHint.feature, + valid=check_field(Assert.in_range_incl, 0.0, 1.0), + ) + seed_shift: int = Field( default=_BIG_PRIMES[11], desc="Seed shift for mixer sampling reproducibility.", @@ -191,6 +206,23 @@ def _validate(self) -> None: normalized_values = normalize_probabilities(list(self.sampling_weights.values())) self.sampling_weights = dict(zip(self.sampling_weights.keys(), normalized_values)) + # Validate predefined layouts + if self.predefined_layouts is not None: + if len(self.predefined_layouts) == 0: + raise ValueError("predefined_layouts must be non-empty if provided") + mixer_names = set(self.mixers.keys()) + for i, layout in enumerate(self.predefined_layouts): + unknown = set(layout) - mixer_names + if unknown: + raise ValueError( + f"predefined_layouts[{i}] contains unknown mixer names: {unknown}. " + f"Valid names: {mixer_names}" + ) + if self.predefined_layout_probability <= 0: + warnings.warn("predefined_layouts provided but predefined_layout_probability is 0") + elif self.predefined_layout_probability > 0: + raise ValueError("predefined_layout_probability > 0 but predefined_layouts is not provided") + @property def layer_class(self) -> "type[StochasticMixer]": from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index 8a05a2556..093daff5d 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -16,6 +16,7 @@ ) from fast_llm.logging import get_model_debug_level from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -110,7 +111,8 @@ def _sample_mixer_name(self, kwargs: dict[str, typing.Any]) -> str: if not self.training: return self._config.main_mixer_name - if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: + # Layout-based selection (full_layout strategy or predefined layout override) + if StochasticMixerKwargs.layout in kwargs: layout = kwargs[StochasticMixerKwargs.layout] counter = kwargs[StochasticMixerKwargs.layout_counter] idx = counter[0] @@ -186,6 +188,21 @@ def _sample_placement(self, counts: list[int], num_layers: int, generator: torch perm = torch.randperm(num_layers, generator=generator) return [layout[i] for i in perm.tolist()] + def _sample_predefined_layout(self, num_layers: int, generator: torch.Generator) -> list[str] | None: + """ + With probability `predefined_layout_probability`, pick a predefined layout uniformly. + Returns None if we should use the normal sampling strategy instead. + """ + if not self._config.predefined_layouts or self._config.predefined_layout_probability <= 0: + return None + coin = torch.rand(1, generator=generator).item() + if coin >= self._config.predefined_layout_probability: + return None + idx = torch.randint(len(self._config.predefined_layouts), (1,), generator=generator).item() + layout = list(self._config.predefined_layouts[idx]) + Assert.eq(len(layout), num_layers) + return layout + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: from fast_llm.engine.distributed.config import MAX_SEED from fast_llm.layers.block.config import BlockKwargs @@ -196,8 +213,14 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: generator.manual_seed(seed) kwargs[StochasticMixerKwargs.generator] = generator - if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: - num_layers = kwargs[BlockKwargs.num_blocks_in_sequence] + num_layers = kwargs[BlockKwargs.num_blocks_in_sequence] + predefined = self._sample_predefined_layout(num_layers, generator) + + if predefined is not None: + # Use predefined layout (overrides any sampling strategy) + kwargs[StochasticMixerKwargs.layout] = predefined + kwargs[StochasticMixerKwargs.layout_counter] = [0] + elif self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: counts = self._sample_allocation(num_layers, generator) layout = self._sample_placement(counts, num_layers, generator) kwargs[StochasticMixerKwargs.layout] = layout From e830aa5272e0cfbf6361832a89a1f58544c0cafe Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 24 Mar 2026 17:45:09 +0000 Subject: [PATCH 3/6] neste list overwrite --- fast_llm/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 5411a2078..8363f4e83 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1083,7 +1083,8 @@ def set_nested_dict_value[KeyType, ValueType]( isinstance(d.get(key), (list, set, tuple)) and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in d[key]) ): - raise ValueError("Update not supported for nested lists.") + # Nested lists cannot be meaningfully merged, so replace entirely. + d[key] = value else: d[key] = value else: From 5d60c28242a59ce4e779ac85cbcf149ce3a3a834 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 May 2026 16:16:51 -0400 Subject: [PATCH 4/6] Restructure predefined-layout schema per review - Replace scalar `predefined_layout_probability` with parallel `predefined_layout_probabilities: list[float]`, one weight per layout. - Switch `predefined_layouts` to `default_factory=list` (no longer Optional). - Sampling collapses to a single multinomial over `[*probs, residual]`; residual mass falls through to `sampling_strategy`. - Validate shared layout length in `_validate`; assert vs `num_layers` in `preprocess`. - Revert bundled change to `set_nested_dict_value` that silently overwrote nested-list values; restore the original `ValueError`. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/config.py | 3 +- fast_llm/layers/decoder/config.py | 35 +++++++++++---------- fast_llm/layers/decoder/stochastic_mixer.py | 29 ++++++++++------- 3 files changed, 38 insertions(+), 29 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index bc3f2d573..61e22737a 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1075,8 +1075,7 @@ def set_nested_dict_value[KeyType, ValueType]( isinstance(d.get(key), (list, set, tuple)) and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in d[key]) ): - # Nested lists cannot be meaningfully merged, so replace entirely. - d[key] = value + raise ValueError("Update not supported for nested lists.") else: d[key] = value else: diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index b50f2d119..d66e50d56 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -170,19 +170,19 @@ class StochasticMixerConfig(MixerConfig): hint=FieldHint.feature, ) - predefined_layouts: list[list[str]] | None = Field( - default=None, + predefined_layouts: list[list[str]] = Field( + default_factory=list, desc="List of predefined layouts to oversample. Each layout is a list of mixer names, one per layer. " - "Mixer names must match keys in the mixers dict.", + "Mixer names must match keys in the mixers dict. All layouts must share a common length.", hint=FieldHint.feature, ) - predefined_layout_probability: float = Field( - default=0.0, - desc="Probability of sampling from predefined_layouts instead of using the sampling_strategy. " - "Must be in [0, 1]. Only used when predefined_layouts is provided.", + predefined_layout_probabilities: list[float] = Field( + default_factory=list, + desc="Per-layout sampling probability, parallel to predefined_layouts. " + "Each value must be in [0, 1]; the sum must be <= 1. The residual probability (1 - sum) " + "falls through to sampling_strategy.", hint=FieldHint.feature, - valid=check_field(Assert.in_range_incl, 0.0, 1.0), ) seed_shift: int = Field( @@ -220,21 +220,24 @@ def _validate(self) -> None: self.sampling_weights = dict(zip(self.sampling_weights.keys(), normalized_values)) # Validate predefined layouts - if self.predefined_layouts is not None: - if len(self.predefined_layouts) == 0: - raise ValueError("predefined_layouts must be non-empty if provided") + if self.predefined_layouts: + Assert.eq(len(self.predefined_layout_probabilities), len(self.predefined_layouts)) mixer_names = set(self.mixers.keys()) - for i, layout in enumerate(self.predefined_layouts): + common_length = len(self.predefined_layouts[0]) + for i, (layout, probability) in enumerate( + zip(self.predefined_layouts, self.predefined_layout_probabilities, strict=True) + ): unknown = set(layout) - mixer_names if unknown: raise ValueError( f"predefined_layouts[{i}] contains unknown mixer names: {unknown}. " f"Valid names: {mixer_names}" ) - if self.predefined_layout_probability <= 0: - warnings.warn("predefined_layouts provided but predefined_layout_probability is 0") - elif self.predefined_layout_probability > 0: - raise ValueError("predefined_layout_probability > 0 but predefined_layouts is not provided") + Assert.eq(len(layout), common_length) + Assert.in_range_incl(probability, 0.0, 1.0) + Assert.leq(sum(self.predefined_layout_probabilities), 1.0) + elif self.predefined_layout_probabilities: + raise ValueError("predefined_layout_probabilities provided but predefined_layouts is empty") @property def layer_class(self) -> "type[StochasticMixer]": diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index 0721e6eca..9d036ac52 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -82,6 +82,13 @@ def __init__( else: raise NotImplementedError(f"Sampling strategy {self._config.sampling_strategy} not implemented") + # Multinomial weights over [*predefined_layouts, residual]; residual mass falls through to sampling_strategy. + if self._config.predefined_layouts: + probs = self._config.predefined_layout_probabilities + self._predefined_layout_probs = torch.tensor([*probs, 1.0 - sum(probs)], dtype=torch.float32, device="cpu") + else: + self._predefined_layout_probs = None + logger.info( f"Initialized StochasticMixer with {len(self.mixers)} mixers: " f"{', '.join(f'{name}={type(mixer).__name__}' for name, mixer in self.mixers.items())} " @@ -191,20 +198,18 @@ def _sample_placement(self, counts: list[int], num_layers: int, generator: torch perm = torch.randperm(num_layers, generator=generator) return [layout[i] for i in perm.tolist()] - def _sample_predefined_layout(self, num_layers: int, generator: torch.Generator) -> list[str] | None: + def _sample_predefined_layout(self, generator: torch.Generator) -> list[str] | None: """ - With probability `predefined_layout_probability`, pick a predefined layout uniformly. - Returns None if we should use the normal sampling strategy instead. + Draw one multinomial over [*predefined_layouts, residual]. + Returns the chosen layout, or None when the residual is selected + (caller falls through to sampling_strategy). """ - if not self._config.predefined_layouts or self._config.predefined_layout_probability <= 0: + if self._predefined_layout_probs is None: return None - coin = torch.rand(1, generator=generator).item() - if coin >= self._config.predefined_layout_probability: + idx = torch.multinomial(self._predefined_layout_probs, num_samples=1, generator=generator).item() + if idx == len(self._config.predefined_layouts): return None - idx = torch.randint(len(self._config.predefined_layouts), (1,), generator=generator).item() - layout = list(self._config.predefined_layouts[idx]) - Assert.eq(len(layout), num_layers) - return layout + return list(self._config.predefined_layouts[idx]) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: from fast_llm.engine.distributed.config import MAX_SEED @@ -217,7 +222,9 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: kwargs[StochasticMixerKwargs.generator] = generator num_layers = kwargs[BlockKwargs.num_blocks_in_sequence] - predefined = self._sample_predefined_layout(num_layers, generator) + if self._config.predefined_layouts: + Assert.eq(len(self._config.predefined_layouts[0]), num_layers) + predefined = self._sample_predefined_layout(generator) if predefined is not None: # Use predefined layout (overrides any sampling strategy) From 100f3a3decf205c4c6abd85999b71f2236dd2996 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 May 2026 16:41:09 -0400 Subject: [PATCH 5/6] Fine-review cleanups in StochasticMixer - Drop redundant comment in `preprocess`. - Rename local `predefined` -> `predefined_layout`. - Condense `_sample_predefined_layout` docstring to one line. - Return predefined layout list directly (drop defensive copy). - Rename `_predefined_layout_probs` -> `_predefined_layout_probabilities` and the local `probs` -> `probabilities`; no abbreviations in new code. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/decoder/stochastic_mixer.py | 27 +++++++++------------ 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index 9d036ac52..bc00516e8 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -84,10 +84,12 @@ def __init__( # Multinomial weights over [*predefined_layouts, residual]; residual mass falls through to sampling_strategy. if self._config.predefined_layouts: - probs = self._config.predefined_layout_probabilities - self._predefined_layout_probs = torch.tensor([*probs, 1.0 - sum(probs)], dtype=torch.float32, device="cpu") + probabilities = self._config.predefined_layout_probabilities + self._predefined_layout_probabilities = torch.tensor( + [*probabilities, 1.0 - sum(probabilities)], dtype=torch.float32, device="cpu" + ) else: - self._predefined_layout_probs = None + self._predefined_layout_probabilities = None logger.info( f"Initialized StochasticMixer with {len(self.mixers)} mixers: " @@ -199,17 +201,13 @@ def _sample_placement(self, counts: list[int], num_layers: int, generator: torch return [layout[i] for i in perm.tolist()] def _sample_predefined_layout(self, generator: torch.Generator) -> list[str] | None: - """ - Draw one multinomial over [*predefined_layouts, residual]. - Returns the chosen layout, or None when the residual is selected - (caller falls through to sampling_strategy). - """ - if self._predefined_layout_probs is None: + """Sample one of the predefined layouts, or return None when the residual is drawn.""" + if self._predefined_layout_probabilities is None: return None - idx = torch.multinomial(self._predefined_layout_probs, num_samples=1, generator=generator).item() + idx = torch.multinomial(self._predefined_layout_probabilities, num_samples=1, generator=generator).item() if idx == len(self._config.predefined_layouts): return None - return list(self._config.predefined_layouts[idx]) + return self._config.predefined_layouts[idx] def preprocess(self, kwargs: dict[str, typing.Any]) -> None: from fast_llm.engine.distributed.config import MAX_SEED @@ -224,11 +222,10 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: num_layers = kwargs[BlockKwargs.num_blocks_in_sequence] if self._config.predefined_layouts: Assert.eq(len(self._config.predefined_layouts[0]), num_layers) - predefined = self._sample_predefined_layout(generator) + predefined_layout = self._sample_predefined_layout(generator) - if predefined is not None: - # Use predefined layout (overrides any sampling strategy) - kwargs[StochasticMixerKwargs.layout] = predefined + if predefined_layout is not None: + kwargs[StochasticMixerKwargs.layout] = predefined_layout kwargs[StochasticMixerKwargs.layout_counter] = [0] elif self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: counts = self._sample_allocation(num_layers, generator) From 810b8a8b73e86c85185d4e662108b7441e196c75 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 May 2026 17:03:44 -0400 Subject: [PATCH 6/6] Account for predefined layouts in get_compute_usage Previously the compute estimate ignored `predefined_layouts` entirely and returned the sampling-strategy average, which can be far off when `predefined_layout_probabilities` dominates (e.g. [1.0] always runs that layout, never the strategy mix). Blend per-layout average usage (mixer-count weighted across positions) by `predefined_layout_probabilities`, then add the residual fraction times the existing strategy estimate. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/decoder/stochastic_mixer.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index bc00516e8..2db31c914 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -247,11 +247,26 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c usages = [mixer.get_compute_usage(input_, kwargs, config) for mixer in self.mixers.values()] if self._sampling_probs is not None: - # Weight by sampling probability and return the expected value - expected_usage = sum(usage * prob.item() for usage, prob in zip(usages, self._sampling_probs)) + strategy_usage = sum(usage * prob.item() for usage, prob in zip(usages, self._sampling_probs)) else: # full_layout: uniform over compositions, so equal expected weight per mixer - expected_usage = sum(usages) / len(usages) + strategy_usage = sum(usages) / len(usages) + + if self._config.predefined_layouts: + usage_by_name = dict(zip(self.mixers.keys(), usages, strict=True)) + layout_length = len(self._config.predefined_layouts[0]) + predefined_usage = sum( + probability * sum(usage_by_name[name] for name in layout) / layout_length + for probability, layout in zip( + self._config.predefined_layout_probabilities, + self._config.predefined_layouts, + strict=True, + ) + ) + residual = 1.0 - sum(self._config.predefined_layout_probabilities) + expected_usage = predefined_usage + residual * strategy_usage + else: + expected_usage = strategy_usage return int(expected_usage)