Skip to content

feat(models): add Mellum 2 hybrid-attention MoE text model#397

Merged
inureyes merged 2 commits into
mainfrom
feature/issue-361-mellum-model
Jun 22, 2026
Merged

feat(models): add Mellum 2 hybrid-attention MoE text model#397
inureyes merged 2 commits into
mainfrom
feature/issue-361-mellum-model

Conversation

@inureyes

@inureyes inureyes commented Jun 22, 2026

Copy link
Copy Markdown
Member

Summary

Add JetBrains' Mellum 2 (model_type: mellum), a sliding/full hybrid-attention sparse-MoE code model, ported from mlx-lm (https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/mellum.py). The architecture is a composition of primitives mlxcel already has (QK-RMSNorm attention, softmax-routed MoE, mixed KVCache/RotatingKVCache, YaRN RoPE), wired into a new family loader.

What changed

  • src/models/mellum.rs (new): config (serde), QK-RMSNorm Attention, softmax-routed SparseMoeBlock (top-k via argpartition, norm_topk_prob, routed through the shared SwitchGLU), an optional dense MLP path driven by mlp_layer_types, the pre-norm DecoderLayer, MellumModel, sanitize, make_caches, and the MellumWrapper that owns the mixed caches and implements LanguageModel.
  • src/models/mellum_tests.rs (new): config parse, YaRN attention-factor derivation, sanitize expert-stacking, per-layer cache selection, and tied/untied head handling.
  • src/models/detection.rs + src/models/detection_tests.rs: detect "mellum" -> ModelType::Mellum, with a detection test.
  • src/models/mod.rs: mellum module + re-exports, the ModelType::Mellum variant, and its ALL_MODEL_TYPES / metadata() / arch-completeness entries, plus the mellum_tests module.
  • src/loaded_model.rs: LoadedModel::Mellum variant + delegate_language_model! arm.
  • src/model_metadata.rs: config-backed registration entry.
  • src/distributed/tensor_parallel/inference.rs: fallback_architecture arm.
  • docs/supported-models.md: Mellum entry.

Design notes

  • Per-layer RoPE. rope_parameters is keyed by layer type. full_attention layers use YaRN (base 500000, factor 16) built once and shared per full layer; the derived mscale equals the config's explicit attention_factor (1.2772588722239782). sliding_attention layers use default RoPE. This mirrors mlx-lm initialize_rope / YarnRoPE.
  • Hybrid caches. layer_types[i] selects a windowed mask + RotatingKVCache(max_size=sliding_window) for sliding layers and a causal mask + KVCache for full layers; the wrapper owns the mixed caches (supports_batching=false).
  • QK-norm. q_norm/k_norm are RMSNorm(head_dim) applied to the reshaped per-head q/k before transpose and RoPE.
  • MoE routing. Precise (f32-accumulation) softmax_precise over the router logits (matching upstream precise=True), top-k via argpartition, scores gathered with take_along_axis, normalized when norm_topk_prob, experts through the shared SwitchGLU (single-token fused kernel with the gather fallback).
  • Untied head. tie_word_embeddings drives Option<lm_head>; sanitize pops a tied lm_head.weight and from_weights projects through embed_tokens.as_linear when tied. This checkpoint is untied.
  • Expert stacking. sanitize stacks per-expert experts.{e}.{gate,up,down}_proj.* into switch_mlp.{proj}.* in place; the shared SwitchGLU also stacks lazily, so the weight-route load path works without an explicit pass.
  • Attribution. mlx-lm model ports are already covered by the top-level NOTICE; the file carries the standard header plus an upstream URL reference, matching the existing mlx-lm ports (e.g. gpt_oss.rs).

Test plan

  • cargo check --lib --tests --features metal,accelerate
  • cargo clippy --lib --tests --features metal,accelerate -- -D warnings
  • cargo fmt --check -p mlxcel
  • cargo test --lib --features metal,accelerate -- models::mellum (6 pass)
  • cargo test --lib --features metal,accelerate -- models::detection (detection + new Mellum case pass)
  • arch-completeness tests (every_variant_is_registered_for_arch, all_model_types_covers_every_variant) pass

