Skip to content

feat(moe): add moe support and fused topk & moe kernels#37

Merged
drewjin merged 6 commits intoSJTU-DENG-Lab:mainfrom
MikanAffine:fusedmoe
Apr 14, 2026
Merged

feat(moe): add moe support and fused topk & moe kernels#37
drewjin merged 6 commits intoSJTU-DENG-Lab:mainfrom
MikanAffine:fusedmoe

Conversation

@MikanAffine
Copy link
Copy Markdown
Contributor

@MikanAffine MikanAffine commented Mar 31, 2026

Feature:

  • add MoE support to SDAR-MOE
  • add expert parallelism support
  • refactored MoE code structure like SGLang style
  • add fused TopK and fused MoE triton kernel, and unit tests

Summary by CodeRabbit

  • New Features

    • Added expert-parallelism (EP) configuration and support for distributed MoE execution.
    • Introduced fused, high-performance MoE and Top‑K kernels (Triton/CUDA) for faster inference.
    • Improved checkpoint weight resolution and flexible weight-loading behaviors.
    • Sampler selection extended with an additional MoE-aware option.
  • Refactor

    • Overhauled model-parallel initialization and utilities to unify tensor- and expert-parallel layouts.
  • Chores

    • Updated ignore list to exclude new local tooling files.

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 31, 2026

📝 Walkthrough

Walkthrough

Replaces the legacy sparse MoE with a unified FusedMoE framework (trivial/tensor-parallel/expert-parallel), adds Triton fused kernels for top‑k routing and expert GEMMs, introduces TP/EP parallelism metadata and init helpers, and refactors checkpoint weight resolution/loading.

Changes

Cohort / File(s) Summary
MoE public surface
diffulex/moe/__init__.py, diffulex/moe/moe_impl.py
Removed legacy SparseMoEBlock (module deleted) and switched MoE construction to build_moe_block(...)/FusedMoE exports.
FusedMoE base & impls
diffulex/moe/layer/base.py, diffulex/moe/layer/__init__.py, diffulex/moe/layer/trivial_impl.py, diffulex/moe/layer/tp_impl.py, diffulex/moe/layer/ep_impl.py
Added FusedMoE abstract class and three concrete implementations (TrivialFusedMoE, TPFusedMoE, EPFusedMoE) with specialized forward, checkpoint resolution, and expert-weight sharding/replication logic; added build_moe_block factory.
Top-k router abstraction
diffulex/moe/topk/base.py, diffulex/moe/topk/datatype.py, diffulex/moe/topk/trivial.py, diffulex/moe/topk/triton.py, diffulex/moe/topk/__init__.py
Introduced TopKRouter base, TopKOutput dataclass, TrivialTopKRouter and TritonFusedTopKRouter, plus build_topk_router() factory.
Triton kernels (native)
diffulex_kernel/python/fused_moe_triton.py, diffulex_kernel/python/fused_topk_triton.py, diffulex_kernel/__init__.py
Added fused_moe(...) and fused_topk(...) Triton JIT implementations and lazy kernel exports (fused_moe, fused_topk) in kernel package.
Parallelism infra
diffulex/utils/parallelism.py
New ModelParallelismMetadata and process-local storage plus accessors (get_tp_rank, get_ep_world_size, is_tp_enabled, is_ep_enabled), init_process_group and helpers to manage TP/EP layouts.
Checkpoint loader refactor
diffulex/utils/checkpoint.py, diffulex/utils/loader.py
Added LoadContext and ResolvedWeight dataclasses; loader now resolves per-weight specs via resolve_weight_spec and apply_resolved_weight, with sequential fallbacks and new helper functions for packed/direct loading.
Config & workers
diffulex/config.py, diffulex/engine/dp_worker.py, diffulex/engine/tp_worker.py, diffulex/engine/model_runner.py
Added expert_parallel_size to config; worker/device/child spawning logic updated to compute combined model-parallel world size from TP and EP; process-group init and model-runner world/rank handling replaced by new parallelism utilities.
Model & layer TP accessors
diffulex/layer/embed_head.py, diffulex/layer/linear.py, diffulex/model/dream.py, diffulex/model/fast_dllm_v2.py, diffulex/model/llada.py, diffulex/model/sdar.py
Replaced direct torch.distributed world-size/rank calls with get_tp_rank() / get_tp_world_size() accessors for consistent tensor-parallel sizing.
Sampler & misc
diffulex/sampler/sdar.py, .gitignore
Registered SDARSampler also under sdar_moe; added .claude and .codex to .gitignore.

Sequence Diagram(s)

sequenceDiagram
    participant Tokens as Tokens (hidden_states)
    participant Gate as Replicated Gate
    participant Router as TopK Router
    participant Expert as Local Experts
    participant AllReduce as All-Reduce / Scatter
    participant Output as Output (hidden_states, logits)

    Tokens->>Gate: compute router_logits
    Gate->>Router: router_logits
    Router->>Router: compute scores (softmax/sigmoid) and select top-k
    Router->>Expert: topk_ids, topk_weights
    Expert->>Expert: packed GEMM / SiLU gating → per-slot outputs
    Expert->>AllReduce: aggregate across TP/EP (all_reduce / a2a)
    AllReduce->>Output: accumulate & reshape
    Output->>Output: return (hidden_states, router_logits)
Loading
sequenceDiagram
    participant Loader as Checkpoint Loader
    participant Resolver as resolve_weight_spec()
    participant Module as Module Hierarchy
    participant Spec as ResolvedWeight
    participant Applier as apply_resolved_weight()
    participant Param as Parameter/Buffer

    Loader->>Resolver: request resolution for weight_name
    Resolver->>Module: walk module tree for resolve_checkpoint_weight
    Module-->>Spec: return ResolvedWeight (or None)
    Spec->>Applier: apply_resolved_weight(loaded_tensor)
    Applier->>Param: call loader / assign param / write buffer
    Param->>Loader: weight applied
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Poem

🐰 I hopped through layers, gates, and kernels bright,

Replaced old sparse paths with fused delight.
Triton hummed, routers chose the best,
TP and EP now share one nested nest.
Checkpoints whisper where each weight belongs—hop on, take flight!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.24% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title accurately reflects the primary changes: adding MoE support, expert parallelism, and new fused TopK/MoE kernels, which are the central features of this changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (3)
diffulex/moe/__init__.py (1)

10-16: Remove the dead fallback or make it a real branch.

Line 15 returns unconditionally, so Line 16 can never execute. If SparseMoEBlock is still meant to be a fallback, this needs an availability/config check instead of an unreachable return.

Suggested cleanup
 def build_mlp_or_moe(config, layer_idx: int, dense_factory):
     """Build a dense MLP or MoE block according to the config."""
     if is_moe_layer(config, layer_idx):
         return FusedSparseMoEBlock.from_config(config)
-        return SparseMoEBlock.from_config(config)
     return dense_factory()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/__init__.py` around lines 10 - 16, The function build_mlp_or_moe
currently returns FusedSparseMoEBlock.from_config(config) unconditionally,
making the subsequent return of SparseMoEBlock.from_config(config) dead code;
either remove the unreachable fallback or implement a real branch that chooses
SparseMoEBlock when FusedSparseMoEBlock is unavailable. Update build_mlp_or_moe
to check availability (e.g., try/except ImportError or a feature flag) before
calling FusedSparseMoEBlock.from_config(config) and only call
SparseMoEBlock.from_config(config) when the fused implementation is not
available, or delete the redundant return line if the fused block is the sole
supported implementation; reference the symbols build_mlp_or_moe,
FusedSparseMoEBlock, SparseMoEBlock, and is_moe_layer when making the change.
diffulex_kernel/__init__.py (1)

42-50: Keep the lazy-export surface consistent.

fused_topk is exposed through both __getattr__ and __all__, but fused_moe is only exposed through __getattr__ while Line 62 keeps it commented out. If fused_moe is public, export it consistently; if it is private, hiding it in one place and exposing it in another is confusing.

Also applies to: 55-63

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/__init__.py` around lines 42 - 50, The lazy-export surface is
inconsistent: __getattr__ exposes fused_topk and fused_moe but __all__ only
lists fused_topk (fused_moe is commented out). Make the exports consistent by
either adding "fused_moe" to the module-level __all__ list (or uncommenting the
existing entry) if it should be public, or remove/deny export in __getattr__ for
fused_moe if it should be private; update the code paths around __getattr__, the
fused_topk and fused_moe import lines, and the __all__ definition so both
symbols are treated the same way.
diffulex/moe/topk.py (1)

9-10: Lazy-import the Triton backend.

TopKRouter(impl="torch") and topk_pytorch_reference() do not need the kernel package, but the module-level from diffulex_kernel import fused_topk makes diffulex.moe.topk depend on that stack at import time anyway. If diffulex_kernel resolves the Triton module eagerly, CPU/reference users will fail before they can ever select the torch path.

