Skip to content

fix(olmoe): score MoE router with full softmax then gather, not top-k softmax#391

Merged
inureyes merged 2 commits into
mainfrom
fix/issue-318-olmoe-router-scores
Jun 21, 2026
Merged

fix(olmoe): score MoE router with full softmax then gather, not top-k softmax#391
inureyes merged 2 commits into
mainfrom
fix/issue-318-olmoe-router-scores

Conversation

@inureyes

Copy link
Copy Markdown
Member

Summary

OLMoE-1B-7B-0125-Instruct produced greedy temp-0 output that started coherent ("Paris, France") and then drifted into non-sequiturs. The defect is the MoE router scoring in SparseMoeBlock::forward, not the #316 q_norm/k_norm fix and not the fused decode-MoE kernel: the corruption reproduces with MLXCEL_FUSED_MOE=0 (the gather_qmm baseline).

Root cause

The router scored the selected experts by softmaxing only the top-k logits (softmax(take_along_axis(logits, topk_indices))). That always sums to 1, i.e. it silently behaves as if norm_topk_prob were always true. ml-explore/mlx-lm OlmoeSparseMoeBlock (https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/olmoe.py) instead softmaxes over all experts first, gathers those full-softmax probabilities at the top-k, and renormalizes only when norm_topk_prob is set:

routing_weights = mx.softmax(router_logits, axis=1, precise=True)   # over all 64 experts
indices = argpartition(-routing_weights, kth=k-1)[..., :k]
scores = take_along_axis(routing_weights, indices, axis=-1)
if norm_topk_prob:
    scores = scores / scores.sum(axis=-1, keepdims=True)

The two formulations are equal only when norm_topk_prob == true. OLMoE-1B-7B-0125 ships norm_topk_prob = false (verified in config.json), so the correct scores are the full-softmax-over-64 probabilities at the top-8 experts, which sum to less than 1 and must not be renormalized. The top-k-only softmax over-weighted the MoE block output by ~1/(sum of top-k probs) at every layer; the error compounded with depth and generation length, producing the "starts coherent then degrades" drift.