Real-model validation

Validated on the real JetBrains/Mellum2-12B-A2.5B-Base checkpoint (24GB bf16, 12B total / 2.5B active, 64 experts top-8) on M1 Ultra (Apple GPU / Metal): loaded in ~5.8s at 22.63 GB bf16 resident, greedy decode ~62-65 tok/s. mlxcel generate produced correct, coherent code on two prompts: a def quicksort(arr): completion that emitted the canonical pivot / left / middle / right recursive implementation, and a def binary_search(sorted_list, target): completion that emitted a correct low/high/mid loop. Degenerate output would have exposed a YaRN, QK-norm, or hybrid-attention bug, so this confirms the model graph. The router softmax was then upgraded to softmax_precise for exact parity with upstream and re-validated to still produce coherent output.

Closes #361

Add JetBrains' Mellum 2 (model_type `mellum`), a sliding/full hybrid-attention sparse-MoE code model ported from mlx-lm (https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/mellum.py).

src/models/mellum.rs implements the config, QK-RMSNorm attention, the softmax-routed sparse MoE block (top-k via argpartition with norm_topk_prob, routed through the shared SwitchGLU), an optional dense MLP path driven by mlp_layer_types, the pre-norm decoder layer, the model, sanitize (per-expert to switch_mlp stacking plus tied-head pop), and make_caches. Attention is driven per layer by layer_types: full_attention layers use a causal mask, a standard KVCache, and YaRN-scaled RoPE, while sliding_attention layers use a windowed mask, a RotatingKVCache bounded to sliding_window, and default RoPE. The YaRN frequencies and the mscale attention factor are derived from the per-layer-type rope_parameters dict, matching mlx-lm's YarnRoPE (the config's explicit attention_factor equals the derived mscale). Both tied and untied LM heads are supported, driven by tie_word_embeddings (this checkpoint is untied). Mixed full/sliding caches are owned inside a MellumWrapper that implements LanguageModel.

Wire-up: register `"mellum"` in detection.rs, add the ModelType::Mellum variant plus its ALL_MODEL_TYPES, metadata, and arch-completeness entries in models/mod.rs, the LoadedModel::Mellum variant and delegate arm in loaded_model.rs, the config-backed registration in model_metadata.rs, and the tensor-parallel fallback-architecture arm. Update docs/supported-models.md.

Tests: src/models/mellum_tests.rs covers config parsing (real Mellum2-12B field set with the rope_parameters dict and layer_types schedule), the YaRN attention-factor derivation, sanitize expert-stacking, per-layer cache selection (RotatingKVCache for sliding, KVCache for full), and tied/untied head handling. detection_tests.rs gains a Mellum detection case.
@inureyes inureyes added status:review Under review type:enhancement New features, capabilities, or significant additions area:models Model architectures, weights, loading, metadata labels Jun 22, 2026
Upstream mellum.py routes experts with mx.softmax(gates, axis=-1, precise=True) for router numerical stability. The port used the non-precise softmax; switch to mlxcel_core::softmax_precise (f32 accumulation), matching upstream and the same choice olmoe already makes for its router. With norm_topk_prob=true the top-k scores are renormalized, so the practical effect is sub-jitter, but this aligns the routing distribution exactly with the reference.
@inureyes inureyes added status:done Completed and removed status:review Under review labels Jun 22, 2026
@inureyes inureyes merged commit 0add095 into main Jun 22, 2026
5 checks passed
@inureyes inureyes deleted the feature/issue-361-mellum-model branch June 22, 2026 01:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:models Model architectures, weights, loading, metadata status:done Completed type:enhancement New features, capabilities, or significant additions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat(models): add Mellum (Mellum 2) text model support

1 participant