feat(models): add Mellum 2 hybrid-attention MoE text model#397
Merged
Conversation
Add JetBrains' Mellum 2 (model_type `mellum`), a sliding/full hybrid-attention sparse-MoE code model ported from mlx-lm (https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/mellum.py). src/models/mellum.rs implements the config, QK-RMSNorm attention, the softmax-routed sparse MoE block (top-k via argpartition with norm_topk_prob, routed through the shared SwitchGLU), an optional dense MLP path driven by mlp_layer_types, the pre-norm decoder layer, the model, sanitize (per-expert to switch_mlp stacking plus tied-head pop), and make_caches. Attention is driven per layer by layer_types: full_attention layers use a causal mask, a standard KVCache, and YaRN-scaled RoPE, while sliding_attention layers use a windowed mask, a RotatingKVCache bounded to sliding_window, and default RoPE. The YaRN frequencies and the mscale attention factor are derived from the per-layer-type rope_parameters dict, matching mlx-lm's YarnRoPE (the config's explicit attention_factor equals the derived mscale). Both tied and untied LM heads are supported, driven by tie_word_embeddings (this checkpoint is untied). Mixed full/sliding caches are owned inside a MellumWrapper that implements LanguageModel. Wire-up: register `"mellum"` in detection.rs, add the ModelType::Mellum variant plus its ALL_MODEL_TYPES, metadata, and arch-completeness entries in models/mod.rs, the LoadedModel::Mellum variant and delegate arm in loaded_model.rs, the config-backed registration in model_metadata.rs, and the tensor-parallel fallback-architecture arm. Update docs/supported-models.md. Tests: src/models/mellum_tests.rs covers config parsing (real Mellum2-12B field set with the rope_parameters dict and layer_types schedule), the YaRN attention-factor derivation, sanitize expert-stacking, per-layer cache selection (RotatingKVCache for sliding, KVCache for full), and tied/untied head handling. detection_tests.rs gains a Mellum detection case.
Upstream mellum.py routes experts with mx.softmax(gates, axis=-1, precise=True) for router numerical stability. The port used the non-precise softmax; switch to mlxcel_core::softmax_precise (f32 accumulation), matching upstream and the same choice olmoe already makes for its router. With norm_topk_prob=true the top-k scores are renormalized, so the practical effect is sub-jitter, but this aligns the routing distribution exactly with the reference.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Add JetBrains' Mellum 2 (
model_type: mellum), a sliding/full hybrid-attention sparse-MoE code model, ported from mlx-lm (https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/mellum.py). The architecture is a composition of primitives mlxcel already has (QK-RMSNorm attention, softmax-routed MoE, mixedKVCache/RotatingKVCache, YaRN RoPE), wired into a new family loader.What changed
src/models/mellum.rs(new): config (serde), QK-RMSNormAttention, softmax-routedSparseMoeBlock(top-k via argpartition,norm_topk_prob, routed through the sharedSwitchGLU), an optional denseMLPpath driven bymlp_layer_types, the pre-normDecoderLayer,MellumModel,sanitize,make_caches, and theMellumWrapperthat owns the mixed caches and implementsLanguageModel.src/models/mellum_tests.rs(new): config parse, YaRN attention-factor derivation, sanitize expert-stacking, per-layer cache selection, and tied/untied head handling.src/models/detection.rs+src/models/detection_tests.rs: detect"mellum"->ModelType::Mellum, with a detection test.src/models/mod.rs:mellummodule + re-exports, theModelType::Mellumvariant, and itsALL_MODEL_TYPES/metadata()/ arch-completeness entries, plus themellum_testsmodule.src/loaded_model.rs:LoadedModel::Mellumvariant +delegate_language_model!arm.src/model_metadata.rs: config-backed registration entry.src/distributed/tensor_parallel/inference.rs:fallback_architecturearm.docs/supported-models.md: Mellum entry.Design notes
rope_parametersis keyed by layer type.full_attentionlayers use YaRN (base 500000, factor 16) built once and shared per full layer; the derived mscale equals the config's explicitattention_factor(1.2772588722239782).sliding_attentionlayers use default RoPE. This mirrors mlx-lminitialize_rope/YarnRoPE.layer_types[i]selects a windowed mask +RotatingKVCache(max_size=sliding_window)for sliding layers and a causal mask +KVCachefor full layers; the wrapper owns the mixed caches (supports_batching=false).q_norm/k_normareRMSNorm(head_dim)applied to the reshaped per-head q/k before transpose and RoPE.softmax_preciseover the router logits (matching upstreamprecise=True), top-k viaargpartition, scores gathered withtake_along_axis, normalized whennorm_topk_prob, experts through the sharedSwitchGLU(single-token fused kernel with the gather fallback).tie_word_embeddingsdrivesOption<lm_head>;sanitizepops a tiedlm_head.weightandfrom_weightsprojects throughembed_tokens.as_linearwhen tied. This checkpoint is untied.sanitizestacks per-expertexperts.{e}.{gate,up,down}_proj.*intoswitch_mlp.{proj}.*in place; the sharedSwitchGLUalso stacks lazily, so the weight-route load path works without an explicit pass.NOTICE; the file carries the standard header plus an upstream URL reference, matching the existing mlx-lm ports (e.g.gpt_oss.rs).Test plan
cargo check --lib --tests --features metal,acceleratecargo clippy --lib --tests --features metal,accelerate -- -D warningscargo fmt --check -p mlxcelcargo test --lib --features metal,accelerate -- models::mellum(6 pass)cargo test --lib --features metal,accelerate -- models::detection(detection + new Mellum case pass)every_variant_is_registered_for_arch,all_model_types_covers_every_variant) passReal-model validation
Validated on the real
JetBrains/Mellum2-12B-A2.5B-Basecheckpoint (24GB bf16, 12B total / 2.5B active, 64 experts top-8) on M1 Ultra (Apple GPU / Metal): loaded in ~5.8s at 22.63 GB bf16 resident, greedy decode ~62-65 tok/s.mlxcel generateproduced correct, coherent code on two prompts: adef quicksort(arr):completion that emitted the canonical pivot / left / middle / right recursive implementation, and adef binary_search(sorted_list, target):completion that emitted a correct low/high/mid loop. Degenerate output would have exposed a YaRN, QK-norm, or hybrid-attention bug, so this confirms the model graph. The router softmax was then upgraded tosoftmax_precisefor exact parity with upstream and re-validated to still produce coherent output.Closes #361