♻️ Suggested change
-from diffulex_kernel import fused_topk
-
 def topk_pytorch_reference(
     router_logits: torch.Tensor,
     top_k: int,
@@
         if impl == "torch":
             self.impl = topk_pytorch_reference
         elif impl == "triton":
+            from diffulex_kernel import fused_topk
             self.impl = fused_topk
         else:
             raise ValueError(f"Unsupported impl: {impl!r}")

Also applies to: 59-64

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/topk.py` around lines 9 - 10, The module currently imports
fused_topk at top-level causing an eager dependency; change to lazy-import
diffulex_kernel.fused_topk only where needed: move the import into the code
paths that actually call it (e.g., inside TopKRouter implementation branch that
selects the Triton backend and inside the function that calls fused_topk), so
TopKRouter(impl="torch") and topk_pytorch_reference() can be imported without
resolving diffulex_kernel; ensure you reference fused_topk by importing it
locally just before use and keep existing function/class names (TopKRouter,
topk_pytorch_reference, fused_topk) to locate the spots to modify.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@diffulex_kernel/python/fused_moe_triton.py`:
- Around line 21-23: Rename the single-letter kernel dimension `I` to a
descriptive name (e.g., `I_dim` or `INPUT_SIZE`) both in the Triton kernel
signature (the `I: tl.constexpr` parameter) and in the wrapper/local variables
that reference it so Ruff E741 is resolved; update every usage (pointer
arithmetic, index calculations, and any calls that pass the argument) to the new
name, including the other occurrences noted around the later blocks (the local
wrapper variable and the other kernel signatures/usages), ensuring all
references (e.g., kernel definition, launch invocation, and any local variable
named `I`) are consistently renamed.
- Around line 266-343: Replace the fragile asserts in _run_fused_moe_kernels
(and validate in fused_moe caller) with explicit runtime checks that raise clear
exceptions: verify w13 and w2 are 3D tensors, check w13.shape[1] % 2 == 0 and
compute I = w13.shape[1] // 2 only after that, ensure w13.shape[2] == H and
w2.shape == (E, H, I), confirm topk_ids and topk_weights are 2D with identical
shapes and that topk_ids.shape[0] == M; raise ValueError (or TypeError) with
descriptive messages naming the offending tensor (w13, w2, topk_ids,
topk_weights) so kernel launches fail fast with clear Python errors.

In `@diffulex/engine/dllm_block.py`:
- Around line 36-47: The file defines __getstate__ and __setstate__ multiple
times which causes the later/older definitions to shadow the new weakref-based
serialization path; consolidate to a single pair of methods that implement the
weakref handling: keep the implementations that use weakref_fn and convert
s['_req'] and s['_dllm_block_buffer'] to/from weak references, remove or merge
the older duplicate __getstate__/__setstate__ definitions so only one canonical
implementation remains, and ensure the final __setstate__/__getstate__ reference
the weakref_fn helper and the _req and _dllm_block_buffer attributes
consistently.

In `@diffulex/moe/fused_moe.py`:
- Around line 29-38: Constructor currently accepts arbitrary hidden_act causing
instantiation of a fused MoE that only supports "silu"; update the validation to
fail fast by checking hidden_act in __init__ (and the alternate
constructor/loader referenced at lines ~148-156, e.g., from_config or similar
factory) and raise a clear ValueError if hidden_act != "silu" that points users
to use the unfused MoE block instead; ensure the check is implemented early in
the FusedMoE initialization path (reference symbols: __init__, hidden_act,
from_config) so unsupported configs never create a fused instance.

In `@test/python/kernel/test_fused_moe.py`:
- Around line 12-18: Rename the ambiguous dimension name `I` to a descriptive
name (e.g., `intermediate_size` or `intermediate_dim`) everywhere in this test
module: update the function signature of fused_moe_pytorch_reference (change
comments and type hints for w13 and w2), update usages inside
fused_moe_pytorch_reference, update the helper function `_run_test` and any
local variables or test locals that use `I` (including the later locals around
lines 58-69) so all occurrences are consistently renamed and lint E741 is
resolved.
- Around line 375-395: The test named test_determinism does not actually
guarantee determinism because top_k=2 allows atomic_add race-induced FP32
variations; update the test to either set top_k=1 (change the local top_k
variable to 1 so fused_moe runs without expert conflicts and true determinism is
validated) or rename the test (e.g., test_approximation) and its docstring to
reflect it verifies bounded numerical closeness for fused_moe with top_k=2; also
update the test name/docstring and any inline comment accordingly so readers and
CI expectations match the chosen behavior.

---

Nitpick comments:
In `@diffulex_kernel/__init__.py`:
- Around line 42-50: The lazy-export surface is inconsistent: __getattr__
exposes fused_topk and fused_moe but __all__ only lists fused_topk (fused_moe is
commented out). Make the exports consistent by either adding "fused_moe" to the
module-level __all__ list (or uncommenting the existing entry) if it should be
public, or remove/deny export in __getattr__ for fused_moe if it should be
private; update the code paths around __getattr__, the fused_topk and fused_moe
import lines, and the __all__ definition so both symbols are treated the same
way.

In `@diffulex/moe/__init__.py`:
- Around line 10-16: The function build_mlp_or_moe currently returns
FusedSparseMoEBlock.from_config(config) unconditionally, making the subsequent
return of SparseMoEBlock.from_config(config) dead code; either remove the
unreachable fallback or implement a real branch that chooses SparseMoEBlock when
FusedSparseMoEBlock is unavailable. Update build_mlp_or_moe to check
availability (e.g., try/except ImportError or a feature flag) before calling
FusedSparseMoEBlock.from_config(config) and only call
SparseMoEBlock.from_config(config) when the fused implementation is not
available, or delete the redundant return line if the fused block is the sole
supported implementation; reference the symbols build_mlp_or_moe,
FusedSparseMoEBlock, SparseMoEBlock, and is_moe_layer when making the change.

In `@diffulex/moe/topk.py`:
- Around line 9-10: The module currently imports fused_topk at top-level causing
an eager dependency; change to lazy-import diffulex_kernel.fused_topk only where
needed: move the import into the code paths that actually call it (e.g., inside
TopKRouter implementation branch that selects the Triton backend and inside the
function that calls fused_topk), so TopKRouter(impl="torch") and
topk_pytorch_reference() can be imported without resolving diffulex_kernel;
ensure you reference fused_topk by importing it locally just before use and keep
existing function/class names (TopKRouter, topk_pytorch_reference, fused_topk)
to locate the spots to modify.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2ad030d3-f366-4b9c-9e61-c3dd60f9c15e

📥 Commits

Reviewing files that changed from the base of the PR and between 9ead055 and 79f111a.

📒 Files selected for processing (15)
  • diffulex/engine/dllm_block.py
  • diffulex/mixin/multi_block/engine/request.py
  • diffulex/model/sdar_moe.py
  • diffulex/moe/__init__.py
  • diffulex/moe/fused_moe.py
  • diffulex/moe/moe_impl.py
  • diffulex/moe/topk.py
  • diffulex/sampler/sdar.py
  • diffulex/utils/loader.py
  • diffulex_kernel/__init__.py
  • diffulex_kernel/python/fused_moe_triton.py
  • diffulex_kernel/python/fused_topk_triton.py
  • pyproject.toml
  • test/python/kernel/test_fused_moe.py
  • test/python/kernel/test_fused_topk.py

Comment on lines +36 to +47
def __getstate__(self):
s = self.__dict__.copy()
s['_req'] = s['_req']()
if '_dllm_block_buffer' in s:
s['_dllm_block_buffer'] = s['_dllm_block_buffer']()
return s

def __setstate__(self, state):
s = self.__dict__ = state.copy()
s['_req'] = weakref_fn(s['_req'])
if '_dllm_block_buffer' in s:
s['_dllm_block_buffer'] = weakref_fn(s['_dllm_block_buffer'])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

These new pickle hooks are shadowed by the older ones below.

Lines 73-80 and 223-229 redefine __getstate__/__setstate__, so Python drops the versions added here. That makes the new weakref rehydration path unreachable and leaves two conflicting serialization implementations in the same file.

Also applies to: 198-205

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/engine/dllm_block.py` around lines 36 - 47, The file defines
__getstate__ and __setstate__ multiple times which causes the later/older
definitions to shadow the new weakref-based serialization path; consolidate to a
single pair of methods that implement the weakref handling: keep the
implementations that use weakref_fn and convert s['_req'] and
s['_dllm_block_buffer'] to/from weak references, remove or merge the older
duplicate __getstate__/__setstate__ definitions so only one canonical
implementation remains, and ensure the final __setstate__/__getstate__ reference
the weakref_fn helper and the _req and _dllm_block_buffer attributes
consistently.

Comment on lines +29 to +38
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_experts: int,
top_k: int,
*,
hidden_act: str = "silu",
norm_topk_prob: bool = True,
) -> None:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Fail fast on unsupported activations.

from_config() forwards arbitrary config.hidden_act, but the backend currently only supports "silu". Right now an unsupported model will instantiate fine and then abort on its first forward. Guard it here or route those configs to the unfused MoE block.

♻️ Suggested change
         self.num_experts = num_experts
         self.top_k = top_k
         self.norm_topk_prob = norm_topk_prob
-        self.hidden_act = hidden_act
+        if hidden_act != "silu":
+            raise ValueError(
+                "FusedSparseMoEBlock only supports hidden_act='silu'"
+            )
+        self.hidden_act = hidden_act

Also applies to: 148-156

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/fused_moe.py` around lines 29 - 38, Constructor currently
accepts arbitrary hidden_act causing instantiation of a fused MoE that only
supports "silu"; update the validation to fail fast by checking hidden_act in
__init__ (and the alternate constructor/loader referenced at lines ~148-156,
e.g., from_config or similar factory) and raise a clear ValueError if hidden_act
!= "silu" that points users to use the unfused MoE block instead; ensure the
check is implemented early in the FusedMoE initialization path (reference
symbols: __init__, hidden_act, from_config) so unsupported configs never create
a fused instance.

Comment on lines +12 to +18
def fused_moe_pytorch_reference(
hidden_states: torch.Tensor, # (M, H)
w13: torch.Tensor, # (E, 2*I, H)
w2: torch.Tensor, # (E, H, I)
topk_weights: torch.Tensor, # (M, top_k)
topk_ids: torch.Tensor, # (M, top_k)
top_k: int,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Rename I across this file.

Ruff is already flagging every I here as E741, so the new test module will stay lint-red until the dimension name is expanded (intermediate_size, intermediate_dim, etc.). The same rename needs to be applied in _run_test and the later test locals.

Also applies to: 58-69

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@test/python/kernel/test_fused_moe.py` around lines 12 - 18, Rename the
ambiguous dimension name `I` to a descriptive name (e.g., `intermediate_size` or
`intermediate_dim`) everywhere in this test module: update the function
signature of fused_moe_pytorch_reference (change comments and type hints for w13
and w2), update usages inside fused_moe_pytorch_reference, update the helper
function `_run_test` and any local variables or test locals that use `I`
(including the later locals around lines 58-69) so all occurrences are
consistently renamed and lint E741 is resolved.

Comment on lines +375 to +395
def test_determinism():
"""Same inputs → same outputs across two runs."""
device = "cuda"
M, H, I, E, top_k = 64, 128, 64, 8, 2

torch.manual_seed(42)
hidden_states = torch.randn(M, H, device=device) * 0.1
w13 = torch.randn(E, 2 * I, H, device=device) * 0.1
w2 = torch.randn(E, H, I, device=device) * 0.1
router_logits = torch.randn(M, E, device=device)
scores = torch.softmax(router_logits, dim=-1)
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_ids = topk_ids.to(torch.int32)

out1 = fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, top_k)
out2 = fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, top_k)

# Note: atomic_add may have non-deterministic ordering for fp32,
# but values should still be very close
torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -type f -name "*.py" | grep -E "(fused_moe|test_fused)" | head -20

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 217


🏁 Script executed:

find . -type f -name "*fused_moe*" | head -20

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 177


🏁 Script executed:

cat -n ./diffulex_kernel/python/fused_moe_triton.py | head -150

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 6333


🏁 Script executed:

rg "atomic_add" ./diffulex_kernel/python/fused_moe_triton.py -A 3 -B 3

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 351


🏁 Script executed:

cat -n ./test/python/kernel/test_fused_moe.py | sed -n '370,400p'

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 1621


🏁 Script executed:

cat -n ./diffulex/moe/fused_moe.py | head -100

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 3809


🏁 Script executed:

rg "determinism|atomic" ./test/python/kernel/test_fused_moe.py -B 2 -A 2

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 497


🏁 Script executed:

wc -l ./diffulex_kernel/python/fused_moe_triton.py

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 114


🏁 Script executed:

cat -n ./diffulex_kernel/python/fused_moe_triton.py | sed -n '150,180p'

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 1515


🏁 Script executed:

rg "top_k" ./test/python/kernel/test_fused_moe.py | head -20

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 960


🏁 Script executed:

rg "deterministic|CUDA_LAUNCH_BLOCKING" . --type py -l

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 216


🏁 Script executed:

cat -n ./diffulex_kernel/python/fused_moe_triton.py | sed -n '175,185p'

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 533


Rename or adjust this test—it does not validate determinism.

With top_k=2 and tl.atomic_add(), thread ordering is non-deterministic on GPU, causing rounding variations. The test mitigates this with relaxed tolerances (atol=1e-5, rtol=1e-5), making it a bounded-approximation test, not a determinism test. Either:

  1. Use top_k=1 to eliminate expert conflicts and actual determinism, or
  2. Rename to test_approximation() or similar to reflect what it actually validates.
🧰 Tools
🪛 Ruff (0.15.7)

[error] 378-378: Ambiguous variable name: I

(E741)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@test/python/kernel/test_fused_moe.py` around lines 375 - 395, The test named
test_determinism does not actually guarantee determinism because top_k=2 allows
atomic_add race-induced FP32 variations; update the test to either set top_k=1
(change the local top_k variable to 1 so fused_moe runs without expert conflicts
and true determinism is validated) or rename the test (e.g., test_approximation)
and its docstring to reflect it verifies bounded numerical closeness for
fused_moe with top_k=2; also update the test name/docstring and any inline
comment accordingly so readers and CI expectations match the chosen behavior.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
diffulex/moe/topk/trivial.py (1)

21-22: ⚠️ Potential issue | 🟡 Minor

Potential division by zero in renormalization.

If topk_weights sums to zero (possible with extreme numerical edge cases), this division produces NaN. Consider adding a small epsilon for robustness, consistent with the Triton kernel which uses tl.maximum(selected_sum, 1e-20).

Suggested fix
         if self.renormalize:
-            topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
+            topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True).clamp(min=1e-20)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/topk/trivial.py` around lines 21 - 22, The renormalization step
in trivial.py can divide by zero when topk_weights.sum(...) == 0; modify the
block in the method where self.renormalize is used to compute a safe denominator
(e.g., denom = topk_weights.sum(dim=-1, keepdim=True).clamp_min(1e-20) or
torch.maximum(..., torch.tensor(1e-20, device=topk_weights.device))) and then do
topk_weights = topk_weights / denom so the division is robust to zero sums (use
1e-20 to match the Triton kernel).
♻️ Duplicate comments (1)
diffulex_kernel/python/fused_moe_triton.py (1)

41-51: ⚠️ Potential issue | 🟡 Minor

Missing validation: topk_ids.shape[0] must equal hidden_states.shape[0].

The validation checks that topk_ids and topk_weights have matching shapes, but does not verify that the number of tokens in topk_ids matches hidden_states. A mismatch would cause incorrect indexing in the kernel.

🛡️ Proposed fix
     if topk_ids.shape != topk_weights.shape:
         raise ValueError(
             f"topk_ids and topk_weights must have the same shape, got {topk_ids.shape} and {topk_weights.shape}."
         )
+    if topk_ids.shape[0] != hidden_states.shape[0]:
+        raise ValueError(
+            f"topk_ids must have one row per token, got {topk_ids.shape[0]} rows but {hidden_states.shape[0]} tokens."
+        )
     if w13.shape[0] != w2.shape[0] or w13.shape[2] != hidden_states.shape[1]:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/python/fused_moe_triton.py` around lines 41 - 51, The code
validates shapes of topk_ids/topk_weights and weight matrices but misses
verifying that the number of tokens in topk_ids matches hidden_states, which can
break indexing; add a validation after the existing topk_ids/topk_weights check
that asserts topk_ids.shape[0] == hidden_states.shape[0] and raises a ValueError
with a clear message referencing topk_ids.shape and hidden_states.shape so
callers can see the mismatch (place this check alongside the other shape
validations in the same block inside fused_moe_triton where topk_ids,
topk_weights, and hidden_states are validated).
🧹 Nitpick comments (16)
diffulex/moe/topk/bypass.py (2)

7-8: Typo in docstring.

Minor: "implemenation" → "implementation".

✏️ Proposed fix
 class BypassTopKRouter(TopKRouter):
-    """Bypass implemenation, use this if fused moe runner also handles topk"""
+    """Bypass implementation, use this if fused moe runner also handles topk."""
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/topk/bypass.py` around lines 7 - 8, Fix the typo in the
BypassTopKRouter class docstring: change "Bypass implemenation, use this if
fused moe runner also handles topk" to "Bypass implementation, use this if fused
moe runner also handles topk" in the BypassTopKRouter (subclass of TopKRouter)
docstring.

10-15: Ensure correct router/runner pairing at runtime.

This bypass router returns weights=None and ids=None, which will cause TritonFusedMoERunner to raise a RuntimeError (per diffulex/moe/runner/triton.py:10-14). The design is intentional for runners that handle top-k internally, but consider adding a note in the docstring about compatible runners to prevent misconfiguration.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/topk/bypass.py` around lines 10 - 15, The bypass router's
forward currently returns weights=None and ids=None which is intentional for
routers that let runners perform top-k, but it lacks documentation and can lead
to a confusing RuntimeError when paired with runners that expect ids/weights
(see TritonFusedMoERunner). Update the forward method's docstring in the bypass
router to clearly state that returning None for weights and ids is intentional
and that this router must only be used with runners that implement internal
top-k (e.g., TritonFusedMoERunner-like runners); additionally, add an optional
runtime compatibility check in forward or router initialization to raise a clear
error message if paired with an incompatible runner instead of allowing the
downstream RuntimeError.
diffulex/utils/loader.py (2)

157-169: Consider using direct attribute access instead of getattr.

Static analysis (B009) flags getattr(param, "weight_loader") with a constant attribute. Since the missing attribute case is already handled by the try/except block, you could use direct attribute access for clarity.

♻️ Proposed fix
             try:
                 param = model.get_parameter(param_name)
-                weight_loader = partial(
-                    getattr(param, "weight_loader"),
-                    param,
-                    loaded_weight,
-                )
+                weight_loader = partial(param.weight_loader, param, loaded_weight)
                 if shard_id is None:
                     weight_loader()
                 else:
                     weight_loader(shard_id)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/utils/loader.py` around lines 157 - 169, The code uses
getattr(param, "weight_loader") inside loader logic though
AttributeError/KeyError is already caught; change to direct attribute access
param.weight_loader when building the partial (in the block that calls
model.get_parameter and constructs weight_loader) to satisfy static analysis and
improve clarity while keeping the same behavior — continue to wrap it in the
existing try/except, create the partial with param.weight_loader, param,
loaded_weight, and then call weight_loader() or weight_loader(shard_id)
depending on shard_id.

139-139: Prefer explicit exception over assert for runtime validation.

Using assert for runtime validation can be disabled with -O flag. Consider raising ValueError for clearer error handling.

🛡️ Proposed fix
-            assert v == "lm_head"
+            if v != "lm_head":
+                raise ValueError(f"Expected 'lm_head' for transformer.ff_out mapping, got '{v}'")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/utils/loader.py` at line 139, Replace the runtime assert "assert v
== 'lm_head'" with an explicit check that raises a ValueError when the condition
fails; specifically, change it to: if v != "lm_head": raise
ValueError(f"Expected 'lm_head' for variable v, got {v!r}"). This ensures the
validation cannot be disabled with -O and provides a clear, descriptive error
message including the actual value.
diffulex/moe/topk/triton.py (1)

2-2: Remove unused import.

torch.nn.functional as F is imported but never used in this file.

🧹 Proposed fix
 import torch
-import torch.nn.functional as F
 
 from diffulex.moe.topk.base import TopKRouter
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/topk/triton.py` at line 2, Remove the unused import "import
torch.nn.functional as F" from triton.py; locate the import statement in the top
of the file (the symbol to remove is "torch.nn.functional as F") and delete it
so the module no longer includes an unused dependency and to satisfy linting.
diffulex/moe/runner/__init__.py (2)

19-24: Consider sorting __all__ for consistency.

Per static analysis (RUF022), __all__ is not sorted. Sorting improves readability and makes diffs cleaner when adding new exports.

🧹 Proposed fix
 __all__ = [
     "MoERunner",
-    "TrivialMoERunner",
     "TritonFusedMoERunner",
+    "TrivialMoERunner",
     "build_runner",
 ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/runner/__init__.py` around lines 19 - 24, The __all__ export
list is unsorted; reorder the list assigned to __all__ so its string entries are
in lexicographical order (e.g., place "MoERunner", "TritonFusedMoERunner",
"TrivialMoERunner", "build_runner" sorted appropriately), updating the __all__
assignment to the sorted sequence to satisfy RUF022 and keep exports consistent.

15-16: Include the invalid impl value in the error message.

When an unsupported backend is requested, the error should indicate what value was passed for easier debugging. This is consistent with the error message pattern in build_dispatcher.

🔧 Proposed fix
     elif impl == "triton":
         return TritonFusedMoERunner(*args, **kwargs)
     else:
-        raise NotImplementedError
+        raise NotImplementedError(f"Unsupported runner backend: {impl!r}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/runner/__init__.py` around lines 15 - 16, The else branch
raising NotImplementedError in __init__.py should include the actual invalid
impl value for clearer debugging; update the error raised in that branch (the
raise in the same block that handles backend selection and mirrors
build_dispatcher’s pattern) to include impl in the message (e.g., "Unsupported
impl: {impl}") so callers see what value was passed.
diffulex/moe/runner/trivial.py (1)

1-5: Missing torch import.

The file uses dispatch_output.hidden_states.new_zeros(...) which works via the tensor method, but adding import torch would be more explicit and consistent with other files.

Suggested addition
 from __future__ import annotations
+import torch
 import torch.nn.functional as F
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/runner/trivial.py` around lines 1 - 5, The module is missing an
explicit torch import used when creating new tensors (e.g.,
dispatch_output.hidden_states.new_zeros(...)); add "import torch" to trivial.py
so tensor creation is explicit and consistent with other modules that use torch,
and ensure any references in MoERunner, CombineInput, or DispatchOutput code
that rely on torch remain unchanged.
diffulex/moe/topk/__init__.py (2)

19-20: Improve error message for unsupported implementations.

The NotImplementedError should include the invalid impl value and available options for easier debugging.

Suggested improvement
     else:
-        raise NotImplementedError
+        raise NotImplementedError(
+            f"Unknown top-k router implementation: {impl!r}. "
+            f"Supported: 'trivial', 'bypass', 'triton'."
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/topk/__init__.py` around lines 19 - 20, The current else branch
raises a bare NotImplementedError; update the else in
diffulex.moe.topk.__init__.py to raise a NotImplementedError that includes the
invalid impl value and the list of supported implementations (e.g.,
f"Unsupported impl '{impl}'; supported: {supported_list}") so callers can see
what was passed and what options exist; locate the else that currently does
raise NotImplementedError and replace it with an informative message referencing
the impl variable and the available implementations constant or keys from the
dispatch map used in this module.

23-30: Consider sorting __all__ for consistency.

Ruff flagged that __all__ is not sorted (RUF022). While optional, alphabetical ordering improves maintainability.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/topk/__init__.py` around lines 23 - 30, The __all__ list in this
module is not alphabetically sorted (RUF022); reorder the entries in the __all__
variable so they are in alphabetical order (e.g., "BypassTopKRouter",
"TopKOutput", "TopKRouter", "TritonFusedTopKRouter", "TrivialTopKRouter",
"build_topk_router") to satisfy the lint rule and improve maintainability;
update the __all__ definition accordingly.
diffulex/engine/model_runner.py (1)

60-61: Redundant assignments overwritten later.

self.world_size and self.rank are assigned here from config values, but they are immediately overwritten at lines 84-85 from the layout object. Consider removing these initial assignments to avoid confusion.

Suggested cleanup
-        self.world_size = config.tensor_parallel_size
-        self.rank = rank
         self.event = event
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/engine/model_runner.py` around lines 60 - 61, The assignments to
self.world_size and self.rank from config at the start of ModelRunner.__init__
are redundant because they are overwritten later from layout; remove the initial
assignments (self.world_size = config.tensor_parallel_size and self.rank = rank)
to avoid confusion and keep the single authoritative source (layout) — update or
delete those two lines in ModelRunner.__init__ and ensure any subsequent logic
relies on the values set from layout rather than the earlier config-derived
values.
diffulex/moe/runner/base.py (1)

36-38: Weight tensors not registered as parameters.

w13 and w2 are stored as plain attributes rather than nn.Parameter or registered buffers. This means they won't appear in state_dict() or be moved by .to(). If this is intentional (e.g., weights managed externally by the loader), consider adding a brief comment explaining the design choice.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/runner/base.py` around lines 36 - 38, w13 and w2 are being
stored as plain attributes (self.w13, self.w2) so they won't be included in
state_dict() or moved with .to(); update the class initializer to register them
properly: if these are trainable weights wrap them with torch.nn.Parameter and
assign to self.w13/self.w2, otherwise register them as buffers via
self.register_buffer('w13', w13) / self.register_buffer('w2', w2); if the
current design intentionally keeps them external, add a short comment next to
local_expert_start/w13/w2 explaining that these tensors are externally managed
and intentionally not registered so reviewers understand the choice.
diffulex_kernel/python/fused_topk_triton.py (1)

48-63: Output validation adds safety at runtime cost.

The _validate_fused_topk_outputs function performs GPU-to-CPU synchronization (detach().cpu()) on every call, which can impact performance. Consider making this validation optional (e.g., via an environment variable or debug flag) for production use.

Suggested change
+import os
+
+_VALIDATE_OUTPUTS = os.environ.get("DIFFULEX_VALIDATE_TOPK_OUTPUTS", "0") == "1"
+
 def _validate_fused_topk_outputs(
     topk_weights: torch.Tensor,
     topk_ids: torch.Tensor,
     *,
     num_experts: int,
 ) -> None:
+    if not _VALIDATE_OUTPUTS:
+        return
     invalid_id_mask = (topk_ids < 0) | (topk_ids >= num_experts)
     ...
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/python/fused_topk_triton.py` around lines 48 - 63, The
runtime validation in _validate_fused_topk_outputs causes GPU->CPU sync
(detach().cpu()) on every call; make this validation conditional by introducing
a toggle (e.g., module-level boolean FUSED_TOPK_VALIDATE_OUTPUTS set from an
environment variable like "FUSED_TOPK_VALIDATE_OUTPUTS" or a debug flag) and
only perform the invalid_id_mask check, detach().cpu(), and isfinite checks when
that toggle is true; update the function to early-return if validation is
disabled and ensure the env-var default is disabled in production but can be
enabled for testing/debugging.
diffulex/moe/dispatcher/trivial.py (1)

67-70: Strict expert ID matching may be overly restrictive.

The current logic raises if active_expert_ids doesn't exactly match the expected list. Consider allowing a subset for scenarios where only specific experts need processing (e.g., during debugging or selective expert evaluation).

Alternative: allow subset
         expert_ids = list(active_expert_ids)
-        if expert_ids != expected:
-            raise ValueError(f"Expected active_expert_ids={expected}, got {expert_ids}.")
+        if not set(expert_ids).issubset(set(expected)):
+            raise ValueError(
+                f"active_expert_ids must be a subset of {expected}, got {expert_ids}."
+            )
         return expert_ids
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/dispatcher/trivial.py` around lines 67 - 70, Change the strict
equality check between active_expert_ids and expected to allow expected to be a
subset: compute expert_ids = list(active_expert_ids), then validate with
something like if expected and not set(expected).issubset(set(expert_ids)):
raise ValueError(...); finally return only the relevant experts (e.g., filter
expert_ids to those in expected if expected is provided, otherwise return all
expert_ids). Update the error message in the raise to reflect that expected must
be a subset when applicable; reference variables active_expert_ids, expected,
expert_ids and the existing ValueError raise site.
diffulex_kernel/python/fused_moe_triton.py (2)

131-132: Early return shape may be inconsistent with weight dimensions.

When hidden_states.shape[0] == 0, the function returns hidden_states.new_zeros((0, w2.shape[1])). However, w2.shape[1] is hidden_size (the output dimension of down_proj). This is correct, but when topk_ids.shape[1] == 0 (zero top-k), returning zeros is semantically correct since no experts contribute.

Consider adding a comment clarifying that w2.shape[1] == hidden_size to improve readability.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/python/fused_moe_triton.py` around lines 131 - 132, The
early-return when hidden_states.shape[0] == 0 or topk_ids.shape[1] == 0 returns
a zero tensor using w2.shape[1]; clarify this by adding an inline comment in
fused_moe_triton.py next to the return explaining that w2.shape[1] is the model
hidden_size (the output dim of down_proj), and that returning zeros is correct
when there are no tokens or no top-k experts; keep the existing return value but
add the comment referencing hidden_states, topk_ids, and w2 for readability.

101-112: 3D tensor load may cause high register pressure.

The weight load creates a (BLOCK_M, BLOCK_N, BLOCK_K) tensor in registers before reduction. With BLOCK_M=8, BLOCK_N=64, BLOCK_K=32, this allocates 16K elements per thread block. This approach works but may limit occupancy on memory-bound workloads.

Consider whether a more traditional 2D tiled GEMM with per-row expert lookup would reduce register pressure, though this is acceptable for a draft implementation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/python/fused_moe_triton.py` around lines 101 - 112, The
current tl.load builds a full (BLOCK_M, BLOCK_N, BLOCK_K) tensor (see the load
using w_ptr + expert_offsets... + current_k * stride_wk and mask combining
row_mask, valid_expert, offs_n, k_mask) which raises register pressure; change
the implementation to tile along the K dimension and load smaller 2D slices
instead (e.g., loop over K sub-blocks: load (BLOCK_M, BLOCK_K) or (BLOCK_N,
BLOCK_K) chunks per iteration, multiply-accumulate into acc incrementally, and
only keep a small b-slice in registers per loop) so you perform partial
reductions across K rather than materializing the full 3D block in registers.
Ensure masks (row_mask, valid_expert, offs_n < num_cols, k_mask) and strides
(stride_we, stride_wn, stride_wk) are applied per sub-load and that the
accumulation acc += tl.sum(a[:, None, :] * b, axis=2) still operates on the
smaller b-slice.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@diffulex/utils/loader.py`:
- Line 1: The file contains a UTF-8 BOM at the start (before the first token
"import os"); remove the leading BOM character so the file is plain UTF-8
without BOM and re-save loader.py (ensure your editor/IDE writes UTF-8 without
BOM) so the module import line ("import os") and the rest of the file no longer
include the hidden character.

---

Outside diff comments:
In `@diffulex/moe/topk/trivial.py`:
- Around line 21-22: The renormalization step in trivial.py can divide by zero
when topk_weights.sum(...) == 0; modify the block in the method where
self.renormalize is used to compute a safe denominator (e.g., denom =
topk_weights.sum(dim=-1, keepdim=True).clamp_min(1e-20) or torch.maximum(...,
torch.tensor(1e-20, device=topk_weights.device))) and then do topk_weights =
topk_weights / denom so the division is robust to zero sums (use 1e-20 to match
the Triton kernel).

---

Duplicate comments:
In `@diffulex_kernel/python/fused_moe_triton.py`:
- Around line 41-51: The code validates shapes of topk_ids/topk_weights and
weight matrices but misses verifying that the number of tokens in topk_ids
matches hidden_states, which can break indexing; add a validation after the
existing topk_ids/topk_weights check that asserts topk_ids.shape[0] ==
hidden_states.shape[0] and raises a ValueError with a clear message referencing
topk_ids.shape and hidden_states.shape so callers can see the mismatch (place
this check alongside the other shape validations in the same block inside
fused_moe_triton where topk_ids, topk_weights, and hidden_states are validated).

---

Nitpick comments:
In `@diffulex_kernel/python/fused_moe_triton.py`:
- Around line 131-132: The early-return when hidden_states.shape[0] == 0 or
topk_ids.shape[1] == 0 returns a zero tensor using w2.shape[1]; clarify this by
adding an inline comment in fused_moe_triton.py next to the return explaining
that w2.shape[1] is the model hidden_size (the output dim of down_proj), and
that returning zeros is correct when there are no tokens or no top-k experts;
keep the existing return value but add the comment referencing hidden_states,
topk_ids, and w2 for readability.
- Around line 101-112: The current tl.load builds a full (BLOCK_M, BLOCK_N,
BLOCK_K) tensor (see the load using w_ptr + expert_offsets... + current_k *
stride_wk and mask combining row_mask, valid_expert, offs_n, k_mask) which
raises register pressure; change the implementation to tile along the K
dimension and load smaller 2D slices instead (e.g., loop over K sub-blocks: load
(BLOCK_M, BLOCK_K) or (BLOCK_N, BLOCK_K) chunks per iteration,
multiply-accumulate into acc incrementally, and only keep a small b-slice in
registers per loop) so you perform partial reductions across K rather than
materializing the full 3D block in registers. Ensure masks (row_mask,
valid_expert, offs_n < num_cols, k_mask) and strides (stride_we, stride_wn,
stride_wk) are applied per sub-load and that the accumulation acc += tl.sum(a[:,
None, :] * b, axis=2) still operates on the smaller b-slice.

In `@diffulex_kernel/python/fused_topk_triton.py`:
- Around line 48-63: The runtime validation in _validate_fused_topk_outputs
causes GPU->CPU sync (detach().cpu()) on every call; make this validation
conditional by introducing a toggle (e.g., module-level boolean
FUSED_TOPK_VALIDATE_OUTPUTS set from an environment variable like
"FUSED_TOPK_VALIDATE_OUTPUTS" or a debug flag) and only perform the
invalid_id_mask check, detach().cpu(), and isfinite checks when that toggle is
true; update the function to early-return if validation is disabled and ensure
the env-var default is disabled in production but can be enabled for
testing/debugging.

In `@diffulex/engine/model_runner.py`:
- Around line 60-61: The assignments to self.world_size and self.rank from
config at the start of ModelRunner.__init__ are redundant because they are
overwritten later from layout; remove the initial assignments (self.world_size =
config.tensor_parallel_size and self.rank = rank) to avoid confusion and keep
the single authoritative source (layout) — update or delete those two lines in
ModelRunner.__init__ and ensure any subsequent logic relies on the values set
from layout rather than the earlier config-derived values.

In `@diffulex/moe/dispatcher/trivial.py`:
- Around line 67-70: Change the strict equality check between active_expert_ids
and expected to allow expected to be a subset: compute expert_ids =
list(active_expert_ids), then validate with something like if expected and not
set(expected).issubset(set(expert_ids)): raise ValueError(...); finally return
only the relevant experts (e.g., filter expert_ids to those in expected if
expected is provided, otherwise return all expert_ids). Update the error message
in the raise to reflect that expected must be a subset when applicable;
reference variables active_expert_ids, expected, expert_ids and the existing
ValueError raise site.

In `@diffulex/moe/runner/__init__.py`:
- Around line 19-24: The __all__ export list is unsorted; reorder the list
assigned to __all__ so its string entries are in lexicographical order (e.g.,
place "MoERunner", "TritonFusedMoERunner", "TrivialMoERunner", "build_runner"
sorted appropriately), updating the __all__ assignment to the sorted sequence to
satisfy RUF022 and keep exports consistent.
- Around line 15-16: The else branch raising NotImplementedError in __init__.py
should include the actual invalid impl value for clearer debugging; update the
error raised in that branch (the raise in the same block that handles backend
selection and mirrors build_dispatcher’s pattern) to include impl in the message
(e.g., "Unsupported impl: {impl}") so callers see what value was passed.

In `@diffulex/moe/runner/base.py`:
- Around line 36-38: w13 and w2 are being stored as plain attributes (self.w13,
self.w2) so they won't be included in state_dict() or moved with .to(); update
the class initializer to register them properly: if these are trainable weights
wrap them with torch.nn.Parameter and assign to self.w13/self.w2, otherwise
register them as buffers via self.register_buffer('w13', w13) /
self.register_buffer('w2', w2); if the current design intentionally keeps them
external, add a short comment next to local_expert_start/w13/w2 explaining that
these tensors are externally managed and intentionally not registered so
reviewers understand the choice.

In `@diffulex/moe/runner/trivial.py`:
- Around line 1-5: The module is missing an explicit torch import used when
creating new tensors (e.g., dispatch_output.hidden_states.new_zeros(...)); add
"import torch" to trivial.py so tensor creation is explicit and consistent with
other modules that use torch, and ensure any references in MoERunner,
CombineInput, or DispatchOutput code that rely on torch remain unchanged.

In `@diffulex/moe/topk/__init__.py`:
- Around line 19-20: The current else branch raises a bare NotImplementedError;
update the else in diffulex.moe.topk.__init__.py to raise a NotImplementedError
that includes the invalid impl value and the list of supported implementations
(e.g., f"Unsupported impl '{impl}'; supported: {supported_list}") so callers can
see what was passed and what options exist; locate the else that currently does
raise NotImplementedError and replace it with an informative message referencing
the impl variable and the available implementations constant or keys from the
dispatch map used in this module.
- Around line 23-30: The __all__ list in this module is not alphabetically
sorted (RUF022); reorder the entries in the __all__ variable so they are in
alphabetical order (e.g., "BypassTopKRouter", "TopKOutput", "TopKRouter",
"TritonFusedTopKRouter", "TrivialTopKRouter", "build_topk_router") to satisfy
the lint rule and improve maintainability; update the __all__ definition
accordingly.

In `@diffulex/moe/topk/bypass.py`:
- Around line 7-8: Fix the typo in the BypassTopKRouter class docstring: change
"Bypass implemenation, use this if fused moe runner also handles topk" to
"Bypass implementation, use this if fused moe runner also handles topk" in the
BypassTopKRouter (subclass of TopKRouter) docstring.
- Around line 10-15: The bypass router's forward currently returns weights=None
and ids=None which is intentional for routers that let runners perform top-k,
but it lacks documentation and can lead to a confusing RuntimeError when paired
with runners that expect ids/weights (see TritonFusedMoERunner). Update the
forward method's docstring in the bypass router to clearly state that returning
None for weights and ids is intentional and that this router must only be used
with runners that implement internal top-k (e.g., TritonFusedMoERunner-like
runners); additionally, add an optional runtime compatibility check in forward
or router initialization to raise a clear error message if paired with an
incompatible runner instead of allowing the downstream RuntimeError.

In `@diffulex/moe/topk/triton.py`:
- Line 2: Remove the unused import "import torch.nn.functional as F" from
triton.py; locate the import statement in the top of the file (the symbol to
remove is "torch.nn.functional as F") and delete it so the module no longer
includes an unused dependency and to satisfy linting.

In `@diffulex/utils/loader.py`:
- Around line 157-169: The code uses getattr(param, "weight_loader") inside
loader logic though AttributeError/KeyError is already caught; change to direct
attribute access param.weight_loader when building the partial (in the block
that calls model.get_parameter and constructs weight_loader) to satisfy static
analysis and improve clarity while keeping the same behavior — continue to wrap
it in the existing try/except, create the partial with param.weight_loader,
param, loaded_weight, and then call weight_loader() or weight_loader(shard_id)
depending on shard_id.
- Line 139: Replace the runtime assert "assert v == 'lm_head'" with an explicit
check that raises a ValueError when the condition fails; specifically, change it
to: if v != "lm_head": raise ValueError(f"Expected 'lm_head' for variable v, got
{v!r}"). This ensures the validation cannot be disabled with -O and provides a
clear, descriptive error message including the actual value.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 984d317a-ff55-46a2-a604-ddeb568c046d

📥 Commits

Reviewing files that changed from the base of the PR and between 79f111a and c962e83.

📒 Files selected for processing (35)
  • .gitignore
  • diffulex/config.py
  • diffulex/engine/dp_worker.py
  • diffulex/engine/model_runner.py
  • diffulex/engine/tp_worker.py
  • diffulex/layer/embed_head.py
  • diffulex/layer/linear.py
  • diffulex/model/dream.py
  • diffulex/model/fast_dllm_v2.py
  • diffulex/model/llada.py
  • diffulex/model/sdar.py
  • diffulex/moe/__init__.py
  • diffulex/moe/dispatcher/__init__.py
  • diffulex/moe/dispatcher/base.py
  • diffulex/moe/dispatcher/datatype.py
  • diffulex/moe/dispatcher/trivial.py
  • diffulex/moe/layers.py
  • diffulex/moe/moe_impl.py
  • diffulex/moe/runner/__init__.py
  • diffulex/moe/runner/base.py
  • diffulex/moe/runner/triton.py
  • diffulex/moe/runner/trivial.py
  • diffulex/moe/topk/__init__.py
  • diffulex/moe/topk/base.py
  • diffulex/moe/topk/bypass.py
  • diffulex/moe/topk/datatype.py
  • diffulex/moe/topk/triton.py
  • diffulex/moe/topk/trivial.py
  • diffulex/sampler/sdar.py
  • diffulex/utils/checkpoint.py
  • diffulex/utils/loader.py
  • diffulex/utils/parallelism.py
  • diffulex_kernel/__init__.py
  • diffulex_kernel/python/fused_moe_triton.py
  • diffulex_kernel/python/fused_topk_triton.py
💤 Files with no reviewable changes (1)
  • diffulex/moe/moe_impl.py
✅ Files skipped from review due to trivial changes (4)
  • .gitignore
  • diffulex/moe/topk/datatype.py
  • diffulex/layer/embed_head.py
  • diffulex/utils/checkpoint.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • diffulex/sampler/sdar.py
  • diffulex/moe/init.py
  • diffulex_kernel/init.py

@@ -1,4 +1,4 @@
import os
import os
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove BOM character from file.

The file starts with a UTF-8 BOM (byte order mark: ), which is unusual for Python source files and can cause issues with some tools. This was also flagged by static analysis (EXE002).

🔧 Proposed fix
-import os
+import os
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import os
import os
🧰 Tools
🪛 Ruff (0.15.9)

[warning] 1-1: The file is executable but no shebang is present

(EXE002)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/utils/loader.py` at line 1, The file contains a UTF-8 BOM at the
start (before the first token "import os"); remove the leading BOM character so
the file is plain UTF-8 without BOM and re-save loader.py (ensure your
editor/IDE writes UTF-8 without BOM) so the module import line ("import os") and
the rest of the file no longer include the hidden character.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
diffulex/config.py (1)

95-111: ⚠️ Potential issue | 🟠 Major

Reject unsupported TP/EP combinations during config validation.

tensor_parallel_size=2, expert_parallel_size=4 passes here today, but diffulex/utils/parallelism.py only supports TP-only, pure EP, or tp == ep. That turns a bad user config into a much later distributed-init failure instead of a clear ValueError at construction time.

Suggested fix
         if not 1 <= self.expert_parallel_size <= 32768:
             raise ValueError(
                 "expert_parallel_size must be in [1, 32768], "
                 f"got: {self.expert_parallel_size}"
             )
+
+        if (
+            self.tensor_parallel_size != 1
+            and self.expert_parallel_size != 1
+            and self.tensor_parallel_size != self.expert_parallel_size
+        ):
+            raise ValueError(
+                "Unsupported model-parallel layout: only TP-only, pure EP, "
+                "and tp == ep are supported, got "
+                f"tensor_parallel_size={self.tensor_parallel_size}, "
+                f"expert_parallel_size={self.expert_parallel_size}"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/config.py` around lines 95 - 111, Add a validation in the config
constructor (where tensor_parallel_size and expert_parallel_size are checked) to
reject unsupported TP/EP combinations by raising ValueError: allow only TP-only
(expert_parallel_size == 1), pure EP (tensor_parallel_size == 1), or equal TP
and EP (tensor_parallel_size == expert_parallel_size); if none of these hold,
raise a clear ValueError mentioning tensor_parallel_size and
expert_parallel_size so invalid combos are caught during config validation
instead of later in diffulex/utils/parallelism.py.
🧹 Nitpick comments (1)
diffulex_kernel/python/fused_topk_triton.py (1)

129-184: Replace commented-out validator block with an explicit debug toggle.

Lines 179-183 leave dead commented code in a hot path. Prefer a deliberate debug guard so validation can be enabled when needed without code edits.

Suggested refactor
+import os
@@
 def _launch_fused_topk_kernels(
@@
 ) -> tuple[torch.Tensor, torch.Tensor]:
+    debug_validate = os.getenv("DIFFULEX_VALIDATE_FUSED_TOPK", "0") == "1"
@@
-    `#_validate_fused_topk_outputs`(
-    #    topk_weights=topk_weights,
-    #    topk_ids=topk_ids,
-    #    num_experts=num_experts,
-    #)
+    if debug_validate:
+        _validate_fused_topk_outputs(
+            topk_weights=topk_weights,
+            topk_ids=topk_ids,
+            num_experts=num_experts,
+        )
     return topk_weights, topk_ids
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/python/fused_topk_triton.py` around lines 129 - 184, The
commented-out validator call in _launch_fused_topk_kernels should be replaced
with an opt-in debug guard so validation can be enabled without editing code:
add a boolean debug parameter (default False) to _launch_fused_topk_kernels (or
read a module-level DEBUG flag) and, when debug is True, call
_validate_fused_topk_outputs(topk_weights=topk_weights, topk_ids=topk_ids,
num_experts=num_experts) after the Triton launch; ensure the default hot path
leaves validation disabled to avoid perf impact and keep the function signature
(_launch_fused_topk_kernels) and return values unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@diffulex_kernel/python/fused_topk_triton.py`:
- Around line 20-45: In _validate_fused_topk_inputs add an explicit guard for
zero experts: detect when router_logits.shape[-1] == 0 and raise a ValueError
(e.g., "fused_topk requires num_experts > 0") before computing effective_topk;
keep the existing top_k and scoring_func checks and the 4096 upper-bound check
but ensure the zero-experts check runs early so effective_topk cannot become 0
and empty routing outputs are prevented.

In `@diffulex/moe/layer/base.py`:
- Around line 34-46: Validate top_k against num_experts in the constructor
before calling build_topk_router: check that top_k is an int and 1 <= top_k <=
num_experts, and raise a clear ValueError (or assert) mentioning both top_k and
num_experts if the check fails; perform this check right before the
build_topk_router(...) call so the router creation is never attempted with
invalid parameters (referencing top_k, num_experts, build_topk_router, and
router to locate the code).

In `@diffulex/moe/layer/ep_impl.py`:
- Around line 349-371: The allocated output tensors for dist.all_gather must
match the input padded shapes: change the creation of gathered_token_indices and
gathered_outputs so each entry has shape (max_local_tokens,) and
(max_local_tokens, hidden_size) respectively (matching
padded_local_token_indices and padded_local_outputs), call dist.all_gather into
those padded buffers, and then slice each gathered buffer down to the actual
per-rank counts in gathered_counts to reconstruct the variable-length results;
update references to padded_local_token_indices, padded_local_outputs,
gathered_token_indices, gathered_outputs, max_local_tokens, hidden_size, and the
dist.all_gather call accordingly.

---

Outside diff comments:
In `@diffulex/config.py`:
- Around line 95-111: Add a validation in the config constructor (where
tensor_parallel_size and expert_parallel_size are checked) to reject unsupported
TP/EP combinations by raising ValueError: allow only TP-only
(expert_parallel_size == 1), pure EP (tensor_parallel_size == 1), or equal TP
and EP (tensor_parallel_size == expert_parallel_size); if none of these hold,
raise a clear ValueError mentioning tensor_parallel_size and
expert_parallel_size so invalid combos are caught during config validation
instead of later in diffulex/utils/parallelism.py.

---

Nitpick comments:
In `@diffulex_kernel/python/fused_topk_triton.py`:
- Around line 129-184: The commented-out validator call in
_launch_fused_topk_kernels should be replaced with an opt-in debug guard so
validation can be enabled without editing code: add a boolean debug parameter
(default False) to _launch_fused_topk_kernels (or read a module-level DEBUG
flag) and, when debug is True, call
_validate_fused_topk_outputs(topk_weights=topk_weights, topk_ids=topk_ids,
num_experts=num_experts) after the Triton launch; ensure the default hot path
leaves validation disabled to avoid perf impact and keep the function signature
(_launch_fused_topk_kernels) and return values unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 83719749-6fcd-437f-8053-d53e116e4074

📥 Commits

Reviewing files that changed from the base of the PR and between c962e83 and 5956680.

📒 Files selected for processing (11)
  • diffulex/config.py
  • diffulex/moe/__init__.py
  • diffulex/moe/layer/__init__.py
  • diffulex/moe/layer/base.py
  • diffulex/moe/layer/ep_impl.py
  • diffulex/moe/layer/tp_impl.py
  • diffulex/moe/layer/trivial_impl.py
  • diffulex/moe/topk/__init__.py
  • diffulex/utils/parallelism.py
  • diffulex_kernel/python/fused_moe_triton.py
  • diffulex_kernel/python/fused_topk_triton.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • diffulex/moe/init.py
  • diffulex_kernel/python/fused_moe_triton.py

Comment on lines +20 to +45
def _validate_fused_topk_inputs(
router_logits: torch.Tensor,
top_k: int,
scoring_func: str,
) -> int:
if not router_logits.is_cuda:
raise ValueError("fused_topk requires CUDA tensors.")
if router_logits.dim() != 2:
raise ValueError(
f"fused_topk expects a 2D [num_tokens, num_experts] tensor, got {router_logits.shape}."
)
if router_logits.dtype not in (torch.float16, torch.bfloat16, torch.float32):
raise TypeError(
f"fused_topk only supports fp16/bf16/fp32 router logits, got {router_logits.dtype}."
)
if top_k <= 0:
raise ValueError(f"fused_topk requires top_k > 0, got {top_k}.")
if scoring_func not in {"softmax", "sigmoid"}:
raise ValueError(f"Unsupported scoring function: {scoring_func!r}.")

effective_topk = min(top_k, router_logits.shape[-1])
if router_logits.shape[-1] > 4096:
raise NotImplementedError(
"This Triton top-k draft currently requires num_experts <= 4096."
)
return effective_topk
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add explicit guard for zero experts.

top_k > 0 is enforced, but effective_topk can still become 0 when router_logits.shape[-1] == 0 (Line 40). This silently violates the routing contract and can propagate invalid empty routing outputs.

Proposed fix
 def _validate_fused_topk_inputs(
     router_logits: torch.Tensor,
     top_k: int,
     scoring_func: str,
 ) -> int:
@@
     if scoring_func not in {"softmax", "sigmoid"}:
         raise ValueError(f"Unsupported scoring function: {scoring_func!r}.")
+    if router_logits.shape[-1] == 0:
+        raise ValueError("fused_topk requires num_experts > 0.")
 
     effective_topk = min(top_k, router_logits.shape[-1])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/python/fused_topk_triton.py` around lines 20 - 45, In
_validate_fused_topk_inputs add an explicit guard for zero experts: detect when
router_logits.shape[-1] == 0 and raise a ValueError (e.g., "fused_topk requires
num_experts > 0") before computing effective_topk; keep the existing top_k and
scoring_func checks and the 4096 upper-bound check but ensure the zero-experts
check runs early so effective_topk cannot become 0 and empty routing outputs are
prevented.

Comment on lines +34 to +46
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_experts = num_experts
self.top_k = top_k
self.hidden_act = hidden_act
self.norm_topk_prob = norm_topk_prob

self.router = build_topk_router(
"triton",
top_k=top_k,
renormalize=norm_topk_prob,
scoring_func="softmax",
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate top_k against num_experts before building the router.

These values come straight from model config, but the constructor never checks 1 <= top_k <= num_experts. If a checkpoint/config drifts, the failure gets deferred into the routing or fused-kernel path instead of stopping here with a clear error.

Suggested fix
-        self.hidden_size = hidden_size
-        self.intermediate_size = intermediate_size
-        self.num_experts = num_experts
-        self.top_k = top_k
+        if num_experts < 1:
+            raise ValueError(f"num_experts must be >= 1, got {num_experts}")
+        if not 1 <= top_k <= num_experts:
+            raise ValueError(
+                f"top_k must be in [1, num_experts], got top_k={top_k}, num_experts={num_experts}"
+            )
+
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_experts = num_experts
+        self.top_k = top_k
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/layer/base.py` around lines 34 - 46, Validate top_k against
num_experts in the constructor before calling build_topk_router: check that
top_k is an int and 1 <= top_k <= num_experts, and raise a clear ValueError (or
assert) mentioning both top_k and num_experts if the check fails; perform this
check right before the build_topk_router(...) call so the router creation is
never attempted with invalid parameters (referencing top_k, num_experts,
build_topk_router, and router to locate the code).

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (2)
diffulex/moe/layer/base.py (1)

31-46: ⚠️ Potential issue | 🟠 Major

Reject invalid top_k/num_experts before building the router.

The constructor still accepts top_k <= 0 or top_k > num_experts and defers the failure into routing/kernel code. This is config validation and should fail here with a clear error.

🔧 Proposed fix
         super().__init__()

         if hidden_act != "silu":
             raise NotImplementedError("only silu is supported currently")
+        if num_experts < 1:
+            raise ValueError(f"num_experts must be >= 1, got {num_experts}.")
+        if not 1 <= top_k <= num_experts:
+            raise ValueError(
+                f"top_k must be in [1, num_experts], got top_k={top_k}, num_experts={num_experts}."
+            )
         
         self.hidden_size = hidden_size
         self.intermediate_size = intermediate_size
         self.num_experts = num_experts
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/layer/base.py` around lines 31 - 46, Validate the MoE layer
constructor inputs before building the router: check that num_experts is > 0 and
top_k is > 0 and top_k <= num_experts, and raise a clear ValueError if any
condition fails; perform these checks near the start of the __init__ (before
calling build_topk_router) and refer to the parameters hidden_act, top_k, and
num_experts so failures are explicit and caught early.
diffulex_kernel/python/fused_moe_triton.py (1)

72-93: ⚠️ Potential issue | 🟠 Major

Validate w13/w2 ranks and routing rows before touching their shapes.

This validator still assumes w13/w2 are 3-D and that topk_ids has one row per token, but it never checks those preconditions before reading w13.shape[2] and entering the pack path. It also never verifies w2.shape[2] == w13.shape[1] // 2, so malformed inputs still fall into IndexError/mask mismatches instead of a clear Python-side error.

🔧 Proposed fix
     if hidden_states.dim() != 2 or topk_ids.dim() != 2 or topk_weights.dim() != 2:
         raise ValueError(
             "fused_moe expects hidden_states/topk_ids/topk_weights to be 2D tensors."
         )
+    if w13.dim() != 3 or w2.dim() != 3:
+        raise ValueError("fused_moe expects w13 and w2 to be 3D tensors.")
@@
     if topk_ids.shape != topk_weights.shape:
         raise ValueError(
             f"topk_ids and topk_weights must have the same shape, got {topk_ids.shape} and {topk_weights.shape}."
         )
-    if w13.shape[0] != w2.shape[0] or w13.shape[2] != hidden_states.shape[1]:
+    if topk_ids.shape[0] != hidden_states.shape[0]:
+        raise ValueError(
+            "fused_moe expects topk_ids/topk_weights to have one row per token in hidden_states."
+        )
+    if w13.shape[1] % 2 != 0:
+        raise ValueError("fused_moe expects w13.shape[1] to be 2 * intermediate_size.")
+    intermediate_size = w13.shape[1] // 2
+    if (
+        w13.shape[0] != w2.shape[0]
+        or w13.shape[2] != hidden_states.shape[1]
+        or w2.shape[1] != hidden_states.shape[1]
+        or w2.shape[2] != intermediate_size
+    ):
         raise ValueError(
             "Weight shapes do not match hidden_states or local expert count."
         )
-    if w13.shape[1] % 2 != 0:
-        raise ValueError("fused_moe expects w13.shape[1] to be 2 * intermediate_size.")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/python/fused_moe_triton.py` around lines 72 - 93, The
validator currently reads w13.shape[2] and assumes w13/w2 are 3-D and that
topk_ids has one row per token; add explicit rank and routing-row checks up
front in the fused_moe validation: verify w13.dim() == 3 and w2.dim() == 3,
verify topk_ids.dim() == 2 and topk_ids.shape[0] == hidden_states.shape[0],
verify w13.shape[0] == w2.shape[0] (already checked) and add the missing check
w2.shape[2] == w13.shape[1] // 2 so the inner dims match, and keep the existing
checks (dtype, w13.shape[1] % 2 == 0) to raise clear TypeError/ValueError before
any indexing or packing occurs.
🧹 Nitpick comments (2)
diffulex_kernel/python/fused_topk_triton.py (1)

100-104: Minor: redundant mask assignment in sigmoid branch.

Line 102 sets masked positions to 0.0, but Line 104 immediately overwrites all masked positions to -1.0e20 regardless of scoring mode. Line 102 is redundant.

Optional cleanup
     else:
         scores = 1.0 / (1.0 + tl.exp(-logits))
-        scores = tl.where(mask, scores, 0.0)
 
     scores = tl.where(mask, scores, -1.0e20)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/python/fused_topk_triton.py` around lines 100 - 104, The
sigmoid branch redundantly sets masked positions to 0.0 then immediately
overwrites them to -1.0e20; remove the intermediate masking line so that in the
else/sigmoid path you only compute scores = 1.0 / (1.0 + tl.exp(-logits)) and
let the subsequent scores = tl.where(mask, scores, -1.0e20) apply the final
mask; look for the variable scores and the mask usage in fused_topk_triton.py to
update the else branch accordingly.
diffulex/moe/layer/__init__.py (1)

7-14: Make the layout note a real docstring or plain comments.

Because the imports come first, Lines 7-14 are a discarded string literal, not the module __doc__. Move it above the imports or convert it to # comments if you want this note to stay discoverable.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/layer/__init__.py` around lines 7 - 14, The triple-quoted layout
note in diffulex/moe/layer/__init__.py is currently placed after imports and
thus becomes a discarded string literal; move that string above the import
statements so it becomes the module docstring (module __doc__) or alternatively
rewrite each line with leading # to make them real comments; update the block in
__init__.py (the existing triple-quoted note) accordingly and ensure imports
remain below it so the note is properly recognized as a docstring or as normal
comments.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@diffulex_kernel/python/fused_moe_triton.py`:
- Around line 499-530: combined_output is a reusable workspace from
_get_workspace_tensor and currently returned directly when
packed_slot_outputs.dtype is float32, risking callers getting the cached buffer;
update the return so fused_moe (the caller of this block) always receives a
fresh tensor instance by cloning or copying combined_output before returning
(e.g., use combined_output.clone() or an equivalent non-in-place copy and then
.to(packed_slot_outputs.dtype)), ensuring you still preserve the dtype
conversion and do this for both early-return (num_valid_slots == 0) and the
final return path.

In `@diffulex_kernel/python/fused_topk_triton.py`:
- Around line 179-183: Re-enable the output validation by making the commented
_validate_fused_topk_outputs call toggleable via a debug flag (e.g., add a
parameter debug: bool = False to the function that calls it) or an env var; when
debug is True, call _validate_fused_topk_outputs(topk_weights=topk_weights,
topk_ids=topk_ids, num_experts=num_experts) to validate expert IDs and weights,
otherwise skip it for performance. Update the function signature that currently
contains the commented block to accept the new debug parameter and wire any
callers to pass debug through or default to False so validation can be enabled
at runtime for debugging. Ensure the validation call references the exact
symbols _validate_fused_topk_outputs, topk_weights, topk_ids, and num_experts.

In `@diffulex/config.py`:
- Around line 107-112: The current range check on expert_parallel_size doesn't
ensure TP/EP layouts are compatible and can later fail in get_world_size; in
Config.__post_init__ call get_world_size(self.tp, self.ep,
self.expert_parallel_size) and validate its result (or catch errors) and raise a
clear ValueError if the layout is unsupported (e.g., get_world_size raises or
returns an inconsistent/invalid world size for tp/ep/expert_parallel_size like
tp=2, ep=3); update the error message to reference tp, ep, and
expert_parallel_size so misconfigured layouts fail fast during initialization.

In `@diffulex/moe/layer/tp_impl.py`:
- Line 67: The TP reduction currently uses dist.all_reduce(final_hidden_states)
without a TP-only process group, causing reductions across the full world in
mixed TP+EP setups; update the parallelism init to create and store TP/EP
subgroups (either in init_process_group() or by adding a get_tp_group() API
alongside get_tp_rank() and get_tp_world_size()), then change the reduction call
in tp_impl.py (the dist.all_reduce(final_hidden_states) invocation) to pass the
TP subgroup explicitly; apply the same pattern to embed_head.py and linear.py so
all TP-dimension all_reduce calls use the new TP group.

In `@diffulex/utils/loader.py`:
- Around line 131-170: The loop over packed_modules_mapping prematurely returns
True even when parameter/buffer lookup or weight_loader invocation fails; update
the logic in the loop that handles mapping keys (using packed_modules_mapping,
weight_name, config.model_name) so the return True only happens after a
successful copy or successful call to the resolved weight_loader (i.e., after
model.get_parameter(...)/weight_loader(...) or
model.get_buffer(...).copy_(loaded_weight) completes without exceptions).
Concretely, move the return True into the success paths (after
default_weight_loader or getattr(param, "weight_loader") is invoked successfully
and after buffer.copy_ succeeds), and ensure that caught AttributeError/KeyError
branches do not return True so the caller can fall back.

---

Duplicate comments:
In `@diffulex_kernel/python/fused_moe_triton.py`:
- Around line 72-93: The validator currently reads w13.shape[2] and assumes
w13/w2 are 3-D and that topk_ids has one row per token; add explicit rank and
routing-row checks up front in the fused_moe validation: verify w13.dim() == 3
and w2.dim() == 3, verify topk_ids.dim() == 2 and topk_ids.shape[0] ==
hidden_states.shape[0], verify w13.shape[0] == w2.shape[0] (already checked) and
add the missing check w2.shape[2] == w13.shape[1] // 2 so the inner dims match,
and keep the existing checks (dtype, w13.shape[1] % 2 == 0) to raise clear
TypeError/ValueError before any indexing or packing occurs.

In `@diffulex/moe/layer/base.py`:
- Around line 31-46: Validate the MoE layer constructor inputs before building
the router: check that num_experts is > 0 and top_k is > 0 and top_k <=
num_experts, and raise a clear ValueError if any condition fails; perform these
checks near the start of the __init__ (before calling build_topk_router) and
refer to the parameters hidden_act, top_k, and num_experts so failures are
explicit and caught early.

---

Nitpick comments:
In `@diffulex_kernel/python/fused_topk_triton.py`:
- Around line 100-104: The sigmoid branch redundantly sets masked positions to
0.0 then immediately overwrites them to -1.0e20; remove the intermediate masking
line so that in the else/sigmoid path you only compute scores = 1.0 / (1.0 +
tl.exp(-logits)) and let the subsequent scores = tl.where(mask, scores, -1.0e20)
apply the final mask; look for the variable scores and the mask usage in
fused_topk_triton.py to update the else branch accordingly.

In `@diffulex/moe/layer/__init__.py`:
- Around line 7-14: The triple-quoted layout note in
diffulex/moe/layer/__init__.py is currently placed after imports and thus
becomes a discarded string literal; move that string above the import statements
so it becomes the module docstring (module __doc__) or alternatively rewrite
each line with leading # to make them real comments; update the block in
__init__.py (the existing triple-quoted note) accordingly and ensure imports
remain below it so the note is properly recognized as a docstring or as normal
comments.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 74887528-809f-4eab-8d9d-6d01dd7613fd

📥 Commits

Reviewing files that changed from the base of the PR and between 5956680 and 7b02278.

📒 Files selected for processing (30)
  • .gitignore
  • diffulex/config.py
  • diffulex/engine/dp_worker.py
  • diffulex/engine/model_runner.py
  • diffulex/engine/tp_worker.py
  • diffulex/layer/embed_head.py
  • diffulex/layer/linear.py
  • diffulex/model/dream.py
  • diffulex/model/fast_dllm_v2.py
  • diffulex/model/llada.py
  • diffulex/model/sdar.py
  • diffulex/moe/__init__.py
  • diffulex/moe/layer/__init__.py
  • diffulex/moe/layer/base.py
  • diffulex/moe/layer/ep_impl.py
  • diffulex/moe/layer/tp_impl.py
  • diffulex/moe/layer/trivial_impl.py
  • diffulex/moe/moe_impl.py
  • diffulex/moe/topk/__init__.py
  • diffulex/moe/topk/base.py
  • diffulex/moe/topk/datatype.py
  • diffulex/moe/topk/triton.py
  • diffulex/moe/topk/trivial.py
  • diffulex/sampler/sdar.py
  • diffulex/utils/checkpoint.py
  • diffulex/utils/loader.py
  • diffulex/utils/parallelism.py
  • diffulex_kernel/__init__.py
  • diffulex_kernel/python/fused_moe_triton.py
  • diffulex_kernel/python/fused_topk_triton.py
💤 Files with no reviewable changes (1)
  • diffulex/moe/moe_impl.py
✅ Files skipped from review due to trivial changes (5)
  • .gitignore
  • diffulex/sampler/sdar.py
  • diffulex/moe/topk/datatype.py
  • diffulex/moe/topk/base.py
  • diffulex/utils/checkpoint.py
🚧 Files skipped from review as they are similar to previous changes (12)
  • diffulex/model/dream.py
  • diffulex/engine/tp_worker.py
  • diffulex/model/fast_dllm_v2.py
  • diffulex/layer/linear.py
  • diffulex/model/llada.py
  • diffulex/layer/embed_head.py
  • diffulex/moe/topk/triton.py
  • diffulex/moe/init.py
  • diffulex_kernel/init.py
  • diffulex/moe/topk/trivial.py
  • diffulex/engine/dp_worker.py
  • diffulex/moe/layer/trivial_impl.py

Comment on lines +499 to +530
combined_output = _get_workspace_tensor(
"combined_output",
(num_tokens, packed_slot_outputs.shape[1]),
dtype=torch.float32,
device=packed_slot_outputs.device,
zero=True,
)
if packed_inputs.num_valid_slots == 0:
return combined_output.to(packed_slot_outputs.dtype)

block_m = PACKED_BLOCK_M
block_n = 64 if packed_slot_outputs.shape[1] >= 64 else triton.next_power_of_2(packed_slot_outputs.shape[1])
_weighted_scatter_add[
(triton.cdiv(packed_inputs.num_padded_slots, block_m), triton.cdiv(packed_slot_outputs.shape[1], block_n))
](
packed_slot_outputs,
packed_inputs.packed_token_ids,
packed_inputs.packed_weights,
combined_output,
packed_inputs.num_padded_slots,
packed_slot_outputs.shape[1],
packed_slot_outputs.stride(0),
packed_slot_outputs.stride(1),
packed_inputs.packed_token_ids.stride(0),
packed_inputs.packed_weights.stride(0),
combined_output.stride(0),
combined_output.stride(1),
BLOCK_M=block_m,
BLOCK_N=block_n,
num_warps=4,
)
return combined_output.to(packed_slot_outputs.dtype)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don’t return the reusable workspace buffer in fp32 mode.

combined_output comes from _WORKSPACE_CACHE, and when packed_slot_outputs.dtype is already torch.float32, the .to() here is a no-op. That means fused_moe() can hand its caller the cached workspace itself, and the next MoE invocation on the same device can overwrite the previous layer’s result before it is consumed.

🔧 Proposed fix
-    return combined_output.to(packed_slot_outputs.dtype)
+    if packed_slot_outputs.dtype == torch.float32:
+        return combined_output.clone()
+    return combined_output.to(packed_slot_outputs.dtype)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/python/fused_moe_triton.py` around lines 499 - 530,
combined_output is a reusable workspace from _get_workspace_tensor and currently
returned directly when packed_slot_outputs.dtype is float32, risking callers
getting the cached buffer; update the return so fused_moe (the caller of this
block) always receives a fresh tensor instance by cloning or copying
combined_output before returning (e.g., use combined_output.clone() or an
equivalent non-in-place copy and then .to(packed_slot_outputs.dtype)), ensuring
you still preserve the dtype conversion and do this for both early-return
(num_valid_slots == 0) and the final return path.

Comment on lines +179 to +183
#_validate_fused_topk_outputs(
# topk_weights=topk_weights,
# topk_ids=topk_ids,
# num_experts=num_experts,
#)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Re-enable output validation or make it toggleable.

The output validation that catches out-of-range expert IDs and non-finite weights is commented out. While disabling it for performance in production is reasonable, having it entirely commented out means:

  1. No way to enable it for debugging without code changes
  2. Silent corruption if the kernel produces invalid outputs

Consider making validation toggleable via an optional parameter (e.g., debug: bool = False) or environment variable, or remove the dead code entirely if validation is intentionally never used.

Proposed fix: add debug flag
 def _launch_fused_topk_kernels(
     router_logits: torch.Tensor,
     top_k: int,
     renormalize: bool,
     scoring_func: str,
+    debug: bool = False,
 ) -> tuple[torch.Tensor, torch.Tensor]:
     ...
     _fused_topk[(num_tokens,)](
         ...
     )
-    `#_validate_fused_topk_outputs`(
-    #    topk_weights=topk_weights,
-    #    topk_ids=topk_ids,
-    #    num_experts=num_experts,
-    #)
+    if debug:
+        _validate_fused_topk_outputs(
+            topk_weights=topk_weights,
+            topk_ids=topk_ids,
+            num_experts=num_experts,
+        )
     return topk_weights, topk_ids
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/python/fused_topk_triton.py` around lines 179 - 183,
Re-enable the output validation by making the commented
_validate_fused_topk_outputs call toggleable via a debug flag (e.g., add a
parameter debug: bool = False to the function that calls it) or an env var; when
debug is True, call _validate_fused_topk_outputs(topk_weights=topk_weights,
topk_ids=topk_ids, num_experts=num_experts) to validate expert IDs and weights,
otherwise skip it for performance. Update the function signature that currently
contains the commented block to accept the new debug parameter and wire any
callers to pass debug through or default to False so validation can be enabled
at runtime for debugging. Ensure the validation call references the exact
symbols _validate_fused_topk_outputs, topk_weights, topk_ids, and num_experts.

Comment on lines +107 to +112
if not 1 <= self.expert_parallel_size <= 32768:
raise ValueError(
"expert_parallel_size must be in [1, 32768], "
f"got: {self.expert_parallel_size}"
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast on unsupported TP/EP layouts (Line 107+).

This range check is incomplete: it accepts TP/EP pairs that are later rejected by get_world_size(...) (e.g., tp=2, ep=3). Validate layout compatibility in Config.__post_init__ to prevent late runtime failures during worker startup.

Proposed fix
         if not 1 <= self.expert_parallel_size <= 32768:
             raise ValueError(
                 "expert_parallel_size must be in [1, 32768], "
                 f"got: {self.expert_parallel_size}"
             )
+
+        # Keep config validation aligned with diffulex.utils.parallelism.get_world_size
+        if (
+            self.tensor_parallel_size != 1
+            and self.expert_parallel_size != 1
+            and self.tensor_parallel_size != self.expert_parallel_size
+        ):
+            raise ValueError(
+                "Unsupported TP/EP layout: only TP-only (ep=1), "
+                "pure EP (tp=1), or tp==ep are supported, got "
+                f"tp={self.tensor_parallel_size}, ep={self.expert_parallel_size}"
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if not 1 <= self.expert_parallel_size <= 32768:
raise ValueError(
"expert_parallel_size must be in [1, 32768], "
f"got: {self.expert_parallel_size}"
)
if not 1 <= self.expert_parallel_size <= 32768:
raise ValueError(
"expert_parallel_size must be in [1, 32768], "
f"got: {self.expert_parallel_size}"
)
# Keep config validation aligned with diffulex.utils.parallelism.get_world_size
if (
self.tensor_parallel_size != 1
and self.expert_parallel_size != 1
and self.tensor_parallel_size != self.expert_parallel_size
):
raise ValueError(
"Unsupported TP/EP layout: only TP-only (ep=1), "
"pure EP (tp=1), or tp==ep are supported, got "
f"tp={self.tensor_parallel_size}, ep={self.expert_parallel_size}"
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/config.py` around lines 107 - 112, The current range check on
expert_parallel_size doesn't ensure TP/EP layouts are compatible and can later
fail in get_world_size; in Config.__post_init__ call get_world_size(self.tp,
self.ep, self.expert_parallel_size) and validate its result (or catch errors)
and raise a clear ValueError if the layout is unsupported (e.g., get_world_size
raises or returns an inconsistent/invalid world size for
tp/ep/expert_parallel_size like tp=2, ep=3); update the error message to
reference tp, ep, and expert_parallel_size so misconfigured layouts fail fast
during initialization.

local_expert_start=0,
hidden_act=self.hidden_act,
)
dist.all_reduce(final_hidden_states)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== TP parallelism helpers =="
rg -n -C3 'get_tp_|tp_group|new_group\(' --type=py diffulex

echo
echo "== all_reduce call sites =="
rg -n -C2 'all_reduce\s*\(' --type=py diffulex

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 12408


🏁 Script executed:

cat -n diffulex/utils/parallelism.py | head -200

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 7935


🏁 Script executed:

sed -n '60,75p' diffulex/moe/layer/tp_impl.py

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 567


🏁 Script executed:

sed -n '180,195p' diffulex/moe/layer/ep_impl.py

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 571


The TP reduction at line 67 lacks an explicit group parameter and will incorrectly reduce across the entire world in mixed TP+EP configurations.

diffulex.utils.parallelism does not expose a TP-specific group handle. The module provides get_tp_rank() and get_tp_world_size() but no process group API. The code currently calls dist.all_reduce() without a group parameter, which uses the default world group. This is correct only when TP is the sole parallelism mode, but fails in mixed TP+EP layouts where the world group includes both TP and EP processes—causing incorrect reduction across the entire world instead of just the TP dimension.

Create and store TP/EP subgroups in init_process_group() or expose a get_tp_group() function, then pass the appropriate group to all_reduce calls in tp_impl.py, embed_head.py, and linear.py.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/layer/tp_impl.py` at line 67, The TP reduction currently uses
dist.all_reduce(final_hidden_states) without a TP-only process group, causing
reductions across the full world in mixed TP+EP setups; update the parallelism
init to create and store TP/EP subgroups (either in init_process_group() or by
adding a get_tp_group() API alongside get_tp_rank() and get_tp_world_size()),
then change the reduction call in tp_impl.py (the
dist.all_reduce(final_hidden_states) invocation) to pass the TP subgroup
explicitly; apply the same pattern to embed_head.py and linear.py so all
TP-dimension all_reduce calls use the new TP group.

Comment on lines +131 to +170
for k in packed_modules_mapping:
if k not in weight_name:
continue

if config.model_name == "llada" and k == "ff_out" and "transformer.ff_out" in weight_name:
continue
elif config.model_name == "llada" and k == "transformer.ff_out":
v, shard_id = packed_modules_mapping[k]
assert v == "lm_head"
param_name = "lm_head.weight"
else:
v, shard_id = packed_modules_mapping[k]
param_name = weight_name.replace(k, v)

if "layernorm" in param_name:
try:
param = model.get_parameter(param_name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
except (AttributeError, KeyError):
try:
buffer = model.get_buffer(param_name)
buffer.copy_(loaded_weight)
except (AttributeError, KeyError):
pass
else:
try:
param = model.get_parameter(param_name)
weight_loader = partial(
getattr(param, "weight_loader"),
param,
loaded_weight,
)
if shard_id is None:
weight_loader()
else:
weight_loader(shard_id)
except (AttributeError, KeyError):
pass
return True
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Only report a packed mapping as handled after the copy succeeds.

This helper currently falls through to return True even when the mapped param/buffer lookup or weight_loader resolution fails. load_model() then skips the direct path, so a bad mapping silently drops the checkpoint tensor instead of falling back or surfacing the mismatch.

🔧 Proposed fix
     for k in packed_modules_mapping:
         if k not in weight_name:
             continue
+        loaded = False

         if config.model_name == "llada" and k == "ff_out" and "transformer.ff_out" in weight_name:
             continue
         elif config.model_name == "llada" and k == "transformer.ff_out":
             v, shard_id = packed_modules_mapping[k]
@@
         if "layernorm" in param_name:
             try:
                 param = model.get_parameter(param_name)
                 weight_loader = getattr(param, "weight_loader", default_weight_loader)
                 weight_loader(param, loaded_weight)
+                loaded = True
             except (AttributeError, KeyError):
                 try:
                     buffer = model.get_buffer(param_name)
                     buffer.copy_(loaded_weight)
+                    loaded = True
                 except (AttributeError, KeyError):
                     pass
         else:
             try:
                 param = model.get_parameter(param_name)
-                weight_loader = partial(
-                    getattr(param, "weight_loader"),
-                    param,
-                    loaded_weight,
-                )
-                if shard_id is None:
-                    weight_loader()
+                weight_loader_fn = getattr(param, "weight_loader", None)
+                if weight_loader_fn is None:
+                    if shard_id is None:
+                        default_weight_loader(param, loaded_weight)
+                        loaded = True
                 else:
-                    weight_loader(shard_id)
+                    weight_loader = partial(weight_loader_fn, param, loaded_weight)
+                    if shard_id is None:
+                        weight_loader()
+                    else:
+                        weight_loader(shard_id)
+                    loaded = True
             except (AttributeError, KeyError):
                 pass
-        return True
+        if loaded:
+            return True
🧰 Tools
🪛 Ruff (0.15.9)

[warning] 160-160: Do not call getattr with a constant attribute value. It is not any safer than normal property access.

Replace getattr with attribute access

(B009)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/utils/loader.py` around lines 131 - 170, The loop over
packed_modules_mapping prematurely returns True even when parameter/buffer
lookup or weight_loader invocation fails; update the logic in the loop that
handles mapping keys (using packed_modules_mapping, weight_name,
config.model_name) so the return True only happens after a successful copy or
successful call to the resolved weight_loader (i.e., after
model.get_parameter(...)/weight_loader(...) or
model.get_buffer(...).copy_(loaded_weight) completes without exceptions).
Concretely, move the return True into the success paths (after
default_weight_loader or getattr(param, "weight_loader") is invoked successfully
and after buffer.copy_ succeeds), and ensure that caught AttributeError/KeyError
branches do not return True so the caller can fall back.

@drewjin
Copy link
Copy Markdown
Contributor

drewjin commented Apr 14, 2026

LGTM

@drewjin drewjin merged commit 595d9be into SJTU-DENG-Lab:main Apr 14, 2026
1 check passed
@MikanAffine MikanAffine deleted the fusedmoe branch April 14, 2026 08:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants