fix(olmoe): score MoE router with full softmax then gather, not top-k softmax#391
Merged
Conversation
… softmax OLMoE-1B-7B-0125-Instruct produced greedy temp-0 output that started coherent then drifted into non-sequiturs. The defect is in the MoE router scoring in SparseMoeBlock::forward, not the #316 q_norm/k_norm fix and not the fused MoE kernel (the corruption reproduces with MLXCEL_FUSED_MOE=0 on the gather_qmm path). The router scored the top-k experts by softmaxing only the selected top-k logits. That always sums to 1, i.e. it silently behaves as if norm_topk_prob were always true. ml-explore/mlx-lm OlmoeSparseMoeBlock instead softmaxes over ALL experts first, gathers those full-softmax probabilities at the top-k, and renormalizes only when norm_topk_prob is set. The two are equal only when norm_topk_prob is true. OLMoE-1B-7B-0125 ships norm_topk_prob=false, so the correct scores are the full-softmax-over-64 probabilities at the top-8 experts, which sum to less than 1 and must not be renormalized. The top-k-only softmax over-weighted the MoE block output by ~1/(sum of top-k probs) every layer, and the error compounded with depth and generation length, producing the observed drift. Extract the routing into router_topk_scores, which softmaxes over all experts with precise f32 accumulation (matching mlx-lm's precise=True via softmax_precise), gathers at the top-k via take_along_axis, and renormalizes only when norm_topk_prob is set. The expert selection is unchanged: argpartition still runs on the raw logits and softmax is monotonic, so the top-k index set is byte-identical to before and only the scores change. Both the fused-kernel path and the SwitchGLU + moe_weighted_sum fallback consume the same indices and scores, so the fix applies to both. Add olmoe_tests.rs with two model-free unit tests over known router logits: with norm_topk_prob=false the scores equal the full softmax gathered at the top-k and sum to < 1 (not 1, which is the pre-fix bug signature); with norm_topk_prob=true they equal the renormalized top-k and sum to 1. The existing fused-dispatch gate test moves into the same file. Closes #318
Record the test checkpoint and note that the full-softmax-then-gather router fix (#318) resolved degenerate greedy temp-0 output; perf sweep still pending.
Member
Author
PR FinalizationTests
Documentation
Lint / Format
Commit
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
OLMoE-1B-7B-0125-Instruct produced greedy temp-0 output that started coherent ("Paris, France") and then drifted into non-sequiturs. The defect is the MoE router scoring in
SparseMoeBlock::forward, not the #316 q_norm/k_norm fix and not the fused decode-MoE kernel: the corruption reproduces withMLXCEL_FUSED_MOE=0(the gather_qmm baseline).Root cause
The router scored the selected experts by softmaxing only the top-k logits (
softmax(take_along_axis(logits, topk_indices))). That always sums to 1, i.e. it silently behaves as ifnorm_topk_probwere always true. ml-explore/mlx-lmOlmoeSparseMoeBlock(https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/olmoe.py) instead softmaxes over all experts first, gathers those full-softmax probabilities at the top-k, and renormalizes only whennorm_topk_probis set:The two formulations are equal only when
norm_topk_prob == true. OLMoE-1B-7B-0125 shipsnorm_topk_prob = false(verified in config.json), so the correct scores are the full-softmax-over-64 probabilities at the top-8 experts, which sum to less than 1 and must not be renormalized. The top-k-only softmax over-weighted the MoE block output by ~1/(sum of top-k probs) at every layer; the error compounded with depth and generation length, producing the "starts coherent then degrades" drift.What changed
src/models/olmoe.rs: extract the routing intorouter_topk_scores(logits, k, norm_topk_prob), which softmaxes over all experts with precise f32 accumulation (matching mlx-lm'sprecise=Trueviasoftmax_precise), gathers at the top-k viatake_along_axis, and renormalizes only whennorm_topk_probis set.SparseMoeBlock::forwardnow calls it; both the fused-kernel path and the SwitchGLU +moe_weighted_sumfallback consume the same indices and scores, so the fix covers both.argpartitionstill runs on the raw logits and softmax is monotonic, so the top-k index set is byte-identical to before. Only the scores change. q_norm/k_norm, RoPE, scale, and cache code are untouched.src/models/olmoe_tests.rs(new): two model-free unit tests over known router logits, plus the existing fused-dispatch gate test moved here.Test plan
cargo check -p mlxcel --lib(clean)cargo clippy -p mlxcel --lib --tests -- -D warnings(clean)cargo fmt -p mlxcelcargo test -p mlxcel --lib olmoe::tests -- --include-ignored --test-threads=1(3 passed):router_scores_are_full_softmax_gathered_when_norm_topk_prob_falseasserts the scores equal the full softmax gathered at the top-k and sum to < 1 (not 1, the pre-fix bug signature);router_scores_are_renormalized_when_norm_topk_prob_trueasserts the renormalized top-k sums to 1.Pending (out of this sandbox's scope, for the maintainer): release build and a real
OLMoE-1B-7B-0125-Instruct-4bitgeneration to confirm coherent greedy temp-0 output, that it stays within the f16 jitter class of mlx-lm, and thatMLXCEL_FUSED_MOE=1remains within the jitter class of the gather_qmm baseline on the now-healthy model.Closes #318