What changed

  • src/models/olmoe.rs: extract the routing into router_topk_scores(logits, k, norm_topk_prob), which softmaxes over all experts with precise f32 accumulation (matching mlx-lm's precise=True via softmax_precise), gathers at the top-k via take_along_axis, and renormalizes only when norm_topk_prob is set. SparseMoeBlock::forward now calls it; both the fused-kernel path and the SwitchGLU + moe_weighted_sum fallback consume the same indices and scores, so the fix covers both.
  • Expert selection is unchanged: argpartition still runs on the raw logits and softmax is monotonic, so the top-k index set is byte-identical to before. Only the scores change. q_norm/k_norm, RoPE, scale, and cache code are untouched.
  • src/models/olmoe_tests.rs (new): two model-free unit tests over known router logits, plus the existing fused-dispatch gate test moved here.

Test plan

  • cargo check -p mlxcel --lib (clean)
  • cargo clippy -p mlxcel --lib --tests -- -D warnings (clean)
  • cargo fmt -p mlxcel
  • cargo test -p mlxcel --lib olmoe::tests -- --include-ignored --test-threads=1 (3 passed): router_scores_are_full_softmax_gathered_when_norm_topk_prob_false asserts the scores equal the full softmax gathered at the top-k and sum to < 1 (not 1, the pre-fix bug signature); router_scores_are_renormalized_when_norm_topk_prob_true asserts the renormalized top-k sums to 1.

Pending (out of this sandbox's scope, for the maintainer): release build and a real OLMoE-1B-7B-0125-Instruct-4bit generation to confirm coherent greedy temp-0 output, that it stays within the f16 jitter class of mlx-lm, and that MLXCEL_FUSED_MOE=1 remains within the jitter class of the gather_qmm baseline on the now-healthy model.

Closes #318

… softmax

OLMoE-1B-7B-0125-Instruct produced greedy temp-0 output that started coherent then drifted into non-sequiturs. The defect is in the MoE router scoring in SparseMoeBlock::forward, not the #316 q_norm/k_norm fix and not the fused MoE kernel (the corruption reproduces with MLXCEL_FUSED_MOE=0 on the gather_qmm path).

The router scored the top-k experts by softmaxing only the selected top-k logits. That always sums to 1, i.e. it silently behaves as if norm_topk_prob were always true. ml-explore/mlx-lm OlmoeSparseMoeBlock instead softmaxes over ALL experts first, gathers those full-softmax probabilities at the top-k, and renormalizes only when norm_topk_prob is set. The two are equal only when norm_topk_prob is true. OLMoE-1B-7B-0125 ships norm_topk_prob=false, so the correct scores are the full-softmax-over-64 probabilities at the top-8 experts, which sum to less than 1 and must not be renormalized. The top-k-only softmax over-weighted the MoE block output by ~1/(sum of top-k probs) every layer, and the error compounded with depth and generation length, producing the observed drift.

Extract the routing into router_topk_scores, which softmaxes over all experts with precise f32 accumulation (matching mlx-lm's precise=True via softmax_precise), gathers at the top-k via take_along_axis, and renormalizes only when norm_topk_prob is set. The expert selection is unchanged: argpartition still runs on the raw logits and softmax is monotonic, so the top-k index set is byte-identical to before and only the scores change. Both the fused-kernel path and the SwitchGLU + moe_weighted_sum fallback consume the same indices and scores, so the fix applies to both.

Add olmoe_tests.rs with two model-free unit tests over known router logits: with norm_topk_prob=false the scores equal the full softmax gathered at the top-k and sum to < 1 (not 1, which is the pre-fix bug signature); with norm_topk_prob=true they equal the renormalized top-k and sum to 1. The existing fused-dispatch gate test moves into the same file.

Closes #318
@inureyes inureyes added type:bug Bug fixes, error corrections, or issue resolutions priority:high High priority area:models Model architectures, weights, loading, metadata area:inference Generation, sampling, decoding (incl. speculative, DRY) platform:macos macOS (Apple Silicon) specific status:review Under review labels Jun 21, 2026
Record the test checkpoint and note that the full-softmax-then-gather router
fix (#318) resolved degenerate greedy temp-0 output; perf sweep still pending.
@inureyes

Copy link
Copy Markdown
Member Author

PR Finalization

Tests

src/models/olmoe_tests.rs already covers both semantics: norm_topk_prob=false (full-softmax gathered, sum < 1) and norm_topk_prob=true (renormalized, sum 1). No gaps found. The existing three tests are sufficient.

Documentation

  • docs/supported-models.md: OLMoE is listed at line 37 (OLMo / OLMo2 / OLMo3 / OLMoE). No change needed there.
  • docs/benchmark_results/model_tests_m1ultra.md: updated the olmoe row to record the test checkpoint (OLMoE-1B-7B-0125-Instruct-4bit) and a note that the router scoring fix (fix(olmoe): OLMoE-1B-7B produces corrupted/degenerate output beyond the q_norm fix (pre-existing, not MoE-related) #318) resolved degenerate greedy temp-0 output; status remains since perf numbers have not been measured yet. Consistent with how other rows record fixes (e.g. bf16 decode fix).
  • PR body fully documents the root cause, fix, and test plan.

Lint / Format

  • cargo fmt --check: clean (no output)
  • cargo clippy -p mlxcel --lib --tests -- -D warnings: clean (Finished with no warnings)

Commit

e651636cc docs(olmoe): note router scoring fix (#318) in M1 Ultra benchmark table

@inureyes inureyes added status:done Completed and removed status:review Under review labels Jun 21, 2026
@inureyes inureyes merged commit 25eecc8 into main Jun 21, 2026
1 check passed
@inureyes inureyes deleted the fix/issue-318-olmoe-router-scores branch June 21, 2026 19:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:inference Generation, sampling, decoding (incl. speculative, DRY) area:models Model architectures, weights, loading, metadata platform:macos macOS (Apple Silicon) specific priority:high High priority status:done Completed type:bug Bug fixes, error corrections, or issue resolutions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

fix(olmoe): OLMoE-1B-7B produces corrupted/degenerate output beyond the q_norm fix (pre-existing, not MoE-related)

1 participant