Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
58b0900
Add XPU MoE decode kernel with INT4 sym/asym and FP16/BF16 baselines
Copilot May 14, 2026
527eede
Document int4 sign-extension trick
Copilot May 14, 2026
78ecc0c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2026
5dc9d95
Add INT8/INT2/FP8 decode MoE GEMV kernels and tests
Copilot May 14, 2026
f15093a
docs: clarify int2 bit-indexing notation in moe_gemm_decode
Copilot May 14, 2026
430868d
Merge remote-tracking branch 'origin/main' into copilot/add-xpu-moe-d…
May 18, 2026
4395884
test: add perf comparison UT — moe_gemm_decode vs default XPU MoE
Copilot May 19, 2026
a864bed
test: clearer skip reasons for moe_gemm_decode perf UT
Copilot May 20, 2026
407da75
fix(ark): correct duplicated bestla include path in sycl_tla_moe_deco…
Copilot May 26, 2026
70dc320
perf: vectorize moe_gemm_decode loads, parallelize expert-id fill, dr…
Copilot May 26, 2026
1da1977
feat(ark): add ARK_FP8_DECODE_USE_LUT switch for FP8 decode in MoE ke…
Copilot May 26, 2026
c297d37
feat(ark): make FP8 decode LUT switch runtime via ARK_FP8_DECODE_USE_…
Copilot May 26, 2026
26dbeaa
fix precommit
a32543254 May 27, 2026
72b19e9
Merge branch 'main' into copilot/add-xpu-moe-decode-implementation
a32543254 May 27, 2026
608bf28
Merge branch 'main' into copilot/add-xpu-moe-decode-implementation
a32543254 May 28, 2026
7ef8dbf
Merge branch 'main' into copilot/add-xpu-moe-decode-implementation
a32543254 May 29, 2026
3dd5a8e
Merge branch 'main' into copilot/add-xpu-moe-decode-implementation
a32543254 Jun 5, 2026
dab6219
Apply remaining changes
Copilot Jun 5, 2026
6cf6c3a
feat(ark): add quantized MoE prefill kernel (functional baseline)
Copilot Jun 8, 2026
9dc0f15
fix: restore patch_torch_sdpa def; relocate TestMoEGemmPrefill out of…
Copilot Jun 8, 2026
8a22540
Merge branch 'main' into copilot/add-xpu-moe-decode-implementation
a32543254 Jun 17, 2026
0e78f31
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2026
d48f45d
feat: add MoE prefill performance test with TFLOPS calculation
Copilot Jun 17, 2026
b07b59f
docs: add MoE prefill performance test documentation
Copilot Jun 17, 2026
5feb707
refactor: change ARK() instance to module reference in MoE perf tests
Copilot Jun 17, 2026
2b311bd
fix(test): correct MoE prefill perf test layout and kernel entry
Copilot Jun 17, 2026
e51cd9d
test(moe-perf): add dequant-inclusive baseline column for Stage-1 fai…
Copilot Jun 17, 2026
834d03c
perf(moe-prefill): cache the [E, K, N] dequant workspace across calls
Copilot Jun 17, 2026
2caa066
perf(moe_prefill): skip dequant for experts with zero tokens
Copilot Jun 17, 2026
3e5192c
perf(moe_prefill): pack PACK_K K-outputs per dequant work-item
Copilot Jun 17, 2026
c5f7c21
docs(moe_prefill): clarify PACK_K must divide group_size for hoist
Copilot Jun 17, 2026
81712c8
test: add accuracy UT for MoE prefill (ark.moe_gemm / moe_gemm_prefill)
Copilot Jun 17, 2026
26aa210
feat(ark): unified `moe` API + model-level perf test
Copilot Jun 17, 2026
5502624
test(ark): add model-level MoE perf benchmark on XPU
Copilot Jun 18, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
518 changes: 518 additions & 0 deletions auto_round_extension/ark/auto_round_kernel/__init__.py

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions auto_round_extension/ark/auto_round_kernel/ark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ typedef uintptr_t torch_ptr;
// Only include declarations, implementations are in separate .cpp files
#include "sycl_tla_common.hpp"
#include "sycl_tla_moe.hpp"
#include "sycl_tla_moe_decode.hpp"
#include "sycl_tla_moe_mixed.hpp"
#include "sycl_tla_sdpa.hpp"
#endif
#else
Expand Down Expand Up @@ -222,6 +224,27 @@ static void moe_gemm_wrapper(torch_ptr stream, torch_ptr activations, torch_ptr
(void*)outputs, (BTLA_DTYPE)(dtype), N, K, (int*)num_tokens_per_expert, num_experts);
}

static void moe_gemm_decode_wrapper(torch_ptr stream, torch_ptr activations, torch_ptr weights, torch_ptr scales,
torch_ptr zeros, torch_ptr outputs, torch_ptr expert_id_per_token_buf,
int act_dtype, int weight_dtype, int N, int K, int group_size,
torch_ptr num_tokens_per_expert, int num_experts, int total_tokens, bool asym) {
ark::moe_gemm_decode((sycl::queue*)stream, (void*)activations, (void*)weights, scales ? (void*)scales : nullptr,
zeros ? (void*)zeros : nullptr, (void*)outputs, (int*)expert_id_per_token_buf,
(BTLA_DTYPE)(act_dtype), (BTLA_DTYPE)(weight_dtype), N, K, group_size,
(int*)num_tokens_per_expert, num_experts, total_tokens, asym);
}

static void moe_gemm_prefill_wrapper(torch_ptr stream, torch_ptr activations, torch_ptr weights, torch_ptr scales,
torch_ptr zeros, torch_ptr outputs, torch_ptr dequant_workspace, int act_dtype,
int weight_dtype, int N, int K, int group_size, torch_ptr num_tokens_per_expert,
int num_experts, int total_tokens, bool asym) {
ark::moe_gemm_prefill((sycl::queue*)stream, (void*)activations, (void*)weights, scales ? (void*)scales : nullptr,
zeros ? (void*)zeros : nullptr, (void*)outputs,
dequant_workspace ? (void*)dequant_workspace : nullptr, (BTLA_DTYPE)(act_dtype),
(BTLA_DTYPE)(weight_dtype), N, K, group_size, (int*)num_tokens_per_expert, num_experts,
total_tokens, asym);
}

static void sage_dynamic_quant(torch_ptr stream, torch_ptr input, torch_ptr bias, torch_ptr output, torch_ptr scale_out,
int num_rows, int head_dim, int block_size) {
auto* q = (sycl::queue*)stream;
Expand Down Expand Up @@ -439,5 +462,7 @@ PYBIND11_MODULE(PY_NAME, m) {
m.def("sage_dynamic_quant_layout", &ark::sage_dynamic_quant_layout);
m.def("sage_dynamic_quant_v_layout", &ark::sage_dynamic_quant_v_layout);
m.def("moe_gemm", &ark::moe_gemm_wrapper);
m.def("moe_gemm_decode", &ark::moe_gemm_decode_wrapper);
m.def("moe_gemm_prefill", &ark::moe_gemm_prefill_wrapper);
#endif
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,67 @@ namespace ark {
void moe_gemm(sycl::queue* q, void* activations, void* weights, void* scales, void* outputs, BTLA_DTYPE dtype, int N,
int K, int* num_tokens_per_expert, int num_experts);

/**
* @brief MoE GEMV optimized for the decode phase (M per expert is typically
* 1-2 tokens). Supports unquantized FP16/BF16 weights and int4 (S4_CLIP)
* weights with group-wise scales and optional zero-points.
*
* Implementation is header-only in `sycl_tla_moe_decode.hpp`.
*
* @param q SYCL queue
* @param activations [total_tokens, K] in `act_dtype`
* @param weights Unquantized: [num_experts, N, K] in act_dtype
* Int4: packed [num_experts, N, K/2] uint8
* @param scales [num_experts, N, K/group_size] (act_dtype),
* ignored when weight_dtype is FP16/BF16
* @param zeros [num_experts, N, K/group_size] (act_dtype) or
* nullptr; required when asym==true
* @param outputs [total_tokens, N] in act_dtype
* @param expert_id_per_token_buf [total_tokens] int32 scratch buffer (device)
* @param act_dtype BTLA_DTYPE::F16 or BTLA_DTYPE::BF16
* @param weight_dtype BTLA_DTYPE::F16/BF16/S4_CLIP
* @param N Output feature dim (must be multiple of 16)
* @param K Input feature dim
* @param group_size Quantization group along K (int4 only); must
* divide K and be even. Default 128.
* @param num_tokens_per_expert [num_experts] int32
* @param num_experts Number of experts
* @param total_tokens Sum of num_tokens_per_expert (== rows of
* activations / outputs)
* @param asym Whether int4 weights are asymmetric
* (zeros required when true).
Comment thread
a32543254 marked this conversation as resolved.
*/
void moe_gemm_decode(sycl::queue* q, void* activations, void* weights, void* scales, void* zeros, void* outputs,
int* expert_id_per_token_buf, BTLA_DTYPE act_dtype, BTLA_DTYPE weight_dtype, int N, int K,
int group_size, int* num_tokens_per_expert, int num_experts, int total_tokens, bool asym);

/**
* @brief MoE Grouped GEMM optimized for the prefill phase, supporting the
* same set of weight encodings as `moe_gemm_decode` (FP16/BF16, INT8 sym/asym,
* INT4 sym/asym, INT2 sym/asym, FP8 E4M3/E5M2).
*
* Stage-1 implementation: dequantizes weights into a `[num_experts, K, N]`
* temporary buffer (must be supplied by the caller via `dequant_workspace`,
* sized `num_experts * K * N * sizeof(act_dtype)`) and then dispatches to the
* existing `moe_gemm` baseline. This guarantees numerical parity with the
* decode path. Mainloop fusion is the follow-up perf-tuning step.
*
* Implementation is header-only in `sycl_tla_moe_mixed.hpp`.
*
* Layout convention (matches `moe_gemm_decode`):
* - activations: [total_tokens, K] in act_dtype
* - weights (quantized): [num_experts, N, K_p] uint8 (decode-style packed)
* - weights (FP16/BF16): [num_experts, K, N] in act_dtype (matches `moe_gemm`)
* - scales: [num_experts, N, K/group_size] in act_dtype
* - zeros (asym only): [num_experts, N, K/group_size] in act_dtype
* - dequant_workspace: [num_experts, K, N] in act_dtype, may be null
* for the unquantized fast path
* - outputs: [total_tokens, N] in act_dtype
*/
void moe_gemm_prefill(sycl::queue* q, void* activations, void* weights, void* scales, void* zeros, void* outputs,
void* dequant_workspace, BTLA_DTYPE act_dtype, BTLA_DTYPE weight_dtype, int N, int K,
int group_size, int* num_tokens_per_expert, int num_experts, int total_tokens, bool asym);

// ========================================================================
// Public API
// ========================================================================
Expand Down
Loading
Loading