diff --git a/CLAUDE.md b/CLAUDE.md index 836b2ba7b..4f97d071f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -174,7 +174,7 @@ Tests live in `tests/`. The following patterns work well in this codebase. - **Comments**: Write no comments by default. Only add one when the *why* is non-obvious — a hidden constraint, a subtle invariant, a workaround for a specific bug, behavior that would surprise a reader. Never restate what the code already says; well-named identifiers do that. - **Imports**: Third-party → `import package.module` (keep fully qualified). First-party → `from fast_llm.module import Thing`. No relative imports. Optional/slow imports inside methods or under `if typing.TYPE_CHECKING:`. -- **Naming**: No abbreviations (use `batch_size` not `bs`). Private members get a single `_` prefix; never use `__`. Keep public interfaces lean. +- **Naming**: No abbreviations (use `batch_size` not `bs`). Non-public members (private or protected) get a single `_` prefix; never use `__`. Keep public interfaces lean. - **Types**: Always type-hint public interfaces. Use modern syntax (`X | Y`, `list[T]` not `List[T]`, PEP 695 generics like `class X[T: Bound]:` instead of `typing.TypeVar`). - **Assert**: Use the `Assert` namespace from `fast_llm.utils` for contract checks (`Assert.eq`, `Assert.geq`, `Assert.incl`, `Assert.custom`, etc.) — error messages auto-format with actual values. Bare `assert` is reserved for internal state-validity invariants (`assert self._is_setup`). - **Exceptions**: Raise stdlib exceptions for runtime errors (`ValueError`, `RuntimeError`, `NotImplementedError`). Custom exception classes are rare — only `ValidationError`, `NestedValidationError`, `FieldTypeError` in `config.py`. diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 886c706c1..d0aee8bbc 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -1,4 +1,6 @@ import abc +import dataclasses +import functools import logging import pathlib import typing @@ -6,7 +8,7 @@ import torch from fast_llm import __version__ -from fast_llm.config import Config +from fast_llm.config import Config, FieldHint, get_nested_dict_value, set_nested_dict_value from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointLoadMetadataConfig from fast_llm.engine.checkpoint.state_dict import StateDictCheckpointHandler @@ -18,6 +20,664 @@ logger = logging.getLogger(__name__) +def _get_attr_path(config: Config, path: tuple[str, ...]) -> typing.Any: + current = config + for name in path: + current = getattr(current, name) + return current + + +def _collect_architecture_paths(config: Config) -> list[tuple[str, ...]]: + """Walk ``config`` and return every architecture-hint field path reachable from it. + + Descends into any field whose runtime value is a :class:`Config`, a ``dict[str, Config]`` + (paths are extended with the entry's string key), or a ``list[Config]`` (paths are extended + with the entry's index as a string), so the path list reflects the actual instance. + """ + paths: list[tuple[str, ...]] = [] + + def descend(value: typing.Any, prefix: tuple[str, ...]) -> None: + if isinstance(value, Config): + walk(value, prefix) + elif isinstance(value, dict): + for key, sub in value.items(): + descend(sub, prefix + (str(key),)) + elif isinstance(value, (list, tuple)): + for index, sub in enumerate(value): + descend(sub, prefix + (str(index),)) + + def walk(node: Config, prefix: tuple[str, ...]) -> None: + for name, field in type(node).fields(): + if field._field_type != dataclasses._FIELD or not field.init: + continue + full = prefix + (name,) + if field.hint == FieldHint.architecture: + paths.append(full) + descend(getattr(node, name), full) + + walk(config, ()) + return paths + + +# ============================================================ +# Config conversion primitives (declarative) +# ============================================================ + + +class ConfigConverter(abc.ABC): + """A declarative description of how one or more Fast-LLM config fields map to one or more HF config keys. + + Each primitive owns a set of ``fast_llm_paths`` (tuples of attribute names rooted at the section's config) and + ``hf_paths`` (tuples of dict keys rooted at the section's HF subdict). The walker calls ``export_to`` to produce + HF entries from a Fast-LLM config object, and ``import_to`` to produce a Fast-LLM config dict from an HF dict. + + ``recurses`` controls how :meth:`ConfigSectionConverter._check_architecture_coverage` interprets the paths: + + * ``recurses = False`` (default) — paths are exact-match leaves. Every architecture-hint field at every depth + under the section's config class must be exactly listed by some declaration. + * ``recurses = True`` — paths are recursive prefixes covering the entire subtree. Used by primitives that + delegate to a sub-converter that runs its own coverage check (Nested/Dispatch/TypedDictContainer), by + :class:`IgnoredConfigConverter` (the format intentionally does not represent the subtree), and by + :class:`CustomConfigConverter` when its author opts in (escape hatch for cases like rotary that don't + decompose into per-leaf renames). + """ + + fast_llm_paths: tuple[tuple[str, ...], ...] = () + hf_paths: tuple[tuple[str, ...], ...] = () + recurses: typing.ClassVar[bool] = False + + @abc.abstractmethod + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: ... + + @abc.abstractmethod + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: ... + + +class RenameConfigConverter(ConfigConverter): + """One-to-one rename between a Fast-LLM attribute path and an HF dict path.""" + + def __init__(self, fast_llm_path: tuple[str, ...], hf_path: tuple[str, ...]): + self.fast_llm_paths = (fast_llm_path,) + self.hf_paths = (hf_path,) + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + value = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + set_nested_dict_value(hf_out, self.hf_paths[0], value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + value = get_nested_dict_value(hf_dict, self.hf_paths[0]) + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], value) + + +class ConstantExportConfigConverter(ConfigConverter): + """Write a constant to the HF dict on export. On import, assert that the HF dict has this constant value. + + Used when a HF format requires a key whose value Fast-LLM doesn't store (or always pins to a constant). + """ + + def __init__(self, hf_path: tuple[str, ...], value: typing.Any): + self.hf_paths = (hf_path,) + self._value = value + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + set_nested_dict_value(hf_out, self.hf_paths[0], self._value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + try: + actual = get_nested_dict_value(hf_dict, self.hf_paths[0]) + except KeyError: + return + Assert.eq(actual, self._value) + + +class ConstantImportConfigConverter(ConfigConverter): + """Inject a constant into the Fast-LLM dict on import. On export, assert the config matches the constant. + + Used when a Fast-LLM field is required but the HF format implies a fixed value (e.g., gated MLP for Llama). + """ + + def __init__(self, fast_llm_path: tuple[str, ...], value: typing.Any): + self.fast_llm_paths = (fast_llm_path,) + self._value = value + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + actual = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + Assert.eq(actual, self._value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], self._value) + + +class DefaultConfigConverter(ConfigConverter): + """Rename with an HF-side fallback used when the HF key is missing on import. + + ``hf_default_fn`` is called with the full HF dict if the path is absent; otherwise it's a plain rename. + On export, behaves like ``RenameConfigConverter``. + """ + + def __init__( + self, + fast_llm_path: tuple[str, ...], + hf_path: tuple[str, ...], + hf_default_fn: typing.Callable[[dict], typing.Any], + ): + self.fast_llm_paths = (fast_llm_path,) + self.hf_paths = (hf_path,) + self._hf_default_fn = hf_default_fn + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + value = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + set_nested_dict_value(hf_out, self.hf_paths[0], value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + try: + value = get_nested_dict_value(hf_dict, self.hf_paths[0]) + except KeyError: + value = self._hf_default_fn(hf_dict) + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], value) + + +class OptionalConfigConverter(ConfigConverter): + """Emit/import only when the value differs from a sentinel (default ``None``). + + Useful for fields that round-trip cleanly only when present (e.g. ``window_size``). + """ + + def __init__(self, fast_llm_path: tuple[str, ...], hf_path: tuple[str, ...], sentinel: typing.Any = None): + self.fast_llm_paths = (fast_llm_path,) + self.hf_paths = (hf_path,) + self._sentinel = sentinel + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + value = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + if value != self._sentinel: + set_nested_dict_value(hf_out, self.hf_paths[0], value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + try: + value = get_nested_dict_value(hf_dict, self.hf_paths[0]) + except KeyError: + return + if value != self._sentinel: + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], value) + + +class IgnoredConfigConverter(ConfigConverter): + """Declares Fast-LLM architecture fields and/or HF dict keys as intentionally not converted by this format. + + Use ``fast_llm_paths`` (positional) when Fast-LLM has architecture fields with no HF representation; the + Fast-LLM default round-trips. Use ``hf_paths`` (kw-only) when the HF format carries fields Fast-LLM does + not consume (generation-only toggles like Mixtral's ``router_aux_loss_coef``, Qwen2's ``sliding_window``). + Both kinds of claim are no-ops at conversion time and serve only the per-side coverage checks. The claim + covers the entire subtree under each listed path on the side it applies to. + """ + + recurses: typing.ClassVar[bool] = True + + def __init__(self, *fast_llm_paths: tuple[str, ...], hf_paths: tuple[tuple[str, ...], ...] = ()): + self.fast_llm_paths = fast_llm_paths + self.hf_paths = hf_paths + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + pass + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + pass + + +class CustomConfigConverter(ConfigConverter): + """Escape hatch for cross-field transforms (e.g., rotary, where one HF blob ↔ several Fast-LLM fields). + + ``fast_llm_paths`` and ``hf_paths`` are declared so the per-side coverage checks see the fields as consumed. + Cross-field validators that produce nothing on the HF side belong on :py:meth:`ConfigSectionConverter._validate_export` + instead; this primitive is for shape-changing transforms. + + Pass ``recurses=True`` when the converter genuinely owns a sub-config subtree (e.g. rotary, per-layer biases) — + the listed paths then act as recursive prefixes and the architecture-coverage check stops at them. The author + is trusted to handle every architecture field of the claimed subtree; prefer Nested/Dispatch primitives when + the subtree decomposes cleanly. + """ + + def __init__( + self, + fast_llm_paths: tuple[tuple[str, ...], ...], + export_fn: typing.Callable[[Config], dict], + import_fn: typing.Callable[[dict], dict], + hf_paths: tuple[tuple[str, ...], ...] = (), + recurses: bool = False, + ): + self.fast_llm_paths = fast_llm_paths + self.hf_paths = hf_paths + self._export_fn = export_fn + self._import_fn = import_fn + self.recurses = recurses + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + produced = self._export_fn(fast_llm_config) + for path, value in produced.items(): + set_nested_dict_value(hf_out, path, value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + produced = self._import_fn(hf_dict) + for path, value in produced.items(): + set_nested_dict_value(fast_llm_out, path, value) + + +class ImportOnlyConfigConverter(ConfigConverter): + """One-way mapping that runs only on import; emits nothing on export. + + Used when the HF format derives a Fast-LLM field from sibling fields (e.g. ``head_size`` from + ``hidden_size // num_attention_heads`` in Qwen2) or implies a value the Fast-LLM side stores + explicitly (e.g. Qwen2's hardcoded Q/K/V biases, Pixtral's mirrored ``patch_size`` ↔ ``patch_width``). + On export the field is redundant and validated through ``_validate_export``; on import the + ``import_fn`` produces the Fast-LLM dict entries. The fast_llm_paths register as consumed for the + architecture-coverage check; ``hf_paths`` register as consumed for the HF-side check. + + Pass ``recurses=True`` when the converter populates a sub-config subtree (e.g. Qwen2's per-layer + biases that target ``query_layer``/``key_layer``/...). Same trade-off as + :class:`CustomConfigConverter`: the listed paths become recursive prefixes and the framework no + longer enforces leaf coverage under them. + """ + + def __init__( + self, + fast_llm_paths: tuple[tuple[str, ...], ...], + import_fn: typing.Callable[[dict], dict], + hf_paths: tuple[tuple[str, ...], ...] = (), + recurses: bool = False, + ): + self.fast_llm_paths = fast_llm_paths + self.hf_paths = hf_paths + self._import_fn = import_fn + self.recurses = recurses + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + pass + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + produced = self._import_fn(hf_dict) + for path, value in produced.items(): + set_nested_dict_value(fast_llm_out, path, value) + + +class NestedConfigConverter(ConfigConverter): + """Recurse into a fixed-typed sub-config field via another section converter class. + + Default (``hf_path=None``): the HF side is flat-merged — the sub-converter's output becomes top-level keys + of the parent's HF dict, asserting any pre-existing keys agree. + + With ``hf_path`` set: the sub-converter's output is placed under that nested key. Use this for HF formats + that mirror Fast-LLM's modular layout (e.g. Apriel2's ``"decoder": {...}`` and ``"head": {...}`` blocks). + + When the target ``converter_class`` declares ``hf_type_name``, an HF discriminator (``"type"`` by default) + is auto-injected on export and validated/stripped on import — matching DispatchConfigConverter's behavior + for homogeneous (single-target) cases. + """ + + recurses: typing.ClassVar[bool] = True + + def __init__( + self, + fast_llm_path: tuple[str, ...], + converter_class: "type[ConfigSectionConverter]", + hf_path: tuple[str, ...] | None = None, + hf_discriminator_key: str = "type", + ): + self.fast_llm_paths = (fast_llm_path,) + self._converter_class = converter_class + self._hf_path = hf_path + self._hf_discriminator_key = hf_discriminator_key + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + sub_config = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + sub_hf = self._converter_class.export_config(sub_config) + if self._converter_class.hf_type_name is not None: + sub_hf = {self._hf_discriminator_key: self._converter_class.hf_type_name, **sub_hf} + if self._hf_path is None: + for key, value in sub_hf.items(): + if key in hf_out: + Assert.eq(hf_out[key], value) + else: + hf_out[key] = value + else: + set_nested_dict_value(hf_out, self._hf_path, sub_hf) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + sub_hf = get_nested_dict_value(hf_dict, self._hf_path) if self._hf_path is not None else hf_dict + if self._converter_class.hf_type_name is not None and self._hf_discriminator_key in sub_hf: + Assert.eq(sub_hf[self._hf_discriminator_key], self._converter_class.hf_type_name) + sub_hf = {key: value for key, value in sub_hf.items() if key != self._hf_discriminator_key} + sub_fast_llm = self._converter_class.import_config(sub_hf) + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], sub_fast_llm) + + +class DispatchConfigConverter(ConfigConverter): + """Polymorphic sub-config dispatch. + + The Fast-LLM field's runtime type selects the section converter; the HF format selects via a ``type`` discriminator. + Both registries (Fast-LLM type → converter class, HF discriminator → converter class) must agree at runtime. + """ + + recurses: typing.ClassVar[bool] = True + + def __init__( + self, + fast_llm_path: tuple[str, ...], + hf_path: tuple[str, ...] | None, + registry: "dict[type[Config], type[ConfigSectionConverter]]", + hf_discriminator_key: str = "type", + ): + self.fast_llm_paths = (fast_llm_path,) + self.hf_paths = (hf_path,) if hf_path is not None else () + self._registry = registry + self._hf_discriminator_key = hf_discriminator_key + self._hf_to_class = {cls.hf_type_name: cls for cls in registry.values() if cls.hf_type_name is not None} + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + sub_config = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + converter_class = self._registry.get(type(sub_config)) + if converter_class is None: + raise NotImplementedError( + f"No converter registered for {type(sub_config).__name__} at {'.'.join(self.fast_llm_paths[0])}" + ) + sub_hf = converter_class.export_config(sub_config) + if converter_class.hf_type_name is not None: + sub_hf = {self._hf_discriminator_key: converter_class.hf_type_name, **sub_hf} + if self.hf_paths: + set_nested_dict_value(hf_out, self.hf_paths[0], sub_hf) + else: + for key, value in sub_hf.items(): + hf_out[key] = value + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + sub_hf = get_nested_dict_value(hf_dict, self.hf_paths[0]) if self.hf_paths else hf_dict + type_name = sub_hf.get(self._hf_discriminator_key) + converter_class = self._hf_to_class.get(type_name) + if converter_class is None: + raise NotImplementedError( + f"No converter registered for HF discriminator {type_name!r} at {'.'.join(self.fast_llm_paths[0])}" + ) + sub_fast_llm = converter_class.import_config(sub_hf) + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], sub_fast_llm) + + +class TypedDictContainerConfigConverter(ConfigConverter): + """Maps a Fast-LLM ``dict[str, Config]`` field to an HF ``dict[str, dict]`` where each entry is round-tripped + through a per-class section converter selected via the entry's runtime type (export) or HF discriminator (import). + + Each entry's HF subdict carries a discriminator key (``"type"`` by default) populated from the converter's + ``hf_type_name``. For homogeneous dicts, register a single class with ``hf_type_name = None``; the discriminator + is then omitted on export and ignored on import. + """ + + recurses: typing.ClassVar[bool] = True + + def __init__( + self, + fast_llm_path: tuple[str, ...], + hf_path: tuple[str, ...], + registry: "dict[type[Config], type[ConfigSectionConverter]]", + hf_discriminator_key: str = "type", + ): + self.fast_llm_paths = (fast_llm_path,) + self.hf_paths = (hf_path,) + self._registry = registry + self._hf_discriminator_key = hf_discriminator_key + self._hf_to_class = {cls.hf_type_name: cls for cls in registry.values() if cls.hf_type_name is not None} + self._homogeneous = len(registry) == 1 and next(iter(registry.values())).hf_type_name is None + if self._homogeneous: + self._homogeneous_class = next(iter(registry.values())) + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + sub_dict = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + out: dict = {} + for name, sub_config in sub_dict.items(): + if self._homogeneous: + converter_class = self._homogeneous_class + else: + converter_class = self._registry.get(type(sub_config)) + if converter_class is None: + raise NotImplementedError( + f"No converter registered for {type(sub_config).__name__} at " + f"{'.'.join(self.fast_llm_paths[0])}[{name!r}]" + ) + sub_hf = converter_class.export_config(sub_config) + if converter_class.hf_type_name is not None: + sub_hf = {self._hf_discriminator_key: converter_class.hf_type_name, **sub_hf} + out[name] = sub_hf + set_nested_dict_value(hf_out, self.hf_paths[0], out) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + sub_hf_dict = get_nested_dict_value(hf_dict, self.hf_paths[0]) + out: dict = {} + for name, sub_hf in sub_hf_dict.items(): + if self._homogeneous: + converter_class = self._homogeneous_class + else: + type_name = sub_hf.get(self._hf_discriminator_key) + converter_class = self._hf_to_class.get(type_name) + if converter_class is None: + raise NotImplementedError( + f"No converter registered for HF discriminator {type_name!r} at " + f"{'.'.join(self.hf_paths[0])}[{name!r}]" + ) + sub_fast_llm = converter_class.import_config(sub_hf) + out[name] = sub_fast_llm + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], out) + + +# ============================================================ +# Section converter — converts one Fast-LLM config class +# ============================================================ + + +class ConfigSectionConverter(abc.ABC): + """Base class for converting one Fast-LLM ``Config`` class ↔ one HF dict subtree. + + Subclasses declare the conversion via ``_create_config_converters``. Format-specific cross-field + invariants go on the ``_validate_export`` hook. The weight side is still imperative (per-converter + ``get_converters`` classmethods on the concrete subclasses); a declarative weight-side primitive will be + added when the weight-converter migration lands. + + Subclasses that participate in :class:`DispatchConfigConverter` set ``hf_type_name`` to the discriminator value + used by the HF format (e.g. ``"attention"``, ``"mamba"``). + """ + + fast_llm_config_class: typing.ClassVar[type[Config]] + hf_type_name: typing.ClassVar[str | None] = None + + @classmethod + @functools.cache + def _create_config_converters(cls) -> dict[str, ConfigConverter]: + """Return declarations keyed by stable string name. Subclasses override entries by re-declaring the key. + + Cached per class — declarations are immutable and depend only on ``cls``. Subclasses must build + and return a *fresh* dict (idiomatically ``{**super()._create_config_converters(), ...}``); mutating + the returned dict in place would corrupt the parent's cache entry for every subsequent caller. + """ + raise NotImplementedError + + @classmethod + def _validate_export(cls, config: Config) -> None: + """Hook for format-specific export-time validation. Default no-op. + + Runs after the architecture-coverage check and before any declaration emits. Use this for cross-field + invariants the format imposes on the Fast-LLM config (e.g. per-layer biases must match a parent flag, + certain sub-configs must be at their default). Subclasses override; super-calls are not required when + the rule is fully replaced (e.g. Qwen2 vs Llama attention biases). + """ + return + + @classmethod + def export_config(cls, config: Config) -> dict: + """Convert a Fast-LLM config object to an HF config dict via this section's declarations.""" + cls._validate_export(config) + out: dict = {} + for converter in cls._create_config_converters().values(): + converter.export_to(config, out) + return out + + @classmethod + def import_config(cls, hf_dict: dict) -> dict: + """Convert an HF config dict to a Fast-LLM config dict via this section's declarations. + + When ``fast_llm_config_class`` carries a ``dynamic_type_name`` (i.e. the target is a registered + dynamic-type subclass), inject ``"type": `` so the caller's ``from_dict`` dispatches to the + correct subclass without each section converter having to prepend it manually. + """ + out: dict = {} + for converter in cls._create_config_converters().values(): + converter.import_to(hf_dict, out) + fast_llm_type = getattr(cls.fast_llm_config_class, "dynamic_type_name", None) + if fast_llm_type is not None: + out = {"type": fast_llm_type, **out} + return out + + @classmethod + @functools.cache + def _consumed_hf_paths(cls) -> frozenset[tuple[str, ...]]: + """Set of HF dict paths consumed by this section's declaration tree. + + Each entry is a tuple-of-keys from the section's HF subdict root. The :meth:`check_hf_coverage` + walker treats every entry as a *recursive prefix* — once an input path matches any prefix, + descent into deeper sub-dicts stops. + + Nested sub-converters (``NestedConfigConverter``/``DispatchConfigConverter`` with ``hf_path`` + set) expand their sub-converter's claims under the nested prefix instead of contributing the + bare prefix, so the walker descends into the subdict and flags unknown keys inside it (e.g. + ``head.normalization.surprise_field``). + + :class:`TypedDictContainerConfigConverter` keeps a blanket prefix because its per-entry sub-dicts + are user-named (pattern keys); we can't statically enumerate which entries will appear or what + keys those entries should claim. + """ + paths: set[tuple[str, ...]] = set() + for declaration in cls._create_config_converters().values(): + if isinstance(declaration, NestedConfigConverter): + sub_class = declaration._converter_class + if declaration._hf_path is None: + # Flat-merge: sub-converter shares the parent's HF namespace. + paths |= sub_class._consumed_hf_paths() + if sub_class.hf_type_name is not None: + paths.add((declaration._hf_discriminator_key,)) + else: + # Nested: prepend hf_path to each sub-claim so the walker recurses into the subdict. + prefix = declaration._hf_path + for sub_path in sub_class._consumed_hf_paths(): + paths.add(prefix + sub_path) + if sub_class.hf_type_name is not None: + paths.add(prefix + (declaration._hf_discriminator_key,)) + elif isinstance(declaration, DispatchConfigConverter): + if declaration.hf_paths: + # Nested dispatch: union of all registered sub-classes' claims under the shared + # hf_path prefix. At runtime only one sub-class fires; the static union is a safe + # over-claim (we only need to *not* flag known keys, never to flag missing ones). + prefix = declaration.hf_paths[0] + paths.add(prefix + (declaration._hf_discriminator_key,)) + for sub_class in declaration._registry.values(): + for sub_path in sub_class._consumed_hf_paths(): + paths.add(prefix + sub_path) + else: + paths.add((declaration._hf_discriminator_key,)) + for sub_class in declaration._registry.values(): + paths |= sub_class._consumed_hf_paths() + elif isinstance(declaration, TypedDictContainerConfigConverter): + paths.add(declaration.hf_paths[0]) + else: + for path in declaration.hf_paths: + if path: + paths.add(path) + return frozenset(paths) + + @classmethod + def check_hf_coverage(cls, hf_dict: dict, *, allowlist: frozenset[str] = frozenset()) -> None: + """Raise :class:`ValueError` if the input HF dict carries keys not consumed by any declaration. + + Walks ``hf_dict`` recursively. A path is considered covered if it (or any of its prefixes) is in + :meth:`_consumed_hf_paths`, or if any segment of the path appears in ``allowlist`` (so transformers' + generic ``PretrainedConfig`` metadata keys — ``architectures``, ``torch_dtype``, ``transformers_version``, + … — are accepted at any depth, including under nested sub-configs like Llava's ``vision_config``). + Uncovered leaves raise; uncovered sub-dicts trigger descent into their entries to surface the offending + leaf path. + + Catches transformers-version drift, manual edits, and corrupted configs at the import boundary — + the symmetric counterpart to the architecture-coverage check (which is statically verified by + ``tests/models/test_converters.py``). + """ + prefixes = cls._consumed_hf_paths() + + def walk(value: typing.Any, path: tuple[str, ...]) -> None: + for length in range(1, len(path) + 1): + if path[:length] in prefixes: + return + if any(segment in allowlist for segment in path): + return + if isinstance(value, dict): + for key, sub in value.items(): + walk(sub, path + (key,)) + return + raise ValueError( + f"{cls.__name__}: HF config has unknown key '{'.'.join(path)}' (value: {value!r}). " + "Possible transformers-version mismatch, manual edit, or corrupted config." + ) + + for key, value in hf_dict.items(): + walk(value, (key,)) + + @classmethod + def check_architecture_coverage(cls, config: Config) -> None: + """Raise if any architecture-hint field reachable from the section's config (recursively) is not consumed. + + Coverage is structural (based on field hints), not value-based: every architecture field at every depth + must be accounted for, even when it currently holds its Fast-LLM default. The walker descends into any + field whose runtime value is a :class:`Config`, collecting an architecture-leaf list, and matches each + leaf against the section's declarations: + + * Recursive declarations (``recurses = True`` — Nested/Dispatch/TypedDictContainer/Ignored, plus Custom + when its author opts in) cover the entire subtree under each listed prefix. Nested/Dispatch/TypedDict + delegate to a sub-converter that runs its own coverage check; Ignored and recursive Custom assume the + author has handled the subtree. + * Non-recursive declarations (Rename, ConstantImport/Export, Default, Optional, ImportOnly, Custom by + default) must list every architecture leaf they consume by exact path. + + Invoked from a test fixture (``tests/models/test_converters.py``) — not from the production + export/import paths. Architecture coverage is a structural invariant of the converter declarations, + so it only needs to hold once per (converter, config-class) pair, not on every save. + """ + Assert.is_(type(config), cls.fast_llm_config_class) + declarations = cls._create_config_converters() + explicit_paths: set[tuple[str, ...]] = set() + recursive_prefixes: list[tuple[str, ...]] = [] + for converter in declarations.values(): + if converter.recurses: + recursive_prefixes.extend(converter.fast_llm_paths) + else: + explicit_paths.update(converter.fast_llm_paths) + missing: list[tuple[str, ...]] = [] + for path in _collect_architecture_paths(config): + if path in explicit_paths: + continue + if any(len(prefix) <= len(path) and path[: len(prefix)] == prefix for prefix in recursive_prefixes): + continue + missing.append(path) + if missing: + # If every missing path shares a top-level prefix that IS claimed (just non-recursively), + # the contributor likely needs a recursive primitive there — surface that as a hint. + shared_prefixes = {path[:1] for path in missing if path[:1] in explicit_paths} + hint = "" + if shared_prefixes: + names = sorted(prefix[0] for prefix in shared_prefixes) + hint = ( + f" (declarations for {names} claim the parent path non-recursively; " + f"either list every architecture sub-field or switch to Nested/Dispatch — " + f"or pass ``recurses=True`` to a Custom/ImportOnly converter when claiming the whole subtree)" + ) + raise ValueError( + f"{cls.__name__}: architecture-hint fields on {type(config).__name__} " + f"have no converter declaration: {[ '.'.join(p) for p in missing ]}{hint}" + ) + + class WeightConverter: def __init__( self, @@ -76,18 +736,6 @@ def import_weight( ) -class CopyWeightConverter(WeightConverter): - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - return weight[0], *[weight[0][:].clone() for _ in range(len(self.export_name) - 1)] - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - return weight[0], *[weight[0][:].clone() for _ in range(len(self.fast_llm_name) - 1)] - - class SplitWeightConverter(WeightConverter): def export_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 8cdb779dd..9074e72fc 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -9,7 +9,12 @@ from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, CheckpointSaveMetadataConfig -from fast_llm.engine.checkpoint.external import ExternalStateDictCheckpointHandler, WeightConverter, logger +from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + ExternalStateDictCheckpointHandler, + WeightConverter, + logger, +) from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert, safe_merge_dicts @@ -120,10 +125,58 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: }, ) + # Top-level HF metadata keys that are always permitted, regardless of the converter tree. + # Covers transformers' generic ``PretrainedConfig`` fields (always present after ``to_dict()``) + # plus a handful of widely-shared metadata that Fast-LLM intentionally does not store. + _HF_METADATA_ALLOWLIST: typing.ClassVar[frozenset[str]] = frozenset( + { + # transformers PretrainedConfig + "_name_or_path", + "architectures", + "auto_map", + "chunk_size_feed_forward", + "dtype", + "id2label", + "is_encoder_decoder", + "label2id", + "model_type", + "output_attentions", + "output_hidden_states", + "problem_type", + "return_dict", + "torch_dtype", + "transformers_version", + "use_cache", + # Token ids — generation/inference, not architecture. + "bos_token_id", + "decoder_start_token_id", + "eos_token_id", + "pad_token_id", + "sep_token_id", + # Initialization / pretraining metadata Fast-LLM does not consume. + "initializer_range", + "max_position_embeddings", + "pretraining_tp", + } + ) + + @classmethod + def _check_hf_coverage(cls, config: dict[str, typing.Any]) -> None: + """Run the HF-side coverage check at the import boundary. + + Skips silently when the format's base-model converter isn't a :class:`ConfigSectionConverter` + (e.g. multimodal aggregators built on top of imperative ``HuggingFaceBaseModelConverter``). + Subclasses that override :meth:`_import_config` should call this explicitly to keep the check + active. + """ + if issubclass(cls.base_model_converter_class, ConfigSectionConverter): + cls.base_model_converter_class.check_hf_coverage(config, allowlist=cls._HF_METADATA_ALLOWLIST) + @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: Assert.eq(config["model_type"], cls.get_huggingface_model_type()) Assert.eq(config["architectures"], [cls.architecture]) + cls._check_hf_coverage(config) return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)}) def _create_weight_converters(self) -> list[WeightConverter]: diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index cc5d80e88..7282de090 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -63,7 +63,7 @@ class AttentionConfig(MixerConfig): ) dense_layer: AffineLinearConfig = Field( desc="Initialization configuration for the dense layer.", - hint=FieldHint.feature, + hint=FieldHint.architecture, ) # TODO: Review names rotary: RotaryConfig = Field( @@ -116,6 +116,7 @@ class AttentionConfig(MixerConfig): " Under Standard Parameterization (SP): default to 0.5. " " Under muP (if scaling head_size size): use 1. " " Under muP (if scaling number of heads instead of head_size): use 0.5.", + hint=FieldHint.architecture, valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) implementation: AttentionImplementation = Field( diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index f834f089b..46aa420be 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -79,10 +79,10 @@ class Llama3RotaryConfig(DefaultRotaryConfig): """ # TODO: Add descriptions. - scale_factor: float = Field(default=8.0, hint=FieldHint.feature) - low_frequency_factor: float = Field(default=1.0, hint=FieldHint.feature) - high_frequency_factor: float = Field(default=4.0, hint=FieldHint.feature) - original_context_length: int = Field(default=8192, hint=FieldHint.feature) + scale_factor: float = Field(default=8.0, hint=FieldHint.architecture) + low_frequency_factor: float = Field(default=1.0, hint=FieldHint.architecture) + high_frequency_factor: float = Field(default=4.0, hint=FieldHint.architecture) + original_context_length: int = Field(default=8192, hint=FieldHint.architecture) def _validate(self) -> None: super()._validate() @@ -103,20 +103,20 @@ class YarnRotaryConfig(DefaultRotaryConfig): """ # TODO: Add descriptions. - scale_factor: float = Field(default=8.0, hint=FieldHint.feature) + scale_factor: float = Field(default=8.0, hint=FieldHint.architecture) attention_factor: None | float = Field( default=None, - hint=FieldHint.feature, + hint=FieldHint.architecture, ) beta_fast: float = Field( default=32.0, - hint=FieldHint.feature, + hint=FieldHint.architecture, ) beta_slow: float = Field( default=1.0, - hint=FieldHint.feature, + hint=FieldHint.architecture, ) - original_context_length: int = Field(default=8192, hint=FieldHint.feature) + original_context_length: int = Field(default=8192, hint=FieldHint.architecture) def _validate(self) -> None: if self.attention_factor is None: diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index aa47a5f2e..25c5fcc82 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -146,7 +146,10 @@ def last_block_config(self) -> BlockConfig: @config_class(dynamic_type={BlockSequenceConfig: "pattern"}) class PatternBlockSequenceConfig(BlockSequenceConfig): _abstract = False - blocks: dict[str, BlockConfig] = Field() + blocks: dict[str, BlockConfig] = Field( + desc="Named block configurations referenced by `pattern`.", + hint=FieldHint.architecture, + ) pattern: list[str] = Field( default=None, desc="The name of each block (key in `blocks`) in the repeated pattern.", diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index d66e50d56..277afb8f1 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -167,7 +167,7 @@ class StochasticMixerConfig(MixerConfig): "Used for inference/eval, checkpoint loading (receives pretrained weights), " "and checkpoint saving (only this mixer is exported). " "If None, uses the first mixer in the dict.", - hint=FieldHint.feature, + hint=FieldHint.architecture, ) predefined_layouts: list[list[str]] = Field( diff --git a/fast_llm/layers/decoder/mlp/config.py b/fast_llm/layers/decoder/mlp/config.py index 672cd05a7..022179c3e 100644 --- a/fast_llm/layers/decoder/mlp/config.py +++ b/fast_llm/layers/decoder/mlp/config.py @@ -64,7 +64,7 @@ class MLPConfig(MLPBaseConfig): activation: ActivationType = Field( default=None, desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", - hint=FieldHint.core, + hint=FieldHint.architecture, ) # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto recompute_level: MLPRecomputeLevel = Field( @@ -97,7 +97,7 @@ class MoEMLPConfig(MLPConfig): router: LinearConfig = Field( # TODO: Improve default? desc="Configuration for the MoE router.", - hint=FieldHint.feature, + hint=FieldHint.architecture, ) router_normalization: NormalizationConfig | None = Field( default=None, diff --git a/fast_llm/layers/vision/config.py b/fast_llm/layers/vision/config.py index 5920a85ee..47cf43391 100644 --- a/fast_llm/layers/vision/config.py +++ b/fast_llm/layers/vision/config.py @@ -34,12 +34,12 @@ class PatchEmbeddingsConfig(BlockConfig): patch_height: int = Field( default=16, desc="Height of image patches, in pixels.", - hint=FieldHint.core, + hint=FieldHint.architecture, ) patch_width: int = Field( default=16, desc="Width of image patches, in pixels.", - hint=FieldHint.core, + hint=FieldHint.architecture, ) full_precision_residual: bool = Field( default=False, diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index ac732ba22..ebb5a54b5 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -1,72 +1,125 @@ +import functools import math import typing from transformers import PretrainedConfig +from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + CustomConfigConverter, + DefaultConfigConverter, + IgnoredConfigConverter, + RenameConfigConverter, + WeightConverter, +) from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig -from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat -from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters +from fast_llm.models.gpt.conversion.llama import ( + LlamaDecoderConverter, + effective_bias, + get_parameter_converter, + get_weight_and_bias_converters, +) from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, MistralBlockConverter, - MistralDecoderConverter, MistralHeadConverter, MistralHuggingfaceCheckpointHandler, ) from fast_llm.utils import Assert, safe_merge_dicts -class AprielMambaConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - return { - "type": "mamba", - "state_size": config["ssm_cfg"]["d_state"], - "d_inner": config["ssm_cfg"].get("d_inner") or config["hidden_size"] * config["ssm_cfg"].get("expand", 1), - "add_linear_biases": config["ssm_cfg"]["bias"], - "convolution_layer": {"bias": {"enabled": config["ssm_cfg"].get("conv_bias", True)}}, - "d_xb": config["ssm_cfg"].get("d_xb") or config["hidden_size"], - "dt_layer": {"bias": {"enabled": config["ssm_cfg"].get("dt_proj_bias", True)}}, - "dt_rank": ( - math.ceil(config["hidden_size"] / 16) - if config["ssm_cfg"].get("dt_rank", "auto") == "auto" - else config["ssm_cfg"]["dt_rank"] - ), - "repeat_kv_before_conv": config["ssm_cfg"].get("repeat_kv_before_conv", True), - } +class AprielMambaConverter(ConfigSectionConverter): + """Converts ``MambaConfig`` ↔ Apriel hybrid SSM HF dict (``ssm_cfg`` subdict + root-level fallbacks). + + A few of MambaConfig's defaults are derived from the HF root's ``hidden_size`` (``d_inner`` defaults + to ``hidden_size * expand``, ``d_xb`` defaults to ``hidden_size``, ``dt_rank="auto"`` resolves to + ``ceil(hidden_size / 16)``). Those declarations read the root HF dict directly, so each leaf + converter sees the full HF root passed by the parent block dispatcher. + """ + + fast_llm_config_class = MambaConfig @classmethod - def export_config(cls, config: MambaConfig) -> dict: - cls._check_config(config) + def _create_config_converters(cls) -> dict: return { - "ssm_cfg": { - "d_state": config.state_size, - "d_inner": config.d_inner, - "bias": config.add_linear_biases, - "conv_bias": ( - config.add_linear_biases - if config.convolution_layer.bias.enabled is None - else config.convolution_layer.bias.enabled - ), - "d_xb": config.d_xb, - "dt_proj_bias": ( - config.add_linear_biases if config.dt_layer.bias.enabled is None else config.dt_layer.bias.enabled - ), - "dt_rank": config.dt_rank, - "repeat_kv_before_conv": config.repeat_kv_before_conv, - } + "state_size": RenameConfigConverter(("state_size",), ("ssm_cfg", "d_state")), + "d_inner": DefaultConfigConverter( + ("d_inner",), + ("ssm_cfg", "d_inner"), + hf_default_fn=lambda hf: hf["hidden_size"] * hf.get("ssm_cfg", {}).get("expand", 1), + ), + "d_xb": DefaultConfigConverter( + ("d_xb",), + ("ssm_cfg", "d_xb"), + hf_default_fn=lambda hf: hf["hidden_size"], + ), + "dt_rank": CustomConfigConverter( + fast_llm_paths=(("dt_rank",),), + hf_paths=(("ssm_cfg", "dt_rank"),), + export_fn=lambda c: {("ssm_cfg", "dt_rank"): c.dt_rank}, + import_fn=lambda hf: { + ("dt_rank",): ( + math.ceil(hf["hidden_size"] / 16) + if hf.get("ssm_cfg", {}).get("dt_rank", "auto") == "auto" + else hf["ssm_cfg"]["dt_rank"] + ) + }, + ), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("ssm_cfg", "bias")), + "repeat_kv_before_conv": DefaultConfigConverter( + ("repeat_kv_before_conv",), + ("ssm_cfg", "repeat_kv_before_conv"), + hf_default_fn=lambda hf: True, + ), + "convolution_layer_bias": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",), ("convolution_layer", "bias")), + hf_paths=(("ssm_cfg", "conv_bias"),), + export_fn=lambda c: { + ("ssm_cfg", "conv_bias"): effective_bias(c.convolution_layer, c.add_linear_biases) + }, + import_fn=lambda hf: { + ("convolution_layer", "bias", "enabled"): hf.get("ssm_cfg", {}).get("conv_bias", True) + }, + ), + # CausalConv1dConfig fields not represented in Apriel HF: weight rides the tensor side via + # ``conv1d.weight``; kernel_size/activation round-trip implicitly at the Fast-LLM defaults. + "convolution_layer_unmapped": IgnoredConfigConverter( + ("convolution_layer", "weight"), + ("convolution_layer", "kernel_size"), + ("convolution_layer", "activation"), + ), + "dt_layer_bias": CustomConfigConverter( + fast_llm_paths=(("dt_layer",), ("dt_layer", "bias")), + hf_paths=(("ssm_cfg", "dt_proj_bias"),), + export_fn=lambda c: {("ssm_cfg", "dt_proj_bias"): effective_bias(c.dt_layer, c.add_linear_biases)}, + import_fn=lambda hf: { + ("dt_layer", "bias", "enabled"): hf.get("ssm_cfg", {}).get("dt_proj_bias", True) + }, + ), + # AffineLinearConfig.weight rides the tensor side via ``dt_proj.weight``. + "dt_layer_unmapped": IgnoredConfigConverter(("dt_layer", "weight")), + # Per-layer biases that must round-trip implicitly via add_linear_biases (validated below). + "linear_layers": IgnoredConfigConverter( + ("z_layer",), + ("x_layer",), + ("b_layer",), + ("c_layer",), + ("output_layer",), + ("dt_input_layer",), + ), + # Parameter sub-configs Mamba doesn't expose to HF; coverage-only. + "parameters": IgnoredConfigConverter(("d_weight",), ("a_log_weight",)), } @classmethod - def _check_config(cls, config: MambaConfig) -> None: - # Opportunity to make derived classes less constrained. - Assert.is_(type(config), MambaConfig) + def _validate_export(cls, config: MambaConfig) -> None: Assert.incl(config.z_layer.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.x_layer.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.b_layer.bias.enabled, (None, config.add_linear_biases)) @@ -99,17 +152,13 @@ def get_converters( *get_weight_and_bias_converters( f"{fast_llm_prefix}.dt_proj", f"{hf_prefix}.dt_proj", - config.add_linear_biases if config.dt_layer.bias.enabled is None else config.dt_layer.bias.enabled, + effective_bias(config.dt_layer, config.add_linear_biases), drop_on_export=drop_on_export, ), *get_weight_and_bias_converters( f"{fast_llm_prefix}.convolution", f"{hf_prefix}.conv1d", - ( - config.add_linear_biases - if config.convolution_layer.bias.enabled is None - else config.convolution_layer.bias.enabled - ), + effective_bias(config.convolution_layer, config.add_linear_biases), drop_on_export=drop_on_export, ), get_parameter_converter( @@ -131,30 +180,42 @@ def get_converters( ] -class GatedDeltaNetConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - return { - "type": "gdn", - "value_heads": config["linear_attn_config"]["gdn_num_value_heads"], - "key_heads": config["linear_attn_config"]["gdn_num_key_heads"], - "key_head_dim": config["linear_attn_config"]["gdn_key_head_dim"], - "value_head_dim": config["linear_attn_config"]["gdn_value_head_dim"], - "convolution_layer": { - "kernel_size": config["linear_attn_config"]["gdn_linear_conv_kernel_size"], - }, - } +class GatedDeltaNetConverter(ConfigSectionConverter): + """Converts ``GatedDeltaNetConfig`` ↔ Apriel HF ``linear_attn_config`` subdict.""" + + fast_llm_config_class = GatedDeltaNetConfig @classmethod - def export_config(cls, config: GatedDeltaNetConfig) -> dict: + def _create_config_converters(cls) -> dict: return { - "linear_attn_config": { - "gdn_num_value_heads": config.value_heads, - "gdn_num_key_heads": config.key_heads, - "gdn_key_head_dim": config.key_head_dim, - "gdn_value_head_dim": config.value_head_dim, - "gdn_linear_conv_kernel_size": config.convolution_layer.kernel_size, - }, + "value_heads": RenameConfigConverter(("value_heads",), ("linear_attn_config", "gdn_num_value_heads")), + "key_heads": RenameConfigConverter(("key_heads",), ("linear_attn_config", "gdn_num_key_heads")), + "key_head_dim": RenameConfigConverter(("key_head_dim",), ("linear_attn_config", "gdn_key_head_dim")), + "value_head_dim": RenameConfigConverter(("value_head_dim",), ("linear_attn_config", "gdn_value_head_dim")), + "convolution_layer_kernel": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",), ("convolution_layer", "kernel_size")), + hf_paths=(("linear_attn_config", "gdn_linear_conv_kernel_size"),), + export_fn=lambda c: { + ("linear_attn_config", "gdn_linear_conv_kernel_size"): c.convolution_layer.kernel_size + }, + import_fn=lambda hf: { + ("convolution_layer", "kernel_size"): hf["linear_attn_config"]["gdn_linear_conv_kernel_size"] + }, + ), + "convolution_layer_unmapped": IgnoredConfigConverter( + ("convolution_layer", "weight"), + ("convolution_layer", "bias"), + ("convolution_layer", "activation"), + ), + # Sub-configs without HF representation; coverage-only. + "sub_configs": IgnoredConfigConverter( + ("normalization",), + ("qkv_projection_layer",), + ("ba_projection_layer",), + ("output_layer",), + ("dt_bias_weight",), + ("a_log_weight",), + ), } @classmethod @@ -209,26 +270,46 @@ def get_converters( ] -class KimiDeltaAttentionConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - return { - "type": "kda", - "head_dim": config["linear_attn_config"]["head_dim"], - "heads": config["linear_attn_config"]["num_heads"], - "convolution_layer": { - "kernel_size": config["linear_attn_config"]["short_conv_kernel_size"], - }, - } +class KimiDeltaAttentionConverter(ConfigSectionConverter): + """Converts ``KimiDeltaAttentionConfig`` ↔ Apriel HF ``linear_attn_config`` subdict.""" + + fast_llm_config_class = KimiDeltaAttentionConfig @classmethod - def export_config(cls, config: KimiDeltaAttentionConfig) -> dict: + def _create_config_converters(cls) -> dict: return { - "linear_attn_config": { - "head_dim": config.head_dim, - "num_heads": config.heads, - "short_conv_kernel_size": config.convolution_layer.kernel_size, - }, + "head_dim": RenameConfigConverter(("head_dim",), ("linear_attn_config", "head_dim")), + "heads": RenameConfigConverter(("heads",), ("linear_attn_config", "num_heads")), + "convolution_layer_kernel": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",), ("convolution_layer", "kernel_size")), + hf_paths=(("linear_attn_config", "short_conv_kernel_size"),), + export_fn=lambda c: { + ("linear_attn_config", "short_conv_kernel_size"): c.convolution_layer.kernel_size + }, + import_fn=lambda hf: { + ("convolution_layer", "kernel_size"): hf["linear_attn_config"]["short_conv_kernel_size"] + }, + ), + "convolution_layer_unmapped": IgnoredConfigConverter( + ("convolution_layer", "weight"), + ("convolution_layer", "bias"), + ("convolution_layer", "activation"), + ), + # Sub-configs without HF representation; coverage-only. + "sub_configs": IgnoredConfigConverter( + ("normalization",), + ("q_projection_layer",), + ("k_projection_layer",), + ("v_projection_layer",), + ("f_a_projection_layer",), + ("f_b_projection_layer",), + ("g_a_projection_layer",), + ("g_b_projection_layer",), + ("beta_projection_layer",), + ("output_projection_layer",), + ("dt_bias_weight",), + ("a_log_weight",), + ), } @classmethod @@ -347,6 +428,11 @@ class AprielGatedDeltaNetBlockConverter(MistralBlockConverter): class AprielBlockConverter: + """Per-block dispatcher: the mixer type is encoded in the parent's ``hybrid_block_layout`` list, + not in a nested HF discriminator, so this dispatcher stays imperative rather than using + :class:`DispatchConfigConverter`. Each branch delegates to a regular declarative block converter. + """ + layout_names = { AttentionConfig: "t", MambaConfig: "m2", @@ -368,6 +454,16 @@ def import_config(cls, config: dict, layout_name: str = "t") -> dict: def export_config(cls, config) -> dict: return cls._converter_classes[type(config.mixer)].export_config(config) + @classmethod + @functools.cache + def _consumed_hf_paths(cls) -> frozenset[tuple[str, ...]]: + """Union of consumed HF paths across every per-mixer-type block converter — used by the parent's + decoder Custom to pre-claim Apriel's flat top-level keys for the HF coverage check.""" + paths: set[tuple[str, ...]] = set() + for sub in cls._converter_classes.values(): + paths |= sub._consumed_hf_paths() + return frozenset(paths) + @classmethod def get_converters( cls, @@ -381,7 +477,13 @@ def get_converters( ) -class AprielDecoderConverter(MistralDecoderConverter): +class AprielDecoderConverter(LlamaDecoderConverter): + """Pattern-style decoder dispatched via Apriel's ``hybrid_block_layout`` list (one entry per block). + Stays imperative because the layout-list shape doesn't match the declarative ``decoder.type`` + discriminator that Apriel2 uses. Overrides every classmethod from + :class:`LlamaDecoderConverter`; the parent is used only as a nominal base. + """ + block_converter_class: typing.ClassVar[type[AprielBlockConverter]] = AprielBlockConverter @classmethod @@ -413,7 +515,8 @@ def export_config(cls, config: BlockSequenceConfig) -> dict: pattern_block_configs = [config.blocks[block_name] for block_name in config.pattern] else: raise NotImplementedError() - # There may be all sorts of blocks, but `safe_merge_dicts` ensures they are compatible. + # Each block emits non-overlapping HF keys (attention -> flat, mamba -> ssm_cfg.*, + # gdn/kda -> linear_attn_config.*) so safe_merge_dicts is sufficient to combine them. return safe_merge_dicts( *[cls.block_converter_class.export_config(block_config) for block_config in block_configs], { @@ -450,8 +553,50 @@ class AprielHeadConverter(MistralHeadConverter): class AprielBaseModelConverter(MistralBaseModelConverter): - decoder_converter_class: typing.ClassVar[type[AprielDecoderConverter]] = AprielDecoderConverter + """Apriel needs the per-position hybrid layout dispatcher (:class:`AprielDecoderConverter`) instead of + the standard Fixed/Pattern dispatch inlined in :class:`LlamaBaseModelConverter`. The override below + replaces the parent's ``"decoder"`` declaration with one that delegates to Apriel's dispatcher. + """ + head_converter_class: typing.ClassVar[type[AprielHeadConverter]] = AprielHeadConverter + apriel_decoder_converter_class: typing.ClassVar[type[AprielDecoderConverter]] = AprielDecoderConverter + + @classmethod + def _create_config_converters(cls) -> dict: + decoder_cls = cls.apriel_decoder_converter_class + + def _decoder_export(parent: Config) -> dict: + return {(k,): v for k, v in decoder_cls.export_config(parent.decoder).items()} + + def _decoder_import(hf_dict: dict) -> dict: + return {("decoder",): decoder_cls.import_config(hf_dict)} + + return { + **super()._create_config_converters(), + "decoder": CustomConfigConverter( + fast_llm_paths=(("decoder",),), + # Block converter is the per-position dispatcher (:class:`AprielBlockConverter`), which + # unions the HF claims of every leaf-mixer's block converter. + hf_paths=( + ("num_hidden_layers",), + ("hybrid_block_layout",), + *decoder_cls.block_converter_class._consumed_hf_paths(), + ), + export_fn=_decoder_export, + import_fn=_decoder_import, + recurses=True, + ), + } + + # --- weight side (imperative): use Apriel's per-position dispatcher instead of the standard inline loop. + + @classmethod + def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + return [ + *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), + *cls.apriel_decoder_converter_class.get_converters(config.decoder, "decoder", "model.layers"), + *cls.head_converter_class.get_converters(config, exported_config), + ] class AprielHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 9b6657b03..d5f89ef96 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -4,11 +4,33 @@ from transformers import PretrainedConfig +from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + ConstantImportConfigConverter, + CustomConfigConverter, + DispatchConfigConverter, + IgnoredConfigConverter, + NestedConfigConverter, + OptionalConfigConverter, + RenameConfigConverter, + TypedDictContainerConfigConverter, + WeightConverter, +) from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig +from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig +from fast_llm.layers.common.normalization.config import ( + LayerNormalizationConfig, + NoNormalizationConfig, + RMSNormalizationConfig, +) from fast_llm.layers.decoder.config import DecoderBlockConfig, StochasticMixerConfig, StochasticMixerSamplingStrategy +from fast_llm.layers.decoder.mlp.config import MLPConfig +from fast_llm.layers.language_model.config import LanguageModelHeadConfig from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.conversion.config import Apriel2TextCheckpointFormat @@ -17,93 +39,144 @@ LlamaEmbeddingsConverter, LlamaNormalizationConverter, MLPLayer2Converter, - QueryWeightConverter, SplitWeightConverter, + assert_no_peft, + effective_bias, get_parameter_converter, get_weight_and_bias_converters, ) from fast_llm.models.gpt.model import GPTModel from fast_llm.utils import Assert, safe_merge_dicts +# ============================================================ +# Helpers +# ============================================================ + + +def _per_layer_bias_export(config: Config, layer_names: tuple[str, ...]) -> dict: + """Emit per-layer ``{layer: {"bias": {"enabled": bool}}}`` only for layers whose bias is explicitly set.""" + out: dict = {} + for layer_name in layer_names: + layer = getattr(config, layer_name) + if layer.bias.enabled is not None: + out[(layer_name,)] = {"bias": {"enabled": layer.bias.enabled}} + return out + + +def _per_layer_bias_import(hf_dict: dict, layer_names: tuple[str, ...]) -> dict: + """Pass through HF ``{layer: {"bias": {...}}}`` entries to the Fast-LLM dict.""" + out: dict = {} + for layer_name in layer_names: + if layer_name in hf_dict: + out[(layer_name,)] = hf_dict[layer_name] + return out + + +def _per_layer_bias_converter(layer_names: tuple[str, ...]) -> CustomConfigConverter: + """Per-layer ``bias.enabled`` round-trip for the named sub-layers of an attention or MLP config: + emits/consumes the HF ``{layer: {"bias": {"enabled": ...}}}`` tree.""" + return CustomConfigConverter( + fast_llm_paths=tuple((name,) for name in layer_names), + hf_paths=tuple((name,) for name in layer_names), + export_fn=lambda c: _per_layer_bias_export(c, layer_names), + import_fn=lambda hf: _per_layer_bias_import(hf, layer_names), + recurses=True, + ) + + +def _apriel2_conv_kernel_converter() -> CustomConfigConverter: + """Round-trip Apriel2's flat ``convolution_layer.kernel_size`` against the Fast-LLM + ``convolution_layer`` sub-config. Shared between :class:`Apriel2GatedDeltaNetConverter` and + :class:`Apriel2KimiDeltaAttentionConverter`.""" + return CustomConfigConverter( + fast_llm_paths=(("convolution_layer",), ("convolution_layer", "kernel_size")), + hf_paths=(("convolution_layer",),), + export_fn=lambda c: {("convolution_layer",): {"kernel_size": c.convolution_layer.kernel_size}}, + import_fn=lambda hf: ({("convolution_layer",): hf["convolution_layer"]} if "convolution_layer" in hf else {}), + ) + + +# ============================================================ +# Mixer converters +# ============================================================ + + +def _apriel2_attention_rotary_export(config: AttentionConfig) -> dict: + """Emit Apriel2's typed rotary subdict. + + Asymmetric with the Fast-LLM type only for the default→``mistral_1d`` rename; ``llama3``/``yarn`` round-trip + by name. The scale parameters of ``llama3``/``yarn`` are emitted under their Fast-LLM field names since + the matching :func:`_apriel2_attention_rotary_import` is a wholesale pass-through of ``hf_dict["rotary"]``. + """ + rotary = config.rotary + if type(rotary) is DefaultRotaryConfig: + return {("rotary",): {"type": "mistral_1d", "theta": rotary.theta}} + if type(rotary) is Llama3RotaryConfig: + return { + ("rotary",): { + "type": "llama3", + "theta": rotary.theta, + "scale_factor": rotary.scale_factor, + "low_frequency_factor": rotary.low_frequency_factor, + "high_frequency_factor": rotary.high_frequency_factor, + "original_context_length": rotary.original_context_length, + } + } + if type(rotary) is YarnRotaryConfig: + return { + ("rotary",): { + "type": "yarn", + "theta": rotary.theta, + "scale_factor": rotary.scale_factor, + "attention_factor": rotary.attention_factor, + "beta_fast": rotary.beta_fast, + "beta_slow": rotary.beta_slow, + "original_context_length": rotary.original_context_length, + } + } + raise NotImplementedError(f"Unsupported rotary type: {type(rotary).__name__}") + + +def _apriel2_attention_rotary_import(hf_dict: dict) -> dict: + rotary = dict(hf_dict["rotary"]) + if rotary.get("type") == "mistral_1d": + rotary["type"] = "default" + return {("rotary",): rotary} + + +class Apriel2AttentionConverter(ConfigSectionConverter): + fast_llm_config_class = AttentionConfig + hf_type_name = "attention" -class Apriel2AttentionConverter: @classmethod - def import_config(cls, config: dict) -> dict: - rotary = config["rotary"] - # Map Apriel2 HuggingFace rotary type to Fast-LLM internal type - if rotary.get("type") == "mistral_1d": - rotary = {**rotary, "type": "default"} - result = { - "type": "attention", - "heads": config["heads"], - "head_groups": config["head_groups"], - "head_size": config["head_size"], - "rotary": rotary, - } - # Per-layer bias configuration mirroring Fast-LLM structure - # If per-layer configs exist, use them; otherwise fall back to add_linear_biases - if "query_layer" in config: - result["query_layer"] = config["query_layer"] - if "key_layer" in config: - result["key_layer"] = config["key_layer"] - if "value_layer" in config: - result["value_layer"] = config["value_layer"] - if "dense_layer" in config: - result["dense_layer"] = config["dense_layer"] - # add_linear_biases serves as default for layers without explicit config - if "add_linear_biases" in config: - result["add_linear_biases"] = config["add_linear_biases"] - if "window_size" in config: - result["window_size"] = config["window_size"] - return result - - @classmethod - def export_config(cls, config: AttentionConfig) -> dict: - from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig - - if type(config.rotary) is DefaultRotaryConfig: - rotary_type = "mistral_1d" - elif type(config.rotary) is Llama3RotaryConfig: - rotary_type = "llama3" - elif type(config.rotary) is YarnRotaryConfig: - rotary_type = "yarn" - else: - raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") - - result = { - "type": "attention", - "heads": config.heads, - "head_groups": config.head_groups, - "head_size": config.head_size, - "rotary": { - "type": rotary_type, - "theta": config.rotary.theta, - }, + def _create_config_converters(cls) -> dict: + layer_names = ("query_layer", "key_layer", "value_layer", "dense_layer") + return { + "heads": RenameConfigConverter(("heads",), ("heads",)), + "head_groups": RenameConfigConverter(("head_groups",), ("head_groups",)), + "head_size": RenameConfigConverter(("head_size",), ("head_size",)), + "rotary": CustomConfigConverter( + fast_llm_paths=(("rotary",),), + hf_paths=(("rotary",),), + export_fn=_apriel2_attention_rotary_export, + import_fn=_apriel2_attention_rotary_import, + recurses=True, + ), + # Apriel2 emits add_linear_biases only when False; the True default is implicit. + "add_linear_biases": OptionalConfigConverter( + ("add_linear_biases",), ("add_linear_biases",), sentinel=True + ), + "window_size": OptionalConfigConverter(("window_size",), ("window_size",)), + "linear_layers": _per_layer_bias_converter(layer_names), + "causal": IgnoredConfigConverter(("causal",)), + "softmax_scale_power": IgnoredConfigConverter(("softmax_scale_power",)), + "query_norm": ConstantImportConfigConverter(("query_norm",), None), + "key_norm": ConstantImportConfigConverter(("key_norm",), None), + "value_norm": ConstantImportConfigConverter(("value_norm",), None), + "shared_key_value": ConstantImportConfigConverter(("shared_key_value",), False), } - if config.window_size is not None: - result["window_size"] = config.window_size - # Export per-layer bias configuration - # Only include if explicitly set (not None) - if config.query_layer.bias.enabled is not None: - result["query_layer"] = {"bias": {"enabled": config.query_layer.bias.enabled}} - if config.key_layer.bias.enabled is not None: - result["key_layer"] = {"bias": {"enabled": config.key_layer.bias.enabled}} - if config.value_layer.bias.enabled is not None: - result["value_layer"] = {"bias": {"enabled": config.value_layer.bias.enabled}} - if config.dense_layer.bias.enabled is not None: - result["dense_layer"] = {"bias": {"enabled": config.dense_layer.bias.enabled}} - # add_linear_biases as fallback default; omit when True (the Fast-LLM default) to avoid - # round-trip inflation on configs that don't set it explicitly. - if not config.add_linear_biases: - result["add_linear_biases"] = config.add_linear_biases - return result - - @classmethod - def _get_effective_bias(cls, layer_config, default: bool) -> bool: - """Get effective bias setting: use layer-specific if set, else default.""" - if layer_config.bias.enabled is not None: - return layer_config.bias.enabled - return default + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -113,13 +186,11 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - # Determine effective bias for each projection - q_bias = cls._get_effective_bias(config.query_layer, config.add_linear_biases) - k_bias = cls._get_effective_bias(config.key_layer, config.add_linear_biases) - v_bias = cls._get_effective_bias(config.value_layer, config.add_linear_biases) - o_bias = cls._get_effective_bias(config.dense_layer, config.add_linear_biases) - # For key_value, both k and v must have same bias setting - # (they're combined in Fast-LLM's key_value layer) + q_bias = effective_bias(config.query_layer, config.add_linear_biases) + k_bias = effective_bias(config.key_layer, config.add_linear_biases) + v_bias = effective_bias(config.value_layer, config.add_linear_biases) + o_bias = effective_bias(config.dense_layer, config.add_linear_biases) + # k_proj and v_proj are merged in Fast-LLM's key_value layer; treat as biased only if both sides agree. kv_bias = k_bias and v_bias return [ @@ -127,8 +198,6 @@ def get_converters( f"{fast_llm_prefix}.query", f"{hf_prefix}.q_proj", q_bias, - QueryWeightConverter, - config, drop_on_export=drop_on_export, ), *get_weight_and_bias_converters( @@ -148,40 +217,65 @@ def get_converters( ] -class Apriel2MambaConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - result = { - "type": "mamba", - "state_size": config["state_size"], - "d_inner": config["d_inner"], - "add_linear_biases": config["add_linear_biases"], - } - if "d_xb" in config: - result["d_xb"] = config["d_xb"] - if "dt_rank" in config: - result["dt_rank"] = config["dt_rank"] - return result - - @classmethod - def export_config(cls, config: MambaConfig) -> dict: - exported = { - "type": "mamba", - "state_size": config.state_size, - "d_inner": config.d_inner, - "d_conv": config.convolution_layer.kernel_size, - "add_linear_biases": config.add_linear_biases, - "conv_bias": config.convolution_layer.bias.enabled, - "dt_proj_bias": config.dt_layer.bias.enabled, - } +def _apriel2_mamba_aux_export(config: MambaConfig) -> dict: + """Emit Apriel2's mamba-specific HF auxiliaries (``d_conv`` from convolution kernel size, plus the + convolution and dt-projection effective bias flags). These have no flat Fast-LLM analogue.""" + return { + ("d_conv",): config.convolution_layer.kernel_size, + ("conv_bias",): config.convolution_layer.bias.enabled, + ("dt_proj_bias",): config.dt_layer.bias.enabled, + } - if config.d_xb is not None: - exported["d_xb"] = config.d_xb - if config.dt_rank != "auto": - exported["dt_rank"] = config.dt_rank +def _apriel2_mamba_aux_import(hf_dict: dict) -> dict: + """Reverse of :func:`_apriel2_mamba_aux_export`. ``conv_bias`` / ``dt_proj_bias`` can diverge from the + mixer-wide ``add_linear_biases`` flag, so they must populate the per-layer ``bias.enabled`` directly; + dropping them on import would silently mask HF bias weights when the weight loader checks the + per-layer flag.""" + out: dict = {} + if "d_conv" in hf_dict: + out[("convolution_layer", "kernel_size")] = hf_dict["d_conv"] + if "conv_bias" in hf_dict: + out[("convolution_layer", "bias", "enabled")] = hf_dict["conv_bias"] + if "dt_proj_bias" in hf_dict: + out[("dt_layer", "bias", "enabled")] = hf_dict["dt_proj_bias"] + return out - return exported + +class Apriel2MambaConverter(ConfigSectionConverter): + fast_llm_config_class = MambaConfig + hf_type_name = "mamba" + + @classmethod + def _create_config_converters(cls) -> dict: + return { + "state_size": RenameConfigConverter(("state_size",), ("state_size",)), + "d_inner": RenameConfigConverter(("d_inner",), ("d_inner",)), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("add_linear_biases",)), + "d_xb": RenameConfigConverter(("d_xb",), ("d_xb",)), + "dt_rank": RenameConfigConverter(("dt_rank",), ("dt_rank",)), + "aux": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",), ("dt_layer",)), + hf_paths=(("d_conv",), ("conv_bias",), ("dt_proj_bias",)), + export_fn=_apriel2_mamba_aux_export, + import_fn=_apriel2_mamba_aux_import, + recurses=True, + ), + # Architecture fields with no HF counterpart; they round-trip at their Fast-LLM defaults. + "layers_unmapped": IgnoredConfigConverter( + ("z_layer",), + ("x_layer",), + ("b_layer",), + ("c_layer",), + ("output_layer",), + ("dt_input_layer",), + ("a_log_weight",), + ("d_weight",), + ("repeat_kv_before_conv",), + ), + } + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -235,33 +329,38 @@ def get_converters( ] -class Apriel2GatedDeltaNetConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - result = { - "type": "gdn", - "value_heads": config["value_heads"], - "key_heads": config["key_heads"], - "key_head_dim": config["key_head_dim"], - "value_head_dim": config["value_head_dim"], - } - if "convolution_layer" in config: - result["convolution_layer"] = config["convolution_layer"] - return result +class Apriel2GatedDeltaNetConverter(ConfigSectionConverter): + fast_llm_config_class = GatedDeltaNetConfig + hf_type_name = "gdn" @classmethod - def export_config(cls, config: GatedDeltaNetConfig) -> dict: + def _create_config_converters(cls) -> dict: return { - "type": "gdn", - "value_heads": config.value_heads, - "key_heads": config.key_heads, - "key_head_dim": config.key_head_dim, - "value_head_dim": config.value_head_dim, - "convolution_layer": { - "kernel_size": config.convolution_layer.kernel_size, - }, + "value_heads": RenameConfigConverter(("value_heads",), ("value_heads",)), + "key_heads": RenameConfigConverter(("key_heads",), ("key_heads",)), + "key_head_dim": RenameConfigConverter(("key_head_dim",), ("key_head_dim",)), + "value_head_dim": RenameConfigConverter(("value_head_dim",), ("value_head_dim",)), + "convolution_layer_kernel": _apriel2_conv_kernel_converter(), + # CausalConv1dConfig sub-fields the Apriel2 HF format does not surface (weight rides the tensor + # side; bias/activation round-trip at their Fast-LLM defaults). + "convolution_layer_unmapped": IgnoredConfigConverter( + ("convolution_layer", "weight"), + ("convolution_layer", "bias"), + ("convolution_layer", "activation"), + ), + # Architecture fields not surfaced in HF; round-trip at default. + "layers_unmapped": IgnoredConfigConverter( + ("normalization",), + ("qkv_projection_layer",), + ("ba_projection_layer",), + ("output_layer",), + ("dt_bias_weight",), + ("a_log_weight",), + ), } + # --- weight side (imperative) --- + @classmethod def get_converters( cls, @@ -314,34 +413,51 @@ def get_converters( ] -class Apriel2KimiDeltaAttentionConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - result = { - "type": "kda", - "heads": config["heads"], - "head_dim": config["head_dim"], - } - if "convolution_layer" in config: - result["convolution_layer"] = config["convolution_layer"] - if "normalization" in config: - result["normalization"] = config["normalization"] - return result +class Apriel2KimiDeltaAttentionConverter(ConfigSectionConverter): + fast_llm_config_class = KimiDeltaAttentionConfig + hf_type_name = "kda" @classmethod - def export_config(cls, config: KimiDeltaAttentionConfig) -> dict: + def _create_config_converters(cls) -> dict: return { - "type": "kda", - "heads": config.heads, - "head_dim": config.head_dim, - "convolution_layer": { - "kernel_size": config.convolution_layer.kernel_size, - }, - "normalization": { - "epsilon": config.normalization.epsilon, - }, + "heads": RenameConfigConverter(("heads",), ("heads",)), + "head_dim": RenameConfigConverter(("head_dim",), ("head_dim",)), + "convolution_layer_kernel": _apriel2_conv_kernel_converter(), + # CausalConv1dConfig sub-fields not surfaced in HF (same as :class:`Apriel2GatedDeltaNetConverter`). + "convolution_layer_unmapped": IgnoredConfigConverter( + ("convolution_layer", "weight"), + ("convolution_layer", "bias"), + ("convolution_layer", "activation"), + ), + "normalization_epsilon": CustomConfigConverter( + fast_llm_paths=(("normalization",), ("normalization", "epsilon")), + hf_paths=(("normalization",),), + export_fn=lambda c: {("normalization",): {"epsilon": c.normalization.epsilon}}, + import_fn=lambda hf: ({("normalization",): hf["normalization"]} if "normalization" in hf else {}), + ), + # Other GatedRMSNormalizationConfig architecture fields are dropped on the HF side. + "normalization_unmapped": IgnoredConfigConverter( + ("normalization", "weight"), + ("normalization", "zero_centered"), + ), + # Architecture fields not surfaced in HF; round-trip at default. + "layers_unmapped": IgnoredConfigConverter( + ("q_projection_layer",), + ("k_projection_layer",), + ("v_projection_layer",), + ("f_a_projection_layer",), + ("f_b_projection_layer",), + ("g_a_projection_layer",), + ("g_b_projection_layer",), + ("beta_projection_layer",), + ("output_projection_layer",), + ("dt_bias_weight",), + ("a_log_weight",), + ), } + # --- weight side (imperative) --- + @classmethod def get_converters( cls, @@ -350,11 +466,7 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - # Fast-LLM KDA uses abbreviated names matching the external module: - # q_proj, k_proj, v_proj, q_conv, k_conv, v_conv, f_a_proj, f_b_proj, - # g_a_proj, g_b_proj, beta_proj, o_proj, A_log, dt_bias, norm return [ - # Q/K/V projections *get_weight_and_bias_converters( f"{fast_llm_prefix}.q_proj", f"{hf_prefix}.q_proj", @@ -373,7 +485,6 @@ def get_converters( False, drop_on_export=drop_on_export, ), - # Convolutions (Q, K, V) *get_weight_and_bias_converters( f"{fast_llm_prefix}.q_conv", f"{hf_prefix}.q_conv", @@ -392,7 +503,6 @@ def get_converters( False, drop_on_export=drop_on_export, ), - # Gate projections (f_a, f_b, g_a, g_b) *get_weight_and_bias_converters( f"{fast_llm_prefix}.f_a_proj", f"{hf_prefix}.f_a_proj", @@ -417,21 +527,18 @@ def get_converters( False, drop_on_export=drop_on_export, ), - # Beta projection *get_weight_and_bias_converters( f"{fast_llm_prefix}.beta_proj", f"{hf_prefix}.beta_proj", False, drop_on_export=drop_on_export, ), - # Output projection *get_weight_and_bias_converters( f"{fast_llm_prefix}.o_proj", f"{hf_prefix}.o_proj", False, drop_on_export=drop_on_export, ), - # Learnable parameters get_parameter_converter( f"{fast_llm_prefix}.A_log", f"{hf_prefix}.A_log", @@ -442,7 +549,6 @@ def get_converters( f"{hf_prefix}.dt_bias", drop_on_export=drop_on_export, ), - # Normalization *LlamaNormalizationConverter.get_converters( config.normalization, f"{fast_llm_prefix}.norm", @@ -452,56 +558,36 @@ def get_converters( ] -class Apriel2StochasticMixerConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - mixers = {} - for name, sub_mixer_config in config["mixers"].items(): - mixer_type = sub_mixer_config["type"] - if mixer_type == "attention": - mixers[name] = Apriel2AttentionConverter.import_config(sub_mixer_config) - elif mixer_type == "mamba": - mixers[name] = Apriel2MambaConverter.import_config(sub_mixer_config) - elif mixer_type == "gdn": - mixers[name] = Apriel2GatedDeltaNetConverter.import_config(sub_mixer_config) - elif mixer_type == "kda": - mixers[name] = Apriel2KimiDeltaAttentionConverter.import_config(sub_mixer_config) - else: - raise ValueError(f"Unknown sub-mixer type: {mixer_type}") - - result = { - "type": "stochastic", - "mixers": mixers, - "main_mixer_name": config["main_mixer_name"], - } - if "sampling_strategy" in config: - result["sampling_strategy"] = config["sampling_strategy"] - return result +# Mixer dispatch registry — used inside StochasticMixer (no nested-stochastic) and at the block level. +APRIEL2_LEAF_MIXER_REGISTRY: dict = { + AttentionConfig: Apriel2AttentionConverter, + MambaConfig: Apriel2MambaConverter, + GatedDeltaNetConfig: Apriel2GatedDeltaNetConverter, + KimiDeltaAttentionConfig: Apriel2KimiDeltaAttentionConverter, +} + + +class Apriel2StochasticMixerConverter(ConfigSectionConverter): + fast_llm_config_class = StochasticMixerConfig + hf_type_name = "stochastic" @classmethod - def export_config(cls, config: StochasticMixerConfig) -> dict: - mixers = {} - for name, sub_mixer in config.mixers.items(): - mixer_type = type(sub_mixer) - if mixer_type is AttentionConfig: - mixers[name] = Apriel2AttentionConverter.export_config(sub_mixer) - elif mixer_type is MambaConfig: - mixers[name] = Apriel2MambaConverter.export_config(sub_mixer) - elif mixer_type is GatedDeltaNetConfig: - mixers[name] = Apriel2GatedDeltaNetConverter.export_config(sub_mixer) - elif mixer_type is KimiDeltaAttentionConfig: - mixers[name] = Apriel2KimiDeltaAttentionConverter.export_config(sub_mixer) - else: - raise ValueError(f"Unknown sub-mixer type: {mixer_type}") - - result = { - "type": "stochastic", - "mixers": mixers, - "main_mixer_name": config.main_mixer_name, + def _create_config_converters(cls) -> dict: + return { + "mixers": TypedDictContainerConfigConverter( + fast_llm_path=("mixers",), + hf_path=("mixers",), + registry=APRIEL2_LEAF_MIXER_REGISTRY, + ), + "main_mixer_name": RenameConfigConverter(("main_mixer_name",), ("main_mixer_name",)), + "sampling_strategy": OptionalConfigConverter( + ("sampling_strategy",), + ("sampling_strategy",), + sentinel=StochasticMixerSamplingStrategy.uniform, + ), } - if config.sampling_strategy != StochasticMixerSamplingStrategy.uniform: - result["sampling_strategy"] = config.sampling_strategy.value - return result + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -513,136 +599,181 @@ def get_converters( ) -> list[WeightConverter]: converters = [] for name, sub_mixer in config.mixers.items(): - mixer_type = type(sub_mixer) - if mixer_type is AttentionConfig: - converter_class = Apriel2AttentionConverter - hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}" - elif mixer_type is MambaConfig: - converter_class = Apriel2MambaConverter - hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}" - elif mixer_type is GatedDeltaNetConfig: - converter_class = Apriel2GatedDeltaNetConverter - hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}" - elif mixer_type is KimiDeltaAttentionConfig: - converter_class = Apriel2KimiDeltaAttentionConverter - hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}" - else: - raise ValueError(f"Unknown sub-mixer type: {mixer_type}") + converter_class = APRIEL2_LEAF_MIXER_REGISTRY.get(type(sub_mixer)) + if converter_class is None: + raise ValueError(f"Unknown sub-mixer type: {type(sub_mixer)}") converters.extend( converter_class.get_converters( sub_mixer, f"{fast_llm_prefix}.mixers.{name}", - hf_sub_mixer_prefix, + f"{hf_prefix}.mixers.{name}", drop_on_export=drop_on_export, ) ) - return converters -class Apriel2BlockConverter: - @classmethod - def import_config(cls, config: dict, block_config: dict) -> dict: - mixer_config = block_config["mixer"] - mixer_type = mixer_config["type"] - - if mixer_type == "attention": - mixer = Apriel2AttentionConverter.import_config(mixer_config) - elif mixer_type == "mamba": - mixer = Apriel2MambaConverter.import_config(mixer_config) - elif mixer_type == "stochastic": - mixer = Apriel2StochasticMixerConverter.import_config(mixer_config) - elif mixer_type == "gdn": - mixer = Apriel2GatedDeltaNetConverter.import_config(mixer_config) - elif mixer_type == "kda": - mixer = Apriel2KimiDeltaAttentionConverter.import_config(mixer_config) - else: - raise ValueError(f"Unknown mixer type: {mixer_type}") - - from fast_llm.functional.config import ActivationType - - mlp_config = block_config["mlp"] - mlp = { - "type": "mlp", - "intermediate_size": mlp_config["intermediate_size"], - "activation": ActivationType.from_hf_name(mlp_config["activation"]), - "gated": mlp_config["gated"], - "add_linear_biases": mlp_config["add_linear_biases"], +# Block-level mixer registry includes StochasticMixer (which can wrap leaf mixers). +APRIEL2_BLOCK_MIXER_REGISTRY: dict = { + **APRIEL2_LEAF_MIXER_REGISTRY, + StochasticMixerConfig: Apriel2StochasticMixerConverter, +} + + +# ============================================================ +# Normalization converters +# ============================================================ + + +class Apriel2RMSNormConverter(ConfigSectionConverter): + fast_llm_config_class = RMSNormalizationConfig + hf_type_name = "rms_norm" + + @classmethod + def _create_config_converters(cls) -> dict: + return { + "epsilon": RenameConfigConverter(("epsilon",), ("epsilon",)), + "weight": IgnoredConfigConverter(("weight",)), + "zero_centered": ConstantImportConfigConverter(("zero_centered",), False), } - # Import per-layer MLP bias settings (layer_1, layer_2) - for layer_name in ("layer_1", "layer_2"): - if layer_name in mlp_config: - layer_cfg = mlp_config[layer_name] - if "bias" in layer_cfg: - mlp[layer_name] = {"bias": layer_cfg["bias"]} - normalization = block_config["normalization"] +class Apriel2LayerNormConverter(ConfigSectionConverter): + fast_llm_config_class = LayerNormalizationConfig + hf_type_name = "layer_norm" + + @classmethod + def _create_config_converters(cls) -> dict: return { - "mixer": mixer, - "mlp": mlp, - "normalization": normalization, + "epsilon": RenameConfigConverter(("epsilon",), ("epsilon",)), + "weight": IgnoredConfigConverter(("weight",)), + "bias": IgnoredConfigConverter(("bias",)), + "zero_centered": ConstantImportConfigConverter(("zero_centered",), False), } + +class Apriel2NoNormConverter(ConfigSectionConverter): + fast_llm_config_class = NoNormalizationConfig + hf_type_name = "none" + @classmethod - def export_config(cls, config: DecoderBlockConfig) -> dict: - from fast_llm.layers.common.normalization.config import ( - LayerNormalizationConfig, - NoNormalizationConfig, - RMSNormalizationConfig, - ) + def _create_config_converters(cls) -> dict: + return {} + + +APRIEL2_NORM_REGISTRY: dict = { + RMSNormalizationConfig: Apriel2RMSNormConverter, + LayerNormalizationConfig: Apriel2LayerNormConverter, + NoNormalizationConfig: Apriel2NoNormConverter, +} + + +# ============================================================ +# MLP, Block, Decoder, Head +# ============================================================ - mixer_type = type(config.mixer) - - if mixer_type is AttentionConfig: - mixer = Apriel2AttentionConverter.export_config(config.mixer) - elif mixer_type is MambaConfig: - mixer = Apriel2MambaConverter.export_config(config.mixer) - elif mixer_type is StochasticMixerConfig: - mixer = Apriel2StochasticMixerConverter.export_config(config.mixer) - elif mixer_type is GatedDeltaNetConfig: - mixer = Apriel2GatedDeltaNetConverter.export_config(config.mixer) - elif mixer_type is KimiDeltaAttentionConfig: - mixer = Apriel2KimiDeltaAttentionConverter.export_config(config.mixer) - else: - raise ValueError(f"Unknown mixer type: {mixer_type}") - - norm_type = type(config.normalization) - if norm_type is RMSNormalizationConfig: - norm_type_str = "rms_norm" - elif norm_type is LayerNormalizationConfig: - norm_type_str = "layer_norm" - elif norm_type is NoNormalizationConfig: - norm_type_str = "none" - else: - raise ValueError(f"Unknown normalization type: {norm_type}") - - from fast_llm.layers.decoder.mlp.config import MLPConfig - - if not isinstance(config.mlp, MLPConfig): - raise ValueError(f"Unsupported MLP type: {type(config.mlp)}") - - mlp = { - "type": "mlp", - "intermediate_size": config.mlp.intermediate_size, - "activation": config.mlp.activation.value, - "gated": config.mlp.gated, - "add_linear_biases": config.mlp.add_linear_biases, + +class Apriel2MLPConverter(ConfigSectionConverter): + fast_llm_config_class = MLPConfig + hf_type_name = "mlp" + + @classmethod + def _create_config_converters(cls) -> dict: + layer_names = ("layer_1", "layer_2") + return { + "intermediate_size": RenameConfigConverter(("intermediate_size",), ("intermediate_size",)), + "gated": RenameConfigConverter(("gated",), ("gated",)), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("add_linear_biases",)), + "activation": CustomConfigConverter( + fast_llm_paths=(("activation",),), + hf_paths=(("activation",),), + export_fn=lambda c: {("activation",): c.activation.hf_name}, + import_fn=lambda hf: {("activation",): ActivationType.from_hf_name(hf["activation"])}, + ), + "layers": _per_layer_bias_converter(layer_names), + "pre_norm": ConstantImportConfigConverter(("pre_norm",), None), + "post_norm": ConstantImportConfigConverter(("post_norm",), None), } - # Export per-layer MLP bias settings (layer_1, layer_2) - if config.mlp.layer_1.bias.enabled is not None: - mlp["layer_1"] = {"bias": {"enabled": config.mlp.layer_1.bias.enabled}} - if config.mlp.layer_2.bias.enabled is not None: - mlp["layer_2"] = {"bias": {"enabled": config.mlp.layer_2.bias.enabled}} - normalization = {"type": norm_type_str, "epsilon": config.normalization.epsilon} + @classmethod + def get_converters( + cls, + config: MLPConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + layer_1_bias = effective_bias(config.layer_1, config.add_linear_biases) + layer_2_bias = effective_bias(config.layer_2, config.add_linear_biases) + if config.gated: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_1", + (f"{hf_prefix}.gate_proj", f"{hf_prefix}.up_proj"), + layer_1_bias, + SplitWeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_2", + f"{hf_prefix}.down_proj", + layer_2_bias, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ] + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_1", + f"{hf_prefix}.up_proj", + layer_1_bias, + WeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_2", + f"{hf_prefix}.down_proj", + layer_2_bias, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ] + +class Apriel2BlockConverter(ConfigSectionConverter): + fast_llm_config_class = DecoderBlockConfig + + @classmethod + def _create_config_converters(cls) -> dict: return { - "mixer": mixer, - "mlp": mlp, - "normalization": normalization, + "mixer": DispatchConfigConverter( + fast_llm_path=("mixer",), + hf_path=("mixer",), + registry=APRIEL2_BLOCK_MIXER_REGISTRY, + ), + "mlp": NestedConfigConverter(("mlp",), Apriel2MLPConverter, hf_path=("mlp",)), + "normalization": DispatchConfigConverter( + fast_llm_path=("normalization",), + hf_path=("normalization",), + registry=APRIEL2_NORM_REGISTRY, + ), + "pre_mixer_normalization": ConstantImportConfigConverter(("pre_mixer_normalization",), None), + "pre_mlp_normalization": ConstantImportConfigConverter(("pre_mlp_normalization",), None), + "post_mixer_normalization": ConstantImportConfigConverter(("post_mixer_normalization",), None), + "post_mlp_normalization": ConstantImportConfigConverter(("post_mlp_normalization",), None), + "output_scale": IgnoredConfigConverter(("output_scale",)), } + @classmethod + def _validate_export(cls, config: DecoderBlockConfig) -> None: + # Apriel2 HF format only represents plain MLP. ``NestedConfigConverter`` dispatches by fixed class + # (``Apriel2MLPConverter`` registered against ``MLPConfig``) and would silently descend into a + # ``MoEMLPConfig`` via MRO, dropping every MoE-specific architecture field. + Assert.is_(type(config.mlp), MLPConfig) + assert not config.output_scale.enabled + + # --- weight side (imperative) --- + @classmethod def get_converters( cls, @@ -651,86 +782,25 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - converters = [] - mixer_type = type(config.mixer) - if mixer_type is AttentionConfig: - converter_class = Apriel2AttentionConverter - hf_mixer_prefix = f"{hf_prefix}.mixer" - elif mixer_type is MambaConfig: - converter_class = Apriel2MambaConverter - hf_mixer_prefix = f"{hf_prefix}.mixer" - elif mixer_type is StochasticMixerConfig: - converter_class = Apriel2StochasticMixerConverter - hf_mixer_prefix = f"{hf_prefix}.mixer" - elif mixer_type is GatedDeltaNetConfig: - converter_class = Apriel2GatedDeltaNetConverter - hf_mixer_prefix = f"{hf_prefix}.mixer" - elif mixer_type is KimiDeltaAttentionConfig: - converter_class = Apriel2KimiDeltaAttentionConverter - hf_mixer_prefix = f"{hf_prefix}.mixer" - else: - raise ValueError(f"Unknown mixer type: {mixer_type}") - - converters.extend( - converter_class.get_converters( + mixer_converter_class = APRIEL2_BLOCK_MIXER_REGISTRY.get(type(config.mixer)) + if mixer_converter_class is None: + raise ValueError(f"Unknown mixer type: {type(config.mixer)}") + converters: list[WeightConverter] = list( + mixer_converter_class.get_converters( config.mixer, f"{fast_llm_prefix}.mixer", - hf_mixer_prefix, + f"{hf_prefix}.mixer", drop_on_export=drop_on_export, ) ) - - # Per-layer MLP bias: use layer-specific setting if set, else default - def get_mlp_layer_bias(layer_config, default: bool) -> bool: - if layer_config.bias.enabled is not None: - return layer_config.bias.enabled - return default - - layer_1_bias = get_mlp_layer_bias(config.mlp.layer_1, config.mlp.add_linear_biases) - layer_2_bias = get_mlp_layer_bias(config.mlp.layer_2, config.mlp.add_linear_biases) - - if config.mlp.gated: - # Gated MLP: gate_proj + up_proj -> layer_1 (split), down_proj -> layer_2 - converters.extend( - [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - layer_1_bias, - SplitWeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - layer_2_bias, - MLPLayer2Converter, - drop_on_export=drop_on_export, - ), - ] - ) - else: - # Non-gated MLP: up_proj -> layer_1, down_proj -> layer_2 - # Note: layer_2 still needs MLPLayer2Converter for the transpose - converters.extend( - [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - f"{hf_prefix}.mlp.up_proj", - layer_1_bias, - WeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - layer_2_bias, - MLPLayer2Converter, - drop_on_export=drop_on_export, - ), - ] + converters.extend( + Apriel2MLPConverter.get_converters( + config.mlp, + f"{fast_llm_prefix}.mlp", + f"{hf_prefix}.mlp", + drop_on_export=drop_on_export, ) - + ) converters.extend( [ *LlamaNormalizationConverter.get_converters( @@ -747,128 +817,121 @@ def get_mlp_layer_bias(layer_config, default: bool) -> bool: ), ] ) - return converters -class Apriel2DecoderConverter: - block_converter_class: typing.ClassVar[type[Apriel2BlockConverter]] = Apriel2BlockConverter +class Apriel2FixedDecoderConverter(ConfigSectionConverter): + fast_llm_config_class = FixedBlockSequenceConfig + hf_type_name = "fixed" + block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BlockConverter @classmethod - def import_config(cls, config: dict) -> dict: - decoder_config = config["decoder"] - decoder_type = decoder_config["type"] - - if decoder_type == "fixed": - block_config = decoder_config["block"] - imported_block = cls.block_converter_class.import_config(config, block_config) - - return { - "type": "fixed", - "num_blocks": decoder_config["num_blocks"], - "block": imported_block, - } - - elif decoder_type == "pattern": - blocks = {} - for name, block_config in decoder_config["blocks"].items(): - blocks[name] = cls.block_converter_class.import_config(config, block_config) - - return { - "type": "pattern", - "blocks": blocks, - "pattern": decoder_config["pattern"], - "num_blocks": decoder_config["num_blocks"], - } - - else: - raise ValueError(f"Unknown decoder type: {decoder_type}") + def _create_config_converters(cls) -> dict: + return { + "num_blocks": RenameConfigConverter(("num_blocks",), ("num_blocks",)), + "block": NestedConfigConverter(("block",), cls.block_converter_class, hf_path=("block",)), + } @classmethod - def export_config(cls, config) -> dict: - from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig + def get_converters( + cls, + config: FixedBlockSequenceConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + converters: list[WeightConverter] = [] + for block_index in range(config.num_blocks): + converters += cls.block_converter_class.get_converters( + config.block, + f"{fast_llm_prefix}.{block_index}", + f"{hf_prefix}.{block_index}", + drop_on_export=drop_on_export, + ) + return converters - if isinstance(config, FixedBlockSequenceConfig): - block_config = cls.block_converter_class.export_config(config.block) - return { - "decoder": { - "type": "fixed", - "num_blocks": config.num_blocks, - "block": block_config, - } - } - elif isinstance(config, PatternBlockSequenceConfig): - blocks = {} - for name, block_config in config.blocks.items(): - blocks[name] = cls.block_converter_class.export_config(block_config) - - return { - "decoder": { - "type": "pattern", - "blocks": blocks, - "pattern": config.pattern, - "num_blocks": config.num_blocks, - } - } +class Apriel2PatternDecoderConverter(ConfigSectionConverter): + fast_llm_config_class = PatternBlockSequenceConfig + hf_type_name = "pattern" + block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BlockConverter - else: - raise ValueError(f"Unknown decoder config type: {type(config)}") + @classmethod + def _create_config_converters(cls) -> dict: + return { + "num_blocks": RenameConfigConverter(("num_blocks",), ("num_blocks",)), + "pattern": RenameConfigConverter(("pattern",), ("pattern",)), + "blocks": TypedDictContainerConfigConverter( + fast_llm_path=("blocks",), + hf_path=("blocks",), + registry={DecoderBlockConfig: cls.block_converter_class}, + ), + } @classmethod def get_converters( cls, - config, + config: PatternBlockSequenceConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig - - converters = [] - if type(config) is FixedBlockSequenceConfig: - for block_index in range(config.num_blocks): - converters += cls.block_converter_class.get_converters( - config.block, - f"{fast_llm_prefix}.{block_index}", - f"{hf_prefix}.{block_index}", - drop_on_export, - ) - elif type(config) is PatternBlockSequenceConfig: - for block_index in range(config.num_blocks): - block_config = config.blocks[config.pattern[block_index % len(config.pattern)]] - converters += cls.block_converter_class.get_converters( - block_config, - f"{fast_llm_prefix}.{block_index}", - f"{hf_prefix}.{block_index}", - drop_on_export, - ) - else: - raise NotImplementedError(f"Unsupported config type: {type(config).__name__}") + converters: list[WeightConverter] = [] + for block_index in range(config.num_blocks): + block_config = config.blocks[config.pattern[block_index % len(config.pattern)]] + converters += cls.block_converter_class.get_converters( + block_config, + f"{fast_llm_prefix}.{block_index}", + f"{hf_prefix}.{block_index}", + drop_on_export=drop_on_export, + ) return converters -class Apriel2HeadConverter: +APRIEL2_DECODER_REGISTRY: dict = { + FixedBlockSequenceConfig: Apriel2FixedDecoderConverter, + PatternBlockSequenceConfig: Apriel2PatternDecoderConverter, +} + + +def get_apriel2_decoder_converter( + decoder_config: FixedBlockSequenceConfig | PatternBlockSequenceConfig, +) -> type[ConfigSectionConverter]: + """Look up the Apriel2 per-shape decoder converter for a given decoder config instance.""" + converter_class = APRIEL2_DECODER_REGISTRY.get(type(decoder_config)) + if converter_class is None: + raise NotImplementedError(f"Unsupported decoder type: {type(decoder_config).__name__}") + return converter_class + + +class Apriel2HeadConverter(ConfigSectionConverter): + fast_llm_config_class = LanguageModelHeadConfig + normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter @classmethod - def import_config(cls, config: dict) -> dict: - norm_config = config["head"]["normalization"] - return {"normalization": {"type": "rms_norm", "epsilon": norm_config["epsilon"]}} + def _create_config_converters(cls) -> dict: + return { + "normalization": DispatchConfigConverter( + fast_llm_path=("normalization",), + hf_path=("normalization",), + registry=APRIEL2_NORM_REGISTRY, + ), + "output_weight": IgnoredConfigConverter(("output_weight",)), + # Apriel2 HF format does not support multi-token prediction; pin to 1 so any non-default value + # fails on export instead of silently round-tripping. + "prediction_heads": ConstantImportConfigConverter(("prediction_heads",), 1), + "final_logit_softcap": ConstantImportConfigConverter(("final_logit_softcap",), None), + } @classmethod - def export_config(cls, config) -> dict: - from fast_llm.layers.language_model.config import LanguageModelHeadConfig + def _validate_export(cls, config: LanguageModelHeadConfig) -> None: + # The config side dispatches normalization through APRIEL2_NORM_REGISTRY (RMS/Layer/None), but the + # weight side below hardcodes ``normalization_converter_class`` (RMSNorm-only). Fail loudly here so a + # LayerNorm/NoNorm head config doesn't silently round-trip through the wrong weight conversion. + Assert.is_(type(config.normalization), RMSNormalizationConfig) - Assert.custom(isinstance, config, LanguageModelHeadConfig) - return { - "head": { - "normalization": { - "type": "rms_norm", - "epsilon": config.normalization.epsilon, - } - } - } + # --- weight side (imperative) --- @classmethod def get_converters( @@ -892,39 +955,47 @@ def get_converters( ] -class Apriel2BaseModelConverter: - decoder_converter_class: typing.ClassVar[type[Apriel2DecoderConverter]] = Apriel2DecoderConverter +class Apriel2BaseModelConverter(ConfigSectionConverter): + fast_llm_config_class = GPTBaseModelConfig + embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter head_converter_class: typing.ClassVar[type[Apriel2HeadConverter]] = Apriel2HeadConverter @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - "embeddings": cls.embeddings_converter_class.import_config(config), - "decoder": cls.decoder_converter_class.import_config(config), - "head": cls.head_converter_class.import_config(config), - "hidden_size": config["hidden_size"], - "tied_embedding_weight": config["tie_word_embeddings"], + "embeddings": NestedConfigConverter(("embeddings",), cls.embeddings_converter_class), + "decoder": DispatchConfigConverter( + fast_llm_path=("decoder",), + hf_path=("decoder",), + registry=APRIEL2_DECODER_REGISTRY, + ), + "head": NestedConfigConverter(("head",), cls.head_converter_class, hf_path=("head",)), + "hidden_size": RenameConfigConverter(("hidden_size",), ("hidden_size",)), + "tied_embedding_weight": RenameConfigConverter(("tied_embedding_weight",), ("tie_word_embeddings",)), + "peft": IgnoredConfigConverter(("peft",)), + # ``Apriel2TextConfig`` default-injects ``{"embeddings": {"max_position_embeddings": 2048}}`` + # the Fast-LLM converter doesn't use — vocab_size rides at top level via the flat-merged + # ``LlamaEmbeddingsConverter``. Claim only the specific injected leaf so any future field + # transformers adds to the same subdict trips the HF coverage check. + "embeddings_subdict_unmapped": IgnoredConfigConverter( + hf_paths=(("embeddings", "max_position_embeddings"),) + ), } @classmethod - def export_config(cls, config: GPTBaseModelConfig) -> dict: - Assert.custom(isinstance, config, GPTBaseModelConfig) - return safe_merge_dicts( - cls.embeddings_converter_class.export_config(config.embeddings), - cls.decoder_converter_class.export_config(config.decoder), - cls.head_converter_class.export_config(config.head), - { - "tie_word_embeddings": config.tied_embedding_weight, - "hidden_size": config.hidden_size, - }, - ) + def _validate_export(cls, config: GPTBaseModelConfig) -> None: + assert_no_peft(config) + + # --- weight side (imperative) --- @classmethod def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: return [ *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), - *cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.decoder.blocks"), + *get_apriel2_decoder_converter(config.decoder).get_converters( + config.decoder, "decoder", "model.decoder.blocks" + ), *cls.head_converter_class.get_converters(config.head, exported_config, "head"), ] @@ -955,7 +1026,7 @@ def get_model_files(cls) -> tuple[str, str, str | None]: @classmethod def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: base_model = config.base_model - exported = safe_merge_dicts( + return safe_merge_dicts( cls.base_model_converter_class.export_config(base_model), { "architectures": [cls.architecture], @@ -967,10 +1038,10 @@ def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: }, }, ) - return exported @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> dict[str, typing.Any]: + cls._check_hf_coverage(config) return {"base_model": cls.base_model_converter_class.import_config(config)} @classmethod diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index 712160663..fcb92df0c 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -2,12 +2,18 @@ import typing +from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + ConstantExportConfigConverter, + CustomConfigConverter, + IgnoredConfigConverter, + RenameConfigConverter, SplitWeightConverter, WeightConverter, ) -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, ProportionalRotaryConfig @@ -27,7 +33,6 @@ LlamaHeadConverter, LlamaNormalizationConverter, MLPLayer2Converter, - QueryWeightConverter, get_parameter_converter, get_weight_and_bias_converters, ) @@ -178,8 +183,6 @@ def get_converters( f"{fast_llm_prefix}.query", f"{hf_prefix}.q_proj", False, - QueryWeightConverter, - config, drop_on_export=drop_on_export, ), *kv_converters, @@ -625,54 +628,132 @@ def export_config(cls, config: LanguageModelHeadConfig) -> dict: return out -class Gemma4BaseModelConverter: +def _gemma4_bidirectional_export(_: Config) -> dict: + # Fast-LLM is text-only; bidirectional attention (used for vision tokens in the multimodal + # model) is not implemented. Always emit ``None``. + return {("use_bidirectional_attention",): None} + + +def _gemma4_bidirectional_import(hf_dict: dict) -> dict: + # ``use_bidirectional_attention="vision"`` only affects vision tokens; the text path stays + # causal. Only ``"all"`` toggles ``is_causal=False`` for the text decoder, which we don't + # implement. + if hf_dict.get("use_bidirectional_attention") == "all": + raise NotImplementedError( + 'Gemma 4 `use_bidirectional_attention="all"` is not supported (text path stays causal).' + ) + return {} + + +class Gemma4BaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): + """Top-level converter for ``GPTBaseModelConfig`` ↔ Gemma 4 HF dict. + + Gemma 4 has several wrinkles that prevent the standard per-section decomposition used by Llama: + + * The decoder is a :class:`PatternBlockSequenceConfig` whose two named blocks + (``sliding_attention`` / ``full_attention``) share most HF keys but diverge on ``head_dim`` and + rope parameters. The HF format emits both block variants from a single root-level config, so + the block-level transform inherently sees both Fast-LLM blocks at once. + * ``embedding_scale = hidden_size ** 0.5`` and ``router_input_scale = hidden_size ** -0.5`` make + the embeddings and routed MLP cross-reference the root-level ``hidden_size``. + + Each section ((embeddings, decoder, head)) is therefore expressed as a :class:`CustomConfigConverter` + that delegates to an imperative helper class (kept private to this module). Coverage at the + section level is satisfied via ``recurses=True``. + """ + + fast_llm_config_class = GPTBaseModelConfig + decoder_converter_class: typing.ClassVar[type[Gemma4DecoderConverter]] = Gemma4DecoderConverter embeddings_converter_class: typing.ClassVar[type[Gemma4EmbeddingsConverter]] = Gemma4EmbeddingsConverter head_converter_class: typing.ClassVar[type[Gemma4HeadConverter]] = Gemma4HeadConverter @classmethod - def import_config(cls, config: dict) -> dict: - if config.get("hidden_size_per_layer_input") not in (None, 0): - raise NotImplementedError( - "Gemma 4 Per-Layer Embeddings (`hidden_size_per_layer_input != 0`) are not supported." - ) - if config.get("num_kv_shared_layers", 0): - raise NotImplementedError("Gemma 4 cross-layer KV sharing (`num_kv_shared_layers != 0`) is not supported.") - if config.get("use_double_wide_mlp", False): - raise NotImplementedError("Gemma 4 `use_double_wide_mlp=True` is not supported.") - # `use_bidirectional_attention="vision"` only affects vision tokens; the text path stays causal. - # Only `"all"` toggles `is_causal=False` for the text decoder, which we don't implement. - if config.get("use_bidirectional_attention") == "all": - raise NotImplementedError( - 'Gemma 4 `use_bidirectional_attention="all"` is not supported (text path stays causal).' - ) + def _create_config_converters(cls) -> dict: + decoder_cls = cls.decoder_converter_class + embeddings_cls = cls.embeddings_converter_class + head_cls = cls.head_converter_class + + def _embeddings_export(parent: Config) -> dict: + return {(k,): v for k, v in embeddings_cls.export_config(parent.embeddings, parent.hidden_size).items()} + + def _embeddings_import(hf_dict: dict) -> dict: + return {("embeddings",): embeddings_cls.import_config(hf_dict)} + + def _decoder_export(parent: Config) -> dict: + return {(k,): v for k, v in decoder_cls.export_config(parent.decoder, parent.hidden_size).items()} + + def _decoder_import(hf_dict: dict) -> dict: + return {("decoder",): decoder_cls.import_config(hf_dict)} + + def _head_export(parent: Config) -> dict: + return {(k,): v for k, v in head_cls.export_config(parent.head).items()} + + def _head_import(hf_dict: dict) -> dict: + return {("head",): head_cls.import_config(hf_dict)} + return { - "embeddings": cls.embeddings_converter_class.import_config(config), - "decoder": cls.decoder_converter_class.import_config(config), - "head": cls.head_converter_class.import_config(config), - "hidden_size": config["hidden_size"], - "tied_embedding_weight": config["tie_word_embeddings"], + "embeddings": CustomConfigConverter( + fast_llm_paths=(("embeddings",),), + hf_paths=(("vocab_size",),), + export_fn=_embeddings_export, + import_fn=_embeddings_import, + recurses=True, + ), + "decoder": CustomConfigConverter( + fast_llm_paths=(("decoder",),), + hf_paths=( + ("num_hidden_layers",), + ("layer_types",), + ("num_attention_heads",), + ("num_key_value_heads",), + ("head_dim",), + ("global_head_dim",), + ("num_global_key_value_heads",), + ("attention_bias",), + ("attention_dropout",), + ("sliding_window",), + ("rms_norm_eps",), + ("attention_k_eq_v",), + ("rope_parameters",), + ("intermediate_size",), + ("hidden_activation",), + ("enable_moe_block",), + ("num_experts",), + ("top_k_experts",), + ("moe_intermediate_size",), + ), + export_fn=_decoder_export, + import_fn=_decoder_import, + recurses=True, + ), + "head": CustomConfigConverter( + fast_llm_paths=(("head",),), + hf_paths=(("final_logit_softcapping",),), + export_fn=_head_export, + import_fn=_head_import, + recurses=True, + ), + "hidden_size": RenameConfigConverter(("hidden_size",), ("hidden_size",)), + "tied_embedding_weight": RenameConfigConverter(("tied_embedding_weight",), ("tie_word_embeddings",)), + "peft": IgnoredConfigConverter(("peft",)), + # TODO: Implement Per-Layer Embeddings (PLE). Gemma4TextConfig defaults to 256; explicitly + # zero to disable the feature in the exported model until Fast-LLM supports it natively. + "hidden_size_per_layer_input": ConstantExportConfigConverter(("hidden_size_per_layer_input",), 0), + "num_kv_shared_layers": ConstantExportConfigConverter(("num_kv_shared_layers",), 0), + "use_double_wide_mlp": ConstantExportConfigConverter(("use_double_wide_mlp",), False), + "use_bidirectional_attention": CustomConfigConverter( + fast_llm_paths=(), + hf_paths=(("use_bidirectional_attention",),), + export_fn=_gemma4_bidirectional_export, + import_fn=_gemma4_bidirectional_import, + ), + # Vocab-size-per-layer is part of Per-Layer Embeddings (PLE), gated by + # ``hidden_size_per_layer_input``. PLE is rejected above, so we ignore the size field too. + "vocab_size_per_layer_input": IgnoredConfigConverter(hf_paths=(("vocab_size_per_layer_input",),)), } - @classmethod - def export_config(cls, config: GPTBaseModelConfig) -> dict: - Assert.custom(isinstance, config, GPTBaseModelConfig) - return safe_merge_dicts( - cls.embeddings_converter_class.export_config(config.embeddings, config.hidden_size), - cls.decoder_converter_class.export_config(config.decoder, config.hidden_size), - cls.head_converter_class.export_config(config.head), - { - "tie_word_embeddings": config.tied_embedding_weight, - "hidden_size": config.hidden_size, - # TODO: Implement Per-Layer Embeddings (PLE). Gemma4TextConfig defaults to 256; - # explicitly zero to disable the feature in the exported model until Fast-LLM - # supports it natively. - "hidden_size_per_layer_input": 0, - # Fast-LLM is text-only; bidirectional attention (used for vision tokens in the - # multimodal model) is not implemented. - "use_bidirectional_attention": None, - }, - ) + # --- weight side (imperative) --- @classmethod def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index d3ad266ee..fd3f66362 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -5,10 +5,18 @@ import torch import transformers +from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + ConstantImportConfigConverter, + CustomConfigConverter, + DefaultConfigConverter, + IgnoredConfigConverter, IgnoreExportWeightConverter, IgnoreImportWeightConverter, + NestedConfigConverter, + RenameConfigConverter, SplitWeightConverter, WeightConverter, ) @@ -18,7 +26,9 @@ from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig +from fast_llm.layers.common.linear.config import AffineLinearConfig from fast_llm.layers.common.normalization.config import RMSNormalizationConfig +from fast_llm.layers.common.peft.config import NoPeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.language_model.config import ( @@ -30,13 +40,29 @@ from fast_llm.models.gpt.conversion.config import LlamaCheckpointFormat from fast_llm.models.gpt.model import GPTModel from fast_llm.tensor import SafeTensorSlice -from fast_llm.utils import Assert, div, safe_merge_dicts +from fast_llm.utils import Assert, div _TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig) logger = logging.getLogger(__name__) +def assert_no_peft(config: GPTBaseModelConfig) -> None: + """Reject any non-trivial PEFT config: HuggingFace formats serialize the base weights only, + so a configured LoRA (or other adapter) would be silently dropped on export.""" + Assert.custom(isinstance, config.peft, NoPeftConfig) + + +def effective_bias(layer_config: AffineLinearConfig, default: bool) -> bool: + """Resolve a layer's effective bias flag: explicit ``bias.enabled`` if set, else the parent default.""" + return default if layer_config.bias.enabled is None else layer_config.bias.enabled + + +# ============================================================ +# Weight converters (imperative — kept as-is during config migration) +# ============================================================ + + def get_parameter_converter( fast_llm_name: str | tuple[str, ...], hf_name: str | tuple[str, ...], @@ -97,16 +123,139 @@ def get_weight_and_bias_converters( return converters -class LlamaNormalizationConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - return {"type": "rms_norm", "epsilon": config["rms_norm_eps"]} +class MLPLayer2Converter(WeightConverter): + # Similar to SplitWeightConverter, but handles the optional MLP transpose. + # Still ok for non-gated (trivial split) and biases (trivial 1d transpose) + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (merged_weight,) = weight + return tuple(t.contiguous() for t in merged_weight[:].t().chunk(len(self.export_name), dim=-1)) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + merged_weight = torch.cat([weight_[:] for weight_ in weight], dim=-1) + return (merged_weight.t().contiguous(),) + + +class KeyValueWeightConverter(WeightConverter): + # Hf uses the real format for rotary embeddings, and keeps the key and value separate. + _config: AttentionConfig + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (key_value,) = weight + key, value = key_value[:].chunk(2) + return key, value + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + key, value = weight + key_value = torch.cat([key[:], value[:]]) + return (key_value,) + + +# ============================================================ +# Config converters (declarative) +# ============================================================ + + +def _llama_rotary_export(config: AttentionConfig) -> dict: + """Build the HF rotary block(s) from a Fast-LLM rotary config. + + Returns a dict keyed by the (Llama-flat) HF paths the converter declares; values vary with rotary subtype and + the active transformers major version (v4 puts ``rope_theta`` flat with optional ``rope_scaling``; + v5 consolidates everything into ``rope_parameters``). + """ + rotary = config.rotary + rope_parameters = {"rope_theta": rotary.theta} + if type(rotary) is DefaultRotaryConfig: + rope_parameters["rope_type"] = "default" + elif type(rotary) is Llama3RotaryConfig: + rope_parameters.update( + { + "rope_type": "llama3", + "factor": rotary.scale_factor, + "low_freq_factor": rotary.low_frequency_factor, + "high_freq_factor": rotary.high_frequency_factor, + "original_max_position_embeddings": rotary.original_context_length, + } + ) + elif type(rotary) is YarnRotaryConfig: + rope_parameters.update( + { + "rope_type": "yarn", + "attention_factor": rotary.attention_factor, + "beta_fast": rotary.beta_fast, + "beta_slow": rotary.beta_slow, + "original_max_position_embeddings": rotary.original_context_length, + } + ) + else: + raise NotImplementedError(f"Unsupported rotary type: {type(rotary).__name__}") + + if _TRANSFORMERS_V4: + out: dict = {("rope_theta",): rope_parameters["rope_theta"]} + if type(rotary) is not DefaultRotaryConfig: + out[("rope_scaling",)] = {k: v for k, v in rope_parameters.items() if k != "rope_theta"} + return out + return {("rope_parameters",): rope_parameters} + + +def _llama_rotary_import(hf_dict: dict) -> dict: + """Reverse of :func:`_llama_rotary_export`. Detects v4/v5 layout from the HF dict.""" + if "rope_parameters" in hf_dict: # transformers v5 + rope_params = hf_dict["rope_parameters"] + rope_theta = rope_params["rope_theta"] + else: # transformers v4 + rope_params = hf_dict.get("rope_scaling") or {} + rope_theta = hf_dict["rope_theta"] + rope_type = rope_params.get("rope_type", "default") + rotary_config: dict = {"type": rope_type, "theta": rope_theta} + if rope_type == "default": + pass + elif rope_type == "llama3": + rotary_config.update( + { + "scale_factor": rope_params["factor"], + "low_frequency_factor": rope_params["low_freq_factor"], + "high_frequency_factor": rope_params["high_freq_factor"], + "original_context_length": rope_params["original_max_position_embeddings"], + } + ) + elif rope_type == "yarn": + rotary_config.update( + { + "attention_factor": rope_params["attention_factor"], + "beta_fast": rope_params["beta_fast"], + "beta_slow": rope_params["beta_slow"], + "original_context_length": rope_params["original_max_position_embeddings"], + } + ) + else: + raise NotImplementedError(f"Unsupported rotary type: {rope_type}") + return {("rotary",): rotary_config} + + +class LlamaNormalizationConverter(ConfigSectionConverter): + """Converts ``RMSNormalizationConfig`` ↔ Llama's flat ``rms_norm_eps`` field.""" + + fast_llm_config_class = RMSNormalizationConfig @classmethod - def export_config(cls, config: RMSNormalizationConfig) -> dict: - Assert.custom(isinstance, config, RMSNormalizationConfig) - assert not config.zero_centered - return {"rms_norm_eps": config.epsilon} + def _create_config_converters(cls) -> dict: + return { + "type": ConstantImportConfigConverter(("type",), "rms_norm"), + "epsilon": RenameConfigConverter(("epsilon",), ("rms_norm_eps",)), + "weight": IgnoredConfigConverter(("weight",)), + "zero_centered": ConstantImportConfigConverter(("zero_centered",), False), + } + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -124,31 +273,38 @@ def get_converters( ) -class LlamaMLPConverter: +class LlamaMLPConverter(ConfigSectionConverter): + """Converts ``MLPConfig`` ↔ Llama's flat ``intermediate_size``/``mlp_bias``/``hidden_act`` fields. + + Llama is always gated (``ConstantImportConfigConverter(("gated",), True)``). + """ + + fast_llm_config_class = MLPConfig + @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - "intermediate_size": config["intermediate_size"], - "add_linear_biases": config["mlp_bias"], - "activation": ActivationType.from_hf_name(config["hidden_act"]), - "gated": True, + "intermediate_size": RenameConfigConverter(("intermediate_size",), ("intermediate_size",)), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("mlp_bias",)), + "activation": CustomConfigConverter( + fast_llm_paths=(("activation",),), + hf_paths=(("hidden_act",),), + export_fn=lambda c: {("hidden_act",): c.activation.hf_name}, + import_fn=lambda hf: {("activation",): ActivationType.from_hf_name(hf["hidden_act"])}, + ), + "gated": ConstantImportConfigConverter(("gated",), True), + # Llama doesn't expose per-layer bias overrides; the bias-match check lives on _validate_export. + "layers": IgnoredConfigConverter(("layer_1",), ("layer_2",)), + "pre_norm": ConstantImportConfigConverter(("pre_norm",), None), + "post_norm": ConstantImportConfigConverter(("post_norm",), None), } @classmethod - def export_config(cls, config: MLPConfig) -> dict: - Assert.custom(isinstance, config, MLPConfig) + def _validate_export(cls, config: MLPConfig) -> None: Assert.incl(config.layer_1.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) - assert config.gated - if config.pre_norm is not None: - raise NotImplementedError(f"MLP `pre_norm` is not supported by `{cls.__name__}`.") - if config.post_norm is not None: - raise NotImplementedError(f"MLP `post_norm` is not supported by `{cls.__name__}`.") - return { - "intermediate_size": config.intermediate_size, - "mlp_bias": config.add_linear_biases, - "hidden_act": config.activation.hf_name, - } + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -176,135 +332,64 @@ def get_converters( ] -class MLPLayer2Converter(WeightConverter): - # Similar to SplitWeightConverter, but handles the optional MLP transpose. - # Still ok for non-gated (trivial split) and biases (trivial 1d transpose) +class LlamaAttentionConverter(ConfigSectionConverter): + """Converts ``AttentionConfig`` ↔ Llama's flat attention fields. - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (merged_weight,) = weight - return tuple(t.contiguous() for t in merged_weight[:].t().chunk(len(self.export_name), dim=-1)) - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - merged_weight = torch.cat([weight_[:] for weight_ in weight], dim=-1) - return (merged_weight.t().contiguous(),) + Notable wrinkles: + - ``head_dim`` is computed from ``hidden_size // num_attention_heads`` when missing on import. + - Rotary handling is delegated to a :class:`CustomConfigConverter` because it spans v4/v5 transformers + layouts and three rotary subtypes. + - Per-layer linear biases (query/key/value/dense) are validated to match ``add_linear_biases`` on + ``_validate_export``; Llama does not expose layer-level overrides, so the sub-config fields are + blanket-consumed via :class:`IgnoredConfigConverter`. + """ + fast_llm_config_class = AttentionConfig -class LlamaAttentionConverter: @classmethod - def import_config(cls, config: dict) -> dict: - # Normalize rope params to a single dict before dispatching on rope_type. - # transformers v5 consolidates rope_theta + rope_scaling into rope_parameters. - # transformers v4: rope_theta at top level, rope_scaling dict for non-default types. - # Note: detection is on checkpoint format, not transformers version — old checkpoints - # remain loadable with v5 transformers. - if "rope_parameters" in config: # transformers v5 - rope_params = config["rope_parameters"] - rope_theta = rope_params["rope_theta"] - else: # transformers v4 - rope_params = config.get("rope_scaling") or {} - rope_theta = config["rope_theta"] - rope_type = rope_params.get("rope_type", "default") - rotary_config = {"type": rope_type, "theta": rope_theta} - if rope_type == "default": - pass - elif rope_type == "llama3": - rotary_config.update( - { - "scale_factor": rope_params["factor"], - "low_frequency_factor": rope_params["low_freq_factor"], - "high_frequency_factor": rope_params["high_freq_factor"], - "original_context_length": rope_params["original_max_position_embeddings"], - } - ) - elif rope_type == "yarn": - rotary_config.update( - { - "attention_factor": rope_params["attention_factor"], - "beta_fast": rope_params["beta_fast"], - "beta_slow": rope_params["beta_slow"], - "original_context_length": rope_params["original_max_position_embeddings"], - } - ) - else: - raise NotImplementedError(f"Unsupported rotary type: {rope_type}") - out = { - "rotary": rotary_config, - "heads": config["num_attention_heads"], - "head_groups": config["num_key_value_heads"], - "head_size": config.get("head_dim"), - "add_linear_biases": config["attention_bias"], - "dropout": config["attention_dropout"], + def _create_config_converters(cls) -> dict: + return { + "heads": RenameConfigConverter(("heads",), ("num_attention_heads",)), + "head_groups": RenameConfigConverter(("head_groups",), ("num_key_value_heads",)), + "head_size": DefaultConfigConverter( + ("head_size",), + ("head_dim",), + hf_default_fn=lambda hf: div(hf["hidden_size"], hf["num_attention_heads"]), + ), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("attention_bias",)), + "dropout": RenameConfigConverter(("dropout",), ("attention_dropout",)), + "causal": ConstantImportConfigConverter(("causal",), True), + "softmax_scale_power": ConstantImportConfigConverter(("softmax_scale_power",), 0.5), + "linear_layers": IgnoredConfigConverter( + ("query_layer",), ("key_layer",), ("value_layer",), ("dense_layer",) + ), + "rotary": CustomConfigConverter( + fast_llm_paths=(("rotary",),), + hf_paths=(("rope_theta",), ("rope_scaling",), ("rope_parameters",)), + export_fn=_llama_rotary_export, + import_fn=_llama_rotary_import, + recurses=True, + ), + "query_norm": ConstantImportConfigConverter(("query_norm",), None), + "key_norm": ConstantImportConfigConverter(("key_norm",), None), + "value_norm": ConstantImportConfigConverter(("value_norm",), None), + "shared_key_value": ConstantImportConfigConverter(("shared_key_value",), False), } - if out["head_size"] is None: - out["head_size"] = div(config["hidden_size"], out["heads"]) - - return out @classmethod - def export_config(cls, config: AttentionConfig) -> dict: - cls._check_config(config) - Assert.eq(config.softmax_scale_power, 0.5) - if config.query_norm is not None: - raise NotImplementedError(f"`query_norm` is not supported by `{cls.__name__}`.") - if config.key_norm is not None: - raise NotImplementedError(f"`key_norm` is not supported by `{cls.__name__}`.") - if config.value_norm is not None: - raise NotImplementedError(f"`value_norm` is not supported by `{cls.__name__}`.") - if config.shared_key_value: - raise NotImplementedError(f"`shared_key_value` is not supported by `{cls.__name__}`.") - rope_parameters = {"rope_theta": config.rotary.theta} - if type(config.rotary) is DefaultRotaryConfig: - rope_parameters["rope_type"] = "default" - elif type(config.rotary) is Llama3RotaryConfig: - rope_parameters.update( - { - "rope_type": "llama3", - "factor": config.rotary.scale_factor, - "low_freq_factor": config.rotary.low_frequency_factor, - "high_freq_factor": config.rotary.high_frequency_factor, - "original_max_position_embeddings": config.rotary.original_context_length, - } - ) - elif type(config.rotary) is YarnRotaryConfig: - rope_parameters.update( - { - "rope_type": "yarn", - "attention_factor": config.rotary.attention_factor, - "beta_fast": config.rotary.beta_fast, - "beta_slow": config.rotary.beta_slow, - "original_max_position_embeddings": config.rotary.original_context_length, - } - ) - else: - raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") - - common = { - "num_attention_heads": config.heads, - "num_key_value_heads": config.head_groups, - "head_dim": config.head_size, - "attention_bias": config.add_linear_biases, - "attention_dropout": config.dropout, - } - if _TRANSFORMERS_V4: - out = {**common, "rope_theta": rope_parameters["rope_theta"]} - if type(config.rotary) is not DefaultRotaryConfig: - out["rope_scaling"] = {k: v for k, v in rope_parameters.items() if k != "rope_theta"} - return out - return {**common, "rope_parameters": rope_parameters} + def _validate_export(cls, config: AttentionConfig) -> None: + """Default: Llama requires per-layer biases to be unset (``None``) or to match ``add_linear_biases``. - @classmethod - def _check_config(cls, config: AttentionConfig) -> None: - # Opportunity to make derived classes less constrained. + Subclasses (e.g. Qwen2 with always-on Q/K/V biases and no dense bias) override. + """ Assert.is_(type(config), AttentionConfig) Assert.incl(config.query_layer.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.key_layer.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.value_layer.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.dense_layer.bias.enabled, (None, config.add_linear_biases)) + # --- weight side (imperative) --- + @classmethod def get_converters( cls, @@ -318,8 +403,6 @@ def get_converters( f"{fast_llm_prefix}.query", f"{hf_prefix}.q_proj", config.add_linear_biases, - QueryWeightConverter, - config, drop_on_export=drop_on_export, ), *get_weight_and_bias_converters( @@ -339,77 +422,38 @@ def get_converters( ] -class QueryWeightConverter(WeightConverter): - # Hf uses the real format for rotary embeddings. - _config: AttentionConfig - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (query,) = weight - return (query,) - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (query,) = weight - return (query,) - - -class KeyValueWeightConverter(WeightConverter): - # Hf uses the real format for rotary embeddings, and keeps the key and value separate. - _config: AttentionConfig - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (key_value,) = weight - key, value = key_value[:].chunk(2) - return key, value +class LlamaBlockConverter(ConfigSectionConverter): + """Converts ``DecoderBlockConfig`` ↔ Llama block fields (flat-merged into the parent's HF dict).""" - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - key, value = weight - key_value = torch.cat([key[:], value[:]]) - return (key_value,) + fast_llm_config_class = DecoderBlockConfig + mixer_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaAttentionConverter + mlp_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaMLPConverter + normalization_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaNormalizationConverter -class LlamaBlockConverter: - mixer_converter_class: typing.ClassVar[type[LlamaAttentionConverter]] = LlamaAttentionConverter - mlp_converter_class: typing.ClassVar[type[LlamaMLPConverter]] = LlamaMLPConverter - normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter hf_mixer_name: typing.ClassVar[str] = "self_attn" hf_mlp_name: typing.ClassVar[str] = "mlp" hf_norm_1_name: typing.ClassVar[str] = "input_layernorm" hf_norm_2_name: typing.ClassVar[str] = "post_attention_layernorm" @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - "mixer": cls.mixer_converter_class.import_config(config), - "mlp": cls.mlp_converter_class.import_config(config), - "normalization": cls.normalization_converter_class.import_config(config), + "mixer": NestedConfigConverter(("mixer",), cls.mixer_converter_class), + "mlp": NestedConfigConverter(("mlp",), cls.mlp_converter_class), + "normalization": NestedConfigConverter(("normalization",), cls.normalization_converter_class), + "pre_mixer_normalization": ConstantImportConfigConverter(("pre_mixer_normalization",), None), + "pre_mlp_normalization": ConstantImportConfigConverter(("pre_mlp_normalization",), None), + "post_mixer_normalization": ConstantImportConfigConverter(("post_mixer_normalization",), None), + "post_mlp_normalization": ConstantImportConfigConverter(("post_mlp_normalization",), None), + "output_scale": IgnoredConfigConverter(("output_scale",)), } @classmethod - def export_config(cls, config: DecoderBlockConfig) -> dict: - Assert.custom(isinstance, config, DecoderBlockConfig) - if config.output_scale.enabled: - raise NotImplementedError(f"`output_scale` is not supported by `{cls.__name__}`.") - if config.pre_mixer_normalization is not None: - raise NotImplementedError(f"`pre_mixer_normalization` is not supported by `{cls.__name__}`.") - if config.pre_mlp_normalization is not None: - raise NotImplementedError(f"`pre_mlp_normalization` is not supported by `{cls.__name__}`.") - if config.post_mixer_normalization is not None: - raise NotImplementedError(f"`post_mixer_normalization` is not supported by `{cls.__name__}`.") - if config.post_mlp_normalization is not None: - raise NotImplementedError(f"`post_mlp_normalization` is not supported by `{cls.__name__}`.") - return safe_merge_dicts( - cls.mixer_converter_class.export_config(config.mixer), - cls.mlp_converter_class.export_config(config.mlp), - cls.normalization_converter_class.export_config(config.normalization), - ) + def _validate_export(cls, config: DecoderBlockConfig) -> None: + assert not config.output_scale.enabled + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -443,35 +487,49 @@ def get_converters( ] +def _llama_decoder_export( + decoder_config: FixedBlockSequenceConfig | PatternBlockSequenceConfig, + block_converter_class: type[ConfigSectionConverter], +) -> dict: + """Convert a Fast-LLM polymorphic Fixed/Pattern block sequence to Llama's flat HF representation. + + Pattern: assert all blocks export identical HF (Llama's format has no per-block discriminator), then use + the common export. Fixed: just delegate to the single block. + """ + if isinstance(decoder_config, PatternBlockSequenceConfig): + exports = [block_converter_class.export_config(block) for block in decoder_config.blocks.values()] + for other in exports[1:]: + Assert.eq(exports[0], other) + block_hf = exports[0] + elif isinstance(decoder_config, FixedBlockSequenceConfig): + block_hf = block_converter_class.export_config(decoder_config.block) + else: + raise NotImplementedError(f"Unsupported decoder type: {type(decoder_config).__name__}") + return {**block_hf, "num_hidden_layers": decoder_config.num_blocks} + + class LlamaDecoderConverter: - block_converter_class: typing.ClassVar[type[LlamaBlockConverter]] = LlamaBlockConverter + """Imperative dispatcher for the polymorphic Fixed/Pattern block sequence. + + Used by formats that don't compose at the :class:`LlamaBaseModelConverter` level — currently only + Pixtral's vision encoder (:class:`PixtralEncoderConverter`) and Apriel's per-position hybrid layout + dispatcher inherit from it. The standard text formats (Mistral/Qwen2/Mixtral) use the inline dispatch + inside :class:`LlamaBaseModelConverter._create_config_converters` instead, parameterised by + ``block_converter_class``. + """ + + block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaBlockConverter @classmethod - def import_config(cls, config: dict) -> dict: + def import_config(cls, hf_dict: dict) -> dict: return { - "block": cls.block_converter_class.import_config(config), - "num_blocks": config["num_hidden_layers"], + "block": cls.block_converter_class.import_config(hf_dict), + "num_blocks": hf_dict["num_hidden_layers"], } @classmethod - def export_config(cls, config: FixedBlockSequenceConfig | PatternBlockSequenceConfig) -> dict: - if isinstance(config, PatternBlockSequenceConfig): - # All exported block configs must be equal - exported_block_configs = [ - safe_merge_dicts( - cls.block_converter_class.export_config(block_config), - {"num_hidden_layers": config.num_blocks}, - ) - for block_config in config.blocks.values() - ] - for other in exported_block_configs[1:]: - Assert.eq(exported_block_configs[0], other) - return exported_block_configs[0] - Assert.custom(isinstance, config, FixedBlockSequenceConfig) - return safe_merge_dicts( - cls.block_converter_class.export_config(config.block), - {"num_hidden_layers": config.num_blocks}, - ) + def export_config(cls, decoder_config: FixedBlockSequenceConfig | PatternBlockSequenceConfig) -> dict: + return _llama_decoder_export(decoder_config, cls.block_converter_class) @classmethod def get_converters( @@ -481,11 +539,10 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - # In the case of PatternBlockSequenceConfig, compatibility was already checked in export_config block_config = ( config.block if isinstance(config, FixedBlockSequenceConfig) else next(iter(config.blocks.values())) ) - converters = [] + converters: list[WeightConverter] = [] for block_index in range(config.num_blocks): converters += cls.block_converter_class.get_converters( block_config, @@ -496,17 +553,29 @@ def get_converters( return converters -class LlamaEmbeddingsConverter: +class LlamaEmbeddingsConverter(ConfigSectionConverter): + """Converts ``LanguageModelEmbeddingsConfig`` ↔ Llama (flat ``vocab_size``). + + Llama has no learnable position embeddings; ``num_position_embeddings`` is irrelevant when + ``position_embeddings.enabled`` is ``False``/``None`` and is therefore blanket-consumed. + """ + + fast_llm_config_class = LanguageModelEmbeddingsConfig + @classmethod - def import_config(cls, config: dict) -> dict: - return {"vocab_size": config["vocab_size"]} + def _create_config_converters(cls) -> dict: + return { + "vocab_size": RenameConfigConverter(("vocab_size",), ("vocab_size",)), + "word_embeddings": IgnoredConfigConverter(("word_embeddings",)), + "position_embeddings": IgnoredConfigConverter(("position_embeddings",), ("num_position_embeddings",)), + "embedding_scale": ConstantImportConfigConverter(("embedding_scale",), 1.0), + } @classmethod - def export_config(cls, config: LanguageModelEmbeddingsConfig) -> dict: - Assert.custom(isinstance, config, LanguageModelEmbeddingsConfig) - assert not config.position_embeddings.enabled - Assert.eq(config.embedding_scale, 1.0) - return {"vocab_size": config.vocab_size} + def _validate_export(cls, config: LanguageModelEmbeddingsConfig) -> None: + Assert.incl(config.position_embeddings.enabled, (None, False)) + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -515,20 +584,29 @@ def get_converters( return [WeightConverter(f"{fast_llm_prefix}.word_embeddings_weight", f"{hf_prefix}.embed_tokens.weight")] -class LlamaHeadConverter: - normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter - block_converter_class: typing.ClassVar[type[LlamaBlockConverter]] = LlamaBlockConverter +class LlamaHeadConverter(ConfigSectionConverter): + """Converts ``LanguageModelHeadConfig`` ↔ Llama final-norm fields (flat-merged).""" - @classmethod - def import_config(cls, config: dict) -> dict: - return {"normalization": cls.normalization_converter_class.import_config(config)} + fast_llm_config_class = LanguageModelHeadConfig + + normalization_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaNormalizationConverter + # Used by MTP-Llama subclass to emit per-prediction-head block weight converters; Llama itself doesn't read it. + block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaBlockConverter @classmethod - def export_config(cls, config: LanguageModelHeadConfig) -> dict: - Assert.custom(isinstance, config, LanguageModelHeadConfig) - if config.final_logit_softcap is not None: - raise NotImplementedError(f"`final_logit_softcap` is not supported by `{cls.__name__}`.") - return cls.normalization_converter_class.export_config(config.normalization) + def _create_config_converters(cls) -> dict: + return { + "normalization": NestedConfigConverter(("normalization",), cls.normalization_converter_class), + "output_weight": IgnoredConfigConverter(("output_weight",)), + # Llama HF format does not represent ``prediction_heads``; pin to 1 so any non-default value + # fails on export instead of silently round-tripping. MTP-Llama overrides this entry with a + # ``RenameConfigConverter`` (the override replaces the parent's declaration in the returned + # dict, so this ConstantImport never fires for MTP-Llama configs). + "prediction_heads": ConstantImportConfigConverter(("prediction_heads",), 1), + "final_logit_softcap": ConstantImportConfigConverter(("final_logit_softcap",), None), + } + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -551,40 +629,79 @@ def get_converters( ] -class LlamaBaseModelConverter(HuggingFaceBaseModelConverter): - # TODO: Peft? - decoder_converter_class: typing.ClassVar[type[LlamaDecoderConverter]] = LlamaDecoderConverter - embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter - head_converter_class: typing.ClassVar[type[LlamaHeadConverter]] = LlamaHeadConverter +class LlamaBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): + """Top-level converter for ``GPTBaseModelConfig`` ↔ Llama HF dict. + + Subclasses (Mistral, Qwen2, Mixtral, MTP-Llama, …) override ``block_converter_class`` to plug their + per-block declarations into the polymorphic Fixed/Pattern decoder dispatch held here. + """ + + fast_llm_config_class = GPTBaseModelConfig + + embeddings_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaEmbeddingsConverter + block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaBlockConverter + head_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaHeadConverter @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: + block_converter_class = cls.block_converter_class + + def _decoder_export(parent: Config) -> dict: + return {(k,): v for k, v in _llama_decoder_export(parent.decoder, block_converter_class).items()} + + def _decoder_import(hf_dict: dict) -> dict: + return { + ("decoder",): { + "block": block_converter_class.import_config(hf_dict), + "num_blocks": hf_dict["num_hidden_layers"], + } + } + return { - "embeddings": cls.embeddings_converter_class.import_config(config), - "decoder": cls.decoder_converter_class.import_config(config), - "head": cls.head_converter_class.import_config(config), - "hidden_size": config["hidden_size"], - "tied_embedding_weight": config["tie_word_embeddings"], + "embeddings": NestedConfigConverter(("embeddings",), cls.embeddings_converter_class), + "head": NestedConfigConverter(("head",), cls.head_converter_class), + "decoder": CustomConfigConverter( + fast_llm_paths=(("decoder",),), + # The block converter's flat-merge declarations claim all per-block top-level keys; pull + # them up here so the HF coverage check sees them as covered. ``num_hidden_layers`` is + # consumed by the Fixed/Pattern dispatch above. + hf_paths=( + ("num_hidden_layers",), + *block_converter_class._consumed_hf_paths(), + ), + export_fn=_decoder_export, + import_fn=_decoder_import, + recurses=True, + ), + "hidden_size": RenameConfigConverter(("hidden_size",), ("hidden_size",)), + "tied_embedding_weight": RenameConfigConverter(("tied_embedding_weight",), ("tie_word_embeddings",)), + # Llama format cannot represent PEFT; the NoPeftConfig assertion lives on _validate_export so a + # user-configured LoRA fails clearly rather than being silently dropped on export. + "peft": IgnoredConfigConverter(("peft",)), } @classmethod - def export_config(cls, config: GPTBaseModelConfig) -> dict: - Assert.custom(isinstance, config, GPTBaseModelConfig) - return safe_merge_dicts( - cls.embeddings_converter_class.export_config(config.embeddings), - cls.decoder_converter_class.export_config(config.decoder), - cls.head_converter_class.export_config(config.head), - { - "tie_word_embeddings": config.tied_embedding_weight, - "hidden_size": config.hidden_size, - }, - ) + def _validate_export(cls, config: GPTBaseModelConfig) -> None: + assert_no_peft(config) + + # --- weight side (imperative) --- @classmethod def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + decoder_config = config.decoder + block_config = ( + decoder_config.block + if isinstance(decoder_config, FixedBlockSequenceConfig) + else next(iter(decoder_config.blocks.values())) + ) + block_converters: list[WeightConverter] = [] + for block_index in range(decoder_config.num_blocks): + block_converters += cls.block_converter_class.get_converters( + block_config, f"decoder.{block_index}", f"model.layers.{block_index}" + ) return [ *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), - *cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.layers"), + *block_converters, *cls.head_converter_class.get_converters(config, exported_config), ] diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index d4a669b22..18251c760 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -1,57 +1,37 @@ import typing from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.layers.attention.config import AttentionConfig -from fast_llm.layers.decoder.mlp.config import MLPConfig +from fast_llm.engine.checkpoint.external import ConstantImportConfigConverter, RenameConfigConverter from fast_llm.models.gpt.conversion.config import MistralCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaAttentionConverter, LlamaBaseModelConverter, LlamaBlockConverter, - LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, LlamaMLPConverter, ) -from fast_llm.utils import safe_merge_dicts class MistralAttentionConverter(LlamaAttentionConverter): @classmethod - def import_config(cls, config: dict) -> dict: - config["attention_bias"] = False - return safe_merge_dicts( - super().import_config(config), - {"window_size": config["sliding_window"]}, - ) - - @classmethod - def export_config(cls, config: AttentionConfig) -> dict: - out = safe_merge_dicts( - super().export_config(config), - {"sliding_window": config.window_size}, - ) - del out["attention_bias"] - return out - - @classmethod - def _check_config(cls, config: AttentionConfig) -> None: - # Mistral doesn't support biases. - assert not config.add_linear_biases + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Mistral has no `attention_bias` HF field; biases are always disabled. + "add_linear_biases": ConstantImportConfigConverter(("add_linear_biases",), False), + "window_size": RenameConfigConverter(("window_size",), ("sliding_window",)), + } class MistralMLPConverter(LlamaMLPConverter): @classmethod - def import_config(cls, config: dict) -> dict: - config["mlp_bias"] = False - return super().import_config(config) - - @classmethod - def export_config(cls, config: MLPConfig) -> dict: - assert not config.add_linear_biases - out = super().export_config(config) - del out["mlp_bias"] - return out + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Mistral has no `mlp_bias` HF field; biases are always disabled. + "add_linear_biases": ConstantImportConfigConverter(("add_linear_biases",), False), + } class MistralBlockConverter(LlamaBlockConverter): @@ -59,16 +39,12 @@ class MistralBlockConverter(LlamaBlockConverter): mlp_converter_class: typing.ClassVar[type[MistralMLPConverter]] = MistralMLPConverter -class MistralDecoderConverter(LlamaDecoderConverter): - block_converter_class: typing.ClassVar[type[MistralBlockConverter]] = MistralBlockConverter - - class MistralHeadConverter(LlamaHeadConverter): block_converter_class: typing.ClassVar[type[MistralBlockConverter]] = MistralBlockConverter class MistralBaseModelConverter(LlamaBaseModelConverter): - decoder_converter_class: typing.ClassVar[type[MistralDecoderConverter]] = MistralDecoderConverter + block_converter_class: typing.ClassVar[type[MistralBlockConverter]] = MistralBlockConverter head_converter_class: typing.ClassVar[type[MistralHeadConverter]] = MistralHeadConverter diff --git a/fast_llm/models/gpt/conversion/mixtral.py b/fast_llm/models/gpt/conversion/mixtral.py index e13d1481a..64917df1b 100644 --- a/fast_llm/models/gpt/conversion/mixtral.py +++ b/fast_llm/models/gpt/conversion/mixtral.py @@ -1,54 +1,60 @@ import typing from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import SplitWeightConverter, WeightConverter -from fast_llm.layers.decoder.mlp.config import MoEMLPConfig +from fast_llm.engine.checkpoint.external import ( + ConstantImportConfigConverter, + IgnoredConfigConverter, + RenameConfigConverter, + SplitWeightConverter, + WeightConverter, +) +from fast_llm.layers.decoder.mlp.config import MoEMLPConfig, RoutingType from fast_llm.models.gpt.conversion.config import MixtralCheckpointFormat from fast_llm.models.gpt.conversion.llama import LlamaMLPConverter, MLPLayer2Converter, get_weight_and_bias_converters from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, MistralBlockConverter, - MistralDecoderConverter, MistralHeadConverter, MistralHuggingfaceCheckpointHandler, ) -from fast_llm.utils import Assert, safe_merge_dicts class MixtralMLPConverter(LlamaMLPConverter): + fast_llm_config_class = MoEMLPConfig + @classmethod - def import_config(cls, config: dict) -> dict: - config["mlp_bias"] = False - return safe_merge_dicts( - super().import_config(config), - { - "type": "moe", - "experts": config["num_local_experts"], - "experts_per_token": config["num_experts_per_tok"], - }, - ) + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Mixtral has no `mlp_bias` HF field; biases are always disabled. + "add_linear_biases": ConstantImportConfigConverter(("add_linear_biases",), False), + "experts": RenameConfigConverter(("experts",), ("num_local_experts",)), + "experts_per_token": RenameConfigConverter(("experts_per_token",), ("num_experts_per_tok",)), + # Mixtral has no shared experts and uses the topk default; assert on export, inject defaults on import. + "shared_experts": ConstantImportConfigConverter(("shared_experts",), 0), + "routing": ConstantImportConfigConverter(("routing",), RoutingType.topk), + # Mixtral has no HF representation for the router sub-config. The blanket consume satisfies + # architecture coverage; non-architecture fields (lr_scale, apply_peft, weight.initialization, + # weight.lr_scale) cannot round-trip through the HF format by design — Fast-LLM keeps them on + # the in-memory config independently. The only architecture-hint sub-field is ``router.weight``, + # a ParameterConfig with no architecture sub-fields, so the blanket carries no real risk. + "router": IgnoredConfigConverter(("router",)), + "router_normalization": ConstantImportConfigConverter(("router_normalization",), None), + "router_scale": IgnoredConfigConverter(("router_scale",)), + "router_input_scale": ConstantImportConfigConverter(("router_input_scale",), 1.0), + "router_per_expert_scale": IgnoredConfigConverter(("router_per_expert_scale",)), + # Router / inference toggles surfaced by HF but not consumed by Fast-LLM's MoEMLPConfig + # (auxiliary_loss_coefficient and jitter_eps are FieldHint.feature, not architecture). + "router_runtime_unsupported": IgnoredConfigConverter( + hf_paths=(("router_aux_loss_coef",), ("router_jitter_noise",), ("output_router_logits",)), + ), + } @classmethod - def export_config(cls, config: MoEMLPConfig) -> dict: - Assert.custom(isinstance, config, MoEMLPConfig) - assert not config.add_linear_biases - if config.router_normalization is not None: - raise NotImplementedError(f"`router_normalization` is not supported by `{cls.__name__}`.") - if config.router_scale.enabled: - raise NotImplementedError(f"`router_scale` is not supported by `{cls.__name__}`.") - if config.router_input_scale != 1.0: - raise NotImplementedError(f"`router_input_scale != 1.0` is not supported by `{cls.__name__}`.") - if config.router_per_expert_scale.enabled: - raise NotImplementedError(f"`router_per_expert_scale` is not supported by `{cls.__name__}`.") - out = super().export_config(config) - del out["mlp_bias"] - return safe_merge_dicts( - out, - { - "num_local_experts": config.experts, - "num_experts_per_tok": config.experts_per_token, - }, - ) + def _validate_export(cls, config: MoEMLPConfig) -> None: + super()._validate_export(config) + assert not config.router_scale.enabled + assert not config.router_per_expert_scale.enabled @classmethod def get_converters( @@ -87,16 +93,12 @@ class MixtralBlockConverter(MistralBlockConverter): mlp_converter_class: typing.ClassVar[type[MixtralMLPConverter]] = MixtralMLPConverter -class MixtralDecoderConverter(MistralDecoderConverter): - block_converter_class: typing.ClassVar[type[MixtralBlockConverter]] = MixtralBlockConverter - - class MixtralHeadConverter(MistralHeadConverter): block_converter_class: typing.ClassVar[type[MixtralBlockConverter]] = MixtralBlockConverter class MixtralBaseModelConverter(MistralBaseModelConverter): - decoder_converter_class: typing.ClassVar[type[MixtralDecoderConverter]] = MixtralDecoderConverter + block_converter_class: typing.ClassVar[type[MixtralBlockConverter]] = MixtralBlockConverter head_converter_class: typing.ClassVar[type[MixtralHeadConverter]] = MixtralHeadConverter diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index f681c4a24..6f6d9e88a 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -3,44 +3,36 @@ from transformers import PretrainedConfig from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter -from fast_llm.layers.block.config import FixedBlockSequenceConfig -from fast_llm.layers.language_model.config import LanguageModelConfig, LanguageModelHeadConfig +from fast_llm.engine.checkpoint.external import RenameConfigConverter, WeightConverter +from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import MTPLlamaCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaBaseModelConverter, - LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, get_parameter_converter, ) -from fast_llm.utils import Assert, safe_merge_dicts +from fast_llm.utils import safe_merge_dicts class MTPLlamaHeadConverter(LlamaHeadConverter): @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - **super().import_config(config), - "prediction_heads": config["prediction_heads"], + **super()._create_config_converters(), + # MTP-Llama exposes the prediction-heads count via the HF config; Llama itself blanket-ignores it. + "prediction_heads": RenameConfigConverter(("prediction_heads",), ("prediction_heads",)), } - @classmethod - def export_config(cls, config: LanguageModelHeadConfig) -> dict: - return safe_merge_dicts( - super().export_config(config), - {"prediction_heads": config.prediction_heads}, - ) - @classmethod def get_converters( cls, config: LanguageModelConfig, exported_config: dict, ) -> list[WeightConverter]: - # Override: map head.final_norm to model.mtp_norms.0 (not model.norm as in standard Llama), - # since MTPLlamaModel uses mtp_norms[0] for the first prediction head. + # MTP-Llama uses ``model.mtp_norms.0`` for the first prediction head's final norm + # instead of the standard ``model.norm``. converters = [ *cls.normalization_converter_class.get_converters( config.head.normalization, @@ -68,26 +60,7 @@ def get_converters( return converters -class MTPLlamaDecoderConverter(LlamaDecoderConverter): - @classmethod - def import_config(cls, config: dict) -> dict: - return { - "block": cls.block_converter_class.import_config(config), - "num_blocks": config["num_hidden_layers"], - } - - @classmethod - def export_config(cls, config: FixedBlockSequenceConfig) -> dict: - # TODO: Support PatternBlockSequenceConfig with compatible configs. - Assert.custom(isinstance, config, FixedBlockSequenceConfig) - return safe_merge_dicts( - cls.block_converter_class.export_config(config.block), - {"num_hidden_layers": config.num_blocks}, - ) - - class MTPLlamaBaseModelConverter(LlamaBaseModelConverter): - decoder_converter_class: typing.ClassVar[type[MTPLlamaDecoderConverter]] = MTPLlamaDecoderConverter head_converter_class: typing.ClassVar[type[MTPLlamaHeadConverter]] = MTPLlamaHeadConverter diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index 9aa2f8c8e..6a4f4f385 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -1,10 +1,14 @@ import typing from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.engine.checkpoint.external import ( + ConstantImportConfigConverter, + IgnoredConfigConverter, + ImportOnlyConfigConverter, + WeightConverter, +) from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import FixedBlockSequenceConfig -from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.models.gpt.config import GPTBaseModelConfig from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat from fast_llm.models.gpt.conversion.llama import ( @@ -12,42 +16,50 @@ LlamaAttentionConverter, LlamaBaseModelConverter, LlamaBlockConverter, - LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, LlamaMLPConverter, - QueryWeightConverter, get_weight_and_bias_converters, ) -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div class Qwen2AttentionConverter(LlamaAttentionConverter): # TODO: Support sliding window with max_window_layers (need 2 kinds of block?) @classmethod - def import_config(cls, config: dict) -> dict: - config["attention_bias"] = False - out = super().import_config(config) - out["query_layer"] = {"bias": {"enabled": True}} - out["key_layer"] = {"bias": {"enabled": True}} - out["value_layer"] = {"bias": {"enabled": True}} - out["dense_layer"] = {"bias": {"enabled": False}} - return out - - @classmethod - def export_config(cls, config: AttentionConfig) -> dict: - out = super().export_config(config) - del out["attention_bias"] - # Qwen2Config does not have head_dim as a standard field; it is always - # derivable as hidden_size // num_attention_heads. - del out["head_dim"] - return out + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Qwen2 has no `attention_bias` HF field; the model always has Q/K/V biases enabled and no dense bias. + "add_linear_biases": ConstantImportConfigConverter(("add_linear_biases",), False), + # Qwen2Config does not have `head_dim`; it is always derivable as `hidden_size // num_attention_heads`. + "head_size": ImportOnlyConfigConverter( + fast_llm_paths=(("head_size",),), + import_fn=lambda hf: {("head_size",): div(hf["hidden_size"], hf["num_attention_heads"])}, + ), + # Override Llama's blanket per-layer bias ignore with Qwen2's hardcoded layer biases. + # On export the per-layer biases must be compatible with `add_linear_biases`; see ``_validate_export``. + "linear_layers": ImportOnlyConfigConverter( + fast_llm_paths=(("query_layer",), ("key_layer",), ("value_layer",), ("dense_layer",)), + import_fn=lambda hf: { + ("query_layer",): {"bias": {"enabled": True}}, + ("key_layer",): {"bias": {"enabled": True}}, + ("value_layer",): {"bias": {"enabled": True}}, + ("dense_layer",): {"bias": {"enabled": False}}, + }, + recurses=True, + ), + # Sliding-window machinery surfaced by Qwen2 HF but not yet supported here (see TODO above). + "sliding_window_unsupported": IgnoredConfigConverter( + hf_paths=(("sliding_window",), ("use_sliding_window",), ("max_window_layers",), ("layer_types",)), + ), + } @classmethod - def _check_config(cls, config: AttentionConfig) -> None: + def _validate_export(cls, config: AttentionConfig) -> None: Assert.is_(type(config), AttentionConfig) - # There are multiple ways to enable biases on QKV only + # There are multiple ways to enable biases on QKV only. if config.add_linear_biases: Assert.incl(config.query_layer.bias.enabled, (None, True)) Assert.incl(config.key_layer.bias.enabled, (None, True)) @@ -72,8 +84,6 @@ def get_converters( f"{fast_llm_prefix}.query", f"{hf_prefix}.q_proj", True, - QueryWeightConverter, - config, drop_on_export=drop_on_export, ), *get_weight_and_bias_converters( @@ -95,15 +105,12 @@ def get_converters( class Qwen2MLPConverter(LlamaMLPConverter): @classmethod - def import_config(cls, config: dict) -> dict: - config["mlp_bias"] = False - return super().import_config(config) - - @classmethod - def export_config(cls, config: MLPConfig) -> dict: - out = super().export_config(config) - del out["mlp_bias"] - return out + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Qwen2 has no `mlp_bias` HF field; biases are always disabled. + "add_linear_biases": ConstantImportConfigConverter(("add_linear_biases",), False), + } class Qwen2BlockConverter(LlamaBlockConverter): @@ -111,25 +118,36 @@ class Qwen2BlockConverter(LlamaBlockConverter): mlp_converter_class: typing.ClassVar[type[Qwen2MLPConverter]] = Qwen2MLPConverter -class Qwen2DecoderConverter(LlamaDecoderConverter): +class Qwen2HeadConverter(LlamaHeadConverter): block_converter_class: typing.ClassVar[type[Qwen2BlockConverter]] = Qwen2BlockConverter -class Qwen2HeadConverter(LlamaHeadConverter): - block_converter_class: typing.ClassVar[type[Qwen2BlockConverter]] = Qwen2BlockConverter +def _qwen2_mrope_guard_import(hf_dict: dict) -> dict: + if hf_dict.get("use_mrope") is True: + raise AssertionError("MRoPE (use_mrope=True) is not supported by the Qwen2 converter") + return {} class Qwen2BaseModelConverter(LlamaBaseModelConverter): - decoder_converter_class: typing.ClassVar[type[Qwen2DecoderConverter]] = Qwen2DecoderConverter + block_converter_class: typing.ClassVar[type[Qwen2BlockConverter]] = Qwen2BlockConverter head_converter_class: typing.ClassVar[type[Qwen2HeadConverter]] = Qwen2HeadConverter @classmethod - def import_config(cls, config: dict) -> dict: - assert config.get("use_mrope") is not True, "MRoPE (use_mrope=True) is not supported by the Qwen2 converter" - return super().import_config(config) + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Refuse MRoPE on import; the export path can't produce ``use_mrope=True`` because Fast-LLM + # has no rotary type that maps to it. + "use_mrope_guard": ImportOnlyConfigConverter( + fast_llm_paths=(), + hf_paths=(("use_mrope",),), + import_fn=_qwen2_mrope_guard_import, + ), + } @classmethod - def export_config(cls, config: GPTBaseModelConfig) -> dict: + def _validate_export(cls, config: GPTBaseModelConfig) -> None: + super()._validate_export(config) block = ( config.decoder.block if isinstance(config.decoder, FixedBlockSequenceConfig) @@ -141,7 +159,6 @@ def export_config(cls, config: GPTBaseModelConfig) -> dict: config.hidden_size, msg="Qwen2 format omits head_dim; requires heads * head_size == hidden_size", ) - return super().export_config(config) class Qwen2HuggingfaceCheckpointHandler(LlamaHuggingfaceCheckpointHandler): diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index 8a947baaa..2064f6106 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -3,17 +3,31 @@ import typing from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + ConstantExportConfigConverter, + ConstantImportConfigConverter, + CustomConfigConverter, + IgnoredConfigConverter, + NestedConfigConverter, + OptionalConfigConverter, + RenameConfigConverter, + WeightConverter, +) +from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Rotary2DConfig +from fast_llm.layers.block.config import FixedBlockSequenceConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionEncoderConfig from fast_llm.models.gpt.conversion.apriel2 import ( Apriel2BaseModelConverter, - Apriel2DecoderConverter, Apriel2HeadConverter, + Apriel2RMSNormConverter, + get_apriel2_decoder_converter, ) from fast_llm.models.gpt.conversion.llama import ( LlamaEmbeddingsConverter, @@ -25,170 +39,297 @@ from fast_llm.models.multimodal.conversion.config import Apriel2CheckpointFormat from fast_llm.models.multimodal.conversion.llava import ( LlavaVisionAdapterConverter, - LlavaVisionModelConverter, PatchEmbeddingWeightConverter, PixtralAttentionConverter, - PixtralBlockConverter, - PixtralEncoderConverter, ) from fast_llm.models.multimodal.model import MultiModalModel -from fast_llm.utils import Assert, safe_merge_dicts +from fast_llm.utils import Assert + + +def _apriel2_vision_attention_rotary_export(config: AttentionConfig) -> dict: + """Emit the Apriel2-vision rotary subdict. Two rotary types are supported: + :class:`Rotary2DConfig` (HF ``pixtral_2d``) and :class:`DefaultRotaryConfig` (HF ``mistral_1d``). + ``patch_size``/``max_image_size`` HF metadata is injected by the parent vision-model converter + (it derives from ``embeddings.patch_height``, outside this scope).""" + rotary = config.rotary + if type(rotary) is Rotary2DConfig: + return {("rotary",): {"type": "pixtral_2d", "theta": rotary.theta}} + if type(rotary) is DefaultRotaryConfig: + return {("rotary",): {"type": "mistral_1d", "theta": rotary.theta}} + raise NotImplementedError(f"Unsupported rotary type: {type(rotary).__name__}") + + +def _apriel2_vision_attention_rotary_import(hf_dict: dict) -> dict: + rotary = dict(hf_dict["rotary"]) + if rotary.get("type") == "pixtral_2d": + rotary["type"] = "default_2d" + elif rotary.get("type") == "mistral_1d": + rotary["type"] = "default" + rotary.pop("patch_size", None) + rotary.pop("max_image_size", None) + return {("rotary",): rotary} class Apriel2VisionAttentionConverter(PixtralAttentionConverter): + """Converts :class:`AttentionConfig` ↔ Apriel2 vision attention HF subdict (typed ``"attention"``). + + Apriel2's vision attention shape uses Apriel2-native field names (``heads``, ``head_groups``, ``head_size``, + ``add_linear_biases``, ``causal``) plus an explicit ``cross_document_attention=False`` flag and a nested + typed ``rotary`` block. Differs from the text :class:`Apriel2AttentionConverter` mainly in the rotary type + set (``pixtral_2d``/``mistral_1d`` instead of ``mistral_1d``/``llama3``/``yarn``) and the lack of + per-layer-bias and ``window_size`` representations. + + Inherits :meth:`get_converters` from :class:`PixtralAttentionConverter` (Llama-style q/k/v/o weight layout). + """ + + hf_type_name = "attention" + @classmethod - def import_config(cls, config: dict) -> dict: - rotary = config["rotary"].copy() - # Map Apriel2 HuggingFace rotary type to Fast-LLM internal type - if rotary.get("type") == "pixtral_2d": - rotary["type"] = "default_2d" - # Strip HF-specific fields not needed by Fast-LLM's Rotary2DConfig - # (Fast-LLM computes patch_positions dynamically from actual image patches) - rotary.pop("max_image_size", None) - rotary.pop("patch_size", None) + def _create_config_converters(cls) -> dict: + # Replace Pixtral's declarations wholesale: Apriel2 vision uses Apriel2-native field names, allows GQA + # and both Rotary2D + DefaultRotary, and has no HF representation for per-layer biases or window_size. return { - "rotary": rotary, - "heads": config["heads"], - "head_groups": config["head_groups"], - "head_size": config["head_size"], - "add_linear_biases": config["add_linear_biases"], - "causal": config["causal"], + "heads": RenameConfigConverter(("heads",), ("heads",)), + "head_groups": RenameConfigConverter(("head_groups",), ("head_groups",)), + "head_size": RenameConfigConverter(("head_size",), ("head_size",)), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("add_linear_biases",)), + "causal": RenameConfigConverter(("causal",), ("causal",)), + "cross_document_attention": ConstantExportConfigConverter(("cross_document_attention",), False), + "rotary": CustomConfigConverter( + fast_llm_paths=(("rotary",),), + hf_paths=(("rotary",),), + export_fn=_apriel2_vision_attention_rotary_export, + import_fn=_apriel2_vision_attention_rotary_import, + recurses=True, + ), + # Apriel2 vision attention has no per-layer bias representation; the Fast-LLM defaults round-trip. + "linear_layers": IgnoredConfigConverter( + ("query_layer",), ("key_layer",), ("value_layer",), ("dense_layer",) + ), + "softmax_scale_power": IgnoredConfigConverter(("softmax_scale_power",)), + "query_norm": ConstantImportConfigConverter(("query_norm",), None), + "key_norm": ConstantImportConfigConverter(("key_norm",), None), + "value_norm": ConstantImportConfigConverter(("value_norm",), None), + "shared_key_value": ConstantImportConfigConverter(("shared_key_value",), False), } @classmethod - def export_config(cls, config: AttentionConfig) -> dict: - from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Rotary2DConfig + def _validate_export(cls, config: AttentionConfig) -> None: + # Replace Pixtral's Rotary2D-only + head_groups==heads checks (Apriel2 vision allows both rotary types + # and supports GQA). Keep the per-layer bias consistency check from the Llama base. + Assert.incl(config.query_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.key_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.value_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.dense_layer.bias.enabled, (None, config.add_linear_biases)) + + +class Apriel2VisionMLPConverter(ConfigSectionConverter): + """The vision-side MLP shape ``{type: mlp, intermediate_size, activation, gated, add_linear_biases}``. + + Distinct from the text :class:`Apriel2MLPConverter` only in lacking the per-layer-bias declaration: the + Apriel2 vision MLP HF shape has no representation for per-layer ``bias.enabled`` overrides, so the + Fast-LLM defaults are dropped on export (declared :class:`IgnoredConfigConverter`) and re-defaulted on + import. Weight-side ``get_converters`` is shared with the text MLP. + """ - if type(config.rotary) is Rotary2DConfig: - rotary_type = "pixtral_2d" - elif type(config.rotary) is DefaultRotaryConfig: - rotary_type = "mistral_1d" - else: - raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") + fast_llm_config_class = MLPConfig + hf_type_name = "mlp" + @classmethod + def _create_config_converters(cls) -> dict: return { - "type": "attention", - "heads": config.heads, - "head_groups": config.head_groups, - "head_size": config.head_size, - "add_linear_biases": config.add_linear_biases, - "causal": config.causal, - "cross_document_attention": False, - "rotary": { - "type": rotary_type, - "theta": config.rotary.theta, - }, + "intermediate_size": RenameConfigConverter(("intermediate_size",), ("intermediate_size",)), + "gated": RenameConfigConverter(("gated",), ("gated",)), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("add_linear_biases",)), + "activation": CustomConfigConverter( + fast_llm_paths=(("activation",),), + hf_paths=(("activation",),), + export_fn=lambda c: {("activation",): c.activation.hf_name}, + import_fn=lambda hf: {("activation",): ActivationType.from_hf_name(hf["activation"])}, + ), + "linear_layers": IgnoredConfigConverter(("layer_1",), ("layer_2",)), + "pre_norm": ConstantImportConfigConverter(("pre_norm",), None), + "post_norm": ConstantImportConfigConverter(("post_norm",), None), } + @classmethod + def get_converters( + cls, config: MLPConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False + ) -> list[WeightConverter]: + from fast_llm.models.gpt.conversion.apriel2 import Apriel2MLPConverter + + return Apriel2MLPConverter.get_converters(config, fast_llm_prefix, hf_prefix, drop_on_export) + + +class Apriel2VisionBlockConverter(ConfigSectionConverter): + """Converts a vision :class:`DecoderBlockConfig` ↔ Apriel2's nested ``{mixer, mlp, normalization}`` block. + + Distinct from :class:`PixtralBlockConverter` (which flat-merges its children into the parent's HF dict) + because the Apriel2 vision format nests each sub-section under a typed sub-key, matching the Apriel2 text + decoder shape. + """ + + fast_llm_config_class = DecoderBlockConfig -class Apriel2VisionBlockConverter(PixtralBlockConverter): mixer_converter_class: typing.ClassVar[type[Apriel2VisionAttentionConverter]] = Apriel2VisionAttentionConverter + mlp_converter_class: typing.ClassVar[type[Apriel2VisionMLPConverter]] = Apriel2VisionMLPConverter + # Config-side: the Apriel2 HF format nests normalization as ``{"type": "rms_norm", "epsilon": ...}``; + # ``Apriel2RMSNormConverter`` handles the typed shape. Weight side uses LlamaNormalizationConverter + # directly (flat parameter names — independent of how the surrounding HF config is structured). + normalization_converter_class: typing.ClassVar[type[Apriel2RMSNormConverter]] = Apriel2RMSNormConverter + hf_mixer_name: typing.ClassVar[str] = "mixer" hf_mlp_name: typing.ClassVar[str] = "mlp" hf_norm_1_name: typing.ClassVar[str] = "input_layernorm" hf_norm_2_name: typing.ClassVar[str] = "post_attention_layernorm" @classmethod - def import_config(cls, config: dict, block_config: dict) -> dict: - mixer_config = block_config["mixer"] - mlp_config = block_config["mlp"] - norm_config = block_config["normalization"] - + def _create_config_converters(cls) -> dict: return { - "mixer": cls.mixer_converter_class.import_config(mixer_config), - "mlp": { - "type": "mlp", - "intermediate_size": mlp_config["intermediate_size"], - "activation": ActivationType.from_hf_name(mlp_config["activation"]), - "gated": mlp_config["gated"], - "add_linear_biases": mlp_config["add_linear_biases"], - }, - "normalization": cls.normalization_converter_class.import_config(norm_config), + "mixer": NestedConfigConverter(("mixer",), cls.mixer_converter_class, hf_path=("mixer",)), + "mlp": NestedConfigConverter(("mlp",), cls.mlp_converter_class, hf_path=("mlp",)), + "normalization": NestedConfigConverter( + ("normalization",), cls.normalization_converter_class, hf_path=("normalization",) + ), + "pre_mixer_normalization": ConstantImportConfigConverter(("pre_mixer_normalization",), None), + "pre_mlp_normalization": ConstantImportConfigConverter(("pre_mlp_normalization",), None), + "post_mixer_normalization": ConstantImportConfigConverter(("post_mixer_normalization",), None), + "post_mlp_normalization": ConstantImportConfigConverter(("post_mlp_normalization",), None), + "output_scale": IgnoredConfigConverter(("output_scale",)), } @classmethod - def export_config(cls, config) -> dict: - from fast_llm.layers.decoder.config import DecoderBlockConfig + def _validate_export(cls, config: DecoderBlockConfig) -> None: + assert not config.output_scale.enabled - Assert.custom(isinstance, config, DecoderBlockConfig) - return { - "mixer": cls.mixer_converter_class.export_config(config.mixer), - "mlp": { - "type": "mlp", - "intermediate_size": config.mlp.intermediate_size, - "activation": config.mlp.activation.hf_name, - "gated": config.mlp.gated, - "add_linear_biases": config.mlp.add_linear_biases, - }, - "normalization": { - "type": "rms_norm", - "epsilon": config.normalization.epsilon, - }, - } + # --- weight side (imperative) --- + @classmethod + def get_converters( + cls, config: DecoderBlockConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False + ) -> list[WeightConverter]: + return [ + *cls.mixer_converter_class.get_converters( + config.mixer, f"{fast_llm_prefix}.mixer", f"{hf_prefix}.{cls.hf_mixer_name}", drop_on_export + ), + *cls.mlp_converter_class.get_converters( + config.mlp, f"{fast_llm_prefix}.mlp", f"{hf_prefix}.{cls.hf_mlp_name}", drop_on_export + ), + *LlamaNormalizationConverter.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm_1", + f"{hf_prefix}.{cls.hf_norm_1_name}", + drop_on_export, + ), + *LlamaNormalizationConverter.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm_2", + f"{hf_prefix}.{cls.hf_norm_2_name}", + drop_on_export, + ), + ] + + +class Apriel2VisionEncoderConverter(ConfigSectionConverter): + """Converts a :class:`FixedBlockSequenceConfig` (vision encoder) ↔ Apriel2 HF ``encoder`` subdict + the + flat ``num_hidden_layers`` mirror that the HF format also requires at the surrounding vision_config level. + + No ``hf_type_name`` is set: the ``type: "fixed"`` discriminator lives *inside* the ``encoder`` subdict + (emitted by the Custom's export_fn), not at the parent vision_config level. The Fast-LLM-side ``type`` + is auto-injected by :meth:`ConfigSectionConverter.import_config` via ``fast_llm_config_class.dynamic_type_name``. + """ + + fast_llm_config_class = FixedBlockSequenceConfig -class Apriel2VisionEncoderConverter(PixtralEncoderConverter): block_converter_class: typing.ClassVar[type[Apriel2VisionBlockConverter]] = Apriel2VisionBlockConverter @classmethod - def import_config(cls, config: dict) -> dict: - encoder_config = config["encoder"] - num_blocks = encoder_config["num_blocks"] - block_config = encoder_config["block"] - + def _create_config_converters(cls) -> dict: return { - "type": "fixed", - "num_blocks": num_blocks, - "block": cls.block_converter_class.import_config(config, block_config), + "encoder": CustomConfigConverter( + fast_llm_paths=(("num_blocks",), ("block",)), + hf_paths=(("encoder",),), + export_fn=lambda c: { + ("encoder",): { + "type": "fixed", + "num_blocks": c.num_blocks, + "block": cls.block_converter_class.export_config(c.block), + }, + }, + import_fn=lambda hf: { + ("num_blocks",): hf["encoder"]["num_blocks"], + ("block",): cls.block_converter_class.import_config(hf["encoder"]["block"]), + }, + recurses=True, + ), + "num_hidden_layers_mirror": CustomConfigConverter( + fast_llm_paths=(), + hf_paths=(("num_hidden_layers",),), + export_fn=lambda c: {("num_hidden_layers",): c.num_blocks}, + import_fn=lambda hf: {}, + ), } + # --- weight side (imperative) --- + @classmethod - def export_config(cls, config) -> dict: - from fast_llm.layers.block.config import FixedBlockSequenceConfig + def get_converters( + cls, + config: FixedBlockSequenceConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + converters: list[WeightConverter] = [] + for block_index in range(config.num_blocks): + converters += cls.block_converter_class.get_converters( + config.block, + f"{fast_llm_prefix}.{block_index}", + f"{hf_prefix}.{block_index}", + drop_on_export, + ) + return converters - Assert.custom(isinstance, config, FixedBlockSequenceConfig) - return { - "encoder": { - "type": "fixed", - "num_blocks": config.num_blocks, - "block": cls.block_converter_class.export_config(config.block), - }, - "num_hidden_layers": config.num_blocks, - } +class Apriel2EmbeddingsConverter(ConfigSectionConverter): + """Converts :class:`PatchEmbeddingsConfig` ↔ Apriel2 HF ``embeddings`` subdict, with top-level + ``patch_size``/``num_channels`` mirrors that the Apriel2 vision_config also requires.""" -class Apriel2EmbeddingsConverter: - """Converts between Fast-LLM PatchEmbeddingsConfig and Apriel2 HF embeddings format.""" + fast_llm_config_class = PatchEmbeddingsConfig - normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter + normalization_converter_class: typing.ClassVar[type[Apriel2RMSNormConverter]] = Apriel2RMSNormConverter @classmethod - def import_config(cls, config: dict) -> dict: - embeddings_config = config["embeddings"] - Assert.eq(embeddings_config["input_channels"], 3) + def _create_config_converters(cls) -> dict: return { - "normalization": embeddings_config["normalization"], - "patch_height": embeddings_config["patch_height"], - "patch_width": embeddings_config["patch_width"], + "patch_height": RenameConfigConverter(("patch_height",), ("embeddings", "patch_height")), + "patch_width": RenameConfigConverter(("patch_width",), ("embeddings", "patch_width")), + "normalization": NestedConfigConverter( + ("normalization",), + cls.normalization_converter_class, + hf_path=("embeddings", "normalization"), + ), + # ``patch_embeddings`` (AffineLinearConfig) carries no HF architecture info; bias presence validated below. + "patch_embeddings": IgnoredConfigConverter(("patch_embeddings",)), + # ``input_channels`` is a cached_property pinned to 3 on the Fast-LLM side; HF emits it under + # ``embeddings`` and again as a top-level ``num_channels`` mirror. + "embeddings_input_channels": ConstantExportConfigConverter(("embeddings", "input_channels"), 3), + "num_channels_mirror": ConstantExportConfigConverter(("num_channels",), 3), + # ``patch_size`` HF top-level mirror of ``embeddings.patch_height`` — emit on export, ignored on + # import (the under-``embeddings`` path is the authoritative source). + "patch_size_mirror": CustomConfigConverter( + fast_llm_paths=(), + hf_paths=(("patch_size",),), + export_fn=lambda c: {("patch_size",): c.patch_height}, + import_fn=lambda hf: {}, + ), } @classmethod - def export_config(cls, config: PatchEmbeddingsConfig) -> dict: - Assert.custom(isinstance, config, PatchEmbeddingsConfig) + def _validate_export(cls, config: PatchEmbeddingsConfig) -> None: Assert.eq(config.patch_height, config.patch_width) Assert.incl(config.patch_embeddings.bias.enabled, (None, False)) - return { - "embeddings": { - "patch_height": config.patch_height, - "patch_width": config.patch_width, - "input_channels": config.input_channels, - "normalization": {"type": "rms_norm", "epsilon": config.normalization.epsilon}, - }, - "patch_size": config.patch_height, - "num_channels": config.input_channels, - } - @classmethod def get_converters( cls, config: PatchEmbeddingsConfig, fast_llm_prefix: str, hf_prefix: str @@ -201,82 +342,100 @@ def get_converters( PatchEmbeddingWeightConverter, config, ), - *cls.normalization_converter_class.get_converters( + *LlamaNormalizationConverter.get_converters( config.normalization, f"{fast_llm_prefix}.normalization", f"{hf_prefix}.normalization" ), ] class Apriel2VisionAdapterConverter(LlavaVisionAdapterConverter): + """Converts :class:`MLPConfig` (adapter) ↔ Apriel2 HF ``adapter`` subdict. + + Apriel2 nests the adapter shape under ``adapter`` and uses the typed ``{type: mlp, ...}`` dict-of-fields + layout, distinct from Llava's flat top-level ``projector_hidden_act``/``multimodal_projector_bias`` shape. + + Inherits declarative ``import_config``/``export_config`` from :class:`ConfigSectionConverter` via + :class:`LlavaVisionAdapterConverter`, and weight-side ``get_converters`` from Llava (same ``linear_1`` / + ``linear_2`` weight names as Llava). + """ + + fast_llm_config_class = MLPConfig + hf_type_name = "mlp" + @classmethod - def import_config(cls, config: dict) -> dict: - adapter_config = config["adapter"] + def _create_config_converters(cls) -> dict: return { - "intermediate_size": adapter_config["intermediate_size"], - "add_linear_biases": adapter_config["add_linear_biases"], - "gated": adapter_config["gated"], - "activation": ActivationType.from_hf_name(adapter_config["activation"]), + "intermediate_size": RenameConfigConverter(("intermediate_size",), ("intermediate_size",)), + "gated": RenameConfigConverter(("gated",), ("gated",)), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("add_linear_biases",)), + "activation": CustomConfigConverter( + fast_llm_paths=(("activation",),), + hf_paths=(("activation",),), + export_fn=lambda c: {("activation",): c.activation.hf_name}, + import_fn=lambda hf: {("activation",): ActivationType.from_hf_name(hf["activation"])}, + ), + "linear_layers": IgnoredConfigConverter(("layer_1",), ("layer_2",)), + "pre_norm": ConstantImportConfigConverter(("pre_norm",), None), + "post_norm": ConstantImportConfigConverter(("post_norm",), None), } @classmethod - def export_config(cls, config: MLPConfig) -> dict: - Assert.custom(isinstance, config, MLPConfig) + def _validate_export(cls, config: MLPConfig) -> None: Assert.incl(config.layer_1.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) - return { - "adapter": { - "type": "mlp", - "intermediate_size": config.intermediate_size, - "activation": config.activation.hf_name, - "add_linear_biases": config.add_linear_biases, - "gated": config.gated, - }, - } +class Apriel2VisionModelConverter(ConfigSectionConverter): + """Top-level vision-encoder converter. The HF representation lives under a single ``vision_encoder`` key, + so declarations are written relative to that nested subdict. + + ``patch_size``/``max_image_size`` rotary metadata is injected here (cross-section reference to + ``embeddings.patch_height``) — the attention converter cannot see it from its own scope. + """ + + fast_llm_config_class = VisionEncoderConfig -class Apriel2VisionModelConverter(LlavaVisionModelConverter): + embeddings_converter_class: typing.ClassVar[type[Apriel2EmbeddingsConverter]] = Apriel2EmbeddingsConverter + encoder_converter_class: typing.ClassVar[type[Apriel2VisionEncoderConverter]] = Apriel2VisionEncoderConverter vision_adapter_converter_class: typing.ClassVar[type[Apriel2VisionAdapterConverter]] = ( Apriel2VisionAdapterConverter ) - embeddings_converter_class: typing.ClassVar[type[Apriel2EmbeddingsConverter]] = Apriel2EmbeddingsConverter - encoder_converter_class: typing.ClassVar[type[Apriel2VisionEncoderConverter]] = Apriel2VisionEncoderConverter - # HF path prefixes for Apriel2 (external HF model format) hf_embeddings_prefix: typing.ClassVar[str] = "model.vision_encoder.embeddings" hf_encoder_prefix: typing.ClassVar[str] = "model.vision_encoder.encoder.blocks" hf_adapter_prefix: typing.ClassVar[str] = "model.vision_encoder.adapter" @classmethod - def import_config(cls, config: dict) -> dict: - vision_config = config["vision_encoder"] + def _create_config_converters(cls) -> dict: return { - "embeddings": cls.embeddings_converter_class.import_config(vision_config), - "encoder": cls.encoder_converter_class.import_config(vision_config), - "adapter": cls.vision_adapter_converter_class.import_config(vision_config), - "hidden_size": vision_config["hidden_size"], + "embeddings": NestedConfigConverter(("embeddings",), cls.embeddings_converter_class), + "encoder": NestedConfigConverter(("encoder",), cls.encoder_converter_class), + "adapter": NestedConfigConverter(("adapter",), cls.vision_adapter_converter_class, hf_path=("adapter",)), + "hidden_size": RenameConfigConverter(("hidden_size",), ("hidden_size",)), + # Cross-section rotary metadata: the Apriel2 HF format requires patch_size + max_image_size inside + # ``encoder.block.mixer.rotary`` (for ``pixtral_2d``), derived from embeddings.patch_height plus a + # constant 1024. Written here because this converter is the smallest scope that sees both. + # No fast_llm_paths/hf_paths claims — the encoder's recursive rotary claim covers HF coverage; the + # values land on import via the same recursive claim and are stripped by the attention import_fn. + "rotary_metadata": CustomConfigConverter( + fast_llm_paths=(), + hf_paths=(), + export_fn=cls._inject_rotary_metadata, + import_fn=lambda hf: {}, + ), } - @classmethod - def export_config(cls, config: VisionEncoderConfig) -> dict: - Assert.custom(isinstance, config, VisionEncoderConfig) - - vision_config = safe_merge_dicts( - cls.embeddings_converter_class.export_config(config.embeddings), - cls.encoder_converter_class.export_config(config.encoder), - cls.vision_adapter_converter_class.export_config(config.adapter), - {"hidden_size": config.hidden_size}, - ) + @staticmethod + def _inject_rotary_metadata(config: VisionEncoderConfig) -> dict: + rotary = config.encoder.block.mixer.rotary + if type(rotary) is Rotary2DConfig: + return { + ("encoder", "block", "mixer", "rotary", "patch_size"): config.embeddings.patch_height, + ("encoder", "block", "mixer", "rotary", "max_image_size"): 1024, + } + return {} - # Add patch_size and max_image_size to rotary config for pixtral_2d - patch_size = config.embeddings.patch_height - encoder_block = vision_config["encoder"]["block"] - rotary = encoder_block["mixer"]["rotary"] - if rotary["type"] == "pixtral_2d": - rotary["patch_size"] = patch_size - rotary["max_image_size"] = 1024 # Standard max image size for Pixtral - - return {"vision_encoder": vision_config} + # --- weight side (imperative) --- @classmethod def get_converters(cls, config: VisionEncoderConfig) -> list[WeightConverter]: @@ -316,51 +475,85 @@ def get_converters( ] -class Apriel2MultimodalBaseModelConverter: +class Apriel2MultimodalBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): + """Top-level converter for Apriel2 multimodal. Composes the Apriel2 text base (flat-merged into the HF + top-level dict) with an optional vision encoder (under HF key ``vision_encoder``) and an optional + ``image_token_index`` field. + + Architecturally the Fast-LLM config (:class:`MultiModalBaseModelConfig`) multi-inherits from both + :class:`GPTBaseModelConfig` (text) and :class:`VisionMultiModalModelConfig` (vision/image_token_index), + so a single declaration set drives both halves. + """ + + fast_llm_config_class = MultiModalBaseModelConfig + + text_base_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BaseModelConverter vision_model_converter_class: typing.ClassVar[type[Apriel2VisionModelConverter]] = Apriel2VisionModelConverter - decoder_converter_class: typing.ClassVar[type[Apriel2DecoderConverter]] = Apriel2DecoderConverter - embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter - head_converter_class: typing.ClassVar[type[Apriel2MultimodalHeadConverter]] = Apriel2MultimodalHeadConverter @classmethod - def import_config(cls, config: dict) -> dict: - text_config = Apriel2BaseModelConverter.import_config(config) - vision_config = cls.vision_model_converter_class.import_config(config) if "vision_encoder" in config else None + def _create_config_converters(cls) -> dict: + text_base_cls = cls.text_base_converter_class + vision_cls = cls.vision_model_converter_class - result = safe_merge_dicts( - text_config, - {"vision_encoder": vision_config}, - ) - if "image_token_index" in config: - result["image_token_index"] = config["image_token_index"] - return result + def _vision_export(config: MultiModalBaseModelConfig) -> dict: + if config.vision_encoder is None: + return {} + return {("vision_encoder",): vision_cls.export_config(config.vision_encoder)} - @classmethod - def export_config(cls, config: MultiModalBaseModelConfig) -> dict: - Assert.custom(isinstance, config, MultiModalBaseModelConfig) - exported = Apriel2BaseModelConverter.export_config(config) - if config.vision_encoder is not None: - exported = safe_merge_dicts( - exported, - cls.vision_model_converter_class.export_config(config.vision_encoder), - ) + def _vision_import(hf_dict: dict) -> dict: + if "vision_encoder" not in hf_dict: + return {} + return {("vision_encoder",): vision_cls.import_config(hf_dict["vision_encoder"])} - if config.image_token_index is not None: - exported["image_token_index"] = config.image_token_index + return { + # Flat-merge the Apriel2 text base into the top-level HF dict. The text base claims every + # GPTBaseModelConfig architecture leaf via its own declarations; we mark them recursively + # consumed here and forward HF coverage via the text base's ``_consumed_hf_paths``. + "text_base": CustomConfigConverter( + fast_llm_paths=( + ("embeddings",), + ("decoder",), + ("head",), + ("hidden_size",), + ("tied_embedding_weight",), + ("peft",), + ), + hf_paths=tuple(text_base_cls._consumed_hf_paths()), + export_fn=lambda c: {(k,): v for k, v in text_base_cls.export_config(c).items()}, + import_fn=lambda hf: {(k,): v for k, v in text_base_cls.import_config(hf).items()}, + recurses=True, + ), + # Optional vision encoder. The Fast-LLM ``vision_encoder`` field is architecture-hint and + # ``None`` by default; the HF ``vision_encoder`` key is absent for text-only models. + "vision_encoder": CustomConfigConverter( + fast_llm_paths=(("vision_encoder",),), + hf_paths=(("vision_encoder",),), + export_fn=_vision_export, + import_fn=_vision_import, + recurses=True, + ), + # ``image_token_index`` is FieldHint.optional so it's not in the architecture-coverage set, + # but it does live on the HF dict for vision-enabled checkpoints. + "image_token_index": OptionalConfigConverter(("image_token_index",), ("image_token_index",)), + } - return exported + # --- weight side (imperative) --- + + embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter + head_converter_class: typing.ClassVar[type[Apriel2MultimodalHeadConverter]] = Apriel2MultimodalHeadConverter @classmethod def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict) -> list[WeightConverter]: - converters = [] + converters: list[WeightConverter] = [] if config.vision_encoder is not None: converters.extend(cls.vision_model_converter_class.get_converters(config.vision_encoder)) converters.extend(cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model")) converters.extend( - cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.decoder.blocks") + get_apriel2_decoder_converter(config.decoder).get_converters( + config.decoder, "decoder", "model.decoder.blocks" + ) ) converters.extend(cls.head_converter_class.get_converters(config.head, exported_config, "head")) - return converters @@ -391,21 +584,17 @@ def get_model_files(cls) -> tuple[str, str, str | None]: @classmethod def _export_config(cls, config: MultiModalModelConfig) -> dict[str, typing.Any]: - base_model = config.base_model - exported = safe_merge_dicts( - cls.base_model_converter_class.export_config(base_model), - { - "architectures": [cls.architecture], - "model_type": cls.get_huggingface_model_type(), - "auto_map": { - "AutoConfig": "configuration_apriel2.Apriel2Config", - "AutoModel": "modeling_apriel2.Apriel2Model", - "AutoModelForCausalLM": "modeling_apriel2.Apriel2ForConditionalGeneration", - "AutoModelForImageTextToText": "modeling_apriel2.Apriel2ForConditionalGeneration", - }, + return { + **cls.base_model_converter_class.export_config(config.base_model), + "architectures": [cls.architecture], + "model_type": cls.get_huggingface_model_type(), + "auto_map": { + "AutoConfig": "configuration_apriel2.Apriel2Config", + "AutoModel": "modeling_apriel2.Apriel2Model", + "AutoModelForCausalLM": "modeling_apriel2.Apriel2ForConditionalGeneration", + "AutoModelForImageTextToText": "modeling_apriel2.Apriel2ForConditionalGeneration", }, - ) - return exported + } @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> dict[str, typing.Any]: diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index fe7c77f5e..4f8e70868 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -3,17 +3,28 @@ import torch from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + ConstantExportConfigConverter, + ConstantImportConfigConverter, + CustomConfigConverter, + IgnoredConfigConverter, + ImportOnlyConfigConverter, + NestedConfigConverter, + RenameConfigConverter, + WeightConverter, +) from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.attention.rotary.config import Rotary2DConfig -from fast_llm.layers.common.normalization.config import RMSNormalizationConfig +from fast_llm.layers.block.config import FixedBlockSequenceConfig from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionEncoderConfig from fast_llm.models.gpt.conversion.llama import ( + _TRANSFORMERS_V4, LlamaAttentionConverter, LlamaBlockConverter, LlamaDecoderConverter, @@ -27,52 +38,75 @@ from fast_llm.models.multimodal.conversion.config import LlavaCheckpointFormat from fast_llm.models.multimodal.model import MultiModalModel from fast_llm.tensor import SafeTensorSlice -from fast_llm.utils import Assert, div, safe_merge_dicts +from fast_llm.utils import Assert, div class PixtralNormalizationConverter(LlamaNormalizationConverter): - """ - epsilon hard-coded to 1e-5. - """ + """RMS norm with HF-side hardcoded epsilon=1e-5 (Pixtral's HF format omits the field).""" @classmethod - def import_config(cls, config: dict) -> dict: - return {"type": "rms_norm", "epsilon": 1e-5} + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Pin epsilon to 1e-5: assert on export, inject on import. No HF write/read. + "epsilon": ConstantImportConfigConverter(("epsilon",), 1e-5), + } + + +def _pixtral_rotary_export(config: AttentionConfig) -> dict: + if _TRANSFORMERS_V4: + return {("rope_theta",): config.rotary.theta} + return {("rope_parameters",): {"rope_theta": config.rotary.theta, "rope_type": "default"}} - @classmethod - def export_config(cls, config: RMSNormalizationConfig) -> dict: - Assert.custom(isinstance, config, RMSNormalizationConfig) - assert not config.zero_centered - # TODO: Too strict? - Assert.eq(config.epsilon, 1e-5) - return {} + +def _pixtral_rotary_import(hf_dict: dict) -> dict: + if "rope_parameters" in hf_dict: + theta = hf_dict["rope_parameters"]["rope_theta"] + else: + theta = hf_dict["rope_theta"] + return {("rotary",): {"type": "default_2d", "theta": theta}} class PixtralAttentionConverter(LlamaAttentionConverter): @classmethod - def import_config(cls, config: dict) -> dict: - config["num_key_value_heads"] = config["num_attention_heads"] - config["attention_bias"] = False - out = super().import_config(config) - out["rotary"]["type"] = "default_2d" - out["causal"] = False - return out + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # PixtralConfig hardcodes Q/K/V/O biases off and does not surface ``attention_bias``. + "add_linear_biases": ConstantImportConfigConverter(("add_linear_biases",), False), + # Pixtral attention is non-causal (vision encoder). + "causal": ConstantImportConfigConverter(("causal",), False), + # No GQA in Pixtral; ``head_groups`` derives from ``num_attention_heads`` on import and is redundant + # on export (``_validate_export`` enforces equality with ``heads``). + "head_groups": ImportOnlyConfigConverter( + fast_llm_paths=(("head_groups",),), + import_fn=lambda hf: {("head_groups",): hf["num_attention_heads"]}, + ), + # Llava's PixtralVisionConfig has no ``head_dim`` field — it is derived as ``hidden_size // + # num_attention_heads``. Don't emit head_dim on export (would otherwise need to be popped + # downstream); on import, derive head_size from the same expression. Invariant validated by + # :class:`LlavaVisionModelConverter._validate_export`, which has access to the parent's + # ``hidden_size``. + "head_size": ImportOnlyConfigConverter( + fast_llm_paths=(("head_size",),), + import_fn=lambda hf: {("head_size",): div(hf["hidden_size"], hf["num_attention_heads"])}, + ), + # Pixtral always uses 2D rotary; only ``theta`` round-trips. The flat (v4) vs ``rope_parameters`` (v5) + # layout follows the active transformers major version, mirroring the Llama parent. + "rotary": CustomConfigConverter( + fast_llm_paths=(("rotary",),), + hf_paths=(("rope_theta",), ("rope_parameters",)), + export_fn=_pixtral_rotary_export, + import_fn=_pixtral_rotary_import, + recurses=True, + ), + } @classmethod - def export_config(cls, config: AttentionConfig) -> dict: - cls._check_config(config) - Assert.eq(config.softmax_scale_power, 0.5) + def _validate_export(cls, config: AttentionConfig) -> None: + super()._validate_export(config) Assert.is_(type(config.rotary), Rotary2DConfig) - assert not config.add_linear_biases - assert not config.causal Assert.eq(config.head_groups, config.heads) - return { - "num_attention_heads": config.heads, - "attention_dropout": config.dropout, - "rope_theta": config.rotary.theta, - # Not in PixtralConfig, but needed for consistency check in LlavaVisionModelConverter. - "head_dim": config.head_size, - } class PixtralBlockConverter(LlamaBlockConverter): @@ -117,32 +151,38 @@ def import_weight( ) -class PixtralEmbeddingsConverter: +class PixtralEmbeddingsConverter(ConfigSectionConverter): + """Converts ``PatchEmbeddingsConfig`` ↔ Pixtral HF flat fields (``patch_size`` / ``num_channels``). + + Pixtral's HF ``vision_config`` carries a single ``patch_size`` field (height == width); the converter + expands it to both Fast-LLM dimensions on import and validates equality on export. + """ + + fast_llm_config_class = PatchEmbeddingsConfig normalization_converter_class: typing.ClassVar[type[PixtralNormalizationConverter]] = PixtralNormalizationConverter @classmethod - def import_config(cls, config: dict) -> dict: - Assert.eq(config["num_channels"], 3) + def _create_config_converters(cls) -> dict: return { - "normalization": cls.normalization_converter_class.import_config(config), - "patch_height": config["patch_size"], - "patch_width": config["patch_size"], + "patch_height": RenameConfigConverter(("patch_height",), ("patch_size",)), + # Pixtral has one `patch_size`; mirror it to `patch_width` on import and validate equality on export. + "patch_width": ImportOnlyConfigConverter( + fast_llm_paths=(("patch_width",),), + import_fn=lambda hf: {("patch_width",): hf["patch_size"]}, + ), + # `input_channels` is a derived cached_property pinned to 3; assert on import, emit on export. + "num_channels": ConstantExportConfigConverter(("num_channels",), 3), + # PixtralNormalizationConverter exports {} (epsilon pinned), so flat-merge is a no-op on export. + "normalization": NestedConfigConverter(("normalization",), cls.normalization_converter_class), + # patch_embeddings (the AffineLinearConfig) has no HF representation; bias presence validated below. + "patch_embeddings": IgnoredConfigConverter(("patch_embeddings",)), } @classmethod - def export_config(cls, config: PatchEmbeddingsConfig) -> dict: - Assert.custom(isinstance, config, PatchEmbeddingsConfig) + def _validate_export(cls, config: PatchEmbeddingsConfig) -> None: Assert.eq(config.patch_height, config.patch_width) Assert.incl(config.patch_embeddings.bias.enabled, (None, False)) - return safe_merge_dicts( - { - "patch_size": config.patch_height, - "num_channels": config.input_channels, - }, - cls.normalization_converter_class.export_config(config.normalization), - ) - @classmethod def get_converters( cls, config: PatchEmbeddingsConfig, fast_llm_prefix: str, hf_prefix: str @@ -161,27 +201,48 @@ def get_converters( ] -class LlavaVisionAdapterConverter: +class LlavaVisionAdapterConverter(ConfigSectionConverter): + """Converts the vision adapter :class:`MLPConfig` ↔ Llava's flat top-level adapter fields + (``projector_hidden_act``, ``multimodal_projector_bias``). + + Wrinkle: the adapter's ``intermediate_size`` derives from the **text** half of the model + (``text_config["hidden_size"]``). The cross-section reference is reachable because this converter is + flat-merged at the :class:`LlavaBaseModelConverter` scope, where ``text_config`` lives as a sibling + HF top-level key. + """ + + fast_llm_config_class = MLPConfig + hf_type_name = "mlp" + @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - "intermediate_size": config["text_config"]["hidden_size"], - "add_linear_biases": config["multimodal_projector_bias"], - "gated": False, - "activation": ActivationType.from_hf_name(config["projector_hidden_act"]), + # Cross-section: imported from text_config.hidden_size. No HF claim — text_config is claimed + # by the language model converter at the base level. + "intermediate_size": ImportOnlyConfigConverter( + fast_llm_paths=(("intermediate_size",),), + import_fn=lambda hf: {("intermediate_size",): hf["text_config"]["hidden_size"]}, + ), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("multimodal_projector_bias",)), + "gated": ConstantImportConfigConverter(("gated",), False), + "activation": CustomConfigConverter( + fast_llm_paths=(("activation",),), + hf_paths=(("projector_hidden_act",),), + export_fn=lambda c: {("projector_hidden_act",): c.activation.hf_name}, + import_fn=lambda hf: {("activation",): ActivationType.from_hf_name(hf["projector_hidden_act"])}, + ), + # Per-layer ``bias.enabled`` has no HF representation; defaults round-trip. Validated below. + "linear_layers": IgnoredConfigConverter(("layer_1",), ("layer_2",)), + "pre_norm": ConstantImportConfigConverter(("pre_norm",), None), + "post_norm": ConstantImportConfigConverter(("post_norm",), None), } @classmethod - def export_config(cls, config: MLPConfig) -> dict: - Assert.custom(isinstance, config, MLPConfig) + def _validate_export(cls, config: MLPConfig) -> None: Assert.incl(config.layer_1.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) - assert not config.gated - return { - "projector_hidden_act": config.activation.hf_name, - "multimodal_projector_bias": config.add_linear_biases, - } + # --- weight side (imperative) --- @classmethod def get_converters(cls, config: MLPConfig, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: @@ -201,39 +262,78 @@ def get_converters(cls, config: MLPConfig, fast_llm_prefix: str, hf_prefix: str) ] -class LlavaVisionModelConverter: - vision_adapter_converter_class: typing.ClassVar[type[LlavaVisionAdapterConverter]] = LlavaVisionAdapterConverter +class LlavaVisionModelConverter(ConfigSectionConverter): + """Converts :class:`VisionEncoderConfig` ↔ Llava's ``vision_config`` HF subdict. + + Declarations operate relative to ``vision_config`` (parent nests this converter via + ``NestedConfigConverter(hf_path=("vision_config",))``). The adapter is *not* declared here — it + lives at the base level because its Fast-LLM intermediate_size derives from text_config.hidden_size, + a cross-section reference only visible at the top of the HF dict. + """ + + fast_llm_config_class = VisionEncoderConfig + embeddings_converter_class: typing.ClassVar[type[PixtralEmbeddingsConverter]] = PixtralEmbeddingsConverter encoder_converter_class: typing.ClassVar[type[PixtralEncoderConverter]] = PixtralEncoderConverter model_type: typing.ClassVar[str] = "pixtral" @classmethod - def import_config(cls, config: dict) -> dict: - Assert.eq(config["vision_config"]["model_type"], cls.model_type) + def _create_config_converters(cls) -> dict: + encoder_cls = cls.encoder_converter_class + + def _encoder_export(config: VisionEncoderConfig) -> dict: + return {(k,): v for k, v in encoder_cls.export_config(config.encoder).items()} + + def _encoder_import(hf_dict: dict) -> dict: + return {("encoder",): encoder_cls.import_config(hf_dict)} + return { - "embeddings": cls.embeddings_converter_class.import_config(config["vision_config"]), - "encoder": cls.encoder_converter_class.import_config(config["vision_config"]), - "adapter": cls.vision_adapter_converter_class.import_config(config), - "hidden_size": config["vision_config"]["hidden_size"], + # Flat-merged into vision_config: embeddings (PatchEmbeddingsConverter writes patch_size/etc), + # encoder (LlamaDecoderConverter dispatch — Custom-wrapped since it stays imperative). + "embeddings": NestedConfigConverter(("embeddings",), cls.embeddings_converter_class), + "encoder": CustomConfigConverter( + fast_llm_paths=(("encoder",),), + hf_paths=( + ("num_hidden_layers",), + *encoder_cls.block_converter_class._consumed_hf_paths(), + ), + export_fn=_encoder_export, + import_fn=_encoder_import, + recurses=True, + ), + "hidden_size": RenameConfigConverter(("hidden_size",), ("hidden_size",)), + # Llava's vision_config carries a literal ``model_type: "pixtral"``; + # ``ConstantExportConfigConverter`` emits on export and asserts equality on import. + "model_type": ConstantExportConfigConverter(("model_type",), cls.model_type), + # ``transformers.LlavaConfig.from_dict(...).save_pretrained(...)`` round-trips the + # vision_config through :class:`PixtralVisionConfig`, which fills in these default fields. + # Fast-LLM does not consume them; mark them ignored so the recursive coverage check accepts + # round-tripped saves. (``head_dim`` is normally not emitted because we override head_size to + # ImportOnly, but transformers fills it from ``hidden_size // num_attention_heads`` on load.) + "pixtral_hf_defaults": IgnoredConfigConverter( + hf_paths=( + ("head_dim",), + ("image_size",), + ("initializer_factor",), + ("layer_norm_eps",), + ("projection_dim",), + ("vocab_size",), + ), + ), + # Adapter is handled at LlavaBaseModelConverter scope (sees text_config). Mark recursively + # consumed here so the architecture walker sees the sub-tree as claimed at this level too. + "adapter": IgnoredConfigConverter(("adapter",)), } @classmethod - def export_config(cls, config: VisionEncoderConfig) -> dict: - Assert.custom(isinstance, config, VisionEncoderConfig) - vision_config = safe_merge_dicts( - cls.embeddings_converter_class.export_config(config.embeddings), - cls.encoder_converter_class.export_config(config.encoder), - {"hidden_size": config.hidden_size, "model_type": cls.model_type}, - ) - - Assert.eq( - vision_config.pop("head_dim"), div(vision_config["hidden_size"], vision_config["num_attention_heads"]) - ) + def _validate_export(cls, config: VisionEncoderConfig) -> None: + # Llava's PixtralVisionConfig does not carry head_dim — it is derived as ``hidden_size // + # num_attention_heads``. Validate the Fast-LLM head_size satisfies this invariant. + mixer = config.encoder.block.mixer + if isinstance(mixer, AttentionConfig): + Assert.eq(mixer.head_size * mixer.heads, config.hidden_size) - return safe_merge_dicts( - {"vision_config": vision_config}, - cls.vision_adapter_converter_class.export_config(config.adapter), - ) + # --- weight side (imperative) --- @classmethod def get_converters(cls, config: VisionEncoderConfig) -> list[WeightConverter]: @@ -244,7 +344,7 @@ def get_converters(cls, config: VisionEncoderConfig) -> list[WeightConverter]: *cls.encoder_converter_class.get_converters( config.encoder, "vision_encoder.encoder", "vision_tower.transformer.layers" ), - *cls.vision_adapter_converter_class.get_converters( + *LlavaVisionAdapterConverter.get_converters( config.adapter, "vision_encoder.adapter", "multi_modal_projector" ), ] @@ -275,50 +375,102 @@ class LlavaLanguageModelConverter(MistralBaseModelConverter): head_converter_class: typing.ClassVar[type[LlavaHeadConverter]] = LlavaHeadConverter -class LlavaBaseModelConverter(HuggingFaceBaseModelConverter): +class LlavaBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): + """Top-level converter for Llava. Composes: + + * ``text_config`` HF subdict ← :class:`LlavaLanguageModelConverter` (Mistral text base). + * ``vision_config`` HF subdict ← :class:`LlavaVisionModelConverter` (Pixtral vision encoder). + * Top-level adapter fields (``projector_hidden_act``, ``multimodal_projector_bias``) ← + :class:`LlavaVisionAdapterConverter`, flat-merged because the adapter's ``intermediate_size`` + derives from ``text_config.hidden_size``. + * Top-level multimodal metadata (``image_token_index``, ``vision_feature_select_strategy``, + ``vision_feature_layer``). + """ + + fast_llm_config_class = MultiModalBaseModelConfig + vision_model_converter_class: typing.ClassVar[type[LlavaVisionModelConverter]] = LlavaVisionModelConverter + vision_adapter_converter_class: typing.ClassVar[type[LlavaVisionAdapterConverter]] = LlavaVisionAdapterConverter # TODO: Make it flexible? language_model_converter_class: typing.ClassVar[type[LlavaLanguageModelConverter]] = LlavaLanguageModelConverter # TODO: Is tie_word_embeddings supported? @classmethod - def import_config(cls, config: dict) -> dict: - return safe_merge_dicts( - { - "vision_encoder": cls.vision_model_converter_class.import_config(config), - "image_token_index": config["image_token_index"], - }, - cls.language_model_converter_class.import_config(config["text_config"]), - ) + def _create_config_converters(cls) -> dict: + text_base_cls = cls.language_model_converter_class + vision_cls = cls.vision_model_converter_class + adapter_cls = cls.vision_adapter_converter_class + + # The Fast-LLM ``MultiModalBaseModelConfig`` IS-A ``GPTBaseModelConfig`` (multi-inherits via + # ``VisionMultiModalModelConfig``), so ``text_base_cls.export_config(config)`` works directly on + # the multimodal config: its declarations only touch GPTBaseModelConfig fields, which exist here. + def _text_export(config: MultiModalBaseModelConfig) -> dict: + return {("text_config",): text_base_cls.export_config(config)} + + def _text_import(hf_dict: dict) -> dict: + return {(k,): v for k, v in text_base_cls.import_config(hf_dict["text_config"]).items()} + + return { + "text_base": CustomConfigConverter( + fast_llm_paths=( + ("embeddings",), + ("decoder",), + ("head",), + ("hidden_size",), + ("tied_embedding_weight",), + ("peft",), + ), + hf_paths=(("text_config",),), + export_fn=_text_export, + import_fn=_text_import, + recurses=True, + ), + "vision_encoder": NestedConfigConverter(("vision_encoder",), vision_cls, hf_path=("vision_config",)), + # Adapter flat-merged at top level: its import sees text_config.hidden_size as a sibling key. + "adapter": NestedConfigConverter(("vision_encoder", "adapter"), adapter_cls, hf_path=None), + "image_token_index": RenameConfigConverter(("image_token_index",), ("image_token_index",)), + "vision_feature_select_strategy": ConstantExportConfigConverter( + ("vision_feature_select_strategy",), "full" + ), + "vision_feature_layer": ConstantExportConfigConverter(("vision_feature_layer",), -1), + # ``transformers.LlavaConfig.save_pretrained(...)`` round-trips the top-level config through + # :class:`transformers.LlavaConfig`, which fills these defaults. Fast-LLM tracks + # ``tie_word_embeddings`` *inside* text_config (Llama's tied_embedding_weight), not at the + # Llava level; ``image_seq_length`` is a runtime/inference field, not architecture. + "llava_hf_defaults": IgnoredConfigConverter( + hf_paths=(("image_seq_length",), ("tie_word_embeddings",)), + ), + } @classmethod - def export_config(cls, config: MultiModalBaseModelConfig) -> dict: - Assert.custom(isinstance, config, MultiModalBaseModelConfig) - assert config.image_token_index is not None - out = safe_merge_dicts( - cls.vision_model_converter_class.export_config(config.vision_encoder), - { - "text_config": cls.language_model_converter_class.export_config(config), - "image_token_index": config.image_token_index, - "vision_feature_select_strategy": "full", - "vision_feature_layer": -1, - }, - ) - return out + def _validate_export(cls, config: MultiModalBaseModelConfig) -> None: + # Llava requires both a vision encoder and an image_token_index to be set. + Assert.custom(lambda v: v is not None, config.vision_encoder) + Assert.custom(lambda v: v is not None, config.image_token_index) + + # --- weight side (imperative) --- @classmethod def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + text_base_cls = cls.language_model_converter_class + decoder_config = config.decoder + block_config = ( + decoder_config.block + if isinstance(decoder_config, FixedBlockSequenceConfig) + else next(iter(decoder_config.blocks.values())) + ) + block_converters: list[WeightConverter] = [] + for block_index in range(decoder_config.num_blocks): + block_converters += text_base_cls.block_converter_class.get_converters( + block_config, f"decoder.{block_index}", f"language_model.model.layers.{block_index}" + ) return [ *cls.vision_model_converter_class.get_converters(config.vision_encoder), - *cls.language_model_converter_class.embeddings_converter_class.get_converters( + *text_base_cls.embeddings_converter_class.get_converters( config.embeddings, "embeddings", "language_model.model" ), - *cls.language_model_converter_class.decoder_converter_class.get_converters( - config.decoder, "decoder", "language_model.model.layers" - ), - *cls.language_model_converter_class.head_converter_class.get_converters( - config, {"tie_word_embeddings": False} - ), + *block_converters, + *text_base_cls.head_converter_class.get_converters(config, {"tie_word_embeddings": False}), ] diff --git a/tests/models/test_converters.py b/tests/models/test_converters.py new file mode 100644 index 000000000..9ce8b0893 --- /dev/null +++ b/tests/models/test_converters.py @@ -0,0 +1,129 @@ +"""Static checks on every checkpoint format's converter tree. + +For each registered ``HuggingfaceStateDictCheckpointHandler``, walk its modular converter structure — +``base_model_converter_class`` and the ``ConfigSectionConverter`` classes reached transitively through +``Nested``/``Dispatch``/``TypedDictContainer`` declarations — and verify, at every node: + +* Architecture-hint fields on ``cls.fast_llm_config_class`` are all consumed by some declaration. +* OptionalConfigConverter sentinels match the resolved field default. Otherwise an exported value equal + to the sentinel becomes absent on disk and re-imports as a different default, silently breaking round-trip. + +Replaces the per-export ``check_architecture_coverage`` invocation that used to happen on every save. +""" + +import typing + +import pytest + +# Force registration of every format handler. +import fast_llm.models.gpt.conversion.auto # noqa: F401 +import fast_llm.models.multimodal.conversion.auto # noqa: F401 +from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + DispatchConfigConverter, + NestedConfigConverter, + OptionalConfigConverter, + TypedDictContainerConfigConverter, + _get_attr_path, +) +from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.block.config import PatternBlockSequenceConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig, StochasticMixerConfig + +# Configs that don't default-construct cleanly need a minimal-valid factory. +_DEFAULT_FACTORIES: dict[type, typing.Callable[[], typing.Any]] = { + PatternBlockSequenceConfig: lambda: PatternBlockSequenceConfig( + blocks={"x": DecoderBlockConfig()}, + pattern=("x",), + ), + StochasticMixerConfig: lambda: StochasticMixerConfig( + mixers={"x": AttentionConfig()}, + main_mixer_name="x", + ), +} + + +def _default_instance(config_class: type) -> typing.Any: + factory = _DEFAULT_FACTORIES.get(config_class) + return factory() if factory is not None else config_class() + + +def _all_format_handlers() -> list[type[HuggingfaceStateDictCheckpointHandler]]: + seen: set[type[HuggingfaceStateDictCheckpointHandler]] = set() + out: list[type[HuggingfaceStateDictCheckpointHandler]] = [] + + def visit(cls: type) -> None: + for sub in cls.__subclasses__(): + if sub in seen: + continue + seen.add(sub) + # Concrete handlers declare a ``base_model_converter_class``; abstract intermediaries don't. + if getattr(sub, "base_model_converter_class", None) is not None: + out.append(sub) + visit(sub) + + visit(HuggingfaceStateDictCheckpointHandler) + return out + + +def _children(node: type) -> list[type]: + """Return every sub-converter class reachable from ``node``. + + Picks up two complementary structures: + * ``ConfigSectionConverter`` declarations — the ``_converter_class`` on each Nested/Dispatch/TypedDict. + * ``*_converter_class`` ClassVars — the polymorphism extension points used by aggregator nodes + (e.g. ``LlavaBaseModelConverter`` is not itself a section converter but exposes + ``vision_model_converter_class`` and ``language_model_converter_class``). + """ + out: list[type] = [] + if isinstance(node, type) and issubclass(node, ConfigSectionConverter): + for declaration in node._create_config_converters().values(): + if isinstance(declaration, NestedConfigConverter): + out.append(declaration._converter_class) + elif isinstance(declaration, (DispatchConfigConverter, TypedDictContainerConfigConverter)): + out.extend(declaration._registry.values()) + for name in dir(node): + if not name.endswith("_converter_class") or name == "base_model_converter_class": + continue + attr = getattr(node, name, None) + if isinstance(attr, type): + out.append(attr) + return out + + +def _walk(root: type) -> typing.Iterator[type]: + """Yield ``root`` and every converter class reachable from it (each at most once).""" + seen: set[type] = set() + stack: list[type] = [root] + while stack: + node = stack.pop() + if node in seen: + continue + seen.add(node) + yield node + stack.extend(_children(node)) + + +_HANDLERS = _all_format_handlers() + + +@pytest.mark.parametrize("handler_class", _HANDLERS, ids=lambda h: h.__name__) +def test_format_converter_tree(handler_class: type[HuggingfaceStateDictCheckpointHandler]) -> None: + """Walk the format's converter tree from ``base_model_converter_class``; check every section node.""" + for converter_class in _walk(handler_class.base_model_converter_class): + if not (isinstance(converter_class, type) and issubclass(converter_class, ConfigSectionConverter)): + continue + if getattr(converter_class, "fast_llm_config_class", None) is None: + continue + config = _default_instance(converter_class.fast_llm_config_class) + converter_class.check_architecture_coverage(config) + for name, declaration in converter_class._create_config_converters().items(): + if not isinstance(declaration, OptionalConfigConverter): + continue + path = declaration.fast_llm_paths[0] + default = _get_attr_path(config, path) + assert declaration._sentinel == default, ( + f"{converter_class.__name__}.{name}: sentinel {declaration._sentinel!r} " + f"does not match field default {default!r} at path {'.'.join(path)}" + )