From 6720c18222f0d37164b41c22a7493ba66f548fd5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 5 May 2026 13:52:56 -0400 Subject: [PATCH 01/27] Reclassify architecture-impacting fields under FieldHint.architecture Eight config fields whose values directly affect model architecture were tagged as feature/core/(none). They drive the upcoming declarative-converter coverage check, which uses FieldHint.architecture as the source of truth for "must be handled by every checkpoint format". - AttentionConfig.dense_layer (output projection presence) - AttentionConfig.softmax_scale_power (attention scaling) - MLPConfig.activation (forward-pass activation type) - MoEMLPConfig.router (routing weights drive token assignment) - Llama3RotaryConfig: scale_factor, low_frequency_factor, high_frequency_factor, original_context_length - YarnRotaryConfig: scale_factor, attention_factor, beta_fast, beta_slow, original_context_length - StochasticMixerConfig.main_mixer_name (selects inference mixer) - PatchEmbeddingsConfig.patch_height/patch_width (input tokenization) Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/attention/config.py | 3 ++- fast_llm/layers/attention/rotary/config.py | 18 +++++++++--------- fast_llm/layers/decoder/config.py | 2 +- fast_llm/layers/decoder/mlp/config.py | 4 ++-- fast_llm/layers/vision/config.py | 4 ++-- 5 files changed, 16 insertions(+), 15 deletions(-) diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index fcb5bfaf6..efdec6c99 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -62,7 +62,7 @@ class AttentionConfig(MixerConfig): ) dense_layer: AffineLinearConfig = Field( desc="Initialization configuration for the dense layer.", - hint=FieldHint.feature, + hint=FieldHint.architecture, ) # TODO: Review names rotary: RotaryConfig = Field( @@ -115,6 +115,7 @@ class AttentionConfig(MixerConfig): " Under Standard Parameterization (SP): default to 0.5. " " Under muP (if scaling head_size size): use 1. " " Under muP (if scaling number of heads instead of head_size): use 0.5.", + hint=FieldHint.architecture, valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) implementation: AttentionImplementation = Field( diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 80f499748..e5e5c8d34 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -78,10 +78,10 @@ class Llama3RotaryConfig(DefaultRotaryConfig): """ # TODO: Add descriptions. - scale_factor: float = Field(default=8.0, hint=FieldHint.feature) - low_frequency_factor: float = Field(default=1.0, hint=FieldHint.feature) - high_frequency_factor: float = Field(default=4.0, hint=FieldHint.feature) - original_context_length: int = Field(default=8192, hint=FieldHint.feature) + scale_factor: float = Field(default=8.0, hint=FieldHint.architecture) + low_frequency_factor: float = Field(default=1.0, hint=FieldHint.architecture) + high_frequency_factor: float = Field(default=4.0, hint=FieldHint.architecture) + original_context_length: int = Field(default=8192, hint=FieldHint.architecture) def _validate(self) -> None: super()._validate() @@ -102,20 +102,20 @@ class YarnRotaryConfig(DefaultRotaryConfig): """ # TODO: Add descriptions. - scale_factor: float = Field(default=8.0, hint=FieldHint.feature) + scale_factor: float = Field(default=8.0, hint=FieldHint.architecture) attention_factor: None | float = Field( default=None, - hint=FieldHint.feature, + hint=FieldHint.architecture, ) beta_fast: float = Field( default=32.0, - hint=FieldHint.feature, + hint=FieldHint.architecture, ) beta_slow: float = Field( default=1.0, - hint=FieldHint.feature, + hint=FieldHint.architecture, ) - original_context_length: int = Field(default=8192, hint=FieldHint.feature) + original_context_length: int = Field(default=8192, hint=FieldHint.architecture) def _validate(self) -> None: if self.attention_factor is None: diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 6ab259b2b..ea2ba5fa3 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -156,7 +156,7 @@ class StochasticMixerConfig(MixerConfig): "Used for inference/eval, checkpoint loading (receives pretrained weights), " "and checkpoint saving (only this mixer is exported). " "If None, uses the first mixer in the dict.", - hint=FieldHint.feature, + hint=FieldHint.architecture, ) seed_shift: int = Field( diff --git a/fast_llm/layers/decoder/mlp/config.py b/fast_llm/layers/decoder/mlp/config.py index 997cf9d2a..01f5bc052 100644 --- a/fast_llm/layers/decoder/mlp/config.py +++ b/fast_llm/layers/decoder/mlp/config.py @@ -62,7 +62,7 @@ class MLPConfig(MLPBaseConfig): activation: ActivationType = Field( default=None, desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", - hint=FieldHint.core, + hint=FieldHint.architecture, ) # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto recompute_level: MLPRecomputeLevel = Field( @@ -95,7 +95,7 @@ class MoEMLPConfig(MLPConfig): router: LinearConfig = Field( # TODO: Improve default? desc="Configuration for the MoE router.", - hint=FieldHint.feature, + hint=FieldHint.architecture, ) experts: int = Field( default=2, diff --git a/fast_llm/layers/vision/config.py b/fast_llm/layers/vision/config.py index 5920a85ee..47cf43391 100644 --- a/fast_llm/layers/vision/config.py +++ b/fast_llm/layers/vision/config.py @@ -34,12 +34,12 @@ class PatchEmbeddingsConfig(BlockConfig): patch_height: int = Field( default=16, desc="Height of image patches, in pixels.", - hint=FieldHint.core, + hint=FieldHint.architecture, ) patch_width: int = Field( default=16, desc="Width of image patches, in pixels.", - hint=FieldHint.core, + hint=FieldHint.architecture, ) full_precision_residual: bool = Field( default=False, From 0d393b220047684594ac06851408d9e42e07fb7a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 5 May 2026 13:53:36 -0400 Subject: [PATCH 02/27] Add declarative ConfigConverter primitives and section-converter ABC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reintroduces the declarative config-conversion shape that pre-dated PR #362, applied within the post-#362 modular per-section structure. Replaces the imperative import_config/export_config bodies with a small set of named primitives and a recursive walker driven by per-section declarations. Primitives in fast_llm.engine.checkpoint.external: - RenameConfigConverter — 1:1 path rename - ConstantExportConfigConverter — write constant on export, assert on import - ConstantImportConfigConverter — assert on export, inject on import - DefaultConfigConverter — rename with HF-side fallback - OptionalConfigConverter — emit/import only when non-sentinel - IgnoredConfigConverter — declare a field as intentionally not converted - CustomConfigConverter — escape hatch for cross-field transforms - NestedConfigConverter — recurse into a fixed-typed sub-config; flat-merges HF output into the parent (transformer side is assumed flat) - DispatchConfigConverter — runtime type dispatch for polymorphic sub-configs ConfigSectionConverter is the per-Fast-LLM-class converter base. Subclasses declare their conversion via _create_config_converters() and inherit import_config/export_config concretely. The architecture-coverage check fires only when type(config) exactly matches the converter's declared fast_llm_config_class — strict subclass types defer to a more specific converter, allowing yet-to-be-migrated subclasses (e.g., Mixtral on Llama) to call super().export_config() without tripping the parent's check on fields the parent doesn't know about. The walker is implicit: NestedConfigConverter / DispatchConfigConverter call the public import_config/export_config on the sub-converter class so subclass overrides participate, rather than a private path that bypasses them. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/engine/checkpoint/external.py | 381 ++++++++++++++++++++++++- 1 file changed, 380 insertions(+), 1 deletion(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 886c706c1..9d2e09fbb 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -1,4 +1,5 @@ import abc +import dataclasses import logging import pathlib import typing @@ -6,7 +7,7 @@ import torch from fast_llm import __version__ -from fast_llm.config import Config +from fast_llm.config import Config, FieldHint, set_nested_dict_value from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointLoadMetadataConfig from fast_llm.engine.checkpoint.state_dict import StateDictCheckpointHandler @@ -18,6 +19,384 @@ logger = logging.getLogger(__name__) +_MISSING = object() + + +def _get_nested(d: dict, path: tuple[str, ...], default=_MISSING): + cur = d + for key in path: + if not isinstance(cur, dict) or key not in cur: + if default is _MISSING: + raise KeyError(f"Missing key {'.'.join(path)} in HF config dict") + return default + cur = cur[key] + return cur + + +def _has_nested(d: dict, path: tuple[str, ...]) -> bool: + cur = d + for key in path: + if not isinstance(cur, dict) or key not in cur: + return False + cur = cur[key] + return True + + +def _get_attr_path(config: Config, path: tuple[str, ...]) -> typing.Any: + cur = config + for name in path: + cur = getattr(cur, name) + return cur + + +# ============================================================ +# Config conversion primitives (declarative) +# ============================================================ + + +class ConfigConverter(abc.ABC): + """A declarative description of how one or more Fast-LLM config fields map to one or more HF config keys. + + Each primitive owns a set of ``fast_llm_paths`` (tuples of attribute names rooted at the section's config) and + ``hf_paths`` (tuples of dict keys rooted at the section's HF subdict). The walker calls ``export_to`` to produce + HF entries from a Fast-LLM config object, and ``import_to`` to produce a Fast-LLM config dict from an HF dict. + """ + + fast_llm_paths: tuple[tuple[str, ...], ...] = () + hf_paths: tuple[tuple[str, ...], ...] = () + + @property + def consumed_fast_llm_fields(self) -> set[str]: + """Top-level Fast-LLM field names this primitive consumes at the current section level. + + Used by the section walker for the architecture-hint coverage check. + """ + return {path[0] for path in self.fast_llm_paths if path} + + @abc.abstractmethod + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: ... + + @abc.abstractmethod + def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: ... + + +class RenameConfigConverter(ConfigConverter): + """One-to-one rename between a Fast-LLM attribute path and an HF dict path.""" + + def __init__(self, fast_llm_path: tuple[str, ...], hf_path: tuple[str, ...]): + self.fast_llm_paths = (fast_llm_path,) + self.hf_paths = (hf_path,) + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + value = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + set_nested_dict_value(hf_out, self.hf_paths[0], value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + value = _get_nested(hf_dict, self.hf_paths[0]) + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], value) + + +class ConstantExportConfigConverter(ConfigConverter): + """Write a constant to the HF dict on export. On import, assert that the HF dict has this constant value. + + Used when a HF format requires a key whose value Fast-LLM doesn't store (or always pins to a constant). + """ + + def __init__(self, hf_path: tuple[str, ...], value: typing.Any): + self.hf_paths = (hf_path,) + self._value = value + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + set_nested_dict_value(hf_out, self.hf_paths[0], self._value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + if _has_nested(hf_dict, self.hf_paths[0]): + actual = _get_nested(hf_dict, self.hf_paths[0]) + Assert.eq(actual, self._value) + + +class ConstantImportConfigConverter(ConfigConverter): + """Inject a constant into the Fast-LLM dict on import. On export, assert the config matches the constant. + + Used when a Fast-LLM field is required but the HF format implies a fixed value (e.g., gated MLP for Llama). + """ + + def __init__(self, fast_llm_path: tuple[str, ...], value: typing.Any): + self.fast_llm_paths = (fast_llm_path,) + self._value = value + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + actual = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + Assert.eq(actual, self._value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], self._value) + + +class DefaultConfigConverter(ConfigConverter): + """Rename with an HF-side fallback used when the HF key is missing on import. + + ``hf_default_fn`` is called with the full HF dict if the path is absent; otherwise it's a plain rename. + On export, behaves like ``RenameConfigConverter``. + """ + + def __init__( + self, + fast_llm_path: tuple[str, ...], + hf_path: tuple[str, ...], + hf_default_fn: typing.Callable[[dict], typing.Any], + ): + self.fast_llm_paths = (fast_llm_path,) + self.hf_paths = (hf_path,) + self._hf_default_fn = hf_default_fn + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + value = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + set_nested_dict_value(hf_out, self.hf_paths[0], value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + if _has_nested(hf_dict, self.hf_paths[0]): + value = _get_nested(hf_dict, self.hf_paths[0]) + else: + value = self._hf_default_fn(hf_dict) + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], value) + + +class OptionalConfigConverter(ConfigConverter): + """Emit/import only when the value differs from a sentinel (default ``None``). + + Useful for fields that round-trip cleanly only when present (e.g. ``window_size``). + """ + + def __init__(self, fast_llm_path: tuple[str, ...], hf_path: tuple[str, ...], sentinel: typing.Any = None): + self.fast_llm_paths = (fast_llm_path,) + self.hf_paths = (hf_path,) + self._sentinel = sentinel + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + value = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + if value != self._sentinel: + set_nested_dict_value(hf_out, self.hf_paths[0], value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + if _has_nested(hf_dict, self.hf_paths[0]): + value = _get_nested(hf_dict, self.hf_paths[0]) + if value != self._sentinel: + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], value) + + +class IgnoredConfigConverter(ConfigConverter): + """Declares Fast-LLM architecture fields as intentionally not converted by this format. + + Use when the HF format has no representation for the field and the Fast-LLM default round-trips correctly. + Acts as a no-op on both directions while satisfying the architecture-coverage check. + """ + + def __init__(self, *fast_llm_paths: tuple[str, ...]): + self.fast_llm_paths = fast_llm_paths + self.hf_paths = () + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + return + + def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + return + + +class CustomConfigConverter(ConfigConverter): + """Escape hatch for cross-field transforms (e.g., rotary, where one HF blob ↔ several Fast-LLM fields). + + The export/import callables receive the section's full config and return/produce arbitrary mappings within + the declared paths. Both ``fast_llm_paths`` and ``hf_paths`` are still declared so the coverage check works. + """ + + def __init__( + self, + fast_llm_paths: tuple[tuple[str, ...], ...], + hf_paths: tuple[tuple[str, ...], ...], + export_fn: typing.Callable[[Config], dict], + import_fn: typing.Callable[[dict, dict | None], dict], + ): + self.fast_llm_paths = fast_llm_paths + self.hf_paths = hf_paths + self._export_fn = export_fn + self._import_fn = import_fn + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + produced = self._export_fn(fast_llm_config) + for path, value in produced.items(): + set_nested_dict_value(hf_out, path, value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + produced = self._import_fn(hf_dict, parent_context) + for path, value in produced.items(): + set_nested_dict_value(fast_llm_out, path, value) + + +class NestedConfigConverter(ConfigConverter): + """Recurse into a fixed-typed sub-config field via another section converter class. + + Exists for Fast-LLM-side modularity: lets a parent converter delegate handling of a sub-config to its own + converter class. The HF side is assumed flat — the sub-converter's output is merged into the parent's HF dict. + For non-flat HF formats, use ``CustomConfigConverter``. + """ + + def __init__( + self, + fast_llm_path: tuple[str, ...], + converter_class: "type[ConfigSectionConverter]", + ): + self.fast_llm_paths = (fast_llm_path,) + self._converter_class = converter_class + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + sub_config = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + sub_hf = self._converter_class.export_config(sub_config) + for key, value in sub_hf.items(): + if key in hf_out: + Assert.eq(hf_out[key], value) + else: + hf_out[key] = value + + def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + sub_fast_llm = self._converter_class.import_config(hf_dict) + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], sub_fast_llm) + + +class DispatchConfigConverter(ConfigConverter): + """Polymorphic sub-config dispatch. + + The Fast-LLM field's runtime type selects the section converter; the HF format selects via a ``type`` discriminator. + Both registries (Fast-LLM type → converter class, HF discriminator → converter class) must agree at runtime. + """ + + def __init__( + self, + fast_llm_path: tuple[str, ...], + hf_path: tuple[str, ...] | None, + registry: "dict[type[Config], type[ConfigSectionConverter]]", + hf_discriminator_key: str = "type", + ): + self.fast_llm_paths = (fast_llm_path,) + self.hf_paths = (hf_path,) if hf_path is not None else () + self._registry = registry + self._hf_discriminator_key = hf_discriminator_key + self._hf_to_class = {cls.hf_type_name: cls for cls in registry.values() if cls.hf_type_name is not None} + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + sub_config = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + converter_class = self._registry.get(type(sub_config)) + if converter_class is None: + raise NotImplementedError( + f"No converter registered for {type(sub_config).__name__} at {'.'.join(self.fast_llm_paths[0])}" + ) + sub_hf = converter_class.export_config(sub_config) + if converter_class.hf_type_name is not None: + sub_hf = {self._hf_discriminator_key: converter_class.hf_type_name, **sub_hf} + if self.hf_paths: + set_nested_dict_value(hf_out, self.hf_paths[0], sub_hf) + else: + for key, value in sub_hf.items(): + hf_out[key] = value + + def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + sub_hf = _get_nested(hf_dict, self.hf_paths[0]) if self.hf_paths else hf_dict + type_name = sub_hf.get(self._hf_discriminator_key) + converter_class = self._hf_to_class.get(type_name) + if converter_class is None: + raise NotImplementedError( + f"No converter registered for HF discriminator {type_name!r} at " f"{'.'.join(self.fast_llm_paths[0])}" + ) + sub_fast_llm = converter_class.import_config(sub_hf) + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], sub_fast_llm) + + +# ============================================================ +# Section converter — converts one Fast-LLM config class +# ============================================================ + + +class ConfigSectionConverter(abc.ABC): + """Base class for converting one Fast-LLM ``Config`` class ↔ one HF dict subtree. + + Subclasses declare the conversion via ``_create_config_converters`` (config side) and + ``_create_weight_converters`` (weight side; receives the live config). + + Subclasses that participate in :class:`DispatchConfigConverter` set ``hf_type_name`` to the discriminator value + used by the HF format (e.g. ``"attention"``, ``"mamba"``). + """ + + fast_llm_config_class: typing.ClassVar[type[Config]] + hf_type_name: typing.ClassVar[str | None] = None + + @classmethod + @abc.abstractmethod + def _create_config_converters(cls) -> dict[str, ConfigConverter]: + """Return declarations keyed by stable string name. Subclasses override entries by re-declaring the key.""" + + @classmethod + def _create_weight_converters( + cls, config: Config, fast_llm_prefix: str, hf_prefix: str + ) -> dict[str, "WeightConverter"]: + """Return weight converters keyed by stable string name. Default is empty (no weights at this level).""" + return {} + + @classmethod + def export_config(cls, config: Config) -> dict: + """Convert a Fast-LLM config object to an HF config dict via this section's declarations.""" + declarations = cls._create_config_converters() + cls._check_architecture_coverage(config, declarations) + out: dict = {} + for converter in declarations.values(): + converter.export_to(config, out) + return out + + @classmethod + def import_config(cls, hf_dict: dict) -> dict: + """Convert an HF config dict to a Fast-LLM config dict via this section's declarations.""" + out: dict = {} + for converter in cls._create_config_converters().values(): + converter.import_to(hf_dict, out) + return out + + @classmethod + def _check_architecture_coverage(cls, config: Config, declarations: dict[str, ConfigConverter]) -> None: + """Raise if any architecture-hint field on the section's declared config class is not consumed. + + Coverage is structural (based on field hints), not value-based: every architecture field must be + explicitly accounted for, even if it currently holds its Fast-LLM default. Sub-config fields are + consumed by ``NestedConfigConverter``/``DispatchConfigConverter``, which delegate the deeper coverage + check to the nested section's own converter. + + The check only runs when ``type(config)`` exactly matches ``cls.fast_llm_config_class`` — when the + config is a strict subclass (e.g. ``MoEMLPConfig`` fed via ``super().export_config()`` from a yet-to-be- + migrated ``MixtralMLPConverter``), the subclass converter is responsible for declaring the additional + fields and running its own check. + """ + declared_class = getattr(cls, "fast_llm_config_class", None) + if declared_class is None or type(config) is not declared_class: + return + consumed: set[str] = set() + for converter in declarations.values(): + consumed |= converter.consumed_fast_llm_fields + missing: list[str] = [] + for name, field in type(config).fields(): + if field._field_type != dataclasses._FIELD: + continue + if not field.init: + continue + if field.hint != FieldHint.architecture: + continue + if name in consumed: + continue + missing.append(name) + if missing: + raise ValueError( + f"{cls.__name__}: architecture-hint fields on {type(config).__name__} " + f"have no converter declaration: {missing}" + ) + + class WeightConverter: def __init__( self, From 0c406db6e50d2272da0bf4f378c87d3ae52395fd Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 5 May 2026 13:54:07 -0400 Subject: [PATCH 03/27] Migrate Llama config converters to declarative primitives MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pilot of the new ConfigSectionConverter framework. Each Llama section converter (Normalization/MLP/Attention/Block/Embeddings/Head/BaseModel) now declares its conversion via _create_config_converters() instead of imperative import_config/export_config bodies. Weight side is unchanged. Notable shape decisions: - LlamaDecoderConverter stays as a regular (imperative) class because Fixed/Pattern block-sequence dispatch doesn't lend itself to the declarative shape. LlamaBaseModelConverter wires it in via a small CustomConfigConverter; subclasses (Mistral, Qwen2, MTP-Llama, ...) continue to plug in different block converters via block_converter_class. - _check_config is retained as an overridable classmethod and called from the linear_layers CustomConfigConverter, so Qwen2 can keep its asymmetric Q/K/V bias rule without re-implementing the export. - IgnoredConfigConverter is used for ParameterConfig sub-fields with no architecture-significant content (weight, output_weight, word_embeddings), and for prediction_heads (which Llama HF doesn't expose; subclass MTP-Llama adds it imperatively). - peft uses CustomConfigConverter to assert NoPeftConfig on export. Llama HF format cannot represent PEFT, so a configured LoRA now fails loudly rather than being silently dropped. - Rotary remains in CustomConfigConverter — the v4/v5 transformers split (rope_theta/rope_scaling vs. rope_parameters) and three rope_type variants don't fit pure rename primitives. Verified with live round-trips of Llama-3, Qwen2, Mistral, Mixtral, and MTP-Llama HF configs, plus tests/models/test_checkpoint.py for all GPT formats (139 passed, 0 failed). Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/models/gpt/conversion/llama.py | 628 ++++++++++++++---------- 1 file changed, 380 insertions(+), 248 deletions(-) diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index f8f36dc23..61c4799fe 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -5,10 +5,18 @@ import torch import transformers +from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + ConstantImportConfigConverter, + CustomConfigConverter, + DefaultConfigConverter, + IgnoredConfigConverter, IgnoreExportWeightConverter, IgnoreImportWeightConverter, + NestedConfigConverter, + RenameConfigConverter, SplitWeightConverter, WeightConverter, ) @@ -30,13 +38,18 @@ from fast_llm.models.gpt.conversion.config import LlamaCheckpointFormat from fast_llm.models.gpt.model import GPTModel from fast_llm.tensor import SafeTensorSlice -from fast_llm.utils import Assert, div, safe_merge_dicts +from fast_llm.utils import Assert, div _TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig) logger = logging.getLogger(__name__) +# ============================================================ +# Weight converters (imperative — kept as-is during config migration) +# ============================================================ + + def get_parameter_converter( fast_llm_name: str | tuple[str, ...], hf_name: str | tuple[str, ...], @@ -97,16 +110,195 @@ def get_weight_and_bias_converters( return converters -class LlamaNormalizationConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - return {"type": "rms_norm", "epsilon": config["rms_norm_eps"]} +class MLPLayer2Converter(WeightConverter): + # Similar to SplitWeightConverter, but handles the optional MLP transpose. + # Still ok for non-gated (trivial split) and biases (trivial 1d transpose) + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (merged_weight,) = weight + return tuple(t.contiguous() for t in merged_weight[:].t().chunk(len(self.export_name), dim=-1)) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + merged_weight = torch.cat([weight_[:] for weight_ in weight], dim=-1) + return (merged_weight.t().contiguous(),) + + +class QueryWeightConverter(WeightConverter): + # Hf uses the real format for rotary embeddings. + _config: AttentionConfig + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (query,) = weight + return (query,) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (query,) = weight + return (query,) + + +class KeyValueWeightConverter(WeightConverter): + # Hf uses the real format for rotary embeddings, and keeps the key and value separate. + _config: AttentionConfig + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (key_value,) = weight + key, value = key_value[:].chunk(2) + return key, value + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + key, value = weight + key_value = torch.cat([key[:], value[:]]) + return (key_value,) + + +# ============================================================ +# Config converters (declarative) +# ============================================================ + + +def _llama_rotary_export(config: AttentionConfig) -> dict: + """Build the HF rotary block(s) from a Fast-LLM rotary config. + + Returns a dict keyed by the (Llama-flat) HF paths the converter declares; values vary with rotary subtype and + the active transformers major version (v4 puts ``rope_theta`` flat with optional ``rope_scaling``; + v5 consolidates everything into ``rope_parameters``). + """ + rotary = config.rotary + rope_parameters = {"rope_theta": rotary.theta} + if type(rotary) is DefaultRotaryConfig: + rope_parameters["rope_type"] = "default" + elif type(rotary) is Llama3RotaryConfig: + rope_parameters.update( + { + "rope_type": "llama3", + "factor": rotary.scale_factor, + "low_freq_factor": rotary.low_frequency_factor, + "high_freq_factor": rotary.high_frequency_factor, + "original_max_position_embeddings": rotary.original_context_length, + } + ) + elif type(rotary) is YarnRotaryConfig: + rope_parameters.update( + { + "rope_type": "yarn", + "attention_factor": rotary.attention_factor, + "beta_fast": rotary.beta_fast, + "beta_slow": rotary.beta_slow, + "original_max_position_embeddings": rotary.original_context_length, + } + ) + else: + raise NotImplementedError(f"Unsupported rotary type: {type(rotary).__name__}") + + if _TRANSFORMERS_V4: + out: dict = {("rope_theta",): rope_parameters["rope_theta"]} + if type(rotary) is not DefaultRotaryConfig: + out[("rope_scaling",)] = {k: v for k, v in rope_parameters.items() if k != "rope_theta"} + return out + return {("rope_parameters",): rope_parameters} + + +def _llama_rotary_import(hf_dict: dict, _parent_context: dict | None) -> dict: + """Reverse of :func:`_llama_rotary_export`. Detects v4/v5 layout from the HF dict.""" + if "rope_parameters" in hf_dict: # transformers v5 + rope_params = hf_dict["rope_parameters"] + rope_theta = rope_params["rope_theta"] + else: # transformers v4 + rope_params = hf_dict.get("rope_scaling") or {} + rope_theta = hf_dict["rope_theta"] + rope_type = rope_params.get("rope_type", "default") + rotary_config: dict = {"type": rope_type, "theta": rope_theta} + if rope_type == "default": + pass + elif rope_type == "llama3": + rotary_config.update( + { + "scale_factor": rope_params["factor"], + "low_frequency_factor": rope_params["low_freq_factor"], + "high_frequency_factor": rope_params["high_freq_factor"], + "original_context_length": rope_params["original_max_position_embeddings"], + } + ) + elif rope_type == "yarn": + rotary_config.update( + { + "attention_factor": rope_params["attention_factor"], + "beta_fast": rope_params["beta_fast"], + "beta_slow": rope_params["beta_slow"], + "original_context_length": rope_params["original_max_position_embeddings"], + } + ) + else: + raise NotImplementedError(f"Unsupported rotary type: {rope_type}") + return {("rotary",): rotary_config} + + +def _llama_attention_check_config(config: AttentionConfig) -> None: + """Default attention layer-bias check for Llama-style formats. + + Subclasses that diverge (e.g. Qwen2 always-on Q/K/V biases with no dense bias) override + :py:meth:`LlamaAttentionConverter._check_config` rather than re-implementing the export function. + """ + Assert.is_(type(config), AttentionConfig) + Assert.incl(config.query_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.key_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.value_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.dense_layer.bias.enabled, (None, config.add_linear_biases)) + + +def _llama_mlp_layers_export(config: MLPConfig) -> dict: + """Validate that MLP layer biases match the parent ``add_linear_biases`` flag and emit nothing. + + Mirrors the attention-layer check: Llama does not expose per-layer bias overrides, so configurations + that disagree with the flat ``mlp_bias`` field are rejected rather than silently dropped. + """ + Assert.incl(config.layer_1.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) + return {} + + +def _llama_position_embeddings_export(config: LanguageModelEmbeddingsConfig) -> dict: + Assert.incl(config.position_embeddings.enabled, (None, False)) + return {} + + +def _llama_peft_export(config: GPTBaseModelConfig) -> dict: + """Assert PEFT is at default (no PEFT). Llama format cannot represent PEFT — non-default values change + the parameter set and must fail loudly rather than be silently dropped on export. + """ + from fast_llm.layers.common.peft.config import NoPeftConfig + + Assert.custom(isinstance, config.peft, NoPeftConfig) + return {} + + +class LlamaNormalizationConverter(ConfigSectionConverter): + """Converts ``RMSNormalizationConfig`` ↔ Llama's flat ``rms_norm_eps`` field.""" + + fast_llm_config_class = RMSNormalizationConfig @classmethod - def export_config(cls, config: RMSNormalizationConfig) -> dict: - Assert.custom(isinstance, config, RMSNormalizationConfig) - assert not config.zero_centered - return {"rms_norm_eps": config.epsilon} + def _create_config_converters(cls) -> dict: + return { + "type": ConstantImportConfigConverter(("type",), "rms_norm"), + "epsilon": RenameConfigConverter(("epsilon",), ("rms_norm_eps",)), + "weight": IgnoredConfigConverter(("weight",)), + "zero_centered": ConstantImportConfigConverter(("zero_centered",), False), + } + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -124,28 +316,36 @@ def get_converters( ) -class LlamaMLPConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - return { - "intermediate_size": config["intermediate_size"], - "add_linear_biases": config["mlp_bias"], - "activation": ActivationType.from_hf_name(config["hidden_act"]), - "gated": True, - } +class LlamaMLPConverter(ConfigSectionConverter): + """Converts ``MLPConfig`` ↔ Llama's flat ``intermediate_size``/``mlp_bias``/``hidden_act`` fields. + + Llama is always gated (``ConstantImportConfigConverter(("gated",), True)``). + """ + + fast_llm_config_class = MLPConfig @classmethod - def export_config(cls, config: MLPConfig) -> dict: - Assert.custom(isinstance, config, MLPConfig) - Assert.incl(config.layer_1.bias.enabled, (None, config.add_linear_biases)) - Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) - assert config.gated + def _create_config_converters(cls) -> dict: return { - "intermediate_size": config.intermediate_size, - "mlp_bias": config.add_linear_biases, - "hidden_act": config.activation.hf_name, + "intermediate_size": RenameConfigConverter(("intermediate_size",), ("intermediate_size",)), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("mlp_bias",)), + "activation": CustomConfigConverter( + fast_llm_paths=(("activation",),), + hf_paths=(("hidden_act",),), + export_fn=lambda c: {("hidden_act",): c.activation.hf_name}, + import_fn=lambda hf, _ctx: {("activation",): ActivationType.from_hf_name(hf["hidden_act"])}, + ), + "gated": ConstantImportConfigConverter(("gated",), True), + "layers": CustomConfigConverter( + fast_llm_paths=(("layer_1",), ("layer_2",)), + hf_paths=(), + export_fn=_llama_mlp_layers_export, + import_fn=lambda hf, _ctx: {}, + ), } + # --- weight side (imperative) --- + @classmethod def get_converters( cls, @@ -172,126 +372,61 @@ def get_converters( ] -class MLPLayer2Converter(WeightConverter): - # Similar to SplitWeightConverter, but handles the optional MLP transpose. - # Still ok for non-gated (trivial split) and biases (trivial 1d transpose) +class LlamaAttentionConverter(ConfigSectionConverter): + """Converts ``AttentionConfig`` ↔ Llama's flat attention fields. - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (merged_weight,) = weight - return tuple(t.contiguous() for t in merged_weight[:].t().chunk(len(self.export_name), dim=-1)) - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - merged_weight = torch.cat([weight_[:] for weight_ in weight], dim=-1) - return (merged_weight.t().contiguous(),) + Notable wrinkles: + - ``head_dim`` is computed from ``hidden_size // num_attention_heads`` when missing on import. + - Rotary handling is delegated to a :class:`CustomConfigConverter` because it spans v4/v5 transformers + layouts and three rotary subtypes. + - Per-layer linear biases (query/key/value/dense) are validated to match ``add_linear_biases``; Llama + does not expose layer-level overrides, so the sub-config fields are blanket-consumed. + """ + fast_llm_config_class = AttentionConfig -class LlamaAttentionConverter: @classmethod - def import_config(cls, config: dict) -> dict: - # Normalize rope params to a single dict before dispatching on rope_type. - # transformers v5 consolidates rope_theta + rope_scaling into rope_parameters. - # transformers v4: rope_theta at top level, rope_scaling dict for non-default types. - # Note: detection is on checkpoint format, not transformers version — old checkpoints - # remain loadable with v5 transformers. - if "rope_parameters" in config: # transformers v5 - rope_params = config["rope_parameters"] - rope_theta = rope_params["rope_theta"] - else: # transformers v4 - rope_params = config.get("rope_scaling") or {} - rope_theta = config["rope_theta"] - rope_type = rope_params.get("rope_type", "default") - rotary_config = {"type": rope_type, "theta": rope_theta} - if rope_type == "default": - pass - elif rope_type == "llama3": - rotary_config.update( - { - "scale_factor": rope_params["factor"], - "low_frequency_factor": rope_params["low_freq_factor"], - "high_frequency_factor": rope_params["high_freq_factor"], - "original_context_length": rope_params["original_max_position_embeddings"], - } - ) - elif rope_type == "yarn": - rotary_config.update( - { - "attention_factor": rope_params["attention_factor"], - "beta_fast": rope_params["beta_fast"], - "beta_slow": rope_params["beta_slow"], - "original_context_length": rope_params["original_max_position_embeddings"], - } - ) - else: - raise NotImplementedError(f"Unsupported rotary type: {rope_type}") - out = { - "rotary": rotary_config, - "heads": config["num_attention_heads"], - "head_groups": config["num_key_value_heads"], - "head_size": config.get("head_dim"), - "add_linear_biases": config["attention_bias"], - "dropout": config["attention_dropout"], - } - if out["head_size"] is None: - out["head_size"] = div(config["hidden_size"], out["heads"]) + def _check_config(cls, config: AttentionConfig) -> None: + """Hook for subclasses to enforce format-specific per-layer bias rules. - return out + Default: Llama requires per-layer biases to be unset (``None``) or to match the parent + ``add_linear_biases``. Subclasses (e.g. Qwen2) override to relax or replace this rule. + """ + _llama_attention_check_config(config) @classmethod - def export_config(cls, config: AttentionConfig) -> dict: - cls._check_config(config) - Assert.eq(config.softmax_scale_power, 0.5) - rope_parameters = {"rope_theta": config.rotary.theta} - if type(config.rotary) is DefaultRotaryConfig: - rope_parameters["rope_type"] = "default" - elif type(config.rotary) is Llama3RotaryConfig: - rope_parameters.update( - { - "rope_type": "llama3", - "factor": config.rotary.scale_factor, - "low_freq_factor": config.rotary.low_frequency_factor, - "high_freq_factor": config.rotary.high_frequency_factor, - "original_max_position_embeddings": config.rotary.original_context_length, - } - ) - elif type(config.rotary) is YarnRotaryConfig: - rope_parameters.update( - { - "rope_type": "yarn", - "attention_factor": config.rotary.attention_factor, - "beta_fast": config.rotary.beta_fast, - "beta_slow": config.rotary.beta_slow, - "original_max_position_embeddings": config.rotary.original_context_length, - } - ) - else: - raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") - - common = { - "num_attention_heads": config.heads, - "num_key_value_heads": config.head_groups, - "head_dim": config.head_size, - "attention_bias": config.add_linear_biases, - "attention_dropout": config.dropout, + def _create_config_converters(cls) -> dict: + def _layers_export(config: AttentionConfig) -> dict: + cls._check_config(config) + return {} + + return { + "heads": RenameConfigConverter(("heads",), ("num_attention_heads",)), + "head_groups": RenameConfigConverter(("head_groups",), ("num_key_value_heads",)), + "head_size": DefaultConfigConverter( + ("head_size",), + ("head_dim",), + hf_default_fn=lambda hf: div(hf["hidden_size"], hf["num_attention_heads"]), + ), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("attention_bias",)), + "dropout": RenameConfigConverter(("dropout",), ("attention_dropout",)), + "causal": ConstantImportConfigConverter(("causal",), True), + "softmax_scale_power": ConstantImportConfigConverter(("softmax_scale_power",), 0.5), + "linear_layers": CustomConfigConverter( + fast_llm_paths=(("query_layer",), ("key_layer",), ("value_layer",), ("dense_layer",)), + hf_paths=(), + export_fn=_layers_export, + import_fn=lambda hf, _ctx: {}, + ), + "rotary": CustomConfigConverter( + fast_llm_paths=(("rotary",),), + hf_paths=(("rope_parameters",), ("rope_theta",), ("rope_scaling",)), + export_fn=_llama_rotary_export, + import_fn=_llama_rotary_import, + ), } - if _TRANSFORMERS_V4: - out = {**common, "rope_theta": rope_parameters["rope_theta"]} - if type(config.rotary) is not DefaultRotaryConfig: - out["rope_scaling"] = {k: v for k, v in rope_parameters.items() if k != "rope_theta"} - return out - return {**common, "rope_parameters": rope_parameters} - @classmethod - def _check_config(cls, config: AttentionConfig) -> None: - # Opportunity to make derived classes less constrained. - Assert.is_(type(config), AttentionConfig) - Assert.incl(config.query_layer.bias.enabled, (None, config.add_linear_biases)) - Assert.incl(config.key_layer.bias.enabled, (None, config.add_linear_biases)) - Assert.incl(config.value_layer.bias.enabled, (None, config.add_linear_biases)) - Assert.incl(config.dense_layer.bias.enabled, (None, config.add_linear_biases)) + # --- weight side (imperative) --- @classmethod def get_converters( @@ -327,67 +462,29 @@ def get_converters( ] -class QueryWeightConverter(WeightConverter): - # Hf uses the real format for rotary embeddings. - _config: AttentionConfig - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (query,) = weight - return (query,) - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (query,) = weight - return (query,) - - -class KeyValueWeightConverter(WeightConverter): - # Hf uses the real format for rotary embeddings, and keeps the key and value separate. - _config: AttentionConfig +class LlamaBlockConverter(ConfigSectionConverter): + """Converts ``DecoderBlockConfig`` ↔ Llama block fields (flat-merged into the parent's HF dict).""" - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (key_value,) = weight - key, value = key_value[:].chunk(2) - return key, value - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - key, value = weight - key_value = torch.cat([key[:], value[:]]) - return (key_value,) + fast_llm_config_class = DecoderBlockConfig + mixer_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaAttentionConverter + mlp_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaMLPConverter + normalization_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaNormalizationConverter -class LlamaBlockConverter: - mixer_converter_class: typing.ClassVar[type[LlamaAttentionConverter]] = LlamaAttentionConverter - mlp_converter_class: typing.ClassVar[type[LlamaMLPConverter]] = LlamaMLPConverter - normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter hf_mixer_name: typing.ClassVar[str] = "self_attn" hf_mlp_name: typing.ClassVar[str] = "mlp" hf_norm_1_name: typing.ClassVar[str] = "input_layernorm" hf_norm_2_name: typing.ClassVar[str] = "post_attention_layernorm" @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - "mixer": cls.mixer_converter_class.import_config(config), - "mlp": cls.mlp_converter_class.import_config(config), - "normalization": cls.normalization_converter_class.import_config(config), + "mixer": NestedConfigConverter(("mixer",), cls.mixer_converter_class), + "mlp": NestedConfigConverter(("mlp",), cls.mlp_converter_class), + "normalization": NestedConfigConverter(("normalization",), cls.normalization_converter_class), } - @classmethod - def export_config(cls, config: DecoderBlockConfig) -> dict: - Assert.custom(isinstance, config, DecoderBlockConfig) - return safe_merge_dicts( - cls.mixer_converter_class.export_config(config.mixer), - cls.mlp_converter_class.export_config(config.mlp), - cls.normalization_converter_class.export_config(config.normalization), - ) + # --- weight side (imperative) --- @classmethod def get_converters( @@ -422,34 +519,34 @@ def get_converters( class LlamaDecoderConverter: - block_converter_class: typing.ClassVar[type[LlamaBlockConverter]] = LlamaBlockConverter + """Converts ``BlockSequenceConfig`` (polymorphic Fixed/Pattern) ↔ Llama's flat block + ``num_hidden_layers``. + + Kept as a regular class (not a :class:`ConfigSectionConverter`) so it can stay imperative — the polymorphism + between Fixed/Pattern block sequences doesn't lend itself to the declarative shape, and subclasses (Mistral, + Qwen2, MTP-Llama, ...) plug in different block converters via ``block_converter_class``. + """ + + block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaBlockConverter @classmethod - def import_config(cls, config: dict) -> dict: + def import_config(cls, hf_dict: dict) -> dict: return { - "block": cls.block_converter_class.import_config(config), - "num_blocks": config["num_hidden_layers"], + "block": cls.block_converter_class.import_config(hf_dict), + "num_blocks": hf_dict["num_hidden_layers"], } @classmethod - def export_config(cls, config: FixedBlockSequenceConfig | PatternBlockSequenceConfig) -> dict: - if isinstance(config, PatternBlockSequenceConfig): - # All exported block configs must be equal - exported_block_configs = [ - safe_merge_dicts( - cls.block_converter_class.export_config(block_config), - {"num_hidden_layers": config.num_blocks}, - ) - for block_config in config.blocks.values() - ] - for other in exported_block_configs[1:]: - Assert.eq(exported_block_configs[0], other) - return exported_block_configs[0] - Assert.custom(isinstance, config, FixedBlockSequenceConfig) - return safe_merge_dicts( - cls.block_converter_class.export_config(config.block), - {"num_hidden_layers": config.num_blocks}, - ) + def export_config(cls, decoder_config: FixedBlockSequenceConfig | PatternBlockSequenceConfig) -> dict: + if isinstance(decoder_config, PatternBlockSequenceConfig): + exports = [cls.block_converter_class.export_config(block) for block in decoder_config.blocks.values()] + for other in exports[1:]: + Assert.eq(exports[0], other) + block_hf = exports[0] + elif isinstance(decoder_config, FixedBlockSequenceConfig): + block_hf = cls.block_converter_class.export_config(decoder_config.block) + else: + raise NotImplementedError(f"Unsupported decoder type: {type(decoder_config).__name__}") + return {**block_hf, "num_hidden_layers": decoder_config.num_blocks} @classmethod def get_converters( @@ -459,11 +556,10 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - # In the case of PatternBlockSequenceConfig, compatibility was already checked in export_config block_config = ( config.block if isinstance(config, FixedBlockSequenceConfig) else next(iter(config.blocks.values())) ) - converters = [] + converters: list[WeightConverter] = [] for block_index in range(config.num_blocks): converters += cls.block_converter_class.get_converters( block_config, @@ -474,16 +570,29 @@ def get_converters( return converters -class LlamaEmbeddingsConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - return {"vocab_size": config["vocab_size"]} +class LlamaEmbeddingsConverter(ConfigSectionConverter): + """Converts ``LanguageModelEmbeddingsConfig`` ↔ Llama (flat ``vocab_size``). + + Llama has no learnable position embeddings; ``num_position_embeddings`` is irrelevant when + ``position_embeddings.enabled`` is ``False``/``None`` and is therefore blanket-consumed. + """ + + fast_llm_config_class = LanguageModelEmbeddingsConfig @classmethod - def export_config(cls, config: LanguageModelEmbeddingsConfig) -> dict: - Assert.custom(isinstance, config, LanguageModelEmbeddingsConfig) - assert not config.position_embeddings.enabled - return {"vocab_size": config.vocab_size} + def _create_config_converters(cls) -> dict: + return { + "vocab_size": RenameConfigConverter(("vocab_size",), ("vocab_size",)), + "word_embeddings": IgnoredConfigConverter(("word_embeddings",)), + "position_embeddings": CustomConfigConverter( + fast_llm_paths=(("position_embeddings",), ("num_position_embeddings",)), + hf_paths=(), + export_fn=_llama_position_embeddings_export, + import_fn=lambda hf, _ctx: {}, + ), + } + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -492,18 +601,27 @@ def get_converters( return [WeightConverter(f"{fast_llm_prefix}.word_embeddings_weight", f"{hf_prefix}.embed_tokens.weight")] -class LlamaHeadConverter: - normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter - block_converter_class: typing.ClassVar[type[LlamaBlockConverter]] = LlamaBlockConverter +class LlamaHeadConverter(ConfigSectionConverter): + """Converts ``LanguageModelHeadConfig`` ↔ Llama final-norm fields (flat-merged).""" - @classmethod - def import_config(cls, config: dict) -> dict: - return {"normalization": cls.normalization_converter_class.import_config(config)} + fast_llm_config_class = LanguageModelHeadConfig + + normalization_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaNormalizationConverter + # Used by MTP-Llama subclass to emit per-prediction-head block weight converters; Llama itself doesn't read it. + block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaBlockConverter @classmethod - def export_config(cls, config: LanguageModelHeadConfig) -> dict: - Assert.custom(isinstance, config, LanguageModelHeadConfig) - return cls.normalization_converter_class.export_config(config.normalization) + def _create_config_converters(cls) -> dict: + return { + "normalization": NestedConfigConverter(("normalization",), cls.normalization_converter_class), + "output_weight": IgnoredConfigConverter(("output_weight",)), + # ``prediction_heads`` is architecture (>1 enables multi-token prediction); Llama HF format does + # not represent it. We don't pin it to 1 here so MTP-Llama (a Llama-derived format) can override + # the declaration with a Rename without first hitting an assertion in the inherited path. + "prediction_heads": IgnoredConfigConverter(("prediction_heads",)), + } + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -526,34 +644,48 @@ def get_converters( ] -class LlamaBaseModelConverter(HuggingFaceBaseModelConverter): +class LlamaBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): + """Top-level converter for ``GPTBaseModelConfig`` ↔ Llama HF dict.""" + + fast_llm_config_class = GPTBaseModelConfig + # TODO: Peft? decoder_converter_class: typing.ClassVar[type[LlamaDecoderConverter]] = LlamaDecoderConverter - embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter - head_converter_class: typing.ClassVar[type[LlamaHeadConverter]] = LlamaHeadConverter + embeddings_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaEmbeddingsConverter + head_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaHeadConverter @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: + decoder_converter_class = cls.decoder_converter_class + + def _decoder_export(parent: Config) -> dict: + return {(k,): v for k, v in decoder_converter_class.export_config(parent.decoder).items()} + + def _decoder_import(hf_dict: dict, _parent_context: dict | None) -> dict: + return {("decoder",): decoder_converter_class.import_config(hf_dict)} + return { - "embeddings": cls.embeddings_converter_class.import_config(config), - "decoder": cls.decoder_converter_class.import_config(config), - "head": cls.head_converter_class.import_config(config), - "hidden_size": config["hidden_size"], - "tied_embedding_weight": config["tie_word_embeddings"], + "embeddings": NestedConfigConverter(("embeddings",), cls.embeddings_converter_class), + "head": NestedConfigConverter(("head",), cls.head_converter_class), + "decoder": CustomConfigConverter( + fast_llm_paths=(("decoder",),), + hf_paths=(("num_hidden_layers",),), + export_fn=_decoder_export, + import_fn=_decoder_import, + ), + "hidden_size": RenameConfigConverter(("hidden_size",), ("hidden_size",)), + "tied_embedding_weight": RenameConfigConverter(("tied_embedding_weight",), ("tie_word_embeddings",)), + # Llama format does not represent PEFT. Strictly assert ``NoPeftConfig`` on export so a + # user-configured LoRA fails clearly rather than being silently dropped. + "peft": CustomConfigConverter( + fast_llm_paths=(("peft",),), + hf_paths=(), + export_fn=_llama_peft_export, + import_fn=lambda hf, _ctx: {}, + ), } - @classmethod - def export_config(cls, config: GPTBaseModelConfig) -> dict: - Assert.custom(isinstance, config, GPTBaseModelConfig) - return safe_merge_dicts( - cls.embeddings_converter_class.export_config(config.embeddings), - cls.decoder_converter_class.export_config(config.decoder), - cls.head_converter_class.export_config(config.head), - { - "tie_word_embeddings": config.tied_embedding_weight, - "hidden_size": config.hidden_size, - }, - ) + # --- weight side (imperative) --- @classmethod def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: From 0a061c4adc93b52492023b329a500fb20ded040b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 6 May 2026 07:14:48 -0400 Subject: [PATCH 04/27] Address PR review: validator hook + cleanup Adds `_validate_export(cls, config)` classmethod hook on `ConfigSectionConverter`, called automatically from `export_config` after the architecture-coverage check. Replaces five `CustomConfigConverter`-as-validator blocks (`linear_layers`/`layers` in attention and MLP, `position_embeddings` in embeddings, `peft` in base model, plus the `_check_config` chain on attention) with `IgnoredConfigConverter` for field-claiming + small `_validate_export` overrides. Mistral and Qwen2 rename their `_check_config` overrides accordingly; Pixtral's imperative export updates its `cls._check_config(config)` call site. Also addresses several reviewer-flagged correctness/cleanup items: - Drop the half-removed `parent_context` parameter from every primitive's `import_to` signature (and from `CustomConfigConverter`'s `import_fn`). It was unreachable through the walker. - `_check_architecture_coverage` now reads `cls.fast_llm_config_class` directly instead of `getattr(..., None)`, surfacing missing class-attribute declarations as `AttributeError` rather than silently disabling the safety net. - Drop the unused `hf_paths` parameter from `CustomConfigConverter.__init__`. There is no symmetric HF-side coverage check yet, so the field was cosmetic. - Add a TODO note in `_check_architecture_coverage` documenting that the `MoEMLPConfig`/`MambaConfig`/etc. safety net is gated on later migrations. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/engine/checkpoint/external.py | 59 ++++---- fast_llm/models/gpt/conversion/llama.py | 126 ++++++------------ fast_llm/models/gpt/conversion/mistral.py | 2 +- fast_llm/models/gpt/conversion/qwen2.py | 2 +- .../models/multimodal/conversion/llava.py | 2 +- 5 files changed, 77 insertions(+), 114 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 9d2e09fbb..6c98ccf3b 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -77,7 +77,7 @@ def consumed_fast_llm_fields(self) -> set[str]: def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: ... @abc.abstractmethod - def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: ... + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: ... class RenameConfigConverter(ConfigConverter): @@ -91,7 +91,7 @@ def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: value = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) set_nested_dict_value(hf_out, self.hf_paths[0], value) - def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: value = _get_nested(hf_dict, self.hf_paths[0]) set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], value) @@ -109,7 +109,7 @@ def __init__(self, hf_path: tuple[str, ...], value: typing.Any): def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: set_nested_dict_value(hf_out, self.hf_paths[0], self._value) - def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: if _has_nested(hf_dict, self.hf_paths[0]): actual = _get_nested(hf_dict, self.hf_paths[0]) Assert.eq(actual, self._value) @@ -129,7 +129,7 @@ def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: actual = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) Assert.eq(actual, self._value) - def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], self._value) @@ -154,7 +154,7 @@ def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: value = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) set_nested_dict_value(hf_out, self.hf_paths[0], value) - def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: if _has_nested(hf_dict, self.hf_paths[0]): value = _get_nested(hf_dict, self.hf_paths[0]) else: @@ -178,7 +178,7 @@ def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: if value != self._sentinel: set_nested_dict_value(hf_out, self.hf_paths[0], value) - def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: if _has_nested(hf_dict, self.hf_paths[0]): value = _get_nested(hf_dict, self.hf_paths[0]) if value != self._sentinel: @@ -199,26 +199,26 @@ def __init__(self, *fast_llm_paths: tuple[str, ...]): def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: return - def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: return class CustomConfigConverter(ConfigConverter): """Escape hatch for cross-field transforms (e.g., rotary, where one HF blob ↔ several Fast-LLM fields). - The export/import callables receive the section's full config and return/produce arbitrary mappings within - the declared paths. Both ``fast_llm_paths`` and ``hf_paths`` are still declared so the coverage check works. + ``fast_llm_paths`` is declared so the coverage check sees the fields as consumed. The HF side is intentionally + not declared — there is no symmetric HF-side coverage check yet, so an ``hf_paths`` argument would be cosmetic. + Cross-field validators that produce nothing on the HF side belong on :py:meth:`ConfigSectionConverter._validate_export` + instead; this primitive is for shape-changing transforms. """ def __init__( self, fast_llm_paths: tuple[tuple[str, ...], ...], - hf_paths: tuple[tuple[str, ...], ...], export_fn: typing.Callable[[Config], dict], - import_fn: typing.Callable[[dict, dict | None], dict], + import_fn: typing.Callable[[dict], dict], ): self.fast_llm_paths = fast_llm_paths - self.hf_paths = hf_paths self._export_fn = export_fn self._import_fn = import_fn @@ -227,8 +227,8 @@ def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: for path, value in produced.items(): set_nested_dict_value(hf_out, path, value) - def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: - produced = self._import_fn(hf_dict, parent_context) + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + produced = self._import_fn(hf_dict) for path, value in produced.items(): set_nested_dict_value(fast_llm_out, path, value) @@ -258,7 +258,7 @@ def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: else: hf_out[key] = value - def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: sub_fast_llm = self._converter_class.import_config(hf_dict) set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], sub_fast_llm) @@ -299,7 +299,7 @@ def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: for key, value in sub_hf.items(): hf_out[key] = value - def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | None = None) -> None: + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: sub_hf = _get_nested(hf_dict, self.hf_paths[0]) if self.hf_paths else hf_dict type_name = sub_hf.get(self._hf_discriminator_key) converter_class = self._hf_to_class.get(type_name) @@ -319,8 +319,10 @@ def import_to(self, hf_dict: dict, fast_llm_out: dict, parent_context: dict | No class ConfigSectionConverter(abc.ABC): """Base class for converting one Fast-LLM ``Config`` class ↔ one HF dict subtree. - Subclasses declare the conversion via ``_create_config_converters`` (config side) and - ``_create_weight_converters`` (weight side; receives the live config). + Subclasses declare the conversion via ``_create_config_converters``. Format-specific cross-field + invariants go on the ``_validate_export`` hook. The weight side is still imperative (per-converter + ``get_converters`` classmethods on the concrete subclasses); a declarative weight-side primitive will be + added when the weight-converter migration lands. Subclasses that participate in :class:`DispatchConfigConverter` set ``hf_type_name`` to the discriminator value used by the HF format (e.g. ``"attention"``, ``"mamba"``). @@ -335,17 +337,22 @@ def _create_config_converters(cls) -> dict[str, ConfigConverter]: """Return declarations keyed by stable string name. Subclasses override entries by re-declaring the key.""" @classmethod - def _create_weight_converters( - cls, config: Config, fast_llm_prefix: str, hf_prefix: str - ) -> dict[str, "WeightConverter"]: - """Return weight converters keyed by stable string name. Default is empty (no weights at this level).""" - return {} + def _validate_export(cls, config: Config) -> None: + """Hook for format-specific export-time validation. Default no-op. + + Runs after the architecture-coverage check and before any declaration emits. Use this for cross-field + invariants the format imposes on the Fast-LLM config (e.g. per-layer biases must match a parent flag, + certain sub-configs must be at their default). Subclasses override; super-calls are not required when + the rule is fully replaced (e.g. Qwen2 vs Llama attention biases). + """ + return @classmethod def export_config(cls, config: Config) -> dict: """Convert a Fast-LLM config object to an HF config dict via this section's declarations.""" declarations = cls._create_config_converters() cls._check_architecture_coverage(config, declarations) + cls._validate_export(config) out: dict = {} for converter in declarations.values(): converter.export_to(config, out) @@ -371,10 +378,10 @@ def _check_architecture_coverage(cls, config: Config, declarations: dict[str, Co The check only runs when ``type(config)`` exactly matches ``cls.fast_llm_config_class`` — when the config is a strict subclass (e.g. ``MoEMLPConfig`` fed via ``super().export_config()`` from a yet-to-be- migrated ``MixtralMLPConverter``), the subclass converter is responsible for declaring the additional - fields and running its own check. + fields and running its own check. TODO: Once Mixtral/Apriel/Apriel2 migrate, the safety net for + ``MoEMLPConfig``/``MambaConfig``/etc. is gated on those migrations landing. """ - declared_class = getattr(cls, "fast_llm_config_class", None) - if declared_class is None or type(config) is not declared_class: + if type(config) is not cls.fast_llm_config_class: return consumed: set[str] = set() for converter in declarations.values(): diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 61c4799fe..578098764 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -210,7 +210,7 @@ def _llama_rotary_export(config: AttentionConfig) -> dict: return {("rope_parameters",): rope_parameters} -def _llama_rotary_import(hf_dict: dict, _parent_context: dict | None) -> dict: +def _llama_rotary_import(hf_dict: dict) -> dict: """Reverse of :func:`_llama_rotary_export`. Detects v4/v5 layout from the HF dict.""" if "rope_parameters" in hf_dict: # transformers v5 rope_params = hf_dict["rope_parameters"] @@ -245,45 +245,6 @@ def _llama_rotary_import(hf_dict: dict, _parent_context: dict | None) -> dict: return {("rotary",): rotary_config} -def _llama_attention_check_config(config: AttentionConfig) -> None: - """Default attention layer-bias check for Llama-style formats. - - Subclasses that diverge (e.g. Qwen2 always-on Q/K/V biases with no dense bias) override - :py:meth:`LlamaAttentionConverter._check_config` rather than re-implementing the export function. - """ - Assert.is_(type(config), AttentionConfig) - Assert.incl(config.query_layer.bias.enabled, (None, config.add_linear_biases)) - Assert.incl(config.key_layer.bias.enabled, (None, config.add_linear_biases)) - Assert.incl(config.value_layer.bias.enabled, (None, config.add_linear_biases)) - Assert.incl(config.dense_layer.bias.enabled, (None, config.add_linear_biases)) - - -def _llama_mlp_layers_export(config: MLPConfig) -> dict: - """Validate that MLP layer biases match the parent ``add_linear_biases`` flag and emit nothing. - - Mirrors the attention-layer check: Llama does not expose per-layer bias overrides, so configurations - that disagree with the flat ``mlp_bias`` field are rejected rather than silently dropped. - """ - Assert.incl(config.layer_1.bias.enabled, (None, config.add_linear_biases)) - Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) - return {} - - -def _llama_position_embeddings_export(config: LanguageModelEmbeddingsConfig) -> dict: - Assert.incl(config.position_embeddings.enabled, (None, False)) - return {} - - -def _llama_peft_export(config: GPTBaseModelConfig) -> dict: - """Assert PEFT is at default (no PEFT). Llama format cannot represent PEFT — non-default values change - the parameter set and must fail loudly rather than be silently dropped on export. - """ - from fast_llm.layers.common.peft.config import NoPeftConfig - - Assert.custom(isinstance, config.peft, NoPeftConfig) - return {} - - class LlamaNormalizationConverter(ConfigSectionConverter): """Converts ``RMSNormalizationConfig`` ↔ Llama's flat ``rms_norm_eps`` field.""" @@ -331,19 +292,19 @@ def _create_config_converters(cls) -> dict: "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("mlp_bias",)), "activation": CustomConfigConverter( fast_llm_paths=(("activation",),), - hf_paths=(("hidden_act",),), export_fn=lambda c: {("hidden_act",): c.activation.hf_name}, - import_fn=lambda hf, _ctx: {("activation",): ActivationType.from_hf_name(hf["hidden_act"])}, + import_fn=lambda hf: {("activation",): ActivationType.from_hf_name(hf["hidden_act"])}, ), "gated": ConstantImportConfigConverter(("gated",), True), - "layers": CustomConfigConverter( - fast_llm_paths=(("layer_1",), ("layer_2",)), - hf_paths=(), - export_fn=_llama_mlp_layers_export, - import_fn=lambda hf, _ctx: {}, - ), + # Llama doesn't expose per-layer bias overrides; the bias-match check lives on _validate_export. + "layers": IgnoredConfigConverter(("layer_1",), ("layer_2",)), } + @classmethod + def _validate_export(cls, config: MLPConfig) -> None: + Assert.incl(config.layer_1.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) + # --- weight side (imperative) --- @classmethod @@ -379,27 +340,15 @@ class LlamaAttentionConverter(ConfigSectionConverter): - ``head_dim`` is computed from ``hidden_size // num_attention_heads`` when missing on import. - Rotary handling is delegated to a :class:`CustomConfigConverter` because it spans v4/v5 transformers layouts and three rotary subtypes. - - Per-layer linear biases (query/key/value/dense) are validated to match ``add_linear_biases``; Llama - does not expose layer-level overrides, so the sub-config fields are blanket-consumed. + - Per-layer linear biases (query/key/value/dense) are validated to match ``add_linear_biases`` on + ``_validate_export``; Llama does not expose layer-level overrides, so the sub-config fields are + blanket-consumed via :class:`IgnoredConfigConverter`. """ fast_llm_config_class = AttentionConfig - @classmethod - def _check_config(cls, config: AttentionConfig) -> None: - """Hook for subclasses to enforce format-specific per-layer bias rules. - - Default: Llama requires per-layer biases to be unset (``None``) or to match the parent - ``add_linear_biases``. Subclasses (e.g. Qwen2) override to relax or replace this rule. - """ - _llama_attention_check_config(config) - @classmethod def _create_config_converters(cls) -> dict: - def _layers_export(config: AttentionConfig) -> dict: - cls._check_config(config) - return {} - return { "heads": RenameConfigConverter(("heads",), ("num_attention_heads",)), "head_groups": RenameConfigConverter(("head_groups",), ("num_key_value_heads",)), @@ -412,20 +361,28 @@ def _layers_export(config: AttentionConfig) -> dict: "dropout": RenameConfigConverter(("dropout",), ("attention_dropout",)), "causal": ConstantImportConfigConverter(("causal",), True), "softmax_scale_power": ConstantImportConfigConverter(("softmax_scale_power",), 0.5), - "linear_layers": CustomConfigConverter( - fast_llm_paths=(("query_layer",), ("key_layer",), ("value_layer",), ("dense_layer",)), - hf_paths=(), - export_fn=_layers_export, - import_fn=lambda hf, _ctx: {}, + "linear_layers": IgnoredConfigConverter( + ("query_layer",), ("key_layer",), ("value_layer",), ("dense_layer",) ), "rotary": CustomConfigConverter( fast_llm_paths=(("rotary",),), - hf_paths=(("rope_parameters",), ("rope_theta",), ("rope_scaling",)), export_fn=_llama_rotary_export, import_fn=_llama_rotary_import, ), } + @classmethod + def _validate_export(cls, config: AttentionConfig) -> None: + """Default: Llama requires per-layer biases to be unset (``None``) or to match ``add_linear_biases``. + + Subclasses (e.g. Qwen2 with always-on Q/K/V biases and no dense bias) override. + """ + Assert.is_(type(config), AttentionConfig) + Assert.incl(config.query_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.key_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.value_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.dense_layer.bias.enabled, (None, config.add_linear_biases)) + # --- weight side (imperative) --- @classmethod @@ -584,14 +541,13 @@ def _create_config_converters(cls) -> dict: return { "vocab_size": RenameConfigConverter(("vocab_size",), ("vocab_size",)), "word_embeddings": IgnoredConfigConverter(("word_embeddings",)), - "position_embeddings": CustomConfigConverter( - fast_llm_paths=(("position_embeddings",), ("num_position_embeddings",)), - hf_paths=(), - export_fn=_llama_position_embeddings_export, - import_fn=lambda hf, _ctx: {}, - ), + "position_embeddings": IgnoredConfigConverter(("position_embeddings",), ("num_position_embeddings",)), } + @classmethod + def _validate_export(cls, config: LanguageModelEmbeddingsConfig) -> None: + Assert.incl(config.position_embeddings.enabled, (None, False)) + # --- weight side (imperative) --- @classmethod @@ -661,7 +617,7 @@ def _create_config_converters(cls) -> dict: def _decoder_export(parent: Config) -> dict: return {(k,): v for k, v in decoder_converter_class.export_config(parent.decoder).items()} - def _decoder_import(hf_dict: dict, _parent_context: dict | None) -> dict: + def _decoder_import(hf_dict: dict) -> dict: return {("decoder",): decoder_converter_class.import_config(hf_dict)} return { @@ -669,22 +625,22 @@ def _decoder_import(hf_dict: dict, _parent_context: dict | None) -> dict: "head": NestedConfigConverter(("head",), cls.head_converter_class), "decoder": CustomConfigConverter( fast_llm_paths=(("decoder",),), - hf_paths=(("num_hidden_layers",),), export_fn=_decoder_export, import_fn=_decoder_import, ), "hidden_size": RenameConfigConverter(("hidden_size",), ("hidden_size",)), "tied_embedding_weight": RenameConfigConverter(("tied_embedding_weight",), ("tie_word_embeddings",)), - # Llama format does not represent PEFT. Strictly assert ``NoPeftConfig`` on export so a - # user-configured LoRA fails clearly rather than being silently dropped. - "peft": CustomConfigConverter( - fast_llm_paths=(("peft",),), - hf_paths=(), - export_fn=_llama_peft_export, - import_fn=lambda hf, _ctx: {}, - ), + # Llama format cannot represent PEFT; the NoPeftConfig assertion lives on _validate_export so a + # user-configured LoRA fails clearly rather than being silently dropped on export. + "peft": IgnoredConfigConverter(("peft",)), } + @classmethod + def _validate_export(cls, config: GPTBaseModelConfig) -> None: + from fast_llm.layers.common.peft.config import NoPeftConfig + + Assert.custom(isinstance, config.peft, NoPeftConfig) + # --- weight side (imperative) --- @classmethod diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index d4a669b22..106f1f0cc 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -35,7 +35,7 @@ def export_config(cls, config: AttentionConfig) -> dict: return out @classmethod - def _check_config(cls, config: AttentionConfig) -> None: + def _validate_export(cls, config: AttentionConfig) -> None: # Mistral doesn't support biases. assert not config.add_linear_biases diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index 9aa2f8c8e..437431ed1 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -45,7 +45,7 @@ def export_config(cls, config: AttentionConfig) -> dict: return out @classmethod - def _check_config(cls, config: AttentionConfig) -> None: + def _validate_export(cls, config: AttentionConfig) -> None: Assert.is_(type(config), AttentionConfig) # There are multiple ways to enable biases on QKV only if config.add_linear_biases: diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index fe7c77f5e..2a2e03502 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -60,7 +60,7 @@ def import_config(cls, config: dict) -> dict: @classmethod def export_config(cls, config: AttentionConfig) -> dict: - cls._check_config(config) + cls._validate_export(config) Assert.eq(config.softmax_scale_power, 0.5) Assert.is_(type(config.rotary), Rotary2DConfig) assert not config.add_linear_biases From 79f8364352472d26d454e4a9195cf4e9bf07bc3b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 6 May 2026 09:48:30 -0400 Subject: [PATCH 05/27] Mark PatternBlockSequenceConfig.blocks as architecture The dict of named per-block configs is unambiguously architecture metadata; without an explicit hint it defaulted to `unknown`, hiding it from the architecture-coverage check used by declarative checkpoint converters. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/block/config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index aa47a5f2e..25c5fcc82 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -146,7 +146,10 @@ def last_block_config(self) -> BlockConfig: @config_class(dynamic_type={BlockSequenceConfig: "pattern"}) class PatternBlockSequenceConfig(BlockSequenceConfig): _abstract = False - blocks: dict[str, BlockConfig] = Field() + blocks: dict[str, BlockConfig] = Field( + desc="Named block configurations referenced by `pattern`.", + hint=FieldHint.architecture, + ) pattern: list[str] = Field( default=None, desc="The name of each block (key in `blocks`) in the repeated pattern.", From 0484dc524bc4754e9ca69824a3053414250fa7ab Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 6 May 2026 09:48:44 -0400 Subject: [PATCH 06/27] Extend converter framework with nested-HF and typed-dict primitives Two additions, both required by Apriel2's nested HF schema: - `NestedConfigConverter` gains an optional `hf_path` kwarg. When set, the sub-converter's output is placed under that nested key instead of being flat-merged. Existing flat-merge behavior is unchanged when `hf_path` is omitted. - New `TypedDictContainerConfigConverter` for `dict[str, Config]` fields where each entry is round-tripped through a per-class section converter. Polymorphic dispatch via the entry's runtime type on export and the HF discriminator on import. A homogeneous mode (single registered class with `hf_type_name = None`) skips the discriminator entirely. Both `DispatchConfigConverter` and `TypedDictContainerConfigConverter` now also inject the Fast-LLM `dynamic_type_name` discriminator into the imported sub-dict so the parent's `from_dict` dispatches to the right `Config` subclass without a separate ConstantImport. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/engine/checkpoint/external.py | 98 +++++++++++++++++++++++--- 1 file changed, 89 insertions(+), 9 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 6c98ccf3b..851ea13b0 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -236,30 +236,38 @@ def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: class NestedConfigConverter(ConfigConverter): """Recurse into a fixed-typed sub-config field via another section converter class. - Exists for Fast-LLM-side modularity: lets a parent converter delegate handling of a sub-config to its own - converter class. The HF side is assumed flat — the sub-converter's output is merged into the parent's HF dict. - For non-flat HF formats, use ``CustomConfigConverter``. + Default (``hf_path=None``): the HF side is flat-merged — the sub-converter's output becomes top-level keys + of the parent's HF dict, asserting any pre-existing keys agree. + + With ``hf_path`` set: the sub-converter's output is placed under that nested key. Use this for HF formats + that mirror Fast-LLM's modular layout (e.g. Apriel2's ``"decoder": {...}`` and ``"head": {...}`` blocks). """ def __init__( self, fast_llm_path: tuple[str, ...], converter_class: "type[ConfigSectionConverter]", + hf_path: tuple[str, ...] | None = None, ): self.fast_llm_paths = (fast_llm_path,) self._converter_class = converter_class + self._hf_path = hf_path def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: sub_config = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) sub_hf = self._converter_class.export_config(sub_config) - for key, value in sub_hf.items(): - if key in hf_out: - Assert.eq(hf_out[key], value) - else: - hf_out[key] = value + if self._hf_path is None: + for key, value in sub_hf.items(): + if key in hf_out: + Assert.eq(hf_out[key], value) + else: + hf_out[key] = value + else: + set_nested_dict_value(hf_out, self._hf_path, sub_hf) def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: - sub_fast_llm = self._converter_class.import_config(hf_dict) + sub_hf = _get_nested(hf_dict, self._hf_path) if self._hf_path is not None else hf_dict + sub_fast_llm = self._converter_class.import_config(sub_hf) set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], sub_fast_llm) @@ -308,9 +316,81 @@ def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: f"No converter registered for HF discriminator {type_name!r} at " f"{'.'.join(self.fast_llm_paths[0])}" ) sub_fast_llm = converter_class.import_config(sub_hf) + # Inject the Fast-LLM dynamic-type discriminator so the parent's `from_dict` dispatches to the + # correct subclass. Reads from the registered Config class rather than the HF discriminator so + # mismatched Fast-LLM/HF type names work too. + fast_llm_type = getattr(converter_class.fast_llm_config_class, "dynamic_type_name", None) + if fast_llm_type is not None: + sub_fast_llm = {"type": fast_llm_type, **sub_fast_llm} set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], sub_fast_llm) +class TypedDictContainerConfigConverter(ConfigConverter): + """Maps a Fast-LLM ``dict[str, Config]`` field to an HF ``dict[str, dict]`` where each entry is round-tripped + through a per-class section converter selected via the entry's runtime type (export) or HF discriminator (import). + + Each entry's HF subdict carries a discriminator key (``"type"`` by default) populated from the converter's + ``hf_type_name``. For homogeneous dicts, register a single class with ``hf_type_name = None``; the discriminator + is then omitted on export and ignored on import. + """ + + def __init__( + self, + fast_llm_path: tuple[str, ...], + hf_path: tuple[str, ...], + registry: "dict[type[Config], type[ConfigSectionConverter]]", + hf_discriminator_key: str = "type", + ): + self.fast_llm_paths = (fast_llm_path,) + self.hf_paths = (hf_path,) + self._registry = registry + self._hf_discriminator_key = hf_discriminator_key + self._hf_to_class = {cls.hf_type_name: cls for cls in registry.values() if cls.hf_type_name is not None} + self._homogeneous = len(registry) == 1 and next(iter(registry.values())).hf_type_name is None + if self._homogeneous: + self._homogeneous_class = next(iter(registry.values())) + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + sub_dict = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + out: dict = {} + for name, sub_config in sub_dict.items(): + if self._homogeneous: + converter_class = self._homogeneous_class + else: + converter_class = self._registry.get(type(sub_config)) + if converter_class is None: + raise NotImplementedError( + f"No converter registered for {type(sub_config).__name__} at " + f"{'.'.join(self.fast_llm_paths[0])}[{name!r}]" + ) + sub_hf = converter_class.export_config(sub_config) + if converter_class.hf_type_name is not None: + sub_hf = {self._hf_discriminator_key: converter_class.hf_type_name, **sub_hf} + out[name] = sub_hf + set_nested_dict_value(hf_out, self.hf_paths[0], out) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + sub_hf_dict = _get_nested(hf_dict, self.hf_paths[0]) + out: dict = {} + for name, sub_hf in sub_hf_dict.items(): + if self._homogeneous: + converter_class = self._homogeneous_class + else: + type_name = sub_hf.get(self._hf_discriminator_key) + converter_class = self._hf_to_class.get(type_name) + if converter_class is None: + raise NotImplementedError( + f"No converter registered for HF discriminator {type_name!r} at " + f"{'.'.join(self.hf_paths[0])}[{name!r}]" + ) + sub_fast_llm = converter_class.import_config(sub_hf) + fast_llm_type = getattr(converter_class.fast_llm_config_class, "dynamic_type_name", None) + if fast_llm_type is not None: + sub_fast_llm = {"type": fast_llm_type, **sub_fast_llm} + out[name] = sub_fast_llm + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], out) + + # ============================================================ # Section converter — converts one Fast-LLM config class # ============================================================ From fcbd282fca7a56519c0341f894d225217064c7d3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 6 May 2026 09:48:58 -0400 Subject: [PATCH 07/27] Migrate Apriel2 config converters to declarative primitives MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stress-tests the framework's polymorphic dispatch and typed-dict support: Apriel2's HF schema is nested (`decoder.block.mixer.{...}`, `head.normalization`, `mixers.{name}`) and the mixer field is heterogeneously polymorphic (Attention/Mamba/StochasticMixer/GDN/KDA). Migrated converters: per-mixer (Attention/Mamba/GDN/KDA), the StochasticMixer container (driven by TypedDictContainer over a leaf-mixer registry), per-normalization (RMS/LayerNorm/NoNorm), MLP, Block, Fixed/Pattern decoder variants (selected by Dispatch on runtime BlockSequenceConfig type), Head, and BaseModel. The imperative weight-side `get_converters` methods are preserved unchanged so the multimodal Apriel2 converter (which inherits from the text-only one) keeps working without modification. PatternDecoder's `blocks` dict uses the homogeneous mode of TypedDictContainer (single-class registry, no discriminator). The attention rotary-type translation (default ↔ mistral_1d) and Mamba's auxiliary HF fields (d_conv, conv_bias, dt_proj_bias derived from linear-config bias flags) remain on `CustomConfigConverter` since they're shape-changing transforms. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/models/gpt/conversion/apriel2.py | 852 +++++++++++----------- 1 file changed, 415 insertions(+), 437 deletions(-) diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 9b6657b03..263af4d02 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -5,10 +5,32 @@ from transformers import PretrainedConfig from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + ConstantExportConfigConverter, + ConstantImportConfigConverter, + CustomConfigConverter, + DispatchConfigConverter, + IgnoredConfigConverter, + NestedConfigConverter, + OptionalConfigConverter, + RenameConfigConverter, + TypedDictContainerConfigConverter, + WeightConverter, +) from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig -from fast_llm.layers.decoder.config import DecoderBlockConfig, StochasticMixerConfig, StochasticMixerSamplingStrategy +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig +from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig +from fast_llm.layers.common.normalization.config import ( + LayerNormalizationConfig, + NoNormalizationConfig, + RMSNormalizationConfig, +) +from fast_llm.layers.decoder.config import DecoderBlockConfig, StochasticMixerConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig +from fast_llm.layers.language_model.config import LanguageModelHeadConfig from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.conversion.config import Apriel2TextCheckpointFormat @@ -25,82 +47,94 @@ from fast_llm.models.gpt.model import GPTModel from fast_llm.utils import Assert, safe_merge_dicts +# ============================================================ +# Helpers +# ============================================================ + + +def _per_layer_bias_export(config, layer_names: tuple[str, ...]) -> dict: + """Emit per-layer ``{layer: {"bias": {"enabled": bool}}}`` only for layers whose bias is explicitly set.""" + out: dict = {} + for layer_name in layer_names: + layer = getattr(config, layer_name) + if layer.bias.enabled is not None: + out[(layer_name,)] = {"bias": {"enabled": layer.bias.enabled}} + return out + + +def _per_layer_bias_import(hf_dict: dict, layer_names: tuple[str, ...]) -> dict: + """Pass through HF ``{layer: {"bias": {...}}}`` entries to the Fast-LLM dict.""" + out: dict = {} + for layer_name in layer_names: + if layer_name in hf_dict: + out[(layer_name,)] = hf_dict[layer_name] + return out + + +# ============================================================ +# Mixer converters +# ============================================================ + + +def _apriel2_attention_rotary_export(config: AttentionConfig) -> dict: + """Emit Apriel2's typed rotary subdict. + + Asymmetric with the Fast-LLM type only for the default→``mistral_1d`` rename; ``llama3``/``yarn`` round-trip + by name. Mirrors current behavior: only ``type`` and ``theta`` are emitted (scale fields are dropped). + """ + rotary = config.rotary + if type(rotary) is DefaultRotaryConfig: + rotary_type = "mistral_1d" + elif type(rotary) is Llama3RotaryConfig: + rotary_type = "llama3" + elif type(rotary) is YarnRotaryConfig: + rotary_type = "yarn" + else: + raise NotImplementedError(f"Unsupported rotary type: {type(rotary).__name__}") + return {("rotary",): {"type": rotary_type, "theta": rotary.theta}} + + +def _apriel2_attention_rotary_import(hf_dict: dict) -> dict: + rotary = dict(hf_dict["rotary"]) + if rotary.get("type") == "mistral_1d": + rotary["type"] = "default" + return {("rotary",): rotary} -class Apriel2AttentionConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - rotary = config["rotary"] - # Map Apriel2 HuggingFace rotary type to Fast-LLM internal type - if rotary.get("type") == "mistral_1d": - rotary = {**rotary, "type": "default"} - result = { - "type": "attention", - "heads": config["heads"], - "head_groups": config["head_groups"], - "head_size": config["head_size"], - "rotary": rotary, - } - # Per-layer bias configuration mirroring Fast-LLM structure - # If per-layer configs exist, use them; otherwise fall back to add_linear_biases - if "query_layer" in config: - result["query_layer"] = config["query_layer"] - if "key_layer" in config: - result["key_layer"] = config["key_layer"] - if "value_layer" in config: - result["value_layer"] = config["value_layer"] - if "dense_layer" in config: - result["dense_layer"] = config["dense_layer"] - # add_linear_biases serves as default for layers without explicit config - if "add_linear_biases" in config: - result["add_linear_biases"] = config["add_linear_biases"] - if "window_size" in config: - result["window_size"] = config["window_size"] - return result + +class Apriel2AttentionConverter(ConfigSectionConverter): + fast_llm_config_class = AttentionConfig + hf_type_name = "attention" @classmethod - def export_config(cls, config: AttentionConfig) -> dict: - from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig - - if type(config.rotary) is DefaultRotaryConfig: - rotary_type = "mistral_1d" - elif type(config.rotary) is Llama3RotaryConfig: - rotary_type = "llama3" - elif type(config.rotary) is YarnRotaryConfig: - rotary_type = "yarn" - else: - raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") - - result = { - "type": "attention", - "heads": config.heads, - "head_groups": config.head_groups, - "head_size": config.head_size, - "rotary": { - "type": rotary_type, - "theta": config.rotary.theta, - }, + def _create_config_converters(cls) -> dict: + layer_names = ("query_layer", "key_layer", "value_layer", "dense_layer") + return { + "heads": RenameConfigConverter(("heads",), ("heads",)), + "head_groups": RenameConfigConverter(("head_groups",), ("head_groups",)), + "head_size": RenameConfigConverter(("head_size",), ("head_size",)), + "rotary": CustomConfigConverter( + fast_llm_paths=(("rotary",),), + export_fn=_apriel2_attention_rotary_export, + import_fn=_apriel2_attention_rotary_import, + ), + # Apriel2 emits add_linear_biases only when False; the True default is implicit. + "add_linear_biases": OptionalConfigConverter( + ("add_linear_biases",), ("add_linear_biases",), sentinel=True + ), + "window_size": OptionalConfigConverter(("window_size",), ("window_size",)), + "linear_layers": CustomConfigConverter( + fast_llm_paths=tuple((name,) for name in layer_names), + export_fn=lambda c: _per_layer_bias_export(c, layer_names), + import_fn=lambda hf: _per_layer_bias_import(hf, layer_names), + ), + "causal": IgnoredConfigConverter(("causal",)), + "softmax_scale_power": IgnoredConfigConverter(("softmax_scale_power",)), } - if config.window_size is not None: - result["window_size"] = config.window_size - # Export per-layer bias configuration - # Only include if explicitly set (not None) - if config.query_layer.bias.enabled is not None: - result["query_layer"] = {"bias": {"enabled": config.query_layer.bias.enabled}} - if config.key_layer.bias.enabled is not None: - result["key_layer"] = {"bias": {"enabled": config.key_layer.bias.enabled}} - if config.value_layer.bias.enabled is not None: - result["value_layer"] = {"bias": {"enabled": config.value_layer.bias.enabled}} - if config.dense_layer.bias.enabled is not None: - result["dense_layer"] = {"bias": {"enabled": config.dense_layer.bias.enabled}} - # add_linear_biases as fallback default; omit when True (the Fast-LLM default) to avoid - # round-trip inflation on configs that don't set it explicitly. - if not config.add_linear_biases: - result["add_linear_biases"] = config.add_linear_biases - return result + + # --- weight side (imperative) --- @classmethod def _get_effective_bias(cls, layer_config, default: bool) -> bool: - """Get effective bias setting: use layer-specific if set, else default.""" if layer_config.bias.enabled is not None: return layer_config.bias.enabled return default @@ -113,13 +147,11 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - # Determine effective bias for each projection q_bias = cls._get_effective_bias(config.query_layer, config.add_linear_biases) k_bias = cls._get_effective_bias(config.key_layer, config.add_linear_biases) v_bias = cls._get_effective_bias(config.value_layer, config.add_linear_biases) o_bias = cls._get_effective_bias(config.dense_layer, config.add_linear_biases) - # For key_value, both k and v must have same bias setting - # (they're combined in Fast-LLM's key_value layer) + # k_proj and v_proj are merged in Fast-LLM's key_value layer; treat as biased only if both sides agree. kv_bias = k_bias and v_bias return [ @@ -148,40 +180,50 @@ def get_converters( ] -class Apriel2MambaConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - result = { - "type": "mamba", - "state_size": config["state_size"], - "d_inner": config["d_inner"], - "add_linear_biases": config["add_linear_biases"], - } - if "d_xb" in config: - result["d_xb"] = config["d_xb"] - if "dt_rank" in config: - result["dt_rank"] = config["dt_rank"] - return result +def _apriel2_mamba_aux_export(config: MambaConfig) -> dict: + """Emit Apriel2's mamba-specific HF auxiliaries (``d_conv`` from convolution kernel size, plus the + convolution and dt-projection effective bias flags). These have no flat Fast-LLM analogue.""" + return { + ("d_conv",): config.convolution_layer.kernel_size, + ("conv_bias",): config.convolution_layer.bias.enabled, + ("dt_proj_bias",): config.dt_layer.bias.enabled, + } - @classmethod - def export_config(cls, config: MambaConfig) -> dict: - exported = { - "type": "mamba", - "state_size": config.state_size, - "d_inner": config.d_inner, - "d_conv": config.convolution_layer.kernel_size, - "add_linear_biases": config.add_linear_biases, - "conv_bias": config.convolution_layer.bias.enabled, - "dt_proj_bias": config.dt_layer.bias.enabled, - } - if config.d_xb is not None: - exported["d_xb"] = config.d_xb +class Apriel2MambaConverter(ConfigSectionConverter): + fast_llm_config_class = MambaConfig + hf_type_name = "mamba" - if config.dt_rank != "auto": - exported["dt_rank"] = config.dt_rank + @classmethod + def _create_config_converters(cls) -> dict: + return { + "state_size": RenameConfigConverter(("state_size",), ("state_size",)), + "d_inner": RenameConfigConverter(("d_inner",), ("d_inner",)), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("add_linear_biases",)), + "d_xb": OptionalConfigConverter(("d_xb",), ("d_xb",)), + "dt_rank": OptionalConfigConverter(("dt_rank",), ("dt_rank",)), + "aux": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",), ("dt_layer",)), + export_fn=_apriel2_mamba_aux_export, + # The d_conv/conv_bias/dt_proj_bias HF fields are not reflected in the Fast-LLM mamba dict — + # current Apriel2 import simply ignores them and lets Fast-LLM use its own defaults. + import_fn=lambda hf: {}, + ), + # Architecture fields with no HF counterpart; they round-trip at their Fast-LLM defaults. + "layers_unmapped": IgnoredConfigConverter( + ("z_layer",), + ("x_layer",), + ("b_layer",), + ("c_layer",), + ("output_layer",), + ("dt_input_layer",), + ("a_log_weight",), + ("d_weight",), + ("repeat_kv_before_conv",), + ), + } - return exported + # --- weight side (imperative) --- @classmethod def get_converters( @@ -235,33 +277,37 @@ def get_converters( ] -class Apriel2GatedDeltaNetConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - result = { - "type": "gdn", - "value_heads": config["value_heads"], - "key_heads": config["key_heads"], - "key_head_dim": config["key_head_dim"], - "value_head_dim": config["value_head_dim"], - } - if "convolution_layer" in config: - result["convolution_layer"] = config["convolution_layer"] - return result +class Apriel2GatedDeltaNetConverter(ConfigSectionConverter): + fast_llm_config_class = GatedDeltaNetConfig + hf_type_name = "gdn" @classmethod - def export_config(cls, config: GatedDeltaNetConfig) -> dict: + def _create_config_converters(cls) -> dict: return { - "type": "gdn", - "value_heads": config.value_heads, - "key_heads": config.key_heads, - "key_head_dim": config.key_head_dim, - "value_head_dim": config.value_head_dim, - "convolution_layer": { - "kernel_size": config.convolution_layer.kernel_size, - }, + "value_heads": RenameConfigConverter(("value_heads",), ("value_heads",)), + "key_heads": RenameConfigConverter(("key_heads",), ("key_heads",)), + "key_head_dim": RenameConfigConverter(("key_head_dim",), ("key_head_dim",)), + "value_head_dim": RenameConfigConverter(("value_head_dim",), ("value_head_dim",)), + "convolution_layer": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",),), + export_fn=lambda c: {("convolution_layer",): {"kernel_size": c.convolution_layer.kernel_size}}, + import_fn=lambda hf: ( + {("convolution_layer",): hf["convolution_layer"]} if "convolution_layer" in hf else {} + ), + ), + # Architecture fields not surfaced in HF; round-trip at default. + "layers_unmapped": IgnoredConfigConverter( + ("normalization",), + ("qkv_projection_layer",), + ("ba_projection_layer",), + ("output_layer",), + ("dt_bias_weight",), + ("a_log_weight",), + ), } + # --- weight side (imperative) --- + @classmethod def get_converters( cls, @@ -314,34 +360,45 @@ def get_converters( ] -class Apriel2KimiDeltaAttentionConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - result = { - "type": "kda", - "heads": config["heads"], - "head_dim": config["head_dim"], - } - if "convolution_layer" in config: - result["convolution_layer"] = config["convolution_layer"] - if "normalization" in config: - result["normalization"] = config["normalization"] - return result +class Apriel2KimiDeltaAttentionConverter(ConfigSectionConverter): + fast_llm_config_class = KimiDeltaAttentionConfig + hf_type_name = "kda" @classmethod - def export_config(cls, config: KimiDeltaAttentionConfig) -> dict: + def _create_config_converters(cls) -> dict: return { - "type": "kda", - "heads": config.heads, - "head_dim": config.head_dim, - "convolution_layer": { - "kernel_size": config.convolution_layer.kernel_size, - }, - "normalization": { - "epsilon": config.normalization.epsilon, - }, + "heads": RenameConfigConverter(("heads",), ("heads",)), + "head_dim": RenameConfigConverter(("head_dim",), ("head_dim",)), + "convolution_layer": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",),), + export_fn=lambda c: {("convolution_layer",): {"kernel_size": c.convolution_layer.kernel_size}}, + import_fn=lambda hf: ( + {("convolution_layer",): hf["convolution_layer"]} if "convolution_layer" in hf else {} + ), + ), + "normalization": CustomConfigConverter( + fast_llm_paths=(("normalization",),), + export_fn=lambda c: {("normalization",): {"epsilon": c.normalization.epsilon}}, + import_fn=lambda hf: ({("normalization",): hf["normalization"]} if "normalization" in hf else {}), + ), + # Architecture fields not surfaced in HF; round-trip at default. + "layers_unmapped": IgnoredConfigConverter( + ("q_projection_layer",), + ("k_projection_layer",), + ("v_projection_layer",), + ("f_a_projection_layer",), + ("f_b_projection_layer",), + ("g_a_projection_layer",), + ("g_b_projection_layer",), + ("beta_projection_layer",), + ("output_projection_layer",), + ("dt_bias_weight",), + ("a_log_weight",), + ), } + # --- weight side (imperative) --- + @classmethod def get_converters( cls, @@ -350,11 +407,7 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - # Fast-LLM KDA uses abbreviated names matching the external module: - # q_proj, k_proj, v_proj, q_conv, k_conv, v_conv, f_a_proj, f_b_proj, - # g_a_proj, g_b_proj, beta_proj, o_proj, A_log, dt_bias, norm return [ - # Q/K/V projections *get_weight_and_bias_converters( f"{fast_llm_prefix}.q_proj", f"{hf_prefix}.q_proj", @@ -373,7 +426,6 @@ def get_converters( False, drop_on_export=drop_on_export, ), - # Convolutions (Q, K, V) *get_weight_and_bias_converters( f"{fast_llm_prefix}.q_conv", f"{hf_prefix}.q_conv", @@ -392,7 +444,6 @@ def get_converters( False, drop_on_export=drop_on_export, ), - # Gate projections (f_a, f_b, g_a, g_b) *get_weight_and_bias_converters( f"{fast_llm_prefix}.f_a_proj", f"{hf_prefix}.f_a_proj", @@ -417,21 +468,18 @@ def get_converters( False, drop_on_export=drop_on_export, ), - # Beta projection *get_weight_and_bias_converters( f"{fast_llm_prefix}.beta_proj", f"{hf_prefix}.beta_proj", False, drop_on_export=drop_on_export, ), - # Output projection *get_weight_and_bias_converters( f"{fast_llm_prefix}.o_proj", f"{hf_prefix}.o_proj", False, drop_on_export=drop_on_export, ), - # Learnable parameters get_parameter_converter( f"{fast_llm_prefix}.A_log", f"{hf_prefix}.A_log", @@ -442,7 +490,6 @@ def get_converters( f"{hf_prefix}.dt_bias", drop_on_export=drop_on_export, ), - # Normalization *LlamaNormalizationConverter.get_converters( config.normalization, f"{fast_llm_prefix}.norm", @@ -452,56 +499,38 @@ def get_converters( ] -class Apriel2StochasticMixerConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - mixers = {} - for name, sub_mixer_config in config["mixers"].items(): - mixer_type = sub_mixer_config["type"] - if mixer_type == "attention": - mixers[name] = Apriel2AttentionConverter.import_config(sub_mixer_config) - elif mixer_type == "mamba": - mixers[name] = Apriel2MambaConverter.import_config(sub_mixer_config) - elif mixer_type == "gdn": - mixers[name] = Apriel2GatedDeltaNetConverter.import_config(sub_mixer_config) - elif mixer_type == "kda": - mixers[name] = Apriel2KimiDeltaAttentionConverter.import_config(sub_mixer_config) - else: - raise ValueError(f"Unknown sub-mixer type: {mixer_type}") - - result = { - "type": "stochastic", - "mixers": mixers, - "main_mixer_name": config["main_mixer_name"], - } - if "sampling_strategy" in config: - result["sampling_strategy"] = config["sampling_strategy"] - return result +# Mixer dispatch registry — used inside StochasticMixer (no nested-stochastic) and at the block level. +APRIEL2_LEAF_MIXER_REGISTRY: dict = { + AttentionConfig: Apriel2AttentionConverter, + MambaConfig: Apriel2MambaConverter, + GatedDeltaNetConfig: Apriel2GatedDeltaNetConverter, + KimiDeltaAttentionConfig: Apriel2KimiDeltaAttentionConverter, +} + + +class Apriel2StochasticMixerConverter(ConfigSectionConverter): + fast_llm_config_class = StochasticMixerConfig + hf_type_name = "stochastic" @classmethod - def export_config(cls, config: StochasticMixerConfig) -> dict: - mixers = {} - for name, sub_mixer in config.mixers.items(): - mixer_type = type(sub_mixer) - if mixer_type is AttentionConfig: - mixers[name] = Apriel2AttentionConverter.export_config(sub_mixer) - elif mixer_type is MambaConfig: - mixers[name] = Apriel2MambaConverter.export_config(sub_mixer) - elif mixer_type is GatedDeltaNetConfig: - mixers[name] = Apriel2GatedDeltaNetConverter.export_config(sub_mixer) - elif mixer_type is KimiDeltaAttentionConfig: - mixers[name] = Apriel2KimiDeltaAttentionConverter.export_config(sub_mixer) - else: - raise ValueError(f"Unknown sub-mixer type: {mixer_type}") - - result = { - "type": "stochastic", - "mixers": mixers, - "main_mixer_name": config.main_mixer_name, + def _create_config_converters(cls) -> dict: + from fast_llm.layers.decoder.config import StochasticMixerSamplingStrategy + + return { + "mixers": TypedDictContainerConfigConverter( + fast_llm_path=("mixers",), + hf_path=("mixers",), + registry=APRIEL2_LEAF_MIXER_REGISTRY, + ), + "main_mixer_name": RenameConfigConverter(("main_mixer_name",), ("main_mixer_name",)), + "sampling_strategy": OptionalConfigConverter( + ("sampling_strategy",), + ("sampling_strategy",), + sentinel=StochasticMixerSamplingStrategy.uniform, + ), } - if config.sampling_strategy != StochasticMixerSamplingStrategy.uniform: - result["sampling_strategy"] = config.sampling_strategy.value - return result + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -513,136 +542,128 @@ def get_converters( ) -> list[WeightConverter]: converters = [] for name, sub_mixer in config.mixers.items(): - mixer_type = type(sub_mixer) - if mixer_type is AttentionConfig: - converter_class = Apriel2AttentionConverter - hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}" - elif mixer_type is MambaConfig: - converter_class = Apriel2MambaConverter - hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}" - elif mixer_type is GatedDeltaNetConfig: - converter_class = Apriel2GatedDeltaNetConverter - hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}" - elif mixer_type is KimiDeltaAttentionConfig: - converter_class = Apriel2KimiDeltaAttentionConverter - hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}" - else: - raise ValueError(f"Unknown sub-mixer type: {mixer_type}") + converter_class = APRIEL2_LEAF_MIXER_REGISTRY.get(type(sub_mixer)) + if converter_class is None: + raise ValueError(f"Unknown sub-mixer type: {type(sub_mixer)}") converters.extend( converter_class.get_converters( sub_mixer, f"{fast_llm_prefix}.mixers.{name}", - hf_sub_mixer_prefix, + f"{hf_prefix}.mixers.{name}", drop_on_export=drop_on_export, ) ) - return converters -class Apriel2BlockConverter: - @classmethod - def import_config(cls, config: dict, block_config: dict) -> dict: - mixer_config = block_config["mixer"] - mixer_type = mixer_config["type"] - - if mixer_type == "attention": - mixer = Apriel2AttentionConverter.import_config(mixer_config) - elif mixer_type == "mamba": - mixer = Apriel2MambaConverter.import_config(mixer_config) - elif mixer_type == "stochastic": - mixer = Apriel2StochasticMixerConverter.import_config(mixer_config) - elif mixer_type == "gdn": - mixer = Apriel2GatedDeltaNetConverter.import_config(mixer_config) - elif mixer_type == "kda": - mixer = Apriel2KimiDeltaAttentionConverter.import_config(mixer_config) - else: - raise ValueError(f"Unknown mixer type: {mixer_type}") +# Block-level mixer registry includes StochasticMixer (which can wrap leaf mixers). +APRIEL2_BLOCK_MIXER_REGISTRY: dict = { + **APRIEL2_LEAF_MIXER_REGISTRY, + StochasticMixerConfig: Apriel2StochasticMixerConverter, +} + - from fast_llm.functional.config import ActivationType +# ============================================================ +# Normalization converters +# ============================================================ - mlp_config = block_config["mlp"] - mlp = { - "type": "mlp", - "intermediate_size": mlp_config["intermediate_size"], - "activation": ActivationType.from_hf_name(mlp_config["activation"]), - "gated": mlp_config["gated"], - "add_linear_biases": mlp_config["add_linear_biases"], + +class Apriel2RMSNormConverter(ConfigSectionConverter): + fast_llm_config_class = RMSNormalizationConfig + hf_type_name = "rms_norm" + + @classmethod + def _create_config_converters(cls) -> dict: + return { + "epsilon": RenameConfigConverter(("epsilon",), ("epsilon",)), + "weight": IgnoredConfigConverter(("weight",)), + "zero_centered": ConstantImportConfigConverter(("zero_centered",), False), } - # Import per-layer MLP bias settings (layer_1, layer_2) - for layer_name in ("layer_1", "layer_2"): - if layer_name in mlp_config: - layer_cfg = mlp_config[layer_name] - if "bias" in layer_cfg: - mlp[layer_name] = {"bias": layer_cfg["bias"]} - normalization = block_config["normalization"] +class Apriel2LayerNormConverter(ConfigSectionConverter): + fast_llm_config_class = LayerNormalizationConfig + hf_type_name = "layer_norm" + + @classmethod + def _create_config_converters(cls) -> dict: return { - "mixer": mixer, - "mlp": mlp, - "normalization": normalization, + "epsilon": RenameConfigConverter(("epsilon",), ("epsilon",)), + "weight": IgnoredConfigConverter(("weight",)), + "bias": IgnoredConfigConverter(("bias",)), + "zero_centered": ConstantImportConfigConverter(("zero_centered",), False), } + +class Apriel2NoNormConverter(ConfigSectionConverter): + fast_llm_config_class = NoNormalizationConfig + hf_type_name = "none" + @classmethod - def export_config(cls, config: DecoderBlockConfig) -> dict: - from fast_llm.layers.common.normalization.config import ( - LayerNormalizationConfig, - NoNormalizationConfig, - RMSNormalizationConfig, - ) + def _create_config_converters(cls) -> dict: + return {} - mixer_type = type(config.mixer) - - if mixer_type is AttentionConfig: - mixer = Apriel2AttentionConverter.export_config(config.mixer) - elif mixer_type is MambaConfig: - mixer = Apriel2MambaConverter.export_config(config.mixer) - elif mixer_type is StochasticMixerConfig: - mixer = Apriel2StochasticMixerConverter.export_config(config.mixer) - elif mixer_type is GatedDeltaNetConfig: - mixer = Apriel2GatedDeltaNetConverter.export_config(config.mixer) - elif mixer_type is KimiDeltaAttentionConfig: - mixer = Apriel2KimiDeltaAttentionConverter.export_config(config.mixer) - else: - raise ValueError(f"Unknown mixer type: {mixer_type}") - - norm_type = type(config.normalization) - if norm_type is RMSNormalizationConfig: - norm_type_str = "rms_norm" - elif norm_type is LayerNormalizationConfig: - norm_type_str = "layer_norm" - elif norm_type is NoNormalizationConfig: - norm_type_str = "none" - else: - raise ValueError(f"Unknown normalization type: {norm_type}") - from fast_llm.layers.decoder.mlp.config import MLPConfig +APRIEL2_NORM_REGISTRY: dict = { + RMSNormalizationConfig: Apriel2RMSNormConverter, + LayerNormalizationConfig: Apriel2LayerNormConverter, + NoNormalizationConfig: Apriel2NoNormConverter, +} - if not isinstance(config.mlp, MLPConfig): - raise ValueError(f"Unsupported MLP type: {type(config.mlp)}") - mlp = { - "type": "mlp", - "intermediate_size": config.mlp.intermediate_size, - "activation": config.mlp.activation.value, - "gated": config.mlp.gated, - "add_linear_biases": config.mlp.add_linear_biases, +# ============================================================ +# MLP, Block, Decoder, Head +# ============================================================ + + +class Apriel2MLPConverter(ConfigSectionConverter): + fast_llm_config_class = MLPConfig + hf_type_name = "mlp" + + @classmethod + def _create_config_converters(cls) -> dict: + layer_names = ("layer_1", "layer_2") + return { + # MLP is wrapped via NestedConfigConverter (no Dispatch discriminator), so emit the HF + # ``"type": "mlp"`` discriminator from inside this converter. + "hf_type": ConstantExportConfigConverter(("type",), "mlp"), + "intermediate_size": RenameConfigConverter(("intermediate_size",), ("intermediate_size",)), + "gated": RenameConfigConverter(("gated",), ("gated",)), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("add_linear_biases",)), + "activation": CustomConfigConverter( + fast_llm_paths=(("activation",),), + export_fn=lambda c: {("activation",): c.activation.hf_name}, + import_fn=lambda hf: {("activation",): ActivationType.from_hf_name(hf["activation"])}, + ), + "layers": CustomConfigConverter( + fast_llm_paths=tuple((name,) for name in layer_names), + export_fn=lambda c: _per_layer_bias_export(c, layer_names), + import_fn=lambda hf: _per_layer_bias_import(hf, layer_names), + ), } - # Export per-layer MLP bias settings (layer_1, layer_2) - if config.mlp.layer_1.bias.enabled is not None: - mlp["layer_1"] = {"bias": {"enabled": config.mlp.layer_1.bias.enabled}} - if config.mlp.layer_2.bias.enabled is not None: - mlp["layer_2"] = {"bias": {"enabled": config.mlp.layer_2.bias.enabled}} - normalization = {"type": norm_type_str, "epsilon": config.normalization.epsilon} +class Apriel2BlockConverter(ConfigSectionConverter): + fast_llm_config_class = DecoderBlockConfig + + @classmethod + def _create_config_converters(cls) -> dict: return { - "mixer": mixer, - "mlp": mlp, - "normalization": normalization, + "mixer": DispatchConfigConverter( + fast_llm_path=("mixer",), + hf_path=("mixer",), + registry=APRIEL2_BLOCK_MIXER_REGISTRY, + ), + "mlp": NestedConfigConverter(("mlp",), Apriel2MLPConverter, hf_path=("mlp",)), + "normalization": DispatchConfigConverter( + fast_llm_path=("normalization",), + hf_path=("normalization",), + registry=APRIEL2_NORM_REGISTRY, + ), } + # --- weight side (imperative) --- + @classmethod def get_converters( cls, @@ -651,46 +672,30 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - converters = [] - mixer_type = type(config.mixer) - if mixer_type is AttentionConfig: - converter_class = Apriel2AttentionConverter - hf_mixer_prefix = f"{hf_prefix}.mixer" - elif mixer_type is MambaConfig: - converter_class = Apriel2MambaConverter - hf_mixer_prefix = f"{hf_prefix}.mixer" - elif mixer_type is StochasticMixerConfig: - converter_class = Apriel2StochasticMixerConverter - hf_mixer_prefix = f"{hf_prefix}.mixer" - elif mixer_type is GatedDeltaNetConfig: - converter_class = Apriel2GatedDeltaNetConverter - hf_mixer_prefix = f"{hf_prefix}.mixer" - elif mixer_type is KimiDeltaAttentionConfig: - converter_class = Apriel2KimiDeltaAttentionConverter - hf_mixer_prefix = f"{hf_prefix}.mixer" - else: - raise ValueError(f"Unknown mixer type: {mixer_type}") - - converters.extend( + converter_class = APRIEL2_BLOCK_MIXER_REGISTRY.get(type(config.mixer)) + if converter_class is None: + raise ValueError(f"Unknown mixer type: {type(config.mixer)}") + converters: list[WeightConverter] = list( converter_class.get_converters( config.mixer, f"{fast_llm_prefix}.mixer", - hf_mixer_prefix, + f"{hf_prefix}.mixer", drop_on_export=drop_on_export, ) ) - # Per-layer MLP bias: use layer-specific setting if set, else default - def get_mlp_layer_bias(layer_config, default: bool) -> bool: - if layer_config.bias.enabled is not None: - return layer_config.bias.enabled - return default - - layer_1_bias = get_mlp_layer_bias(config.mlp.layer_1, config.mlp.add_linear_biases) - layer_2_bias = get_mlp_layer_bias(config.mlp.layer_2, config.mlp.add_linear_biases) + layer_1_bias = ( + config.mlp.layer_1.bias.enabled + if config.mlp.layer_1.bias.enabled is not None + else config.mlp.add_linear_biases + ) + layer_2_bias = ( + config.mlp.layer_2.bias.enabled + if config.mlp.layer_2.bias.enabled is not None + else config.mlp.add_linear_biases + ) if config.mlp.gated: - # Gated MLP: gate_proj + up_proj -> layer_1 (split), down_proj -> layer_2 converters.extend( [ *get_weight_and_bias_converters( @@ -710,8 +715,6 @@ def get_mlp_layer_bias(layer_config, default: bool) -> bool: ] ) else: - # Non-gated MLP: up_proj -> layer_1, down_proj -> layer_2 - # Note: layer_2 still needs MLPLayer2Converter for the transpose converters.extend( [ *get_weight_and_bias_converters( @@ -747,73 +750,52 @@ def get_mlp_layer_bias(layer_config, default: bool) -> bool: ), ] ) - return converters -class Apriel2DecoderConverter: - block_converter_class: typing.ClassVar[type[Apriel2BlockConverter]] = Apriel2BlockConverter +class Apriel2FixedDecoderConverter(ConfigSectionConverter): + fast_llm_config_class = FixedBlockSequenceConfig + hf_type_name = "fixed" @classmethod - def import_config(cls, config: dict) -> dict: - decoder_config = config["decoder"] - decoder_type = decoder_config["type"] - - if decoder_type == "fixed": - block_config = decoder_config["block"] - imported_block = cls.block_converter_class.import_config(config, block_config) - - return { - "type": "fixed", - "num_blocks": decoder_config["num_blocks"], - "block": imported_block, - } - - elif decoder_type == "pattern": - blocks = {} - for name, block_config in decoder_config["blocks"].items(): - blocks[name] = cls.block_converter_class.import_config(config, block_config) - - return { - "type": "pattern", - "blocks": blocks, - "pattern": decoder_config["pattern"], - "num_blocks": decoder_config["num_blocks"], - } + def _create_config_converters(cls) -> dict: + return { + "num_blocks": RenameConfigConverter(("num_blocks",), ("num_blocks",)), + "block": NestedConfigConverter(("block",), Apriel2BlockConverter, hf_path=("block",)), + } - else: - raise ValueError(f"Unknown decoder type: {decoder_type}") + +class Apriel2PatternDecoderConverter(ConfigSectionConverter): + fast_llm_config_class = PatternBlockSequenceConfig + hf_type_name = "pattern" @classmethod - def export_config(cls, config) -> dict: - from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig - - if isinstance(config, FixedBlockSequenceConfig): - block_config = cls.block_converter_class.export_config(config.block) - return { - "decoder": { - "type": "fixed", - "num_blocks": config.num_blocks, - "block": block_config, - } - } - - elif isinstance(config, PatternBlockSequenceConfig): - blocks = {} - for name, block_config in config.blocks.items(): - blocks[name] = cls.block_converter_class.export_config(block_config) - - return { - "decoder": { - "type": "pattern", - "blocks": blocks, - "pattern": config.pattern, - "num_blocks": config.num_blocks, - } - } + def _create_config_converters(cls) -> dict: + return { + "num_blocks": RenameConfigConverter(("num_blocks",), ("num_blocks",)), + "pattern": RenameConfigConverter(("pattern",), ("pattern",)), + "blocks": TypedDictContainerConfigConverter( + fast_llm_path=("blocks",), + hf_path=("blocks",), + registry={DecoderBlockConfig: Apriel2BlockConverter}, + ), + } - else: - raise ValueError(f"Unknown decoder config type: {type(config)}") + +APRIEL2_DECODER_REGISTRY: dict = { + FixedBlockSequenceConfig: Apriel2FixedDecoderConverter, + PatternBlockSequenceConfig: Apriel2PatternDecoderConverter, +} + + +class Apriel2DecoderConverter: + """Imperative decoder dispatcher kept for the weight side. + + Config-side conversion is handled declaratively via :class:`Apriel2FixedDecoderConverter` and + :class:`Apriel2PatternDecoderConverter`, dispatched at the base-model level. + """ + + block_converter_class: typing.ClassVar[type[Apriel2BlockConverter]] = Apriel2BlockConverter @classmethod def get_converters( @@ -823,9 +805,7 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig - - converters = [] + converters: list[WeightConverter] = [] if type(config) is FixedBlockSequenceConfig: for block_index in range(config.num_blocks): converters += cls.block_converter_class.get_converters( @@ -848,28 +828,25 @@ def get_converters( return converters -class Apriel2HeadConverter: - normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter +class Apriel2HeadConverter(ConfigSectionConverter): + fast_llm_config_class = LanguageModelHeadConfig - @classmethod - def import_config(cls, config: dict) -> dict: - norm_config = config["head"]["normalization"] - return {"normalization": {"type": "rms_norm", "epsilon": norm_config["epsilon"]}} + normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter @classmethod - def export_config(cls, config) -> dict: - from fast_llm.layers.language_model.config import LanguageModelHeadConfig - - Assert.custom(isinstance, config, LanguageModelHeadConfig) + def _create_config_converters(cls) -> dict: return { - "head": { - "normalization": { - "type": "rms_norm", - "epsilon": config.normalization.epsilon, - } - } + "normalization": DispatchConfigConverter( + fast_llm_path=("normalization",), + hf_path=("normalization",), + registry=APRIEL2_NORM_REGISTRY, + ), + "output_weight": IgnoredConfigConverter(("output_weight",)), + "prediction_heads": IgnoredConfigConverter(("prediction_heads",)), } + # --- weight side (imperative) --- + @classmethod def get_converters( cls, @@ -892,33 +869,35 @@ def get_converters( ] -class Apriel2BaseModelConverter: +class Apriel2BaseModelConverter(ConfigSectionConverter): + fast_llm_config_class = GPTBaseModelConfig + decoder_converter_class: typing.ClassVar[type[Apriel2DecoderConverter]] = Apriel2DecoderConverter embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter head_converter_class: typing.ClassVar[type[Apriel2HeadConverter]] = Apriel2HeadConverter @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - "embeddings": cls.embeddings_converter_class.import_config(config), - "decoder": cls.decoder_converter_class.import_config(config), - "head": cls.head_converter_class.import_config(config), - "hidden_size": config["hidden_size"], - "tied_embedding_weight": config["tie_word_embeddings"], + "embeddings": NestedConfigConverter(("embeddings",), cls.embeddings_converter_class), + "decoder": DispatchConfigConverter( + fast_llm_path=("decoder",), + hf_path=("decoder",), + registry=APRIEL2_DECODER_REGISTRY, + ), + "head": NestedConfigConverter(("head",), cls.head_converter_class, hf_path=("head",)), + "hidden_size": RenameConfigConverter(("hidden_size",), ("hidden_size",)), + "tied_embedding_weight": RenameConfigConverter(("tied_embedding_weight",), ("tie_word_embeddings",)), + "peft": IgnoredConfigConverter(("peft",)), } @classmethod - def export_config(cls, config: GPTBaseModelConfig) -> dict: - Assert.custom(isinstance, config, GPTBaseModelConfig) - return safe_merge_dicts( - cls.embeddings_converter_class.export_config(config.embeddings), - cls.decoder_converter_class.export_config(config.decoder), - cls.head_converter_class.export_config(config.head), - { - "tie_word_embeddings": config.tied_embedding_weight, - "hidden_size": config.hidden_size, - }, - ) + def _validate_export(cls, config: GPTBaseModelConfig) -> None: + from fast_llm.layers.common.peft.config import NoPeftConfig + + Assert.custom(isinstance, config.peft, NoPeftConfig) + + # --- weight side (imperative) --- @classmethod def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: @@ -955,7 +934,7 @@ def get_model_files(cls) -> tuple[str, str, str | None]: @classmethod def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: base_model = config.base_model - exported = safe_merge_dicts( + return safe_merge_dicts( cls.base_model_converter_class.export_config(base_model), { "architectures": [cls.architecture], @@ -967,7 +946,6 @@ def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: }, }, ) - return exported @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> dict[str, typing.Any]: From f5103839c6206aa24a0ec8b6a77284c607d73a32 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 6 May 2026 10:22:09 -0400 Subject: [PATCH 08/27] Migrate Mistral / Qwen2 / MTP-Llama config converters to declarative primitives Each format inherits Llama's `_create_config_converters` and replaces only the fields that diverge: * Mistral: ConstantImportConfigConverter pinning `add_linear_biases=False` for attention and MLP (HF format has no `attention_bias`/`mlp_bias`); rename `window_size` <-> `sliding_window`. * Qwen2: ConstantImportConfigConverter for `add_linear_biases`; CustomConfigConverter for `head_size` (no `head_dim` HF field, derive on import); CustomConfigConverter for per-layer biases (always Q/K/V=True, dense=False); the head_dim relationship `heads * head_size == hidden_size` moves to `_validate_export` on the base-model converter; the use_mrope guard moves to `import_config`. * MTP-Llama: RenameConfigConverter for `prediction_heads` (Llama blanket-ignores it). Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/models/gpt/conversion/mistral.py | 47 ++++--------- fast_llm/models/gpt/conversion/mtp_llama.py | 36 +++++----- fast_llm/models/gpt/conversion/qwen2.py | 74 +++++++++++---------- 3 files changed, 69 insertions(+), 88 deletions(-) diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index 106f1f0cc..7664a195c 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -1,8 +1,7 @@ import typing from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.layers.attention.config import AttentionConfig -from fast_llm.layers.decoder.mlp.config import MLPConfig +from fast_llm.engine.checkpoint.external import ConstantImportConfigConverter, RenameConfigConverter from fast_llm.models.gpt.conversion.config import MistralCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaAttentionConverter, @@ -13,45 +12,27 @@ LlamaHuggingfaceCheckpointHandler, LlamaMLPConverter, ) -from fast_llm.utils import safe_merge_dicts class MistralAttentionConverter(LlamaAttentionConverter): @classmethod - def import_config(cls, config: dict) -> dict: - config["attention_bias"] = False - return safe_merge_dicts( - super().import_config(config), - {"window_size": config["sliding_window"]}, - ) - - @classmethod - def export_config(cls, config: AttentionConfig) -> dict: - out = safe_merge_dicts( - super().export_config(config), - {"sliding_window": config.window_size}, - ) - del out["attention_bias"] - return out - - @classmethod - def _validate_export(cls, config: AttentionConfig) -> None: - # Mistral doesn't support biases. - assert not config.add_linear_biases + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Mistral has no `attention_bias` HF field; biases are always disabled. + "add_linear_biases": ConstantImportConfigConverter(("add_linear_biases",), False), + "window_size": RenameConfigConverter(("window_size",), ("sliding_window",)), + } class MistralMLPConverter(LlamaMLPConverter): @classmethod - def import_config(cls, config: dict) -> dict: - config["mlp_bias"] = False - return super().import_config(config) - - @classmethod - def export_config(cls, config: MLPConfig) -> dict: - assert not config.add_linear_biases - out = super().export_config(config) - del out["mlp_bias"] - return out + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Mistral has no `mlp_bias` HF field; biases are always disabled. + "add_linear_biases": ConstantImportConfigConverter(("add_linear_biases",), False), + } class MistralBlockConverter(LlamaBlockConverter): diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index f681c4a24..787ba0220 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -3,9 +3,9 @@ from transformers import PretrainedConfig from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.engine.checkpoint.external import RenameConfigConverter, WeightConverter from fast_llm.layers.block.config import FixedBlockSequenceConfig -from fast_llm.layers.language_model.config import LanguageModelConfig, LanguageModelHeadConfig +from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import MTPLlamaCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( @@ -20,27 +20,21 @@ class MTPLlamaHeadConverter(LlamaHeadConverter): @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - **super().import_config(config), - "prediction_heads": config["prediction_heads"], + **super()._create_config_converters(), + # MTP-Llama exposes the prediction-heads count via the HF config; Llama itself blanket-ignores it. + "prediction_heads": RenameConfigConverter(("prediction_heads",), ("prediction_heads",)), } - @classmethod - def export_config(cls, config: LanguageModelHeadConfig) -> dict: - return safe_merge_dicts( - super().export_config(config), - {"prediction_heads": config.prediction_heads}, - ) - @classmethod def get_converters( cls, config: LanguageModelConfig, exported_config: dict, ) -> list[WeightConverter]: - # Override: map head.final_norm to model.mtp_norms.0 (not model.norm as in standard Llama), - # since MTPLlamaModel uses mtp_norms[0] for the first prediction head. + # MTP-Llama uses ``model.mtp_norms.0`` for the first prediction head's final norm + # instead of the standard ``model.norm``. converters = [ *cls.normalization_converter_class.get_converters( config.head.normalization, @@ -70,19 +64,19 @@ def get_converters( class MTPLlamaDecoderConverter(LlamaDecoderConverter): @classmethod - def import_config(cls, config: dict) -> dict: + def import_config(cls, hf_dict: dict) -> dict: return { - "block": cls.block_converter_class.import_config(config), - "num_blocks": config["num_hidden_layers"], + "block": cls.block_converter_class.import_config(hf_dict), + "num_blocks": hf_dict["num_hidden_layers"], } @classmethod - def export_config(cls, config: FixedBlockSequenceConfig) -> dict: + def export_config(cls, decoder_config: FixedBlockSequenceConfig) -> dict: # TODO: Support PatternBlockSequenceConfig with compatible configs. - Assert.custom(isinstance, config, FixedBlockSequenceConfig) + Assert.custom(isinstance, decoder_config, FixedBlockSequenceConfig) return safe_merge_dicts( - cls.block_converter_class.export_config(config.block), - {"num_hidden_layers": config.num_blocks}, + cls.block_converter_class.export_config(decoder_config.block), + {"num_hidden_layers": decoder_config.num_blocks}, ) diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index 437431ed1..6fc9a45eb 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -1,10 +1,13 @@ import typing from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.engine.checkpoint.external import ( + ConstantImportConfigConverter, + CustomConfigConverter, + WeightConverter, +) from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import FixedBlockSequenceConfig -from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.models.gpt.config import GPTBaseModelConfig from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat from fast_llm.models.gpt.conversion.llama import ( @@ -19,35 +22,41 @@ QueryWeightConverter, get_weight_and_bias_converters, ) -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div class Qwen2AttentionConverter(LlamaAttentionConverter): # TODO: Support sliding window with max_window_layers (need 2 kinds of block?) @classmethod - def import_config(cls, config: dict) -> dict: - config["attention_bias"] = False - out = super().import_config(config) - out["query_layer"] = {"bias": {"enabled": True}} - out["key_layer"] = {"bias": {"enabled": True}} - out["value_layer"] = {"bias": {"enabled": True}} - out["dense_layer"] = {"bias": {"enabled": False}} - return out - - @classmethod - def export_config(cls, config: AttentionConfig) -> dict: - out = super().export_config(config) - del out["attention_bias"] - # Qwen2Config does not have head_dim as a standard field; it is always - # derivable as hidden_size // num_attention_heads. - del out["head_dim"] + def _create_config_converters(cls) -> dict: + out = super()._create_config_converters() + # Qwen2 has no `attention_bias` HF field; the model always has Q/K/V biases enabled and no dense bias. + out["add_linear_biases"] = ConstantImportConfigConverter(("add_linear_biases",), False) + # Qwen2Config does not have `head_dim`; it is always derivable as `hidden_size // num_attention_heads`. + out["head_size"] = CustomConfigConverter( + fast_llm_paths=(("head_size",),), + export_fn=lambda config: {}, + import_fn=lambda hf: {("head_size",): div(hf["hidden_size"], hf["num_attention_heads"])}, + ) + # Override Llama's blanket per-layer bias ignore with Qwen2's hardcoded layer biases. + # On export the per-layer biases must be compatible with `add_linear_biases`; see ``_validate_export``. + out["linear_layers"] = CustomConfigConverter( + fast_llm_paths=(("query_layer",), ("key_layer",), ("value_layer",), ("dense_layer",)), + export_fn=lambda config: {}, + import_fn=lambda hf: { + ("query_layer",): {"bias": {"enabled": True}}, + ("key_layer",): {"bias": {"enabled": True}}, + ("value_layer",): {"bias": {"enabled": True}}, + ("dense_layer",): {"bias": {"enabled": False}}, + }, + ) return out @classmethod def _validate_export(cls, config: AttentionConfig) -> None: Assert.is_(type(config), AttentionConfig) - # There are multiple ways to enable biases on QKV only + # There are multiple ways to enable biases on QKV only. if config.add_linear_biases: Assert.incl(config.query_layer.bias.enabled, (None, True)) Assert.incl(config.key_layer.bias.enabled, (None, True)) @@ -95,15 +104,12 @@ def get_converters( class Qwen2MLPConverter(LlamaMLPConverter): @classmethod - def import_config(cls, config: dict) -> dict: - config["mlp_bias"] = False - return super().import_config(config) - - @classmethod - def export_config(cls, config: MLPConfig) -> dict: - out = super().export_config(config) - del out["mlp_bias"] - return out + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Qwen2 has no `mlp_bias` HF field; biases are always disabled. + "add_linear_biases": ConstantImportConfigConverter(("add_linear_biases",), False), + } class Qwen2BlockConverter(LlamaBlockConverter): @@ -124,12 +130,13 @@ class Qwen2BaseModelConverter(LlamaBaseModelConverter): head_converter_class: typing.ClassVar[type[Qwen2HeadConverter]] = Qwen2HeadConverter @classmethod - def import_config(cls, config: dict) -> dict: - assert config.get("use_mrope") is not True, "MRoPE (use_mrope=True) is not supported by the Qwen2 converter" - return super().import_config(config) + def import_config(cls, hf_dict: dict) -> dict: + assert hf_dict.get("use_mrope") is not True, "MRoPE (use_mrope=True) is not supported by the Qwen2 converter" + return super().import_config(hf_dict) @classmethod - def export_config(cls, config: GPTBaseModelConfig) -> dict: + def _validate_export(cls, config: GPTBaseModelConfig) -> None: + super()._validate_export(config) block = ( config.decoder.block if isinstance(config.decoder, FixedBlockSequenceConfig) @@ -141,7 +148,6 @@ def export_config(cls, config: GPTBaseModelConfig) -> dict: config.hidden_size, msg="Qwen2 format omits head_dim; requires heads * head_size == hidden_size", ) - return super().export_config(config) class Qwen2HuggingfaceCheckpointHandler(LlamaHuggingfaceCheckpointHandler): From df438c46501809dc245f1f278fe6c5a859096fbb Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 6 May 2026 10:35:23 -0400 Subject: [PATCH 09/27] Migrate Mixtral config converters to declarative primitives `MixtralMLPConverter` switches its `fast_llm_config_class` to `MoEMLPConfig` so the architecture-coverage check sees MoE-specific fields. The config-side overrides: * `add_linear_biases` -> ConstantImportConfigConverter (Mixtral has no `mlp_bias`). * `experts` <-> `num_local_experts` and `experts_per_token` <-> `num_experts_per_tok` via RenameConfigConverter. * `shared_experts=0` and `routing=topk` pinned via ConstantImportConfigConverter so they round-trip cleanly without an HF representation. * `router` covered by IgnoredConfigConverter (Mixtral's gate is a default `LinearConfig`). The Fast-LLM dynamic-type discriminator (`type: "moe"`) is injected via an `import_config` override since the MLP is wrapped via `NestedConfigConverter` rather than `DispatchConfigConverter`. Diffusion-Dream and Diffusion-Llama need no migration: they only override `architecture`, `get_transformers_configuration_class`, and `_export_config` (auto_map). They inherit the declarative converters from their parents (Qwen2 and Llama). Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/models/gpt/conversion/mixtral.py | 53 ++++++++++++----------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/fast_llm/models/gpt/conversion/mixtral.py b/fast_llm/models/gpt/conversion/mixtral.py index 6908d2958..7659befa3 100644 --- a/fast_llm/models/gpt/conversion/mixtral.py +++ b/fast_llm/models/gpt/conversion/mixtral.py @@ -1,8 +1,14 @@ import typing from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import SplitWeightConverter, WeightConverter -from fast_llm.layers.decoder.mlp.config import MoEMLPConfig +from fast_llm.engine.checkpoint.external import ( + ConstantImportConfigConverter, + IgnoredConfigConverter, + RenameConfigConverter, + SplitWeightConverter, + WeightConverter, +) +from fast_llm.layers.decoder.mlp.config import MoEMLPConfig, RoutingType from fast_llm.models.gpt.conversion.config import MixtralCheckpointFormat from fast_llm.models.gpt.conversion.llama import LlamaMLPConverter, MLPLayer2Converter, get_weight_and_bias_converters from fast_llm.models.gpt.conversion.mistral import ( @@ -12,35 +18,32 @@ MistralHeadConverter, MistralHuggingfaceCheckpointHandler, ) -from fast_llm.utils import Assert, safe_merge_dicts class MixtralMLPConverter(LlamaMLPConverter): + fast_llm_config_class = MoEMLPConfig + @classmethod - def import_config(cls, config: dict) -> dict: - config["mlp_bias"] = False - return safe_merge_dicts( - super().import_config(config), - { - "type": "moe", - "experts": config["num_local_experts"], - "experts_per_token": config["num_experts_per_tok"], - }, - ) + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Mixtral has no `mlp_bias` HF field; biases are always disabled. + "add_linear_biases": ConstantImportConfigConverter(("add_linear_biases",), False), + "experts": RenameConfigConverter(("experts",), ("num_local_experts",)), + "experts_per_token": RenameConfigConverter(("experts_per_token",), ("num_experts_per_tok",)), + # Mixtral has no shared experts and uses the topk default; assert on export, inject defaults on import. + "shared_experts": ConstantImportConfigConverter(("shared_experts",), 0), + "routing": ConstantImportConfigConverter(("routing",), RoutingType.topk), + # Mixtral's gate is a default LinearConfig (no bias); blanket-consume so coverage passes. + "router": IgnoredConfigConverter(("router",)), + } @classmethod - def export_config(cls, config: MoEMLPConfig) -> dict: - Assert.custom(isinstance, config, MoEMLPConfig) - assert not config.add_linear_biases - out = super().export_config(config) - del out["mlp_bias"] - return safe_merge_dicts( - out, - { - "num_local_experts": config.experts, - "num_experts_per_tok": config.experts_per_token, - }, - ) + def import_config(cls, hf_dict: dict) -> dict: + # Inject the Fast-LLM dynamic-type discriminator so `from_dict` instantiates `MoEMLPConfig` + # rather than the default `MLPConfig`. The MLP is wrapped via `NestedConfigConverter`, so + # there's no surrounding `DispatchConfigConverter` to inject this for us. + return {"type": "moe", **super().import_config(hf_dict)} @classmethod def get_converters( From 1b025dbf1fe8e794aeab237d1bc157c619d8de73 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 6 May 2026 11:07:53 -0400 Subject: [PATCH 10/27] Migrate Apriel hybrid SSM mixer config converters to declarative primitives `AprielMambaConverter`, `GatedDeltaNetConverter`, and `KimiDeltaAttentionConverter` become `ConfigSectionConverter` subclasses with their HF-side fields nested under the appropriate HF subkey (`ssm_cfg` for Mamba, `linear_attn_config` for GDN/KDA). Mamba's three sibling-default fields (`d_inner`, `d_xb`, `dt_rank`) read the HF root's `hidden_size` directly via `DefaultConfigConverter.hf_default_fn` / `CustomConfigConverter`, removing the need for an explicit `parent_context` plumbing through the framework. The per-layer convolution and dt biases use `CustomConfigConverter` to pick up the mixer-wide `add_linear_biases` fallback when unset; the existing `_check_config` per-layer assertions move to `_validate_export`. `AprielBlockConverter` (the per-block dispatcher) and `AprielDecoderConverter` (the `hybrid_block_layout` driver) stay imperative because Apriel's HF format encodes the mixer type in a parent-level list rather than a per-block discriminator, which `DispatchConfigConverter` doesn't model. The `type: "mamba"`/`"gdn"`/`"kda"` Fast-LLM discriminator is injected via a one-line `import_config` override on each leaf converter (same pattern Mixtral uses). The HF format has no test coverage in `tests/models/test_checkpoint.py` or `tests/models/test_hf_roundtrip.py`, so verification was a synthesized live round-trip covering each mixer leaf plus a hybrid attention+Mamba pattern decoder. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/models/gpt/conversion/apriel.py | 242 +++++++++++++++-------- 1 file changed, 159 insertions(+), 83 deletions(-) diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index ac732ba22..efa801799 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -4,7 +4,14 @@ from transformers import PretrainedConfig from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + CustomConfigConverter, + DefaultConfigConverter, + IgnoredConfigConverter, + RenameConfigConverter, + WeightConverter, +) from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig @@ -22,51 +29,88 @@ from fast_llm.utils import Assert, safe_merge_dicts -class AprielMambaConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - return { - "type": "mamba", - "state_size": config["ssm_cfg"]["d_state"], - "d_inner": config["ssm_cfg"].get("d_inner") or config["hidden_size"] * config["ssm_cfg"].get("expand", 1), - "add_linear_biases": config["ssm_cfg"]["bias"], - "convolution_layer": {"bias": {"enabled": config["ssm_cfg"].get("conv_bias", True)}}, - "d_xb": config["ssm_cfg"].get("d_xb") or config["hidden_size"], - "dt_layer": {"bias": {"enabled": config["ssm_cfg"].get("dt_proj_bias", True)}}, - "dt_rank": ( - math.ceil(config["hidden_size"] / 16) - if config["ssm_cfg"].get("dt_rank", "auto") == "auto" - else config["ssm_cfg"]["dt_rank"] - ), - "repeat_kv_before_conv": config["ssm_cfg"].get("repeat_kv_before_conv", True), - } +def _resolve_bias_enabled(layer_bias_enabled: bool | None, add_linear_biases: bool) -> bool: + """Per-layer bias falls back to the mixer-wide flag when unset, matching the imperative behaviour.""" + return add_linear_biases if layer_bias_enabled is None else layer_bias_enabled + + +class AprielMambaConverter(ConfigSectionConverter): + """Converts ``MambaConfig`` <-> Apriel hybrid SSM HF dict (``ssm_cfg`` subdict + root-level fallbacks). + + A few of MambaConfig's defaults are derived from the HF root's ``hidden_size`` (``d_inner`` defaults + to ``hidden_size * expand``, ``d_xb`` defaults to ``hidden_size``, ``dt_rank="auto"`` resolves to + ``ceil(hidden_size / 16)``). Those declarations read the root HF dict directly, so each leaf + converter sees the full HF root passed by the parent block dispatcher. + """ + + fast_llm_config_class = MambaConfig @classmethod - def export_config(cls, config: MambaConfig) -> dict: - cls._check_config(config) + def _create_config_converters(cls) -> dict: return { - "ssm_cfg": { - "d_state": config.state_size, - "d_inner": config.d_inner, - "bias": config.add_linear_biases, - "conv_bias": ( - config.add_linear_biases - if config.convolution_layer.bias.enabled is None - else config.convolution_layer.bias.enabled - ), - "d_xb": config.d_xb, - "dt_proj_bias": ( - config.add_linear_biases if config.dt_layer.bias.enabled is None else config.dt_layer.bias.enabled - ), - "dt_rank": config.dt_rank, - "repeat_kv_before_conv": config.repeat_kv_before_conv, - } + "state_size": RenameConfigConverter(("state_size",), ("ssm_cfg", "d_state")), + "d_inner": DefaultConfigConverter( + ("d_inner",), + ("ssm_cfg", "d_inner"), + hf_default_fn=lambda hf: hf["hidden_size"] * hf.get("ssm_cfg", {}).get("expand", 1), + ), + "d_xb": DefaultConfigConverter( + ("d_xb",), + ("ssm_cfg", "d_xb"), + hf_default_fn=lambda hf: hf["hidden_size"], + ), + "dt_rank": CustomConfigConverter( + fast_llm_paths=(("dt_rank",),), + export_fn=lambda c: {("ssm_cfg", "dt_rank"): c.dt_rank}, + import_fn=lambda hf: { + ("dt_rank",): ( + math.ceil(hf["hidden_size"] / 16) + if hf.get("ssm_cfg", {}).get("dt_rank", "auto") == "auto" + else hf["ssm_cfg"]["dt_rank"] + ) + }, + ), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("ssm_cfg", "bias")), + "repeat_kv_before_conv": DefaultConfigConverter( + ("repeat_kv_before_conv",), + ("ssm_cfg", "repeat_kv_before_conv"), + hf_default_fn=lambda hf: True, + ), + "convolution_layer": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",),), + export_fn=lambda c: { + ("ssm_cfg", "conv_bias"): _resolve_bias_enabled( + c.convolution_layer.bias.enabled, c.add_linear_biases + ) + }, + import_fn=lambda hf: { + ("convolution_layer", "bias", "enabled"): hf.get("ssm_cfg", {}).get("conv_bias", True) + }, + ), + "dt_layer": CustomConfigConverter( + fast_llm_paths=(("dt_layer",),), + export_fn=lambda c: { + ("ssm_cfg", "dt_proj_bias"): _resolve_bias_enabled(c.dt_layer.bias.enabled, c.add_linear_biases) + }, + import_fn=lambda hf: { + ("dt_layer", "bias", "enabled"): hf.get("ssm_cfg", {}).get("dt_proj_bias", True) + }, + ), + # Per-layer biases that must round-trip implicitly via add_linear_biases (validated below). + "linear_layers": IgnoredConfigConverter( + ("z_layer",), + ("x_layer",), + ("b_layer",), + ("c_layer",), + ("output_layer",), + ("dt_input_layer",), + ), + # Parameter sub-configs Mamba doesn't expose to HF; coverage-only. + "parameters": IgnoredConfigConverter(("d_weight",), ("a_log_weight",)), } @classmethod - def _check_config(cls, config: MambaConfig) -> None: - # Opportunity to make derived classes less constrained. - Assert.is_(type(config), MambaConfig) + def _validate_export(cls, config: MambaConfig) -> None: Assert.incl(config.z_layer.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.x_layer.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.b_layer.bias.enabled, (None, config.add_linear_biases)) @@ -74,6 +118,13 @@ def _check_config(cls, config: MambaConfig) -> None: Assert.incl(config.dt_input_layer.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.output_layer.bias.enabled, (None, config.add_linear_biases)) + @classmethod + def import_config(cls, hf_dict: dict) -> dict: + # Inject the Fast-LLM dynamic-type discriminator: the parent (AprielBlockConverter) selects this + # leaf via `hybrid_block_layout`, not via a nested HF discriminator, so DispatchConfigConverter's + # auto-injection isn't in play and we must add `type` manually. + return {"type": "mamba", **super().import_config(hf_dict)} + @classmethod def get_converters( cls, @@ -99,17 +150,13 @@ def get_converters( *get_weight_and_bias_converters( f"{fast_llm_prefix}.dt_proj", f"{hf_prefix}.dt_proj", - config.add_linear_biases if config.dt_layer.bias.enabled is None else config.dt_layer.bias.enabled, + _resolve_bias_enabled(config.dt_layer.bias.enabled, config.add_linear_biases), drop_on_export=drop_on_export, ), *get_weight_and_bias_converters( f"{fast_llm_prefix}.convolution", f"{hf_prefix}.conv1d", - ( - config.add_linear_biases - if config.convolution_layer.bias.enabled is None - else config.convolution_layer.bias.enabled - ), + _resolve_bias_enabled(config.convolution_layer.bias.enabled, config.add_linear_biases), drop_on_export=drop_on_export, ), get_parameter_converter( @@ -131,31 +178,36 @@ def get_converters( ] -class GatedDeltaNetConverter: +class GatedDeltaNetConverter(ConfigSectionConverter): + """Converts ``GatedDeltaNetConfig`` <-> Apriel HF ``linear_attn_config`` subdict.""" + + fast_llm_config_class = GatedDeltaNetConfig + @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - "type": "gdn", - "value_heads": config["linear_attn_config"]["gdn_num_value_heads"], - "key_heads": config["linear_attn_config"]["gdn_num_key_heads"], - "key_head_dim": config["linear_attn_config"]["gdn_key_head_dim"], - "value_head_dim": config["linear_attn_config"]["gdn_value_head_dim"], - "convolution_layer": { - "kernel_size": config["linear_attn_config"]["gdn_linear_conv_kernel_size"], - }, + "value_heads": RenameConfigConverter(("value_heads",), ("linear_attn_config", "gdn_num_value_heads")), + "key_heads": RenameConfigConverter(("key_heads",), ("linear_attn_config", "gdn_num_key_heads")), + "key_head_dim": RenameConfigConverter(("key_head_dim",), ("linear_attn_config", "gdn_key_head_dim")), + "value_head_dim": RenameConfigConverter(("value_head_dim",), ("linear_attn_config", "gdn_value_head_dim")), + "convolution_kernel_size": RenameConfigConverter( + ("convolution_layer", "kernel_size"), + ("linear_attn_config", "gdn_linear_conv_kernel_size"), + ), + # Sub-configs without HF representation; coverage-only. + "sub_configs": IgnoredConfigConverter( + ("normalization",), + ("qkv_projection_layer",), + ("ba_projection_layer",), + ("output_layer",), + ("dt_bias_weight",), + ("a_log_weight",), + ), } @classmethod - def export_config(cls, config: GatedDeltaNetConfig) -> dict: - return { - "linear_attn_config": { - "gdn_num_value_heads": config.value_heads, - "gdn_num_key_heads": config.key_heads, - "gdn_key_head_dim": config.key_head_dim, - "gdn_value_head_dim": config.value_head_dim, - "gdn_linear_conv_kernel_size": config.convolution_layer.kernel_size, - }, - } + def import_config(cls, hf_dict: dict) -> dict: + return {"type": "gdn", **super().import_config(hf_dict)} @classmethod def get_converters( @@ -209,27 +261,40 @@ def get_converters( ] -class KimiDeltaAttentionConverter: +class KimiDeltaAttentionConverter(ConfigSectionConverter): + """Converts ``KimiDeltaAttentionConfig`` <-> Apriel HF ``linear_attn_config`` subdict.""" + + fast_llm_config_class = KimiDeltaAttentionConfig + @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - "type": "kda", - "head_dim": config["linear_attn_config"]["head_dim"], - "heads": config["linear_attn_config"]["num_heads"], - "convolution_layer": { - "kernel_size": config["linear_attn_config"]["short_conv_kernel_size"], - }, + "head_dim": RenameConfigConverter(("head_dim",), ("linear_attn_config", "head_dim")), + "heads": RenameConfigConverter(("heads",), ("linear_attn_config", "num_heads")), + "convolution_kernel_size": RenameConfigConverter( + ("convolution_layer", "kernel_size"), + ("linear_attn_config", "short_conv_kernel_size"), + ), + # Sub-configs without HF representation; coverage-only. + "sub_configs": IgnoredConfigConverter( + ("normalization",), + ("q_projection_layer",), + ("k_projection_layer",), + ("v_projection_layer",), + ("f_a_projection_layer",), + ("f_b_projection_layer",), + ("g_a_projection_layer",), + ("g_b_projection_layer",), + ("beta_projection_layer",), + ("output_projection_layer",), + ("dt_bias_weight",), + ("a_log_weight",), + ), } @classmethod - def export_config(cls, config: KimiDeltaAttentionConfig) -> dict: - return { - "linear_attn_config": { - "head_dim": config.head_dim, - "num_heads": config.heads, - "short_conv_kernel_size": config.convolution_layer.kernel_size, - }, - } + def import_config(cls, hf_dict: dict) -> dict: + return {"type": "kda", **super().import_config(hf_dict)} @classmethod def get_converters( @@ -347,6 +412,11 @@ class AprielGatedDeltaNetBlockConverter(MistralBlockConverter): class AprielBlockConverter: + """Per-block dispatcher: the mixer type is encoded in the parent's ``hybrid_block_layout`` list, + not in a nested HF discriminator, so this dispatcher stays imperative rather than using + :class:`DispatchConfigConverter`. Each branch delegates to a regular declarative block converter. + """ + layout_names = { AttentionConfig: "t", MambaConfig: "m2", @@ -382,6 +452,11 @@ def get_converters( class AprielDecoderConverter(MistralDecoderConverter): + """Pattern-style decoder dispatched via Apriel's ``hybrid_block_layout`` list (one entry per block). + Stays imperative because the layout-list shape doesn't match the declarative ``decoder.type`` + discriminator that Apriel2 uses. + """ + block_converter_class: typing.ClassVar[type[AprielBlockConverter]] = AprielBlockConverter @classmethod @@ -413,7 +488,8 @@ def export_config(cls, config: BlockSequenceConfig) -> dict: pattern_block_configs = [config.blocks[block_name] for block_name in config.pattern] else: raise NotImplementedError() - # There may be all sorts of blocks, but `safe_merge_dicts` ensures they are compatible. + # Each block emits non-overlapping HF keys (attention -> flat, mamba -> ssm_cfg.*, + # gdn/kda -> linear_attn_config.*) so safe_merge_dicts is sufficient to combine them. return safe_merge_dicts( *[cls.block_converter_class.export_config(block_config) for block_config in block_configs], { From 17b91d9a91b552b2430b7ac8883a935f37cc0125 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 6 May 2026 11:46:28 -0400 Subject: [PATCH 11/27] Migrate Pixtral normalization and embeddings config converters to declarative primitives `PixtralNormalizationConverter` collapses to a single `_create_config_converters` override that pins `epsilon=1e-5` via `ConstantImportConfigConverter` (asserts on export, injects on import; no HF write). `PixtralEmbeddingsConverter` becomes a `ConfigSectionConverter` with declarations for `patch_height` (rename to `patch_size`), `patch_width` (mirror `patch_size` on import), `num_channels` (export-only constant 3), nested `normalization`, and an `IgnoredConfigConverter` for `patch_embeddings`. The `patch_height == patch_width` and `patch_embeddings.bias.enabled in (None, False)` checks move to `_validate_export`. The remaining Llava and Apriel2 multimodal converters stay imperative: they're cross-section aggregators (vision_config + text_config + top-level merge) whose shape doesn't fit a single ConfigSectionConverter, often with parent-context dependencies (e.g., the adapter's intermediate_size derives from the text model's hidden_size). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../models/multimodal/conversion/llava.py | 71 +++++++++++-------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 2a2e03502..b48b1d042 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -3,13 +3,21 @@ import torch from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + ConstantExportConfigConverter, + ConstantImportConfigConverter, + CustomConfigConverter, + IgnoredConfigConverter, + NestedConfigConverter, + RenameConfigConverter, + WeightConverter, +) from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.attention.rotary.config import Rotary2DConfig -from fast_llm.layers.common.normalization.config import RMSNormalizationConfig from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionEncoderConfig @@ -31,21 +39,15 @@ class PixtralNormalizationConverter(LlamaNormalizationConverter): - """ - epsilon hard-coded to 1e-5. - """ + """RMS norm with HF-side hardcoded epsilon=1e-5 (Pixtral's HF format omits the field).""" @classmethod - def import_config(cls, config: dict) -> dict: - return {"type": "rms_norm", "epsilon": 1e-5} - - @classmethod - def export_config(cls, config: RMSNormalizationConfig) -> dict: - Assert.custom(isinstance, config, RMSNormalizationConfig) - assert not config.zero_centered - # TODO: Too strict? - Assert.eq(config.epsilon, 1e-5) - return {} + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Pin epsilon to 1e-5: assert on export, inject on import. No HF write/read. + "epsilon": ConstantImportConfigConverter(("epsilon",), 1e-5), + } class PixtralAttentionConverter(LlamaAttentionConverter): @@ -117,32 +119,39 @@ def import_weight( ) -class PixtralEmbeddingsConverter: +class PixtralEmbeddingsConverter(ConfigSectionConverter): + """Converts ``PatchEmbeddingsConfig`` <-> Pixtral HF flat fields (``patch_size`` / ``num_channels``). + + Pixtral's HF ``vision_config`` carries a single ``patch_size`` field (height == width); the converter + expands it to both Fast-LLM dimensions on import and validates equality on export. + """ + + fast_llm_config_class = PatchEmbeddingsConfig normalization_converter_class: typing.ClassVar[type[PixtralNormalizationConverter]] = PixtralNormalizationConverter @classmethod - def import_config(cls, config: dict) -> dict: - Assert.eq(config["num_channels"], 3) + def _create_config_converters(cls) -> dict: return { - "normalization": cls.normalization_converter_class.import_config(config), - "patch_height": config["patch_size"], - "patch_width": config["patch_size"], + "patch_height": RenameConfigConverter(("patch_height",), ("patch_size",)), + # Pixtral has one `patch_size`; mirror it to `patch_width` on import and validate equality on export. + "patch_width": CustomConfigConverter( + fast_llm_paths=(("patch_width",),), + export_fn=lambda c: {}, + import_fn=lambda hf: {("patch_width",): hf["patch_size"]}, + ), + # `input_channels` is a derived cached_property pinned to 3; assert on import, emit on export. + "num_channels": ConstantExportConfigConverter(("num_channels",), 3), + # PixtralNormalizationConverter exports {} (epsilon pinned), so flat-merge is a no-op on export. + "normalization": NestedConfigConverter(("normalization",), cls.normalization_converter_class), + # patch_embeddings (the AffineLinearConfig) has no HF representation; bias presence validated below. + "patch_embeddings": IgnoredConfigConverter(("patch_embeddings",)), } @classmethod - def export_config(cls, config: PatchEmbeddingsConfig) -> dict: - Assert.custom(isinstance, config, PatchEmbeddingsConfig) + def _validate_export(cls, config: PatchEmbeddingsConfig) -> None: Assert.eq(config.patch_height, config.patch_width) Assert.incl(config.patch_embeddings.bias.enabled, (None, False)) - return safe_merge_dicts( - { - "patch_size": config.patch_height, - "num_channels": config.input_channels, - }, - cls.normalization_converter_class.export_config(config.normalization), - ) - @classmethod def get_converters( cls, config: PatchEmbeddingsConfig, fast_llm_prefix: str, hf_prefix: str From dc418d1ec2b68aab96ea807d1c7b964eabbc4da5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 6 May 2026 12:36:53 -0400 Subject: [PATCH 12/27] Remove unused weight-converter scaffolding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `CopyWeightConverter` was defined in `external.py` but never instantiated; deleted. `QueryWeightConverter` was a no-op identity (its `export_weight`/`import_weight` just unwrap and rewrap); replaced with the default `WeightConverter` at all three call sites (Llama, Qwen2, Apriel2 attention) and removed the redundant `config` arg. The broader weight-side refactor (declarative `WeightConverter` primitives, walker-driven `drop_on_export` removal) is deferred — out of scope for this PR. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/engine/checkpoint/external.py | 12 ------------ fast_llm/models/gpt/conversion/apriel2.py | 3 --- fast_llm/models/gpt/conversion/llama.py | 19 ------------------- fast_llm/models/gpt/conversion/qwen2.py | 3 --- 4 files changed, 37 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 851ea13b0..a083e4ad5 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -542,18 +542,6 @@ def import_weight( ) -class CopyWeightConverter(WeightConverter): - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - return weight[0], *[weight[0][:].clone() for _ in range(len(self.export_name) - 1)] - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - return weight[0], *[weight[0][:].clone() for _ in range(len(self.fast_llm_name) - 1)] - - class SplitWeightConverter(WeightConverter): def export_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 263af4d02..86b4caf4f 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -39,7 +39,6 @@ LlamaEmbeddingsConverter, LlamaNormalizationConverter, MLPLayer2Converter, - QueryWeightConverter, SplitWeightConverter, get_parameter_converter, get_weight_and_bias_converters, @@ -159,8 +158,6 @@ def get_converters( f"{fast_llm_prefix}.query", f"{hf_prefix}.q_proj", q_bias, - QueryWeightConverter, - config, drop_on_export=drop_on_export, ), *get_weight_and_bias_converters( diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 578098764..1888e6fd3 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -127,23 +127,6 @@ def import_weight( return (merged_weight.t().contiguous(),) -class QueryWeightConverter(WeightConverter): - # Hf uses the real format for rotary embeddings. - _config: AttentionConfig - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (query,) = weight - return (query,) - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (query,) = weight - return (query,) - - class KeyValueWeightConverter(WeightConverter): # Hf uses the real format for rotary embeddings, and keeps the key and value separate. _config: AttentionConfig @@ -398,8 +381,6 @@ def get_converters( f"{fast_llm_prefix}.query", f"{hf_prefix}.q_proj", config.add_linear_biases, - QueryWeightConverter, - config, drop_on_export=drop_on_export, ), *get_weight_and_bias_converters( diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index 6fc9a45eb..3d9d6f349 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -19,7 +19,6 @@ LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, LlamaMLPConverter, - QueryWeightConverter, get_weight_and_bias_converters, ) from fast_llm.utils import Assert, div @@ -81,8 +80,6 @@ def get_converters( f"{fast_llm_prefix}.query", f"{hf_prefix}.q_proj", True, - QueryWeightConverter, - config, drop_on_export=drop_on_export, ), *get_weight_and_bias_converters( From 8272abf75a7851834c6d004e8b4b31fa5feaa79d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 7 May 2026 13:42:49 -0400 Subject: [PATCH 13/27] Self-review fixes - Fix asymmetric round-trip in `Apriel2MambaConverter`: the `aux` declaration's import_fn now reads `d_conv` / `conv_bias` / `dt_proj_bias` back into `convolution_layer.kernel_size`, `convolution_layer.bias.enabled`, and `dt_layer.bias.enabled`. Previously these HF fields were dropped on import, which silently masked HF conv1d/dt_proj bias weights when they diverged from the mixer-wide `add_linear_biases` flag (parallel to the apriel.py mamba migration earlier in this PR). - Drop the stale TODO from `_check_architecture_coverage`'s docstring (the migrations it referred to have all landed in this PR); reword the surrounding comment to describe the current strict-subtype handling. - Combine adjacent f-strings in `DispatchConfigConverter`'s import-error message. - Hoist `StochasticMixerSamplingStrategy` to the module-level import in `apriel2.py`; it was being re-imported on every `_create_config_converters` call. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/engine/checkpoint/external.py | 7 +++---- fast_llm/models/gpt/conversion/apriel2.py | 23 +++++++++++++++++------ 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index a083e4ad5..c6ad48dcb 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -456,10 +456,9 @@ def _check_architecture_coverage(cls, config: Config, declarations: dict[str, Co check to the nested section's own converter. The check only runs when ``type(config)`` exactly matches ``cls.fast_llm_config_class`` — when the - config is a strict subclass (e.g. ``MoEMLPConfig`` fed via ``super().export_config()`` from a yet-to-be- - migrated ``MixtralMLPConverter``), the subclass converter is responsible for declaring the additional - fields and running its own check. TODO: Once Mixtral/Apriel/Apriel2 migrate, the safety net for - ``MoEMLPConfig``/``MambaConfig``/etc. is gated on those migrations landing. + config is a strict subclass (e.g. ``MoEMLPConfig`` fed through ``LlamaMLPConverter`` declarations + before the dispatching ``MixtralMLPConverter`` overrides ``fast_llm_config_class``), the subclass + converter is responsible for declaring the additional fields and running its own check. """ if type(config) is not cls.fast_llm_config_class: return diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 86b4caf4f..8d177467c 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -28,7 +28,7 @@ NoNormalizationConfig, RMSNormalizationConfig, ) -from fast_llm.layers.decoder.config import DecoderBlockConfig, StochasticMixerConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig, StochasticMixerConfig, StochasticMixerSamplingStrategy from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.language_model.config import LanguageModelHeadConfig from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig @@ -187,6 +187,21 @@ def _apriel2_mamba_aux_export(config: MambaConfig) -> dict: } +def _apriel2_mamba_aux_import(hf_dict: dict) -> dict: + """Reverse of :func:`_apriel2_mamba_aux_export`. ``conv_bias`` / ``dt_proj_bias`` can diverge from the + mixer-wide ``add_linear_biases`` flag, so they must populate the per-layer ``bias.enabled`` directly; + dropping them on import would silently mask HF bias weights when the weight loader checks the + per-layer flag.""" + out: dict = {} + if "d_conv" in hf_dict: + out[("convolution_layer", "kernel_size")] = hf_dict["d_conv"] + if "conv_bias" in hf_dict: + out[("convolution_layer", "bias", "enabled")] = hf_dict["conv_bias"] + if "dt_proj_bias" in hf_dict: + out[("dt_layer", "bias", "enabled")] = hf_dict["dt_proj_bias"] + return out + + class Apriel2MambaConverter(ConfigSectionConverter): fast_llm_config_class = MambaConfig hf_type_name = "mamba" @@ -202,9 +217,7 @@ def _create_config_converters(cls) -> dict: "aux": CustomConfigConverter( fast_llm_paths=(("convolution_layer",), ("dt_layer",)), export_fn=_apriel2_mamba_aux_export, - # The d_conv/conv_bias/dt_proj_bias HF fields are not reflected in the Fast-LLM mamba dict — - # current Apriel2 import simply ignores them and lets Fast-LLM use its own defaults. - import_fn=lambda hf: {}, + import_fn=_apriel2_mamba_aux_import, ), # Architecture fields with no HF counterpart; they round-trip at their Fast-LLM defaults. "layers_unmapped": IgnoredConfigConverter( @@ -511,8 +524,6 @@ class Apriel2StochasticMixerConverter(ConfigSectionConverter): @classmethod def _create_config_converters(cls) -> dict: - from fast_llm.layers.decoder.config import StochasticMixerSamplingStrategy - return { "mixers": TypedDictContainerConfigConverter( fast_llm_path=("mixers",), From 8314f1225507c828ebabed29c4652015b2671b8e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 7 May 2026 17:23:24 -0400 Subject: [PATCH 14/27] Address review feedback - Recursive architecture-coverage walker (item 1): the section-level check now collects every architecture-hint path under the active config tree and matches each against the declarations. Recursive primitives (Nested/Dispatch/TypedDictContainer/Ignored, plus Custom/ImportOnly when the author opts in) cover whole subtrees by prefix; non-recursive ones must list every leaf they consume. Fixes the silent-drop class of bug previously masked for any sub-config field claimed by a flat CustomConfigConverter. - Apriel2 rotary export bug fix (motivating leak for item 1): the export now emits the Llama3/Yarn scale parameters that round-trip via the pass-through import, instead of silently dropping them. - Pixtral attention migrated to declarative form (item 3): _create_config_converters overrides instead of an imperative export_config that bypassed the coverage check. - Apriel2 weight side cleanup (items 5, 6, 12): Apriel2MLPConverter owns its weight converters and the block delegates; the imperative Apriel2DecoderConverter is gone, replaced by per-shape get_converters on Apriel2FixedDecoderConverter / Apriel2PatternDecoderConverter dispatched via APRIEL2_DECODER_REGISTRY. - ImportOnlyConfigConverter primitive (item 11) collapses three asymmetric CustomConfigConverter sites in qwen2.py and llava.py. - Helper consolidation: drop external.py's _get_nested/_has_nested in favour of fast_llm.config.get_nested_dict_value (item 7); share assert_no_peft between Llama and Apriel2 base-model converters (item 10). Co-Authored-By: Claude Opus 4.7 --- fast_llm/engine/checkpoint/external.py | 191 +++++++++----- fast_llm/models/gpt/conversion/apriel.py | 2 + fast_llm/models/gpt/conversion/apriel2.py | 236 ++++++++++-------- fast_llm/models/gpt/conversion/llama.py | 14 +- fast_llm/models/gpt/conversion/qwen2.py | 9 +- .../models/multimodal/conversion/apriel2.py | 10 +- .../models/multimodal/conversion/llava.py | 59 +++-- 7 files changed, 322 insertions(+), 199 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index c6ad48dcb..17341c33c 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -7,7 +7,7 @@ import torch from fast_llm import __version__ -from fast_llm.config import Config, FieldHint, set_nested_dict_value +from fast_llm.config import Config, FieldHint, get_nested_dict_value, set_nested_dict_value from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointLoadMetadataConfig from fast_llm.engine.checkpoint.state_dict import StateDictCheckpointHandler @@ -19,29 +19,6 @@ logger = logging.getLogger(__name__) -_MISSING = object() - - -def _get_nested(d: dict, path: tuple[str, ...], default=_MISSING): - cur = d - for key in path: - if not isinstance(cur, dict) or key not in cur: - if default is _MISSING: - raise KeyError(f"Missing key {'.'.join(path)} in HF config dict") - return default - cur = cur[key] - return cur - - -def _has_nested(d: dict, path: tuple[str, ...]) -> bool: - cur = d - for key in path: - if not isinstance(cur, dict) or key not in cur: - return False - cur = cur[key] - return True - - def _get_attr_path(config: Config, path: tuple[str, ...]) -> typing.Any: cur = config for name in path: @@ -49,6 +26,30 @@ def _get_attr_path(config: Config, path: tuple[str, ...]) -> typing.Any: return cur +def _collect_architecture_paths(config: Config) -> list[tuple[str, ...]]: + """Walk ``config`` and return every architecture-hint field path reachable from it. + + Descends into any field whose runtime value is itself a :class:`Config`, so the path list reflects the + actual instance (e.g. only ``Llama3RotaryConfig`` fields when ``config.rotary`` is one). The caller uses + the list to verify each path is claimed by some declaration. + """ + paths: list[tuple[str, ...]] = [] + + def walk(node: Config, prefix: tuple[str, ...]) -> None: + for name, field in type(node).fields(): + if field._field_type != dataclasses._FIELD or not field.init: + continue + full = prefix + (name,) + if field.hint == FieldHint.architecture: + paths.append(full) + value = getattr(node, name) + if isinstance(value, Config): + walk(value, full) + + walk(config, ()) + return paths + + # ============================================================ # Config conversion primitives (declarative) # ============================================================ @@ -60,18 +61,21 @@ class ConfigConverter(abc.ABC): Each primitive owns a set of ``fast_llm_paths`` (tuples of attribute names rooted at the section's config) and ``hf_paths`` (tuples of dict keys rooted at the section's HF subdict). The walker calls ``export_to`` to produce HF entries from a Fast-LLM config object, and ``import_to`` to produce a Fast-LLM config dict from an HF dict. + + ``recurses`` controls how :meth:`ConfigSectionConverter._check_architecture_coverage` interprets the paths: + + * ``recurses = False`` (default) — paths are exact-match leaves. Every architecture-hint field at every depth + under the section's config class must be exactly listed by some declaration. + * ``recurses = True`` — paths are recursive prefixes covering the entire subtree. Used by primitives that + delegate to a sub-converter that runs its own coverage check (Nested/Dispatch/TypedDictContainer), by + :class:`IgnoredConfigConverter` (the format intentionally does not represent the subtree), and by + :class:`CustomConfigConverter` when its author opts in (escape hatch for cases like rotary that don't + decompose into per-leaf renames). """ fast_llm_paths: tuple[tuple[str, ...], ...] = () hf_paths: tuple[tuple[str, ...], ...] = () - - @property - def consumed_fast_llm_fields(self) -> set[str]: - """Top-level Fast-LLM field names this primitive consumes at the current section level. - - Used by the section walker for the architecture-hint coverage check. - """ - return {path[0] for path in self.fast_llm_paths if path} + recurses: typing.ClassVar[bool] = False @abc.abstractmethod def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: ... @@ -92,7 +96,7 @@ def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: set_nested_dict_value(hf_out, self.hf_paths[0], value) def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: - value = _get_nested(hf_dict, self.hf_paths[0]) + value = get_nested_dict_value(hf_dict, self.hf_paths[0]) set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], value) @@ -110,9 +114,11 @@ def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: set_nested_dict_value(hf_out, self.hf_paths[0], self._value) def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: - if _has_nested(hf_dict, self.hf_paths[0]): - actual = _get_nested(hf_dict, self.hf_paths[0]) - Assert.eq(actual, self._value) + try: + actual = get_nested_dict_value(hf_dict, self.hf_paths[0]) + except KeyError: + return + Assert.eq(actual, self._value) class ConstantImportConfigConverter(ConfigConverter): @@ -155,9 +161,9 @@ def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: set_nested_dict_value(hf_out, self.hf_paths[0], value) def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: - if _has_nested(hf_dict, self.hf_paths[0]): - value = _get_nested(hf_dict, self.hf_paths[0]) - else: + try: + value = get_nested_dict_value(hf_dict, self.hf_paths[0]) + except KeyError: value = self._hf_default_fn(hf_dict) set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], value) @@ -179,19 +185,25 @@ def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: set_nested_dict_value(hf_out, self.hf_paths[0], value) def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: - if _has_nested(hf_dict, self.hf_paths[0]): - value = _get_nested(hf_dict, self.hf_paths[0]) - if value != self._sentinel: - set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], value) + try: + value = get_nested_dict_value(hf_dict, self.hf_paths[0]) + except KeyError: + return + if value != self._sentinel: + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], value) class IgnoredConfigConverter(ConfigConverter): """Declares Fast-LLM architecture fields as intentionally not converted by this format. Use when the HF format has no representation for the field and the Fast-LLM default round-trips correctly. - Acts as a no-op on both directions while satisfying the architecture-coverage check. + Acts as a no-op on both directions while satisfying the architecture-coverage check. The claim covers the + entire subtree under each listed path: deeper architecture fields are also implicitly ignored, on the + assumption that a format which does not represent the parent likewise does not represent its children. """ + recurses: typing.ClassVar[bool] = True + def __init__(self, *fast_llm_paths: tuple[str, ...]): self.fast_llm_paths = fast_llm_paths self.hf_paths = () @@ -210,6 +222,11 @@ class CustomConfigConverter(ConfigConverter): not declared — there is no symmetric HF-side coverage check yet, so an ``hf_paths`` argument would be cosmetic. Cross-field validators that produce nothing on the HF side belong on :py:meth:`ConfigSectionConverter._validate_export` instead; this primitive is for shape-changing transforms. + + Pass ``recurses=True`` when the converter genuinely owns a sub-config subtree (e.g. rotary, per-layer biases) — + the listed paths then act as recursive prefixes and the architecture-coverage check stops at them. The author + is trusted to handle every architecture field of the claimed subtree; prefer Nested/Dispatch primitives when + the subtree decomposes cleanly. """ def __init__( @@ -217,10 +234,12 @@ def __init__( fast_llm_paths: tuple[tuple[str, ...], ...], export_fn: typing.Callable[[Config], dict], import_fn: typing.Callable[[dict], dict], + recurses: bool = False, ): self.fast_llm_paths = fast_llm_paths self._export_fn = export_fn self._import_fn = import_fn + self.recurses = recurses def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: produced = self._export_fn(fast_llm_config) @@ -233,6 +252,41 @@ def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: set_nested_dict_value(fast_llm_out, path, value) +class ImportOnlyConfigConverter(ConfigConverter): + """One-way mapping that runs only on import; emits nothing on export. + + Used when the HF format derives a Fast-LLM field from sibling fields (e.g. ``head_size`` from + ``hidden_size // num_attention_heads`` in Qwen2) or implies a value the Fast-LLM side stores + explicitly (e.g. Qwen2's hardcoded Q/K/V biases, Pixtral's mirrored ``patch_size`` ↔ ``patch_width``). + On export the field is redundant and validated through ``_validate_export``; on import the + ``import_fn`` produces the Fast-LLM dict entries. The fast_llm_paths still register as consumed + for the architecture-coverage check. + + Pass ``recurses=True`` when the converter populates a sub-config subtree (e.g. Qwen2's per-layer + biases that target ``query_layer``/``key_layer``/...). Same trade-off as + :class:`CustomConfigConverter`: the listed paths become recursive prefixes and the framework no + longer enforces leaf coverage under them. + """ + + def __init__( + self, + fast_llm_paths: tuple[tuple[str, ...], ...], + import_fn: typing.Callable[[dict], dict], + recurses: bool = False, + ): + self.fast_llm_paths = fast_llm_paths + self._import_fn = import_fn + self.recurses = recurses + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + return + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + produced = self._import_fn(hf_dict) + for path, value in produced.items(): + set_nested_dict_value(fast_llm_out, path, value) + + class NestedConfigConverter(ConfigConverter): """Recurse into a fixed-typed sub-config field via another section converter class. @@ -243,6 +297,8 @@ class NestedConfigConverter(ConfigConverter): that mirror Fast-LLM's modular layout (e.g. Apriel2's ``"decoder": {...}`` and ``"head": {...}`` blocks). """ + recurses: typing.ClassVar[bool] = True + def __init__( self, fast_llm_path: tuple[str, ...], @@ -266,7 +322,7 @@ def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: set_nested_dict_value(hf_out, self._hf_path, sub_hf) def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: - sub_hf = _get_nested(hf_dict, self._hf_path) if self._hf_path is not None else hf_dict + sub_hf = get_nested_dict_value(hf_dict, self._hf_path) if self._hf_path is not None else hf_dict sub_fast_llm = self._converter_class.import_config(sub_hf) set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], sub_fast_llm) @@ -278,6 +334,8 @@ class DispatchConfigConverter(ConfigConverter): Both registries (Fast-LLM type → converter class, HF discriminator → converter class) must agree at runtime. """ + recurses: typing.ClassVar[bool] = True + def __init__( self, fast_llm_path: tuple[str, ...], @@ -308,7 +366,7 @@ def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: hf_out[key] = value def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: - sub_hf = _get_nested(hf_dict, self.hf_paths[0]) if self.hf_paths else hf_dict + sub_hf = get_nested_dict_value(hf_dict, self.hf_paths[0]) if self.hf_paths else hf_dict type_name = sub_hf.get(self._hf_discriminator_key) converter_class = self._hf_to_class.get(type_name) if converter_class is None: @@ -334,6 +392,8 @@ class TypedDictContainerConfigConverter(ConfigConverter): is then omitted on export and ignored on import. """ + recurses: typing.ClassVar[bool] = True + def __init__( self, fast_llm_path: tuple[str, ...], @@ -370,7 +430,7 @@ def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: set_nested_dict_value(hf_out, self.hf_paths[0], out) def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: - sub_hf_dict = _get_nested(hf_dict, self.hf_paths[0]) + sub_hf_dict = get_nested_dict_value(hf_dict, self.hf_paths[0]) out: dict = {} for name, sub_hf in sub_hf_dict.items(): if self._homogeneous: @@ -448,12 +508,19 @@ def import_config(cls, hf_dict: dict) -> dict: @classmethod def _check_architecture_coverage(cls, config: Config, declarations: dict[str, ConfigConverter]) -> None: - """Raise if any architecture-hint field on the section's declared config class is not consumed. + """Raise if any architecture-hint field reachable from the section's config (recursively) is not consumed. + + Coverage is structural (based on field hints), not value-based: every architecture field at every depth + must be accounted for, even when it currently holds its Fast-LLM default. The walker descends into any + field whose runtime value is a :class:`Config`, collecting an architecture-leaf list, and matches each + leaf against the section's declarations: - Coverage is structural (based on field hints), not value-based: every architecture field must be - explicitly accounted for, even if it currently holds its Fast-LLM default. Sub-config fields are - consumed by ``NestedConfigConverter``/``DispatchConfigConverter``, which delegate the deeper coverage - check to the nested section's own converter. + * Recursive declarations (``recurses = True`` — Nested/Dispatch/TypedDictContainer/Ignored, plus Custom + when its author opts in) cover the entire subtree under each listed prefix. Nested/Dispatch/TypedDict + delegate to a sub-converter that runs its own coverage check; Ignored and recursive Custom assume the + author has handled the subtree. + * Non-recursive declarations (Rename, ConstantImport/Export, Default, Optional, ImportOnly, Custom by + default) must list every architecture leaf they consume by exact path. The check only runs when ``type(config)`` exactly matches ``cls.fast_llm_config_class`` — when the config is a strict subclass (e.g. ``MoEMLPConfig`` fed through ``LlamaMLPConverter`` declarations @@ -462,20 +529,20 @@ def _check_architecture_coverage(cls, config: Config, declarations: dict[str, Co """ if type(config) is not cls.fast_llm_config_class: return - consumed: set[str] = set() + explicit_paths: set[tuple[str, ...]] = set() + recursive_prefixes: list[tuple[str, ...]] = [] for converter in declarations.values(): - consumed |= converter.consumed_fast_llm_fields + if converter.recurses: + recursive_prefixes.extend(converter.fast_llm_paths) + else: + explicit_paths.update(converter.fast_llm_paths) missing: list[str] = [] - for name, field in type(config).fields(): - if field._field_type != dataclasses._FIELD: - continue - if not field.init: - continue - if field.hint != FieldHint.architecture: + for path in _collect_architecture_paths(config): + if path in explicit_paths: continue - if name in consumed: + if any(len(prefix) <= len(path) and path[: len(prefix)] == prefix for prefix in recursive_prefixes): continue - missing.append(name) + missing.append(".".join(path)) if missing: raise ValueError( f"{cls.__name__}: architecture-hint fields on {type(config).__name__} " diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index efa801799..f483808c0 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -86,6 +86,7 @@ def _create_config_converters(cls) -> dict: import_fn=lambda hf: { ("convolution_layer", "bias", "enabled"): hf.get("ssm_cfg", {}).get("conv_bias", True) }, + recurses=True, ), "dt_layer": CustomConfigConverter( fast_llm_paths=(("dt_layer",),), @@ -95,6 +96,7 @@ def _create_config_converters(cls) -> dict: import_fn=lambda hf: { ("dt_layer", "bias", "enabled"): hf.get("ssm_cfg", {}).get("dt_proj_bias", True) }, + recurses=True, ), # Per-layer biases that must round-trip implicitly via add_linear_biases (validated below). "linear_layers": IgnoredConfigConverter( diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 8d177467c..1f873a873 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -40,11 +40,12 @@ LlamaNormalizationConverter, MLPLayer2Converter, SplitWeightConverter, + assert_no_peft, get_parameter_converter, get_weight_and_bias_converters, ) from fast_llm.models.gpt.model import GPTModel -from fast_llm.utils import Assert, safe_merge_dicts +from fast_llm.utils import safe_merge_dicts # ============================================================ # Helpers @@ -79,18 +80,36 @@ def _apriel2_attention_rotary_export(config: AttentionConfig) -> dict: """Emit Apriel2's typed rotary subdict. Asymmetric with the Fast-LLM type only for the default→``mistral_1d`` rename; ``llama3``/``yarn`` round-trip - by name. Mirrors current behavior: only ``type`` and ``theta`` are emitted (scale fields are dropped). + by name. The scale parameters of ``llama3``/``yarn`` are emitted under their Fast-LLM field names since + the matching :func:`_apriel2_attention_rotary_import` is a wholesale pass-through of ``hf_dict["rotary"]``. """ rotary = config.rotary if type(rotary) is DefaultRotaryConfig: - rotary_type = "mistral_1d" - elif type(rotary) is Llama3RotaryConfig: - rotary_type = "llama3" - elif type(rotary) is YarnRotaryConfig: - rotary_type = "yarn" - else: - raise NotImplementedError(f"Unsupported rotary type: {type(rotary).__name__}") - return {("rotary",): {"type": rotary_type, "theta": rotary.theta}} + return {("rotary",): {"type": "mistral_1d", "theta": rotary.theta}} + if type(rotary) is Llama3RotaryConfig: + return { + ("rotary",): { + "type": "llama3", + "theta": rotary.theta, + "scale_factor": rotary.scale_factor, + "low_frequency_factor": rotary.low_frequency_factor, + "high_frequency_factor": rotary.high_frequency_factor, + "original_context_length": rotary.original_context_length, + } + } + if type(rotary) is YarnRotaryConfig: + return { + ("rotary",): { + "type": "yarn", + "theta": rotary.theta, + "scale_factor": rotary.scale_factor, + "attention_factor": rotary.attention_factor, + "beta_fast": rotary.beta_fast, + "beta_slow": rotary.beta_slow, + "original_context_length": rotary.original_context_length, + } + } + raise NotImplementedError(f"Unsupported rotary type: {type(rotary).__name__}") def _apriel2_attention_rotary_import(hf_dict: dict) -> dict: @@ -115,6 +134,7 @@ def _create_config_converters(cls) -> dict: fast_llm_paths=(("rotary",),), export_fn=_apriel2_attention_rotary_export, import_fn=_apriel2_attention_rotary_import, + recurses=True, ), # Apriel2 emits add_linear_biases only when False; the True default is implicit. "add_linear_biases": OptionalConfigConverter( @@ -125,6 +145,7 @@ def _create_config_converters(cls) -> dict: fast_llm_paths=tuple((name,) for name in layer_names), export_fn=lambda c: _per_layer_bias_export(c, layer_names), import_fn=lambda hf: _per_layer_bias_import(hf, layer_names), + recurses=True, ), "causal": IgnoredConfigConverter(("causal",)), "softmax_scale_power": IgnoredConfigConverter(("softmax_scale_power",)), @@ -218,6 +239,7 @@ def _create_config_converters(cls) -> dict: fast_llm_paths=(("convolution_layer",), ("dt_layer",)), export_fn=_apriel2_mamba_aux_export, import_fn=_apriel2_mamba_aux_import, + recurses=True, ), # Architecture fields with no HF counterpart; they round-trip at their Fast-LLM defaults. "layers_unmapped": IgnoredConfigConverter( @@ -304,6 +326,7 @@ def _create_config_converters(cls) -> dict: import_fn=lambda hf: ( {("convolution_layer",): hf["convolution_layer"]} if "convolution_layer" in hf else {} ), + recurses=True, ), # Architecture fields not surfaced in HF; round-trip at default. "layers_unmapped": IgnoredConfigConverter( @@ -385,11 +408,13 @@ def _create_config_converters(cls) -> dict: import_fn=lambda hf: ( {("convolution_layer",): hf["convolution_layer"]} if "convolution_layer" in hf else {} ), + recurses=True, ), "normalization": CustomConfigConverter( fast_llm_paths=(("normalization",),), export_fn=lambda c: {("normalization",): {"epsilon": c.normalization.epsilon}}, import_fn=lambda hf: ({("normalization",): hf["normalization"]} if "normalization" in hf else {}), + recurses=True, ), # Architecture fields not surfaced in HF; round-trip at default. "layers_unmapped": IgnoredConfigConverter( @@ -647,9 +672,58 @@ def _create_config_converters(cls) -> dict: fast_llm_paths=tuple((name,) for name in layer_names), export_fn=lambda c: _per_layer_bias_export(c, layer_names), import_fn=lambda hf: _per_layer_bias_import(hf, layer_names), + recurses=True, ), } + @classmethod + def get_converters( + cls, + config: MLPConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + layer_1_bias = ( + config.layer_1.bias.enabled if config.layer_1.bias.enabled is not None else config.add_linear_biases + ) + layer_2_bias = ( + config.layer_2.bias.enabled if config.layer_2.bias.enabled is not None else config.add_linear_biases + ) + if config.gated: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_1", + (f"{hf_prefix}.gate_proj", f"{hf_prefix}.up_proj"), + layer_1_bias, + SplitWeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_2", + f"{hf_prefix}.down_proj", + layer_2_bias, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ] + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_1", + f"{hf_prefix}.up_proj", + layer_1_bias, + WeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_2", + f"{hf_prefix}.down_proj", + layer_2_bias, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ] + class Apriel2BlockConverter(ConfigSectionConverter): fast_llm_config_class = DecoderBlockConfig @@ -680,68 +754,25 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - converter_class = APRIEL2_BLOCK_MIXER_REGISTRY.get(type(config.mixer)) - if converter_class is None: + mixer_converter_class = APRIEL2_BLOCK_MIXER_REGISTRY.get(type(config.mixer)) + if mixer_converter_class is None: raise ValueError(f"Unknown mixer type: {type(config.mixer)}") converters: list[WeightConverter] = list( - converter_class.get_converters( + mixer_converter_class.get_converters( config.mixer, f"{fast_llm_prefix}.mixer", f"{hf_prefix}.mixer", drop_on_export=drop_on_export, ) ) - - layer_1_bias = ( - config.mlp.layer_1.bias.enabled - if config.mlp.layer_1.bias.enabled is not None - else config.mlp.add_linear_biases - ) - layer_2_bias = ( - config.mlp.layer_2.bias.enabled - if config.mlp.layer_2.bias.enabled is not None - else config.mlp.add_linear_biases - ) - - if config.mlp.gated: - converters.extend( - [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - layer_1_bias, - SplitWeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - layer_2_bias, - MLPLayer2Converter, - drop_on_export=drop_on_export, - ), - ] - ) - else: - converters.extend( - [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - f"{hf_prefix}.mlp.up_proj", - layer_1_bias, - WeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - layer_2_bias, - MLPLayer2Converter, - drop_on_export=drop_on_export, - ), - ] + converters.extend( + Apriel2MLPConverter.get_converters( + config.mlp, + f"{fast_llm_prefix}.mlp", + f"{hf_prefix}.mlp", + drop_on_export=drop_on_export, ) - + ) converters.extend( [ *LlamaNormalizationConverter.get_converters( @@ -772,6 +803,24 @@ def _create_config_converters(cls) -> dict: "block": NestedConfigConverter(("block",), Apriel2BlockConverter, hf_path=("block",)), } + @classmethod + def get_converters( + cls, + config: FixedBlockSequenceConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + converters: list[WeightConverter] = [] + for block_index in range(config.num_blocks): + converters += Apriel2BlockConverter.get_converters( + config.block, + f"{fast_llm_prefix}.{block_index}", + f"{hf_prefix}.{block_index}", + drop_on_export=drop_on_export, + ) + return converters + class Apriel2PatternDecoderConverter(ConfigSectionConverter): fast_llm_config_class = PatternBlockSequenceConfig @@ -789,53 +838,32 @@ def _create_config_converters(cls) -> dict: ), } - -APRIEL2_DECODER_REGISTRY: dict = { - FixedBlockSequenceConfig: Apriel2FixedDecoderConverter, - PatternBlockSequenceConfig: Apriel2PatternDecoderConverter, -} - - -class Apriel2DecoderConverter: - """Imperative decoder dispatcher kept for the weight side. - - Config-side conversion is handled declaratively via :class:`Apriel2FixedDecoderConverter` and - :class:`Apriel2PatternDecoderConverter`, dispatched at the base-model level. - """ - - block_converter_class: typing.ClassVar[type[Apriel2BlockConverter]] = Apriel2BlockConverter - @classmethod def get_converters( cls, - config, + config: PatternBlockSequenceConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: converters: list[WeightConverter] = [] - if type(config) is FixedBlockSequenceConfig: - for block_index in range(config.num_blocks): - converters += cls.block_converter_class.get_converters( - config.block, - f"{fast_llm_prefix}.{block_index}", - f"{hf_prefix}.{block_index}", - drop_on_export, - ) - elif type(config) is PatternBlockSequenceConfig: - for block_index in range(config.num_blocks): - block_config = config.blocks[config.pattern[block_index % len(config.pattern)]] - converters += cls.block_converter_class.get_converters( - block_config, - f"{fast_llm_prefix}.{block_index}", - f"{hf_prefix}.{block_index}", - drop_on_export, - ) - else: - raise NotImplementedError(f"Unsupported config type: {type(config).__name__}") + for block_index in range(config.num_blocks): + block_config = config.blocks[config.pattern[block_index % len(config.pattern)]] + converters += Apriel2BlockConverter.get_converters( + block_config, + f"{fast_llm_prefix}.{block_index}", + f"{hf_prefix}.{block_index}", + drop_on_export=drop_on_export, + ) return converters +APRIEL2_DECODER_REGISTRY: dict = { + FixedBlockSequenceConfig: Apriel2FixedDecoderConverter, + PatternBlockSequenceConfig: Apriel2PatternDecoderConverter, +} + + class Apriel2HeadConverter(ConfigSectionConverter): fast_llm_config_class = LanguageModelHeadConfig @@ -880,7 +908,6 @@ def get_converters( class Apriel2BaseModelConverter(ConfigSectionConverter): fast_llm_config_class = GPTBaseModelConfig - decoder_converter_class: typing.ClassVar[type[Apriel2DecoderConverter]] = Apriel2DecoderConverter embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter head_converter_class: typing.ClassVar[type[Apriel2HeadConverter]] = Apriel2HeadConverter @@ -901,17 +928,18 @@ def _create_config_converters(cls) -> dict: @classmethod def _validate_export(cls, config: GPTBaseModelConfig) -> None: - from fast_llm.layers.common.peft.config import NoPeftConfig - - Assert.custom(isinstance, config.peft, NoPeftConfig) + assert_no_peft(config) # --- weight side (imperative) --- @classmethod def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + decoder_converter_class = APRIEL2_DECODER_REGISTRY.get(type(config.decoder)) + if decoder_converter_class is None: + raise NotImplementedError(f"Unsupported decoder type: {type(config.decoder).__name__}") return [ *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), - *cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.decoder.blocks"), + *decoder_converter_class.get_converters(config.decoder, "decoder", "model.decoder.blocks"), *cls.head_converter_class.get_converters(config.head, exported_config, "head"), ] diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 1888e6fd3..c8a41e8e2 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -45,6 +45,14 @@ logger = logging.getLogger(__name__) +def assert_no_peft(config: GPTBaseModelConfig) -> None: + """Reject any non-trivial PEFT config: HuggingFace formats serialize the base weights only, + so a configured LoRA (or other adapter) would be silently dropped on export.""" + from fast_llm.layers.common.peft.config import NoPeftConfig + + Assert.custom(isinstance, config.peft, NoPeftConfig) + + # ============================================================ # Weight converters (imperative — kept as-is during config migration) # ============================================================ @@ -351,6 +359,7 @@ def _create_config_converters(cls) -> dict: fast_llm_paths=(("rotary",),), export_fn=_llama_rotary_export, import_fn=_llama_rotary_import, + recurses=True, ), } @@ -608,6 +617,7 @@ def _decoder_import(hf_dict: dict) -> dict: fast_llm_paths=(("decoder",),), export_fn=_decoder_export, import_fn=_decoder_import, + recurses=True, ), "hidden_size": RenameConfigConverter(("hidden_size",), ("hidden_size",)), "tied_embedding_weight": RenameConfigConverter(("tied_embedding_weight",), ("tie_word_embeddings",)), @@ -618,9 +628,7 @@ def _decoder_import(hf_dict: dict) -> dict: @classmethod def _validate_export(cls, config: GPTBaseModelConfig) -> None: - from fast_llm.layers.common.peft.config import NoPeftConfig - - Assert.custom(isinstance, config.peft, NoPeftConfig) + assert_no_peft(config) # --- weight side (imperative) --- diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index 3d9d6f349..e8719a44e 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -3,7 +3,7 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantImportConfigConverter, - CustomConfigConverter, + ImportOnlyConfigConverter, WeightConverter, ) from fast_llm.layers.attention.config import AttentionConfig @@ -33,22 +33,21 @@ def _create_config_converters(cls) -> dict: # Qwen2 has no `attention_bias` HF field; the model always has Q/K/V biases enabled and no dense bias. out["add_linear_biases"] = ConstantImportConfigConverter(("add_linear_biases",), False) # Qwen2Config does not have `head_dim`; it is always derivable as `hidden_size // num_attention_heads`. - out["head_size"] = CustomConfigConverter( + out["head_size"] = ImportOnlyConfigConverter( fast_llm_paths=(("head_size",),), - export_fn=lambda config: {}, import_fn=lambda hf: {("head_size",): div(hf["hidden_size"], hf["num_attention_heads"])}, ) # Override Llama's blanket per-layer bias ignore with Qwen2's hardcoded layer biases. # On export the per-layer biases must be compatible with `add_linear_biases`; see ``_validate_export``. - out["linear_layers"] = CustomConfigConverter( + out["linear_layers"] = ImportOnlyConfigConverter( fast_llm_paths=(("query_layer",), ("key_layer",), ("value_layer",), ("dense_layer",)), - export_fn=lambda config: {}, import_fn=lambda hf: { ("query_layer",): {"bias": {"enabled": True}}, ("key_layer",): {"bias": {"enabled": True}}, ("value_layer",): {"bias": {"enabled": True}}, ("dense_layer",): {"bias": {"enabled": False}}, }, + recurses=True, ) return out diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index 8a947baaa..58d267e3c 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -11,8 +11,8 @@ from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionEncoderConfig from fast_llm.models.gpt.conversion.apriel2 import ( + APRIEL2_DECODER_REGISTRY, Apriel2BaseModelConverter, - Apriel2DecoderConverter, Apriel2HeadConverter, ) from fast_llm.models.gpt.conversion.llama import ( @@ -318,7 +318,6 @@ def get_converters( class Apriel2MultimodalBaseModelConverter: vision_model_converter_class: typing.ClassVar[type[Apriel2VisionModelConverter]] = Apriel2VisionModelConverter - decoder_converter_class: typing.ClassVar[type[Apriel2DecoderConverter]] = Apriel2DecoderConverter embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter head_converter_class: typing.ClassVar[type[Apriel2MultimodalHeadConverter]] = Apriel2MultimodalHeadConverter @@ -352,13 +351,14 @@ def export_config(cls, config: MultiModalBaseModelConfig) -> dict: @classmethod def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + decoder_converter_class = APRIEL2_DECODER_REGISTRY.get(type(config.decoder)) + if decoder_converter_class is None: + raise NotImplementedError(f"Unsupported decoder type: {type(config.decoder).__name__}") converters = [] if config.vision_encoder is not None: converters.extend(cls.vision_model_converter_class.get_converters(config.vision_encoder)) converters.extend(cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model")) - converters.extend( - cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.decoder.blocks") - ) + converters.extend(decoder_converter_class.get_converters(config.decoder, "decoder", "model.decoder.blocks")) converters.extend(cls.head_converter_class.get_converters(config.head, exported_config, "head")) return converters diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index b48b1d042..1fdb32378 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -9,6 +9,7 @@ ConstantImportConfigConverter, CustomConfigConverter, IgnoredConfigConverter, + ImportOnlyConfigConverter, NestedConfigConverter, RenameConfigConverter, WeightConverter, @@ -22,6 +23,7 @@ from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionEncoderConfig from fast_llm.models.gpt.conversion.llama import ( + _TRANSFORMERS_V4, LlamaAttentionConverter, LlamaBlockConverter, LlamaDecoderConverter, @@ -50,31 +52,49 @@ def _create_config_converters(cls) -> dict: } +def _pixtral_rotary_export(config: AttentionConfig) -> dict: + if _TRANSFORMERS_V4: + return {("rope_theta",): config.rotary.theta} + return {("rope_parameters",): {"rope_theta": config.rotary.theta, "rope_type": "default"}} + + +def _pixtral_rotary_import(hf_dict: dict) -> dict: + if "rope_parameters" in hf_dict: + theta = hf_dict["rope_parameters"]["rope_theta"] + else: + theta = hf_dict["rope_theta"] + return {("rotary",): {"type": "default_2d", "theta": theta}} + + class PixtralAttentionConverter(LlamaAttentionConverter): @classmethod - def import_config(cls, config: dict) -> dict: - config["num_key_value_heads"] = config["num_attention_heads"] - config["attention_bias"] = False - out = super().import_config(config) - out["rotary"]["type"] = "default_2d" - out["causal"] = False + def _create_config_converters(cls) -> dict: + out = super()._create_config_converters() + # PixtralConfig hardcodes Q/K/V/O biases off and does not surface ``attention_bias``. + out["add_linear_biases"] = ConstantImportConfigConverter(("add_linear_biases",), False) + # Pixtral attention is non-causal (vision encoder). + out["causal"] = ConstantImportConfigConverter(("causal",), False) + # No GQA in Pixtral; ``head_groups`` derives from ``num_attention_heads`` on import and is redundant + # on export (``_validate_export`` enforces equality with ``heads``). + out["head_groups"] = ImportOnlyConfigConverter( + fast_llm_paths=(("head_groups",),), + import_fn=lambda hf: {("head_groups",): hf["num_attention_heads"]}, + ) + # Pixtral always uses 2D rotary; only ``theta`` round-trips. The flat (v4) vs ``rope_parameters`` (v5) + # layout follows the active transformers major version, mirroring the Llama parent. + out["rotary"] = CustomConfigConverter( + fast_llm_paths=(("rotary",),), + export_fn=_pixtral_rotary_export, + import_fn=_pixtral_rotary_import, + recurses=True, + ) return out @classmethod - def export_config(cls, config: AttentionConfig) -> dict: - cls._validate_export(config) - Assert.eq(config.softmax_scale_power, 0.5) + def _validate_export(cls, config: AttentionConfig) -> None: + super()._validate_export(config) Assert.is_(type(config.rotary), Rotary2DConfig) - assert not config.add_linear_biases - assert not config.causal Assert.eq(config.head_groups, config.heads) - return { - "num_attention_heads": config.heads, - "attention_dropout": config.dropout, - "rope_theta": config.rotary.theta, - # Not in PixtralConfig, but needed for consistency check in LlavaVisionModelConverter. - "head_dim": config.head_size, - } class PixtralBlockConverter(LlamaBlockConverter): @@ -134,9 +154,8 @@ def _create_config_converters(cls) -> dict: return { "patch_height": RenameConfigConverter(("patch_height",), ("patch_size",)), # Pixtral has one `patch_size`; mirror it to `patch_width` on import and validate equality on export. - "patch_width": CustomConfigConverter( + "patch_width": ImportOnlyConfigConverter( fast_llm_paths=(("patch_width",),), - export_fn=lambda c: {}, import_fn=lambda hf: {("patch_width",): hf["patch_size"]}, ), # `input_channels` is a derived cached_property pinned to 3; assert on import, emit on export. From eb9f179fc2e5ed24e5463b29d74064f35b0615cb Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 May 2026 14:10:45 -0400 Subject: [PATCH 15/27] Address second review round MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Subtree drops are now visible at the declaration site (review item 1). Five Custom converters previously claimed a sub-config subtree via ``recurses=True`` while only round-tripping a fraction of its architecture leaves; each is now non-recursive (lists every leaf it actually round-trips) with sibling ``IgnoredConfigConverter`` entries for the leaves the format drops on purpose. Sites: Apriel mamba ``convolution_layer`` and ``dt_layer``, Apriel2 GDN ``convolution_layer``, Apriel2 KDA ``convolution_layer`` and ``normalization``. - Architecture-coverage walker now descends into ``dict[str, Config]`` and list/tuple-of-Config fields (item 2). Previously masked by ``TypedDictContainerConfigConverter.recurses=True``; the walker now matches what the docstring claims. - Coverage error gains a hint when missing paths share a top-level prefix that is claimed non-recursively (item 3 — message half only): suggests Nested/Dispatch or ``recurses=True`` on Custom/ImportOnly. No new ``recurses`` kwarg on the base primitives. - Single ``effective_bias(layer_config, default)`` helper in llama.py replaces three near-duplicates (item 4): ``_resolve_bias_enabled`` in apriel.py, ``_get_effective_bias`` in apriel2.py, and the inline ternary in ``Apriel2MLPConverter``. - Apriel2 decoder dispatch lookup lifted into module-level ``get_apriel2_decoder_converter(decoder)`` (item 6); used by both the text and multimodal base-model converters. Co-Authored-By: Claude Opus 4.7 --- fast_llm/engine/checkpoint/external.py | 37 +++++++--- fast_llm/models/gpt/conversion/apriel.py | 42 +++++------ fast_llm/models/gpt/conversion/apriel2.py | 71 +++++++++++-------- fast_llm/models/gpt/conversion/llama.py | 5 ++ .../models/multimodal/conversion/apriel2.py | 11 +-- 5 files changed, 103 insertions(+), 63 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 17341c33c..3d31f156f 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -29,12 +29,22 @@ def _get_attr_path(config: Config, path: tuple[str, ...]) -> typing.Any: def _collect_architecture_paths(config: Config) -> list[tuple[str, ...]]: """Walk ``config`` and return every architecture-hint field path reachable from it. - Descends into any field whose runtime value is itself a :class:`Config`, so the path list reflects the - actual instance (e.g. only ``Llama3RotaryConfig`` fields when ``config.rotary`` is one). The caller uses - the list to verify each path is claimed by some declaration. + Descends into any field whose runtime value is a :class:`Config`, a ``dict[str, Config]`` + (paths are extended with the entry's string key), or a ``list[Config]`` (paths are extended + with the entry's index as a string), so the path list reflects the actual instance. """ paths: list[tuple[str, ...]] = [] + def descend(value: typing.Any, prefix: tuple[str, ...]) -> None: + if isinstance(value, Config): + walk(value, prefix) + elif isinstance(value, dict): + for key, sub in value.items(): + descend(sub, prefix + (str(key),)) + elif isinstance(value, (list, tuple)): + for index, sub in enumerate(value): + descend(sub, prefix + (str(index),)) + def walk(node: Config, prefix: tuple[str, ...]) -> None: for name, field in type(node).fields(): if field._field_type != dataclasses._FIELD or not field.init: @@ -42,9 +52,7 @@ def walk(node: Config, prefix: tuple[str, ...]) -> None: full = prefix + (name,) if field.hint == FieldHint.architecture: paths.append(full) - value = getattr(node, name) - if isinstance(value, Config): - walk(value, full) + descend(getattr(node, name), full) walk(config, ()) return paths @@ -536,17 +544,28 @@ def _check_architecture_coverage(cls, config: Config, declarations: dict[str, Co recursive_prefixes.extend(converter.fast_llm_paths) else: explicit_paths.update(converter.fast_llm_paths) - missing: list[str] = [] + missing: list[tuple[str, ...]] = [] for path in _collect_architecture_paths(config): if path in explicit_paths: continue if any(len(prefix) <= len(path) and path[: len(prefix)] == prefix for prefix in recursive_prefixes): continue - missing.append(".".join(path)) + missing.append(path) if missing: + # If every missing path shares a top-level prefix that IS claimed (just non-recursively), + # the contributor likely needs a recursive primitive there — surface that as a hint. + shared_prefixes = {path[:1] for path in missing if path[:1] in explicit_paths} + hint = "" + if shared_prefixes: + names = sorted(prefix[0] for prefix in shared_prefixes) + hint = ( + f" (declarations for {names} claim the parent path non-recursively; " + f"either list every architecture sub-field or switch to Nested/Dispatch — " + f"or pass ``recurses=True`` to a Custom/ImportOnly converter when claiming the whole subtree)" + ) raise ValueError( f"{cls.__name__}: architecture-hint fields on {type(config).__name__} " - f"have no converter declaration: {missing}" + f"have no converter declaration: {[ '.'.join(p) for p in missing ]}{hint}" ) diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index f483808c0..1e9c6a6c0 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -18,7 +18,11 @@ from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat -from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters +from fast_llm.models.gpt.conversion.llama import ( + effective_bias, + get_parameter_converter, + get_weight_and_bias_converters, +) from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, MistralBlockConverter, @@ -29,11 +33,6 @@ from fast_llm.utils import Assert, safe_merge_dicts -def _resolve_bias_enabled(layer_bias_enabled: bool | None, add_linear_biases: bool) -> bool: - """Per-layer bias falls back to the mixer-wide flag when unset, matching the imperative behaviour.""" - return add_linear_biases if layer_bias_enabled is None else layer_bias_enabled - - class AprielMambaConverter(ConfigSectionConverter): """Converts ``MambaConfig`` <-> Apriel hybrid SSM HF dict (``ssm_cfg`` subdict + root-level fallbacks). @@ -76,28 +75,31 @@ def _create_config_converters(cls) -> dict: ("ssm_cfg", "repeat_kv_before_conv"), hf_default_fn=lambda hf: True, ), - "convolution_layer": CustomConfigConverter( - fast_llm_paths=(("convolution_layer",),), + "convolution_layer_bias": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",), ("convolution_layer", "bias")), export_fn=lambda c: { - ("ssm_cfg", "conv_bias"): _resolve_bias_enabled( - c.convolution_layer.bias.enabled, c.add_linear_biases - ) + ("ssm_cfg", "conv_bias"): effective_bias(c.convolution_layer, c.add_linear_biases) }, import_fn=lambda hf: { ("convolution_layer", "bias", "enabled"): hf.get("ssm_cfg", {}).get("conv_bias", True) }, - recurses=True, ), - "dt_layer": CustomConfigConverter( - fast_llm_paths=(("dt_layer",),), - export_fn=lambda c: { - ("ssm_cfg", "dt_proj_bias"): _resolve_bias_enabled(c.dt_layer.bias.enabled, c.add_linear_biases) - }, + # CausalConv1dConfig fields not represented in Apriel HF: weight rides the tensor side via + # ``conv1d.weight``; kernel_size/activation round-trip implicitly at the Fast-LLM defaults. + "convolution_layer_unmapped": IgnoredConfigConverter( + ("convolution_layer", "weight"), + ("convolution_layer", "kernel_size"), + ("convolution_layer", "activation"), + ), + "dt_layer_bias": CustomConfigConverter( + fast_llm_paths=(("dt_layer",), ("dt_layer", "bias")), + export_fn=lambda c: {("ssm_cfg", "dt_proj_bias"): effective_bias(c.dt_layer, c.add_linear_biases)}, import_fn=lambda hf: { ("dt_layer", "bias", "enabled"): hf.get("ssm_cfg", {}).get("dt_proj_bias", True) }, - recurses=True, ), + # AffineLinearConfig.weight rides the tensor side via ``dt_proj.weight``. + "dt_layer_unmapped": IgnoredConfigConverter(("dt_layer", "weight")), # Per-layer biases that must round-trip implicitly via add_linear_biases (validated below). "linear_layers": IgnoredConfigConverter( ("z_layer",), @@ -152,13 +154,13 @@ def get_converters( *get_weight_and_bias_converters( f"{fast_llm_prefix}.dt_proj", f"{hf_prefix}.dt_proj", - _resolve_bias_enabled(config.dt_layer.bias.enabled, config.add_linear_biases), + effective_bias(config.dt_layer, config.add_linear_biases), drop_on_export=drop_on_export, ), *get_weight_and_bias_converters( f"{fast_llm_prefix}.convolution", f"{hf_prefix}.conv1d", - _resolve_bias_enabled(config.convolution_layer.bias.enabled, config.add_linear_biases), + effective_bias(config.convolution_layer, config.add_linear_biases), drop_on_export=drop_on_export, ), get_parameter_converter( diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 1f873a873..cc94913ae 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -41,6 +41,7 @@ MLPLayer2Converter, SplitWeightConverter, assert_no_peft, + effective_bias, get_parameter_converter, get_weight_and_bias_converters, ) @@ -153,12 +154,6 @@ def _create_config_converters(cls) -> dict: # --- weight side (imperative) --- - @classmethod - def _get_effective_bias(cls, layer_config, default: bool) -> bool: - if layer_config.bias.enabled is not None: - return layer_config.bias.enabled - return default - @classmethod def get_converters( cls, @@ -167,10 +162,10 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - q_bias = cls._get_effective_bias(config.query_layer, config.add_linear_biases) - k_bias = cls._get_effective_bias(config.key_layer, config.add_linear_biases) - v_bias = cls._get_effective_bias(config.value_layer, config.add_linear_biases) - o_bias = cls._get_effective_bias(config.dense_layer, config.add_linear_biases) + q_bias = effective_bias(config.query_layer, config.add_linear_biases) + k_bias = effective_bias(config.key_layer, config.add_linear_biases) + v_bias = effective_bias(config.value_layer, config.add_linear_biases) + o_bias = effective_bias(config.dense_layer, config.add_linear_biases) # k_proj and v_proj are merged in Fast-LLM's key_value layer; treat as biased only if both sides agree. kv_bias = k_bias and v_bias @@ -320,13 +315,19 @@ def _create_config_converters(cls) -> dict: "key_heads": RenameConfigConverter(("key_heads",), ("key_heads",)), "key_head_dim": RenameConfigConverter(("key_head_dim",), ("key_head_dim",)), "value_head_dim": RenameConfigConverter(("value_head_dim",), ("value_head_dim",)), - "convolution_layer": CustomConfigConverter( - fast_llm_paths=(("convolution_layer",),), + "convolution_layer_kernel": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",), ("convolution_layer", "kernel_size")), export_fn=lambda c: {("convolution_layer",): {"kernel_size": c.convolution_layer.kernel_size}}, import_fn=lambda hf: ( {("convolution_layer",): hf["convolution_layer"]} if "convolution_layer" in hf else {} ), - recurses=True, + ), + # CausalConv1dConfig sub-fields the Apriel2 HF format does not surface (weight rides the tensor + # side; bias/activation round-trip at their Fast-LLM defaults). + "convolution_layer_unmapped": IgnoredConfigConverter( + ("convolution_layer", "weight"), + ("convolution_layer", "bias"), + ("convolution_layer", "activation"), ), # Architecture fields not surfaced in HF; round-trip at default. "layers_unmapped": IgnoredConfigConverter( @@ -402,19 +403,28 @@ def _create_config_converters(cls) -> dict: return { "heads": RenameConfigConverter(("heads",), ("heads",)), "head_dim": RenameConfigConverter(("head_dim",), ("head_dim",)), - "convolution_layer": CustomConfigConverter( - fast_llm_paths=(("convolution_layer",),), + "convolution_layer_kernel": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",), ("convolution_layer", "kernel_size")), export_fn=lambda c: {("convolution_layer",): {"kernel_size": c.convolution_layer.kernel_size}}, import_fn=lambda hf: ( {("convolution_layer",): hf["convolution_layer"]} if "convolution_layer" in hf else {} ), - recurses=True, ), - "normalization": CustomConfigConverter( - fast_llm_paths=(("normalization",),), + # CausalConv1dConfig sub-fields not surfaced in HF (same as :class:`Apriel2GatedDeltaNetConverter`). + "convolution_layer_unmapped": IgnoredConfigConverter( + ("convolution_layer", "weight"), + ("convolution_layer", "bias"), + ("convolution_layer", "activation"), + ), + "normalization_epsilon": CustomConfigConverter( + fast_llm_paths=(("normalization",), ("normalization", "epsilon")), export_fn=lambda c: {("normalization",): {"epsilon": c.normalization.epsilon}}, import_fn=lambda hf: ({("normalization",): hf["normalization"]} if "normalization" in hf else {}), - recurses=True, + ), + # Other GatedRMSNormalizationConfig architecture fields are dropped on the HF side. + "normalization_unmapped": IgnoredConfigConverter( + ("normalization", "weight"), + ("normalization", "zero_centered"), ), # Architecture fields not surfaced in HF; round-trip at default. "layers_unmapped": IgnoredConfigConverter( @@ -684,12 +694,8 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - layer_1_bias = ( - config.layer_1.bias.enabled if config.layer_1.bias.enabled is not None else config.add_linear_biases - ) - layer_2_bias = ( - config.layer_2.bias.enabled if config.layer_2.bias.enabled is not None else config.add_linear_biases - ) + layer_1_bias = effective_bias(config.layer_1, config.add_linear_biases) + layer_2_bias = effective_bias(config.layer_2, config.add_linear_biases) if config.gated: return [ *get_weight_and_bias_converters( @@ -864,6 +870,14 @@ def get_converters( } +def get_apriel2_decoder_converter(decoder_config) -> "type[ConfigSectionConverter]": + """Look up the Apriel2 per-shape decoder converter for a given decoder config instance.""" + converter_class = APRIEL2_DECODER_REGISTRY.get(type(decoder_config)) + if converter_class is None: + raise NotImplementedError(f"Unsupported decoder type: {type(decoder_config).__name__}") + return converter_class + + class Apriel2HeadConverter(ConfigSectionConverter): fast_llm_config_class = LanguageModelHeadConfig @@ -934,12 +948,11 @@ def _validate_export(cls, config: GPTBaseModelConfig) -> None: @classmethod def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: - decoder_converter_class = APRIEL2_DECODER_REGISTRY.get(type(config.decoder)) - if decoder_converter_class is None: - raise NotImplementedError(f"Unsupported decoder type: {type(config.decoder).__name__}") return [ *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), - *decoder_converter_class.get_converters(config.decoder, "decoder", "model.decoder.blocks"), + *get_apriel2_decoder_converter(config.decoder).get_converters( + config.decoder, "decoder", "model.decoder.blocks" + ), *cls.head_converter_class.get_converters(config.head, exported_config, "head"), ] diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index c8a41e8e2..762d38c4d 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -53,6 +53,11 @@ def assert_no_peft(config: GPTBaseModelConfig) -> None: Assert.custom(isinstance, config.peft, NoPeftConfig) +def effective_bias(layer_config, default: bool) -> bool: + """Resolve a layer's effective bias flag: explicit ``bias.enabled`` if set, else the parent default.""" + return default if layer_config.bias.enabled is None else layer_config.bias.enabled + + # ============================================================ # Weight converters (imperative — kept as-is during config migration) # ============================================================ diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index 58d267e3c..225863fcc 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -11,9 +11,9 @@ from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionEncoderConfig from fast_llm.models.gpt.conversion.apriel2 import ( - APRIEL2_DECODER_REGISTRY, Apriel2BaseModelConverter, Apriel2HeadConverter, + get_apriel2_decoder_converter, ) from fast_llm.models.gpt.conversion.llama import ( LlamaEmbeddingsConverter, @@ -351,14 +351,15 @@ def export_config(cls, config: MultiModalBaseModelConfig) -> dict: @classmethod def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict) -> list[WeightConverter]: - decoder_converter_class = APRIEL2_DECODER_REGISTRY.get(type(config.decoder)) - if decoder_converter_class is None: - raise NotImplementedError(f"Unsupported decoder type: {type(config.decoder).__name__}") converters = [] if config.vision_encoder is not None: converters.extend(cls.vision_model_converter_class.get_converters(config.vision_encoder)) converters.extend(cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model")) - converters.extend(decoder_converter_class.get_converters(config.decoder, "decoder", "model.decoder.blocks")) + converters.extend( + get_apriel2_decoder_converter(config.decoder).get_converters( + config.decoder, "decoder", "model.decoder.blocks" + ) + ) converters.extend(cls.head_converter_class.get_converters(config.head, exported_config, "head")) return converters From 0588262fcbc630a14c17870725a04cd63fc36af4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 May 2026 16:14:38 -0400 Subject: [PATCH 16/27] Address third review round MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Apriel2 decoder converters expose `block_converter_class` ClassVar so subclasses can swap the block converter, mirroring the LlamaDecoderConverter polymorphism pattern. * `_create_config_converters` is memoized via `functools.cache` (keyed by cls), so per-class declarations are built once. Convert two `out = super(); out[k] = v` mutation patterns (qwen2, llava) to spread+new-dict so the cached parent dict is never mutated. * `NestedConfigConverter` auto-injects the HF `type` discriminator from the target converter's `hf_type_name`, mirroring `DispatchConfigConverter`/`TypedDictContainer`. Drops a manual `ConstantExportConfigConverter` from `Apriel2MLPConverter`. * Move architecture-coverage check to `tests/models/test_converters.py`, parametrized per-format. Walks each `HuggingfaceStateDictCheckpointHandler.base_model_converter_class` through the modular converter tree (Nested/Dispatch/TypedDict + `*_converter_class` ClassVars) and runs `check_architecture_coverage` on each `ConfigSectionConverter` node. The per-export runtime invocation is removed. * Same test verifies `OptionalConfigConverter` sentinels match the resolved field default — catches silent round-trip drift if a Fast-LLM default changes. * Two latent bugs surfaced and fixed by the new test: * `apriel.py` GDN/KDA converters were missing `convolution_layer` architecture claims. * `Apriel2MambaConverter.d_xb`/`dt_rank` misused `OptionalConfigConverter` (sentinel=None on a non-Optional int) - converted to `RenameConfigConverter`. Deferred to follow-up commit: HF-side coverage check on every import (item 10) - needs `hf_paths` audit across ~20 Custom/ImportOnly call sites and a flat-merge-aware walker. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/checkpoint/external.py | 37 +++-- fast_llm/models/gpt/conversion/apriel.py | 32 ++++- fast_llm/models/gpt/conversion/apriel2.py | 18 ++- fast_llm/models/gpt/conversion/qwen2.py | 43 +++--- .../models/multimodal/conversion/llava.py | 41 +++--- tests/models/test_converters.py | 129 ++++++++++++++++++ 6 files changed, 231 insertions(+), 69 deletions(-) create mode 100644 tests/models/test_converters.py diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 3d31f156f..bee31a509 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -1,5 +1,6 @@ import abc import dataclasses +import functools import logging import pathlib import typing @@ -303,6 +304,10 @@ class NestedConfigConverter(ConfigConverter): With ``hf_path`` set: the sub-converter's output is placed under that nested key. Use this for HF formats that mirror Fast-LLM's modular layout (e.g. Apriel2's ``"decoder": {...}`` and ``"head": {...}`` blocks). + + When the target ``converter_class`` declares ``hf_type_name``, an HF discriminator (``"type"`` by default) + is auto-injected on export and validated/stripped on import — matching DispatchConfigConverter's behavior + for homogeneous (single-target) cases. """ recurses: typing.ClassVar[bool] = True @@ -312,14 +317,18 @@ def __init__( fast_llm_path: tuple[str, ...], converter_class: "type[ConfigSectionConverter]", hf_path: tuple[str, ...] | None = None, + hf_discriminator_key: str = "type", ): self.fast_llm_paths = (fast_llm_path,) self._converter_class = converter_class self._hf_path = hf_path + self._hf_discriminator_key = hf_discriminator_key def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: sub_config = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) sub_hf = self._converter_class.export_config(sub_config) + if self._converter_class.hf_type_name is not None: + sub_hf = {self._hf_discriminator_key: self._converter_class.hf_type_name, **sub_hf} if self._hf_path is None: for key, value in sub_hf.items(): if key in hf_out: @@ -331,6 +340,9 @@ def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: sub_hf = get_nested_dict_value(hf_dict, self._hf_path) if self._hf_path is not None else hf_dict + if self._converter_class.hf_type_name is not None and self._hf_discriminator_key in sub_hf: + Assert.eq(sub_hf[self._hf_discriminator_key], self._converter_class.hf_type_name) + sub_hf = {key: value for key, value in sub_hf.items() if key != self._hf_discriminator_key} sub_fast_llm = self._converter_class.import_config(sub_hf) set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], sub_fast_llm) @@ -480,9 +492,13 @@ class ConfigSectionConverter(abc.ABC): hf_type_name: typing.ClassVar[str | None] = None @classmethod - @abc.abstractmethod + @functools.cache def _create_config_converters(cls) -> dict[str, ConfigConverter]: - """Return declarations keyed by stable string name. Subclasses override entries by re-declaring the key.""" + """Return declarations keyed by stable string name. Subclasses override entries by re-declaring the key. + + Cached per class — declarations are immutable and depend only on ``cls``. + """ + raise NotImplementedError @classmethod def _validate_export(cls, config: Config) -> None: @@ -498,11 +514,9 @@ def _validate_export(cls, config: Config) -> None: @classmethod def export_config(cls, config: Config) -> dict: """Convert a Fast-LLM config object to an HF config dict via this section's declarations.""" - declarations = cls._create_config_converters() - cls._check_architecture_coverage(config, declarations) cls._validate_export(config) out: dict = {} - for converter in declarations.values(): + for converter in cls._create_config_converters().values(): converter.export_to(config, out) return out @@ -515,7 +529,7 @@ def import_config(cls, hf_dict: dict) -> dict: return out @classmethod - def _check_architecture_coverage(cls, config: Config, declarations: dict[str, ConfigConverter]) -> None: + def check_architecture_coverage(cls, config: Config) -> None: """Raise if any architecture-hint field reachable from the section's config (recursively) is not consumed. Coverage is structural (based on field hints), not value-based: every architecture field at every depth @@ -530,13 +544,12 @@ def _check_architecture_coverage(cls, config: Config, declarations: dict[str, Co * Non-recursive declarations (Rename, ConstantImport/Export, Default, Optional, ImportOnly, Custom by default) must list every architecture leaf they consume by exact path. - The check only runs when ``type(config)`` exactly matches ``cls.fast_llm_config_class`` — when the - config is a strict subclass (e.g. ``MoEMLPConfig`` fed through ``LlamaMLPConverter`` declarations - before the dispatching ``MixtralMLPConverter`` overrides ``fast_llm_config_class``), the subclass - converter is responsible for declaring the additional fields and running its own check. + Invoked from a test fixture (``tests/models/test_converters.py``) — not from the production + export/import paths. Architecture coverage is a structural invariant of the converter declarations, + so it only needs to hold once per (converter, config-class) pair, not on every save. """ - if type(config) is not cls.fast_llm_config_class: - return + Assert.is_(type(config), cls.fast_llm_config_class) + declarations = cls._create_config_converters() explicit_paths: set[tuple[str, ...]] = set() recursive_prefixes: list[tuple[str, ...]] = [] for converter in declarations.values(): diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 1e9c6a6c0..75c6e1605 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -194,9 +194,19 @@ def _create_config_converters(cls) -> dict: "key_heads": RenameConfigConverter(("key_heads",), ("linear_attn_config", "gdn_num_key_heads")), "key_head_dim": RenameConfigConverter(("key_head_dim",), ("linear_attn_config", "gdn_key_head_dim")), "value_head_dim": RenameConfigConverter(("value_head_dim",), ("linear_attn_config", "gdn_value_head_dim")), - "convolution_kernel_size": RenameConfigConverter( - ("convolution_layer", "kernel_size"), - ("linear_attn_config", "gdn_linear_conv_kernel_size"), + "convolution_layer_kernel": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",), ("convolution_layer", "kernel_size")), + export_fn=lambda c: { + ("linear_attn_config", "gdn_linear_conv_kernel_size"): c.convolution_layer.kernel_size + }, + import_fn=lambda hf: { + ("convolution_layer", "kernel_size"): hf["linear_attn_config"]["gdn_linear_conv_kernel_size"] + }, + ), + "convolution_layer_unmapped": IgnoredConfigConverter( + ("convolution_layer", "weight"), + ("convolution_layer", "bias"), + ("convolution_layer", "activation"), ), # Sub-configs without HF representation; coverage-only. "sub_configs": IgnoredConfigConverter( @@ -275,9 +285,19 @@ def _create_config_converters(cls) -> dict: return { "head_dim": RenameConfigConverter(("head_dim",), ("linear_attn_config", "head_dim")), "heads": RenameConfigConverter(("heads",), ("linear_attn_config", "num_heads")), - "convolution_kernel_size": RenameConfigConverter( - ("convolution_layer", "kernel_size"), - ("linear_attn_config", "short_conv_kernel_size"), + "convolution_layer_kernel": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",), ("convolution_layer", "kernel_size")), + export_fn=lambda c: { + ("linear_attn_config", "short_conv_kernel_size"): c.convolution_layer.kernel_size + }, + import_fn=lambda hf: { + ("convolution_layer", "kernel_size"): hf["linear_attn_config"]["short_conv_kernel_size"] + }, + ), + "convolution_layer_unmapped": IgnoredConfigConverter( + ("convolution_layer", "weight"), + ("convolution_layer", "bias"), + ("convolution_layer", "activation"), ), # Sub-configs without HF representation; coverage-only. "sub_configs": IgnoredConfigConverter( diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index cc94913ae..311e05202 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -7,7 +7,6 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConfigSectionConverter, - ConstantExportConfigConverter, ConstantImportConfigConverter, CustomConfigConverter, DispatchConfigConverter, @@ -228,8 +227,8 @@ def _create_config_converters(cls) -> dict: "state_size": RenameConfigConverter(("state_size",), ("state_size",)), "d_inner": RenameConfigConverter(("d_inner",), ("d_inner",)), "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("add_linear_biases",)), - "d_xb": OptionalConfigConverter(("d_xb",), ("d_xb",)), - "dt_rank": OptionalConfigConverter(("dt_rank",), ("dt_rank",)), + "d_xb": RenameConfigConverter(("d_xb",), ("d_xb",)), + "dt_rank": RenameConfigConverter(("dt_rank",), ("dt_rank",)), "aux": CustomConfigConverter( fast_llm_paths=(("convolution_layer",), ("dt_layer",)), export_fn=_apriel2_mamba_aux_export, @@ -667,9 +666,6 @@ class Apriel2MLPConverter(ConfigSectionConverter): def _create_config_converters(cls) -> dict: layer_names = ("layer_1", "layer_2") return { - # MLP is wrapped via NestedConfigConverter (no Dispatch discriminator), so emit the HF - # ``"type": "mlp"`` discriminator from inside this converter. - "hf_type": ConstantExportConfigConverter(("type",), "mlp"), "intermediate_size": RenameConfigConverter(("intermediate_size",), ("intermediate_size",)), "gated": RenameConfigConverter(("gated",), ("gated",)), "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("add_linear_biases",)), @@ -801,12 +797,13 @@ def get_converters( class Apriel2FixedDecoderConverter(ConfigSectionConverter): fast_llm_config_class = FixedBlockSequenceConfig hf_type_name = "fixed" + block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BlockConverter @classmethod def _create_config_converters(cls) -> dict: return { "num_blocks": RenameConfigConverter(("num_blocks",), ("num_blocks",)), - "block": NestedConfigConverter(("block",), Apriel2BlockConverter, hf_path=("block",)), + "block": NestedConfigConverter(("block",), cls.block_converter_class, hf_path=("block",)), } @classmethod @@ -819,7 +816,7 @@ def get_converters( ) -> list[WeightConverter]: converters: list[WeightConverter] = [] for block_index in range(config.num_blocks): - converters += Apriel2BlockConverter.get_converters( + converters += cls.block_converter_class.get_converters( config.block, f"{fast_llm_prefix}.{block_index}", f"{hf_prefix}.{block_index}", @@ -831,6 +828,7 @@ def get_converters( class Apriel2PatternDecoderConverter(ConfigSectionConverter): fast_llm_config_class = PatternBlockSequenceConfig hf_type_name = "pattern" + block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BlockConverter @classmethod def _create_config_converters(cls) -> dict: @@ -840,7 +838,7 @@ def _create_config_converters(cls) -> dict: "blocks": TypedDictContainerConfigConverter( fast_llm_path=("blocks",), hf_path=("blocks",), - registry={DecoderBlockConfig: Apriel2BlockConverter}, + registry={DecoderBlockConfig: cls.block_converter_class}, ), } @@ -855,7 +853,7 @@ def get_converters( converters: list[WeightConverter] = [] for block_index in range(config.num_blocks): block_config = config.blocks[config.pattern[block_index % len(config.pattern)]] - converters += Apriel2BlockConverter.get_converters( + converters += cls.block_converter_class.get_converters( block_config, f"{fast_llm_prefix}.{block_index}", f"{hf_prefix}.{block_index}", diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index e8719a44e..821619c46 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -29,27 +29,28 @@ class Qwen2AttentionConverter(LlamaAttentionConverter): @classmethod def _create_config_converters(cls) -> dict: - out = super()._create_config_converters() - # Qwen2 has no `attention_bias` HF field; the model always has Q/K/V biases enabled and no dense bias. - out["add_linear_biases"] = ConstantImportConfigConverter(("add_linear_biases",), False) - # Qwen2Config does not have `head_dim`; it is always derivable as `hidden_size // num_attention_heads`. - out["head_size"] = ImportOnlyConfigConverter( - fast_llm_paths=(("head_size",),), - import_fn=lambda hf: {("head_size",): div(hf["hidden_size"], hf["num_attention_heads"])}, - ) - # Override Llama's blanket per-layer bias ignore with Qwen2's hardcoded layer biases. - # On export the per-layer biases must be compatible with `add_linear_biases`; see ``_validate_export``. - out["linear_layers"] = ImportOnlyConfigConverter( - fast_llm_paths=(("query_layer",), ("key_layer",), ("value_layer",), ("dense_layer",)), - import_fn=lambda hf: { - ("query_layer",): {"bias": {"enabled": True}}, - ("key_layer",): {"bias": {"enabled": True}}, - ("value_layer",): {"bias": {"enabled": True}}, - ("dense_layer",): {"bias": {"enabled": False}}, - }, - recurses=True, - ) - return out + return { + **super()._create_config_converters(), + # Qwen2 has no `attention_bias` HF field; the model always has Q/K/V biases enabled and no dense bias. + "add_linear_biases": ConstantImportConfigConverter(("add_linear_biases",), False), + # Qwen2Config does not have `head_dim`; it is always derivable as `hidden_size // num_attention_heads`. + "head_size": ImportOnlyConfigConverter( + fast_llm_paths=(("head_size",),), + import_fn=lambda hf: {("head_size",): div(hf["hidden_size"], hf["num_attention_heads"])}, + ), + # Override Llama's blanket per-layer bias ignore with Qwen2's hardcoded layer biases. + # On export the per-layer biases must be compatible with `add_linear_biases`; see ``_validate_export``. + "linear_layers": ImportOnlyConfigConverter( + fast_llm_paths=(("query_layer",), ("key_layer",), ("value_layer",), ("dense_layer",)), + import_fn=lambda hf: { + ("query_layer",): {"bias": {"enabled": True}}, + ("key_layer",): {"bias": {"enabled": True}}, + ("value_layer",): {"bias": {"enabled": True}}, + ("dense_layer",): {"bias": {"enabled": False}}, + }, + recurses=True, + ), + } @classmethod def _validate_export(cls, config: AttentionConfig) -> None: diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 1fdb32378..5ad0abb5d 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -69,26 +69,27 @@ def _pixtral_rotary_import(hf_dict: dict) -> dict: class PixtralAttentionConverter(LlamaAttentionConverter): @classmethod def _create_config_converters(cls) -> dict: - out = super()._create_config_converters() - # PixtralConfig hardcodes Q/K/V/O biases off and does not surface ``attention_bias``. - out["add_linear_biases"] = ConstantImportConfigConverter(("add_linear_biases",), False) - # Pixtral attention is non-causal (vision encoder). - out["causal"] = ConstantImportConfigConverter(("causal",), False) - # No GQA in Pixtral; ``head_groups`` derives from ``num_attention_heads`` on import and is redundant - # on export (``_validate_export`` enforces equality with ``heads``). - out["head_groups"] = ImportOnlyConfigConverter( - fast_llm_paths=(("head_groups",),), - import_fn=lambda hf: {("head_groups",): hf["num_attention_heads"]}, - ) - # Pixtral always uses 2D rotary; only ``theta`` round-trips. The flat (v4) vs ``rope_parameters`` (v5) - # layout follows the active transformers major version, mirroring the Llama parent. - out["rotary"] = CustomConfigConverter( - fast_llm_paths=(("rotary",),), - export_fn=_pixtral_rotary_export, - import_fn=_pixtral_rotary_import, - recurses=True, - ) - return out + return { + **super()._create_config_converters(), + # PixtralConfig hardcodes Q/K/V/O biases off and does not surface ``attention_bias``. + "add_linear_biases": ConstantImportConfigConverter(("add_linear_biases",), False), + # Pixtral attention is non-causal (vision encoder). + "causal": ConstantImportConfigConverter(("causal",), False), + # No GQA in Pixtral; ``head_groups`` derives from ``num_attention_heads`` on import and is redundant + # on export (``_validate_export`` enforces equality with ``heads``). + "head_groups": ImportOnlyConfigConverter( + fast_llm_paths=(("head_groups",),), + import_fn=lambda hf: {("head_groups",): hf["num_attention_heads"]}, + ), + # Pixtral always uses 2D rotary; only ``theta`` round-trips. The flat (v4) vs ``rope_parameters`` (v5) + # layout follows the active transformers major version, mirroring the Llama parent. + "rotary": CustomConfigConverter( + fast_llm_paths=(("rotary",),), + export_fn=_pixtral_rotary_export, + import_fn=_pixtral_rotary_import, + recurses=True, + ), + } @classmethod def _validate_export(cls, config: AttentionConfig) -> None: diff --git a/tests/models/test_converters.py b/tests/models/test_converters.py new file mode 100644 index 000000000..9ce8b0893 --- /dev/null +++ b/tests/models/test_converters.py @@ -0,0 +1,129 @@ +"""Static checks on every checkpoint format's converter tree. + +For each registered ``HuggingfaceStateDictCheckpointHandler``, walk its modular converter structure — +``base_model_converter_class`` and the ``ConfigSectionConverter`` classes reached transitively through +``Nested``/``Dispatch``/``TypedDictContainer`` declarations — and verify, at every node: + +* Architecture-hint fields on ``cls.fast_llm_config_class`` are all consumed by some declaration. +* OptionalConfigConverter sentinels match the resolved field default. Otherwise an exported value equal + to the sentinel becomes absent on disk and re-imports as a different default, silently breaking round-trip. + +Replaces the per-export ``check_architecture_coverage`` invocation that used to happen on every save. +""" + +import typing + +import pytest + +# Force registration of every format handler. +import fast_llm.models.gpt.conversion.auto # noqa: F401 +import fast_llm.models.multimodal.conversion.auto # noqa: F401 +from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + DispatchConfigConverter, + NestedConfigConverter, + OptionalConfigConverter, + TypedDictContainerConfigConverter, + _get_attr_path, +) +from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.block.config import PatternBlockSequenceConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig, StochasticMixerConfig + +# Configs that don't default-construct cleanly need a minimal-valid factory. +_DEFAULT_FACTORIES: dict[type, typing.Callable[[], typing.Any]] = { + PatternBlockSequenceConfig: lambda: PatternBlockSequenceConfig( + blocks={"x": DecoderBlockConfig()}, + pattern=("x",), + ), + StochasticMixerConfig: lambda: StochasticMixerConfig( + mixers={"x": AttentionConfig()}, + main_mixer_name="x", + ), +} + + +def _default_instance(config_class: type) -> typing.Any: + factory = _DEFAULT_FACTORIES.get(config_class) + return factory() if factory is not None else config_class() + + +def _all_format_handlers() -> list[type[HuggingfaceStateDictCheckpointHandler]]: + seen: set[type[HuggingfaceStateDictCheckpointHandler]] = set() + out: list[type[HuggingfaceStateDictCheckpointHandler]] = [] + + def visit(cls: type) -> None: + for sub in cls.__subclasses__(): + if sub in seen: + continue + seen.add(sub) + # Concrete handlers declare a ``base_model_converter_class``; abstract intermediaries don't. + if getattr(sub, "base_model_converter_class", None) is not None: + out.append(sub) + visit(sub) + + visit(HuggingfaceStateDictCheckpointHandler) + return out + + +def _children(node: type) -> list[type]: + """Return every sub-converter class reachable from ``node``. + + Picks up two complementary structures: + * ``ConfigSectionConverter`` declarations — the ``_converter_class`` on each Nested/Dispatch/TypedDict. + * ``*_converter_class`` ClassVars — the polymorphism extension points used by aggregator nodes + (e.g. ``LlavaBaseModelConverter`` is not itself a section converter but exposes + ``vision_model_converter_class`` and ``language_model_converter_class``). + """ + out: list[type] = [] + if isinstance(node, type) and issubclass(node, ConfigSectionConverter): + for declaration in node._create_config_converters().values(): + if isinstance(declaration, NestedConfigConverter): + out.append(declaration._converter_class) + elif isinstance(declaration, (DispatchConfigConverter, TypedDictContainerConfigConverter)): + out.extend(declaration._registry.values()) + for name in dir(node): + if not name.endswith("_converter_class") or name == "base_model_converter_class": + continue + attr = getattr(node, name, None) + if isinstance(attr, type): + out.append(attr) + return out + + +def _walk(root: type) -> typing.Iterator[type]: + """Yield ``root`` and every converter class reachable from it (each at most once).""" + seen: set[type] = set() + stack: list[type] = [root] + while stack: + node = stack.pop() + if node in seen: + continue + seen.add(node) + yield node + stack.extend(_children(node)) + + +_HANDLERS = _all_format_handlers() + + +@pytest.mark.parametrize("handler_class", _HANDLERS, ids=lambda h: h.__name__) +def test_format_converter_tree(handler_class: type[HuggingfaceStateDictCheckpointHandler]) -> None: + """Walk the format's converter tree from ``base_model_converter_class``; check every section node.""" + for converter_class in _walk(handler_class.base_model_converter_class): + if not (isinstance(converter_class, type) and issubclass(converter_class, ConfigSectionConverter)): + continue + if getattr(converter_class, "fast_llm_config_class", None) is None: + continue + config = _default_instance(converter_class.fast_llm_config_class) + converter_class.check_architecture_coverage(config) + for name, declaration in converter_class._create_config_converters().items(): + if not isinstance(declaration, OptionalConfigConverter): + continue + path = declaration.fast_llm_paths[0] + default = _get_attr_path(config, path) + assert declaration._sentinel == default, ( + f"{converter_class.__name__}.{name}: sentinel {declaration._sentinel!r} " + f"does not match field default {default!r} at path {'.'.join(path)}" + ) From 7d013e1831329e3ea9be5c5e05caebf4fb767db0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 11 May 2026 16:31:34 -0400 Subject: [PATCH 17/27] Add HF-side coverage check on import Symmetric counterpart to the architecture-coverage check (already a test). Walks the HF config dict at the import boundary and raises on any key not consumed by some declaration in the converter tree. Catches transformers-version drift, manual edits, and corrupted configs at the point of import rather than as cryptic downstream failures. * ``ConfigConverter`` primitives gain a recursive ``_consumed_hf_paths`` walker. Nested/ Dispatch/TypedDictContainer with a fixed ``hf_path`` claim it as a subtree prefix; their flat-merge variants (``hf_path=None``) pull the sub-converter's claims up to the current level so a parent's check sees them. * ``CustomConfigConverter`` / ``ImportOnlyConfigConverter`` gain an ``hf_paths`` kwarg; every existing call site is audited and populated. ``IgnoredConfigConverter`` gains an ``hf_paths`` kwarg used for HF-only fields Fast-LLM intentionally does not consume (Mixtral router toggles, Qwen2 sliding-window machinery, Apriel2's default-injected ``embeddings`` subdict from ``Apriel2TextConfig``). * ``HuggingfaceStateDictCheckpointHandler`` runs the check from ``_import_config`` against the base-model converter. A class-level allowlist covers transformers' generic ``PretrainedConfig`` fields and inference-only metadata that's always permitted. The ``Apriel2`` text handler's override is updated to call the shared ``_check_hf_coverage`` helper. Non-``ConfigSectionConverter`` base-model converters (Llava aggregators) skip the check transparently. * ``LlamaBaseModelConverter``'s decoder Custom - which wraps the imperative ``LlamaDecoderConverter`` - auto-extends its ``hf_paths`` from the block converter's ``_consumed_hf_paths``, so Mistral/Mixtral/Qwen2/MTPLlama/Apriel inherit correct coverage. ``AprielBlockConverter`` (per-block-type dispatcher, also imperative) gets its own ``_consumed_hf_paths`` that unions across registered per-mixer block converters. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/checkpoint/external.py | 98 ++++++++++++++++--- fast_llm/engine/checkpoint/huggingface.py | 55 ++++++++++- fast_llm/models/gpt/conversion/apriel.py | 14 +++ fast_llm/models/gpt/conversion/apriel2.py | 14 +++ fast_llm/models/gpt/conversion/llama.py | 11 +++ fast_llm/models/gpt/conversion/mixtral.py | 5 + fast_llm/models/gpt/conversion/qwen2.py | 5 + .../models/multimodal/conversion/llava.py | 1 + 8 files changed, 191 insertions(+), 12 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index bee31a509..281bddf97 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -203,19 +203,20 @@ def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: class IgnoredConfigConverter(ConfigConverter): - """Declares Fast-LLM architecture fields as intentionally not converted by this format. + """Declares Fast-LLM architecture fields and/or HF dict keys as intentionally not converted by this format. - Use when the HF format has no representation for the field and the Fast-LLM default round-trips correctly. - Acts as a no-op on both directions while satisfying the architecture-coverage check. The claim covers the - entire subtree under each listed path: deeper architecture fields are also implicitly ignored, on the - assumption that a format which does not represent the parent likewise does not represent its children. + Use ``fast_llm_paths`` (positional) when Fast-LLM has architecture fields with no HF representation; the + Fast-LLM default round-trips. Use ``hf_paths`` (kw-only) when the HF format carries fields Fast-LLM does + not consume (generation-only toggles like Mixtral's ``router_aux_loss_coef``, Qwen2's ``sliding_window``). + Both kinds of claim are no-ops at conversion time and serve only the per-side coverage checks. The claim + covers the entire subtree under each listed path on the side it applies to. """ recurses: typing.ClassVar[bool] = True - def __init__(self, *fast_llm_paths: tuple[str, ...]): + def __init__(self, *fast_llm_paths: tuple[str, ...], hf_paths: tuple[tuple[str, ...], ...] = ()): self.fast_llm_paths = fast_llm_paths - self.hf_paths = () + self.hf_paths = hf_paths def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: return @@ -227,8 +228,7 @@ def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: class CustomConfigConverter(ConfigConverter): """Escape hatch for cross-field transforms (e.g., rotary, where one HF blob ↔ several Fast-LLM fields). - ``fast_llm_paths`` is declared so the coverage check sees the fields as consumed. The HF side is intentionally - not declared — there is no symmetric HF-side coverage check yet, so an ``hf_paths`` argument would be cosmetic. + ``fast_llm_paths`` and ``hf_paths`` are declared so the per-side coverage checks see the fields as consumed. Cross-field validators that produce nothing on the HF side belong on :py:meth:`ConfigSectionConverter._validate_export` instead; this primitive is for shape-changing transforms. @@ -243,9 +243,11 @@ def __init__( fast_llm_paths: tuple[tuple[str, ...], ...], export_fn: typing.Callable[[Config], dict], import_fn: typing.Callable[[dict], dict], + hf_paths: tuple[tuple[str, ...], ...] = (), recurses: bool = False, ): self.fast_llm_paths = fast_llm_paths + self.hf_paths = hf_paths self._export_fn = export_fn self._import_fn = import_fn self.recurses = recurses @@ -268,8 +270,8 @@ class ImportOnlyConfigConverter(ConfigConverter): ``hidden_size // num_attention_heads`` in Qwen2) or implies a value the Fast-LLM side stores explicitly (e.g. Qwen2's hardcoded Q/K/V biases, Pixtral's mirrored ``patch_size`` ↔ ``patch_width``). On export the field is redundant and validated through ``_validate_export``; on import the - ``import_fn`` produces the Fast-LLM dict entries. The fast_llm_paths still register as consumed - for the architecture-coverage check. + ``import_fn`` produces the Fast-LLM dict entries. The fast_llm_paths register as consumed for the + architecture-coverage check; ``hf_paths`` register as consumed for the HF-side check. Pass ``recurses=True`` when the converter populates a sub-config subtree (e.g. Qwen2's per-layer biases that target ``query_layer``/``key_layer``/...). Same trade-off as @@ -281,9 +283,11 @@ def __init__( self, fast_llm_paths: tuple[tuple[str, ...], ...], import_fn: typing.Callable[[dict], dict], + hf_paths: tuple[tuple[str, ...], ...] = (), recurses: bool = False, ): self.fast_llm_paths = fast_llm_paths + self.hf_paths = hf_paths self._import_fn = import_fn self.recurses = recurses @@ -528,6 +532,78 @@ def import_config(cls, hf_dict: dict) -> dict: converter.import_to(hf_dict, out) return out + @classmethod + @functools.cache + def _consumed_hf_paths(cls) -> frozenset[tuple[str, ...]]: + """Set of HF dict path prefixes consumed by this section's declaration tree. + + Each entry is a tuple-of-keys from the section's HF subdict root. The + :meth:`check_hf_coverage` walker treats every entry as a *recursive prefix* — once an input + path matches any prefix, descent into deeper sub-dicts stops. + + Recurses through: + * :class:`NestedConfigConverter` with ``hf_path=None`` (flat-merge): the sub-converter shares + the parent's HF namespace, so its claims are pulled up to this level. + * :class:`DispatchConfigConverter` with ``hf_paths=()`` (flat-merge): every registered class + contributes its claims at this level (the union covers all possible runtime types). + Otherwise the primitive's own ``hf_paths`` are used. + """ + paths: set[tuple[str, ...]] = set() + for declaration in cls._create_config_converters().values(): + if isinstance(declaration, NestedConfigConverter): + if declaration._hf_path is None: + paths |= declaration._converter_class._consumed_hf_paths() + if declaration._converter_class.hf_type_name is not None: + paths.add((declaration._hf_discriminator_key,)) + else: + paths.add(declaration._hf_path) + elif isinstance(declaration, DispatchConfigConverter): + if declaration.hf_paths: + paths.add(declaration.hf_paths[0]) + else: + paths.add((declaration._hf_discriminator_key,)) + for sub_class in declaration._registry.values(): + paths |= sub_class._consumed_hf_paths() + elif isinstance(declaration, TypedDictContainerConfigConverter): + paths.add(declaration.hf_paths[0]) + else: + for path in declaration.hf_paths: + if path: + paths.add(path) + return frozenset(paths) + + @classmethod + def check_hf_coverage(cls, hf_dict: dict, *, allowlist: frozenset[str] = frozenset()) -> None: + """Raise :class:`ValueError` if the input HF dict carries keys not consumed by any declaration. + + Walks ``hf_dict`` recursively. A path is considered covered if it (or any of its prefixes) is in + :meth:`_consumed_hf_paths`, or — for top-level keys — appears in ``allowlist``. Uncovered leaves + raise; uncovered sub-dicts trigger descent into their entries to surface the offending leaf path. + + Catches transformers-version drift, manual edits, and corrupted configs at the import boundary — + the symmetric counterpart to the architecture-coverage check (which is statically verified by + ``tests/models/test_converters.py``). + """ + prefixes = cls._consumed_hf_paths() + + def walk(value: typing.Any, path: tuple[str, ...]) -> None: + for length in range(1, len(path) + 1): + if path[:length] in prefixes: + return + if len(path) == 1 and path[0] in allowlist: + return + if isinstance(value, dict): + for key, sub in value.items(): + walk(sub, path + (key,)) + return + raise ValueError( + f"{cls.__name__}: HF config has unknown key '{'.'.join(path)}' (value: {value!r}). " + "Possible transformers-version mismatch, manual edit, or corrupted config." + ) + + for key, value in hf_dict.items(): + walk(value, (key,)) + @classmethod def check_architecture_coverage(cls, config: Config) -> None: """Raise if any architecture-hint field reachable from the section's config (recursively) is not consumed. diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 8cdb779dd..9074e72fc 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -9,7 +9,12 @@ from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, CheckpointSaveMetadataConfig -from fast_llm.engine.checkpoint.external import ExternalStateDictCheckpointHandler, WeightConverter, logger +from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + ExternalStateDictCheckpointHandler, + WeightConverter, + logger, +) from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert, safe_merge_dicts @@ -120,10 +125,58 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: }, ) + # Top-level HF metadata keys that are always permitted, regardless of the converter tree. + # Covers transformers' generic ``PretrainedConfig`` fields (always present after ``to_dict()``) + # plus a handful of widely-shared metadata that Fast-LLM intentionally does not store. + _HF_METADATA_ALLOWLIST: typing.ClassVar[frozenset[str]] = frozenset( + { + # transformers PretrainedConfig + "_name_or_path", + "architectures", + "auto_map", + "chunk_size_feed_forward", + "dtype", + "id2label", + "is_encoder_decoder", + "label2id", + "model_type", + "output_attentions", + "output_hidden_states", + "problem_type", + "return_dict", + "torch_dtype", + "transformers_version", + "use_cache", + # Token ids — generation/inference, not architecture. + "bos_token_id", + "decoder_start_token_id", + "eos_token_id", + "pad_token_id", + "sep_token_id", + # Initialization / pretraining metadata Fast-LLM does not consume. + "initializer_range", + "max_position_embeddings", + "pretraining_tp", + } + ) + + @classmethod + def _check_hf_coverage(cls, config: dict[str, typing.Any]) -> None: + """Run the HF-side coverage check at the import boundary. + + Skips silently when the format's base-model converter isn't a :class:`ConfigSectionConverter` + (e.g. multimodal aggregators built on top of imperative ``HuggingFaceBaseModelConverter``). + Subclasses that override :meth:`_import_config` should call this explicitly to keep the check + active. + """ + if issubclass(cls.base_model_converter_class, ConfigSectionConverter): + cls.base_model_converter_class.check_hf_coverage(config, allowlist=cls._HF_METADATA_ALLOWLIST) + @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: Assert.eq(config["model_type"], cls.get_huggingface_model_type()) Assert.eq(config["architectures"], [cls.architecture]) + cls._check_hf_coverage(config) return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)}) def _create_weight_converters(self) -> list[WeightConverter]: diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 75c6e1605..00fc52182 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -60,6 +60,7 @@ def _create_config_converters(cls) -> dict: ), "dt_rank": CustomConfigConverter( fast_llm_paths=(("dt_rank",),), + hf_paths=(("ssm_cfg", "dt_rank"),), export_fn=lambda c: {("ssm_cfg", "dt_rank"): c.dt_rank}, import_fn=lambda hf: { ("dt_rank",): ( @@ -77,6 +78,7 @@ def _create_config_converters(cls) -> dict: ), "convolution_layer_bias": CustomConfigConverter( fast_llm_paths=(("convolution_layer",), ("convolution_layer", "bias")), + hf_paths=(("ssm_cfg", "conv_bias"),), export_fn=lambda c: { ("ssm_cfg", "conv_bias"): effective_bias(c.convolution_layer, c.add_linear_biases) }, @@ -93,6 +95,7 @@ def _create_config_converters(cls) -> dict: ), "dt_layer_bias": CustomConfigConverter( fast_llm_paths=(("dt_layer",), ("dt_layer", "bias")), + hf_paths=(("ssm_cfg", "dt_proj_bias"),), export_fn=lambda c: {("ssm_cfg", "dt_proj_bias"): effective_bias(c.dt_layer, c.add_linear_biases)}, import_fn=lambda hf: { ("dt_layer", "bias", "enabled"): hf.get("ssm_cfg", {}).get("dt_proj_bias", True) @@ -196,6 +199,7 @@ def _create_config_converters(cls) -> dict: "value_head_dim": RenameConfigConverter(("value_head_dim",), ("linear_attn_config", "gdn_value_head_dim")), "convolution_layer_kernel": CustomConfigConverter( fast_llm_paths=(("convolution_layer",), ("convolution_layer", "kernel_size")), + hf_paths=(("linear_attn_config", "gdn_linear_conv_kernel_size"),), export_fn=lambda c: { ("linear_attn_config", "gdn_linear_conv_kernel_size"): c.convolution_layer.kernel_size }, @@ -287,6 +291,7 @@ def _create_config_converters(cls) -> dict: "heads": RenameConfigConverter(("heads",), ("linear_attn_config", "num_heads")), "convolution_layer_kernel": CustomConfigConverter( fast_llm_paths=(("convolution_layer",), ("convolution_layer", "kernel_size")), + hf_paths=(("linear_attn_config", "short_conv_kernel_size"),), export_fn=lambda c: { ("linear_attn_config", "short_conv_kernel_size"): c.convolution_layer.kernel_size }, @@ -462,6 +467,15 @@ def import_config(cls, config: dict, layout_name: str = "t") -> dict: def export_config(cls, config) -> dict: return cls._converter_classes[type(config.mixer)].export_config(config) + @classmethod + def _consumed_hf_paths(cls) -> frozenset[tuple[str, ...]]: + """Union of consumed HF paths across every per-mixer-type block converter — used by the parent's + decoder Custom to pre-claim Apriel's flat top-level keys for the HF coverage check.""" + paths: set[tuple[str, ...]] = set() + for sub in cls._converter_classes.values(): + paths |= sub._consumed_hf_paths() + return frozenset(paths) + @classmethod def get_converters( cls, diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 311e05202..7c7f3e383 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -132,6 +132,7 @@ def _create_config_converters(cls) -> dict: "head_size": RenameConfigConverter(("head_size",), ("head_size",)), "rotary": CustomConfigConverter( fast_llm_paths=(("rotary",),), + hf_paths=(("rotary",),), export_fn=_apriel2_attention_rotary_export, import_fn=_apriel2_attention_rotary_import, recurses=True, @@ -143,6 +144,7 @@ def _create_config_converters(cls) -> dict: "window_size": OptionalConfigConverter(("window_size",), ("window_size",)), "linear_layers": CustomConfigConverter( fast_llm_paths=tuple((name,) for name in layer_names), + hf_paths=tuple((name,) for name in layer_names), export_fn=lambda c: _per_layer_bias_export(c, layer_names), import_fn=lambda hf: _per_layer_bias_import(hf, layer_names), recurses=True, @@ -231,6 +233,7 @@ def _create_config_converters(cls) -> dict: "dt_rank": RenameConfigConverter(("dt_rank",), ("dt_rank",)), "aux": CustomConfigConverter( fast_llm_paths=(("convolution_layer",), ("dt_layer",)), + hf_paths=(("d_conv",), ("conv_bias",), ("dt_proj_bias",)), export_fn=_apriel2_mamba_aux_export, import_fn=_apriel2_mamba_aux_import, recurses=True, @@ -316,6 +319,7 @@ def _create_config_converters(cls) -> dict: "value_head_dim": RenameConfigConverter(("value_head_dim",), ("value_head_dim",)), "convolution_layer_kernel": CustomConfigConverter( fast_llm_paths=(("convolution_layer",), ("convolution_layer", "kernel_size")), + hf_paths=(("convolution_layer",),), export_fn=lambda c: {("convolution_layer",): {"kernel_size": c.convolution_layer.kernel_size}}, import_fn=lambda hf: ( {("convolution_layer",): hf["convolution_layer"]} if "convolution_layer" in hf else {} @@ -404,6 +408,7 @@ def _create_config_converters(cls) -> dict: "head_dim": RenameConfigConverter(("head_dim",), ("head_dim",)), "convolution_layer_kernel": CustomConfigConverter( fast_llm_paths=(("convolution_layer",), ("convolution_layer", "kernel_size")), + hf_paths=(("convolution_layer",),), export_fn=lambda c: {("convolution_layer",): {"kernel_size": c.convolution_layer.kernel_size}}, import_fn=lambda hf: ( {("convolution_layer",): hf["convolution_layer"]} if "convolution_layer" in hf else {} @@ -417,6 +422,7 @@ def _create_config_converters(cls) -> dict: ), "normalization_epsilon": CustomConfigConverter( fast_llm_paths=(("normalization",), ("normalization", "epsilon")), + hf_paths=(("normalization",),), export_fn=lambda c: {("normalization",): {"epsilon": c.normalization.epsilon}}, import_fn=lambda hf: ({("normalization",): hf["normalization"]} if "normalization" in hf else {}), ), @@ -671,11 +677,13 @@ def _create_config_converters(cls) -> dict: "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("add_linear_biases",)), "activation": CustomConfigConverter( fast_llm_paths=(("activation",),), + hf_paths=(("activation",),), export_fn=lambda c: {("activation",): c.activation.hf_name}, import_fn=lambda hf: {("activation",): ActivationType.from_hf_name(hf["activation"])}, ), "layers": CustomConfigConverter( fast_llm_paths=tuple((name,) for name in layer_names), + hf_paths=tuple((name,) for name in layer_names), export_fn=lambda c: _per_layer_bias_export(c, layer_names), import_fn=lambda hf: _per_layer_bias_import(hf, layer_names), recurses=True, @@ -936,6 +944,11 @@ def _create_config_converters(cls) -> dict: "hidden_size": RenameConfigConverter(("hidden_size",), ("hidden_size",)), "tied_embedding_weight": RenameConfigConverter(("tied_embedding_weight",), ("tie_word_embeddings",)), "peft": IgnoredConfigConverter(("peft",)), + # ``Apriel2TextConfig`` default-injects an ``embeddings`` HF subdict + # (``{"max_position_embeddings": 2048}``) the Fast-LLM converter doesn't use — vocab_size + # rides at top level via the flat-merged ``LlamaEmbeddingsConverter``. Claim the injected + # subdict so the HF coverage check doesn't flag it. + "embeddings_subdict_unmapped": IgnoredConfigConverter(hf_paths=(("embeddings",),)), } @classmethod @@ -996,6 +1009,7 @@ def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> dict[str, typing.Any]: + cls._check_hf_coverage(config) return {"base_model": cls.base_model_converter_class.import_config(config)} @classmethod diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 762d38c4d..31161680e 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -288,6 +288,7 @@ def _create_config_converters(cls) -> dict: "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("mlp_bias",)), "activation": CustomConfigConverter( fast_llm_paths=(("activation",),), + hf_paths=(("hidden_act",),), export_fn=lambda c: {("hidden_act",): c.activation.hf_name}, import_fn=lambda hf: {("activation",): ActivationType.from_hf_name(hf["hidden_act"])}, ), @@ -362,6 +363,7 @@ def _create_config_converters(cls) -> dict: ), "rotary": CustomConfigConverter( fast_llm_paths=(("rotary",),), + hf_paths=(("rope_theta",), ("rope_scaling",), ("rope_parameters",)), export_fn=_llama_rotary_export, import_fn=_llama_rotary_import, recurses=True, @@ -620,6 +622,15 @@ def _decoder_import(hf_dict: dict) -> dict: "head": NestedConfigConverter(("head",), cls.head_converter_class), "decoder": CustomConfigConverter( fast_llm_paths=(("decoder",),), + # The Custom wraps the imperative LlamaDecoderConverter, which delegates to + # cls.decoder_converter_class.block_converter_class (a ConfigSectionConverter). The + # block converter's flat-merge declarations claim all per-block top-level keys; pull + # them up here so the HF coverage check sees them as covered. ``num_hidden_layers`` + # is consumed by LlamaDecoderConverter itself. + hf_paths=( + ("num_hidden_layers",), + *cls.decoder_converter_class.block_converter_class._consumed_hf_paths(), + ), export_fn=_decoder_export, import_fn=_decoder_import, recurses=True, diff --git a/fast_llm/models/gpt/conversion/mixtral.py b/fast_llm/models/gpt/conversion/mixtral.py index 7659befa3..02226d040 100644 --- a/fast_llm/models/gpt/conversion/mixtral.py +++ b/fast_llm/models/gpt/conversion/mixtral.py @@ -36,6 +36,11 @@ def _create_config_converters(cls) -> dict: "routing": ConstantImportConfigConverter(("routing",), RoutingType.topk), # Mixtral's gate is a default LinearConfig (no bias); blanket-consume so coverage passes. "router": IgnoredConfigConverter(("router",)), + # Router / inference toggles surfaced by HF but not consumed by Fast-LLM's MoEMLPConfig + # (auxiliary_loss_coefficient and jitter_eps are FieldHint.feature, not architecture). + "router_runtime_unsupported": IgnoredConfigConverter( + hf_paths=(("router_aux_loss_coef",), ("router_jitter_noise",), ("output_router_logits",)), + ), } @classmethod diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index 821619c46..b785fb0a6 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -3,6 +3,7 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantImportConfigConverter, + IgnoredConfigConverter, ImportOnlyConfigConverter, WeightConverter, ) @@ -50,6 +51,10 @@ def _create_config_converters(cls) -> dict: }, recurses=True, ), + # Sliding-window machinery surfaced by Qwen2 HF but not yet supported here (see TODO above). + "sliding_window_unsupported": IgnoredConfigConverter( + hf_paths=(("sliding_window",), ("use_sliding_window",), ("max_window_layers",), ("layer_types",)), + ), } @classmethod diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 5ad0abb5d..82f1ea7ff 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -85,6 +85,7 @@ def _create_config_converters(cls) -> dict: # layout follows the active transformers major version, mirroring the Llama parent. "rotary": CustomConfigConverter( fast_llm_paths=(("rotary",),), + hf_paths=(("rope_theta",), ("rope_parameters",)), export_fn=_pixtral_rotary_export, import_fn=_pixtral_rotary_import, recurses=True, From 536d5482776be82bc98338986208b0f47aa24849 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 11 May 2026 19:05:40 -0400 Subject: [PATCH 18/27] Address fourth review round - Apriel2BlockConverter._validate_export asserts type(config.mlp) is MLPConfig, restoring the pre-PR rejection of MoEMLPConfig that NestedConfigConverter would otherwise silently descend through (dropping experts/routing/router). - _consumed_hf_paths now expands a nested sub-converter's claims under its hf_path prefix (NestedConfigConverter/DispatchConfigConverter with hf_path set) so check_hf_coverage descends and flags unknown keys deep inside apriel2's head/decoder, llava's vision_config, etc. - Pin prediction_heads to 1 in Llama and Apriel2 head converters via ConstantImportConfigConverter so non-default values fail on export instead of silently dropping (MTP-Llama overrides the entry with Rename). - Document the cache-mutation hazard on _create_config_converters: subclasses must spread the parent's dict, never mutate it in place. - Narrow Apriel2BaseModelConverter's HF embeddings Ignored to the single injected leaf so future transformers fields in the same subdict trip the coverage check. - Tighten Mixtral router Ignored comment to record the structural rationale (router.weight has no architecture sub-fields, so the blanket claim is equivalent to the narrowest possible non-recursive claim). Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/checkpoint/external.py | 52 +++++++++++++++-------- fast_llm/models/gpt/conversion/apriel2.py | 25 ++++++++--- fast_llm/models/gpt/conversion/llama.py | 9 ++-- fast_llm/models/gpt/conversion/mixtral.py | 6 ++- 4 files changed, 63 insertions(+), 29 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 281bddf97..72bc957fc 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -500,7 +500,9 @@ class ConfigSectionConverter(abc.ABC): def _create_config_converters(cls) -> dict[str, ConfigConverter]: """Return declarations keyed by stable string name. Subclasses override entries by re-declaring the key. - Cached per class — declarations are immutable and depend only on ``cls``. + Cached per class — declarations are immutable and depend only on ``cls``. Subclasses must build + and return a *fresh* dict (idiomatically ``{**super()._create_config_converters(), ...}``); mutating + the returned dict in place would corrupt the parent's cache entry for every subsequent caller. """ raise NotImplementedError @@ -535,31 +537,47 @@ def import_config(cls, hf_dict: dict) -> dict: @classmethod @functools.cache def _consumed_hf_paths(cls) -> frozenset[tuple[str, ...]]: - """Set of HF dict path prefixes consumed by this section's declaration tree. - - Each entry is a tuple-of-keys from the section's HF subdict root. The - :meth:`check_hf_coverage` walker treats every entry as a *recursive prefix* — once an input - path matches any prefix, descent into deeper sub-dicts stops. - - Recurses through: - * :class:`NestedConfigConverter` with ``hf_path=None`` (flat-merge): the sub-converter shares - the parent's HF namespace, so its claims are pulled up to this level. - * :class:`DispatchConfigConverter` with ``hf_paths=()`` (flat-merge): every registered class - contributes its claims at this level (the union covers all possible runtime types). - Otherwise the primitive's own ``hf_paths`` are used. + """Set of HF dict paths consumed by this section's declaration tree. + + Each entry is a tuple-of-keys from the section's HF subdict root. The :meth:`check_hf_coverage` + walker treats every entry as a *recursive prefix* — once an input path matches any prefix, + descent into deeper sub-dicts stops. + + Nested sub-converters (``NestedConfigConverter``/``DispatchConfigConverter`` with ``hf_path`` + set) expand their sub-converter's claims under the nested prefix instead of contributing the + bare prefix, so the walker descends into the subdict and flags unknown keys inside it (e.g. + ``head.normalization.surprise_field``). + + :class:`TypedDictContainerConfigConverter` keeps a blanket prefix because its per-entry sub-dicts + are user-named (pattern keys); we can't statically enumerate which entries will appear or what + keys those entries should claim. """ paths: set[tuple[str, ...]] = set() for declaration in cls._create_config_converters().values(): if isinstance(declaration, NestedConfigConverter): + sub_class = declaration._converter_class if declaration._hf_path is None: - paths |= declaration._converter_class._consumed_hf_paths() - if declaration._converter_class.hf_type_name is not None: + # Flat-merge: sub-converter shares the parent's HF namespace. + paths |= sub_class._consumed_hf_paths() + if sub_class.hf_type_name is not None: paths.add((declaration._hf_discriminator_key,)) else: - paths.add(declaration._hf_path) + # Nested: prepend hf_path to each sub-claim so the walker recurses into the subdict. + prefix = declaration._hf_path + for sub_path in sub_class._consumed_hf_paths(): + paths.add(prefix + sub_path) + if sub_class.hf_type_name is not None: + paths.add(prefix + (declaration._hf_discriminator_key,)) elif isinstance(declaration, DispatchConfigConverter): if declaration.hf_paths: - paths.add(declaration.hf_paths[0]) + # Nested dispatch: union of all registered sub-classes' claims under the shared + # hf_path prefix. At runtime only one sub-class fires; the static union is a safe + # over-claim (we only need to *not* flag known keys, never to flag missing ones). + prefix = declaration.hf_paths[0] + paths.add(prefix + (declaration._hf_discriminator_key,)) + for sub_class in declaration._registry.values(): + for sub_path in sub_class._consumed_hf_paths(): + paths.add(prefix + sub_path) else: paths.add((declaration._hf_discriminator_key,)) for sub_class in declaration._registry.values(): diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 7c7f3e383..5425a932b 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -45,7 +45,7 @@ get_weight_and_bias_converters, ) from fast_llm.models.gpt.model import GPTModel -from fast_llm.utils import safe_merge_dicts +from fast_llm.utils import Assert, safe_merge_dicts # ============================================================ # Helpers @@ -754,6 +754,13 @@ def _create_config_converters(cls) -> dict: ), } + @classmethod + def _validate_export(cls, config: DecoderBlockConfig) -> None: + # Apriel2 HF format only represents plain MLP. ``NestedConfigConverter`` dispatches by fixed class + # (``Apriel2MLPConverter`` registered against ``MLPConfig``) and would silently descend into a + # ``MoEMLPConfig`` via MRO, dropping every MoE-specific architecture field. + Assert.is_(type(config.mlp), MLPConfig) + # --- weight side (imperative) --- @classmethod @@ -898,7 +905,9 @@ def _create_config_converters(cls) -> dict: registry=APRIEL2_NORM_REGISTRY, ), "output_weight": IgnoredConfigConverter(("output_weight",)), - "prediction_heads": IgnoredConfigConverter(("prediction_heads",)), + # Apriel2 HF format does not support multi-token prediction; pin to 1 so any non-default value + # fails on export instead of silently round-tripping. + "prediction_heads": ConstantImportConfigConverter(("prediction_heads",), 1), } # --- weight side (imperative) --- @@ -944,11 +953,13 @@ def _create_config_converters(cls) -> dict: "hidden_size": RenameConfigConverter(("hidden_size",), ("hidden_size",)), "tied_embedding_weight": RenameConfigConverter(("tied_embedding_weight",), ("tie_word_embeddings",)), "peft": IgnoredConfigConverter(("peft",)), - # ``Apriel2TextConfig`` default-injects an ``embeddings`` HF subdict - # (``{"max_position_embeddings": 2048}``) the Fast-LLM converter doesn't use — vocab_size - # rides at top level via the flat-merged ``LlamaEmbeddingsConverter``. Claim the injected - # subdict so the HF coverage check doesn't flag it. - "embeddings_subdict_unmapped": IgnoredConfigConverter(hf_paths=(("embeddings",),)), + # ``Apriel2TextConfig`` default-injects ``{"embeddings": {"max_position_embeddings": 2048}}`` + # the Fast-LLM converter doesn't use — vocab_size rides at top level via the flat-merged + # ``LlamaEmbeddingsConverter``. Claim only the specific injected leaf so any future field + # transformers adds to the same subdict trips the HF coverage check. + "embeddings_subdict_unmapped": IgnoredConfigConverter( + hf_paths=(("embeddings", "max_position_embeddings"),) + ), } @classmethod diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 31161680e..32f1c79ed 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -568,10 +568,11 @@ def _create_config_converters(cls) -> dict: return { "normalization": NestedConfigConverter(("normalization",), cls.normalization_converter_class), "output_weight": IgnoredConfigConverter(("output_weight",)), - # ``prediction_heads`` is architecture (>1 enables multi-token prediction); Llama HF format does - # not represent it. We don't pin it to 1 here so MTP-Llama (a Llama-derived format) can override - # the declaration with a Rename without first hitting an assertion in the inherited path. - "prediction_heads": IgnoredConfigConverter(("prediction_heads",)), + # Llama HF format does not represent ``prediction_heads``; pin to 1 so any non-default value + # fails on export instead of silently round-tripping. MTP-Llama overrides this entry with a + # ``RenameConfigConverter`` (the override replaces the parent's declaration in the returned + # dict, so this ConstantImport never fires for MTP-Llama configs). + "prediction_heads": ConstantImportConfigConverter(("prediction_heads",), 1), } # --- weight side (imperative) --- diff --git a/fast_llm/models/gpt/conversion/mixtral.py b/fast_llm/models/gpt/conversion/mixtral.py index 02226d040..5d9deb4d5 100644 --- a/fast_llm/models/gpt/conversion/mixtral.py +++ b/fast_llm/models/gpt/conversion/mixtral.py @@ -34,7 +34,11 @@ def _create_config_converters(cls) -> dict: # Mixtral has no shared experts and uses the topk default; assert on export, inject defaults on import. "shared_experts": ConstantImportConfigConverter(("shared_experts",), 0), "routing": ConstantImportConfigConverter(("routing",), RoutingType.topk), - # Mixtral's gate is a default LinearConfig (no bias); blanket-consume so coverage passes. + # Mixtral has no HF representation for the router sub-config. The blanket consume satisfies + # architecture coverage; non-architecture fields (lr_scale, apply_peft, weight.initialization, + # weight.lr_scale) cannot round-trip through the HF format by design — Fast-LLM keeps them on + # the in-memory config independently. The only architecture-hint sub-field is ``router.weight``, + # a ParameterConfig with no architecture sub-fields, so the blanket carries no real risk. "router": IgnoredConfigConverter(("router",)), # Router / inference toggles surfaced by HF but not consumed by Fast-LLM's MoEMLPConfig # (auxiliary_loss_coefficient and jitter_eps are FieldHint.feature, not architecture). From d34aac42ad35a088624f9496206bbd450b6fa952 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 11 May 2026 19:54:32 -0400 Subject: [PATCH 19/27] Address fifth review round MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Surface cleanups from the fine-pass review: rename ``cur`` → ``current`` in ``_get_attr_path``, merge an unintentionally split f-string in the ``DispatchConfigConverter`` error path, switch bare ``return`` to ``pass`` in empty ``-> None`` converter bodies, type-annotate ``_per_layer_bias_export`` and ``get_apriel2_decoder_converter`` (dropping a redundant forward-ref quote), and replace ``<->`` with ``↔`` in the remaining converter docstrings for consistency across the migration. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/checkpoint/external.py | 14 +++++++------- fast_llm/models/gpt/conversion/apriel.py | 6 +++--- fast_llm/models/gpt/conversion/apriel2.py | 7 +++++-- fast_llm/models/multimodal/conversion/llava.py | 2 +- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 72bc957fc..b7d443854 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -21,10 +21,10 @@ def _get_attr_path(config: Config, path: tuple[str, ...]) -> typing.Any: - cur = config + current = config for name in path: - cur = getattr(cur, name) - return cur + current = getattr(current, name) + return current def _collect_architecture_paths(config: Config) -> list[tuple[str, ...]]: @@ -219,10 +219,10 @@ def __init__(self, *fast_llm_paths: tuple[str, ...], hf_paths: tuple[tuple[str, self.hf_paths = hf_paths def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: - return + pass def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: - return + pass class CustomConfigConverter(ConfigConverter): @@ -292,7 +292,7 @@ def __init__( self.recurses = recurses def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: - return + pass def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: produced = self._import_fn(hf_dict) @@ -395,7 +395,7 @@ def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: converter_class = self._hf_to_class.get(type_name) if converter_class is None: raise NotImplementedError( - f"No converter registered for HF discriminator {type_name!r} at " f"{'.'.join(self.fast_llm_paths[0])}" + f"No converter registered for HF discriminator {type_name!r} at {'.'.join(self.fast_llm_paths[0])}" ) sub_fast_llm = converter_class.import_config(sub_hf) # Inject the Fast-LLM dynamic-type discriminator so the parent's `from_dict` dispatches to the diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 00fc52182..6f6924260 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -34,7 +34,7 @@ class AprielMambaConverter(ConfigSectionConverter): - """Converts ``MambaConfig`` <-> Apriel hybrid SSM HF dict (``ssm_cfg`` subdict + root-level fallbacks). + """Converts ``MambaConfig`` ↔ Apriel hybrid SSM HF dict (``ssm_cfg`` subdict + root-level fallbacks). A few of MambaConfig's defaults are derived from the HF root's ``hidden_size`` (``d_inner`` defaults to ``hidden_size * expand``, ``d_xb`` defaults to ``hidden_size``, ``dt_rank="auto"`` resolves to @@ -186,7 +186,7 @@ def get_converters( class GatedDeltaNetConverter(ConfigSectionConverter): - """Converts ``GatedDeltaNetConfig`` <-> Apriel HF ``linear_attn_config`` subdict.""" + """Converts ``GatedDeltaNetConfig`` ↔ Apriel HF ``linear_attn_config`` subdict.""" fast_llm_config_class = GatedDeltaNetConfig @@ -280,7 +280,7 @@ def get_converters( class KimiDeltaAttentionConverter(ConfigSectionConverter): - """Converts ``KimiDeltaAttentionConfig`` <-> Apriel HF ``linear_attn_config`` subdict.""" + """Converts ``KimiDeltaAttentionConfig`` ↔ Apriel HF ``linear_attn_config`` subdict.""" fast_llm_config_class = KimiDeltaAttentionConfig diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 5425a932b..c18d13707 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -4,6 +4,7 @@ from transformers import PretrainedConfig +from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConfigSectionConverter, @@ -52,7 +53,7 @@ # ============================================================ -def _per_layer_bias_export(config, layer_names: tuple[str, ...]) -> dict: +def _per_layer_bias_export(config: Config, layer_names: tuple[str, ...]) -> dict: """Emit per-layer ``{layer: {"bias": {"enabled": bool}}}`` only for layers whose bias is explicitly set.""" out: dict = {} for layer_name in layer_names: @@ -883,7 +884,9 @@ def get_converters( } -def get_apriel2_decoder_converter(decoder_config) -> "type[ConfigSectionConverter]": +def get_apriel2_decoder_converter( + decoder_config: FixedBlockSequenceConfig | PatternBlockSequenceConfig, +) -> type[ConfigSectionConverter]: """Look up the Apriel2 per-shape decoder converter for a given decoder config instance.""" converter_class = APRIEL2_DECODER_REGISTRY.get(type(decoder_config)) if converter_class is None: diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 82f1ea7ff..40f22616b 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -142,7 +142,7 @@ def import_weight( class PixtralEmbeddingsConverter(ConfigSectionConverter): - """Converts ``PatchEmbeddingsConfig`` <-> Pixtral HF flat fields (``patch_size`` / ``num_channels``). + """Converts ``PatchEmbeddingsConfig`` ↔ Pixtral HF flat fields (``patch_size`` / ``num_channels``). Pixtral's HF ``vision_config`` carries a single ``patch_size`` field (height == width); the converter expands it to both Fast-LLM dimensions on import and validates equality on export. From 58581603fae5ef41cff8b8a6aa9fceea27d73b82 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 11 May 2026 20:34:17 -0400 Subject: [PATCH 20/27] Address sixth review round MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Round 6 picks up one latent correctness bug, consolidates duplicated declarations into framework primitives, and tidies several surface items. * ``Apriel2HeadConverter._validate_export`` now asserts ``RMSNormalizationConfig``: the config side dispatches normalization through ``APRIEL2_NORM_REGISTRY`` while the weight side hardcoded RMS, so a LayerNorm/NoNorm head would have silently dropped its bias on convert. * ``ConfigSectionConverter.import_config`` injects ``{"type": }`` from ``fast_llm_config_class`` automatically, removing the redundant injection from ``NestedConfigConverter`` / ``TypedDictContainerConfigConverter`` and collapsing four hand-rolled overrides (Apriel mamba/gdn/kda + Mixtral moe). * Deleted ``MTPLlamaDecoderConverter`` — its overrides were byte-identical to the parent's after the migration, with the only diff being a Pattern restriction that the parent now handles correctly through the multi-block-equality branch. * Extracted ``_per_layer_bias_converter`` and ``_apriel2_conv_kernel_converter`` helpers in apriel2.py to collapse pairs of byte-identical CustomConfigConverter declarations. * ``AprielBlockConverter._consumed_hf_paths`` gets ``@functools.cache`` for parity with the base ``ConfigSectionConverter._consumed_hf_paths``. * ``effective_bias`` typed as ``AffineLinearConfig``; ``NoPeftConfig`` import moved to top of llama.py (the module is not a config module subject to the heavy-import rule); stale ``# TODO: Peft?`` removed. * CLAUDE.md naming convention clarified: single underscore covers non-public (private or protected), matching the project's actual usage. Co-Authored-By: Claude Sonnet 4.6 --- CLAUDE.md | 2 +- fast_llm/engine/checkpoint/external.py | 19 +++--- fast_llm/models/gpt/conversion/apriel.py | 17 +----- fast_llm/models/gpt/conversion/apriel2.py | 65 +++++++++++---------- fast_llm/models/gpt/conversion/llama.py | 7 +-- fast_llm/models/gpt/conversion/mixtral.py | 7 --- fast_llm/models/gpt/conversion/mtp_llama.py | 23 +------- 7 files changed, 51 insertions(+), 89 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 836b2ba7b..4f97d071f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -174,7 +174,7 @@ Tests live in `tests/`. The following patterns work well in this codebase. - **Comments**: Write no comments by default. Only add one when the *why* is non-obvious — a hidden constraint, a subtle invariant, a workaround for a specific bug, behavior that would surprise a reader. Never restate what the code already says; well-named identifiers do that. - **Imports**: Third-party → `import package.module` (keep fully qualified). First-party → `from fast_llm.module import Thing`. No relative imports. Optional/slow imports inside methods or under `if typing.TYPE_CHECKING:`. -- **Naming**: No abbreviations (use `batch_size` not `bs`). Private members get a single `_` prefix; never use `__`. Keep public interfaces lean. +- **Naming**: No abbreviations (use `batch_size` not `bs`). Non-public members (private or protected) get a single `_` prefix; never use `__`. Keep public interfaces lean. - **Types**: Always type-hint public interfaces. Use modern syntax (`X | Y`, `list[T]` not `List[T]`, PEP 695 generics like `class X[T: Bound]:` instead of `typing.TypeVar`). - **Assert**: Use the `Assert` namespace from `fast_llm.utils` for contract checks (`Assert.eq`, `Assert.geq`, `Assert.incl`, `Assert.custom`, etc.) — error messages auto-format with actual values. Bare `assert` is reserved for internal state-validity invariants (`assert self._is_setup`). - **Exceptions**: Raise stdlib exceptions for runtime errors (`ValueError`, `RuntimeError`, `NotImplementedError`). Custom exception classes are rare — only `ValidationError`, `NestedValidationError`, `FieldTypeError` in `config.py`. diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index b7d443854..ac87589bd 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -398,12 +398,6 @@ def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: f"No converter registered for HF discriminator {type_name!r} at {'.'.join(self.fast_llm_paths[0])}" ) sub_fast_llm = converter_class.import_config(sub_hf) - # Inject the Fast-LLM dynamic-type discriminator so the parent's `from_dict` dispatches to the - # correct subclass. Reads from the registered Config class rather than the HF discriminator so - # mismatched Fast-LLM/HF type names work too. - fast_llm_type = getattr(converter_class.fast_llm_config_class, "dynamic_type_name", None) - if fast_llm_type is not None: - sub_fast_llm = {"type": fast_llm_type, **sub_fast_llm} set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], sub_fast_llm) @@ -468,9 +462,6 @@ def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: f"{'.'.join(self.hf_paths[0])}[{name!r}]" ) sub_fast_llm = converter_class.import_config(sub_hf) - fast_llm_type = getattr(converter_class.fast_llm_config_class, "dynamic_type_name", None) - if fast_llm_type is not None: - sub_fast_llm = {"type": fast_llm_type, **sub_fast_llm} out[name] = sub_fast_llm set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], out) @@ -528,10 +519,18 @@ def export_config(cls, config: Config) -> dict: @classmethod def import_config(cls, hf_dict: dict) -> dict: - """Convert an HF config dict to a Fast-LLM config dict via this section's declarations.""" + """Convert an HF config dict to a Fast-LLM config dict via this section's declarations. + + When ``fast_llm_config_class`` carries a ``dynamic_type_name`` (i.e. the target is a registered + dynamic-type subclass), inject ``"type": `` so the caller's ``from_dict`` dispatches to the + correct subclass without each section converter having to prepend it manually. + """ out: dict = {} for converter in cls._create_config_converters().values(): converter.import_to(hf_dict, out) + fast_llm_type = getattr(cls.fast_llm_config_class, "dynamic_type_name", None) + if fast_llm_type is not None: + out = {"type": fast_llm_type, **out} return out @classmethod diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 6f6924260..befabad4d 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -1,3 +1,4 @@ +import functools import math import typing @@ -125,13 +126,6 @@ def _validate_export(cls, config: MambaConfig) -> None: Assert.incl(config.dt_input_layer.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.output_layer.bias.enabled, (None, config.add_linear_biases)) - @classmethod - def import_config(cls, hf_dict: dict) -> dict: - # Inject the Fast-LLM dynamic-type discriminator: the parent (AprielBlockConverter) selects this - # leaf via `hybrid_block_layout`, not via a nested HF discriminator, so DispatchConfigConverter's - # auto-injection isn't in play and we must add `type` manually. - return {"type": "mamba", **super().import_config(hf_dict)} - @classmethod def get_converters( cls, @@ -223,10 +217,6 @@ def _create_config_converters(cls) -> dict: ), } - @classmethod - def import_config(cls, hf_dict: dict) -> dict: - return {"type": "gdn", **super().import_config(hf_dict)} - @classmethod def get_converters( cls, @@ -321,10 +311,6 @@ def _create_config_converters(cls) -> dict: ), } - @classmethod - def import_config(cls, hf_dict: dict) -> dict: - return {"type": "kda", **super().import_config(hf_dict)} - @classmethod def get_converters( cls, @@ -468,6 +454,7 @@ def export_config(cls, config) -> dict: return cls._converter_classes[type(config.mixer)].export_config(config) @classmethod + @functools.cache def _consumed_hf_paths(cls) -> frozenset[tuple[str, ...]]: """Union of consumed HF paths across every per-mixer-type block converter — used by the parent's decoder Custom to pre-claim Apriel's flat top-level keys for the HF coverage check.""" diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index c18d13707..16ab133b2 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -72,6 +72,30 @@ def _per_layer_bias_import(hf_dict: dict, layer_names: tuple[str, ...]) -> dict: return out +def _per_layer_bias_converter(layer_names: tuple[str, ...]) -> CustomConfigConverter: + """Per-layer ``bias.enabled`` round-trip for the named sub-layers of an attention or MLP config: + emits/consumes the HF ``{layer: {"bias": {"enabled": ...}}}`` tree.""" + return CustomConfigConverter( + fast_llm_paths=tuple((name,) for name in layer_names), + hf_paths=tuple((name,) for name in layer_names), + export_fn=lambda c: _per_layer_bias_export(c, layer_names), + import_fn=lambda hf: _per_layer_bias_import(hf, layer_names), + recurses=True, + ) + + +def _apriel2_conv_kernel_converter() -> CustomConfigConverter: + """Round-trip Apriel2's flat ``convolution_layer.kernel_size`` against the Fast-LLM + ``convolution_layer`` sub-config. Shared between :class:`Apriel2GatedDeltaNetConverter` and + :class:`Apriel2KimiDeltaAttentionConverter`.""" + return CustomConfigConverter( + fast_llm_paths=(("convolution_layer",), ("convolution_layer", "kernel_size")), + hf_paths=(("convolution_layer",),), + export_fn=lambda c: {("convolution_layer",): {"kernel_size": c.convolution_layer.kernel_size}}, + import_fn=lambda hf: ({("convolution_layer",): hf["convolution_layer"]} if "convolution_layer" in hf else {}), + ) + + # ============================================================ # Mixer converters # ============================================================ @@ -143,13 +167,7 @@ def _create_config_converters(cls) -> dict: ("add_linear_biases",), ("add_linear_biases",), sentinel=True ), "window_size": OptionalConfigConverter(("window_size",), ("window_size",)), - "linear_layers": CustomConfigConverter( - fast_llm_paths=tuple((name,) for name in layer_names), - hf_paths=tuple((name,) for name in layer_names), - export_fn=lambda c: _per_layer_bias_export(c, layer_names), - import_fn=lambda hf: _per_layer_bias_import(hf, layer_names), - recurses=True, - ), + "linear_layers": _per_layer_bias_converter(layer_names), "causal": IgnoredConfigConverter(("causal",)), "softmax_scale_power": IgnoredConfigConverter(("softmax_scale_power",)), } @@ -318,14 +336,7 @@ def _create_config_converters(cls) -> dict: "key_heads": RenameConfigConverter(("key_heads",), ("key_heads",)), "key_head_dim": RenameConfigConverter(("key_head_dim",), ("key_head_dim",)), "value_head_dim": RenameConfigConverter(("value_head_dim",), ("value_head_dim",)), - "convolution_layer_kernel": CustomConfigConverter( - fast_llm_paths=(("convolution_layer",), ("convolution_layer", "kernel_size")), - hf_paths=(("convolution_layer",),), - export_fn=lambda c: {("convolution_layer",): {"kernel_size": c.convolution_layer.kernel_size}}, - import_fn=lambda hf: ( - {("convolution_layer",): hf["convolution_layer"]} if "convolution_layer" in hf else {} - ), - ), + "convolution_layer_kernel": _apriel2_conv_kernel_converter(), # CausalConv1dConfig sub-fields the Apriel2 HF format does not surface (weight rides the tensor # side; bias/activation round-trip at their Fast-LLM defaults). "convolution_layer_unmapped": IgnoredConfigConverter( @@ -407,14 +418,7 @@ def _create_config_converters(cls) -> dict: return { "heads": RenameConfigConverter(("heads",), ("heads",)), "head_dim": RenameConfigConverter(("head_dim",), ("head_dim",)), - "convolution_layer_kernel": CustomConfigConverter( - fast_llm_paths=(("convolution_layer",), ("convolution_layer", "kernel_size")), - hf_paths=(("convolution_layer",),), - export_fn=lambda c: {("convolution_layer",): {"kernel_size": c.convolution_layer.kernel_size}}, - import_fn=lambda hf: ( - {("convolution_layer",): hf["convolution_layer"]} if "convolution_layer" in hf else {} - ), - ), + "convolution_layer_kernel": _apriel2_conv_kernel_converter(), # CausalConv1dConfig sub-fields not surfaced in HF (same as :class:`Apriel2GatedDeltaNetConverter`). "convolution_layer_unmapped": IgnoredConfigConverter( ("convolution_layer", "weight"), @@ -682,13 +686,7 @@ def _create_config_converters(cls) -> dict: export_fn=lambda c: {("activation",): c.activation.hf_name}, import_fn=lambda hf: {("activation",): ActivationType.from_hf_name(hf["activation"])}, ), - "layers": CustomConfigConverter( - fast_llm_paths=tuple((name,) for name in layer_names), - hf_paths=tuple((name,) for name in layer_names), - export_fn=lambda c: _per_layer_bias_export(c, layer_names), - import_fn=lambda hf: _per_layer_bias_import(hf, layer_names), - recurses=True, - ), + "layers": _per_layer_bias_converter(layer_names), } @classmethod @@ -913,6 +911,13 @@ def _create_config_converters(cls) -> dict: "prediction_heads": ConstantImportConfigConverter(("prediction_heads",), 1), } + @classmethod + def _validate_export(cls, config: LanguageModelHeadConfig) -> None: + # The config side dispatches normalization through APRIEL2_NORM_REGISTRY (RMS/Layer/None), but the + # weight side below hardcodes ``normalization_converter_class`` (RMSNorm-only). Fail loudly here so a + # LayerNorm/NoNorm head config doesn't silently round-trip through the wrong weight conversion. + Assert.is_(type(config.normalization), RMSNormalizationConfig) + # --- weight side (imperative) --- @classmethod diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 32f1c79ed..af4cc9ef0 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -26,7 +26,9 @@ from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig +from fast_llm.layers.common.linear.config import AffineLinearConfig from fast_llm.layers.common.normalization.config import RMSNormalizationConfig +from fast_llm.layers.common.peft.config import NoPeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.language_model.config import ( @@ -48,12 +50,10 @@ def assert_no_peft(config: GPTBaseModelConfig) -> None: """Reject any non-trivial PEFT config: HuggingFace formats serialize the base weights only, so a configured LoRA (or other adapter) would be silently dropped on export.""" - from fast_llm.layers.common.peft.config import NoPeftConfig - Assert.custom(isinstance, config.peft, NoPeftConfig) -def effective_bias(layer_config, default: bool) -> bool: +def effective_bias(layer_config: AffineLinearConfig, default: bool) -> bool: """Resolve a layer's effective bias flag: explicit ``bias.enabled`` if set, else the parent default.""" return default if layer_config.bias.enabled is None else layer_config.bias.enabled @@ -603,7 +603,6 @@ class LlamaBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConver fast_llm_config_class = GPTBaseModelConfig - # TODO: Peft? decoder_converter_class: typing.ClassVar[type[LlamaDecoderConverter]] = LlamaDecoderConverter embeddings_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaEmbeddingsConverter head_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaHeadConverter diff --git a/fast_llm/models/gpt/conversion/mixtral.py b/fast_llm/models/gpt/conversion/mixtral.py index 5d9deb4d5..d1ead7309 100644 --- a/fast_llm/models/gpt/conversion/mixtral.py +++ b/fast_llm/models/gpt/conversion/mixtral.py @@ -47,13 +47,6 @@ def _create_config_converters(cls) -> dict: ), } - @classmethod - def import_config(cls, hf_dict: dict) -> dict: - # Inject the Fast-LLM dynamic-type discriminator so `from_dict` instantiates `MoEMLPConfig` - # rather than the default `MLPConfig`. The MLP is wrapped via `NestedConfigConverter`, so - # there's no surrounding `DispatchConfigConverter` to inject this for us. - return {"type": "moe", **super().import_config(hf_dict)} - @classmethod def get_converters( cls, diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index 787ba0220..6f6d9e88a 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -4,18 +4,16 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import RenameConfigConverter, WeightConverter -from fast_llm.layers.block.config import FixedBlockSequenceConfig from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import MTPLlamaCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaBaseModelConverter, - LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, get_parameter_converter, ) -from fast_llm.utils import Assert, safe_merge_dicts +from fast_llm.utils import safe_merge_dicts class MTPLlamaHeadConverter(LlamaHeadConverter): @@ -62,26 +60,7 @@ def get_converters( return converters -class MTPLlamaDecoderConverter(LlamaDecoderConverter): - @classmethod - def import_config(cls, hf_dict: dict) -> dict: - return { - "block": cls.block_converter_class.import_config(hf_dict), - "num_blocks": hf_dict["num_hidden_layers"], - } - - @classmethod - def export_config(cls, decoder_config: FixedBlockSequenceConfig) -> dict: - # TODO: Support PatternBlockSequenceConfig with compatible configs. - Assert.custom(isinstance, decoder_config, FixedBlockSequenceConfig) - return safe_merge_dicts( - cls.block_converter_class.export_config(decoder_config.block), - {"num_hidden_layers": decoder_config.num_blocks}, - ) - - class MTPLlamaBaseModelConverter(LlamaBaseModelConverter): - decoder_converter_class: typing.ClassVar[type[MTPLlamaDecoderConverter]] = MTPLlamaDecoderConverter head_converter_class: typing.ClassVar[type[MTPLlamaHeadConverter]] = MTPLlamaHeadConverter From 807fbe1c52dfc14656caec351c7887ae17a8b5f5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 May 2026 16:25:48 -0400 Subject: [PATCH 21/27] Migrate Apriel2 multimodal config converters to declarative Apriel2VisionAttention/Block/MLP/Encoder/Embeddings/Adapter/Model and the top-level Apriel2MultimodalBaseModelConverter become ConfigSectionConverter subclasses. The vision branch keeps inheriting weight-side get_converters from Pixtral/Llava bases via MRO; only the config side is declarative. Cross-section rotary metadata (patch_size/max_image_size derived from embeddings.patch_height) is injected at the vision-model level via a Custom, which is the smallest scope that sees both halves. Co-Authored-By: Claude Opus 4.7 --- .../models/multimodal/conversion/apriel2.py | 576 ++++++++++++------ 1 file changed, 373 insertions(+), 203 deletions(-) diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index 225863fcc..f808714eb 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -3,16 +3,29 @@ import typing from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + ConstantExportConfigConverter, + CustomConfigConverter, + IgnoredConfigConverter, + NestedConfigConverter, + OptionalConfigConverter, + RenameConfigConverter, + WeightConverter, +) +from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Rotary2DConfig +from fast_llm.layers.block.config import FixedBlockSequenceConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionEncoderConfig from fast_llm.models.gpt.conversion.apriel2 import ( Apriel2BaseModelConverter, Apriel2HeadConverter, + Apriel2RMSNormConverter, get_apriel2_decoder_converter, ) from fast_llm.models.gpt.conversion.llama import ( @@ -25,170 +38,282 @@ from fast_llm.models.multimodal.conversion.config import Apriel2CheckpointFormat from fast_llm.models.multimodal.conversion.llava import ( LlavaVisionAdapterConverter, - LlavaVisionModelConverter, PatchEmbeddingWeightConverter, PixtralAttentionConverter, - PixtralBlockConverter, - PixtralEncoderConverter, ) from fast_llm.models.multimodal.model import MultiModalModel -from fast_llm.utils import Assert, safe_merge_dicts +from fast_llm.utils import Assert + + +def _apriel2_vision_attention_rotary_export(config: AttentionConfig) -> dict: + """Emit the Apriel2-vision rotary subdict. Two rotary types are supported: + :class:`Rotary2DConfig` (HF ``pixtral_2d``) and :class:`DefaultRotaryConfig` (HF ``mistral_1d``). + ``patch_size``/``max_image_size`` HF metadata is injected by the parent vision-model converter + (it derives from ``embeddings.patch_height``, outside this scope).""" + rotary = config.rotary + if type(rotary) is Rotary2DConfig: + return {("rotary",): {"type": "pixtral_2d", "theta": rotary.theta}} + if type(rotary) is DefaultRotaryConfig: + return {("rotary",): {"type": "mistral_1d", "theta": rotary.theta}} + raise NotImplementedError(f"Unsupported rotary type: {type(rotary).__name__}") + + +def _apriel2_vision_attention_rotary_import(hf_dict: dict) -> dict: + rotary = dict(hf_dict["rotary"]) + if rotary.get("type") == "pixtral_2d": + rotary["type"] = "default_2d" + elif rotary.get("type") == "mistral_1d": + rotary["type"] = "default" + rotary.pop("patch_size", None) + rotary.pop("max_image_size", None) + return {("rotary",): rotary} class Apriel2VisionAttentionConverter(PixtralAttentionConverter): + """Converts :class:`AttentionConfig` ↔ Apriel2 vision attention HF subdict (typed ``"attention"``). + + Apriel2's vision attention shape uses Apriel2-native field names (``heads``, ``head_groups``, ``head_size``, + ``add_linear_biases``, ``causal``) plus an explicit ``cross_document_attention=False`` flag and a nested + typed ``rotary`` block. Differs from the text :class:`Apriel2AttentionConverter` mainly in the rotary type + set (``pixtral_2d``/``mistral_1d`` instead of ``mistral_1d``/``llama3``/``yarn``) and the lack of + per-layer-bias and ``window_size`` representations. + + Inherits :meth:`get_converters` from :class:`PixtralAttentionConverter` (Llama-style q/k/v/o weight layout). + """ + + hf_type_name = "attention" + @classmethod - def import_config(cls, config: dict) -> dict: - rotary = config["rotary"].copy() - # Map Apriel2 HuggingFace rotary type to Fast-LLM internal type - if rotary.get("type") == "pixtral_2d": - rotary["type"] = "default_2d" - # Strip HF-specific fields not needed by Fast-LLM's Rotary2DConfig - # (Fast-LLM computes patch_positions dynamically from actual image patches) - rotary.pop("max_image_size", None) - rotary.pop("patch_size", None) + def _create_config_converters(cls) -> dict: + # Replace Pixtral's declarations wholesale: Apriel2 vision uses Apriel2-native field names, allows GQA + # and both Rotary2D + DefaultRotary, and has no HF representation for per-layer biases or window_size. return { - "rotary": rotary, - "heads": config["heads"], - "head_groups": config["head_groups"], - "head_size": config["head_size"], - "add_linear_biases": config["add_linear_biases"], - "causal": config["causal"], + "heads": RenameConfigConverter(("heads",), ("heads",)), + "head_groups": RenameConfigConverter(("head_groups",), ("head_groups",)), + "head_size": RenameConfigConverter(("head_size",), ("head_size",)), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("add_linear_biases",)), + "causal": RenameConfigConverter(("causal",), ("causal",)), + "cross_document_attention": ConstantExportConfigConverter(("cross_document_attention",), False), + "rotary": CustomConfigConverter( + fast_llm_paths=(("rotary",),), + hf_paths=(("rotary",),), + export_fn=_apriel2_vision_attention_rotary_export, + import_fn=_apriel2_vision_attention_rotary_import, + recurses=True, + ), + # Apriel2 vision attention has no per-layer bias representation; the Fast-LLM defaults round-trip. + "linear_layers": IgnoredConfigConverter( + ("query_layer",), ("key_layer",), ("value_layer",), ("dense_layer",) + ), + "softmax_scale_power": IgnoredConfigConverter(("softmax_scale_power",)), } @classmethod - def export_config(cls, config: AttentionConfig) -> dict: - from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Rotary2DConfig + def _validate_export(cls, config: AttentionConfig) -> None: + # Replace Pixtral's Rotary2D-only + head_groups==heads checks (Apriel2 vision allows both rotary types + # and supports GQA). Keep the per-layer bias consistency check from the Llama base. + Assert.incl(config.query_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.key_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.value_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.dense_layer.bias.enabled, (None, config.add_linear_biases)) + + +class Apriel2VisionMLPConverter(ConfigSectionConverter): + """The vision-side MLP shape ``{type: mlp, intermediate_size, activation, gated, add_linear_biases}``. + + Distinct from the text :class:`Apriel2MLPConverter` only in lacking the per-layer-bias declaration: the + Apriel2 vision MLP HF shape has no representation for per-layer ``bias.enabled`` overrides, so the + Fast-LLM defaults are dropped on export (declared :class:`IgnoredConfigConverter`) and re-defaulted on + import. Weight-side ``get_converters`` is shared with the text MLP. + """ - if type(config.rotary) is Rotary2DConfig: - rotary_type = "pixtral_2d" - elif type(config.rotary) is DefaultRotaryConfig: - rotary_type = "mistral_1d" - else: - raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") + fast_llm_config_class = MLPConfig + hf_type_name = "mlp" + @classmethod + def _create_config_converters(cls) -> dict: return { - "type": "attention", - "heads": config.heads, - "head_groups": config.head_groups, - "head_size": config.head_size, - "add_linear_biases": config.add_linear_biases, - "causal": config.causal, - "cross_document_attention": False, - "rotary": { - "type": rotary_type, - "theta": config.rotary.theta, - }, + "intermediate_size": RenameConfigConverter(("intermediate_size",), ("intermediate_size",)), + "gated": RenameConfigConverter(("gated",), ("gated",)), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("add_linear_biases",)), + "activation": CustomConfigConverter( + fast_llm_paths=(("activation",),), + hf_paths=(("activation",),), + export_fn=lambda c: {("activation",): c.activation.hf_name}, + import_fn=lambda hf: {("activation",): ActivationType.from_hf_name(hf["activation"])}, + ), + "linear_layers": IgnoredConfigConverter(("layer_1",), ("layer_2",)), } + @classmethod + def get_converters( + cls, config: MLPConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False + ) -> list[WeightConverter]: + from fast_llm.models.gpt.conversion.apriel2 import Apriel2MLPConverter + + return Apriel2MLPConverter.get_converters(config, fast_llm_prefix, hf_prefix, drop_on_export) + + +class Apriel2VisionBlockConverter(ConfigSectionConverter): + """Converts a vision :class:`DecoderBlockConfig` ↔ Apriel2's nested ``{mixer, mlp, normalization}`` block. + + Distinct from :class:`PixtralBlockConverter` (which flat-merges its children into the parent's HF dict) + because the Apriel2 vision format nests each sub-section under a typed sub-key, matching the Apriel2 text + decoder shape. + """ + + fast_llm_config_class = DecoderBlockConfig -class Apriel2VisionBlockConverter(PixtralBlockConverter): mixer_converter_class: typing.ClassVar[type[Apriel2VisionAttentionConverter]] = Apriel2VisionAttentionConverter + mlp_converter_class: typing.ClassVar[type[Apriel2VisionMLPConverter]] = Apriel2VisionMLPConverter + # Config-side: the Apriel2 HF format nests normalization as ``{"type": "rms_norm", "epsilon": ...}``; + # ``Apriel2RMSNormConverter`` handles the typed shape. Weight side uses LlamaNormalizationConverter + # directly (flat parameter names — independent of how the surrounding HF config is structured). + normalization_converter_class: typing.ClassVar[type[Apriel2RMSNormConverter]] = Apriel2RMSNormConverter + hf_mixer_name: typing.ClassVar[str] = "mixer" hf_mlp_name: typing.ClassVar[str] = "mlp" hf_norm_1_name: typing.ClassVar[str] = "input_layernorm" hf_norm_2_name: typing.ClassVar[str] = "post_attention_layernorm" @classmethod - def import_config(cls, config: dict, block_config: dict) -> dict: - mixer_config = block_config["mixer"] - mlp_config = block_config["mlp"] - norm_config = block_config["normalization"] - + def _create_config_converters(cls) -> dict: return { - "mixer": cls.mixer_converter_class.import_config(mixer_config), - "mlp": { - "type": "mlp", - "intermediate_size": mlp_config["intermediate_size"], - "activation": ActivationType.from_hf_name(mlp_config["activation"]), - "gated": mlp_config["gated"], - "add_linear_biases": mlp_config["add_linear_biases"], - }, - "normalization": cls.normalization_converter_class.import_config(norm_config), + "mixer": NestedConfigConverter(("mixer",), cls.mixer_converter_class, hf_path=("mixer",)), + "mlp": NestedConfigConverter(("mlp",), cls.mlp_converter_class, hf_path=("mlp",)), + "normalization": NestedConfigConverter( + ("normalization",), cls.normalization_converter_class, hf_path=("normalization",) + ), } + # --- weight side (imperative) --- + @classmethod - def export_config(cls, config) -> dict: - from fast_llm.layers.decoder.config import DecoderBlockConfig + def get_converters( + cls, config: DecoderBlockConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False + ) -> list[WeightConverter]: + return [ + *cls.mixer_converter_class.get_converters( + config.mixer, f"{fast_llm_prefix}.mixer", f"{hf_prefix}.{cls.hf_mixer_name}", drop_on_export + ), + *cls.mlp_converter_class.get_converters( + config.mlp, f"{fast_llm_prefix}.mlp", f"{hf_prefix}.{cls.hf_mlp_name}", drop_on_export + ), + *LlamaNormalizationConverter.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm_1", + f"{hf_prefix}.{cls.hf_norm_1_name}", + drop_on_export, + ), + *LlamaNormalizationConverter.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm_2", + f"{hf_prefix}.{cls.hf_norm_2_name}", + drop_on_export, + ), + ] - Assert.custom(isinstance, config, DecoderBlockConfig) - return { - "mixer": cls.mixer_converter_class.export_config(config.mixer), - "mlp": { - "type": "mlp", - "intermediate_size": config.mlp.intermediate_size, - "activation": config.mlp.activation.hf_name, - "gated": config.mlp.gated, - "add_linear_biases": config.mlp.add_linear_biases, - }, - "normalization": { - "type": "rms_norm", - "epsilon": config.normalization.epsilon, - }, - } +class Apriel2VisionEncoderConverter(ConfigSectionConverter): + """Converts a :class:`FixedBlockSequenceConfig` (vision encoder) ↔ Apriel2 HF ``encoder`` subdict + the + flat ``num_hidden_layers`` mirror that the HF format also requires at the surrounding vision_config level. + + No ``hf_type_name`` is set: the ``type: "fixed"`` discriminator lives *inside* the ``encoder`` subdict + (emitted by the Custom's export_fn), not at the parent vision_config level. The Fast-LLM-side ``type`` + is auto-injected by :meth:`ConfigSectionConverter.import_config` via ``fast_llm_config_class.dynamic_type_name``. + """ + + fast_llm_config_class = FixedBlockSequenceConfig -class Apriel2VisionEncoderConverter(PixtralEncoderConverter): block_converter_class: typing.ClassVar[type[Apriel2VisionBlockConverter]] = Apriel2VisionBlockConverter @classmethod - def import_config(cls, config: dict) -> dict: - encoder_config = config["encoder"] - num_blocks = encoder_config["num_blocks"] - block_config = encoder_config["block"] - + def _create_config_converters(cls) -> dict: return { - "type": "fixed", - "num_blocks": num_blocks, - "block": cls.block_converter_class.import_config(config, block_config), + "encoder": CustomConfigConverter( + fast_llm_paths=(("num_blocks",), ("block",)), + hf_paths=(("encoder",),), + export_fn=lambda c: { + ("encoder",): { + "type": "fixed", + "num_blocks": c.num_blocks, + "block": cls.block_converter_class.export_config(c.block), + }, + }, + import_fn=lambda hf: { + ("num_blocks",): hf["encoder"]["num_blocks"], + ("block",): cls.block_converter_class.import_config(hf["encoder"]["block"]), + }, + recurses=True, + ), + "num_hidden_layers_mirror": CustomConfigConverter( + fast_llm_paths=(), + hf_paths=(("num_hidden_layers",),), + export_fn=lambda c: {("num_hidden_layers",): c.num_blocks}, + import_fn=lambda hf: {}, + ), } + # --- weight side (imperative) --- + @classmethod - def export_config(cls, config) -> dict: - from fast_llm.layers.block.config import FixedBlockSequenceConfig + def get_converters( + cls, + config: FixedBlockSequenceConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + converters: list[WeightConverter] = [] + for block_index in range(config.num_blocks): + converters += cls.block_converter_class.get_converters( + config.block, + f"{fast_llm_prefix}.{block_index}", + f"{hf_prefix}.{block_index}", + drop_on_export, + ) + return converters - Assert.custom(isinstance, config, FixedBlockSequenceConfig) - return { - "encoder": { - "type": "fixed", - "num_blocks": config.num_blocks, - "block": cls.block_converter_class.export_config(config.block), - }, - "num_hidden_layers": config.num_blocks, - } +class Apriel2EmbeddingsConverter(ConfigSectionConverter): + """Converts :class:`PatchEmbeddingsConfig` ↔ Apriel2 HF ``embeddings`` subdict, with top-level + ``patch_size``/``num_channels`` mirrors that the Apriel2 vision_config also requires.""" -class Apriel2EmbeddingsConverter: - """Converts between Fast-LLM PatchEmbeddingsConfig and Apriel2 HF embeddings format.""" + fast_llm_config_class = PatchEmbeddingsConfig - normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter + normalization_converter_class: typing.ClassVar[type[Apriel2RMSNormConverter]] = Apriel2RMSNormConverter @classmethod - def import_config(cls, config: dict) -> dict: - embeddings_config = config["embeddings"] - Assert.eq(embeddings_config["input_channels"], 3) + def _create_config_converters(cls) -> dict: return { - "normalization": embeddings_config["normalization"], - "patch_height": embeddings_config["patch_height"], - "patch_width": embeddings_config["patch_width"], + "patch_height": RenameConfigConverter(("patch_height",), ("embeddings", "patch_height")), + "patch_width": RenameConfigConverter(("patch_width",), ("embeddings", "patch_width")), + "normalization": NestedConfigConverter( + ("normalization",), + cls.normalization_converter_class, + hf_path=("embeddings", "normalization"), + ), + # ``patch_embeddings`` (AffineLinearConfig) carries no HF architecture info; bias presence validated below. + "patch_embeddings": IgnoredConfigConverter(("patch_embeddings",)), + # ``input_channels`` is a cached_property pinned to 3 on the Fast-LLM side; HF emits it under + # ``embeddings`` and again as a top-level ``num_channels`` mirror. + "embeddings_input_channels": ConstantExportConfigConverter(("embeddings", "input_channels"), 3), + "num_channels_mirror": ConstantExportConfigConverter(("num_channels",), 3), + # ``patch_size`` HF top-level mirror of ``embeddings.patch_height`` — emit on export, ignored on + # import (the under-``embeddings`` path is the authoritative source). + "patch_size_mirror": CustomConfigConverter( + fast_llm_paths=(), + hf_paths=(("patch_size",),), + export_fn=lambda c: {("patch_size",): c.patch_height}, + import_fn=lambda hf: {}, + ), } @classmethod - def export_config(cls, config: PatchEmbeddingsConfig) -> dict: - Assert.custom(isinstance, config, PatchEmbeddingsConfig) + def _validate_export(cls, config: PatchEmbeddingsConfig) -> None: Assert.eq(config.patch_height, config.patch_width) Assert.incl(config.patch_embeddings.bias.enabled, (None, False)) - return { - "embeddings": { - "patch_height": config.patch_height, - "patch_width": config.patch_width, - "input_channels": config.input_channels, - "normalization": {"type": "rms_norm", "epsilon": config.normalization.epsilon}, - }, - "patch_size": config.patch_height, - "num_channels": config.input_channels, - } - @classmethod def get_converters( cls, config: PatchEmbeddingsConfig, fast_llm_prefix: str, hf_prefix: str @@ -201,82 +326,98 @@ def get_converters( PatchEmbeddingWeightConverter, config, ), - *cls.normalization_converter_class.get_converters( + *LlamaNormalizationConverter.get_converters( config.normalization, f"{fast_llm_prefix}.normalization", f"{hf_prefix}.normalization" ), ] -class Apriel2VisionAdapterConverter(LlavaVisionAdapterConverter): +class Apriel2VisionAdapterConverter(ConfigSectionConverter, LlavaVisionAdapterConverter): + """Converts :class:`MLPConfig` (adapter) ↔ Apriel2 HF ``adapter`` subdict. + + Apriel2 nests the adapter shape under ``adapter`` and uses the typed ``{type: mlp, ...}`` dict-of-fields + layout, distinct from Llava's flat top-level ``projector_hidden_act``/``multimodal_projector_bias`` shape. + + MRO ordering: :class:`ConfigSectionConverter` (declarative) comes before :class:`LlavaVisionAdapterConverter` + (imperative) so the declarative ``import_config``/``export_config`` are picked up. ``get_converters`` is + inherited from Llava (weight-side imperative). + """ + + fast_llm_config_class = MLPConfig + hf_type_name = "mlp" + @classmethod - def import_config(cls, config: dict) -> dict: - adapter_config = config["adapter"] + def _create_config_converters(cls) -> dict: return { - "intermediate_size": adapter_config["intermediate_size"], - "add_linear_biases": adapter_config["add_linear_biases"], - "gated": adapter_config["gated"], - "activation": ActivationType.from_hf_name(adapter_config["activation"]), + "intermediate_size": RenameConfigConverter(("intermediate_size",), ("intermediate_size",)), + "gated": RenameConfigConverter(("gated",), ("gated",)), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("add_linear_biases",)), + "activation": CustomConfigConverter( + fast_llm_paths=(("activation",),), + hf_paths=(("activation",),), + export_fn=lambda c: {("activation",): c.activation.hf_name}, + import_fn=lambda hf: {("activation",): ActivationType.from_hf_name(hf["activation"])}, + ), + "linear_layers": IgnoredConfigConverter(("layer_1",), ("layer_2",)), } @classmethod - def export_config(cls, config: MLPConfig) -> dict: - Assert.custom(isinstance, config, MLPConfig) + def _validate_export(cls, config: MLPConfig) -> None: Assert.incl(config.layer_1.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) - return { - "adapter": { - "type": "mlp", - "intermediate_size": config.intermediate_size, - "activation": config.activation.hf_name, - "add_linear_biases": config.add_linear_biases, - "gated": config.gated, - }, - } +class Apriel2VisionModelConverter(ConfigSectionConverter): + """Top-level vision-encoder converter. The HF representation lives under a single ``vision_encoder`` key, + so declarations are written relative to that nested subdict. + + ``patch_size``/``max_image_size`` rotary metadata is injected here (cross-section reference to + ``embeddings.patch_height``) — the attention converter cannot see it from its own scope. + """ + + fast_llm_config_class = VisionEncoderConfig -class Apriel2VisionModelConverter(LlavaVisionModelConverter): + embeddings_converter_class: typing.ClassVar[type[Apriel2EmbeddingsConverter]] = Apriel2EmbeddingsConverter + encoder_converter_class: typing.ClassVar[type[Apriel2VisionEncoderConverter]] = Apriel2VisionEncoderConverter vision_adapter_converter_class: typing.ClassVar[type[Apriel2VisionAdapterConverter]] = ( Apriel2VisionAdapterConverter ) - embeddings_converter_class: typing.ClassVar[type[Apriel2EmbeddingsConverter]] = Apriel2EmbeddingsConverter - encoder_converter_class: typing.ClassVar[type[Apriel2VisionEncoderConverter]] = Apriel2VisionEncoderConverter - # HF path prefixes for Apriel2 (external HF model format) hf_embeddings_prefix: typing.ClassVar[str] = "model.vision_encoder.embeddings" hf_encoder_prefix: typing.ClassVar[str] = "model.vision_encoder.encoder.blocks" hf_adapter_prefix: typing.ClassVar[str] = "model.vision_encoder.adapter" @classmethod - def import_config(cls, config: dict) -> dict: - vision_config = config["vision_encoder"] + def _create_config_converters(cls) -> dict: return { - "embeddings": cls.embeddings_converter_class.import_config(vision_config), - "encoder": cls.encoder_converter_class.import_config(vision_config), - "adapter": cls.vision_adapter_converter_class.import_config(vision_config), - "hidden_size": vision_config["hidden_size"], + "embeddings": NestedConfigConverter(("embeddings",), cls.embeddings_converter_class), + "encoder": NestedConfigConverter(("encoder",), cls.encoder_converter_class), + "adapter": NestedConfigConverter(("adapter",), cls.vision_adapter_converter_class, hf_path=("adapter",)), + "hidden_size": RenameConfigConverter(("hidden_size",), ("hidden_size",)), + # Cross-section rotary metadata: the Apriel2 HF format requires patch_size + max_image_size inside + # ``encoder.block.mixer.rotary`` (for ``pixtral_2d``), derived from embeddings.patch_height plus a + # constant 1024. Written here because this converter is the smallest scope that sees both. + # No fast_llm_paths/hf_paths claims — the encoder's recursive rotary claim covers HF coverage; the + # values land on import via the same recursive claim and are stripped by the attention import_fn. + "rotary_metadata": CustomConfigConverter( + fast_llm_paths=(), + hf_paths=(), + export_fn=cls._inject_rotary_metadata, + import_fn=lambda hf: {}, + ), } - @classmethod - def export_config(cls, config: VisionEncoderConfig) -> dict: - Assert.custom(isinstance, config, VisionEncoderConfig) - - vision_config = safe_merge_dicts( - cls.embeddings_converter_class.export_config(config.embeddings), - cls.encoder_converter_class.export_config(config.encoder), - cls.vision_adapter_converter_class.export_config(config.adapter), - {"hidden_size": config.hidden_size}, - ) + @staticmethod + def _inject_rotary_metadata(config: VisionEncoderConfig) -> dict: + rotary = config.encoder.block.mixer.rotary + if type(rotary) is Rotary2DConfig: + return { + ("encoder", "block", "mixer", "rotary", "patch_size"): config.embeddings.patch_height, + ("encoder", "block", "mixer", "rotary", "max_image_size"): 1024, + } + return {} - # Add patch_size and max_image_size to rotary config for pixtral_2d - patch_size = config.embeddings.patch_height - encoder_block = vision_config["encoder"]["block"] - rotary = encoder_block["mixer"]["rotary"] - if rotary["type"] == "pixtral_2d": - rotary["patch_size"] = patch_size - rotary["max_image_size"] = 1024 # Standard max image size for Pixtral - - return {"vision_encoder": vision_config} + # --- weight side (imperative) --- @classmethod def get_converters(cls, config: VisionEncoderConfig) -> list[WeightConverter]: @@ -316,42 +457,76 @@ def get_converters( ] -class Apriel2MultimodalBaseModelConverter: +class Apriel2MultimodalBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): + """Top-level converter for Apriel2 multimodal. Composes the Apriel2 text base (flat-merged into the HF + top-level dict) with an optional vision encoder (under HF key ``vision_encoder``) and an optional + ``image_token_index`` field. + + Architecturally the Fast-LLM config (:class:`MultiModalBaseModelConfig`) multi-inherits from both + :class:`GPTBaseModelConfig` (text) and :class:`VisionMultiModalModelConfig` (vision/image_token_index), + so a single declaration set drives both halves. + """ + + fast_llm_config_class = MultiModalBaseModelConfig + + text_base_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BaseModelConverter vision_model_converter_class: typing.ClassVar[type[Apriel2VisionModelConverter]] = Apriel2VisionModelConverter - embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter - head_converter_class: typing.ClassVar[type[Apriel2MultimodalHeadConverter]] = Apriel2MultimodalHeadConverter @classmethod - def import_config(cls, config: dict) -> dict: - text_config = Apriel2BaseModelConverter.import_config(config) - vision_config = cls.vision_model_converter_class.import_config(config) if "vision_encoder" in config else None + def _create_config_converters(cls) -> dict: + text_base_cls = cls.text_base_converter_class + vision_cls = cls.vision_model_converter_class - result = safe_merge_dicts( - text_config, - {"vision_encoder": vision_config}, - ) - if "image_token_index" in config: - result["image_token_index"] = config["image_token_index"] - return result + def _vision_export(config: MultiModalBaseModelConfig) -> dict: + if config.vision_encoder is None: + return {} + return {("vision_encoder",): vision_cls.export_config(config.vision_encoder)} - @classmethod - def export_config(cls, config: MultiModalBaseModelConfig) -> dict: - Assert.custom(isinstance, config, MultiModalBaseModelConfig) - exported = Apriel2BaseModelConverter.export_config(config) - if config.vision_encoder is not None: - exported = safe_merge_dicts( - exported, - cls.vision_model_converter_class.export_config(config.vision_encoder), - ) + def _vision_import(hf_dict: dict) -> dict: + if "vision_encoder" not in hf_dict: + return {} + return {("vision_encoder",): vision_cls.import_config(hf_dict["vision_encoder"])} - if config.image_token_index is not None: - exported["image_token_index"] = config.image_token_index + return { + # Flat-merge the Apriel2 text base into the top-level HF dict. The text base claims every + # GPTBaseModelConfig architecture leaf via its own declarations; we mark them recursively + # consumed here and forward HF coverage via the text base's ``_consumed_hf_paths``. + "text_base": CustomConfigConverter( + fast_llm_paths=( + ("embeddings",), + ("decoder",), + ("head",), + ("hidden_size",), + ("tied_embedding_weight",), + ("peft",), + ), + hf_paths=tuple(text_base_cls._consumed_hf_paths()), + export_fn=lambda c: {(k,): v for k, v in text_base_cls.export_config(c).items()}, + import_fn=lambda hf: {(k,): v for k, v in text_base_cls.import_config(hf).items()}, + recurses=True, + ), + # Optional vision encoder. The Fast-LLM ``vision_encoder`` field is architecture-hint and + # ``None`` by default; the HF ``vision_encoder`` key is absent for text-only models. + "vision_encoder": CustomConfigConverter( + fast_llm_paths=(("vision_encoder",),), + hf_paths=(("vision_encoder",),), + export_fn=_vision_export, + import_fn=_vision_import, + recurses=True, + ), + # ``image_token_index`` is FieldHint.optional so it's not in the architecture-coverage set, + # but it does live on the HF dict for vision-enabled checkpoints. + "image_token_index": OptionalConfigConverter(("image_token_index",), ("image_token_index",)), + } - return exported + # --- weight side (imperative) --- + + embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter + head_converter_class: typing.ClassVar[type[Apriel2MultimodalHeadConverter]] = Apriel2MultimodalHeadConverter @classmethod def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict) -> list[WeightConverter]: - converters = [] + converters: list[WeightConverter] = [] if config.vision_encoder is not None: converters.extend(cls.vision_model_converter_class.get_converters(config.vision_encoder)) converters.extend(cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model")) @@ -361,7 +536,6 @@ def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict ) ) converters.extend(cls.head_converter_class.get_converters(config.head, exported_config, "head")) - return converters @@ -392,21 +566,17 @@ def get_model_files(cls) -> tuple[str, str, str | None]: @classmethod def _export_config(cls, config: MultiModalModelConfig) -> dict[str, typing.Any]: - base_model = config.base_model - exported = safe_merge_dicts( - cls.base_model_converter_class.export_config(base_model), - { - "architectures": [cls.architecture], - "model_type": cls.get_huggingface_model_type(), - "auto_map": { - "AutoConfig": "configuration_apriel2.Apriel2Config", - "AutoModel": "modeling_apriel2.Apriel2Model", - "AutoModelForCausalLM": "modeling_apriel2.Apriel2ForConditionalGeneration", - "AutoModelForImageTextToText": "modeling_apriel2.Apriel2ForConditionalGeneration", - }, + return { + **cls.base_model_converter_class.export_config(config.base_model), + "architectures": [cls.architecture], + "model_type": cls.get_huggingface_model_type(), + "auto_map": { + "AutoConfig": "configuration_apriel2.Apriel2Config", + "AutoModel": "modeling_apriel2.Apriel2Model", + "AutoModelForCausalLM": "modeling_apriel2.Apriel2ForConditionalGeneration", + "AutoModelForImageTextToText": "modeling_apriel2.Apriel2ForConditionalGeneration", }, - ) - return exported + } @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> dict[str, typing.Any]: From 6cfcc2c547e7fbf49e1e422296cffffeb2e78de5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 May 2026 16:47:09 -0400 Subject: [PATCH 22/27] Migrate Llava multimodal config converters to declarative MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit LlavaVisionAdapter/VisionModel/Base converters become ConfigSectionConverter subclasses. The adapter is declared at the LlavaBaseModelConverter scope (not inside VisionModelConverter) because its intermediate_size derives from text_config.hidden_size — a cross-section reference reachable only at the top-level HF dict. PixtralAttentionConverter's head_size declaration changes from DefaultConfig (emits head_dim) to ImportOnly (derives from hidden_size / num_attention_heads). The previous head_dim popping in the imperative LlavaVisionModelConverter is replaced by a head_size invariant check on the new declarative converter's _validate_export. Apriel2VisionAdapterConverter loses its MRO trick (ConfigSectionConverter + LlavaVisionAdapterConverter) and inherits cleanly from Llava — now that Llava is also a ConfigSectionConverter, the trick would produce an inconsistent MRO. Co-Authored-By: Claude Opus 4.7 --- .../models/multimodal/conversion/apriel2.py | 8 +- .../models/multimodal/conversion/llava.py | 209 +++++++++++++----- 2 files changed, 153 insertions(+), 64 deletions(-) diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index f808714eb..250b95f59 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -332,15 +332,15 @@ def get_converters( ] -class Apriel2VisionAdapterConverter(ConfigSectionConverter, LlavaVisionAdapterConverter): +class Apriel2VisionAdapterConverter(LlavaVisionAdapterConverter): """Converts :class:`MLPConfig` (adapter) ↔ Apriel2 HF ``adapter`` subdict. Apriel2 nests the adapter shape under ``adapter`` and uses the typed ``{type: mlp, ...}`` dict-of-fields layout, distinct from Llava's flat top-level ``projector_hidden_act``/``multimodal_projector_bias`` shape. - MRO ordering: :class:`ConfigSectionConverter` (declarative) comes before :class:`LlavaVisionAdapterConverter` - (imperative) so the declarative ``import_config``/``export_config`` are picked up. ``get_converters`` is - inherited from Llava (weight-side imperative). + Inherits declarative ``import_config``/``export_config`` from :class:`ConfigSectionConverter` via + :class:`LlavaVisionAdapterConverter`, and weight-side ``get_converters`` from Llava (same ``linear_1`` / + ``linear_2`` weight names as Llava). """ fast_llm_config_class = MLPConfig diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 40f22616b..b504bc033 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -37,7 +37,7 @@ from fast_llm.models.multimodal.conversion.config import LlavaCheckpointFormat from fast_llm.models.multimodal.model import MultiModalModel from fast_llm.tensor import SafeTensorSlice -from fast_llm.utils import Assert, div, safe_merge_dicts +from fast_llm.utils import Assert, div class PixtralNormalizationConverter(LlamaNormalizationConverter): @@ -81,6 +81,15 @@ def _create_config_converters(cls) -> dict: fast_llm_paths=(("head_groups",),), import_fn=lambda hf: {("head_groups",): hf["num_attention_heads"]}, ), + # Llava's PixtralVisionConfig has no ``head_dim`` field — it is derived as ``hidden_size // + # num_attention_heads``. Don't emit head_dim on export (would otherwise need to be popped + # downstream); on import, derive head_size from the same expression. Invariant validated by + # :class:`LlavaVisionModelConverter._validate_export`, which has access to the parent's + # ``hidden_size``. + "head_size": ImportOnlyConfigConverter( + fast_llm_paths=(("head_size",),), + import_fn=lambda hf: {("head_size",): div(hf["hidden_size"], hf["num_attention_heads"])}, + ), # Pixtral always uses 2D rotary; only ``theta`` round-trips. The flat (v4) vs ``rope_parameters`` (v5) # layout follows the active transformers major version, mirroring the Llama parent. "rotary": CustomConfigConverter( @@ -191,27 +200,46 @@ def get_converters( ] -class LlavaVisionAdapterConverter: +class LlavaVisionAdapterConverter(ConfigSectionConverter): + """Converts the vision adapter :class:`MLPConfig` ↔ Llava's flat top-level adapter fields + (``projector_hidden_act``, ``multimodal_projector_bias``). + + Wrinkle: the adapter's ``intermediate_size`` derives from the **text** half of the model + (``text_config["hidden_size"]``). The cross-section reference is reachable because this converter is + flat-merged at the :class:`LlavaBaseModelConverter` scope, where ``text_config`` lives as a sibling + HF top-level key. + """ + + fast_llm_config_class = MLPConfig + hf_type_name = "mlp" + @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - "intermediate_size": config["text_config"]["hidden_size"], - "add_linear_biases": config["multimodal_projector_bias"], - "gated": False, - "activation": ActivationType.from_hf_name(config["projector_hidden_act"]), + # Cross-section: imported from text_config.hidden_size. No HF claim — text_config is claimed + # by the language model converter at the base level. + "intermediate_size": ImportOnlyConfigConverter( + fast_llm_paths=(("intermediate_size",),), + import_fn=lambda hf: {("intermediate_size",): hf["text_config"]["hidden_size"]}, + ), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("multimodal_projector_bias",)), + "gated": ConstantImportConfigConverter(("gated",), False), + "activation": CustomConfigConverter( + fast_llm_paths=(("activation",),), + hf_paths=(("projector_hidden_act",),), + export_fn=lambda c: {("projector_hidden_act",): c.activation.hf_name}, + import_fn=lambda hf: {("activation",): ActivationType.from_hf_name(hf["projector_hidden_act"])}, + ), + # Per-layer ``bias.enabled`` has no HF representation; defaults round-trip. Validated below. + "linear_layers": IgnoredConfigConverter(("layer_1",), ("layer_2",)), } @classmethod - def export_config(cls, config: MLPConfig) -> dict: - Assert.custom(isinstance, config, MLPConfig) + def _validate_export(cls, config: MLPConfig) -> None: Assert.incl(config.layer_1.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) - assert not config.gated - return { - "projector_hidden_act": config.activation.hf_name, - "multimodal_projector_bias": config.add_linear_biases, - } + # --- weight side (imperative) --- @classmethod def get_converters(cls, config: MLPConfig, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: @@ -231,39 +259,63 @@ def get_converters(cls, config: MLPConfig, fast_llm_prefix: str, hf_prefix: str) ] -class LlavaVisionModelConverter: - vision_adapter_converter_class: typing.ClassVar[type[LlavaVisionAdapterConverter]] = LlavaVisionAdapterConverter +class LlavaVisionModelConverter(ConfigSectionConverter): + """Converts :class:`VisionEncoderConfig` ↔ Llava's ``vision_config`` HF subdict. + + Declarations operate relative to ``vision_config`` (parent nests this converter via + ``NestedConfigConverter(hf_path=("vision_config",))``). The adapter is *not* declared here — it + lives at the base level because its Fast-LLM intermediate_size derives from text_config.hidden_size, + a cross-section reference only visible at the top of the HF dict. + """ + + fast_llm_config_class = VisionEncoderConfig + embeddings_converter_class: typing.ClassVar[type[PixtralEmbeddingsConverter]] = PixtralEmbeddingsConverter encoder_converter_class: typing.ClassVar[type[PixtralEncoderConverter]] = PixtralEncoderConverter model_type: typing.ClassVar[str] = "pixtral" @classmethod - def import_config(cls, config: dict) -> dict: - Assert.eq(config["vision_config"]["model_type"], cls.model_type) + def _create_config_converters(cls) -> dict: + encoder_cls = cls.encoder_converter_class + + def _encoder_export(config: VisionEncoderConfig) -> dict: + return {(k,): v for k, v in encoder_cls.export_config(config.encoder).items()} + + def _encoder_import(hf_dict: dict) -> dict: + return {("encoder",): encoder_cls.import_config(hf_dict)} + return { - "embeddings": cls.embeddings_converter_class.import_config(config["vision_config"]), - "encoder": cls.encoder_converter_class.import_config(config["vision_config"]), - "adapter": cls.vision_adapter_converter_class.import_config(config), - "hidden_size": config["vision_config"]["hidden_size"], + # Flat-merged into vision_config: embeddings (PatchEmbeddingsConverter writes patch_size/etc), + # encoder (LlamaDecoderConverter dispatch — Custom-wrapped since it stays imperative). + "embeddings": NestedConfigConverter(("embeddings",), cls.embeddings_converter_class), + "encoder": CustomConfigConverter( + fast_llm_paths=(("encoder",),), + hf_paths=( + ("num_hidden_layers",), + *encoder_cls.block_converter_class._consumed_hf_paths(), + ), + export_fn=_encoder_export, + import_fn=_encoder_import, + recurses=True, + ), + "hidden_size": RenameConfigConverter(("hidden_size",), ("hidden_size",)), + # Llava's vision_config carries a literal ``model_type: "pixtral"``; + # ``ConstantExportConfigConverter`` emits on export and asserts equality on import. + "model_type": ConstantExportConfigConverter(("model_type",), cls.model_type), + # Adapter is handled at LlavaBaseModelConverter scope (sees text_config). Mark recursively + # consumed here so the architecture walker sees the sub-tree as claimed at this level too. + "adapter": IgnoredConfigConverter(("adapter",)), } @classmethod - def export_config(cls, config: VisionEncoderConfig) -> dict: - Assert.custom(isinstance, config, VisionEncoderConfig) - vision_config = safe_merge_dicts( - cls.embeddings_converter_class.export_config(config.embeddings), - cls.encoder_converter_class.export_config(config.encoder), - {"hidden_size": config.hidden_size, "model_type": cls.model_type}, - ) - - Assert.eq( - vision_config.pop("head_dim"), div(vision_config["hidden_size"], vision_config["num_attention_heads"]) - ) + def _validate_export(cls, config: VisionEncoderConfig) -> None: + # Llava's PixtralVisionConfig does not carry head_dim — it is derived as ``hidden_size // + # num_attention_heads``. Validate the Fast-LLM head_size satisfies this invariant. + mixer = config.encoder.block.mixer + if isinstance(mixer, AttentionConfig): + Assert.eq(mixer.head_size * mixer.heads, config.hidden_size) - return safe_merge_dicts( - {"vision_config": vision_config}, - cls.vision_adapter_converter_class.export_config(config.adapter), - ) + # --- weight side (imperative) --- @classmethod def get_converters(cls, config: VisionEncoderConfig) -> list[WeightConverter]: @@ -274,7 +326,7 @@ def get_converters(cls, config: VisionEncoderConfig) -> list[WeightConverter]: *cls.encoder_converter_class.get_converters( config.encoder, "vision_encoder.encoder", "vision_tower.transformer.layers" ), - *cls.vision_adapter_converter_class.get_converters( + *LlavaVisionAdapterConverter.get_converters( config.adapter, "vision_encoder.adapter", "multi_modal_projector" ), ] @@ -305,36 +357,73 @@ class LlavaLanguageModelConverter(MistralBaseModelConverter): head_converter_class: typing.ClassVar[type[LlavaHeadConverter]] = LlavaHeadConverter -class LlavaBaseModelConverter(HuggingFaceBaseModelConverter): +class LlavaBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): + """Top-level converter for Llava. Composes: + + * ``text_config`` HF subdict ← :class:`LlavaLanguageModelConverter` (Mistral text base). + * ``vision_config`` HF subdict ← :class:`LlavaVisionModelConverter` (Pixtral vision encoder). + * Top-level adapter fields (``projector_hidden_act``, ``multimodal_projector_bias``) ← + :class:`LlavaVisionAdapterConverter`, flat-merged because the adapter's ``intermediate_size`` + derives from ``text_config.hidden_size``. + * Top-level multimodal metadata (``image_token_index``, ``vision_feature_select_strategy``, + ``vision_feature_layer``). + """ + + fast_llm_config_class = MultiModalBaseModelConfig + vision_model_converter_class: typing.ClassVar[type[LlavaVisionModelConverter]] = LlavaVisionModelConverter + vision_adapter_converter_class: typing.ClassVar[type[LlavaVisionAdapterConverter]] = LlavaVisionAdapterConverter # TODO: Make it flexible? language_model_converter_class: typing.ClassVar[type[LlavaLanguageModelConverter]] = LlavaLanguageModelConverter # TODO: Is tie_word_embeddings supported? @classmethod - def import_config(cls, config: dict) -> dict: - return safe_merge_dicts( - { - "vision_encoder": cls.vision_model_converter_class.import_config(config), - "image_token_index": config["image_token_index"], - }, - cls.language_model_converter_class.import_config(config["text_config"]), - ) + def _create_config_converters(cls) -> dict: + text_base_cls = cls.language_model_converter_class + vision_cls = cls.vision_model_converter_class + adapter_cls = cls.vision_adapter_converter_class + + # The Fast-LLM ``MultiModalBaseModelConfig`` IS-A ``GPTBaseModelConfig`` (multi-inherits via + # ``VisionMultiModalModelConfig``), so ``text_base_cls.export_config(config)`` works directly on + # the multimodal config: its declarations only touch GPTBaseModelConfig fields, which exist here. + def _text_export(config: MultiModalBaseModelConfig) -> dict: + return {("text_config",): text_base_cls.export_config(config)} + + def _text_import(hf_dict: dict) -> dict: + return {(k,): v for k, v in text_base_cls.import_config(hf_dict["text_config"]).items()} + + return { + "text_base": CustomConfigConverter( + fast_llm_paths=( + ("embeddings",), + ("decoder",), + ("head",), + ("hidden_size",), + ("tied_embedding_weight",), + ("peft",), + ), + hf_paths=(("text_config",),), + export_fn=_text_export, + import_fn=_text_import, + recurses=True, + ), + "vision_encoder": NestedConfigConverter(("vision_encoder",), vision_cls, hf_path=("vision_config",)), + # Adapter flat-merged at top level: its import sees text_config.hidden_size as a sibling key. + "adapter": NestedConfigConverter(("vision_encoder", "adapter"), adapter_cls, hf_path=None), + "image_token_index": RenameConfigConverter(("image_token_index",), ("image_token_index",)), + "vision_feature_select_strategy": ConstantExportConfigConverter( + ("vision_feature_select_strategy",), "full" + ), + "vision_feature_layer": ConstantExportConfigConverter(("vision_feature_layer",), -1), + } @classmethod - def export_config(cls, config: MultiModalBaseModelConfig) -> dict: - Assert.custom(isinstance, config, MultiModalBaseModelConfig) - assert config.image_token_index is not None - out = safe_merge_dicts( - cls.vision_model_converter_class.export_config(config.vision_encoder), - { - "text_config": cls.language_model_converter_class.export_config(config), - "image_token_index": config.image_token_index, - "vision_feature_select_strategy": "full", - "vision_feature_layer": -1, - }, - ) - return out + def _validate_export(cls, config: MultiModalBaseModelConfig) -> None: + # Llava requires both a vision encoder and an image_token_index to be set. + Assert.custom(lambda v: v is not None, config.vision_encoder) + Assert.custom(lambda v: v is not None, config.image_token_index) + + # --- weight side (imperative) --- @classmethod def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict) -> list[WeightConverter]: From d35e39c95cbb921b13d4e04bfc71d5a32e28a089 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 May 2026 16:54:41 -0400 Subject: [PATCH 23/27] Flatten LlamaDecoderConverter chain + Qwen2 MRoPE declarative Mistral/Qwen2/Mixtral DecoderConverter subclasses disappear; their BaseModelConverters now plug in block_converter_class directly, and LlamaBaseModelConverter inlines the Fixed/Pattern dispatch (config + weight sides) parameterised by that ClassVar. LlamaDecoderConverter stays as an imperative helper for the cases that don't fit the common pattern: Pixtral's vision encoder dispatch (Llava) and Apriel's per-position hybrid layout dispatch. AprielBaseModelConverter overrides the "decoder" declaration to delegate to AprielDecoderConverter (held via a new apriel_decoder_converter_class ClassVar) instead of using the inlined Llama dispatch. Qwen2BaseModelConverter.import_config (one-line MRoPE guard) becomes a declarative ImportOnlyConfigConverter claiming use_mrope and asserting on import. Co-Authored-By: Claude Opus 4.7 --- fast_llm/models/gpt/conversion/apriel.py | 54 ++++++++++-- fast_llm/models/gpt/conversion/llama.py | 84 +++++++++++++------ fast_llm/models/gpt/conversion/mistral.py | 7 +- fast_llm/models/gpt/conversion/mixtral.py | 7 +- fast_llm/models/gpt/conversion/qwen2.py | 25 ++++-- .../models/multimodal/conversion/llava.py | 23 +++-- 6 files changed, 142 insertions(+), 58 deletions(-) diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index befabad4d..ebb5a54b5 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -4,6 +4,7 @@ from transformers import PretrainedConfig +from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConfigSectionConverter, @@ -17,9 +18,10 @@ from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig -from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( + LlamaDecoderConverter, effective_bias, get_parameter_converter, get_weight_and_bias_converters, @@ -27,7 +29,6 @@ from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, MistralBlockConverter, - MistralDecoderConverter, MistralHeadConverter, MistralHuggingfaceCheckpointHandler, ) @@ -476,10 +477,11 @@ def get_converters( ) -class AprielDecoderConverter(MistralDecoderConverter): +class AprielDecoderConverter(LlamaDecoderConverter): """Pattern-style decoder dispatched via Apriel's ``hybrid_block_layout`` list (one entry per block). Stays imperative because the layout-list shape doesn't match the declarative ``decoder.type`` - discriminator that Apriel2 uses. + discriminator that Apriel2 uses. Overrides every classmethod from + :class:`LlamaDecoderConverter`; the parent is used only as a nominal base. """ block_converter_class: typing.ClassVar[type[AprielBlockConverter]] = AprielBlockConverter @@ -551,8 +553,50 @@ class AprielHeadConverter(MistralHeadConverter): class AprielBaseModelConverter(MistralBaseModelConverter): - decoder_converter_class: typing.ClassVar[type[AprielDecoderConverter]] = AprielDecoderConverter + """Apriel needs the per-position hybrid layout dispatcher (:class:`AprielDecoderConverter`) instead of + the standard Fixed/Pattern dispatch inlined in :class:`LlamaBaseModelConverter`. The override below + replaces the parent's ``"decoder"`` declaration with one that delegates to Apriel's dispatcher. + """ + head_converter_class: typing.ClassVar[type[AprielHeadConverter]] = AprielHeadConverter + apriel_decoder_converter_class: typing.ClassVar[type[AprielDecoderConverter]] = AprielDecoderConverter + + @classmethod + def _create_config_converters(cls) -> dict: + decoder_cls = cls.apriel_decoder_converter_class + + def _decoder_export(parent: Config) -> dict: + return {(k,): v for k, v in decoder_cls.export_config(parent.decoder).items()} + + def _decoder_import(hf_dict: dict) -> dict: + return {("decoder",): decoder_cls.import_config(hf_dict)} + + return { + **super()._create_config_converters(), + "decoder": CustomConfigConverter( + fast_llm_paths=(("decoder",),), + # Block converter is the per-position dispatcher (:class:`AprielBlockConverter`), which + # unions the HF claims of every leaf-mixer's block converter. + hf_paths=( + ("num_hidden_layers",), + ("hybrid_block_layout",), + *decoder_cls.block_converter_class._consumed_hf_paths(), + ), + export_fn=_decoder_export, + import_fn=_decoder_import, + recurses=True, + ), + } + + # --- weight side (imperative): use Apriel's per-position dispatcher instead of the standard inline loop. + + @classmethod + def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + return [ + *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), + *cls.apriel_decoder_converter_class.get_converters(config.decoder, "decoder", "model.layers"), + *cls.head_converter_class.get_converters(config, exported_config), + ] class AprielHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index af4cc9ef0..f2b2479a4 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -472,12 +472,35 @@ def get_converters( ] +def _llama_decoder_export( + decoder_config: FixedBlockSequenceConfig | PatternBlockSequenceConfig, + block_converter_class: type[ConfigSectionConverter], +) -> dict: + """Convert a Fast-LLM polymorphic Fixed/Pattern block sequence to Llama's flat HF representation. + + Pattern: assert all blocks export identical HF (Llama's format has no per-block discriminator), then use + the common export. Fixed: just delegate to the single block. + """ + if isinstance(decoder_config, PatternBlockSequenceConfig): + exports = [block_converter_class.export_config(block) for block in decoder_config.blocks.values()] + for other in exports[1:]: + Assert.eq(exports[0], other) + block_hf = exports[0] + elif isinstance(decoder_config, FixedBlockSequenceConfig): + block_hf = block_converter_class.export_config(decoder_config.block) + else: + raise NotImplementedError(f"Unsupported decoder type: {type(decoder_config).__name__}") + return {**block_hf, "num_hidden_layers": decoder_config.num_blocks} + + class LlamaDecoderConverter: - """Converts ``BlockSequenceConfig`` (polymorphic Fixed/Pattern) ↔ Llama's flat block + ``num_hidden_layers``. + """Imperative dispatcher for the polymorphic Fixed/Pattern block sequence. - Kept as a regular class (not a :class:`ConfigSectionConverter`) so it can stay imperative — the polymorphism - between Fixed/Pattern block sequences doesn't lend itself to the declarative shape, and subclasses (Mistral, - Qwen2, MTP-Llama, ...) plug in different block converters via ``block_converter_class``. + Used by formats that don't compose at the :class:`LlamaBaseModelConverter` level — currently only + Pixtral's vision encoder (:class:`PixtralEncoderConverter`) and Apriel's per-position hybrid layout + dispatcher inherit from it. The standard text formats (Mistral/Qwen2/Mixtral) use the inline dispatch + inside :class:`LlamaBaseModelConverter._create_config_converters` instead, parameterised by + ``block_converter_class``. """ block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaBlockConverter @@ -491,16 +514,7 @@ def import_config(cls, hf_dict: dict) -> dict: @classmethod def export_config(cls, decoder_config: FixedBlockSequenceConfig | PatternBlockSequenceConfig) -> dict: - if isinstance(decoder_config, PatternBlockSequenceConfig): - exports = [cls.block_converter_class.export_config(block) for block in decoder_config.blocks.values()] - for other in exports[1:]: - Assert.eq(exports[0], other) - block_hf = exports[0] - elif isinstance(decoder_config, FixedBlockSequenceConfig): - block_hf = cls.block_converter_class.export_config(decoder_config.block) - else: - raise NotImplementedError(f"Unsupported decoder type: {type(decoder_config).__name__}") - return {**block_hf, "num_hidden_layers": decoder_config.num_blocks} + return _llama_decoder_export(decoder_config, cls.block_converter_class) @classmethod def get_converters( @@ -599,37 +613,44 @@ def get_converters( class LlamaBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): - """Top-level converter for ``GPTBaseModelConfig`` ↔ Llama HF dict.""" + """Top-level converter for ``GPTBaseModelConfig`` ↔ Llama HF dict. + + Subclasses (Mistral, Qwen2, Mixtral, MTP-Llama, …) override ``block_converter_class`` to plug their + per-block declarations into the polymorphic Fixed/Pattern decoder dispatch held here. + """ fast_llm_config_class = GPTBaseModelConfig - decoder_converter_class: typing.ClassVar[type[LlamaDecoderConverter]] = LlamaDecoderConverter embeddings_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaEmbeddingsConverter + block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaBlockConverter head_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaHeadConverter @classmethod def _create_config_converters(cls) -> dict: - decoder_converter_class = cls.decoder_converter_class + block_converter_class = cls.block_converter_class def _decoder_export(parent: Config) -> dict: - return {(k,): v for k, v in decoder_converter_class.export_config(parent.decoder).items()} + return {(k,): v for k, v in _llama_decoder_export(parent.decoder, block_converter_class).items()} def _decoder_import(hf_dict: dict) -> dict: - return {("decoder",): decoder_converter_class.import_config(hf_dict)} + return { + ("decoder",): { + "block": block_converter_class.import_config(hf_dict), + "num_blocks": hf_dict["num_hidden_layers"], + } + } return { "embeddings": NestedConfigConverter(("embeddings",), cls.embeddings_converter_class), "head": NestedConfigConverter(("head",), cls.head_converter_class), "decoder": CustomConfigConverter( fast_llm_paths=(("decoder",),), - # The Custom wraps the imperative LlamaDecoderConverter, which delegates to - # cls.decoder_converter_class.block_converter_class (a ConfigSectionConverter). The - # block converter's flat-merge declarations claim all per-block top-level keys; pull - # them up here so the HF coverage check sees them as covered. ``num_hidden_layers`` - # is consumed by LlamaDecoderConverter itself. + # The block converter's flat-merge declarations claim all per-block top-level keys; pull + # them up here so the HF coverage check sees them as covered. ``num_hidden_layers`` is + # consumed by the Fixed/Pattern dispatch above. hf_paths=( ("num_hidden_layers",), - *cls.decoder_converter_class.block_converter_class._consumed_hf_paths(), + *block_converter_class._consumed_hf_paths(), ), export_fn=_decoder_export, import_fn=_decoder_import, @@ -650,9 +671,20 @@ def _validate_export(cls, config: GPTBaseModelConfig) -> None: @classmethod def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + decoder_config = config.decoder + block_config = ( + decoder_config.block + if isinstance(decoder_config, FixedBlockSequenceConfig) + else next(iter(decoder_config.blocks.values())) + ) + block_converters: list[WeightConverter] = [] + for block_index in range(decoder_config.num_blocks): + block_converters += cls.block_converter_class.get_converters( + block_config, f"decoder.{block_index}", f"model.layers.{block_index}" + ) return [ *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), - *cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.layers"), + *block_converters, *cls.head_converter_class.get_converters(config, exported_config), ] diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index 7664a195c..18251c760 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -7,7 +7,6 @@ LlamaAttentionConverter, LlamaBaseModelConverter, LlamaBlockConverter, - LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, LlamaMLPConverter, @@ -40,16 +39,12 @@ class MistralBlockConverter(LlamaBlockConverter): mlp_converter_class: typing.ClassVar[type[MistralMLPConverter]] = MistralMLPConverter -class MistralDecoderConverter(LlamaDecoderConverter): - block_converter_class: typing.ClassVar[type[MistralBlockConverter]] = MistralBlockConverter - - class MistralHeadConverter(LlamaHeadConverter): block_converter_class: typing.ClassVar[type[MistralBlockConverter]] = MistralBlockConverter class MistralBaseModelConverter(LlamaBaseModelConverter): - decoder_converter_class: typing.ClassVar[type[MistralDecoderConverter]] = MistralDecoderConverter + block_converter_class: typing.ClassVar[type[MistralBlockConverter]] = MistralBlockConverter head_converter_class: typing.ClassVar[type[MistralHeadConverter]] = MistralHeadConverter diff --git a/fast_llm/models/gpt/conversion/mixtral.py b/fast_llm/models/gpt/conversion/mixtral.py index d1ead7309..9b341981e 100644 --- a/fast_llm/models/gpt/conversion/mixtral.py +++ b/fast_llm/models/gpt/conversion/mixtral.py @@ -14,7 +14,6 @@ from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, MistralBlockConverter, - MistralDecoderConverter, MistralHeadConverter, MistralHuggingfaceCheckpointHandler, ) @@ -84,16 +83,12 @@ class MixtralBlockConverter(MistralBlockConverter): mlp_converter_class: typing.ClassVar[type[MixtralMLPConverter]] = MixtralMLPConverter -class MixtralDecoderConverter(MistralDecoderConverter): - block_converter_class: typing.ClassVar[type[MixtralBlockConverter]] = MixtralBlockConverter - - class MixtralHeadConverter(MistralHeadConverter): block_converter_class: typing.ClassVar[type[MixtralBlockConverter]] = MixtralBlockConverter class MixtralBaseModelConverter(MistralBaseModelConverter): - decoder_converter_class: typing.ClassVar[type[MixtralDecoderConverter]] = MixtralDecoderConverter + block_converter_class: typing.ClassVar[type[MixtralBlockConverter]] = MixtralBlockConverter head_converter_class: typing.ClassVar[type[MixtralHeadConverter]] = MixtralHeadConverter diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index b785fb0a6..6a4f4f385 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -16,7 +16,6 @@ LlamaAttentionConverter, LlamaBaseModelConverter, LlamaBlockConverter, - LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, LlamaMLPConverter, @@ -119,22 +118,32 @@ class Qwen2BlockConverter(LlamaBlockConverter): mlp_converter_class: typing.ClassVar[type[Qwen2MLPConverter]] = Qwen2MLPConverter -class Qwen2DecoderConverter(LlamaDecoderConverter): +class Qwen2HeadConverter(LlamaHeadConverter): block_converter_class: typing.ClassVar[type[Qwen2BlockConverter]] = Qwen2BlockConverter -class Qwen2HeadConverter(LlamaHeadConverter): - block_converter_class: typing.ClassVar[type[Qwen2BlockConverter]] = Qwen2BlockConverter +def _qwen2_mrope_guard_import(hf_dict: dict) -> dict: + if hf_dict.get("use_mrope") is True: + raise AssertionError("MRoPE (use_mrope=True) is not supported by the Qwen2 converter") + return {} class Qwen2BaseModelConverter(LlamaBaseModelConverter): - decoder_converter_class: typing.ClassVar[type[Qwen2DecoderConverter]] = Qwen2DecoderConverter + block_converter_class: typing.ClassVar[type[Qwen2BlockConverter]] = Qwen2BlockConverter head_converter_class: typing.ClassVar[type[Qwen2HeadConverter]] = Qwen2HeadConverter @classmethod - def import_config(cls, hf_dict: dict) -> dict: - assert hf_dict.get("use_mrope") is not True, "MRoPE (use_mrope=True) is not supported by the Qwen2 converter" - return super().import_config(hf_dict) + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Refuse MRoPE on import; the export path can't produce ``use_mrope=True`` because Fast-LLM + # has no rotary type that maps to it. + "use_mrope_guard": ImportOnlyConfigConverter( + fast_llm_paths=(), + hf_paths=(("use_mrope",),), + import_fn=_qwen2_mrope_guard_import, + ), + } @classmethod def _validate_export(cls, config: GPTBaseModelConfig) -> None: diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index b504bc033..77ed565ce 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -19,6 +19,7 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.attention.rotary.config import Rotary2DConfig +from fast_llm.layers.block.config import FixedBlockSequenceConfig from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionEncoderConfig @@ -427,17 +428,25 @@ def _validate_export(cls, config: MultiModalBaseModelConfig) -> None: @classmethod def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + text_base_cls = cls.language_model_converter_class + decoder_config = config.decoder + block_config = ( + decoder_config.block + if isinstance(decoder_config, FixedBlockSequenceConfig) + else next(iter(decoder_config.blocks.values())) + ) + block_converters: list[WeightConverter] = [] + for block_index in range(decoder_config.num_blocks): + block_converters += text_base_cls.block_converter_class.get_converters( + block_config, f"decoder.{block_index}", f"language_model.model.layers.{block_index}" + ) return [ *cls.vision_model_converter_class.get_converters(config.vision_encoder), - *cls.language_model_converter_class.embeddings_converter_class.get_converters( + *text_base_cls.embeddings_converter_class.get_converters( config.embeddings, "embeddings", "language_model.model" ), - *cls.language_model_converter_class.decoder_converter_class.get_converters( - config.decoder, "decoder", "language_model.model.layers" - ), - *cls.language_model_converter_class.head_converter_class.get_converters( - config, {"tie_word_embeddings": False} - ), + *block_converters, + *text_base_cls.head_converter_class.get_converters(config, {"tie_word_embeddings": False}), ] From b3b41b7519bbf07c12707b10ad19ddbaa51b6d2a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 May 2026 17:27:57 -0400 Subject: [PATCH 24/27] Claim transformers metadata keys in Llava vision_config transformers' PretrainedConfig.to_dict() populates _name_or_path/architectures/ torch_dtype/transformers_version on nested configs (vision_config is a PretrainedConfig under transformers.LlavaConfig), so a round-tripped save carries these keys back through the HF coverage check. The top-level _HF_METADATA_ALLOWLIST only matches single-key paths, so we mark them explicitly ignored inside LlavaVisionModelConverter. Fixes test_conversion[llava] which failed on "unknown key 'vision_config.architectures'". Co-Authored-By: Claude Opus 4.7 --- fast_llm/models/multimodal/conversion/llava.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 77ed565ce..502de0457 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -303,6 +303,19 @@ def _encoder_import(hf_dict: dict) -> dict: # Llava's vision_config carries a literal ``model_type: "pixtral"``; # ``ConstantExportConfigConverter`` emits on export and asserts equality on import. "model_type": ConstantExportConfigConverter(("model_type",), cls.model_type), + # transformers' ``PretrainedConfig.to_dict()`` populates these metadata fields on nested + # configs (vision_config is a PretrainedConfig under transformers.LlavaConfig). The top-level + # ``_HF_METADATA_ALLOWLIST`` only matches single-key paths, so we explicitly mark them ignored + # within this scope so the recursive HF coverage check doesn't flag round-tripped saves. + "hf_metadata": IgnoredConfigConverter( + hf_paths=( + ("_name_or_path",), + ("architectures",), + ("dtype",), + ("torch_dtype",), + ("transformers_version",), + ), + ), # Adapter is handled at LlavaBaseModelConverter scope (sees text_config). Mark recursively # consumed here so the architecture walker sees the sub-tree as claimed at this level too. "adapter": IgnoredConfigConverter(("adapter",)), From 70c63ba3f96727070a5ee210437f4b1468b0a2d6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 May 2026 17:48:05 -0400 Subject: [PATCH 25/27] Apply HF metadata allowlist recursively in coverage check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The top-level _HF_METADATA_ALLOWLIST covers generic PretrainedConfig fields (architectures, torch_dtype, transformers_version, output_hidden_states, …), but the recursive coverage walker only matched it on single-key paths. After a round-tripped save, transformers populates the same metadata on nested sub-configs like Llava's vision_config, which then trip the walker. Match the allowlist against any segment of the path. Revert the previous local-scope ignore claim on LlavaVisionModelConverter, which only patched a subset of the keys and didn't help apriel2 or future nested formats. Co-Authored-By: Claude Opus 4.7 --- fast_llm/engine/checkpoint/external.py | 9 ++++++--- fast_llm/models/multimodal/conversion/llava.py | 13 ------------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index ac87589bd..d0aee8bbc 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -594,8 +594,11 @@ def check_hf_coverage(cls, hf_dict: dict, *, allowlist: frozenset[str] = frozens """Raise :class:`ValueError` if the input HF dict carries keys not consumed by any declaration. Walks ``hf_dict`` recursively. A path is considered covered if it (or any of its prefixes) is in - :meth:`_consumed_hf_paths`, or — for top-level keys — appears in ``allowlist``. Uncovered leaves - raise; uncovered sub-dicts trigger descent into their entries to surface the offending leaf path. + :meth:`_consumed_hf_paths`, or if any segment of the path appears in ``allowlist`` (so transformers' + generic ``PretrainedConfig`` metadata keys — ``architectures``, ``torch_dtype``, ``transformers_version``, + … — are accepted at any depth, including under nested sub-configs like Llava's ``vision_config``). + Uncovered leaves raise; uncovered sub-dicts trigger descent into their entries to surface the offending + leaf path. Catches transformers-version drift, manual edits, and corrupted configs at the import boundary — the symmetric counterpart to the architecture-coverage check (which is statically verified by @@ -607,7 +610,7 @@ def walk(value: typing.Any, path: tuple[str, ...]) -> None: for length in range(1, len(path) + 1): if path[:length] in prefixes: return - if len(path) == 1 and path[0] in allowlist: + if any(segment in allowlist for segment in path): return if isinstance(value, dict): for key, sub in value.items(): diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 502de0457..77ed565ce 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -303,19 +303,6 @@ def _encoder_import(hf_dict: dict) -> dict: # Llava's vision_config carries a literal ``model_type: "pixtral"``; # ``ConstantExportConfigConverter`` emits on export and asserts equality on import. "model_type": ConstantExportConfigConverter(("model_type",), cls.model_type), - # transformers' ``PretrainedConfig.to_dict()`` populates these metadata fields on nested - # configs (vision_config is a PretrainedConfig under transformers.LlavaConfig). The top-level - # ``_HF_METADATA_ALLOWLIST`` only matches single-key paths, so we explicitly mark them ignored - # within this scope so the recursive HF coverage check doesn't flag round-tripped saves. - "hf_metadata": IgnoredConfigConverter( - hf_paths=( - ("_name_or_path",), - ("architectures",), - ("dtype",), - ("torch_dtype",), - ("transformers_version",), - ), - ), # Adapter is handled at LlavaBaseModelConverter scope (sees text_config). Mark recursively # consumed here so the architecture walker sees the sub-tree as claimed at this level too. "adapter": IgnoredConfigConverter(("adapter",)), From 1b2cf9d4db2e50167e85d58606b0b25eb9d2c999 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 May 2026 18:01:03 -0400 Subject: [PATCH 26/27] Claim transformers' Pixtral and Llava HF defaults MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit transformers.LlavaConfig.from_dict(...).save_pretrained(...) round-trips the config through transformers.LlavaConfig and PixtralVisionConfig, which fill in many model-specific defaults Fast-LLM doesn't consume (head_dim, image_size, layer_norm_eps, initializer_factor, projection_dim, vocab_size in vision_config; image_seq_length, tie_word_embeddings at the top level). Add IgnoredConfigConverter claims for these so the recursive HF coverage check accepts round-tripped saves. tie_word_embeddings is intentionally claimed only at the top level — Fast-LLM tracks it inside text_config via Llama's tied_embedding_weight declaration. Co-Authored-By: Claude Opus 4.7 --- .../models/multimodal/conversion/llava.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 77ed565ce..48b86951a 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -303,6 +303,21 @@ def _encoder_import(hf_dict: dict) -> dict: # Llava's vision_config carries a literal ``model_type: "pixtral"``; # ``ConstantExportConfigConverter`` emits on export and asserts equality on import. "model_type": ConstantExportConfigConverter(("model_type",), cls.model_type), + # ``transformers.LlavaConfig.from_dict(...).save_pretrained(...)`` round-trips the + # vision_config through :class:`PixtralVisionConfig`, which fills in these default fields. + # Fast-LLM does not consume them; mark them ignored so the recursive coverage check accepts + # round-tripped saves. (``head_dim`` is normally not emitted because we override head_size to + # ImportOnly, but transformers fills it from ``hidden_size // num_attention_heads`` on load.) + "pixtral_hf_defaults": IgnoredConfigConverter( + hf_paths=( + ("head_dim",), + ("image_size",), + ("initializer_factor",), + ("layer_norm_eps",), + ("projection_dim",), + ("vocab_size",), + ), + ), # Adapter is handled at LlavaBaseModelConverter scope (sees text_config). Mark recursively # consumed here so the architecture walker sees the sub-tree as claimed at this level too. "adapter": IgnoredConfigConverter(("adapter",)), @@ -416,6 +431,13 @@ def _text_import(hf_dict: dict) -> dict: ("vision_feature_select_strategy",), "full" ), "vision_feature_layer": ConstantExportConfigConverter(("vision_feature_layer",), -1), + # ``transformers.LlavaConfig.save_pretrained(...)`` round-trips the top-level config through + # :class:`transformers.LlavaConfig`, which fills these defaults. Fast-LLM tracks + # ``tie_word_embeddings`` *inside* text_config (Llama's tied_embedding_weight), not at the + # Llava level; ``image_seq_length`` is a runtime/inference field, not architecture. + "llava_hf_defaults": IgnoredConfigConverter( + hf_paths=(("image_seq_length",), ("tie_word_embeddings",)), + ), } @classmethod From 5b31071a43bf3bbaaf0aa745b6a482b07d3b89fe Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 May 2026 19:24:32 -0400 Subject: [PATCH 27/27] Port Gemma4BaseModelConverter to ConfigSectionConverter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Migrate the top-level Gemma 4 base-model converter from the imperative import_config/export_config shape to the declarative ConfigSectionConverter API used by the rest of the codebase. The embeddings/decoder/head sections remain imperative helpers — each is wrapped in a recursing CustomConfigConverter because Gemma 4's HF format cross-references hidden_size (embeddings, MoE router scale) and merges two block variants (sliding_attention / full_attention) into one HF dict, neither of which fits the standard per-section decomposition. The previously-imperative top-level guards (PLE, KV sharing, double-wide MLP, bidirectional attention) become declarative ConstantExportConfigConverter / CustomConfigConverter entries that preserve the rejected-feature checks while running through the same HF-coverage walker as every other format. ``vocab_size_per_layer_input`` is claimed via IgnoredConfigConverter so the coverage walker accepts the value transformers fills on save_pretrained round-trip. Co-Authored-By: Claude Opus 4.7 --- fast_llm/models/gpt/conversion/gemma4.py | 166 +++++++++++++++++------ 1 file changed, 125 insertions(+), 41 deletions(-) diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index 3c6f22788..fcb92df0c 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -2,12 +2,18 @@ import typing +from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + ConstantExportConfigConverter, + CustomConfigConverter, + IgnoredConfigConverter, + RenameConfigConverter, SplitWeightConverter, WeightConverter, ) -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, ProportionalRotaryConfig @@ -622,54 +628,132 @@ def export_config(cls, config: LanguageModelHeadConfig) -> dict: return out -class Gemma4BaseModelConverter: +def _gemma4_bidirectional_export(_: Config) -> dict: + # Fast-LLM is text-only; bidirectional attention (used for vision tokens in the multimodal + # model) is not implemented. Always emit ``None``. + return {("use_bidirectional_attention",): None} + + +def _gemma4_bidirectional_import(hf_dict: dict) -> dict: + # ``use_bidirectional_attention="vision"`` only affects vision tokens; the text path stays + # causal. Only ``"all"`` toggles ``is_causal=False`` for the text decoder, which we don't + # implement. + if hf_dict.get("use_bidirectional_attention") == "all": + raise NotImplementedError( + 'Gemma 4 `use_bidirectional_attention="all"` is not supported (text path stays causal).' + ) + return {} + + +class Gemma4BaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): + """Top-level converter for ``GPTBaseModelConfig`` ↔ Gemma 4 HF dict. + + Gemma 4 has several wrinkles that prevent the standard per-section decomposition used by Llama: + + * The decoder is a :class:`PatternBlockSequenceConfig` whose two named blocks + (``sliding_attention`` / ``full_attention``) share most HF keys but diverge on ``head_dim`` and + rope parameters. The HF format emits both block variants from a single root-level config, so + the block-level transform inherently sees both Fast-LLM blocks at once. + * ``embedding_scale = hidden_size ** 0.5`` and ``router_input_scale = hidden_size ** -0.5`` make + the embeddings and routed MLP cross-reference the root-level ``hidden_size``. + + Each section ((embeddings, decoder, head)) is therefore expressed as a :class:`CustomConfigConverter` + that delegates to an imperative helper class (kept private to this module). Coverage at the + section level is satisfied via ``recurses=True``. + """ + + fast_llm_config_class = GPTBaseModelConfig + decoder_converter_class: typing.ClassVar[type[Gemma4DecoderConverter]] = Gemma4DecoderConverter embeddings_converter_class: typing.ClassVar[type[Gemma4EmbeddingsConverter]] = Gemma4EmbeddingsConverter head_converter_class: typing.ClassVar[type[Gemma4HeadConverter]] = Gemma4HeadConverter @classmethod - def import_config(cls, config: dict) -> dict: - if config.get("hidden_size_per_layer_input") not in (None, 0): - raise NotImplementedError( - "Gemma 4 Per-Layer Embeddings (`hidden_size_per_layer_input != 0`) are not supported." - ) - if config.get("num_kv_shared_layers", 0): - raise NotImplementedError("Gemma 4 cross-layer KV sharing (`num_kv_shared_layers != 0`) is not supported.") - if config.get("use_double_wide_mlp", False): - raise NotImplementedError("Gemma 4 `use_double_wide_mlp=True` is not supported.") - # `use_bidirectional_attention="vision"` only affects vision tokens; the text path stays causal. - # Only `"all"` toggles `is_causal=False` for the text decoder, which we don't implement. - if config.get("use_bidirectional_attention") == "all": - raise NotImplementedError( - 'Gemma 4 `use_bidirectional_attention="all"` is not supported (text path stays causal).' - ) + def _create_config_converters(cls) -> dict: + decoder_cls = cls.decoder_converter_class + embeddings_cls = cls.embeddings_converter_class + head_cls = cls.head_converter_class + + def _embeddings_export(parent: Config) -> dict: + return {(k,): v for k, v in embeddings_cls.export_config(parent.embeddings, parent.hidden_size).items()} + + def _embeddings_import(hf_dict: dict) -> dict: + return {("embeddings",): embeddings_cls.import_config(hf_dict)} + + def _decoder_export(parent: Config) -> dict: + return {(k,): v for k, v in decoder_cls.export_config(parent.decoder, parent.hidden_size).items()} + + def _decoder_import(hf_dict: dict) -> dict: + return {("decoder",): decoder_cls.import_config(hf_dict)} + + def _head_export(parent: Config) -> dict: + return {(k,): v for k, v in head_cls.export_config(parent.head).items()} + + def _head_import(hf_dict: dict) -> dict: + return {("head",): head_cls.import_config(hf_dict)} + return { - "embeddings": cls.embeddings_converter_class.import_config(config), - "decoder": cls.decoder_converter_class.import_config(config), - "head": cls.head_converter_class.import_config(config), - "hidden_size": config["hidden_size"], - "tied_embedding_weight": config["tie_word_embeddings"], + "embeddings": CustomConfigConverter( + fast_llm_paths=(("embeddings",),), + hf_paths=(("vocab_size",),), + export_fn=_embeddings_export, + import_fn=_embeddings_import, + recurses=True, + ), + "decoder": CustomConfigConverter( + fast_llm_paths=(("decoder",),), + hf_paths=( + ("num_hidden_layers",), + ("layer_types",), + ("num_attention_heads",), + ("num_key_value_heads",), + ("head_dim",), + ("global_head_dim",), + ("num_global_key_value_heads",), + ("attention_bias",), + ("attention_dropout",), + ("sliding_window",), + ("rms_norm_eps",), + ("attention_k_eq_v",), + ("rope_parameters",), + ("intermediate_size",), + ("hidden_activation",), + ("enable_moe_block",), + ("num_experts",), + ("top_k_experts",), + ("moe_intermediate_size",), + ), + export_fn=_decoder_export, + import_fn=_decoder_import, + recurses=True, + ), + "head": CustomConfigConverter( + fast_llm_paths=(("head",),), + hf_paths=(("final_logit_softcapping",),), + export_fn=_head_export, + import_fn=_head_import, + recurses=True, + ), + "hidden_size": RenameConfigConverter(("hidden_size",), ("hidden_size",)), + "tied_embedding_weight": RenameConfigConverter(("tied_embedding_weight",), ("tie_word_embeddings",)), + "peft": IgnoredConfigConverter(("peft",)), + # TODO: Implement Per-Layer Embeddings (PLE). Gemma4TextConfig defaults to 256; explicitly + # zero to disable the feature in the exported model until Fast-LLM supports it natively. + "hidden_size_per_layer_input": ConstantExportConfigConverter(("hidden_size_per_layer_input",), 0), + "num_kv_shared_layers": ConstantExportConfigConverter(("num_kv_shared_layers",), 0), + "use_double_wide_mlp": ConstantExportConfigConverter(("use_double_wide_mlp",), False), + "use_bidirectional_attention": CustomConfigConverter( + fast_llm_paths=(), + hf_paths=(("use_bidirectional_attention",),), + export_fn=_gemma4_bidirectional_export, + import_fn=_gemma4_bidirectional_import, + ), + # Vocab-size-per-layer is part of Per-Layer Embeddings (PLE), gated by + # ``hidden_size_per_layer_input``. PLE is rejected above, so we ignore the size field too. + "vocab_size_per_layer_input": IgnoredConfigConverter(hf_paths=(("vocab_size_per_layer_input",),)), } - @classmethod - def export_config(cls, config: GPTBaseModelConfig) -> dict: - Assert.custom(isinstance, config, GPTBaseModelConfig) - return safe_merge_dicts( - cls.embeddings_converter_class.export_config(config.embeddings, config.hidden_size), - cls.decoder_converter_class.export_config(config.decoder, config.hidden_size), - cls.head_converter_class.export_config(config.head), - { - "tie_word_embeddings": config.tied_embedding_weight, - "hidden_size": config.hidden_size, - # TODO: Implement Per-Layer Embeddings (PLE). Gemma4TextConfig defaults to 256; - # explicitly zero to disable the feature in the exported model until Fast-LLM - # supports it natively. - "hidden_size_per_layer_input": 0, - # Fast-LLM is text-only; bidirectional attention (used for vision tokens in the - # multimodal model) is not implemented. - "use_bidirectional_attention": None, - }, - ) + # --- weight side (imperative) --- @classmethod def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: