From 58b090033dd473c76b2f72cb23ac00ec3f8f7101 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 14 May 2026 04:03:37 +0000 Subject: [PATCH 01/28] Add XPU MoE decode kernel with INT4 sym/asym and FP16/BF16 baselines Agent-Logs-Url: https://github.com/intel/auto-round/sessions/95841e6d-d5d1-4662-8db0-4dd69690bc28 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com> --- .../ark/auto_round_kernel/__init__.py | 150 +++++++++ .../ark/auto_round_kernel/ark.cpp | 12 + .../wrapper/include/sycl_tla_common.hpp | 34 +++ .../wrapper/include/sycl_tla_moe_decode.hpp | 286 ++++++++++++++++++ auto_round_extension/ark/test/test_moe.py | 223 ++++++++++++++ 5 files changed, 705 insertions(+) create mode 100644 auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp diff --git a/auto_round_extension/ark/auto_round_kernel/__init__.py b/auto_round_extension/ark/auto_round_kernel/__init__.py index 48778aff3..511dd7b95 100644 --- a/auto_round_extension/ark/auto_round_kernel/__init__.py +++ b/auto_round_extension/ark/auto_round_kernel/__init__.py @@ -795,6 +795,156 @@ def moe_gemm( ) return outputs + def moe_gemm_decode( + self, + activations: torch.Tensor, + weights: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + *, + scales: Optional[torch.Tensor] = None, + zeros: Optional[torch.Tensor] = None, + weight_bits: int = 4, + group_size: int = 128, + asym: bool = False, + ) -> torch.Tensor: + """MoE GEMV optimized for the decode phase. + + Each expert typically processes only 1-2 tokens (top-k routing with + small batch). Activations must already be gathered/sorted by expert + (same convention as ``moe_gemm``). + + Args: + activations: ``[total_tokens, K]`` in fp16 or bf16. + weights: + * ``weight_bits=4``: packed uint8 of shape ``[E, N, K // 2]`` + (two int4 values per byte; low nibble at the lower K index). + * ``weight_bits=16``: fp16/bf16 of shape ``[E, N, K]`` matching + the activations dtype. + num_tokens_per_expert: ``[E]`` int32. Sum must equal + ``activations.shape[0]``. + scales: ``[E, N, K // group_size]`` in activations dtype. Required + when ``weight_bits=4``; ignored otherwise. + zeros: ``[E, N, K // group_size]`` in activations dtype. Required + when ``weight_bits=4`` and ``asym=True``; otherwise None. + weight_bits: 4 (int4, S4_CLIP) or 16 (no quantization). + group_size: group along K for int4 weights (default 128). + asym: if True, weights are stored as unsigned nibbles and ``zeros`` + must be provided. + + Returns: + outputs: ``[total_tokens, N]`` in the same dtype as activations. + """ + if activations.device.type != "xpu": + raise NotImplementedError("moe_gemm_decode is only supported on XPU") + + if activations.dtype not in (torch.float16, torch.bfloat16): + raise ValueError(f"activations must be fp16/bf16, got {activations.dtype}") + + if activations.ndim != 2: + raise ValueError("activations must be 2D [total_tokens, K]") + if weights.ndim != 3: + raise ValueError("weights must be 3D [E, N, K] or [E, N, K//2]") + + if not activations.is_contiguous(): + activations = activations.contiguous() + if not weights.is_contiguous(): + weights = weights.contiguous() + + if num_tokens_per_expert.dtype != torch.int32: + num_tokens_per_expert = num_tokens_per_expert.to(torch.int32) + if not num_tokens_per_expert.is_contiguous(): + num_tokens_per_expert = num_tokens_per_expert.contiguous() + + total_tokens, K = activations.shape + num_experts = weights.shape[0] + N = weights.shape[1] + + if num_tokens_per_expert.shape[0] != num_experts: + raise ValueError( + f"num_tokens_per_expert length {num_tokens_per_expert.shape[0]} != num_experts {num_experts}" + ) + + # Validate weight layout / dtype combination. + if weight_bits == 16: + if weights.dtype != activations.dtype: + raise ValueError("Unquantized weights must match activations dtype") + if weights.shape[2] != K: + raise ValueError(f"Unquantized weights K dim {weights.shape[2]} != activations K {K}") + weight_dtype = cvt_dtype(activations.dtype) + if scales is not None or zeros is not None: + raise ValueError("scales/zeros must be None when weight_bits=16") + elif weight_bits == 4: + if weights.dtype != torch.uint8: + raise ValueError("Int4 packed weights must be torch.uint8") + if weights.shape[2] * 2 != K: + raise ValueError( + f"Int4 packed weights last dim {weights.shape[2]} must equal K/2 ({K // 2})" + ) + if scales is None: + raise ValueError("scales is required for int4 weights") + if scales.dtype != activations.dtype: + raise ValueError("scales dtype must match activations dtype") + if K % group_size != 0 or (group_size & 1) != 0: + raise ValueError("K must be a multiple of group_size and group_size must be even") + expected_scale_shape = (num_experts, N, K // group_size) + if tuple(scales.shape) != expected_scale_shape: + raise ValueError(f"scales shape {tuple(scales.shape)} != expected {expected_scale_shape}") + if asym: + if zeros is None: + raise ValueError("zeros is required when asym=True") + if zeros.dtype != activations.dtype: + raise ValueError("zeros dtype must match activations dtype") + if tuple(zeros.shape) != expected_scale_shape: + raise ValueError(f"zeros shape {tuple(zeros.shape)} != expected {expected_scale_shape}") + else: + if zeros is not None: + raise ValueError("zeros must be None when asym=False") + # S4_CLIP weight dtype, regardless of activations dtype. + weight_dtype = ARK_DT.int4 + if not scales.is_contiguous(): + scales = scales.contiguous() + if asym and not zeros.is_contiguous(): + zeros = zeros.contiguous() + else: + raise ValueError(f"Unsupported weight_bits={weight_bits} (supported: 4, 16)") + + if N % 16 != 0: + raise ValueError(f"N must be a multiple of 16 (got {N})") + + expected_total = int(num_tokens_per_expert.sum().item()) + if expected_total != total_tokens: + raise ValueError(f"Sum of num_tokens_per_expert ({expected_total}) != total_tokens ({total_tokens})") + + lib = self.get_lib(activations) + stream = get_stream(activations) + outputs = torch.empty((total_tokens, N), device=activations.device, dtype=activations.dtype) + # Scratch buffer mapping each token to its expert id; filled on-device + # inside the kernel wrapper so we avoid host-device sync. + expert_id_per_token = torch.empty((total_tokens,), device=activations.device, dtype=torch.int32) + + scales_ptr = scales.data_ptr() if scales is not None else 0 + zeros_ptr = zeros.data_ptr() if zeros is not None else 0 + + lib.moe_gemm_decode( + stream, + activations.data_ptr(), + weights.data_ptr(), + scales_ptr, + zeros_ptr, + outputs.data_ptr(), + expert_id_per_token.data_ptr(), + cvt_dtype(activations.dtype), + weight_dtype, + N, + K, + group_size, + num_tokens_per_expert.data_ptr(), + num_experts, + total_tokens, + bool(asym), + ) + return outputs + if __name__ == "__main__": ark = ARK() diff --git a/auto_round_extension/ark/auto_round_kernel/ark.cpp b/auto_round_extension/ark/auto_round_kernel/ark.cpp index cd5c1ab03..39819f4db 100755 --- a/auto_round_extension/ark/auto_round_kernel/ark.cpp +++ b/auto_round_extension/ark/auto_round_kernel/ark.cpp @@ -25,6 +25,7 @@ typedef uintptr_t torch_ptr; // Only include declarations, implementations are in separate .cpp files #include "sycl_tla_common.hpp" #include "sycl_tla_moe.hpp" +#include "sycl_tla_moe_decode.hpp" #include "sycl_tla_sdpa.hpp" #endif #else @@ -153,6 +154,16 @@ static void moe_gemm_wrapper(torch_ptr stream, torch_ptr activations, torch_ptr (void*)outputs, (BTLA_DTYPE)(dtype), N, K, (int*)num_tokens_per_expert, num_experts); } +static void moe_gemm_decode_wrapper(torch_ptr stream, torch_ptr activations, torch_ptr weights, torch_ptr scales, + torch_ptr zeros, torch_ptr outputs, torch_ptr expert_id_per_token_buf, + int act_dtype, int weight_dtype, int N, int K, int group_size, + torch_ptr num_tokens_per_expert, int num_experts, int total_tokens, bool asym) { + ark::moe_gemm_decode((sycl::queue*)stream, (void*)activations, (void*)weights, scales ? (void*)scales : nullptr, + zeros ? (void*)zeros : nullptr, (void*)outputs, (int*)expert_id_per_token_buf, + (BTLA_DTYPE)(act_dtype), (BTLA_DTYPE)(weight_dtype), N, K, group_size, + (int*)num_tokens_per_expert, num_experts, total_tokens, asym); +} + static void sage_dynamic_quant(torch_ptr stream, torch_ptr input, torch_ptr output, torch_ptr scale_out, int num_rows, int head_dim, int block_size) { auto* q = (sycl::queue*)stream; @@ -279,5 +290,6 @@ PYBIND11_MODULE(PY_NAME, m) { m.def("sage", &ark::sage); m.def("sage_dynamic_quant", &ark::sage_dynamic_quant); m.def("moe_gemm", &ark::moe_gemm_wrapper); + m.def("moe_gemm_decode", &ark::moe_gemm_decode_wrapper); #endif } \ No newline at end of file diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp index 37752f79a..31b3994ce 100644 --- a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp @@ -35,6 +35,40 @@ namespace ark { void moe_gemm(sycl::queue* q, void* activations, void* weights, void* scales, void* outputs, BTLA_DTYPE dtype, int N, int K, int* num_tokens_per_expert, int num_experts); +/** + * @brief MoE GEMV optimized for the decode phase (M per expert is typically + * 1-2 tokens). Supports unquantized FP16/BF16 weights and int4 (S4_CLIP) + * weights with group-wise scales and optional zero-points. + * + * Implementation is header-only in `sycl_tla_moe_decode.hpp`. + * + * @param q SYCL queue + * @param activations [total_tokens, K] in `act_dtype` + * @param weights Unquantized: [num_experts, N, K] in act_dtype + * Int4: packed [num_experts, N, K/2] uint8 + * @param scales [num_experts, N, K/group_size] (act_dtype), + * ignored when weight_dtype is FP16/BF16 + * @param zeros [num_experts, N, K/group_size] (act_dtype) or + * nullptr; required when asym==true + * @param outputs [total_tokens, N] in act_dtype + * @param expert_id_per_token_buf [total_tokens] int32 scratch buffer (device) + * @param act_dtype BTLA_DTYPE::F16 or BTLA_DTYPE::BF16 + * @param weight_dtype BTLA_DTYPE::F16/BF16/S4_CLIP + * @param N Output feature dim (must be multiple of 16) + * @param K Input feature dim + * @param group_size Quantization group along K (int4 only); must + * divide K and be even. Default 128. + * @param num_tokens_per_expert [num_experts] int32 + * @param num_experts Number of experts + * @param total_tokens Sum of num_tokens_per_expert (== rows of + * activations / outputs) + * @param asym Whether int4 weights are asymmetric + * (zeros required when true). + */ +void moe_gemm_decode(sycl::queue* q, void* activations, void* weights, void* scales, void* zeros, void* outputs, + int* expert_id_per_token_buf, BTLA_DTYPE act_dtype, BTLA_DTYPE weight_dtype, int N, int K, + int group_size, int* num_tokens_per_expert, int num_experts, int total_tokens, bool asym); + // ======================================================================== // Public API // ======================================================================== diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp new file mode 100644 index 000000000..f1bb4cbac --- /dev/null +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp @@ -0,0 +1,286 @@ +// SYCL MoE Decode Kernel +// +// GEMV-style MoE kernel optimized for the decode phase, where each expert +// typically processes only 1-2 tokens (top-k routing with batch size 1). +// +// Layout convention (caller already sorted activations per expert, +// identical to the prefill `moe_gemm` interface): +// - activations: [total_tokens, K] row-major +// - weights (fp/bf16): [num_experts, N, K] row-major +// - weights (int4 packed): [num_experts, N, K/2] row-major, two +// 4-bit values per byte (low nibble at lower K) +// - scales: [num_experts, N, K/group_size] +// - zeros (asym only): [num_experts, N, K/group_size] +// - num_tokens_per_expert: [num_experts] int32 +// - outputs: [total_tokens, N] +// +// Target: Intel BMG (Xe2), sub_group_size = 16. One sub-group per (token, N-tile) +// with N_TILE == SG_SIZE: each lane independently computes one output element, +// so no cross-lane reduction is needed and activation reads are coalesced across +// the sub-group through the L1 cache. +// +// Copyright (C) 2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "bestla/bestla/bestla.h" + +#ifdef ARK_XPU +#include +#endif + +#if defined(ARK_XPU) && defined(ARK_SYCL_TLA) + +namespace ark { +namespace moe_decode_detail { + +constexpr int SG_SIZE = 16; +constexpr int N_TILE = SG_SIZE; // one output element per sub-group lane + +// ---------------------------------------------------------------------------- +// Kernel name tags (one per specialization, required for SYCL kernel naming) +// ---------------------------------------------------------------------------- +template +class MoEDecodeKernelFP; + +template +class MoEDecodeKernelInt4; + +// ---------------------------------------------------------------------------- +// Build a [total_tokens] -> expert_id mapping from num_tokens_per_expert. +// Runs on host (num_experts is small, total_tokens is small in decode). +// Caller-managed buffer (USM device allocation) keeps host noise out of the +// hot path; here we just fill it via a tiny SYCL kernel for simplicity. +// ---------------------------------------------------------------------------- +inline void fill_expert_id_per_token(sycl::queue* q, int* expert_id_per_token, + const int* num_tokens_per_expert, int num_experts, + int total_tokens) { + // Sequential prefix-scan on a single thread; cheap because num_experts is + // small (typ. <= 256) and we avoid host-device sync entirely. + q->single_task([=]() { + int offset = 0; + for (int e = 0; e < num_experts; ++e) { + int n = num_tokens_per_expert[e]; + for (int i = 0; i < n; ++i) { + if (offset + i < total_tokens) { + expert_id_per_token[offset + i] = e; + } + } + offset += n; + } + }).wait(); +} + +// ---------------------------------------------------------------------------- +// FP16 / BF16 baseline GEMV (no quantization). +// ---------------------------------------------------------------------------- +template +void launch_fp(sycl::queue* q, const ScalarT* activations, const ScalarT* weights, ScalarT* outputs, + const int* expert_id_per_token, int total_tokens, int N, int K) { + if (N % N_TILE != 0) { + throw std::invalid_argument("moe_gemm_decode: N must be a multiple of 16"); + } + if (total_tokens == 0) return; + + const int n_tiles = N / N_TILE; + sycl::range<2> global{static_cast(total_tokens), static_cast(n_tiles * SG_SIZE)}; + sycl::range<2> local{1, static_cast(SG_SIZE)}; + + q->parallel_for>( + sycl::nd_range<2>(global, local), + [=](sycl::nd_item<2> it) [[intel::reqd_sub_group_size(SG_SIZE)]] { + const int token = static_cast(it.get_global_id(0)); + const int n_tile = static_cast(it.get_group(1)); + const int lane = static_cast(it.get_local_id(1)); + const int n_global = n_tile * N_TILE + lane; + + const int expert = expert_id_per_token[token]; + const ScalarT* act_row = activations + static_cast(token) * K; + const ScalarT* w_row = + weights + (static_cast(expert) * N + static_cast(n_global)) * K; + + float acc = 0.0f; + // Unroll by 8 to hide latency; arbitrary K (any multiple of 8). + int k = 0; + constexpr int UNROLL = 8; + for (; k + UNROLL <= K; k += UNROLL) { +#pragma unroll + for (int u = 0; u < UNROLL; ++u) { + acc += static_cast(act_row[k + u]) * static_cast(w_row[k + u]); + } + } + for (; k < K; ++k) { + acc += static_cast(act_row[k]) * static_cast(w_row[k]); + } + + outputs[static_cast(token) * N + n_global] = static_cast(acc); + }) + .wait(); +} + +// ---------------------------------------------------------------------------- +// INT4 (S4_CLIP) GEMV with group-wise dequantization. +// +// Asym=false: signed nibble in [-8, 7], dequant = q * scale +// Asym=true : unsigned nibble in [0, 15], dequant = (q - zero) * scale +// +// Packing: two 4-bit values per byte; the value at k = 2*i is the LOW nibble +// of byte i, the value at k = 2*i+1 is the HIGH nibble. This matches the +// existing CPU/XPU `packq` layout for S4_CLIP weights. +// ---------------------------------------------------------------------------- +template +void launch_int4(sycl::queue* q, const ScalarT* activations, const uint8_t* weights, const ScalarT* scales, + const ScalarT* zeros, ScalarT* outputs, const int* expert_id_per_token, int total_tokens, int N, + int K, int group_size) { + if (N % N_TILE != 0) { + throw std::invalid_argument("moe_gemm_decode(int4): N must be a multiple of 16"); + } + if (K % group_size != 0 || (group_size & 1) != 0) { + throw std::invalid_argument("moe_gemm_decode(int4): K must be a multiple of group_size and group_size must be even"); + } + if (Asym && zeros == nullptr) { + throw std::invalid_argument("moe_gemm_decode(int4): zeros pointer required when asym=true"); + } + if (total_tokens == 0) return; + + const int n_tiles = N / N_TILE; + const int num_groups_k = K / group_size; + const int k_packed = K / 2; // bytes of packed weight per (expert, n) + + sycl::range<2> global{static_cast(total_tokens), static_cast(n_tiles * SG_SIZE)}; + sycl::range<2> local{1, static_cast(SG_SIZE)}; + + q->parallel_for>( + sycl::nd_range<2>(global, local), + [=](sycl::nd_item<2> it) [[intel::reqd_sub_group_size(SG_SIZE)]] { + const int token = static_cast(it.get_global_id(0)); + const int n_tile = static_cast(it.get_group(1)); + const int lane = static_cast(it.get_local_id(1)); + const int n_global = n_tile * N_TILE + lane; + + const int expert = expert_id_per_token[token]; + const ScalarT* act_row = activations + static_cast(token) * K; + + const uint8_t* w_row = + weights + (static_cast(expert) * N + static_cast(n_global)) * k_packed; + const ScalarT* s_row = + scales + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k; + const ScalarT* z_row = Asym + ? zeros + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k + : nullptr; + + float acc = 0.0f; + for (int g = 0; g < num_groups_k; ++g) { + const float scale = static_cast(s_row[g]); + float zero = 0.0f; + if constexpr (Asym) { + zero = static_cast(z_row[g]); + } + const int k_base = g * group_size; + // Two nibbles per byte; iterate in pairs. + for (int kk = 0; kk < group_size; kk += 2) { + const uint8_t packed = w_row[(k_base + kk) / 2]; + float w0, w1; + if constexpr (Asym) { + const int q0 = static_cast(packed & 0x0F); + const int q1 = static_cast((packed >> 4) & 0x0F); + w0 = (static_cast(q0) - zero) * scale; + w1 = (static_cast(q1) - zero) * scale; + } else { + // Sign-extend each nibble from 4-bit signed to 8-bit signed. + const int q0 = static_cast(static_cast(packed << 4) >> 4); + const int q1 = static_cast(static_cast(packed & 0xF0) >> 4); + w0 = static_cast(q0) * scale; + w1 = static_cast(q1) * scale; + } + acc += static_cast(act_row[k_base + kk]) * w0; + acc += static_cast(act_row[k_base + kk + 1]) * w1; + } + } + + outputs[static_cast(token) * N + n_global] = static_cast(acc); + }) + .wait(); +} + +} // namespace moe_decode_detail + +// ---------------------------------------------------------------------------- +// Public API +// +// weight_dtype: +// BTLA_DTYPE::F16 / BF16 : weights stored as [E, N, K] in matching +// floating dtype, no scales/zeros needed +// BTLA_DTYPE::S4_CLIP : packed int4 weights [E, N, K/2] (uint8), +// scales [E, N, K/group_size] in act dtype, +// zeros optional (asym==true requires it) +// act_dtype: F16 or BF16 (must match scales/outputs dtype) +// ---------------------------------------------------------------------------- +inline void moe_gemm_decode(sycl::queue* q, void* activations, void* weights, void* scales, void* zeros, + void* outputs, int* expert_id_per_token_buf, BTLA_DTYPE act_dtype, + BTLA_DTYPE weight_dtype, int N, int K, int group_size, int* num_tokens_per_expert, + int num_experts, int total_tokens, bool asym) { + moe_decode_detail::fill_expert_id_per_token(q, expert_id_per_token_buf, num_tokens_per_expert, num_experts, + total_tokens); + + if (weight_dtype == BTLA_DTYPE::F16 || weight_dtype == BTLA_DTYPE::BF16) { + if (weight_dtype != act_dtype) { + throw std::invalid_argument("moe_gemm_decode: unquantized weight_dtype must match act_dtype"); + } + if (act_dtype == BTLA_DTYPE::F16) { + moe_decode_detail::launch_fp(q, static_cast(activations), + static_cast(weights), + static_cast(outputs), expert_id_per_token_buf, + total_tokens, N, K); + } else { + moe_decode_detail::launch_fp( + q, static_cast(activations), + static_cast(weights), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K); + } + return; + } + + if (weight_dtype == BTLA_DTYPE::S4_CLIP) { + if (act_dtype == BTLA_DTYPE::F16) { + if (asym) { + moe_decode_detail::launch_int4( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int4( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else if (act_dtype == BTLA_DTYPE::BF16) { + using BF = sycl::ext::oneapi::bfloat16; + if (asym) { + moe_decode_detail::launch_int4( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int4( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else { + throw std::invalid_argument("moe_gemm_decode(int4): act_dtype must be FP16 or BF16"); + } + return; + } + + throw std::invalid_argument("moe_gemm_decode: unsupported weight_dtype (supported: F16, BF16, S4_CLIP)"); +} + +} // namespace ark + +#endif // ARK_XPU && ARK_SYCL_TLA diff --git a/auto_round_extension/ark/test/test_moe.py b/auto_round_extension/ark/test/test_moe.py index 1134f7b4d..fe1e29065 100644 --- a/auto_round_extension/ark/test/test_moe.py +++ b/auto_round_extension/ark/test/test_moe.py @@ -39,6 +39,13 @@ def has_moe_gemm(): return hasattr(ark.xpu_lib, "moe_gemm") +def has_moe_gemm_decode(): + """Check if MoE decode GEMV kernel is available.""" + if ark.xpu_lib is None: + return False + return hasattr(ark.xpu_lib, "moe_gemm_decode") + + @pytest.mark.skipif(not is_xpu_available(), reason="XPU not available") @pytest.mark.skipif(not has_moe_gemm(), reason="MOE GEMM kernel not built (need ARK_SYCL_TLA=ON)") class TestMoEGemm: @@ -166,5 +173,221 @@ def test_moe_gemm_various_sizes(self, N, K): print(f"MOE GEMM test passed for N={N}, K={K}") +# --------------------------------------------------------------------------- +# Decode-path tests (M per expert is typically 1-2, mirrors top-k routing +# after the activations have been gathered/sorted by the upper layer). +# --------------------------------------------------------------------------- + + +def _pack_int4_sym(w_float, scales, group_size): + """Quantize a [E, N, K] fp tensor to symmetric int4 packed [E, N, K/2]. + + scales is filled in-place with [E, N, K/group_size] values. + """ + E, N, K = w_float.shape + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + s = (absmax / 7.0).squeeze(-1).to(scales.dtype) + scales.copy_(s) + q = torch.clamp(torch.round(w / (s.to(w.dtype).unsqueeze(-1))), -8, 7).to(torch.int8) + q = q.reshape(E, N, K) + # Pack two nibbles per byte: low nibble at lower K, high nibble at higher K. + q_low = q[..., 0::2] & 0x0F + q_high = q[..., 1::2] & 0x0F + packed = (q_low | (q_high << 4)).to(torch.uint8) + return packed + + +def _pack_int4_asym(w_float, scales, zeros, group_size): + """Quantize to asymmetric int4 (range [0, 15]); returns packed weights.""" + E, N, K = w_float.shape + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + wmin = w.amin(dim=-1, keepdim=True) + wmax = w.amax(dim=-1, keepdim=True) + s = ((wmax - wmin) / 15.0).clamp(min=1e-8) + z = torch.round(-wmin / s).clamp(0, 15) + scales.copy_(s.squeeze(-1).to(scales.dtype)) + zeros.copy_(z.squeeze(-1).to(zeros.dtype)) + q = torch.clamp(torch.round(w / s + z), 0, 15).to(torch.int32) + q = q.reshape(E, N, K) + q_low = q[..., 0::2] & 0x0F + q_high = q[..., 1::2] & 0x0F + packed = (q_low | (q_high << 4)).to(torch.uint8) + return packed + + +def _dequant_int4_sym(packed, scales, group_size): + """Inverse of _pack_int4_sym. Returns [E, N, K] in scales.dtype.""" + E, N, K_half = packed.shape + K = K_half * 2 + low = (packed & 0x0F).to(torch.int8) + high = ((packed >> 4) & 0x0F).to(torch.int8) + # Sign extend 4-bit -> 8-bit + low = torch.where(low >= 8, low - 16, low) + high = torch.where(high >= 8, high - 16, high) + q = torch.empty(E, N, K, dtype=torch.int8, device=packed.device) + q[..., 0::2] = low + q[..., 1::2] = high + q = q.reshape(E, N, K // group_size, group_size).to(scales.dtype) + return (q * scales.unsqueeze(-1)).reshape(E, N, K) + + +def _dequant_int4_asym(packed, scales, zeros, group_size): + E, N, K_half = packed.shape + K = K_half * 2 + low = (packed & 0x0F).to(torch.int32) + high = ((packed >> 4) & 0x0F).to(torch.int32) + q = torch.empty(E, N, K, dtype=torch.int32, device=packed.device) + q[..., 0::2] = low + q[..., 1::2] = high + q = q.reshape(E, N, K // group_size, group_size).to(scales.dtype) + deq = (q - zeros.to(scales.dtype).unsqueeze(-1)) * scales.unsqueeze(-1) + return deq.reshape(E, N, K) + + +def _moe_decode_reference(activations, dequant_weights, num_tokens_per_expert): + """Reference: each token is matmul'd against its routed expert's weights.""" + total_tokens, K = activations.shape + E, N, _ = dequant_weights.shape + out = torch.empty(total_tokens, N, dtype=activations.dtype, device=activations.device) + offset = 0 + for e in range(E): + n_tokens = int(num_tokens_per_expert[e].item()) + if n_tokens == 0: + continue + a = activations[offset : offset + n_tokens] # [n_tokens, K] + w = dequant_weights[e] # [N, K] + out[offset : offset + n_tokens] = a @ w.T + offset += n_tokens + return out + + +@pytest.mark.skipif(not is_xpu_available(), reason="XPU not available") +@pytest.mark.skipif( + not has_moe_gemm_decode(), reason="MoE decode GEMV kernel not built (need ARK_SYCL_TLA=ON)" +) +class TestMoEGemmDecode: + """Unit tests for the MoE decode GEMV kernel. + + The activations layout follows the same convention as ``moe_gemm``: the + upper layer has already gathered/sorted tokens per expert, so the kernel + only needs ``num_tokens_per_expert`` (no top-k indices). + """ + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_decode_fp_basic(self, dtype): + num_experts = 4 + # One token per expert with one zero-token expert -> typical top-k=3 + # decode pattern after gather. + tokens_per_expert = [1, 0, 1, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 128 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + weights = torch.randn(num_experts, N, K, dtype=dtype, device="xpu") * 0.1 + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode(activations, weights, num_tokens_per_expert, weight_bits=16) + + ref = _moe_decode_reference(activations, weights, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + assert out.dtype == dtype + torch.testing.assert_close(out, ref, rtol=2e-2, atol=2e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_decode_int4_sym(self, dtype, group_size): + num_experts = 4 + tokens_per_expert = [1, 1, 0, 2] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_sym(w_float, scales, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + weight_bits=4, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_int4_sym(packed, scales, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_decode_int4_asym(self, dtype): + num_experts = 4 + group_size = 128 + tokens_per_expert = [0, 1, 2, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_asym(w_float, scales, zeros, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=4, + group_size=group_size, + asym=True, + ) + + dequant = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + def test_decode_validation_errors(self): + """Sanity-check that Python-side validation catches misuse.""" + num_experts = 2 + activations = torch.randn(2, 128, dtype=torch.float16, device="xpu") + num_tokens_per_expert = torch.tensor([1, 1], dtype=torch.int32, device="xpu") + + # N must be a multiple of 16 + bad_weights = torch.randn(num_experts, 17, 128, dtype=torch.float16, device="xpu") + with pytest.raises(ValueError): + ark.moe_gemm_decode(activations, bad_weights, num_tokens_per_expert, weight_bits=16) + + # weight_bits=4 requires uint8 packed weights + bad_packed = torch.randn(num_experts, 64, 64, dtype=torch.float16, device="xpu") + scales = torch.empty(num_experts, 64, 1, dtype=torch.float16, device="xpu") + with pytest.raises(ValueError): + ark.moe_gemm_decode( + activations, bad_packed, num_tokens_per_expert, scales=scales, weight_bits=4, group_size=128 + ) + + # asym=True without zeros must error + packed = torch.zeros(num_experts, 64, 64, dtype=torch.uint8, device="xpu") + with pytest.raises(ValueError): + ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + weight_bits=4, + group_size=128, + asym=True, + ) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 527eede0635b5db494c9d4d65323be471ceeb1be Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 14 May 2026 04:06:02 +0000 Subject: [PATCH 02/28] Document int4 sign-extension trick Agent-Logs-Url: https://github.com/intel/auto-round/sessions/95841e6d-d5d1-4662-8db0-4dd69690bc28 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com> --- .../wrapper/include/sycl_tla_moe_decode.hpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp index f1bb4cbac..97668991b 100644 --- a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp @@ -192,7 +192,10 @@ void launch_int4(sycl::queue* q, const ScalarT* activations, const uint8_t* weig w0 = (static_cast(q0) - zero) * scale; w1 = (static_cast(q1) - zero) * scale; } else { - // Sign-extend each nibble from 4-bit signed to 8-bit signed. + // Sign-extend each 4-bit signed nibble to 8-bit signed: + // low nibble: shift left by 4 to move bit[3] into bit[7], + // then arithmetic right-shift by 4 replicates the sign bit. + // high nibble: same trick after masking off the low nibble. const int q0 = static_cast(static_cast(packed << 4) >> 4); const int q1 = static_cast(static_cast(packed & 0xF0) >> 4); w0 = static_cast(q0) * scale; From 78ecc0c689a9e3089368a3f226411d30d367adf3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 May 2026 04:27:46 +0000 Subject: [PATCH 03/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round_extension/ark/auto_round_kernel/__init__.py | 4 +--- auto_round_extension/ark/test/test_moe.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/auto_round_extension/ark/auto_round_kernel/__init__.py b/auto_round_extension/ark/auto_round_kernel/__init__.py index 511dd7b95..04f5434ad 100644 --- a/auto_round_extension/ark/auto_round_kernel/__init__.py +++ b/auto_round_extension/ark/auto_round_kernel/__init__.py @@ -877,9 +877,7 @@ def moe_gemm_decode( if weights.dtype != torch.uint8: raise ValueError("Int4 packed weights must be torch.uint8") if weights.shape[2] * 2 != K: - raise ValueError( - f"Int4 packed weights last dim {weights.shape[2]} must equal K/2 ({K // 2})" - ) + raise ValueError(f"Int4 packed weights last dim {weights.shape[2]} must equal K/2 ({K // 2})") if scales is None: raise ValueError("scales is required for int4 weights") if scales.dtype != activations.dtype: diff --git a/auto_round_extension/ark/test/test_moe.py b/auto_round_extension/ark/test/test_moe.py index fe1e29065..e7905569e 100644 --- a/auto_round_extension/ark/test/test_moe.py +++ b/auto_round_extension/ark/test/test_moe.py @@ -265,9 +265,7 @@ def _moe_decode_reference(activations, dequant_weights, num_tokens_per_expert): @pytest.mark.skipif(not is_xpu_available(), reason="XPU not available") -@pytest.mark.skipif( - not has_moe_gemm_decode(), reason="MoE decode GEMV kernel not built (need ARK_SYCL_TLA=ON)" -) +@pytest.mark.skipif(not has_moe_gemm_decode(), reason="MoE decode GEMV kernel not built (need ARK_SYCL_TLA=ON)") class TestMoEGemmDecode: """Unit tests for the MoE decode GEMV kernel. From 5dc9d95843b1037430e9d366aee4d2e290146ad1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 14 May 2026 07:16:06 +0000 Subject: [PATCH 04/28] Add INT8/INT2/FP8 decode MoE GEMV kernels and tests Agent-Logs-Url: https://github.com/intel/auto-round/sessions/91221649-2c90-4404-ae86-3321b1581428 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com> --- .../ark/auto_round_kernel/__init__.py | 102 ++++- .../wrapper/include/sycl_tla_moe_decode.hpp | 414 +++++++++++++++++- auto_round_extension/ark/test/test_moe.py | 326 ++++++++++++++ 3 files changed, 818 insertions(+), 24 deletions(-) diff --git a/auto_round_extension/ark/auto_round_kernel/__init__.py b/auto_round_extension/ark/auto_round_kernel/__init__.py index 04f5434ad..8f0be1e98 100644 --- a/auto_round_extension/ark/auto_round_kernel/__init__.py +++ b/auto_round_extension/ark/auto_round_kernel/__init__.py @@ -815,21 +815,35 @@ def moe_gemm_decode( Args: activations: ``[total_tokens, K]`` in fp16 or bf16. - weights: - * ``weight_bits=4``: packed uint8 of shape ``[E, N, K // 2]`` - (two int4 values per byte; low nibble at the lower K index). - * ``weight_bits=16``: fp16/bf16 of shape ``[E, N, K]`` matching - the activations dtype. + weights: 3-D tensor ``[E, N, K_packed]``. The accepted layouts are: + + * Unquantized (``weight_bits=16``): ``torch.float16`` / ``torch.bfloat16`` + matching the activations dtype, ``K_packed == K``. + * Int8 (``weight_bits=8``): ``torch.uint8``, ``K_packed == K``. + Sym (``asym=False``) reinterprets each byte as signed int8; + asym (``asym=True``) treats each byte as ``uint8`` with a + per-group zero-point. + * Int4 (``weight_bits=4``): ``torch.uint8`` packed, + ``K_packed == K // 2`` (two 4-bit values per byte; low nibble + at the lower K index). + * Int2 (``weight_bits=2``): ``torch.uint8`` packed, + ``K_packed == K // 4`` (four 2-bit values per byte; field j at + K index ``4*i + j`` is bits ``[2j+1:2j]`` of byte i). + * FP8 (``torch.float8_e4m3fn`` / ``torch.float8_e5m2``): + ``K_packed == K``. ``weight_bits`` is ignored; ``asym`` must + be ``False`` (no zero-points for FP8). num_tokens_per_expert: ``[E]`` int32. Sum must equal ``activations.shape[0]``. scales: ``[E, N, K // group_size]`` in activations dtype. Required - when ``weight_bits=4``; ignored otherwise. + for all quantized paths (int8/int4/int2/fp8); must be ``None`` + for unquantized weights. zeros: ``[E, N, K // group_size]`` in activations dtype. Required - when ``weight_bits=4`` and ``asym=True``; otherwise None. - weight_bits: 4 (int4, S4_CLIP) or 16 (no quantization). - group_size: group along K for int4 weights (default 128). - asym: if True, weights are stored as unsigned nibbles and ``zeros`` - must be provided. + when ``asym=True`` (int8/int4/int2 only); otherwise ``None``. + weight_bits: 2, 4, 8, or 16. Ignored when ``weights`` is an FP8 + tensor (the FP8 sub-format is taken from ``weights.dtype``). + group_size: group along K for quantized weights (default 128). + asym: if ``True``, weights use unsigned encoding and ``zeros`` must + be provided. Not supported for FP8. Returns: outputs: ``[total_tokens, N]`` in the same dtype as activations. @@ -843,7 +857,7 @@ def moe_gemm_decode( if activations.ndim != 2: raise ValueError("activations must be 2D [total_tokens, K]") if weights.ndim != 3: - raise ValueError("weights must be 3D [E, N, K] or [E, N, K//2]") + raise ValueError("weights must be 3D [E, N, K_packed]") if not activations.is_contiguous(): activations = activations.contiguous() @@ -864,8 +878,32 @@ def moe_gemm_decode( f"num_tokens_per_expert length {num_tokens_per_expert.shape[0]} != num_experts {num_experts}" ) + # Detect FP8 weight dtype first (overrides weight_bits). + is_fp8 = weights.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + # Validate weight layout / dtype combination. - if weight_bits == 16: + if is_fp8: + if asym: + raise ValueError("FP8 weights do not support asym=True") + if weights.shape[2] != K: + raise ValueError(f"FP8 weights K dim {weights.shape[2]} != activations K {K}") + if scales is None: + raise ValueError("scales is required for FP8 weights") + if scales.dtype != activations.dtype: + raise ValueError("scales dtype must match activations dtype") + if K % group_size != 0: + raise ValueError("K must be a multiple of group_size") + expected_scale_shape = (num_experts, N, K // group_size) + if tuple(scales.shape) != expected_scale_shape: + raise ValueError(f"scales shape {tuple(scales.shape)} != expected {expected_scale_shape}") + if zeros is not None: + raise ValueError("zeros must be None for FP8 weights") + weight_dtype = ( + ARK_DT.float8_e4m3 if weights.dtype == torch.float8_e4m3fn else ARK_DT.float8_e5m2 + ) + if not scales.is_contiguous(): + scales = scales.contiguous() + elif weight_bits == 16: if weights.dtype != activations.dtype: raise ValueError("Unquantized weights must match activations dtype") if weights.shape[2] != K: @@ -873,17 +911,36 @@ def moe_gemm_decode( weight_dtype = cvt_dtype(activations.dtype) if scales is not None or zeros is not None: raise ValueError("scales/zeros must be None when weight_bits=16") - elif weight_bits == 4: + elif weight_bits in (8, 4, 2): if weights.dtype != torch.uint8: - raise ValueError("Int4 packed weights must be torch.uint8") - if weights.shape[2] * 2 != K: - raise ValueError(f"Int4 packed weights last dim {weights.shape[2]} must equal K/2 ({K // 2})") + raise ValueError(f"Int{weight_bits} packed weights must be torch.uint8") + if weight_bits == 8: + k_packed_expected = K + k_div = 1 + elif weight_bits == 4: + k_packed_expected = K // 2 + k_div = 2 + else: # weight_bits == 2 + k_packed_expected = K // 4 + k_div = 4 + if K % k_div != 0: + raise ValueError(f"K must be a multiple of {k_div} for weight_bits={weight_bits}") + if weights.shape[2] != k_packed_expected: + raise ValueError( + f"Int{weight_bits} packed weights last dim {weights.shape[2]} must equal K/{k_div} " + f"({k_packed_expected})" + ) if scales is None: - raise ValueError("scales is required for int4 weights") + raise ValueError(f"scales is required for int{weight_bits} weights") if scales.dtype != activations.dtype: raise ValueError("scales dtype must match activations dtype") - if K % group_size != 0 or (group_size & 1) != 0: - raise ValueError("K must be a multiple of group_size and group_size must be even") + if K % group_size != 0: + raise ValueError("K must be a multiple of group_size") + # Group_size constraints per dtype. + if weight_bits == 4 and (group_size & 1) != 0: + raise ValueError("group_size must be even for int4 weights") + if weight_bits == 2 and (group_size & 3) != 0: + raise ValueError("group_size must be a multiple of 4 for int2 weights") expected_scale_shape = (num_experts, N, K // group_size) if tuple(scales.shape) != expected_scale_shape: raise ValueError(f"scales shape {tuple(scales.shape)} != expected {expected_scale_shape}") @@ -897,14 +954,13 @@ def moe_gemm_decode( else: if zeros is not None: raise ValueError("zeros must be None when asym=False") - # S4_CLIP weight dtype, regardless of activations dtype. - weight_dtype = ARK_DT.int4 + weight_dtype = {8: ARK_DT.int8, 4: ARK_DT.int4, 2: ARK_DT.int2}[weight_bits] if not scales.is_contiguous(): scales = scales.contiguous() if asym and not zeros.is_contiguous(): zeros = zeros.contiguous() else: - raise ValueError(f"Unsupported weight_bits={weight_bits} (supported: 4, 16)") + raise ValueError(f"Unsupported weight_bits={weight_bits} (supported: 2, 4, 8, 16)") if N % 16 != 0: raise ValueError(f"N must be a multiple of 16 (got {N})") diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp index 97668991b..343752633 100644 --- a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp @@ -7,8 +7,17 @@ // identical to the prefill `moe_gemm` interface): // - activations: [total_tokens, K] row-major // - weights (fp/bf16): [num_experts, N, K] row-major +// - weights (int8): [num_experts, N, K] row-major, one +// int8 per byte (sym: signed -128..127; +// asym: unsigned 0..255 with zero-point) // - weights (int4 packed): [num_experts, N, K/2] row-major, two // 4-bit values per byte (low nibble at lower K) +// - weights (int2 packed): [num_experts, N, K/4] row-major, four +// 2-bit values per byte (field j at K index +// 4*i+j is bits [2j+1:2j]) +// - weights (fp8): [num_experts, N, K] row-major, one +// FP8 byte per weight (E4M3 / E5M2); scales +// applied per-group, no zero-points // - scales: [num_experts, N, K/group_size] // - zeros (asym only): [num_experts, N, K/group_size] // - num_tokens_per_expert: [num_experts] int32 @@ -26,6 +35,7 @@ #include #include +#include #include #include "bestla/bestla/bestla.h" @@ -51,6 +61,64 @@ class MoEDecodeKernelFP; template class MoEDecodeKernelInt4; +template +class MoEDecodeKernelInt8; + +template +class MoEDecodeKernelInt2; + +template +class MoEDecodeKernelFP8; + +// ---------------------------------------------------------------------------- +// FP8 byte -> float decode (device-side, no LUT / SLM required). +// Matches IEEE-style layout used by torch.float8_e4m3fn / torch.float8_e5m2: +// E4M3 (finite-only): 1 sign, 4 exp (bias 7), 3 mantissa; 0x7F/0xFF = NaN. +// E5M2 (IEEE-like): 1 sign, 5 exp (bias 15), 2 mantissa; exp==31 -> Inf/NaN. +// We keep these inline rather than using a SLM LUT because the decode kernel +// runs only one sub-group per workgroup and per-lane bit ops are cheap relative +// to the global memory loads. +// ---------------------------------------------------------------------------- +inline float decode_fp8_e4m3(uint8_t byte) { + const uint32_t mag = byte & 0x7Fu; + const uint32_t sign = byte >> 7; + float v; + if (mag == 0u) { + v = 0.0f; + } else if (mag == 0x7Fu) { + v = sycl::nan(0u); + } else { + const int exp = static_cast((mag >> 3) & 0xFu); + const int man = static_cast(mag & 0x7u); + if (exp == 0) { + // subnormal: value = man * 2^(1 - bias - mbits) = man / 512 + v = static_cast(man) * (1.0f / 512.0f); + } else { + // normal: (1 + man/8) * 2^(exp - bias), bias = 7 + v = (1.0f + static_cast(man) * 0.125f) * sycl::ldexp(1.0f, exp - 7); + } + } + return sign ? -v : v; +} + +inline float decode_fp8_e5m2(uint8_t byte) { + const uint32_t mag = byte & 0x7Fu; + const uint32_t sign = byte >> 7; + const int exp = static_cast((mag >> 2) & 0x1Fu); + const int man = static_cast(mag & 0x3u); + float v; + if (exp == 0) { + // subnormal (incl. zero): value = man * 2^(1 - bias - mbits) = man / 65536 + v = static_cast(man) * (1.0f / 65536.0f); + } else if (exp == 31) { + v = (man == 0) ? std::numeric_limits::infinity() : sycl::nan(0u); + } else { + // normal: (1 + man/4) * 2^(exp - bias), bias = 15 + v = (1.0f + static_cast(man) * 0.25f) * sycl::ldexp(1.0f, exp - 15); + } + return sign ? -v : v; +} + // ---------------------------------------------------------------------------- // Build a [total_tokens] -> expert_id mapping from num_tokens_per_expert. // Runs on host (num_experts is small, total_tokens is small in decode). @@ -211,6 +279,242 @@ void launch_int4(sycl::queue* q, const ScalarT* activations, const uint8_t* weig .wait(); } +// ---------------------------------------------------------------------------- +// INT8 (S8) GEMV with group-wise dequantization. +// +// Asym=false: signed byte in [-128, 127], dequant = q * scale +// Asym=true : unsigned byte in [0, 255], dequant = (q - zero) * scale +// +// Weights are stored as raw uint8 bytes (1 byte per weight). The same buffer +// type is used for sym and asym; the only difference is the sign interpretation +// performed at decode time. +// ---------------------------------------------------------------------------- +template +void launch_int8(sycl::queue* q, const ScalarT* activations, const uint8_t* weights, const ScalarT* scales, + const ScalarT* zeros, ScalarT* outputs, const int* expert_id_per_token, int total_tokens, int N, + int K, int group_size) { + if (N % N_TILE != 0) { + throw std::invalid_argument("moe_gemm_decode(int8): N must be a multiple of 16"); + } + if (K % group_size != 0) { + throw std::invalid_argument("moe_gemm_decode(int8): K must be a multiple of group_size"); + } + if (Asym && zeros == nullptr) { + throw std::invalid_argument("moe_gemm_decode(int8): zeros pointer required when asym=true"); + } + if (total_tokens == 0) return; + + const int n_tiles = N / N_TILE; + const int num_groups_k = K / group_size; + + sycl::range<2> global{static_cast(total_tokens), static_cast(n_tiles * SG_SIZE)}; + sycl::range<2> local{1, static_cast(SG_SIZE)}; + + q->parallel_for>( + sycl::nd_range<2>(global, local), + [=](sycl::nd_item<2> it) [[intel::reqd_sub_group_size(SG_SIZE)]] { + const int token = static_cast(it.get_global_id(0)); + const int n_tile = static_cast(it.get_group(1)); + const int lane = static_cast(it.get_local_id(1)); + const int n_global = n_tile * N_TILE + lane; + + const int expert = expert_id_per_token[token]; + const ScalarT* act_row = activations + static_cast(token) * K; + + const uint8_t* w_row = + weights + (static_cast(expert) * N + static_cast(n_global)) * K; + const ScalarT* s_row = + scales + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k; + const ScalarT* z_row = Asym + ? zeros + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k + : nullptr; + + float acc = 0.0f; + for (int g = 0; g < num_groups_k; ++g) { + const float scale = static_cast(s_row[g]); + float zero = 0.0f; + if constexpr (Asym) { + zero = static_cast(z_row[g]); + } + const int k_base = g * group_size; + for (int kk = 0; kk < group_size; ++kk) { + const uint8_t raw = w_row[k_base + kk]; + float w; + if constexpr (Asym) { + w = (static_cast(raw) - zero) * scale; + } else { + // Reinterpret as signed int8. + w = static_cast(static_cast(raw)) * scale; + } + acc += static_cast(act_row[k_base + kk]) * w; + } + } + + outputs[static_cast(token) * N + n_global] = static_cast(acc); + }) + .wait(); +} + +// ---------------------------------------------------------------------------- +// INT2 (S2_CLIP) GEMV with group-wise dequantization. +// +// Packing: 4 values per byte. The value at K index 4*i + j is stored in +// bits [2j+1 : 2j] of byte i (i.e. byte = q0 | (q1<<2) | (q2<<4) | (q3<<6)). +// +// Asym=false: signed 2-bit value in [-2, 1]; dequant = q * scale +// Asym=true : unsigned 2-bit value in [0, 3]; dequant = (q - zero) * scale +// ---------------------------------------------------------------------------- +template +void launch_int2(sycl::queue* q, const ScalarT* activations, const uint8_t* weights, const ScalarT* scales, + const ScalarT* zeros, ScalarT* outputs, const int* expert_id_per_token, int total_tokens, int N, + int K, int group_size) { + if (N % N_TILE != 0) { + throw std::invalid_argument("moe_gemm_decode(int2): N must be a multiple of 16"); + } + if ((K & 0x3) != 0) { + throw std::invalid_argument("moe_gemm_decode(int2): K must be a multiple of 4"); + } + if (K % group_size != 0 || (group_size & 0x3) != 0) { + throw std::invalid_argument( + "moe_gemm_decode(int2): K must be a multiple of group_size and group_size must be a multiple of 4"); + } + if (Asym && zeros == nullptr) { + throw std::invalid_argument("moe_gemm_decode(int2): zeros pointer required when asym=true"); + } + if (total_tokens == 0) return; + + const int n_tiles = N / N_TILE; + const int num_groups_k = K / group_size; + const int k_packed = K / 4; // bytes of packed weight per (expert, n) + + sycl::range<2> global{static_cast(total_tokens), static_cast(n_tiles * SG_SIZE)}; + sycl::range<2> local{1, static_cast(SG_SIZE)}; + + q->parallel_for>( + sycl::nd_range<2>(global, local), + [=](sycl::nd_item<2> it) [[intel::reqd_sub_group_size(SG_SIZE)]] { + const int token = static_cast(it.get_global_id(0)); + const int n_tile = static_cast(it.get_group(1)); + const int lane = static_cast(it.get_local_id(1)); + const int n_global = n_tile * N_TILE + lane; + + const int expert = expert_id_per_token[token]; + const ScalarT* act_row = activations + static_cast(token) * K; + + const uint8_t* w_row = + weights + (static_cast(expert) * N + static_cast(n_global)) * k_packed; + const ScalarT* s_row = + scales + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k; + const ScalarT* z_row = Asym + ? zeros + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k + : nullptr; + + float acc = 0.0f; + for (int g = 0; g < num_groups_k; ++g) { + const float scale = static_cast(s_row[g]); + float zero = 0.0f; + if constexpr (Asym) { + zero = static_cast(z_row[g]); + } + const int k_base = g * group_size; + // 4 values per byte; iterate in quads. + for (int kk = 0; kk < group_size; kk += 4) { + const uint8_t packed = w_row[(k_base + kk) / 4]; + float w[4]; + if constexpr (Asym) { + const int q0 = static_cast(packed & 0x3); + const int q1 = static_cast((packed >> 2) & 0x3); + const int q2 = static_cast((packed >> 4) & 0x3); + const int q3 = static_cast((packed >> 6) & 0x3); + w[0] = (static_cast(q0) - zero) * scale; + w[1] = (static_cast(q1) - zero) * scale; + w[2] = (static_cast(q2) - zero) * scale; + w[3] = (static_cast(q3) - zero) * scale; + } else { + // Sign-extend each 2-bit signed value via shift-left-then-arith-shift-right. + // After placing the 2 bits in the high 2 of an int8, >>6 replicates the sign. + const int q0 = static_cast(static_cast(packed << 6) >> 6); + const int q1 = static_cast(static_cast((packed << 4) & 0xC0) >> 6); + const int q2 = static_cast(static_cast((packed << 2) & 0xC0) >> 6); + const int q3 = static_cast(static_cast(packed & 0xC0) >> 6); + w[0] = static_cast(q0) * scale; + w[1] = static_cast(q1) * scale; + w[2] = static_cast(q2) * scale; + w[3] = static_cast(q3) * scale; + } + acc += static_cast(act_row[k_base + kk + 0]) * w[0]; + acc += static_cast(act_row[k_base + kk + 1]) * w[1]; + acc += static_cast(act_row[k_base + kk + 2]) * w[2]; + acc += static_cast(act_row[k_base + kk + 3]) * w[3]; + } + } + + outputs[static_cast(token) * N + n_global] = static_cast(acc); + }) + .wait(); +} + +// ---------------------------------------------------------------------------- +// FP8 (E4M3 / E5M2) GEMV with group-wise scale (no zero-point). +// +// Weights are 1 FP8 byte per element [E, N, K]. The byte is decoded inline by +// bit manipulation; the LUT in fp8_lut.h would also work but inline decode +// keeps this kernel self-contained and avoids touching SLM. +// ---------------------------------------------------------------------------- +template +void launch_fp8(sycl::queue* q, const ScalarT* activations, const uint8_t* weights, const ScalarT* scales, + ScalarT* outputs, const int* expert_id_per_token, int total_tokens, int N, int K, int group_size) { + if (N % N_TILE != 0) { + throw std::invalid_argument("moe_gemm_decode(fp8): N must be a multiple of 16"); + } + if (K % group_size != 0) { + throw std::invalid_argument("moe_gemm_decode(fp8): K must be a multiple of group_size"); + } + if (total_tokens == 0) return; + + const int n_tiles = N / N_TILE; + const int num_groups_k = K / group_size; + + sycl::range<2> global{static_cast(total_tokens), static_cast(n_tiles * SG_SIZE)}; + sycl::range<2> local{1, static_cast(SG_SIZE)}; + + q->parallel_for>( + sycl::nd_range<2>(global, local), + [=](sycl::nd_item<2> it) [[intel::reqd_sub_group_size(SG_SIZE)]] { + const int token = static_cast(it.get_global_id(0)); + const int n_tile = static_cast(it.get_group(1)); + const int lane = static_cast(it.get_local_id(1)); + const int n_global = n_tile * N_TILE + lane; + + const int expert = expert_id_per_token[token]; + const ScalarT* act_row = activations + static_cast(token) * K; + + const uint8_t* w_row = + weights + (static_cast(expert) * N + static_cast(n_global)) * K; + const ScalarT* s_row = + scales + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k; + + float acc = 0.0f; + for (int g = 0; g < num_groups_k; ++g) { + const float scale = static_cast(s_row[g]); + const int k_base = g * group_size; + for (int kk = 0; kk < group_size; ++kk) { + const uint8_t raw = w_row[k_base + kk]; + float w; + if constexpr (IsE4M3) { + w = decode_fp8_e4m3(raw) * scale; + } else { + w = decode_fp8_e5m2(raw) * scale; + } + acc += static_cast(act_row[k_base + kk]) * w; + } + } + + outputs[static_cast(token) * N + n_global] = static_cast(acc); + }) + .wait(); +} + } // namespace moe_decode_detail // ---------------------------------------------------------------------------- @@ -219,9 +523,16 @@ void launch_int4(sycl::queue* q, const ScalarT* activations, const uint8_t* weig // weight_dtype: // BTLA_DTYPE::F16 / BF16 : weights stored as [E, N, K] in matching // floating dtype, no scales/zeros needed +// BTLA_DTYPE::S8 : int8 weights [E, N, K] (uint8 buffer, +// interpreted as signed when asym==false, +// unsigned with zero-points when asym==true) // BTLA_DTYPE::S4_CLIP : packed int4 weights [E, N, K/2] (uint8), // scales [E, N, K/group_size] in act dtype, // zeros optional (asym==true requires it) +// BTLA_DTYPE::S2_CLIP : packed int2 weights [E, N, K/4] (uint8), +// 4 values per byte, sym/asym like int4 +// BTLA_DTYPE::F8_E4M3 / F8_E5M2 : FP8 weights [E, N, K] (uint8 buffer), +// group-wise scales, no zero-points // act_dtype: F16 or BF16 (must match scales/outputs dtype) // ---------------------------------------------------------------------------- inline void moe_gemm_decode(sycl::queue* q, void* activations, void* weights, void* scales, void* zeros, @@ -281,7 +592,108 @@ inline void moe_gemm_decode(sycl::queue* q, void* activations, void* weights, vo return; } - throw std::invalid_argument("moe_gemm_decode: unsupported weight_dtype (supported: F16, BF16, S4_CLIP)"); + if (weight_dtype == BTLA_DTYPE::S8) { + if (act_dtype == BTLA_DTYPE::F16) { + if (asym) { + moe_decode_detail::launch_int8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else if (act_dtype == BTLA_DTYPE::BF16) { + using BF = sycl::ext::oneapi::bfloat16; + if (asym) { + moe_decode_detail::launch_int8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else { + throw std::invalid_argument("moe_gemm_decode(int8): act_dtype must be FP16 or BF16"); + } + return; + } + + if (weight_dtype == BTLA_DTYPE::S2_CLIP) { + if (act_dtype == BTLA_DTYPE::F16) { + if (asym) { + moe_decode_detail::launch_int2( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int2( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else if (act_dtype == BTLA_DTYPE::BF16) { + using BF = sycl::ext::oneapi::bfloat16; + if (asym) { + moe_decode_detail::launch_int2( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int2( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else { + throw std::invalid_argument("moe_gemm_decode(int2): act_dtype must be FP16 or BF16"); + } + return; + } + + if (weight_dtype == BTLA_DTYPE::F8_E4M3 || weight_dtype == BTLA_DTYPE::F8_E5M2) { + if (asym) { + throw std::invalid_argument("moe_gemm_decode(fp8): asym mode is not supported"); + } + const bool is_e4m3 = (weight_dtype == BTLA_DTYPE::F8_E4M3); + if (act_dtype == BTLA_DTYPE::F16) { + if (is_e4m3) { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, + total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, + total_tokens, N, K, group_size); + } + } else if (act_dtype == BTLA_DTYPE::BF16) { + using BF = sycl::ext::oneapi::bfloat16; + if (is_e4m3) { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, + group_size); + } else { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, + group_size); + } + } else { + throw std::invalid_argument("moe_gemm_decode(fp8): act_dtype must be FP16 or BF16"); + } + return; + } + + throw std::invalid_argument( + "moe_gemm_decode: unsupported weight_dtype (supported: F16, BF16, S8, S4_CLIP, S2_CLIP, F8_E4M3, F8_E5M2)"); } } // namespace ark diff --git a/auto_round_extension/ark/test/test_moe.py b/auto_round_extension/ark/test/test_moe.py index e7905569e..bca8d84be 100644 --- a/auto_round_extension/ark/test/test_moe.py +++ b/auto_round_extension/ark/test/test_moe.py @@ -247,6 +247,163 @@ def _dequant_int4_asym(packed, scales, zeros, group_size): return deq.reshape(E, N, K) +# --------------------------------------------------------------------------- +# Int8 / Int2 / FP8 helpers (decode-path). +# --------------------------------------------------------------------------- + + +def _pack_int8_sym(w_float, scales, group_size): + """Quantize [E, N, K] fp -> int8 (signed, [-127, 127]); fills scales in place.""" + E, N, K = w_float.shape + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + s = (absmax / 127.0).squeeze(-1).to(scales.dtype) + scales.copy_(s) + q = torch.clamp(torch.round(w / s.to(w.dtype).unsqueeze(-1)), -127, 127).to(torch.int8) + # Reinterpret as uint8 with no value change. + return q.reshape(E, N, K).view(torch.uint8).contiguous() + + +def _pack_int8_asym(w_float, scales, zeros, group_size): + """Quantize [E, N, K] fp -> uint8 ([0, 255]); fills scales/zeros in place.""" + E, N, K = w_float.shape + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + wmin = w.amin(dim=-1, keepdim=True) + wmax = w.amax(dim=-1, keepdim=True) + s = ((wmax - wmin) / 255.0).clamp(min=1e-8) + z = torch.round(-wmin / s).clamp(0, 255) + scales.copy_(s.squeeze(-1).to(scales.dtype)) + zeros.copy_(z.squeeze(-1).to(zeros.dtype)) + q = torch.clamp(torch.round(w / s + z), 0, 255).to(torch.int32) + return q.reshape(E, N, K).to(torch.uint8).contiguous() + + +def _dequant_int8_sym(packed_u8, scales, group_size): + # Reinterpret uint8 bytes as signed int8. + q = packed_u8.view(torch.int8).to(scales.dtype) + E, N, K = q.shape + q = q.reshape(E, N, K // group_size, group_size) + return (q * scales.unsqueeze(-1)).reshape(E, N, K) + + +def _dequant_int8_asym(packed_u8, scales, zeros, group_size): + q = packed_u8.to(torch.int32).to(scales.dtype) + E, N, K = q.shape + q = q.reshape(E, N, K // group_size, group_size) + deq = (q - zeros.to(scales.dtype).unsqueeze(-1)) * scales.unsqueeze(-1) + return deq.reshape(E, N, K) + + +def _pack_int2_sym(w_float, scales, group_size): + """Quantize [E, N, K] fp -> packed int2 (signed [-2, 1]); shape [E, N, K/4]. + + Packing: byte = q0 | (q1<<2) | (q2<<4) | (q3<<6), where the j-th 2-bit + field corresponds to K index 4*i + j. + """ + E, N, K = w_float.shape + assert K % 4 == 0, "K must be a multiple of 4 for int2 packing" + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + # Symmetric int2 has signed range [-2, 1] (i.e. clip at 2 and -2 but -2 inclusive). + absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + s = (absmax / 2.0).squeeze(-1).to(scales.dtype) + scales.copy_(s) + q = torch.clamp(torch.round(w / s.to(w.dtype).unsqueeze(-1)), -2, 1).to(torch.int32) + q = q.reshape(E, N, K) + # Pack 4 values per byte. + q0 = q[..., 0::4] & 0x3 + q1 = q[..., 1::4] & 0x3 + q2 = q[..., 2::4] & 0x3 + q3 = q[..., 3::4] & 0x3 + packed = (q0 | (q1 << 2) | (q2 << 4) | (q3 << 6)).to(torch.uint8) + return packed + + +def _pack_int2_asym(w_float, scales, zeros, group_size): + """Quantize [E, N, K] fp -> packed int2 (unsigned [0, 3]); shape [E, N, K/4].""" + E, N, K = w_float.shape + assert K % 4 == 0 + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + wmin = w.amin(dim=-1, keepdim=True) + wmax = w.amax(dim=-1, keepdim=True) + s = ((wmax - wmin) / 3.0).clamp(min=1e-8) + z = torch.round(-wmin / s).clamp(0, 3) + scales.copy_(s.squeeze(-1).to(scales.dtype)) + zeros.copy_(z.squeeze(-1).to(zeros.dtype)) + q = torch.clamp(torch.round(w / s + z), 0, 3).to(torch.int32) + q = q.reshape(E, N, K) + q0 = q[..., 0::4] & 0x3 + q1 = q[..., 1::4] & 0x3 + q2 = q[..., 2::4] & 0x3 + q3 = q[..., 3::4] & 0x3 + packed = (q0 | (q1 << 2) | (q2 << 4) | (q3 << 6)).to(torch.uint8) + return packed + + +def _dequant_int2_sym(packed, scales, group_size): + E, N, K_q = packed.shape + K = K_q * 4 + p = packed.to(torch.int32) + fields = torch.empty(E, N, K, dtype=torch.int32, device=packed.device) + fields[..., 0::4] = p & 0x3 + fields[..., 1::4] = (p >> 2) & 0x3 + fields[..., 2::4] = (p >> 4) & 0x3 + fields[..., 3::4] = (p >> 6) & 0x3 + # Sign-extend 2-bit (>=2 means negative). + fields = torch.where(fields >= 2, fields - 4, fields).to(scales.dtype) + fields = fields.reshape(E, N, K // group_size, group_size) + return (fields * scales.unsqueeze(-1)).reshape(E, N, K) + + +def _dequant_int2_asym(packed, scales, zeros, group_size): + E, N, K_q = packed.shape + K = K_q * 4 + p = packed.to(torch.int32) + fields = torch.empty(E, N, K, dtype=torch.int32, device=packed.device) + fields[..., 0::4] = p & 0x3 + fields[..., 1::4] = (p >> 2) & 0x3 + fields[..., 2::4] = (p >> 4) & 0x3 + fields[..., 3::4] = (p >> 6) & 0x3 + fields = fields.to(scales.dtype) + fields = fields.reshape(E, N, K // group_size, group_size) + deq = (fields - zeros.to(scales.dtype).unsqueeze(-1)) * scales.unsqueeze(-1) + return deq.reshape(E, N, K) + + +def _pack_fp8(w_float, scales, group_size, fp8_dtype): + """Quantize [E, N, K] fp -> FP8 (e4m3fn/e5m2) with per-group scale. + + Scales are filled in-place; ``w_float`` is divided by the scale, then cast + to ``fp8_dtype`` (rounding handled by torch). Returns the FP8 tensor. + """ + assert fp8_dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + E, N, K = w_float.shape + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + # Pick fp8 max representable magnitude for the chosen format. + if fp8_dtype == torch.float8_e4m3fn: + fp8_max = 448.0 + else: # e5m2 + fp8_max = 57344.0 + s = (absmax / fp8_max).squeeze(-1).to(scales.dtype) + scales.copy_(s) + scaled = (w / s.to(w.dtype).unsqueeze(-1)).reshape(E, N, K) + # Clamp to fp8 representable range before cast to avoid Inf/NaN. + scaled = scaled.clamp(-fp8_max, fp8_max) + return scaled.to(fp8_dtype).contiguous() + + +def _dequant_fp8(packed_fp8, scales, group_size, out_dtype): + """Reference dequant: cast fp8 -> out_dtype and multiply per-group scale.""" + E, N, K = packed_fp8.shape + w = packed_fp8.to(out_dtype).reshape(E, N, K // group_size, group_size) + return (w * scales.unsqueeze(-1)).reshape(E, N, K) + + def _moe_decode_reference(activations, dequant_weights, num_tokens_per_expert): """Reference: each token is matmul'd against its routed expert's weights.""" total_tokens, K = activations.shape @@ -386,6 +543,175 @@ def test_decode_validation_errors(self): asym=True, ) + # FP8 + asym is rejected + fp8_w = torch.zeros(num_experts, 64, 128, dtype=torch.float8_e4m3fn, device="xpu") + zeros = torch.empty(num_experts, 64, 1, dtype=torch.float16, device="xpu") + with pytest.raises(ValueError): + ark.moe_gemm_decode( + activations, + fp8_w, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + group_size=128, + asym=True, + ) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_decode_int8_sym(self, dtype, group_size): + num_experts = 4 + tokens_per_expert = [1, 1, 0, 2] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_sym(w_float, scales, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + weight_bits=8, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_int8_sym(packed, scales, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_decode_int8_asym(self, dtype): + num_experts = 4 + group_size = 128 + tokens_per_expert = [0, 1, 2, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_asym(w_float, scales, zeros, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=8, + group_size=group_size, + asym=True, + ) + + dequant = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_decode_int2_sym(self, dtype, group_size): + num_experts = 4 + tokens_per_expert = [1, 1, 0, 2] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_sym(w_float, scales, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + weight_bits=2, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_int2_sym(packed, scales, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + # Int2 has much higher quant error; relax tolerance vs int4. + torch.testing.assert_close(out, ref, rtol=1e-1, atol=1e-1) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_decode_int2_asym(self, dtype): + num_experts = 4 + group_size = 128 + tokens_per_expert = [0, 1, 2, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_asym(w_float, scales, zeros, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=2, + group_size=group_size, + asym=True, + ) + + dequant = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=1e-1, atol=1e-1) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize( + "fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2] + ) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_decode_fp8(self, dtype, fp8_dtype, group_size): + num_experts = 4 + tokens_per_expert = [1, 0, 2, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_fp8(packed, scales, group_size, dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + # E5M2 has only 2 mantissa bits -> coarser; relax tolerance for both. + rtol = 1e-1 if fp8_dtype == torch.float8_e5m2 else 5e-2 + atol = 1e-1 if fp8_dtype == torch.float8_e5m2 else 5e-2 + torch.testing.assert_close(out, ref, rtol=rtol, atol=atol) + if __name__ == "__main__": pytest.main([__file__, "-v"]) From f15093a1b6e4ff0efdab7bd5bb52b2da45a231ed Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 14 May 2026 07:19:31 +0000 Subject: [PATCH 05/28] docs: clarify int2 bit-indexing notation in moe_gemm_decode Agent-Logs-Url: https://github.com/intel/auto-round/sessions/91221649-2c90-4404-ae86-3321b1581428 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com> --- auto_round_extension/ark/auto_round_kernel/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round_extension/ark/auto_round_kernel/__init__.py b/auto_round_extension/ark/auto_round_kernel/__init__.py index 8f0be1e98..44db0871d 100644 --- a/auto_round_extension/ark/auto_round_kernel/__init__.py +++ b/auto_round_extension/ark/auto_round_kernel/__init__.py @@ -828,7 +828,7 @@ def moe_gemm_decode( at the lower K index). * Int2 (``weight_bits=2``): ``torch.uint8`` packed, ``K_packed == K // 4`` (four 2-bit values per byte; field j at - K index ``4*i + j`` is bits ``[2j+1:2j]`` of byte i). + K index ``4*i + j`` occupies bits 2j and 2j+1 of byte i). * FP8 (``torch.float8_e4m3fn`` / ``torch.float8_e5m2``): ``K_packed == K``. ``weight_bits`` is ignored; ``asym`` must be ``False`` (no zero-points for FP8). From 43958843090777f7ab66d7abfedc719700ed4b9c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 19 May 2026 06:22:40 +0000 Subject: [PATCH 06/28] =?UTF-8?q?test:=20add=20perf=20comparison=20UT=20?= =?UTF-8?q?=E2=80=94=20moe=5Fgemm=5Fdecode=20vs=20default=20XPU=20MoE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Agent-Logs-Url: https://github.com/intel/auto-round/sessions/132db2ab-85c0-45b6-81a7-b9baaa533e5e Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com> --- .../ark/test/test_moe_decode_perf.py | 321 ++++++++++++++++++ 1 file changed, 321 insertions(+) create mode 100644 auto_round_extension/ark/test/test_moe_decode_perf.py diff --git a/auto_round_extension/ark/test/test_moe_decode_perf.py b/auto_round_extension/ark/test/test_moe_decode_perf.py new file mode 100644 index 000000000..6b82f1d03 --- /dev/null +++ b/auto_round_extension/ark/test/test_moe_decode_perf.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Performance comparison: ``ark.moe_gemm_decode`` vs default XPU MoE. + +The "default XPU MoE implementation" used as the baseline is the standard +per-expert PyTorch matmul loop (the same approach ``_moe_decode_reference`` +uses in ``test_moe.py``). For quantized formats the weights are dequantized +once up-front (outside the timed region), so the baseline measures only the +matmul cost on XPU. This is what models fall back to when no fused decode +kernel is available. + +How to run:: + + pytest -v -s auto_round_extension/ark/test/test_moe_decode_perf.py + +The ``-s`` flag is required to see the printed timing tables. +""" + +import auto_round_kernel +import pytest +import torch + +# Reuse the existing pack/dequant helpers from the correctness tests so that +# the benchmarked path matches what the unit tests already validate. +from test_moe import ( # noqa: E402 + _dequant_fp8, + _dequant_int2_asym, + _dequant_int2_sym, + _dequant_int4_asym, + _dequant_int4_sym, + _dequant_int8_asym, + _dequant_int8_sym, + _pack_fp8, + _pack_int2_asym, + _pack_int2_sym, + _pack_int4_asym, + _pack_int4_sym, + _pack_int8_asym, + _pack_int8_sym, + has_moe_gemm_decode, + is_xpu_available, +) + +ark = auto_round_kernel.ARK() + + +# --------------------------------------------------------------------------- +# Timing utilities. +# --------------------------------------------------------------------------- + +# Warmup / iteration counts kept modest so the suite is still UT-shaped +# (finishes in seconds) but large enough for stable medians. +WARMUP = 5 +ITERS = 30 + + +def _xpu_time_ms(fn, warmup: int = WARMUP, iters: int = ITERS) -> float: + """Time ``fn`` on XPU using device events; returns median ms per call.""" + for _ in range(warmup): + fn() + torch.xpu.synchronize() + + timings = [] + for _ in range(iters): + start = torch.xpu.Event(enable_timing=True) + end = torch.xpu.Event(enable_timing=True) + start.record() + fn() + end.record() + end.synchronize() + timings.append(start.elapsed_time(end)) + timings.sort() + return timings[len(timings) // 2] + + +def _default_moe_decode(activations, dequant_weights, num_tokens_per_expert): + """Default XPU MoE decode baseline: per-expert torch matmul loop. + + This mirrors the path a model would take when no fused MoE decode kernel + is available: gather/sort tokens by expert (done by the caller), then + iterate over experts and do a plain ``A @ W.T`` on each slice. + """ + total_tokens, _ = activations.shape + E, N, _ = dequant_weights.shape + out = torch.empty(total_tokens, N, dtype=activations.dtype, device=activations.device) + offset = 0 + for e in range(E): + n_tokens = int(num_tokens_per_expert[e].item()) + if n_tokens == 0: + continue + a = activations[offset : offset + n_tokens] + out[offset : offset + n_tokens] = a @ dequant_weights[e].T + offset += n_tokens + return out + + +# --------------------------------------------------------------------------- +# Shape matrix. +# +# Picked to cover small / medium / large MoE expert FFNs (Mixtral-style +# 4096x14336 down-projection is the upper bound; smaller shapes catch +# launch-overhead-dominated cases). ``tokens_per_expert`` follows the +# expected decode-phase pattern (top-k routing with batch=1: each active +# expert sees one token). +# --------------------------------------------------------------------------- + +DECODE_SHAPES = [ + # (label, num_experts, tokens_per_expert, N, K) + ("small E=4 ", 4, [1, 0, 1, 1], 1024, 1024), + ("medium E=8 ", 8, [1, 1, 0, 1, 1, 0, 1, 1], 2048, 2048), + ("large E=8 ", 8, [1, 0, 1, 1, 0, 1, 1, 1], 4096, 4096), + ("ffn-up E=8 ", 8, [1, 1, 0, 1, 1, 1, 0, 1], 14336, 4096), + ("ffn-dn E=8 ", 8, [1, 1, 0, 1, 1, 1, 0, 1], 4096, 14336), +] + + +def _print_header(title: str) -> None: + print() + print("=" * 96) + print(title) + print( + f"{'shape':<14}{'N':>7}{'K':>7}{'tokens':>8}" + f"{'baseline(ms)':>16}{'ark(ms)':>14}{'speedup':>12}" + ) + print("-" * 96) + + +def _print_row(label, N, K, total_tokens, base_ms, ark_ms): + speedup = base_ms / ark_ms if ark_ms > 0 else float("nan") + print( + f"{label:<14}{N:>7}{K:>7}{total_tokens:>8}" + f"{base_ms:>16.4f}{ark_ms:>14.4f}{speedup:>11.2f}x" + ) + + +# --------------------------------------------------------------------------- +# Benchmark cases. +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not is_xpu_available(), reason="XPU not available") +@pytest.mark.skipif(not has_moe_gemm_decode(), reason="MoE decode GEMV kernel not built (need ARK_SYCL_TLA=ON)") +class TestMoEGemmDecodePerf: + """Median XPU-event timings of ``moe_gemm_decode`` vs per-expert ``A @ W.T``. + + The baseline uses *already-dequantized* weights, so quantized cases only + pay the matmul cost in the timed region (no per-iteration dequant). This + is the most favorable apples-to-apples comparison for the baseline; the + fused decode kernel must beat that to be worth using. + """ + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_perf_fp(self, dtype): + _print_header(f"FP weights ({str(dtype).split('.')[-1]}) -- ark.moe_gemm_decode vs per-expert A @ W.T") + for label, E, tpe, N, K in DECODE_SHAPES: + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + weights = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, weights, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_decode(activations, weights, ntpe, weight_bits=16) + ) + _print_row(label, N, K, total_tokens, base_ms, ark_ms) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_perf_int4(self, dtype, asym): + group_size = 128 + kind = "asym" if asym else "sym" + _print_header( + f"INT4 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " + f"-- ark.moe_gemm_decode vs dequant + per-expert A @ W.T" + ) + for label, E, tpe, N, K in DECODE_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int4_sym(w_float, scales, group_size) + dequant = _dequant_int4_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, dequant, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_decode( + activations, packed, ntpe, + scales=scales, zeros=zeros, + weight_bits=4, group_size=group_size, asym=asym, + ) + ) + _print_row(label, N, K, total_tokens, base_ms, ark_ms) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_perf_int8(self, dtype, asym): + group_size = 128 + kind = "asym" if asym else "sym" + _print_header( + f"INT8 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " + f"-- ark.moe_gemm_decode vs dequant + per-expert A @ W.T" + ) + for label, E, tpe, N, K in DECODE_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int8_sym(w_float, scales, group_size) + dequant = _dequant_int8_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, dequant, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_decode( + activations, packed, ntpe, + scales=scales, zeros=zeros, + weight_bits=8, group_size=group_size, asym=asym, + ) + ) + _print_row(label, N, K, total_tokens, base_ms, ark_ms) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_perf_int2(self, dtype, asym): + group_size = 128 + kind = "asym" if asym else "sym" + _print_header( + f"INT2 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " + f"-- ark.moe_gemm_decode vs dequant + per-expert A @ W.T" + ) + for label, E, tpe, N, K in DECODE_SHAPES: + if K % group_size != 0 or K % 4 != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int2_sym(w_float, scales, group_size) + dequant = _dequant_int2_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, dequant, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_decode( + activations, packed, ntpe, + scales=scales, zeros=zeros, + weight_bits=2, group_size=group_size, asym=asym, + ) + ) + _print_row(label, N, K, total_tokens, base_ms, ark_ms) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + def test_perf_fp8(self, dtype, fp8_dtype): + group_size = 128 + _print_header( + f"FP8 {str(fp8_dtype).split('.')[-1]} (group_size={group_size}, " + f"act={str(dtype).split('.')[-1]}) -- ark.moe_gemm_decode vs dequant + per-expert A @ W.T" + ) + for label, E, tpe, N, K in DECODE_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) + dequant = _dequant_fp8(packed, scales, group_size, dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, dequant, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_decode( + activations, packed, ntpe, + scales=scales, + group_size=group_size, asym=False, + ) + ) + _print_row(label, N, K, total_tokens, base_ms, ark_ms) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) From a864bedbea6ee6ddd0cc5535b673a7b7662a6f49 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 06:42:48 +0000 Subject: [PATCH 07/28] test: clearer skip reasons for moe_gemm_decode perf UT Agent-Logs-Url: https://github.com/intel/auto-round/sessions/4f2a4b1d-b510-4522-84b1-4667ac8b5b97 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com> --- .../ark/test/test_moe_decode_perf.py | 63 +++++++++++++++++-- 1 file changed, 59 insertions(+), 4 deletions(-) diff --git a/auto_round_extension/ark/test/test_moe_decode_perf.py b/auto_round_extension/ark/test/test_moe_decode_perf.py index 6b82f1d03..44bbb7fd0 100644 --- a/auto_round_extension/ark/test/test_moe_decode_perf.py +++ b/auto_round_extension/ark/test/test_moe_decode_perf.py @@ -52,13 +52,69 @@ _pack_int4_sym, _pack_int8_asym, _pack_int8_sym, - has_moe_gemm_decode, - is_xpu_available, ) ark = auto_round_kernel.ARK() +# --------------------------------------------------------------------------- +# Skip reasons. +# +# The original test_moe.py collapses several different failure modes into one +# generic "kernel not built" message which makes it impossible to tell whether +# the build is missing the kernel or whether XPU itself didn't come up. The +# helpers below distinguish those cases so a skipped run is actually +# actionable. +# --------------------------------------------------------------------------- + + +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def _xpu_skip_reason() -> str: + if not hasattr(torch, "xpu"): + return "torch has no xpu submodule (need an Intel XPU build of torch)" + if not torch.xpu.is_available(): + return "torch.xpu.is_available() == False (no XPU device or driver visible)" + return "" + + +def _decode_skip_reason() -> str: + """Return non-empty string if the decode kernel can't be exercised.""" + reason = _xpu_skip_reason() + if reason: + return reason + if ark.xpu_lib is None: + return ( + "ark.xpu_lib is None -- the XPU extension module " + "(auto_round_kernel_xpu) failed to import; check that auto_round_kernel " + "was installed for THIS Python env with XPU support enabled" + ) + if not hasattr(ark.xpu_lib, "moe_gemm_decode"): + return ( + "ark.xpu_lib loaded but has no moe_gemm_decode symbol -- " + "rebuild with ARK_SYCL_TLA=ON to compile the MoE decode GEMV kernel" + ) + return "" + + +_DECODE_SKIP = _decode_skip_reason() + +# Surface diagnostics on collection so the user always sees why the suite +# would skip, without having to add extra flags. +print( + "[moe-decode-perf] xpu_available=%s xpu_lib=%s has_moe_gemm_decode=%s" + % ( + _xpu_available(), + "loaded" if ark.xpu_lib is not None else "None", + hasattr(ark.xpu_lib, "moe_gemm_decode") if ark.xpu_lib is not None else False, + ) +) +if _DECODE_SKIP: + print("[moe-decode-perf] suite will SKIP. reason: %s" % _DECODE_SKIP) + + # --------------------------------------------------------------------------- # Timing utilities. # --------------------------------------------------------------------------- @@ -153,8 +209,7 @@ def _print_row(label, N, K, total_tokens, base_ms, ark_ms): # --------------------------------------------------------------------------- -@pytest.mark.skipif(not is_xpu_available(), reason="XPU not available") -@pytest.mark.skipif(not has_moe_gemm_decode(), reason="MoE decode GEMV kernel not built (need ARK_SYCL_TLA=ON)") +@pytest.mark.skipif(bool(_DECODE_SKIP), reason=_DECODE_SKIP or "ok") class TestMoEGemmDecodePerf: """Median XPU-event timings of ``moe_gemm_decode`` vs per-expert ``A @ W.T``. From 407da7547d69af90ac688b4edd26821aee0f33f9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 26 May 2026 07:34:05 +0000 Subject: [PATCH 08/28] fix(ark): correct duplicated bestla include path in sycl_tla_moe_decode.hpp Agent-Logs-Url: https://github.com/intel/auto-round/sessions/8aaf52b0-b12d-4682-816f-40b7a19cc44b Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com> --- .../auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp index 343752633..e9c7e949d 100644 --- a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp @@ -38,7 +38,7 @@ #include #include -#include "bestla/bestla/bestla.h" +#include "bestla/bestla.h" #ifdef ARK_XPU #include From 70dc3208ede292118a7aff2962231cb01250a447 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 26 May 2026 08:39:40 +0000 Subject: [PATCH 09/28] perf: vectorize moe_gemm_decode loads, parallelize expert-id fill, drop in-flight waits Agent-Logs-Url: https://github.com/intel/auto-round/sessions/80a527ca-8514-4f3e-abb2-bfbec367d409 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com> --- .../wrapper/include/sycl_tla_moe_decode.hpp | 226 ++++++++++++++---- 1 file changed, 184 insertions(+), 42 deletions(-) diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp index e9c7e949d..aa97c4d94 100644 --- a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp @@ -128,20 +128,25 @@ inline float decode_fp8_e5m2(uint8_t byte) { inline void fill_expert_id_per_token(sycl::queue* q, int* expert_id_per_token, const int* num_tokens_per_expert, int num_experts, int total_tokens) { - // Sequential prefix-scan on a single thread; cheap because num_experts is - // small (typ. <= 256) and we avoid host-device sync entirely. - q->single_task([=]() { - int offset = 0; - for (int e = 0; e < num_experts; ++e) { - int n = num_tokens_per_expert[e]; - for (int i = 0; i < n; ++i) { - if (offset + i < total_tokens) { - expert_id_per_token[offset + i] = e; - } - } - offset += n; - } - }).wait(); + // Parallel fill: each work-item independently scans the small + // num_tokens_per_expert array (typ. <= 256) to find its expert id. This + // removes the single-task serialization point and avoids an explicit + // host-device sync; the in-order queue chains this with the GEMV launch. + if (total_tokens == 0) return; + q->parallel_for(sycl::range<1>(static_cast(total_tokens)), [=](sycl::id<1> idx) { + const int i = static_cast(idx[0]); + int offset = 0; + int expert = num_experts - 1; + for (int e = 0; e < num_experts; ++e) { + const int n = num_tokens_per_expert[e]; + if (i < offset + n) { + expert = e; + break; + } + offset += n; + } + expert_id_per_token[i] = expert; + }); } // ---------------------------------------------------------------------------- @@ -173,13 +178,25 @@ void launch_fp(sycl::queue* q, const ScalarT* activations, const ScalarT* weight weights + (static_cast(expert) * N + static_cast(n_global)) * K; float acc = 0.0f; - // Unroll by 8 to hide latency; arbitrary K (any multiple of 8). + // Unroll by 8 with a 16-byte vector load for both activations and + // weights. Activations are sub-group-uniform so they coalesce via + // L1; each lane's weight load is an independent 16-byte transaction. + // We load through a uint16_t vector to stay portable across SYCL + // implementations that may not provide sycl::vec. int k = 0; - constexpr int UNROLL = 8; - for (; k + UNROLL <= K; k += UNROLL) { + constexpr int VEC = 8; + using LoadVec = sycl::vec; + static_assert(sizeof(ScalarT) == sizeof(uint16_t), + "ScalarT must be a 16-bit floating type"); + const int k_vec_end = (K / VEC) * VEC; + for (; k < k_vec_end; k += VEC) { + const LoadVec av = *reinterpret_cast(act_row + k); + const LoadVec wv = *reinterpret_cast(w_row + k); #pragma unroll - for (int u = 0; u < UNROLL; ++u) { - acc += static_cast(act_row[k + u]) * static_cast(w_row[k + u]); + for (int u = 0; u < VEC; ++u) { + const ScalarT a = sycl::bit_cast(static_cast(av[u])); + const ScalarT w = sycl::bit_cast(static_cast(wv[u])); + acc += static_cast(a) * static_cast(w); } } for (; k < K; ++k) { @@ -187,8 +204,7 @@ void launch_fp(sycl::queue* q, const ScalarT* activations, const ScalarT* weight } outputs[static_cast(token) * N + n_global] = static_cast(acc); - }) - .wait(); + }); } // ---------------------------------------------------------------------------- @@ -250,8 +266,43 @@ void launch_int4(sycl::queue* q, const ScalarT* activations, const uint8_t* weig zero = static_cast(z_row[g]); } const int k_base = g * group_size; - // Two nibbles per byte; iterate in pairs. - for (int kk = 0; kk < group_size; kk += 2) { + // Vectorized path: process 16 K-elements at a time, which is + // 8 packed weight bytes and a vec activation block. + // group_size is a multiple of 16 in every supported config + // (group_size >= 32, even); a scalar tail loop covers leftovers. + constexpr int CHUNK = 16; + using ActVec = sycl::vec; + using PackVec = sycl::vec; + static_assert(sizeof(ScalarT) == sizeof(uint16_t), + "ScalarT must be a 16-bit floating type"); + const int chunk_end = (group_size / CHUNK) * CHUNK; + int kk = 0; + for (; kk < chunk_end; kk += CHUNK) { + const ActVec av = *reinterpret_cast(act_row + k_base + kk); + const PackVec pv = *reinterpret_cast(w_row + (k_base + kk) / 2); +#pragma unroll + for (int b = 0; b < CHUNK / 2; ++b) { + const uint8_t packed = pv[b]; + float w0, w1; + if constexpr (Asym) { + const int q0 = static_cast(packed & 0x0F); + const int q1 = static_cast((packed >> 4) & 0x0F); + w0 = (static_cast(q0) - zero) * scale; + w1 = (static_cast(q1) - zero) * scale; + } else { + const int q0 = static_cast(static_cast(packed << 4) >> 4); + const int q1 = static_cast(static_cast(packed & 0xF0) >> 4); + w0 = static_cast(q0) * scale; + w1 = static_cast(q1) * scale; + } + const ScalarT a0 = sycl::bit_cast(static_cast(av[2 * b])); + const ScalarT a1 = sycl::bit_cast(static_cast(av[2 * b + 1])); + acc += static_cast(a0) * w0; + acc += static_cast(a1) * w1; + } + } + // Scalar tail for group_size not divisible by CHUNK. + for (; kk < group_size; kk += 2) { const uint8_t packed = w_row[(k_base + kk) / 2]; float w0, w1; if constexpr (Asym) { @@ -260,10 +311,6 @@ void launch_int4(sycl::queue* q, const ScalarT* activations, const uint8_t* weig w0 = (static_cast(q0) - zero) * scale; w1 = (static_cast(q1) - zero) * scale; } else { - // Sign-extend each 4-bit signed nibble to 8-bit signed: - // low nibble: shift left by 4 to move bit[3] into bit[7], - // then arithmetic right-shift by 4 replicates the sign bit. - // high nibble: same trick after masking off the low nibble. const int q0 = static_cast(static_cast(packed << 4) >> 4); const int q1 = static_cast(static_cast(packed & 0xF0) >> 4); w0 = static_cast(q0) * scale; @@ -275,8 +322,7 @@ void launch_int4(sycl::queue* q, const ScalarT* activations, const uint8_t* weig } outputs[static_cast(token) * N + n_global] = static_cast(acc); - }) - .wait(); + }); } // ---------------------------------------------------------------------------- @@ -337,13 +383,38 @@ void launch_int8(sycl::queue* q, const ScalarT* activations, const uint8_t* weig zero = static_cast(z_row[g]); } const int k_base = g * group_size; - for (int kk = 0; kk < group_size; ++kk) { + // Vectorized path: 16 weights (16 bytes) + 16 activations per load. + // group_size is typically 128 (mult of 16); scalar tail handles + // anything that doesn't divide evenly. + constexpr int CHUNK = 16; + using ActVec = sycl::vec; + using ByteVec = sycl::vec; + static_assert(sizeof(ScalarT) == sizeof(uint16_t), + "ScalarT must be a 16-bit floating type"); + const int chunk_end = (group_size / CHUNK) * CHUNK; + int kk = 0; + for (; kk < chunk_end; kk += CHUNK) { + const ActVec av = *reinterpret_cast(act_row + k_base + kk); + const ByteVec wv = *reinterpret_cast(w_row + k_base + kk); +#pragma unroll + for (int u = 0; u < CHUNK; ++u) { + const uint8_t raw = wv[u]; + float w; + if constexpr (Asym) { + w = (static_cast(raw) - zero) * scale; + } else { + w = static_cast(static_cast(raw)) * scale; + } + const ScalarT a = sycl::bit_cast(static_cast(av[u])); + acc += static_cast(a) * w; + } + } + for (; kk < group_size; ++kk) { const uint8_t raw = w_row[k_base + kk]; float w; if constexpr (Asym) { w = (static_cast(raw) - zero) * scale; } else { - // Reinterpret as signed int8. w = static_cast(static_cast(raw)) * scale; } acc += static_cast(act_row[k_base + kk]) * w; @@ -351,8 +422,7 @@ void launch_int8(sycl::queue* q, const ScalarT* activations, const uint8_t* weig } outputs[static_cast(token) * N + n_global] = static_cast(acc); - }) - .wait(); + }); } // ---------------------------------------------------------------------------- @@ -417,8 +487,57 @@ void launch_int2(sycl::queue* q, const ScalarT* activations, const uint8_t* weig zero = static_cast(z_row[g]); } const int k_base = g * group_size; - // 4 values per byte; iterate in quads. - for (int kk = 0; kk < group_size; kk += 4) { + // Vectorized: 16 K-elements per chunk = 4 packed bytes (4 values + // each) plus a vec activation block. group_size is a + // multiple of 4 and typically 128 (mult of 16); scalar tail covers + // any leftover. We load activations via uint16_t to stay portable + // across SYCL implementations that may not provide + // sycl::vec. + constexpr int CHUNK = 16; + using ActVec = sycl::vec; + using PackVec = sycl::vec; + static_assert(sizeof(ScalarT) == sizeof(uint16_t), + "ScalarT must be a 16-bit floating type"); + const int chunk_end = (group_size / CHUNK) * CHUNK; + int kk = 0; + for (; kk < chunk_end; kk += CHUNK) { + const ActVec av = *reinterpret_cast(act_row + k_base + kk); + const PackVec pv = *reinterpret_cast(w_row + (k_base + kk) / 4); +#pragma unroll + for (int b = 0; b < CHUNK / 4; ++b) { + const uint8_t packed = pv[b]; + float w0, w1, w2, w3; + if constexpr (Asym) { + const int q0 = static_cast(packed & 0x3); + const int q1 = static_cast((packed >> 2) & 0x3); + const int q2 = static_cast((packed >> 4) & 0x3); + const int q3 = static_cast((packed >> 6) & 0x3); + w0 = (static_cast(q0) - zero) * scale; + w1 = (static_cast(q1) - zero) * scale; + w2 = (static_cast(q2) - zero) * scale; + w3 = (static_cast(q3) - zero) * scale; + } else { + const int q0 = static_cast(static_cast(packed << 6) >> 6); + const int q1 = static_cast(static_cast((packed << 4) & 0xC0) >> 6); + const int q2 = static_cast(static_cast((packed << 2) & 0xC0) >> 6); + const int q3 = static_cast(static_cast(packed & 0xC0) >> 6); + w0 = static_cast(q0) * scale; + w1 = static_cast(q1) * scale; + w2 = static_cast(q2) * scale; + w3 = static_cast(q3) * scale; + } + const ScalarT a0 = sycl::bit_cast(static_cast(av[4 * b + 0])); + const ScalarT a1 = sycl::bit_cast(static_cast(av[4 * b + 1])); + const ScalarT a2 = sycl::bit_cast(static_cast(av[4 * b + 2])); + const ScalarT a3 = sycl::bit_cast(static_cast(av[4 * b + 3])); + acc += static_cast(a0) * w0; + acc += static_cast(a1) * w1; + acc += static_cast(a2) * w2; + acc += static_cast(a3) * w3; + } + } + // Scalar tail (4 values per byte). + for (; kk < group_size; kk += 4) { const uint8_t packed = w_row[(k_base + kk) / 4]; float w[4]; if constexpr (Asym) { @@ -431,8 +550,6 @@ void launch_int2(sycl::queue* q, const ScalarT* activations, const uint8_t* weig w[2] = (static_cast(q2) - zero) * scale; w[3] = (static_cast(q3) - zero) * scale; } else { - // Sign-extend each 2-bit signed value via shift-left-then-arith-shift-right. - // After placing the 2 bits in the high 2 of an int8, >>6 replicates the sign. const int q0 = static_cast(static_cast(packed << 6) >> 6); const int q1 = static_cast(static_cast((packed << 4) & 0xC0) >> 6); const int q2 = static_cast(static_cast((packed << 2) & 0xC0) >> 6); @@ -450,8 +567,7 @@ void launch_int2(sycl::queue* q, const ScalarT* activations, const uint8_t* weig } outputs[static_cast(token) * N + n_global] = static_cast(acc); - }) - .wait(); + }); } // ---------------------------------------------------------------------------- @@ -498,7 +614,34 @@ void launch_fp8(sycl::queue* q, const ScalarT* activations, const uint8_t* weigh for (int g = 0; g < num_groups_k; ++g) { const float scale = static_cast(s_row[g]); const int k_base = g * group_size; - for (int kk = 0; kk < group_size; ++kk) { + // Vectorized: 16 weights (16 bytes) + 16 activations per load. + // Decode each FP8 byte to float inline, then apply the per-group + // scale. group_size is typically 128 (mult of 16); scalar tail + // covers anything that doesn't divide evenly. + constexpr int CHUNK = 16; + using ActVec = sycl::vec; + using ByteVec = sycl::vec; + static_assert(sizeof(ScalarT) == sizeof(uint16_t), + "ScalarT must be a 16-bit floating type"); + const int chunk_end = (group_size / CHUNK) * CHUNK; + int kk = 0; + for (; kk < chunk_end; kk += CHUNK) { + const ActVec av = *reinterpret_cast(act_row + k_base + kk); + const ByteVec wv = *reinterpret_cast(w_row + k_base + kk); +#pragma unroll + for (int u = 0; u < CHUNK; ++u) { + const uint8_t raw = wv[u]; + float w; + if constexpr (IsE4M3) { + w = decode_fp8_e4m3(raw) * scale; + } else { + w = decode_fp8_e5m2(raw) * scale; + } + const ScalarT a = sycl::bit_cast(static_cast(av[u])); + acc += static_cast(a) * w; + } + } + for (; kk < group_size; ++kk) { const uint8_t raw = w_row[k_base + kk]; float w; if constexpr (IsE4M3) { @@ -511,8 +654,7 @@ void launch_fp8(sycl::queue* q, const ScalarT* activations, const uint8_t* weigh } outputs[static_cast(token) * N + n_global] = static_cast(acc); - }) - .wait(); + }); } } // namespace moe_decode_detail From 1da19772d71a212ebc77a8cce9f41f4456413f3e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 26 May 2026 13:24:21 +0000 Subject: [PATCH 10/28] feat(ark): add ARK_FP8_DECODE_USE_LUT switch for FP8 decode in MoE kernel Agent-Logs-Url: https://github.com/intel/auto-round/sessions/279896de-7d75-4062-a4c5-cbdf20b70481 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com> --- .../wrapper/include/sycl_tla_moe_decode.hpp | 52 ++++++++++++++++--- 1 file changed, 45 insertions(+), 7 deletions(-) diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp index aa97c4d94..d5ae0a7d8 100644 --- a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp @@ -39,11 +39,25 @@ #include #include "bestla/bestla.h" +#include "bestla/sycl/fp8_lut.h" #ifdef ARK_XPU #include #endif +// ---------------------------------------------------------------------------- +// FP8 decode implementation switch +// +// Define `ARK_FP8_DECODE_USE_LUT` (e.g. -DARK_FP8_DECODE_USE_LUT) to dequantize +// each FP8 byte via the 128-entry magnitude LUT in `bestla/sycl/fp8_lut.h` +// (sign applied separately). Leave it undefined to keep the inline bit-manip +// decode below, which is the default and matches the previous behavior. +// +// Both paths are mathematically equivalent for finite values; pick whichever +// is faster on the target hardware. The LUT trades a handful of bit/branch +// ops for a single constant-memory load per weight byte. +// ---------------------------------------------------------------------------- + #if defined(ARK_XPU) && defined(ARK_SYCL_TLA) namespace ark { @@ -71,14 +85,35 @@ template class MoEDecodeKernelFP8; // ---------------------------------------------------------------------------- -// FP8 byte -> float decode (device-side, no LUT / SLM required). +// FP8 byte -> float decode. // Matches IEEE-style layout used by torch.float8_e4m3fn / torch.float8_e5m2: // E4M3 (finite-only): 1 sign, 4 exp (bias 7), 3 mantissa; 0x7F/0xFF = NaN. // E5M2 (IEEE-like): 1 sign, 5 exp (bias 15), 2 mantissa; exp==31 -> Inf/NaN. -// We keep these inline rather than using a SLM LUT because the decode kernel -// runs only one sub-group per workgroup and per-lane bit ops are cheap relative -// to the global memory loads. +// +// Two implementations are provided and selected at compile time via +// `ARK_FP8_DECODE_USE_LUT` (see the comment near the top of this file): +// - default (macro undefined): inline bit-manipulation, fully self-contained, +// no LUT / SLM required. Per-lane bit ops are cheap relative to the global +// memory loads in this kernel, so this is a reasonable default. +// - macro defined: read the magnitude from the 128-entry constexpr LUT in +// `bestla/sycl/fp8_lut.h` and apply the sign bit separately. // ---------------------------------------------------------------------------- +#if defined(ARK_FP8_DECODE_USE_LUT) + +inline float decode_fp8_e4m3(uint8_t byte) { + const uint32_t mag = byte & 0x7Fu; + const float v = bestla::sycl_prologue_b::fp8_lut::lut_e4m3_128[mag]; + return (byte & 0x80u) ? -v : v; +} + +inline float decode_fp8_e5m2(uint8_t byte) { + const uint32_t mag = byte & 0x7Fu; + const float v = bestla::sycl_prologue_b::fp8_lut::lut_e5m2_128[mag]; + return (byte & 0x80u) ? -v : v; +} + +#else // !ARK_FP8_DECODE_USE_LUT + inline float decode_fp8_e4m3(uint8_t byte) { const uint32_t mag = byte & 0x7Fu; const uint32_t sign = byte >> 7; @@ -119,6 +154,8 @@ inline float decode_fp8_e5m2(uint8_t byte) { return sign ? -v : v; } +#endif // ARK_FP8_DECODE_USE_LUT + // ---------------------------------------------------------------------------- // Build a [total_tokens] -> expert_id mapping from num_tokens_per_expert. // Runs on host (num_experts is small, total_tokens is small in decode). @@ -573,9 +610,10 @@ void launch_int2(sycl::queue* q, const ScalarT* activations, const uint8_t* weig // ---------------------------------------------------------------------------- // FP8 (E4M3 / E5M2) GEMV with group-wise scale (no zero-point). // -// Weights are 1 FP8 byte per element [E, N, K]. The byte is decoded inline by -// bit manipulation; the LUT in fp8_lut.h would also work but inline decode -// keeps this kernel self-contained and avoids touching SLM. +// Weights are 1 FP8 byte per element [E, N, K]. The byte is decoded via the +// `decode_fp8_e4m3` / `decode_fp8_e5m2` helpers above, which can be compiled +// either as inline bit manipulation (default) or as a LUT lookup by defining +// `ARK_FP8_DECODE_USE_LUT`. // ---------------------------------------------------------------------------- template void launch_fp8(sycl::queue* q, const ScalarT* activations, const uint8_t* weights, const ScalarT* scales, From c297d37091cd3b1ea2e51a429f9eac6f4c66807b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 26 May 2026 13:35:06 +0000 Subject: [PATCH 11/28] feat(ark): make FP8 decode LUT switch runtime via ARK_FP8_DECODE_USE_LUT env var (default on) Agent-Logs-Url: https://github.com/intel/auto-round/sessions/0f88e20f-9644-4ebb-8cd3-2052a9f7f2e9 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com> --- .../wrapper/include/sycl_tla_moe_decode.hpp | 169 ++++++++++++------ 1 file changed, 112 insertions(+), 57 deletions(-) diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp index d5ae0a7d8..b10594f0f 100644 --- a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp @@ -38,6 +38,10 @@ #include #include +#include +#include +#include + #include "bestla/bestla.h" #include "bestla/sycl/fp8_lut.h" @@ -46,16 +50,20 @@ #endif // ---------------------------------------------------------------------------- -// FP8 decode implementation switch +// FP8 decode implementation switch (runtime) +// +// FP8 weight bytes can be dequantized either via inline bit manipulation or +// via the 128-entry magnitude LUT in `bestla/sycl/fp8_lut.h` (sign applied +// separately). Both paths are mathematically equivalent for finite values; +// pick whichever is faster on the target hardware. // -// Define `ARK_FP8_DECODE_USE_LUT` (e.g. -DARK_FP8_DECODE_USE_LUT) to dequantize -// each FP8 byte via the 128-entry magnitude LUT in `bestla/sycl/fp8_lut.h` -// (sign applied separately). Leave it undefined to keep the inline bit-manip -// decode below, which is the default and matches the previous behavior. +// Selection is done at runtime through the environment variable +// `ARK_FP8_DECODE_USE_LUT`: +// - unset / "1" / "true" / "on" / "yes" (case-insensitive) -> LUT path (default) +// - "0" / "false" / "off" / "no" (case-insensitive) -> inline bit-manip // -// Both paths are mathematically equivalent for finite values; pick whichever -// is faster on the target hardware. The LUT trades a handful of bit/branch -// ops for a single constant-memory load per weight byte. +// The env var is read once on the host (cached) and passed as a template +// parameter into the SYCL kernel, so there is no per-element runtime branch. // ---------------------------------------------------------------------------- #if defined(ARK_XPU) && defined(ARK_SYCL_TLA) @@ -81,7 +89,7 @@ class MoEDecodeKernelInt8; template class MoEDecodeKernelInt2; -template +template class MoEDecodeKernelFP8; // ---------------------------------------------------------------------------- @@ -90,31 +98,26 @@ class MoEDecodeKernelFP8; // E4M3 (finite-only): 1 sign, 4 exp (bias 7), 3 mantissa; 0x7F/0xFF = NaN. // E5M2 (IEEE-like): 1 sign, 5 exp (bias 15), 2 mantissa; exp==31 -> Inf/NaN. // -// Two implementations are provided and selected at compile time via -// `ARK_FP8_DECODE_USE_LUT` (see the comment near the top of this file): -// - default (macro undefined): inline bit-manipulation, fully self-contained, -// no LUT / SLM required. Per-lane bit ops are cheap relative to the global -// memory loads in this kernel, so this is a reasonable default. -// - macro defined: read the magnitude from the 128-entry constexpr LUT in -// `bestla/sycl/fp8_lut.h` and apply the sign bit separately. +// Two implementations are provided. Selection happens at kernel launch time +// via a bool template parameter (`UseLut`) sourced from the env var +// `ARK_FP8_DECODE_USE_LUT` (see `fp8_decode_use_lut()` below). // ---------------------------------------------------------------------------- -#if defined(ARK_FP8_DECODE_USE_LUT) -inline float decode_fp8_e4m3(uint8_t byte) { +// LUT path: read magnitude from the 128-entry constexpr table, apply sign. +inline float decode_fp8_e4m3_lut(uint8_t byte) { const uint32_t mag = byte & 0x7Fu; const float v = bestla::sycl_prologue_b::fp8_lut::lut_e4m3_128[mag]; return (byte & 0x80u) ? -v : v; } -inline float decode_fp8_e5m2(uint8_t byte) { +inline float decode_fp8_e5m2_lut(uint8_t byte) { const uint32_t mag = byte & 0x7Fu; const float v = bestla::sycl_prologue_b::fp8_lut::lut_e5m2_128[mag]; return (byte & 0x80u) ? -v : v; } -#else // !ARK_FP8_DECODE_USE_LUT - -inline float decode_fp8_e4m3(uint8_t byte) { +// Inline bit-manipulation path: fully self-contained, no LUT / SLM required. +inline float decode_fp8_e4m3_bits(uint8_t byte) { const uint32_t mag = byte & 0x7Fu; const uint32_t sign = byte >> 7; float v; @@ -136,7 +139,7 @@ inline float decode_fp8_e4m3(uint8_t byte) { return sign ? -v : v; } -inline float decode_fp8_e5m2(uint8_t byte) { +inline float decode_fp8_e5m2_bits(uint8_t byte) { const uint32_t mag = byte & 0x7Fu; const uint32_t sign = byte >> 7; const int exp = static_cast((mag >> 2) & 0x1Fu); @@ -154,7 +157,40 @@ inline float decode_fp8_e5m2(uint8_t byte) { return sign ? -v : v; } -#endif // ARK_FP8_DECODE_USE_LUT +// Dispatch helpers used inside the kernel (both branches resolved at compile +// time via `if constexpr`, so there is no per-element runtime cost). +template +inline float decode_fp8(uint8_t byte) { + if constexpr (UseLut) { + if constexpr (IsE4M3) { + return decode_fp8_e4m3_lut(byte); + } else { + return decode_fp8_e5m2_lut(byte); + } + } else { + if constexpr (IsE4M3) { + return decode_fp8_e4m3_bits(byte); + } else { + return decode_fp8_e5m2_bits(byte); + } + } +} + +// ---------------------------------------------------------------------------- +// Host-side env-var reader: cached, defaults to LUT enabled. +// Accepts (case-insensitive) "0"/"false"/"off"/"no" as the OFF spelling. +// ---------------------------------------------------------------------------- +inline bool fp8_decode_use_lut() { + static const bool value = []() { + const char* env = std::getenv("ARK_FP8_DECODE_USE_LUT"); + if (env == nullptr) return true; // default: LUT on + std::string s(env); + for (char& c : s) c = static_cast(std::tolower(static_cast(c))); + if (s == "0" || s == "false" || s == "off" || s == "no") return false; + return true; + }(); + return value; +} // ---------------------------------------------------------------------------- // Build a [total_tokens] -> expert_id mapping from num_tokens_per_expert. @@ -611,11 +647,11 @@ void launch_int2(sycl::queue* q, const ScalarT* activations, const uint8_t* weig // FP8 (E4M3 / E5M2) GEMV with group-wise scale (no zero-point). // // Weights are 1 FP8 byte per element [E, N, K]. The byte is decoded via the -// `decode_fp8_e4m3` / `decode_fp8_e5m2` helpers above, which can be compiled -// either as inline bit manipulation (default) or as a LUT lookup by defining -// `ARK_FP8_DECODE_USE_LUT`. +// `decode_fp8` helper, which selects between the LUT and the +// inline bit-manipulation path at compile time. The choice is driven at +// launch time by the env var `ARK_FP8_DECODE_USE_LUT` (default: ON). // ---------------------------------------------------------------------------- -template +template void launch_fp8(sycl::queue* q, const ScalarT* activations, const uint8_t* weights, const ScalarT* scales, ScalarT* outputs, const int* expert_id_per_token, int total_tokens, int N, int K, int group_size) { if (N % N_TILE != 0) { @@ -632,7 +668,7 @@ void launch_fp8(sycl::queue* q, const ScalarT* activations, const uint8_t* weigh sycl::range<2> global{static_cast(total_tokens), static_cast(n_tiles * SG_SIZE)}; sycl::range<2> local{1, static_cast(SG_SIZE)}; - q->parallel_for>( + q->parallel_for>( sycl::nd_range<2>(global, local), [=](sycl::nd_item<2> it) [[intel::reqd_sub_group_size(SG_SIZE)]] { const int token = static_cast(it.get_global_id(0)); @@ -669,24 +705,14 @@ void launch_fp8(sycl::queue* q, const ScalarT* activations, const uint8_t* weigh #pragma unroll for (int u = 0; u < CHUNK; ++u) { const uint8_t raw = wv[u]; - float w; - if constexpr (IsE4M3) { - w = decode_fp8_e4m3(raw) * scale; - } else { - w = decode_fp8_e5m2(raw) * scale; - } + const float w = decode_fp8(raw) * scale; const ScalarT a = sycl::bit_cast(static_cast(av[u])); acc += static_cast(a) * w; } } for (; kk < group_size; ++kk) { const uint8_t raw = w_row[k_base + kk]; - float w; - if constexpr (IsE4M3) { - w = decode_fp8_e4m3(raw) * scale; - } else { - w = decode_fp8_e5m2(raw) * scale; - } + const float w = decode_fp8(raw) * scale; acc += static_cast(act_row[k_base + kk]) * w; } } @@ -841,30 +867,59 @@ inline void moe_gemm_decode(sycl::queue* q, void* activations, void* weights, vo throw std::invalid_argument("moe_gemm_decode(fp8): asym mode is not supported"); } const bool is_e4m3 = (weight_dtype == BTLA_DTYPE::F8_E4M3); + const bool use_lut = moe_decode_detail::fp8_decode_use_lut(); if (act_dtype == BTLA_DTYPE::F16) { if (is_e4m3) { - moe_decode_detail::launch_fp8( - q, static_cast(activations), static_cast(weights), - static_cast(scales), static_cast(outputs), expert_id_per_token_buf, - total_tokens, N, K, group_size); + if (use_lut) { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, + total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, + total_tokens, N, K, group_size); + } } else { - moe_decode_detail::launch_fp8( - q, static_cast(activations), static_cast(weights), - static_cast(scales), static_cast(outputs), expert_id_per_token_buf, - total_tokens, N, K, group_size); + if (use_lut) { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, + total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, + total_tokens, N, K, group_size); + } } } else if (act_dtype == BTLA_DTYPE::BF16) { using BF = sycl::ext::oneapi::bfloat16; if (is_e4m3) { - moe_decode_detail::launch_fp8( - q, static_cast(activations), static_cast(weights), - static_cast(scales), static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, - group_size); + if (use_lut) { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, + group_size); + } else { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, + group_size); + } } else { - moe_decode_detail::launch_fp8( - q, static_cast(activations), static_cast(weights), - static_cast(scales), static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, - group_size); + if (use_lut) { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, + group_size); + } else { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, + group_size); + } } } else { throw std::invalid_argument("moe_gemm_decode(fp8): act_dtype must be FP16 or BF16"); From 26dbeaadef3ba071a9ebc4674996af5e9ad3679d Mon Sep 17 00:00:00 2001 From: "Dong, Bo1" Date: Wed, 27 May 2026 23:03:05 +0800 Subject: [PATCH 12/28] fix precommit Signed-off-by: Dong, Bo1 --- .../ark/auto_round_kernel/__init__.py | 4 +- auto_round_extension/ark/test/test_moe.py | 4 +- .../ark/test/test_moe_decode_perf.py | 54 +++++++++++-------- 3 files changed, 34 insertions(+), 28 deletions(-) diff --git a/auto_round_extension/ark/auto_round_kernel/__init__.py b/auto_round_extension/ark/auto_round_kernel/__init__.py index 731cdf37d..cb053a655 100644 --- a/auto_round_extension/ark/auto_round_kernel/__init__.py +++ b/auto_round_extension/ark/auto_round_kernel/__init__.py @@ -900,9 +900,7 @@ def moe_gemm_decode( raise ValueError(f"scales shape {tuple(scales.shape)} != expected {expected_scale_shape}") if zeros is not None: raise ValueError("zeros must be None for FP8 weights") - weight_dtype = ( - ARK_DT.float8_e4m3 if weights.dtype == torch.float8_e4m3fn else ARK_DT.float8_e5m2 - ) + weight_dtype = ARK_DT.float8_e4m3 if weights.dtype == torch.float8_e4m3fn else ARK_DT.float8_e5m2 if not scales.is_contiguous(): scales = scales.contiguous() elif weight_bits == 16: diff --git a/auto_round_extension/ark/test/test_moe.py b/auto_round_extension/ark/test/test_moe.py index bca8d84be..7e9c54f8f 100644 --- a/auto_round_extension/ark/test/test_moe.py +++ b/auto_round_extension/ark/test/test_moe.py @@ -679,9 +679,7 @@ def test_decode_int2_asym(self, dtype): torch.testing.assert_close(out, ref, rtol=1e-1, atol=1e-1) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) - @pytest.mark.parametrize( - "fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2] - ) + @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @pytest.mark.parametrize("group_size", [32, 128]) def test_decode_fp8(self, dtype, fp8_dtype, group_size): num_experts = 4 diff --git a/auto_round_extension/ark/test/test_moe_decode_perf.py b/auto_round_extension/ark/test/test_moe_decode_perf.py index 44bbb7fd0..be253dcf9 100644 --- a/auto_round_extension/ark/test/test_moe_decode_perf.py +++ b/auto_round_extension/ark/test/test_moe_decode_perf.py @@ -189,19 +189,13 @@ def _print_header(title: str) -> None: print() print("=" * 96) print(title) - print( - f"{'shape':<14}{'N':>7}{'K':>7}{'tokens':>8}" - f"{'baseline(ms)':>16}{'ark(ms)':>14}{'speedup':>12}" - ) + print(f"{'shape':<14}{'N':>7}{'K':>7}{'tokens':>8}" f"{'baseline(ms)':>16}{'ark(ms)':>14}{'speedup':>12}") print("-" * 96) def _print_row(label, N, K, total_tokens, base_ms, ark_ms): speedup = base_ms / ark_ms if ark_ms > 0 else float("nan") - print( - f"{label:<14}{N:>7}{K:>7}{total_tokens:>8}" - f"{base_ms:>16.4f}{ark_ms:>14.4f}{speedup:>11.2f}x" - ) + print(f"{label:<14}{N:>7}{K:>7}{total_tokens:>8}" f"{base_ms:>16.4f}{ark_ms:>14.4f}{speedup:>11.2f}x") # --------------------------------------------------------------------------- @@ -229,9 +223,7 @@ def test_perf_fp(self, dtype): ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, weights, ntpe)) - ark_ms = _xpu_time_ms( - lambda: ark.moe_gemm_decode(activations, weights, ntpe, weight_bits=16) - ) + ark_ms = _xpu_time_ms(lambda: ark.moe_gemm_decode(activations, weights, ntpe, weight_bits=16)) _print_row(label, N, K, total_tokens, base_ms, ark_ms) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -263,9 +255,14 @@ def test_perf_int4(self, dtype, asym): base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, dequant, ntpe)) ark_ms = _xpu_time_ms( lambda: ark.moe_gemm_decode( - activations, packed, ntpe, - scales=scales, zeros=zeros, - weight_bits=4, group_size=group_size, asym=asym, + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=4, + group_size=group_size, + asym=asym, ) ) _print_row(label, N, K, total_tokens, base_ms, ark_ms) @@ -299,9 +296,14 @@ def test_perf_int8(self, dtype, asym): base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, dequant, ntpe)) ark_ms = _xpu_time_ms( lambda: ark.moe_gemm_decode( - activations, packed, ntpe, - scales=scales, zeros=zeros, - weight_bits=8, group_size=group_size, asym=asym, + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=8, + group_size=group_size, + asym=asym, ) ) _print_row(label, N, K, total_tokens, base_ms, ark_ms) @@ -335,9 +337,14 @@ def test_perf_int2(self, dtype, asym): base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, dequant, ntpe)) ark_ms = _xpu_time_ms( lambda: ark.moe_gemm_decode( - activations, packed, ntpe, - scales=scales, zeros=zeros, - weight_bits=2, group_size=group_size, asym=asym, + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=2, + group_size=group_size, + asym=asym, ) ) _print_row(label, N, K, total_tokens, base_ms, ark_ms) @@ -364,9 +371,12 @@ def test_perf_fp8(self, dtype, fp8_dtype): base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, dequant, ntpe)) ark_ms = _xpu_time_ms( lambda: ark.moe_gemm_decode( - activations, packed, ntpe, + activations, + packed, + ntpe, scales=scales, - group_size=group_size, asym=False, + group_size=group_size, + asym=False, ) ) _print_row(label, N, K, total_tokens, base_ms, ark_ms) From dab6219b4828933ae4188b5076a2941751eee02e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 5 Jun 2026 06:11:19 +0000 Subject: [PATCH 13/28] Apply remaining changes --- .../wrapper/include/sycl_tla_moe_decode.hpp | 117 ++------------ .../wrapper/include/sycl_tla_moe_dequant.hpp | 150 ++++++++++++++++++ 2 files changed, 164 insertions(+), 103 deletions(-) create mode 100644 auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_dequant.hpp diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp index b10594f0f..c78446690 100644 --- a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp @@ -35,15 +35,10 @@ #include #include -#include #include -#include -#include -#include - #include "bestla/bestla.h" -#include "bestla/sycl/fp8_lut.h" +#include "sycl_tla_moe_dequant.hpp" #ifdef ARK_XPU #include @@ -64,6 +59,8 @@ // // The env var is read once on the host (cached) and passed as a template // parameter into the SYCL kernel, so there is no per-element runtime branch. +// The actual primitives live in `sycl_tla_moe_dequant.hpp` (shared with the +// mixed-input prefill path); this file just re-exports them via `using`. // ---------------------------------------------------------------------------- #if defined(ARK_XPU) && defined(ARK_SYCL_TLA) @@ -93,104 +90,18 @@ template class MoEDecodeKernelFP8; // ---------------------------------------------------------------------------- -// FP8 byte -> float decode. -// Matches IEEE-style layout used by torch.float8_e4m3fn / torch.float8_e5m2: -// E4M3 (finite-only): 1 sign, 4 exp (bias 7), 3 mantissa; 0x7F/0xFF = NaN. -// E5M2 (IEEE-like): 1 sign, 5 exp (bias 15), 2 mantissa; exp==31 -> Inf/NaN. -// -// Two implementations are provided. Selection happens at kernel launch time -// via a bool template parameter (`UseLut`) sourced from the env var -// `ARK_FP8_DECODE_USE_LUT` (see `fp8_decode_use_lut()` below). +// FP8 weight dequantization primitives + host-side env-var reader live in +// `sycl_tla_moe_dequant.hpp` so the prefill (mixed-input Grouped GEMM) and +// decode (GEMV) paths share one definition. The `using` declarations below +// keep the in-kernel call sites (`decode_fp8<...>(byte)`) and the host-side +// `fp8_decode_use_lut()` lookup inside `moe_decode_detail` working unchanged. // ---------------------------------------------------------------------------- - -// LUT path: read magnitude from the 128-entry constexpr table, apply sign. -inline float decode_fp8_e4m3_lut(uint8_t byte) { - const uint32_t mag = byte & 0x7Fu; - const float v = bestla::sycl_prologue_b::fp8_lut::lut_e4m3_128[mag]; - return (byte & 0x80u) ? -v : v; -} - -inline float decode_fp8_e5m2_lut(uint8_t byte) { - const uint32_t mag = byte & 0x7Fu; - const float v = bestla::sycl_prologue_b::fp8_lut::lut_e5m2_128[mag]; - return (byte & 0x80u) ? -v : v; -} - -// Inline bit-manipulation path: fully self-contained, no LUT / SLM required. -inline float decode_fp8_e4m3_bits(uint8_t byte) { - const uint32_t mag = byte & 0x7Fu; - const uint32_t sign = byte >> 7; - float v; - if (mag == 0u) { - v = 0.0f; - } else if (mag == 0x7Fu) { - v = sycl::nan(0u); - } else { - const int exp = static_cast((mag >> 3) & 0xFu); - const int man = static_cast(mag & 0x7u); - if (exp == 0) { - // subnormal: value = man * 2^(1 - bias - mbits) = man / 512 - v = static_cast(man) * (1.0f / 512.0f); - } else { - // normal: (1 + man/8) * 2^(exp - bias), bias = 7 - v = (1.0f + static_cast(man) * 0.125f) * sycl::ldexp(1.0f, exp - 7); - } - } - return sign ? -v : v; -} - -inline float decode_fp8_e5m2_bits(uint8_t byte) { - const uint32_t mag = byte & 0x7Fu; - const uint32_t sign = byte >> 7; - const int exp = static_cast((mag >> 2) & 0x1Fu); - const int man = static_cast(mag & 0x3u); - float v; - if (exp == 0) { - // subnormal (incl. zero): value = man * 2^(1 - bias - mbits) = man / 65536 - v = static_cast(man) * (1.0f / 65536.0f); - } else if (exp == 31) { - v = (man == 0) ? std::numeric_limits::infinity() : sycl::nan(0u); - } else { - // normal: (1 + man/4) * 2^(exp - bias), bias = 15 - v = (1.0f + static_cast(man) * 0.25f) * sycl::ldexp(1.0f, exp - 15); - } - return sign ? -v : v; -} - -// Dispatch helpers used inside the kernel (both branches resolved at compile -// time via `if constexpr`, so there is no per-element runtime cost). -template -inline float decode_fp8(uint8_t byte) { - if constexpr (UseLut) { - if constexpr (IsE4M3) { - return decode_fp8_e4m3_lut(byte); - } else { - return decode_fp8_e5m2_lut(byte); - } - } else { - if constexpr (IsE4M3) { - return decode_fp8_e4m3_bits(byte); - } else { - return decode_fp8_e5m2_bits(byte); - } - } -} - -// ---------------------------------------------------------------------------- -// Host-side env-var reader: cached, defaults to LUT enabled. -// Accepts (case-insensitive) "0"/"false"/"off"/"no" as the OFF spelling. -// ---------------------------------------------------------------------------- -inline bool fp8_decode_use_lut() { - static const bool value = []() { - const char* env = std::getenv("ARK_FP8_DECODE_USE_LUT"); - if (env == nullptr) return true; // default: LUT on - std::string s(env); - for (char& c : s) c = static_cast(std::tolower(static_cast(c))); - if (s == "0" || s == "false" || s == "off" || s == "no") return false; - return true; - }(); - return value; -} +using moe_dequant::decode_fp8; +using moe_dequant::decode_fp8_e4m3_bits; +using moe_dequant::decode_fp8_e4m3_lut; +using moe_dequant::decode_fp8_e5m2_bits; +using moe_dequant::decode_fp8_e5m2_lut; +using moe_dequant::fp8_decode_use_lut; // ---------------------------------------------------------------------------- // Build a [total_tokens] -> expert_id mapping from num_tokens_per_expert. diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_dequant.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_dequant.hpp new file mode 100644 index 000000000..191286717 --- /dev/null +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_dequant.hpp @@ -0,0 +1,150 @@ +// SYCL MoE Weight Dequantization Primitives +// +// Device-side dequantization helpers shared between the MoE *decode* (GEMV) +// kernel in `sycl_tla_moe_decode.hpp` and the MoE *prefill* (mixed-input +// Grouped GEMM) kernel in `sycl_tla_moe_mixed.hpp`. Keeping the primitives +// in one place guarantees that both paths produce bit-identical results for +// the same packed weight bytes, which is what the round-trip parity tests +// (decode vs prefill) rely on. +// +// Currently extracted (PR-A1): the FP8 byte->float decoders and the host- +// side `ARK_FP8_DECODE_USE_LUT` env-var reader. INT2/INT4/INT8 decoders are +// still inlined inside the decode kernel; they will be added here when the +// mixed-input prefill mainloop in PR-A2/PR-A3 starts consuming them. +// +// Copyright (C) 2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include +#include + +#include "bestla/sycl/fp8_lut.h" + +#ifdef ARK_XPU +#include +#endif + +#if defined(ARK_XPU) && defined(ARK_SYCL_TLA) + +namespace ark { +namespace moe_dequant { + +// ---------------------------------------------------------------------------- +// FP8 byte -> float decode. +// Matches IEEE-style layout used by torch.float8_e4m3fn / torch.float8_e5m2: +// E4M3 (finite-only): 1 sign, 4 exp (bias 7), 3 mantissa; 0x7F/0xFF = NaN. +// E5M2 (IEEE-like): 1 sign, 5 exp (bias 15), 2 mantissa; exp==31 -> Inf/NaN. +// +// Two equivalent (for finite values) implementations are provided: +// - `_lut`: read magnitude from the 128-entry constexpr table in +// `bestla/sycl/fp8_lut.h`, apply sign separately. +// - `_bits`: fully self-contained inline bit-manipulation, no LUT/SLM. +// +// Selection happens at kernel launch time via a `bool UseLut` template +// parameter sourced from the env var `ARK_FP8_DECODE_USE_LUT` (read once on +// the host by `fp8_decode_use_lut()` below). This keeps the per-element hot +// path branch-free. +// ---------------------------------------------------------------------------- + +inline float decode_fp8_e4m3_lut(uint8_t byte) { + const uint32_t mag = byte & 0x7Fu; + const float v = bestla::sycl_prologue_b::fp8_lut::lut_e4m3_128[mag]; + return (byte & 0x80u) ? -v : v; +} + +inline float decode_fp8_e5m2_lut(uint8_t byte) { + const uint32_t mag = byte & 0x7Fu; + const float v = bestla::sycl_prologue_b::fp8_lut::lut_e5m2_128[mag]; + return (byte & 0x80u) ? -v : v; +} + +inline float decode_fp8_e4m3_bits(uint8_t byte) { + const uint32_t mag = byte & 0x7Fu; + const uint32_t sign = byte >> 7; + float v; + if (mag == 0u) { + v = 0.0f; + } else if (mag == 0x7Fu) { + v = sycl::nan(0u); + } else { + const int exp = static_cast((mag >> 3) & 0xFu); + const int man = static_cast(mag & 0x7u); + if (exp == 0) { + // subnormal: value = man * 2^(1 - bias - mbits) = man / 512 + v = static_cast(man) * (1.0f / 512.0f); + } else { + // normal: (1 + man/8) * 2^(exp - bias), bias = 7 + v = (1.0f + static_cast(man) * 0.125f) * sycl::ldexp(1.0f, exp - 7); + } + } + return sign ? -v : v; +} + +inline float decode_fp8_e5m2_bits(uint8_t byte) { + const uint32_t mag = byte & 0x7Fu; + const uint32_t sign = byte >> 7; + const int exp = static_cast((mag >> 2) & 0x1Fu); + const int man = static_cast(mag & 0x3u); + float v; + if (exp == 0) { + // subnormal (incl. zero): value = man * 2^(1 - bias - mbits) = man / 65536 + v = static_cast(man) * (1.0f / 65536.0f); + } else if (exp == 31) { + v = (man == 0) ? std::numeric_limits::infinity() : sycl::nan(0u); + } else { + // normal: (1 + man/4) * 2^(exp - bias), bias = 15 + v = (1.0f + static_cast(man) * 0.25f) * sycl::ldexp(1.0f, exp - 15); + } + return sign ? -v : v; +} + +// Compile-time dispatch helper. Both branches are resolved via `if constexpr`, +// so there is no per-element runtime cost regardless of which path is chosen. +template +inline float decode_fp8(uint8_t byte) { + if constexpr (UseLut) { + if constexpr (IsE4M3) { + return decode_fp8_e4m3_lut(byte); + } else { + return decode_fp8_e5m2_lut(byte); + } + } else { + if constexpr (IsE4M3) { + return decode_fp8_e4m3_bits(byte); + } else { + return decode_fp8_e5m2_bits(byte); + } + } +} + +// ---------------------------------------------------------------------------- +// Host-side env-var reader: cached, defaults to LUT enabled. +// +// `ARK_FP8_DECODE_USE_LUT`: +// - unset / "1" / "true" / "on" / "yes" (case-insensitive) -> LUT path (default) +// - "0" / "false" / "off" / "no" (case-insensitive) -> inline bit-manip +// +// Read once on first call and cached in a function-local static, so it is +// safe (and free) to call this on every launch. +// ---------------------------------------------------------------------------- +inline bool fp8_decode_use_lut() { + static const bool value = []() { + const char* env = std::getenv("ARK_FP8_DECODE_USE_LUT"); + if (env == nullptr) return true; // default: LUT on + std::string s(env); + for (char& c : s) c = static_cast(std::tolower(static_cast(c))); + if (s == "0" || s == "false" || s == "off" || s == "no") return false; + return true; + }(); + return value; +} + +} // namespace moe_dequant +} // namespace ark + +#endif // ARK_XPU && ARK_SYCL_TLA From 6cf6c3a7e905c1ab9a608bb4509366144e0cfd42 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 8 Jun 2026 02:03:05 +0000 Subject: [PATCH 14/28] feat(ark): add quantized MoE prefill kernel (functional baseline) --- .../ark/auto_round_kernel/__init__.py | 534 +++++++++++------- .../ark/auto_round_kernel/ark.cpp | 13 + .../wrapper/include/sycl_tla_common.hpp | 27 + .../wrapper/include/sycl_tla_moe_mixed.hpp | 426 ++++++++++++++ auto_round_extension/ark/test/test_moe.py | 193 +++++++ 5 files changed, 994 insertions(+), 199 deletions(-) create mode 100644 auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp diff --git a/auto_round_extension/ark/auto_round_kernel/__init__.py b/auto_round_extension/ark/auto_round_kernel/__init__.py index 1580aac44..766fa46e5 100644 --- a/auto_round_extension/ark/auto_round_kernel/__init__.py +++ b/auto_round_extension/ark/auto_round_kernel/__init__.py @@ -1065,207 +1065,237 @@ def _ceil_div(a, b): out = out[:, :, :Sq, :] return out - def moe_gemm_decode( - self, - activations: torch.Tensor, - weights: torch.Tensor, - num_tokens_per_expert: torch.Tensor, - *, - scales: Optional[torch.Tensor] = None, - zeros: Optional[torch.Tensor] = None, - weight_bits: int = 4, - group_size: int = 128, - asym: bool = False, - ) -> torch.Tensor: - """MoE GEMV optimized for the decode phase. - - Each expert typically processes only 1-2 tokens (top-k routing with - small batch). Activations must already be gathered/sorted by expert - (same convention as ``moe_gemm``). - - Args: - activations: ``[total_tokens, K]`` in fp16 or bf16. - weights: 3-D tensor ``[E, N, K_packed]``. The accepted layouts are: - - * Unquantized (``weight_bits=16``): ``torch.float16`` / ``torch.bfloat16`` - matching the activations dtype, ``K_packed == K``. - * Int8 (``weight_bits=8``): ``torch.uint8``, ``K_packed == K``. - Sym (``asym=False``) reinterprets each byte as signed int8; - asym (``asym=True``) treats each byte as ``uint8`` with a - per-group zero-point. - * Int4 (``weight_bits=4``): ``torch.uint8`` packed, - ``K_packed == K // 2`` (two 4-bit values per byte; low nibble - at the lower K index). - * Int2 (``weight_bits=2``): ``torch.uint8`` packed, - ``K_packed == K // 4`` (four 2-bit values per byte; field j at - K index ``4*i + j`` occupies bits 2j and 2j+1 of byte i). - * FP8 (``torch.float8_e4m3fn`` / ``torch.float8_e5m2``): - ``K_packed == K``. ``weight_bits`` is ignored; ``asym`` must - be ``False`` (no zero-points for FP8). - num_tokens_per_expert: ``[E]`` int32. Sum must equal - ``activations.shape[0]``. - scales: ``[E, N, K // group_size]`` in activations dtype. Required - for all quantized paths (int8/int4/int2/fp8); must be ``None`` - for unquantized weights. - zeros: ``[E, N, K // group_size]`` in activations dtype. Required - when ``asym=True`` (int8/int4/int2 only); otherwise ``None``. - weight_bits: 2, 4, 8, or 16. Ignored when ``weights`` is an FP8 - tensor (the FP8 sub-format is taken from ``weights.dtype``). - group_size: group along K for quantized weights (default 128). - asym: if ``True``, weights use unsigned encoding and ``zeros`` must - be provided. Not supported for FP8. - - Returns: - outputs: ``[total_tokens, N]`` in the same dtype as activations. - """ - if activations.device.type != "xpu": - raise NotImplementedError("moe_gemm_decode is only supported on XPU") - - if activations.dtype not in (torch.float16, torch.bfloat16): - raise ValueError(f"activations must be fp16/bf16, got {activations.dtype}") - - if activations.ndim != 2: - raise ValueError("activations must be 2D [total_tokens, K]") - if weights.ndim != 3: - raise ValueError("weights must be 3D [E, N, K_packed]") - - if not activations.is_contiguous(): - activations = activations.contiguous() - if not weights.is_contiguous(): - weights = weights.contiguous() - - if num_tokens_per_expert.dtype != torch.int32: - num_tokens_per_expert = num_tokens_per_expert.to(torch.int32) - if not num_tokens_per_expert.is_contiguous(): - num_tokens_per_expert = num_tokens_per_expert.contiguous() - - total_tokens, K = activations.shape - num_experts = weights.shape[0] - N = weights.shape[1] - - if num_tokens_per_expert.shape[0] != num_experts: + +def moe_gemm_decode( + activations: torch.Tensor, + weights: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + *, + scales: Optional[torch.Tensor] = None, + zeros: Optional[torch.Tensor] = None, + weight_bits: int = 4, + group_size: int = 128, + asym: bool = False, +) -> torch.Tensor: + """MoE GEMV optimized for the decode phase. + + Each expert typically processes only 1-2 tokens (top-k routing with + small batch). Activations must already be gathered/sorted by expert + (same convention as ``moe_gemm``). + + Args: + activations: ``[total_tokens, K]`` in fp16 or bf16. + weights: 3-D tensor ``[E, N, K_packed]``. The accepted layouts are: + + * Unquantized (``weight_bits=16``): ``torch.float16`` / ``torch.bfloat16`` + matching the activations dtype, ``K_packed == K``. + * Int8 (``weight_bits=8``): ``torch.uint8``, ``K_packed == K``. + Sym (``asym=False``) reinterprets each byte as signed int8; + asym (``asym=True``) treats each byte as ``uint8`` with a + per-group zero-point. + * Int4 (``weight_bits=4``): ``torch.uint8`` packed, + ``K_packed == K // 2`` (two 4-bit values per byte; low nibble + at the lower K index). + * Int2 (``weight_bits=2``): ``torch.uint8`` packed, + ``K_packed == K // 4`` (four 2-bit values per byte; field j at + K index ``4*i + j`` occupies bits 2j and 2j+1 of byte i). + * FP8 (``torch.float8_e4m3fn`` / ``torch.float8_e5m2``): + ``K_packed == K``. ``weight_bits`` is ignored; ``asym`` must + be ``False`` (no zero-points for FP8). + num_tokens_per_expert: ``[E]`` int32. Sum must equal + ``activations.shape[0]``. + scales: ``[E, N, K // group_size]`` in activations dtype. Required + for all quantized paths (int8/int4/int2/fp8); must be ``None`` + for unquantized weights. + zeros: ``[E, N, K // group_size]`` in activations dtype. Required + when ``asym=True`` (int8/int4/int2 only); otherwise ``None``. + weight_bits: 2, 4, 8, or 16. Ignored when ``weights`` is an FP8 + tensor (the FP8 sub-format is taken from ``weights.dtype``). + group_size: group along K for quantized weights (default 128). + asym: if ``True``, weights use unsigned encoding and ``zeros`` must + be provided. Not supported for FP8. + + Returns: + outputs: ``[total_tokens, N]`` in the same dtype as activations. + """ + activations, weights, scales, zeros, num_tokens_per_expert, weight_dtype, total_tokens, N, K, num_experts = ( + _validate_moe_quant_args( + activations, weights, num_tokens_per_expert, + scales=scales, zeros=zeros, weight_bits=weight_bits, group_size=group_size, asym=asym, + api_name="moe_gemm_decode", + ) + ) + + lib = get_lib(activations) + stream = get_stream(activations) + outputs = torch.empty((total_tokens, N), device=activations.device, dtype=activations.dtype) + # Scratch buffer mapping each token to its expert id; filled on-device + # inside the kernel wrapper so we avoid host-device sync. + expert_id_per_token = torch.empty((total_tokens,), device=activations.device, dtype=torch.int32) + + scales_ptr = scales.data_ptr() if scales is not None else 0 + zeros_ptr = zeros.data_ptr() if zeros is not None else 0 + + lib.moe_gemm_decode( + stream, + activations.data_ptr(), + weights.data_ptr(), + scales_ptr, + zeros_ptr, + outputs.data_ptr(), + expert_id_per_token.data_ptr(), + cvt_dtype(activations.dtype), + weight_dtype, + N, + K, + group_size, + num_tokens_per_expert.data_ptr(), + num_experts, + total_tokens, + bool(asym), + ) + return outputs + + +def _validate_moe_quant_args( + activations: torch.Tensor, + weights: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + *, + scales: Optional[torch.Tensor], + zeros: Optional[torch.Tensor], + weight_bits: int, + group_size: int, + asym: bool, + api_name: str, +): + """Shared validation/normalisation for quantized MoE entry points. + + Returns a tuple of normalised tensors and dtype/shape metadata used by the + kernel-call site: + ``(activations, weights, scales, zeros, num_tokens_per_expert, + weight_dtype, total_tokens, N, K, num_experts)``. + """ + if activations.device.type != "xpu": + raise NotImplementedError(f"{api_name} is only supported on XPU") + + if activations.dtype not in (torch.float16, torch.bfloat16): + raise ValueError(f"activations must be fp16/bf16, got {activations.dtype}") + + if activations.ndim != 2: + raise ValueError("activations must be 2D [total_tokens, K]") + if weights.ndim != 3: + raise ValueError("weights must be 3D [E, N, K_packed]") + + if not activations.is_contiguous(): + activations = activations.contiguous() + if not weights.is_contiguous(): + weights = weights.contiguous() + + if num_tokens_per_expert.dtype != torch.int32: + num_tokens_per_expert = num_tokens_per_expert.to(torch.int32) + if not num_tokens_per_expert.is_contiguous(): + num_tokens_per_expert = num_tokens_per_expert.contiguous() + + total_tokens, K = activations.shape + num_experts = weights.shape[0] + N = weights.shape[1] + + if num_tokens_per_expert.shape[0] != num_experts: + raise ValueError( + f"num_tokens_per_expert length {num_tokens_per_expert.shape[0]} != num_experts {num_experts}" + ) + + # Detect FP8 weight dtype first (overrides weight_bits). + is_fp8 = weights.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + + if is_fp8: + if asym: + raise ValueError("FP8 weights do not support asym=True") + if weights.shape[2] != K: + raise ValueError(f"FP8 weights K dim {weights.shape[2]} != activations K {K}") + if scales is None: + raise ValueError("scales is required for FP8 weights") + if scales.dtype != activations.dtype: + raise ValueError("scales dtype must match activations dtype") + if K % group_size != 0: + raise ValueError("K must be a multiple of group_size") + expected_scale_shape = (num_experts, N, K // group_size) + if tuple(scales.shape) != expected_scale_shape: + raise ValueError(f"scales shape {tuple(scales.shape)} != expected {expected_scale_shape}") + if zeros is not None: + raise ValueError("zeros must be None for FP8 weights") + weight_dtype = ARK_DT.float8_e4m3 if weights.dtype == torch.float8_e4m3fn else ARK_DT.float8_e5m2 + if not scales.is_contiguous(): + scales = scales.contiguous() + elif weight_bits == 16: + if weights.dtype != activations.dtype: + raise ValueError("Unquantized weights must match activations dtype") + if weights.shape[2] != K: + raise ValueError(f"Unquantized weights K dim {weights.shape[2]} != activations K {K}") + weight_dtype = cvt_dtype(activations.dtype) + if scales is not None or zeros is not None: + raise ValueError("scales/zeros must be None when weight_bits=16") + elif weight_bits in (8, 4, 2): + if weights.dtype != torch.uint8: + raise ValueError(f"Int{weight_bits} packed weights must be torch.uint8") + if weight_bits == 8: + k_packed_expected = K + k_div = 1 + elif weight_bits == 4: + k_packed_expected = K // 2 + k_div = 2 + else: # weight_bits == 2 + k_packed_expected = K // 4 + k_div = 4 + if K % k_div != 0: + raise ValueError(f"K must be a multiple of {k_div} for weight_bits={weight_bits}") + if weights.shape[2] != k_packed_expected: raise ValueError( - f"num_tokens_per_expert length {num_tokens_per_expert.shape[0]} != num_experts {num_experts}" + f"Int{weight_bits} packed weights last dim {weights.shape[2]} must equal K/{k_div} " + f"({k_packed_expected})" ) - - # Detect FP8 weight dtype first (overrides weight_bits). - is_fp8 = weights.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) - - # Validate weight layout / dtype combination. - if is_fp8: - if asym: - raise ValueError("FP8 weights do not support asym=True") - if weights.shape[2] != K: - raise ValueError(f"FP8 weights K dim {weights.shape[2]} != activations K {K}") - if scales is None: - raise ValueError("scales is required for FP8 weights") - if scales.dtype != activations.dtype: - raise ValueError("scales dtype must match activations dtype") - if K % group_size != 0: - raise ValueError("K must be a multiple of group_size") - expected_scale_shape = (num_experts, N, K // group_size) - if tuple(scales.shape) != expected_scale_shape: - raise ValueError(f"scales shape {tuple(scales.shape)} != expected {expected_scale_shape}") - if zeros is not None: - raise ValueError("zeros must be None for FP8 weights") - weight_dtype = ARK_DT.float8_e4m3 if weights.dtype == torch.float8_e4m3fn else ARK_DT.float8_e5m2 - if not scales.is_contiguous(): - scales = scales.contiguous() - elif weight_bits == 16: - if weights.dtype != activations.dtype: - raise ValueError("Unquantized weights must match activations dtype") - if weights.shape[2] != K: - raise ValueError(f"Unquantized weights K dim {weights.shape[2]} != activations K {K}") - weight_dtype = cvt_dtype(activations.dtype) - if scales is not None or zeros is not None: - raise ValueError("scales/zeros must be None when weight_bits=16") - elif weight_bits in (8, 4, 2): - if weights.dtype != torch.uint8: - raise ValueError(f"Int{weight_bits} packed weights must be torch.uint8") - if weight_bits == 8: - k_packed_expected = K - k_div = 1 - elif weight_bits == 4: - k_packed_expected = K // 2 - k_div = 2 - else: # weight_bits == 2 - k_packed_expected = K // 4 - k_div = 4 - if K % k_div != 0: - raise ValueError(f"K must be a multiple of {k_div} for weight_bits={weight_bits}") - if weights.shape[2] != k_packed_expected: - raise ValueError( - f"Int{weight_bits} packed weights last dim {weights.shape[2]} must equal K/{k_div} " - f"({k_packed_expected})" - ) - if scales is None: - raise ValueError(f"scales is required for int{weight_bits} weights") - if scales.dtype != activations.dtype: - raise ValueError("scales dtype must match activations dtype") - if K % group_size != 0: - raise ValueError("K must be a multiple of group_size") - # Group_size constraints per dtype. - if weight_bits == 4 and (group_size & 1) != 0: - raise ValueError("group_size must be even for int4 weights") - if weight_bits == 2 and (group_size & 3) != 0: - raise ValueError("group_size must be a multiple of 4 for int2 weights") - expected_scale_shape = (num_experts, N, K // group_size) - if tuple(scales.shape) != expected_scale_shape: - raise ValueError(f"scales shape {tuple(scales.shape)} != expected {expected_scale_shape}") - if asym: - if zeros is None: - raise ValueError("zeros is required when asym=True") - if zeros.dtype != activations.dtype: - raise ValueError("zeros dtype must match activations dtype") - if tuple(zeros.shape) != expected_scale_shape: - raise ValueError(f"zeros shape {tuple(zeros.shape)} != expected {expected_scale_shape}") - else: - if zeros is not None: - raise ValueError("zeros must be None when asym=False") - weight_dtype = {8: ARK_DT.int8, 4: ARK_DT.int4, 2: ARK_DT.int2}[weight_bits] - if not scales.is_contiguous(): - scales = scales.contiguous() - if asym and not zeros.is_contiguous(): - zeros = zeros.contiguous() + if scales is None: + raise ValueError(f"scales is required for int{weight_bits} weights") + if scales.dtype != activations.dtype: + raise ValueError("scales dtype must match activations dtype") + if K % group_size != 0: + raise ValueError("K must be a multiple of group_size") + # Group_size constraints per dtype. + if weight_bits == 4 and (group_size & 1) != 0: + raise ValueError("group_size must be even for int4 weights") + if weight_bits == 2 and (group_size & 3) != 0: + raise ValueError("group_size must be a multiple of 4 for int2 weights") + expected_scale_shape = (num_experts, N, K // group_size) + if tuple(scales.shape) != expected_scale_shape: + raise ValueError(f"scales shape {tuple(scales.shape)} != expected {expected_scale_shape}") + if asym: + if zeros is None: + raise ValueError("zeros is required when asym=True") + if zeros.dtype != activations.dtype: + raise ValueError("zeros dtype must match activations dtype") + if tuple(zeros.shape) != expected_scale_shape: + raise ValueError(f"zeros shape {tuple(zeros.shape)} != expected {expected_scale_shape}") else: - raise ValueError(f"Unsupported weight_bits={weight_bits} (supported: 2, 4, 8, 16)") - - if N % 16 != 0: - raise ValueError(f"N must be a multiple of 16 (got {N})") - - expected_total = int(num_tokens_per_expert.sum().item()) - if expected_total != total_tokens: - raise ValueError(f"Sum of num_tokens_per_expert ({expected_total}) != total_tokens ({total_tokens})") - - lib = self.get_lib(activations) - stream = get_stream(activations) - outputs = torch.empty((total_tokens, N), device=activations.device, dtype=activations.dtype) - # Scratch buffer mapping each token to its expert id; filled on-device - # inside the kernel wrapper so we avoid host-device sync. - expert_id_per_token = torch.empty((total_tokens,), device=activations.device, dtype=torch.int32) - - scales_ptr = scales.data_ptr() if scales is not None else 0 - zeros_ptr = zeros.data_ptr() if zeros is not None else 0 - - lib.moe_gemm_decode( - stream, - activations.data_ptr(), - weights.data_ptr(), - scales_ptr, - zeros_ptr, - outputs.data_ptr(), - expert_id_per_token.data_ptr(), - cvt_dtype(activations.dtype), - weight_dtype, - N, - K, - group_size, - num_tokens_per_expert.data_ptr(), - num_experts, - total_tokens, - bool(asym), - ) - return outputs + if zeros is not None: + raise ValueError("zeros must be None when asym=False") + weight_dtype = {8: ARK_DT.int8, 4: ARK_DT.int4, 2: ARK_DT.int2}[weight_bits] + if not scales.is_contiguous(): + scales = scales.contiguous() + if asym and not zeros.is_contiguous(): + zeros = zeros.contiguous() + else: + raise ValueError(f"Unsupported weight_bits={weight_bits} (supported: 2, 4, 8, 16)") + + if N % 16 != 0: + raise ValueError(f"N must be a multiple of 16 (got {N})") + + expected_total = int(num_tokens_per_expert.sum().item()) + if expected_total != total_tokens: + raise ValueError(f"Sum of num_tokens_per_expert ({expected_total}) != total_tokens ({total_tokens})") + + return (activations, weights, scales, zeros, num_tokens_per_expert, + weight_dtype, total_tokens, N, K, num_experts) def moe_gemm( @@ -1344,7 +1374,113 @@ def moe_gemm( return outputs -def patch_torch_sdpa(*args, **kwargs): +def moe_gemm_prefill( + activations: torch.Tensor, + weights: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + *, + scales: Optional[torch.Tensor] = None, + zeros: Optional[torch.Tensor] = None, + weight_bits: int = 4, + group_size: int = 128, + asym: bool = False, +) -> torch.Tensor: + """MoE Grouped GEMM optimized for the prefill phase, supporting all weight + encodings of ``moe_gemm_decode`` (FP16/BF16, INT8 sym/asym, INT4 sym/asym, + INT2 sym/asym, FP8 E4M3/E5M2). + + The argument shapes/dtypes match :func:`moe_gemm_decode` exactly so the same + quantized weights/scales/zeros tensors can be re-used between prefill and + decode without reshaping. Internally, for the quantized paths the kernel + materialises a ``[E, K, N]`` ``act_dtype`` temporary via an on-device + dequantization kernel and then dispatches to the existing CUTLASS-SYCL + Grouped GEMM (``moe_gemm``). Numerical results are bit-identical to + ``moe_gemm`` applied to the same dequantized weights. + + Args: + activations: ``[total_tokens, K]`` in fp16 or bf16. + weights: 3-D tensor; same layout/dtype contract as + :func:`moe_gemm_decode`. Quantized layouts are ``[E, N, K_packed]``; + the unquantized fast path (``weight_bits=16``) accepts + ``[E, N, K]`` -- callers providing already-``[E, K, N]`` weights + (as ``moe_gemm`` requires) should call ``moe_gemm`` directly. + num_tokens_per_expert: ``[E]`` int32. Sum must equal + ``activations.shape[0]``. + scales: ``[E, N, K // group_size]`` in activations dtype. Required for + quantized paths; ignored (must be ``None``) for unquantized. + zeros: ``[E, N, K // group_size]`` in activations dtype, required when + ``asym=True`` (int8/int4/int2 only). + weight_bits: 2, 4, 8, or 16. Ignored for FP8 weights. + group_size: group along K for quantized weights (default 128). + asym: if ``True``, weights use unsigned encoding; ``zeros`` required. + + Returns: + outputs: ``[total_tokens, N]`` in the same dtype as activations. + """ + activations, weights, scales, zeros, num_tokens_per_expert, weight_dtype, total_tokens, N, K, num_experts = ( + _validate_moe_quant_args( + activations, weights, num_tokens_per_expert, + scales=scales, zeros=zeros, weight_bits=weight_bits, group_size=group_size, asym=asym, + api_name="moe_gemm_prefill", + ) + ) + + lib = get_lib(activations) + stream = get_stream(activations) + outputs = torch.empty((total_tokens, N), device=activations.device, dtype=activations.dtype) + + # Quantized paths need an [E, K, N] act-dtype scratch buffer that the + # on-device dequant kernel fills before the inner Grouped GEMM consumes + # it. The unquantized fast path forwards directly through `moe_gemm` and + # doesn't need scratch (passing 0 is safe -- the C++ wrapper short-circuits + # before touching the workspace pointer in that case). We allocate the + # workspace from PyTorch's caching allocator so repeated calls reuse the + # same memory. + is_unquantized = (weight_bits == 16) and (weights.dtype == activations.dtype) + if is_unquantized: + # `moe_gemm` requires `[E, K, N]` row-major weights; the decode-style + # `[E, N, K]` weight shape coming through this validator can be + # transposed into a temporary contiguous `[E, K, N]` view. The + # workspace serves the same role as the dequant scratch so the + # on-device path stays uniform. + dequant_workspace = weights.transpose(1, 2).contiguous() + weights_ptr = dequant_workspace.data_ptr() + else: + dequant_workspace = torch.empty( + (num_experts, K, N), device=activations.device, dtype=activations.dtype + ) + weights_ptr = weights.data_ptr() + + scales_ptr = scales.data_ptr() if scales is not None else 0 + zeros_ptr = zeros.data_ptr() if zeros is not None else 0 + + lib.moe_gemm_prefill( + stream, + activations.data_ptr(), + weights_ptr, + scales_ptr, + zeros_ptr, + outputs.data_ptr(), + dequant_workspace.data_ptr(), + cvt_dtype(activations.dtype), + weight_dtype, + N, + K, + group_size, + num_tokens_per_expert.data_ptr(), + num_experts, + total_tokens, + bool(asym), + ) + # The inner CUTLASS-SYCL `moe_gemm` calls `event.wait()` before returning + # (see `moe_detail::moe_gemm_launcher` in `sycl_tla_moe.hpp`), so by the + # time `lib.moe_gemm_prefill` returns the device has already consumed the + # workspace and it is safe to drop our reference here. + del dequant_workspace + return outputs + + + from .torch_sdpa_patch import patch_torch_sdpa_with_ark return patch_torch_sdpa_with_ark(*args, **kwargs) diff --git a/auto_round_extension/ark/auto_round_kernel/ark.cpp b/auto_round_extension/ark/auto_round_kernel/ark.cpp index ec1ac0393..6a5d390fe 100755 --- a/auto_round_extension/ark/auto_round_kernel/ark.cpp +++ b/auto_round_extension/ark/auto_round_kernel/ark.cpp @@ -26,6 +26,7 @@ typedef uintptr_t torch_ptr; #include "sycl_tla_common.hpp" #include "sycl_tla_moe.hpp" #include "sycl_tla_moe_decode.hpp" +#include "sycl_tla_moe_mixed.hpp" #include "sycl_tla_sdpa.hpp" #endif #else @@ -233,6 +234,17 @@ static void moe_gemm_decode_wrapper(torch_ptr stream, torch_ptr activations, tor (int*)num_tokens_per_expert, num_experts, total_tokens, asym); } +static void moe_gemm_prefill_wrapper(torch_ptr stream, torch_ptr activations, torch_ptr weights, torch_ptr scales, + torch_ptr zeros, torch_ptr outputs, torch_ptr dequant_workspace, int act_dtype, + int weight_dtype, int N, int K, int group_size, torch_ptr num_tokens_per_expert, + int num_experts, int total_tokens, bool asym) { + ark::moe_gemm_prefill((sycl::queue*)stream, (void*)activations, (void*)weights, scales ? (void*)scales : nullptr, + zeros ? (void*)zeros : nullptr, (void*)outputs, + dequant_workspace ? (void*)dequant_workspace : nullptr, (BTLA_DTYPE)(act_dtype), + (BTLA_DTYPE)(weight_dtype), N, K, group_size, (int*)num_tokens_per_expert, num_experts, + total_tokens, asym); +} + static void sage_dynamic_quant(torch_ptr stream, torch_ptr input, torch_ptr bias, torch_ptr output, torch_ptr scale_out, int num_rows, int head_dim, int block_size) { auto* q = (sycl::queue*)stream; @@ -451,5 +463,6 @@ PYBIND11_MODULE(PY_NAME, m) { m.def("sage_dynamic_quant_v_layout", &ark::sage_dynamic_quant_v_layout); m.def("moe_gemm", &ark::moe_gemm_wrapper); m.def("moe_gemm_decode", &ark::moe_gemm_decode_wrapper); + m.def("moe_gemm_prefill", &ark::moe_gemm_prefill_wrapper); #endif } \ No newline at end of file diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp index 5f2f9e133..cc8cd8971 100644 --- a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp @@ -69,6 +69,33 @@ void moe_gemm_decode(sycl::queue* q, void* activations, void* weights, void* sca int* expert_id_per_token_buf, BTLA_DTYPE act_dtype, BTLA_DTYPE weight_dtype, int N, int K, int group_size, int* num_tokens_per_expert, int num_experts, int total_tokens, bool asym); +/** + * @brief MoE Grouped GEMM optimized for the prefill phase, supporting the + * same set of weight encodings as `moe_gemm_decode` (FP16/BF16, INT8 sym/asym, + * INT4 sym/asym, INT2 sym/asym, FP8 E4M3/E5M2). + * + * Stage-1 implementation: dequantizes weights into a `[num_experts, K, N]` + * temporary buffer (must be supplied by the caller via `dequant_workspace`, + * sized `num_experts * K * N * sizeof(act_dtype)`) and then dispatches to the + * existing `moe_gemm` baseline. This guarantees numerical parity with the + * decode path. Mainloop fusion is the follow-up perf-tuning step. + * + * Implementation is header-only in `sycl_tla_moe_mixed.hpp`. + * + * Layout convention (matches `moe_gemm_decode`): + * - activations: [total_tokens, K] in act_dtype + * - weights (quantized): [num_experts, N, K_p] uint8 (decode-style packed) + * - weights (FP16/BF16): [num_experts, K, N] in act_dtype (matches `moe_gemm`) + * - scales: [num_experts, N, K/group_size] in act_dtype + * - zeros (asym only): [num_experts, N, K/group_size] in act_dtype + * - dequant_workspace: [num_experts, K, N] in act_dtype, may be null + * for the unquantized fast path + * - outputs: [total_tokens, N] in act_dtype + */ +void moe_gemm_prefill(sycl::queue* q, void* activations, void* weights, void* scales, void* zeros, void* outputs, + void* dequant_workspace, BTLA_DTYPE act_dtype, BTLA_DTYPE weight_dtype, int N, int K, + int group_size, int* num_tokens_per_expert, int num_experts, int total_tokens, bool asym); + // ======================================================================== // Public API // ======================================================================== diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp new file mode 100644 index 000000000..6e24534f9 --- /dev/null +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp @@ -0,0 +1,426 @@ +// SYCL MoE Mixed-Input Prefill Kernel +// +// MoE prefill (Grouped GEMM) entry point that accepts the same set of +// quantized weight encodings as the decode kernel in +// `sycl_tla_moe_decode.hpp` (FP16/BF16 baseline, INT8 sym/asym, INT4 +// sym/asym, INT2 sym/asym, FP8 E4M3/E5M2 with group-wise scale). +// +// Stage-1 implementation ("function-first"): a single device-side +// dequantization kernel materialises the per-expert weights into a +// `[E, K, N]` fp16/bf16 temporary, after which the existing CUTLASS-SYCL +// grouped GEMM (`moe_gemm` in `sycl_tla_moe.hpp`) is invoked. This keeps +// the dispatch surface, packing convention and numerical behaviour +// bit-identical to the decode path so end-to-end models can be validated +// and profiled. Mainloop fusion (mixed-input grouped GEMM) is the +// follow-up perf-tuning step. +// +// Layout convention (matches `sycl_tla_moe_decode.hpp`): +// - activations: [total_tokens, K] row-major +// - weights (fp/bf16): [num_experts, N, K] row-major +// - weights (int8): [num_experts, N, K] row-major (uint8 buf) +// - weights (int4 packed): [num_experts, N, K/2] row-major (uint8) +// - weights (int2 packed): [num_experts, N, K/4] row-major (uint8) +// - weights (fp8): [num_experts, N, K] row-major (uint8 buf) +// - scales: [num_experts, N, K/group] in act dtype +// - zeros (asym only): [num_experts, N, K/group] in act dtype +// - num_tokens_per_expert: [num_experts] int32 +// - outputs: [total_tokens, N] row-major +// +// The dequantized weights are written transposed to `[E, K, N]` so the +// existing prefill grouped GEMM (which expects `[E, K, N]` row-major) can +// consume them directly without an additional transpose pass. +// +// Copyright (C) 2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "bestla/bestla.h" +#include "sycl_tla_moe.hpp" +#include "sycl_tla_moe_dequant.hpp" + +#ifdef ARK_XPU +#include +#endif + +#if defined(ARK_XPU) && defined(ARK_SYCL_TLA) + +namespace ark { +namespace moe_mixed_detail { + +// ---------------------------------------------------------------------------- +// Kernel name tags (one per specialization, required for SYCL kernel naming). +// ---------------------------------------------------------------------------- +template +class MoEDequantKernelFP; + +template +class MoEDequantKernelInt8; + +template +class MoEDequantKernelInt4; + +template +class MoEDequantKernelInt2; + +template +class MoEDequantKernelFP8; + +// Tile sizes for the dequant kernel: each work-item dequantises one weight +// element, with the work-group covering a 1xWG_N tile along (k, n). N is +// chosen as the inner dimension to keep stores to the [E, K, N] output +// coalesced across the sub-group. +constexpr int WG_N = 16; + +// ---------------------------------------------------------------------------- +// FP16 / BF16 weight reshape: in-place transpose [E, N, K] -> [E, K, N]. +// Implemented as a generic dequant pass with identity scale; used so the +// public entry point can dispatch FP16/BF16 through the same code path +// as the quantized variants. +// ---------------------------------------------------------------------------- +template +void launch_dequant_fp(sycl::queue* q, const ScalarT* weights_NK, ScalarT* weights_KN, int E, int N, int K) { + if (E == 0 || N == 0 || K == 0) return; + + sycl::range<3> global{static_cast(E), static_cast(K), + static_cast((N + WG_N - 1) / WG_N) * WG_N}; + sycl::range<3> local{1, 1, static_cast(WG_N)}; + + q->parallel_for>( + sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { + const int e = static_cast(it.get_global_id(0)); + const int k = static_cast(it.get_global_id(1)); + const int n = static_cast(it.get_global_id(2)); + if (n >= N) return; + const ScalarT v = + weights_NK[(static_cast(e) * N + static_cast(n)) * K + static_cast(k)]; + weights_KN[(static_cast(e) * K + static_cast(k)) * N + static_cast(n)] = v; + }); +} + +// ---------------------------------------------------------------------------- +// INT8 (S8) dequant: [E, N, K] uint8 -> [E, K, N] ScalarT. +// Asym=false: signed int8 in [-128, 127], dequant = q * scale +// Asym=true : unsigned uint8 in [0, 255], dequant = (q - zero) * scale +// ---------------------------------------------------------------------------- +template +void launch_dequant_int8(sycl::queue* q, const uint8_t* weights_NK, const ScalarT* scales, const ScalarT* zeros, + ScalarT* weights_KN, int E, int N, int K, int group_size) { + if (E == 0 || N == 0 || K == 0) return; + if (Asym && zeros == nullptr) { + throw std::invalid_argument("moe_gemm_prefill(int8): zeros pointer required when asym=true"); + } + const int num_groups_k = K / group_size; + + sycl::range<3> global{static_cast(E), static_cast(K), + static_cast((N + WG_N - 1) / WG_N) * WG_N}; + sycl::range<3> local{1, 1, static_cast(WG_N)}; + + q->parallel_for>( + sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { + const int e = static_cast(it.get_global_id(0)); + const int k = static_cast(it.get_global_id(1)); + const int n = static_cast(it.get_global_id(2)); + if (n >= N) return; + const int g = k / group_size; + const size_t s_idx = (static_cast(e) * N + static_cast(n)) * num_groups_k + + static_cast(g); + const float scale = static_cast(scales[s_idx]); + const uint8_t raw = + weights_NK[(static_cast(e) * N + static_cast(n)) * K + static_cast(k)]; + float w; + if constexpr (Asym) { + const float zero = static_cast(zeros[s_idx]); + w = (static_cast(raw) - zero) * scale; + } else { + w = static_cast(static_cast(raw)) * scale; + } + weights_KN[(static_cast(e) * K + static_cast(k)) * N + static_cast(n)] = + static_cast(w); + }); +} + +// ---------------------------------------------------------------------------- +// INT4 (S4_CLIP) dequant: [E, N, K/2] uint8 packed -> [E, K, N] ScalarT. +// Packing: low nibble at lower K (k = 2*i), high nibble at higher K (k = 2*i+1). +// Asym=false: signed nibble in [-8, 7], dequant = q * scale +// Asym=true : unsigned nibble in [0, 15], dequant = (q - zero) * scale +// ---------------------------------------------------------------------------- +template +void launch_dequant_int4(sycl::queue* q, const uint8_t* weights_NKp, const ScalarT* scales, const ScalarT* zeros, + ScalarT* weights_KN, int E, int N, int K, int group_size) { + if (E == 0 || N == 0 || K == 0) return; + if (Asym && zeros == nullptr) { + throw std::invalid_argument("moe_gemm_prefill(int4): zeros pointer required when asym=true"); + } + if ((K & 1) != 0) { + throw std::invalid_argument("moe_gemm_prefill(int4): K must be even"); + } + const int num_groups_k = K / group_size; + const int k_packed = K / 2; + + sycl::range<3> global{static_cast(E), static_cast(K), + static_cast((N + WG_N - 1) / WG_N) * WG_N}; + sycl::range<3> local{1, 1, static_cast(WG_N)}; + + q->parallel_for>( + sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { + const int e = static_cast(it.get_global_id(0)); + const int k = static_cast(it.get_global_id(1)); + const int n = static_cast(it.get_global_id(2)); + if (n >= N) return; + const int g = k / group_size; + const size_t s_idx = (static_cast(e) * N + static_cast(n)) * num_groups_k + + static_cast(g); + const float scale = static_cast(scales[s_idx]); + const uint8_t packed = + weights_NKp[(static_cast(e) * N + static_cast(n)) * k_packed + + static_cast(k / 2)]; + const bool is_high = (k & 1) != 0; + float w; + if constexpr (Asym) { + const float zero = static_cast(zeros[s_idx]); + const int q = static_cast(is_high ? ((packed >> 4) & 0x0F) : (packed & 0x0F)); + w = (static_cast(q) - zero) * scale; + } else { + // Sign-extend 4-bit -> 8-bit by shifting into the top nibble. + const int q = is_high + ? static_cast(static_cast(packed & 0xF0) >> 4) + : static_cast(static_cast(packed << 4) >> 4); + w = static_cast(q) * scale; + } + weights_KN[(static_cast(e) * K + static_cast(k)) * N + static_cast(n)] = + static_cast(w); + }); +} + +// ---------------------------------------------------------------------------- +// INT2 (S2_CLIP) dequant: [E, N, K/4] uint8 packed -> [E, K, N] ScalarT. +// Packing: byte = q0 | (q1<<2) | (q2<<4) | (q3<<6); field j corresponds to +// K index 4*i + j. +// Asym=false: signed in [-2, 1]; Asym=true: unsigned in [0, 3]. +// ---------------------------------------------------------------------------- +template +void launch_dequant_int2(sycl::queue* q, const uint8_t* weights_NKp, const ScalarT* scales, const ScalarT* zeros, + ScalarT* weights_KN, int E, int N, int K, int group_size) { + if (E == 0 || N == 0 || K == 0) return; + if (Asym && zeros == nullptr) { + throw std::invalid_argument("moe_gemm_prefill(int2): zeros pointer required when asym=true"); + } + if ((K & 3) != 0) { + throw std::invalid_argument("moe_gemm_prefill(int2): K must be a multiple of 4"); + } + const int num_groups_k = K / group_size; + const int k_packed = K / 4; + + sycl::range<3> global{static_cast(E), static_cast(K), + static_cast((N + WG_N - 1) / WG_N) * WG_N}; + sycl::range<3> local{1, 1, static_cast(WG_N)}; + + q->parallel_for>( + sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { + const int e = static_cast(it.get_global_id(0)); + const int k = static_cast(it.get_global_id(1)); + const int n = static_cast(it.get_global_id(2)); + if (n >= N) return; + const int g = k / group_size; + const size_t s_idx = (static_cast(e) * N + static_cast(n)) * num_groups_k + + static_cast(g); + const float scale = static_cast(scales[s_idx]); + const uint8_t packed = + weights_NKp[(static_cast(e) * N + static_cast(n)) * k_packed + + static_cast(k / 4)]; + const int field = k & 3; + float w; + if constexpr (Asym) { + const float zero = static_cast(zeros[s_idx]); + const int q = static_cast((packed >> (2 * field)) & 0x3); + w = (static_cast(q) - zero) * scale; + } else { + // Sign-extend 2-bit by shifting the field into the top bits of an int8. + const int shift = 6 - 2 * field; // 6, 4, 2, 0 for fields 0..3 + const int8_t s8 = static_cast((packed << shift) & 0xC0); + const int q = static_cast(s8 >> 6); + w = static_cast(q) * scale; + } + weights_KN[(static_cast(e) * K + static_cast(k)) * N + static_cast(n)] = + static_cast(w); + }); +} + +// ---------------------------------------------------------------------------- +// FP8 (E4M3 / E5M2) dequant: [E, N, K] uint8 -> [E, K, N] ScalarT. +// Per-group scale applied; no zero-points (caller must enforce sym). +// `UseLut` selects the LUT vs inline-bits decode at compile time +// (driven by `ARK_FP8_DECODE_USE_LUT`, same as the decode path). +// ---------------------------------------------------------------------------- +template +void launch_dequant_fp8(sycl::queue* q, const uint8_t* weights_NK, const ScalarT* scales, ScalarT* weights_KN, int E, + int N, int K, int group_size) { + if (E == 0 || N == 0 || K == 0) return; + const int num_groups_k = K / group_size; + + sycl::range<3> global{static_cast(E), static_cast(K), + static_cast((N + WG_N - 1) / WG_N) * WG_N}; + sycl::range<3> local{1, 1, static_cast(WG_N)}; + + q->parallel_for>( + sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { + const int e = static_cast(it.get_global_id(0)); + const int k = static_cast(it.get_global_id(1)); + const int n = static_cast(it.get_global_id(2)); + if (n >= N) return; + const int g = k / group_size; + const size_t s_idx = (static_cast(e) * N + static_cast(n)) * num_groups_k + + static_cast(g); + const float scale = static_cast(scales[s_idx]); + const uint8_t raw = + weights_NK[(static_cast(e) * N + static_cast(n)) * K + static_cast(k)]; + const float w = moe_dequant::decode_fp8(raw) * scale; + weights_KN[(static_cast(e) * K + static_cast(k)) * N + static_cast(n)] = + static_cast(w); + }); +} + +// ---------------------------------------------------------------------------- +// Dispatch helper: dequant any supported weight encoding into `weights_KN` +// (already-allocated `[E, K, N]` ScalarT buffer) using ScalarT == act dtype. +// ---------------------------------------------------------------------------- +template +void dequant_to_KN(sycl::queue* q, const void* weights, const void* scales, const void* zeros, ScalarT* weights_KN, + BTLA_DTYPE weight_dtype, int E, int N, int K, int group_size, bool asym) { + if (weight_dtype == BTLA_DTYPE::F16 || weight_dtype == BTLA_DTYPE::BF16) { + launch_dequant_fp(q, static_cast(weights), weights_KN, E, N, K); + return; + } + if (weight_dtype == BTLA_DTYPE::S8) { + if (asym) { + launch_dequant_int8(q, static_cast(weights), static_cast(scales), + static_cast(zeros), weights_KN, E, N, K, group_size); + } else { + launch_dequant_int8(q, static_cast(weights), static_cast(scales), + static_cast(zeros), weights_KN, E, N, K, group_size); + } + return; + } + if (weight_dtype == BTLA_DTYPE::S4_CLIP) { + if (asym) { + launch_dequant_int4(q, static_cast(weights), static_cast(scales), + static_cast(zeros), weights_KN, E, N, K, group_size); + } else { + launch_dequant_int4(q, static_cast(weights), static_cast(scales), + static_cast(zeros), weights_KN, E, N, K, group_size); + } + return; + } + if (weight_dtype == BTLA_DTYPE::S2_CLIP) { + if (asym) { + launch_dequant_int2(q, static_cast(weights), static_cast(scales), + static_cast(zeros), weights_KN, E, N, K, group_size); + } else { + launch_dequant_int2(q, static_cast(weights), static_cast(scales), + static_cast(zeros), weights_KN, E, N, K, group_size); + } + return; + } + if (weight_dtype == BTLA_DTYPE::F8_E4M3 || weight_dtype == BTLA_DTYPE::F8_E5M2) { + if (asym) { + throw std::invalid_argument("moe_gemm_prefill(fp8): asym mode is not supported"); + } + const bool is_e4m3 = (weight_dtype == BTLA_DTYPE::F8_E4M3); + const bool use_lut = moe_dequant::fp8_decode_use_lut(); + if (is_e4m3) { + if (use_lut) { + launch_dequant_fp8(q, static_cast(weights), + static_cast(scales), weights_KN, E, N, K, + group_size); + } else { + launch_dequant_fp8(q, static_cast(weights), + static_cast(scales), weights_KN, E, N, K, + group_size); + } + } else { + if (use_lut) { + launch_dequant_fp8(q, static_cast(weights), + static_cast(scales), weights_KN, E, N, K, + group_size); + } else { + launch_dequant_fp8(q, static_cast(weights), + static_cast(scales), weights_KN, E, N, K, + group_size); + } + } + return; + } + throw std::invalid_argument( + "moe_gemm_prefill: unsupported weight_dtype (supported: F16, BF16, S8, S4_CLIP, S2_CLIP, F8_E4M3, F8_E5M2)"); +} + +} // namespace moe_mixed_detail + +// ---------------------------------------------------------------------------- +// Public API +// +// MoE prefill (Grouped GEMM) supporting the same set of weight encodings as +// `moe_gemm_decode`. Activations/scales/zeros/outputs are in `act_dtype` +// (FP16 or BF16). For the unquantized fast path (weight_dtype matches +// act_dtype and is FP16/BF16), the call is forwarded directly to the +// existing `moe_gemm` (which already expects `[E, K, N]` row-major +// weights). For all other dtypes, weights are dequantised on-device into +// a temporary `[E, K, N]` buffer of `act_dtype` and then handed to the +// same grouped GEMM. The temporary buffer must be supplied by the caller +// (see the Python wrapper which allocates it sized +// `E * K * N * sizeof(act_dtype)`); this avoids per-call USM allocations +// and keeps memory ownership in PyTorch's caching allocator. +// +// IMPORTANT layout note: +// - Quantized `weights` are `[E, N, K_packed]` (decode-style). +// - The unquantized fast path (`act_dtype == weight_dtype` and FP/BF16) +// forwards directly to `moe_gemm`, which expects `[E, K, N]` weights. +// Callers must therefore pass already-`[E, K, N]` weights for the +// unquantized fast path (matching the existing `moe_gemm` contract). +// The Python wrapper handles this by exposing the unquantized path +// under the same shape contract as `moe_gemm` and the quantized paths +// under the same shape contract as `moe_gemm_decode`. +// ---------------------------------------------------------------------------- +inline void moe_gemm_prefill(sycl::queue* q, void* activations, void* weights, void* scales, void* zeros, + void* outputs, void* dequant_workspace, BTLA_DTYPE act_dtype, BTLA_DTYPE weight_dtype, + int N, int K, int group_size, int* num_tokens_per_expert, int num_experts, + int total_tokens, bool asym) { + if (total_tokens == 0) return; + + // Unquantized fast path: forward directly to the existing prefill GEMM. + // The caller is responsible for passing weights in `[E, K, N]` layout + // (matching the existing `moe_gemm` contract). + if (weight_dtype == act_dtype && (weight_dtype == BTLA_DTYPE::F16 || weight_dtype == BTLA_DTYPE::BF16)) { + moe_gemm(q, activations, weights, scales, outputs, act_dtype, N, K, num_tokens_per_expert, num_experts); + return; + } + + if (dequant_workspace == nullptr) { + throw std::invalid_argument("moe_gemm_prefill: dequant_workspace must be non-null for quantized paths"); + } + + if (act_dtype == BTLA_DTYPE::F16) { + auto* w_kn = static_cast(dequant_workspace); + moe_mixed_detail::dequant_to_KN(q, weights, scales, zeros, w_kn, weight_dtype, num_experts, N, K, + group_size, asym); + moe_gemm(q, activations, w_kn, /*scales=*/nullptr, outputs, act_dtype, N, K, num_tokens_per_expert, num_experts); + } else if (act_dtype == BTLA_DTYPE::BF16) { + using BF = sycl::ext::oneapi::bfloat16; + auto* w_kn = static_cast(dequant_workspace); + moe_mixed_detail::dequant_to_KN(q, weights, scales, zeros, w_kn, weight_dtype, num_experts, N, K, group_size, + asym); + moe_gemm(q, activations, w_kn, /*scales=*/nullptr, outputs, act_dtype, N, K, num_tokens_per_expert, num_experts); + } else { + throw std::invalid_argument("moe_gemm_prefill: act_dtype must be F16 or BF16"); + } +} + +} // namespace ark + +#endif // ARK_XPU && ARK_SYCL_TLA diff --git a/auto_round_extension/ark/test/test_moe.py b/auto_round_extension/ark/test/test_moe.py index dfb60f6ff..363cd4178 100644 --- a/auto_round_extension/ark/test/test_moe.py +++ b/auto_round_extension/ark/test/test_moe.py @@ -543,6 +543,199 @@ def test_decode_validation_errors(self): asym=True, ) + +def has_moe_gemm_prefill(): + """Check if the quantized MoE prefill GEMM kernel is available.""" + if ark.xpu_lib is None: + return False + return hasattr(ark.xpu_lib, "moe_gemm_prefill") + + +@pytest.mark.skipif(not is_xpu_available(), reason="XPU not available") +@pytest.mark.skipif(not has_moe_gemm_prefill(), reason="MoE prefill GEMM kernel not built (need ARK_SYCL_TLA=ON)") +class TestMoEGemmPrefill: + """Parity tests for the quantized MoE prefill kernel. + + The prefill kernel accepts the same quantized weight encodings as + ``moe_gemm_decode`` (decode-style ``[E, N, K_packed]``); these tests check + that ``moe_gemm_prefill(q_weights)`` matches ``moe_gemm`` evaluated on the + same dequantized weights, for multi-token-per-expert workloads typical of + the prefill phase. + """ + + @staticmethod + def _run_prefill_reference(activations, weights_NK, num_tokens_per_expert): + """Reference: per-expert ``A @ W.T`` over ``[E, N, K]`` weights.""" + total_tokens, _ = activations.shape + E, N, _ = weights_NK.shape + out = torch.empty(total_tokens, N, dtype=activations.dtype, device=activations.device) + offset = 0 + for e in range(E): + n_tokens = int(num_tokens_per_expert[e].item()) + if n_tokens == 0: + continue + a = activations[offset:offset + n_tokens] # [n_tokens, K] + out[offset:offset + n_tokens] = a @ weights_NK[e].T + offset += n_tokens + return out + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_prefill_fp_basic(self, dtype): + # Many tokens per expert (prefill regime). + num_experts = 4 + tokens_per_expert = [16, 4, 0, 12] + total_tokens = sum(tokens_per_expert) + N, K = 128, 128 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + weights_NK = torch.randn(num_experts, N, K, dtype=dtype, device="xpu") * 0.1 + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill(activations, weights_NK, num_tokens_per_expert, weight_bits=16) + + ref = self._run_prefill_reference(activations, weights_NK, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + assert out.dtype == dtype + torch.testing.assert_close(out, ref, rtol=2e-2, atol=2e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_prefill_int4_sym(self, dtype, group_size): + num_experts = 4 + tokens_per_expert = [8, 8, 0, 16] + total_tokens = sum(tokens_per_expert) + N, K = 128, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_sym(w_float, scales, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, packed, num_tokens_per_expert, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + ) + + dequant = _dequant_int4_sym(packed, scales, group_size).to(dtype) + ref = self._run_prefill_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_prefill_int4_asym(self, dtype): + num_experts = 4 + group_size = 128 + tokens_per_expert = [4, 8, 12, 8] + total_tokens = sum(tokens_per_expert) + N, K = 128, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_asym(w_float, scales, zeros, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, packed, num_tokens_per_expert, + scales=scales, zeros=zeros, weight_bits=4, group_size=group_size, asym=True, + ) + + dequant = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) + ref = self._run_prefill_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_prefill_int8(self, dtype, asym): + num_experts = 4 + group_size = 128 + tokens_per_expert = [8, 8, 16, 0] + total_tokens = sum(tokens_per_expert) + N, K = 128, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int8_sym(w_float, scales, group_size) + dequant = _dequant_int8_sym(packed, scales, group_size).to(dtype) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, packed, num_tokens_per_expert, + scales=scales, zeros=zeros, weight_bits=8, group_size=group_size, asym=asym, + ) + + ref = self._run_prefill_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_prefill_int2(self, dtype, asym): + num_experts = 4 + group_size = 128 + tokens_per_expert = [4, 8, 4, 16] + total_tokens = sum(tokens_per_expert) + N, K = 128, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int2_sym(w_float, scales, group_size) + dequant = _dequant_int2_sym(packed, scales, group_size).to(dtype) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, packed, num_tokens_per_expert, + scales=scales, zeros=zeros, weight_bits=2, group_size=group_size, asym=asym, + ) + + ref = self._run_prefill_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + # Int2 quantization noise is high; relax tolerances. + torch.testing.assert_close(out, ref, rtol=1e-1, atol=1e-1) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + def test_prefill_fp8(self, dtype, fp8_dtype): + num_experts = 4 + group_size = 128 + tokens_per_expert = [8, 0, 12, 4] + total_tokens = sum(tokens_per_expert) + N, K = 128, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, packed, num_tokens_per_expert, + scales=scales, group_size=group_size, asym=False, + ) + + dequant = _dequant_fp8(packed, scales, group_size, dtype) + ref = self._run_prefill_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + # FP8 + asym is rejected fp8_w = torch.zeros(num_experts, 64, 128, dtype=torch.float8_e4m3fn, device="xpu") zeros = torch.empty(num_experts, 64, 1, dtype=torch.float16, device="xpu") From 9dc0f1598fd536083a7b555fbda93666ed67087d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 8 Jun 2026 02:12:32 +0000 Subject: [PATCH 15/28] fix: restore patch_torch_sdpa def; relocate TestMoEGemmPrefill out of TestMoEGemmDecode --- .../ark/auto_round_kernel/__init__.py | 2 +- auto_round_extension/ark/test/test_moe.py | 343 +++++++++--------- 2 files changed, 172 insertions(+), 173 deletions(-) diff --git a/auto_round_extension/ark/auto_round_kernel/__init__.py b/auto_round_extension/ark/auto_round_kernel/__init__.py index 766fa46e5..7b46ed738 100644 --- a/auto_round_extension/ark/auto_round_kernel/__init__.py +++ b/auto_round_extension/ark/auto_round_kernel/__init__.py @@ -1480,7 +1480,7 @@ def moe_gemm_prefill( return outputs - +def patch_torch_sdpa(*args, **kwargs): from .torch_sdpa_patch import patch_torch_sdpa_with_ark return patch_torch_sdpa_with_ark(*args, **kwargs) diff --git a/auto_round_extension/ark/test/test_moe.py b/auto_round_extension/ark/test/test_moe.py index 363cd4178..8e6783974 100644 --- a/auto_round_extension/ark/test/test_moe.py +++ b/auto_round_extension/ark/test/test_moe.py @@ -544,6 +544,177 @@ def test_decode_validation_errors(self): ) + # FP8 + asym is rejected + fp8_w = torch.zeros(num_experts, 64, 128, dtype=torch.float8_e4m3fn, device="xpu") + zeros = torch.empty(num_experts, 64, 1, dtype=torch.float16, device="xpu") + with pytest.raises(ValueError): + ark.moe_gemm_decode( + activations, + fp8_w, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + group_size=128, + asym=True, + ) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_decode_int8_sym(self, dtype, group_size): + num_experts = 4 + tokens_per_expert = [1, 1, 0, 2] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_sym(w_float, scales, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + weight_bits=8, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_int8_sym(packed, scales, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_decode_int8_asym(self, dtype): + num_experts = 4 + group_size = 128 + tokens_per_expert = [0, 1, 2, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_asym(w_float, scales, zeros, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=8, + group_size=group_size, + asym=True, + ) + + dequant = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_decode_int2_sym(self, dtype, group_size): + num_experts = 4 + tokens_per_expert = [1, 1, 0, 2] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_sym(w_float, scales, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + weight_bits=2, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_int2_sym(packed, scales, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + # Int2 has much higher quant error; relax tolerance vs int4. + torch.testing.assert_close(out, ref, rtol=1e-1, atol=1e-1) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_decode_int2_asym(self, dtype): + num_experts = 4 + group_size = 128 + tokens_per_expert = [0, 1, 2, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_asym(w_float, scales, zeros, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=2, + group_size=group_size, + asym=True, + ) + + dequant = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=1e-1, atol=1e-1) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_decode_fp8(self, dtype, fp8_dtype, group_size): + num_experts = 4 + tokens_per_expert = [1, 0, 2, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_fp8(packed, scales, group_size, dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + # E5M2 has only 2 mantissa bits -> coarser; relax tolerance for both. + rtol = 1e-1 if fp8_dtype == torch.float8_e5m2 else 5e-2 + atol = 1e-1 if fp8_dtype == torch.float8_e5m2 else 5e-2 + torch.testing.assert_close(out, ref, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) + def has_moe_gemm_prefill(): """Check if the quantized MoE prefill GEMM kernel is available.""" if ark.xpu_lib is None: @@ -734,175 +905,3 @@ def test_prefill_fp8(self, dtype, fp8_dtype): ref = self._run_prefill_reference(activations, dequant, num_tokens_per_expert) assert out.shape == (total_tokens, N) torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) - - - # FP8 + asym is rejected - fp8_w = torch.zeros(num_experts, 64, 128, dtype=torch.float8_e4m3fn, device="xpu") - zeros = torch.empty(num_experts, 64, 1, dtype=torch.float16, device="xpu") - with pytest.raises(ValueError): - ark.moe_gemm_decode( - activations, - fp8_w, - num_tokens_per_expert, - scales=scales, - zeros=zeros, - group_size=128, - asym=True, - ) - - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) - @pytest.mark.parametrize("group_size", [32, 128]) - def test_decode_int8_sym(self, dtype, group_size): - num_experts = 4 - tokens_per_expert = [1, 1, 0, 2] - total_tokens = sum(tokens_per_expert) - N, K = 256, 256 - - activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") - w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) - scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") - packed = _pack_int8_sym(w_float, scales, group_size) - num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") - - out = ark.moe_gemm_decode( - activations, - packed, - num_tokens_per_expert, - scales=scales, - weight_bits=8, - group_size=group_size, - asym=False, - ) - - dequant = _dequant_int8_sym(packed, scales, group_size).to(dtype) - ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) - assert out.shape == (total_tokens, N) - torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) - - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) - def test_decode_int8_asym(self, dtype): - num_experts = 4 - group_size = 128 - tokens_per_expert = [0, 1, 2, 1] - total_tokens = sum(tokens_per_expert) - N, K = 256, 256 - - activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") - w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) - scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") - zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") - packed = _pack_int8_asym(w_float, scales, zeros, group_size) - num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") - - out = ark.moe_gemm_decode( - activations, - packed, - num_tokens_per_expert, - scales=scales, - zeros=zeros, - weight_bits=8, - group_size=group_size, - asym=True, - ) - - dequant = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) - ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) - assert out.shape == (total_tokens, N) - torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) - - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) - @pytest.mark.parametrize("group_size", [32, 128]) - def test_decode_int2_sym(self, dtype, group_size): - num_experts = 4 - tokens_per_expert = [1, 1, 0, 2] - total_tokens = sum(tokens_per_expert) - N, K = 256, 256 - - activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") - w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) - scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") - packed = _pack_int2_sym(w_float, scales, group_size) - num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") - - out = ark.moe_gemm_decode( - activations, - packed, - num_tokens_per_expert, - scales=scales, - weight_bits=2, - group_size=group_size, - asym=False, - ) - - dequant = _dequant_int2_sym(packed, scales, group_size).to(dtype) - ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) - assert out.shape == (total_tokens, N) - # Int2 has much higher quant error; relax tolerance vs int4. - torch.testing.assert_close(out, ref, rtol=1e-1, atol=1e-1) - - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) - def test_decode_int2_asym(self, dtype): - num_experts = 4 - group_size = 128 - tokens_per_expert = [0, 1, 2, 1] - total_tokens = sum(tokens_per_expert) - N, K = 256, 256 - - activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") - w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) - scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") - zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") - packed = _pack_int2_asym(w_float, scales, zeros, group_size) - num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") - - out = ark.moe_gemm_decode( - activations, - packed, - num_tokens_per_expert, - scales=scales, - zeros=zeros, - weight_bits=2, - group_size=group_size, - asym=True, - ) - - dequant = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) - ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) - assert out.shape == (total_tokens, N) - torch.testing.assert_close(out, ref, rtol=1e-1, atol=1e-1) - - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) - @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) - @pytest.mark.parametrize("group_size", [32, 128]) - def test_decode_fp8(self, dtype, fp8_dtype, group_size): - num_experts = 4 - tokens_per_expert = [1, 0, 2, 1] - total_tokens = sum(tokens_per_expert) - N, K = 256, 256 - - activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") - w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) - scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") - packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) - num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") - - out = ark.moe_gemm_decode( - activations, - packed, - num_tokens_per_expert, - scales=scales, - group_size=group_size, - asym=False, - ) - - dequant = _dequant_fp8(packed, scales, group_size, dtype) - ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) - assert out.shape == (total_tokens, N) - # E5M2 has only 2 mantissa bits -> coarser; relax tolerance for both. - rtol = 1e-1 if fp8_dtype == torch.float8_e5m2 else 5e-2 - atol = 1e-1 if fp8_dtype == torch.float8_e5m2 else 5e-2 - torch.testing.assert_close(out, ref, rtol=rtol, atol=atol) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) From 0e78f3138a3c1249c192024d0125a4280b61edf2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Jun 2026 03:19:28 +0000 Subject: [PATCH 16/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../ark/auto_round_kernel/__init__.py | 31 ++++++----- auto_round_extension/ark/test/test_moe.py | 53 ++++++++++++++----- 2 files changed, 59 insertions(+), 25 deletions(-) diff --git a/auto_round_extension/ark/auto_round_kernel/__init__.py b/auto_round_extension/ark/auto_round_kernel/__init__.py index 7b46ed738..b4bbc4cc0 100644 --- a/auto_round_extension/ark/auto_round_kernel/__init__.py +++ b/auto_round_extension/ark/auto_round_kernel/__init__.py @@ -1120,8 +1120,14 @@ def moe_gemm_decode( """ activations, weights, scales, zeros, num_tokens_per_expert, weight_dtype, total_tokens, N, K, num_experts = ( _validate_moe_quant_args( - activations, weights, num_tokens_per_expert, - scales=scales, zeros=zeros, weight_bits=weight_bits, group_size=group_size, asym=asym, + activations, + weights, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=weight_bits, + group_size=group_size, + asym=asym, api_name="moe_gemm_decode", ) ) @@ -1202,9 +1208,7 @@ def _validate_moe_quant_args( N = weights.shape[1] if num_tokens_per_expert.shape[0] != num_experts: - raise ValueError( - f"num_tokens_per_expert length {num_tokens_per_expert.shape[0]} != num_experts {num_experts}" - ) + raise ValueError(f"num_tokens_per_expert length {num_tokens_per_expert.shape[0]} != num_experts {num_experts}") # Detect FP8 weight dtype first (overrides weight_bits). is_fp8 = weights.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) @@ -1294,8 +1298,7 @@ def _validate_moe_quant_args( if expected_total != total_tokens: raise ValueError(f"Sum of num_tokens_per_expert ({expected_total}) != total_tokens ({total_tokens})") - return (activations, weights, scales, zeros, num_tokens_per_expert, - weight_dtype, total_tokens, N, K, num_experts) + return (activations, weights, scales, zeros, num_tokens_per_expert, weight_dtype, total_tokens, N, K, num_experts) def moe_gemm( @@ -1419,8 +1422,14 @@ def moe_gemm_prefill( """ activations, weights, scales, zeros, num_tokens_per_expert, weight_dtype, total_tokens, N, K, num_experts = ( _validate_moe_quant_args( - activations, weights, num_tokens_per_expert, - scales=scales, zeros=zeros, weight_bits=weight_bits, group_size=group_size, asym=asym, + activations, + weights, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=weight_bits, + group_size=group_size, + asym=asym, api_name="moe_gemm_prefill", ) ) @@ -1446,9 +1455,7 @@ def moe_gemm_prefill( dequant_workspace = weights.transpose(1, 2).contiguous() weights_ptr = dequant_workspace.data_ptr() else: - dequant_workspace = torch.empty( - (num_experts, K, N), device=activations.device, dtype=activations.dtype - ) + dequant_workspace = torch.empty((num_experts, K, N), device=activations.device, dtype=activations.dtype) weights_ptr = weights.data_ptr() scales_ptr = scales.data_ptr() if scales is not None else 0 diff --git a/auto_round_extension/ark/test/test_moe.py b/auto_round_extension/ark/test/test_moe.py index 8e6783974..8bad20cf9 100644 --- a/auto_round_extension/ark/test/test_moe.py +++ b/auto_round_extension/ark/test/test_moe.py @@ -543,7 +543,6 @@ def test_decode_validation_errors(self): asym=True, ) - # FP8 + asym is rejected fp8_w = torch.zeros(num_experts, 64, 128, dtype=torch.float8_e4m3fn, device="xpu") zeros = torch.empty(num_experts, 64, 1, dtype=torch.float16, device="xpu") @@ -715,6 +714,7 @@ def test_decode_fp8(self, dtype, fp8_dtype, group_size): if __name__ == "__main__": pytest.main([__file__, "-v"]) + def has_moe_gemm_prefill(): """Check if the quantized MoE prefill GEMM kernel is available.""" if ark.xpu_lib is None: @@ -745,8 +745,8 @@ def _run_prefill_reference(activations, weights_NK, num_tokens_per_expert): n_tokens = int(num_tokens_per_expert[e].item()) if n_tokens == 0: continue - a = activations[offset:offset + n_tokens] # [n_tokens, K] - out[offset:offset + n_tokens] = a @ weights_NK[e].T + a = activations[offset : offset + n_tokens] # [n_tokens, K] + out[offset : offset + n_tokens] = a @ weights_NK[e].T offset += n_tokens return out @@ -784,8 +784,13 @@ def test_prefill_int4_sym(self, dtype, group_size): num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") out = ark.moe_gemm_prefill( - activations, packed, num_tokens_per_expert, - scales=scales, weight_bits=4, group_size=group_size, asym=False, + activations, + packed, + num_tokens_per_expert, + scales=scales, + weight_bits=4, + group_size=group_size, + asym=False, ) dequant = _dequant_int4_sym(packed, scales, group_size).to(dtype) @@ -809,8 +814,14 @@ def test_prefill_int4_asym(self, dtype): num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") out = ark.moe_gemm_prefill( - activations, packed, num_tokens_per_expert, - scales=scales, zeros=zeros, weight_bits=4, group_size=group_size, asym=True, + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=4, + group_size=group_size, + asym=True, ) dequant = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) @@ -841,8 +852,14 @@ def test_prefill_int8(self, dtype, asym): num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") out = ark.moe_gemm_prefill( - activations, packed, num_tokens_per_expert, - scales=scales, zeros=zeros, weight_bits=8, group_size=group_size, asym=asym, + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=8, + group_size=group_size, + asym=asym, ) ref = self._run_prefill_reference(activations, dequant, num_tokens_per_expert) @@ -872,8 +889,14 @@ def test_prefill_int2(self, dtype, asym): num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") out = ark.moe_gemm_prefill( - activations, packed, num_tokens_per_expert, - scales=scales, zeros=zeros, weight_bits=2, group_size=group_size, asym=asym, + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=2, + group_size=group_size, + asym=asym, ) ref = self._run_prefill_reference(activations, dequant, num_tokens_per_expert) @@ -897,8 +920,12 @@ def test_prefill_fp8(self, dtype, fp8_dtype): num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") out = ark.moe_gemm_prefill( - activations, packed, num_tokens_per_expert, - scales=scales, group_size=group_size, asym=False, + activations, + packed, + num_tokens_per_expert, + scales=scales, + group_size=group_size, + asym=False, ) dequant = _dequant_fp8(packed, scales, group_size, dtype) From d48f45d85dcebce642c2da069d3e5d96d3c58ecf Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Jun 2026 04:11:12 +0000 Subject: [PATCH 17/28] feat: add MoE prefill performance test with TFLOPS calculation --- .../ark/test/test_moe_prefill_perf.py | 442 ++++++++++++++++++ 1 file changed, 442 insertions(+) create mode 100644 auto_round_extension/ark/test/test_moe_prefill_perf.py diff --git a/auto_round_extension/ark/test/test_moe_prefill_perf.py b/auto_round_extension/ark/test/test_moe_prefill_perf.py new file mode 100644 index 000000000..ac59dbd3e --- /dev/null +++ b/auto_round_extension/ark/test/test_moe_prefill_perf.py @@ -0,0 +1,442 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Performance benchmark: ``ark.moe_gemm`` for MoE prefill workloads. + +MoE prefill is the matrix-matrix multiplication phase where many tokens (e.g., +the entire prompt or a batch of sequences) are routed to different experts. +Unlike decode (one token per expert), prefill has multiple tokens per expert, +making it a batched GEMM problem. + +This benchmark measures throughput (TFLOPS) and compares against a baseline +PyTorch implementation. The baseline uses per-expert matrix multiplication +with already-dequantized weights (for quantized tests), representing the +standard fallback path when no fused MoE kernel is available. + +How to run:: + + pytest -v -s auto_round_extension/ark/test/test_moe_prefill_perf.py + +The ``-s`` flag is required to see the printed timing tables and TFLOPS. +""" + +import auto_round_kernel +import pytest +import torch + +# Reuse pack/dequant helpers from the correctness tests +from test_moe import ( # noqa: E402 + _dequant_fp8, + _dequant_int2_asym, + _dequant_int2_sym, + _dequant_int4_asym, + _dequant_int4_sym, + _dequant_int8_asym, + _dequant_int8_sym, + _pack_fp8, + _pack_int2_asym, + _pack_int2_sym, + _pack_int4_asym, + _pack_int4_sym, + _pack_int8_asym, + _pack_int8_sym, +) + +ark = auto_round_kernel.ARK() + + +# --------------------------------------------------------------------------- +# Skip reasons +# --------------------------------------------------------------------------- + + +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def _xpu_skip_reason() -> str: + if not hasattr(torch, "xpu"): + return "torch has no xpu submodule (need an Intel XPU build of torch)" + if not torch.xpu.is_available(): + return "torch.xpu.is_available() == False (no XPU device or driver visible)" + return "" + + +def _prefill_skip_reason() -> str: + """Return non-empty string if the MoE prefill kernel can't be exercised.""" + reason = _xpu_skip_reason() + if reason: + return reason + if ark.xpu_lib is None: + return ( + "ark.xpu_lib is None -- the XPU extension module " + "(auto_round_kernel_xpu) failed to import; check that auto_round_kernel " + "was installed for THIS Python env with XPU support enabled" + ) + if not hasattr(ark.xpu_lib, "moe_gemm"): + return ( + "ark.xpu_lib loaded but has no moe_gemm symbol -- " + "rebuild with ARK_SYCL_TLA=ON to compile the MoE GEMM kernel" + ) + return "" + + +_PREFILL_SKIP = _prefill_skip_reason() + +# Surface diagnostics on collection +print( + "[moe-prefill-perf] xpu_available=%s xpu_lib=%s has_moe_gemm=%s" + % ( + _xpu_available(), + "loaded" if ark.xpu_lib is not None else "None", + hasattr(ark.xpu_lib, "moe_gemm") if ark.xpu_lib is not None else False, + ) +) +if _PREFILL_SKIP: + print("[moe-prefill-perf] suite will SKIP. reason: %s" % _PREFILL_SKIP) + + +# --------------------------------------------------------------------------- +# Timing utilities +# --------------------------------------------------------------------------- + +WARMUP = 5 +ITERS = 30 + + +def _xpu_time_ms(fn, warmup: int = WARMUP, iters: int = ITERS) -> float: + """Time ``fn`` on XPU using device events; returns median ms per call.""" + for _ in range(warmup): + fn() + torch.xpu.synchronize() + + timings = [] + for _ in range(iters): + start = torch.xpu.Event(enable_timing=True) + end = torch.xpu.Event(enable_timing=True) + start.record() + fn() + end.record() + end.synchronize() + timings.append(start.elapsed_time(end)) + timings.sort() + return timings[len(timings) // 2] + + +def _default_moe_prefill(activations, dequant_weights, num_tokens_per_expert): + """Default XPU MoE prefill baseline: per-expert torch matmul. + + This mirrors the path a model would take when no fused MoE kernel is + available: iterate over experts and do ``A @ W.T`` on each token slice. + For prefill, each expert may have many tokens (unlike decode where each + expert typically has one token). + """ + total_tokens, K = activations.shape + E, N, _ = dequant_weights.shape + out = torch.empty(total_tokens, N, dtype=activations.dtype, device=activations.device) + offset = 0 + for e in range(E): + n_tokens = int(num_tokens_per_expert[e].item()) + if n_tokens == 0: + continue + a = activations[offset : offset + n_tokens] + # Weights are [N, K], activations are [n_tokens, K] + # Output is [n_tokens, N] + out[offset : offset + n_tokens] = a @ dequant_weights[e].T + offset += n_tokens + return out + + +# --------------------------------------------------------------------------- +# TFLOPS calculation +# --------------------------------------------------------------------------- + + +def _compute_moe_flops(total_tokens, K, N, num_experts_active): + """Compute FLOPs for MoE GEMM: sum over active experts of (tokens * K * N * 2). + + For a typical prefill, all experts may be active with varying token counts. + As a simplification, we use the total_tokens across all experts. + """ + # Each GEMM is [tokens, K] @ [K, N] -> [tokens, N] + # FLOPs = tokens * K * N * 2 (multiply-add counts as 2 ops) + return total_tokens * K * N * 2 + + +# --------------------------------------------------------------------------- +# Shape matrix for prefill +# +# Prefill has many tokens per expert (e.g., batch size, prompt length). +# We test small to large expert counts and token distributions typical of +# MoE models during prefill (Mixtral, DeepSeek, etc.). +# --------------------------------------------------------------------------- + +PREFILL_SHAPES = [ + # (label, num_experts, tokens_per_expert_list, N, K) + # Small models (e.g., Mixtral 8x7B style) + ("small E=8 ", 8, [32, 28, 30, 35, 33, 31, 29, 34], 4096, 4096), + ("medium E=8 ", 8, [64, 60, 68, 72, 65, 63, 70, 66], 4096, 14336), # up-proj + ("medium E=8 ", 8, [64, 60, 68, 72, 65, 63, 70, 66], 14336, 4096), # down-proj + # Larger models (e.g., DeepSeek style with more experts) + ("large E=16", 16, [16] * 16, 2048, 2048), + ("large E=32", 32, [8] * 32, 2048, 2048), + ("large E=64", 64, [4] * 64, 2048, 2048), + # Uneven distribution (some experts get more tokens) + ("uneven E=8 ", 8, [100, 50, 75, 80, 60, 90, 70, 85], 4096, 4096), +] + + +def _print_header(title: str) -> None: + print() + print("=" * 110) + print(title) + print( + f"{'shape':<14}{'E':>4}{'N':>7}{'K':>7}{'tokens':>8}" + f"{'baseline(ms)':>16}{'ark(ms)':>14}{'speedup':>12}{'TFLOPS':>10}" + ) + print("-" * 110) + + +def _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops): + speedup = base_ms / ark_ms if ark_ms > 0 else float("nan") + print( + f"{label:<14}{E:>4}{N:>7}{K:>7}{total_tokens:>8}" + f"{base_ms:>16.4f}{ark_ms:>14.4f}{speedup:>11.2f}x{tflops:>9.1f}" + ) + + +# --------------------------------------------------------------------------- +# Benchmark cases +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(bool(_PREFILL_SKIP), reason=_PREFILL_SKIP or "ok") +class TestMoEGemmPrefillPerf: + """Median XPU-event timings of ``moe_gemm`` vs per-expert matrix multiply. + + The baseline uses *already-dequantized* weights for quantized tests, so + the timed region only measures matmul cost. This is the most favorable + apples-to-apples comparison for the baseline. + """ + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_perf_fp(self, dtype): + _print_header(f"FP weights ({str(dtype).split('.')[-1]}) -- ark.moe_gemm (prefill) vs per-expert A @ W.T") + for label, E, tpe, N, K in PREFILL_SHAPES: + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + # Weights layout: [E, K, N] for moe_gemm + weights = (torch.randn(E, K, N, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + # Baseline: per-expert matmul. Weights need to be [E, N, K] for the baseline. + weights_baseline = weights.transpose(1, 2) # [E, N, K] + + base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, weights_baseline, ntpe)) + ark_ms = _xpu_time_ms(lambda: ark.moe_gemm(activations, weights, ntpe)) + + # Compute TFLOPS + flops = _compute_moe_flops(total_tokens, K, N, E) + tflops = flops / (ark_ms * 1e-3) / 1e12 + + _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_perf_int4(self, dtype, asym): + group_size = 128 + kind = "asym" if asym else "sym" + _print_header( + f"INT4 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " + f"-- ark.moe_gemm (prefill) vs dequant + per-expert A @ W.T" + ) + for label, E, tpe, N, K in PREFILL_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + # Generate weights [E, K, N] + w_float = (torch.randn(E, K, N, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + # Pack with [E, K, N] layout + packed = _pack_int4_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int4_sym(w_float, scales, group_size) + dequant = _dequant_int4_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + # Baseline expects [E, N, K] + dequant_baseline = dequant.transpose(1, 2) + + base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant_baseline, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm( + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=4, + group_size=group_size, + asym=asym, + ) + ) + + flops = _compute_moe_flops(total_tokens, K, N, E) + tflops = flops / (ark_ms * 1e-3) / 1e12 + + _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_perf_int8(self, dtype, asym): + group_size = 128 + kind = "asym" if asym else "sym" + _print_header( + f"INT8 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " + f"-- ark.moe_gemm (prefill) vs dequant + per-expert A @ W.T" + ) + for label, E, tpe, N, K in PREFILL_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, K, N, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int8_sym(w_float, scales, group_size) + dequant = _dequant_int8_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + dequant_baseline = dequant.transpose(1, 2) + + base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant_baseline, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm( + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=8, + group_size=group_size, + asym=asym, + ) + ) + + flops = _compute_moe_flops(total_tokens, K, N, E) + tflops = flops / (ark_ms * 1e-3) / 1e12 + + _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_perf_int2(self, dtype, asym): + group_size = 128 + kind = "asym" if asym else "sym" + _print_header( + f"INT2 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " + f"-- ark.moe_gemm (prefill) vs dequant + per-expert A @ W.T" + ) + for label, E, tpe, N, K in PREFILL_SHAPES: + if K % group_size != 0 or K % 4 != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, K, N, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int2_sym(w_float, scales, group_size) + dequant = _dequant_int2_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + dequant_baseline = dequant.transpose(1, 2) + + base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant_baseline, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm( + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=2, + group_size=group_size, + asym=asym, + ) + ) + + flops = _compute_moe_flops(total_tokens, K, N, E) + tflops = flops / (ark_ms * 1e-3) / 1e12 + + _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + def test_perf_fp8(self, dtype, fp8_dtype): + group_size = 128 + _print_header( + f"FP8 {str(fp8_dtype).split('.')[-1]} (group_size={group_size}, " + f"act={str(dtype).split('.')[-1]}) -- ark.moe_gemm (prefill) vs dequant + per-expert A @ W.T" + ) + for label, E, tpe, N, K in PREFILL_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, K, N, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) + dequant = _dequant_fp8(packed, scales, group_size, dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + dequant_baseline = dequant.transpose(1, 2) + + base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant_baseline, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm( + activations, + packed, + ntpe, + scales=scales, + group_size=group_size, + asym=False, + ) + ) + + flops = _compute_moe_flops(total_tokens, K, N, E) + tflops = flops / (ark_ms * 1e-3) / 1e12 + + _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) From b07b59f8b5ec01af9554cfb17f773087352df27e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Jun 2026 04:11:50 +0000 Subject: [PATCH 18/28] docs: add MoE prefill performance test documentation --- .../ark/test/README_MOE_PREFILL_PERF.md | 151 ++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 auto_round_extension/ark/test/README_MOE_PREFILL_PERF.md diff --git a/auto_round_extension/ark/test/README_MOE_PREFILL_PERF.md b/auto_round_extension/ark/test/README_MOE_PREFILL_PERF.md new file mode 100644 index 000000000..5bebc2af5 --- /dev/null +++ b/auto_round_extension/ark/test/README_MOE_PREFILL_PERF.md @@ -0,0 +1,151 @@ +# MoE Prefill Performance Test + +## Overview + +The `test_moe_prefill_perf.py` file provides comprehensive performance benchmarks for MoE (Mixture of Experts) prefill operations with TFLOPS (Tera Floating Point Operations Per Second) calculations. + +## What is MoE Prefill? + +**Prefill** is the phase during LLM inference where many tokens (e.g., the entire prompt or a batch of sequences) are processed simultaneously. In MoE models, tokens are routed to different experts, and each expert may receive multiple tokens. This is different from **decode** (token generation), where typically only one token per expert is processed at a time. + +## Features + +### 1. **Comprehensive Data Type Support** +- FP16 (float16) +- BF16 (bfloat16) +- INT8 (symmetric and asymmetric quantization) +- INT4 (symmetric and asymmetric quantization) +- INT2 (symmetric and asymmetric quantization) +- FP8 (float8_e4m3fn and float8_e5m2) + +### 2. **TFLOPS Calculation** +The test calculates TFLOPS for each configuration using the formula: +``` +FLOPs = total_tokens × K × N × 2 +TFLOPS = FLOPs / (time_in_seconds) / 1e12 +``` + +Where: +- `total_tokens`: Total number of tokens across all experts +- `K`: Input feature dimension +- `N`: Output feature dimension +- `×2`: Each multiply-add operation counts as 2 FLOPs + +### 3. **Various MoE Configurations** +The test covers multiple realistic MoE scenarios: +- **Small models** (8 experts, Mixtral-style): 4096×4096, 4096×14336, 14336×4096 +- **Medium models** (8 experts): Various token distributions +- **Large models** (16, 32, 64 experts, DeepSeek-style): 2048×2048 +- **Uneven distributions**: Simulates real-world routing patterns + +### 4. **Baseline Comparison** +Each test compares the ARK MoE kernel against a baseline PyTorch implementation: +- **Baseline**: Per-expert matrix multiplication using `torch.matmul` +- **ARK Kernel**: Optimized `ark.moe_gemm` with fused operations +- **Speedup**: Reports speedup ratio (baseline_time / ark_time) + +## How to Run + +### Run all tests: +```bash +cd /path/to/auto_round_extension/ark/test +pytest -v -s test_moe_prefill_perf.py +``` + +### Run specific data type: +```bash +# FP16 tests only +pytest -v -s test_moe_prefill_perf.py::TestMoEGemmPrefillPerf::test_perf_fp + +# INT4 tests only +pytest -v -s test_moe_prefill_perf.py::TestMoEGemmPrefillPerf::test_perf_int4 + +# INT8 symmetric quantization with bfloat16 activations +pytest -v -s test_moe_prefill_perf.py::TestMoEGemmPrefillPerf::test_perf_int8 -k "bfloat16 and not asym" +``` + +**Note**: The `-s` flag is required to see the printed timing tables and TFLOPS output. + +## Output Format + +The test prints formatted tables with the following columns: + +``` +shape E N K tokens baseline(ms) ark(ms) speedup TFLOPS +small E=8 8 4096 4096 252 12.3456 4.5678 2.70x 45.2 +medium E=8 8 4096 14336 528 23.4567 8.9012 2.64x 78.9 +... +``` + +Where: +- **shape**: Configuration label +- **E**: Number of experts +- **N**: Output feature dimension +- **K**: Input feature dimension +- **tokens**: Total tokens across all experts +- **baseline(ms)**: PyTorch baseline latency (milliseconds) +- **ark(ms)**: ARK kernel latency (milliseconds) +- **speedup**: Performance improvement ratio +- **TFLOPS**: Throughput in tera floating point operations per second + +## Requirements + +- Intel XPU (Arc GPU) with PyTorch XPU support +- `auto_round_kernel` built with `ARK_SYCL_TLA=ON` +- Test dependencies from `test_moe.py` (pack/dequant helpers) + +## Architecture + +``` +test_moe_prefill_perf.py +├── Timing utilities (_xpu_time_ms) +│ └── Uses XPU events for accurate GPU timing +├── FLOPS calculation (_compute_moe_flops) +│ └── Computes theoretical FLOPs for TFLOPS metric +├── Baseline implementation (_default_moe_prefill) +│ └── Per-expert PyTorch matmul for comparison +├── Test shapes (PREFILL_SHAPES) +│ └── Various realistic MoE configurations +└── Test cases (TestMoEGemmPrefillPerf) + ├── test_perf_fp (FP16/BF16) + ├── test_perf_int4 (INT4 sym/asym) + ├── test_perf_int8 (INT8 sym/asym) + ├── test_perf_int2 (INT2 sym/asym) + └── test_perf_fp8 (FP8 e4m3fn/e5m2) +``` + +## Example Output + +``` +================================================================== +FP weights (float16) -- ark.moe_gemm (prefill) vs per-expert A @ W.T +================================================================== +shape E N K tokens baseline(ms) ark(ms) speedup TFLOPS +------------------------------------------------------------------ +small E=8 8 4096 4096 252 12.3456 4.5678 2.70x 45.2 +medium E=8 8 4096 14336 528 23.4567 8.9012 2.64x 78.9 +medium E=8 8 14336 4096 528 25.6789 9.1234 2.82x 76.5 +large E=16 16 2048 2048 256 5.6789 2.3456 2.42x 91.2 +large E=32 32 2048 2048 256 5.7890 2.4567 2.36x 87.3 +large E=64 64 2048 2048 256 5.8901 2.5678 2.29x 83.5 +uneven E=8 8 4096 4096 610 28.9012 10.1234 2.86x 52.1 +``` + +## Key Metrics + +1. **TFLOPS**: Higher is better - indicates compute throughput +2. **Speedup**: Higher is better - shows performance gain over baseline +3. **Latency (ms)**: Lower is better - actual kernel execution time + +## Integration with CI/CD + +This test can be integrated into performance regression testing: +- Set minimum TFLOPS thresholds for each configuration +- Track speedup ratios over time +- Alert on performance degradation + +## Related Files + +- `test_moe.py`: Correctness tests for MoE GEMM +- `test_moe_decode_perf.py`: Performance tests for MoE decode (single token per expert) +- `test_bench_bmg.py`: SDPA performance benchmarks with TFLOPS From 5feb7071acd83f5cc62112302a62452a9de8682f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Jun 2026 04:16:16 +0000 Subject: [PATCH 19/28] refactor: change ARK() instance to module reference in MoE perf tests --- auto_round_extension/ark/test/test_moe_decode_perf.py | 2 +- auto_round_extension/ark/test/test_moe_prefill_perf.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round_extension/ark/test/test_moe_decode_perf.py b/auto_round_extension/ark/test/test_moe_decode_perf.py index be253dcf9..e2ba36a18 100644 --- a/auto_round_extension/ark/test/test_moe_decode_perf.py +++ b/auto_round_extension/ark/test/test_moe_decode_perf.py @@ -54,7 +54,7 @@ _pack_int8_sym, ) -ark = auto_round_kernel.ARK() +ark = auto_round_kernel # --------------------------------------------------------------------------- diff --git a/auto_round_extension/ark/test/test_moe_prefill_perf.py b/auto_round_extension/ark/test/test_moe_prefill_perf.py index ac59dbd3e..3ef05998e 100644 --- a/auto_round_extension/ark/test/test_moe_prefill_perf.py +++ b/auto_round_extension/ark/test/test_moe_prefill_perf.py @@ -56,7 +56,7 @@ _pack_int8_sym, ) -ark = auto_round_kernel.ARK() +ark = auto_round_kernel # --------------------------------------------------------------------------- From 2b311bd73d7d436450065fa2c90d9cfcbf42fdd6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Jun 2026 04:55:24 +0000 Subject: [PATCH 20/28] fix(test): correct MoE prefill perf test layout and kernel entry - Use [E, N, K] weight layout for pack helpers (matches their contract and the working correctness tests in test_moe.py). - Drop the erroneous dequant.transpose(1, 2); dequant is already [E, N, K], which is what the baseline expects. - Call ark.moe_gemm_prefill for quantized timing (the dedicated quantized prefill kernel) instead of the FP-only ark.moe_gemm. - Add a per-test skip guard for quantized tests requiring moe_gemm_prefill so the FP perf cases still run when only moe_gemm is present. --- .../ark/test/test_moe_prefill_perf.py | 63 +++++++++++-------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/auto_round_extension/ark/test/test_moe_prefill_perf.py b/auto_round_extension/ark/test/test_moe_prefill_perf.py index 3ef05998e..073fb32ce 100644 --- a/auto_round_extension/ark/test/test_moe_prefill_perf.py +++ b/auto_round_extension/ark/test/test_moe_prefill_perf.py @@ -95,7 +95,21 @@ def _prefill_skip_reason() -> str: return "" +def _quantized_prefill_skip_reason() -> str: + """Return non-empty string if the quantized MoE prefill kernel can't be exercised.""" + reason = _prefill_skip_reason() + if reason: + return reason + if not hasattr(ark.xpu_lib, "moe_gemm_prefill"): + return ( + "ark.xpu_lib loaded but has no moe_gemm_prefill symbol -- " + "rebuild with ARK_SYCL_TLA=ON to compile the quantized MoE prefill kernel" + ) + return "" + + _PREFILL_SKIP = _prefill_skip_reason() +_QUANT_PREFILL_SKIP = _quantized_prefill_skip_reason() # Surface diagnostics on collection print( @@ -255,6 +269,7 @@ def test_perf_fp(self, dtype): _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops) + @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("asym", [False, True]) def test_perf_int4(self, dtype, asym): @@ -262,19 +277,18 @@ def test_perf_int4(self, dtype, asym): kind = "asym" if asym else "sym" _print_header( f"INT4 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " - f"-- ark.moe_gemm (prefill) vs dequant + per-expert A @ W.T" + f"-- ark.moe_gemm_prefill (prefill) vs dequant + per-expert A @ W.T" ) for label, E, tpe, N, K in PREFILL_SHAPES: if K % group_size != 0: continue total_tokens = sum(tpe) activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") - # Generate weights [E, K, N] - w_float = (torch.randn(E, K, N, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + # Pack helpers expect weights in [E, N, K] layout. + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") if asym: zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") - # Pack with [E, K, N] layout packed = _pack_int4_asym(w_float, scales, zeros, group_size) dequant = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) else: @@ -283,12 +297,10 @@ def test_perf_int4(self, dtype, asym): dequant = _dequant_int4_sym(packed, scales, group_size).to(dtype) ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") - # Baseline expects [E, N, K] - dequant_baseline = dequant.transpose(1, 2) - - base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant_baseline, ntpe)) + # ``dequant`` is already [E, N, K] -- matches the baseline contract. + base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant, ntpe)) ark_ms = _xpu_time_ms( - lambda: ark.moe_gemm( + lambda: ark.moe_gemm_prefill( activations, packed, ntpe, @@ -305,6 +317,7 @@ def test_perf_int4(self, dtype, asym): _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops) + @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("asym", [False, True]) def test_perf_int8(self, dtype, asym): @@ -312,14 +325,14 @@ def test_perf_int8(self, dtype, asym): kind = "asym" if asym else "sym" _print_header( f"INT8 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " - f"-- ark.moe_gemm (prefill) vs dequant + per-expert A @ W.T" + f"-- ark.moe_gemm_prefill (prefill) vs dequant + per-expert A @ W.T" ) for label, E, tpe, N, K in PREFILL_SHAPES: if K % group_size != 0: continue total_tokens = sum(tpe) activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") - w_float = (torch.randn(E, K, N, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") if asym: zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") @@ -331,11 +344,9 @@ def test_perf_int8(self, dtype, asym): dequant = _dequant_int8_sym(packed, scales, group_size).to(dtype) ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") - dequant_baseline = dequant.transpose(1, 2) - - base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant_baseline, ntpe)) + base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant, ntpe)) ark_ms = _xpu_time_ms( - lambda: ark.moe_gemm( + lambda: ark.moe_gemm_prefill( activations, packed, ntpe, @@ -352,6 +363,7 @@ def test_perf_int8(self, dtype, asym): _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops) + @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("asym", [False, True]) def test_perf_int2(self, dtype, asym): @@ -359,14 +371,14 @@ def test_perf_int2(self, dtype, asym): kind = "asym" if asym else "sym" _print_header( f"INT2 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " - f"-- ark.moe_gemm (prefill) vs dequant + per-expert A @ W.T" + f"-- ark.moe_gemm_prefill (prefill) vs dequant + per-expert A @ W.T" ) for label, E, tpe, N, K in PREFILL_SHAPES: if K % group_size != 0 or K % 4 != 0: continue total_tokens = sum(tpe) activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") - w_float = (torch.randn(E, K, N, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") if asym: zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") @@ -378,11 +390,9 @@ def test_perf_int2(self, dtype, asym): dequant = _dequant_int2_sym(packed, scales, group_size).to(dtype) ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") - dequant_baseline = dequant.transpose(1, 2) - - base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant_baseline, ntpe)) + base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant, ntpe)) ark_ms = _xpu_time_ms( - lambda: ark.moe_gemm( + lambda: ark.moe_gemm_prefill( activations, packed, ntpe, @@ -399,30 +409,29 @@ def test_perf_int2(self, dtype, asym): _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops) + @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) def test_perf_fp8(self, dtype, fp8_dtype): group_size = 128 _print_header( f"FP8 {str(fp8_dtype).split('.')[-1]} (group_size={group_size}, " - f"act={str(dtype).split('.')[-1]}) -- ark.moe_gemm (prefill) vs dequant + per-expert A @ W.T" + f"act={str(dtype).split('.')[-1]}) -- ark.moe_gemm_prefill (prefill) vs dequant + per-expert A @ W.T" ) for label, E, tpe, N, K in PREFILL_SHAPES: if K % group_size != 0: continue total_tokens = sum(tpe) activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") - w_float = (torch.randn(E, K, N, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) dequant = _dequant_fp8(packed, scales, group_size, dtype) ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") - dequant_baseline = dequant.transpose(1, 2) - - base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant_baseline, ntpe)) + base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant, ntpe)) ark_ms = _xpu_time_ms( - lambda: ark.moe_gemm( + lambda: ark.moe_gemm_prefill( activations, packed, ntpe, From e51cd9d29a68db61ee6cb4a1a34154f493089939 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Jun 2026 06:54:26 +0000 Subject: [PATCH 21/28] test(moe-perf): add dequant-inclusive baseline column for Stage-1 fairness Print both `base mm(ms)` (matmul-only baseline, prior behavior) and `base+deq(ms)` (dequant + matmul, the apples-to-apples comparison for the current Stage-1 `moe_gemm_prefill` which materialises a fp16/bf16 workspace before dispatching to the FP GEMM). Speedup is now reported against the dequant-inclusive baseline so the numbers reflect real end-to-end quantized prefill cost. FP perf test (`test_perf_fp`) is unchanged: it has no dequant pass. --- .../ark/test/test_moe_prefill_perf.py | 112 ++++++++++++++---- 1 file changed, 91 insertions(+), 21 deletions(-) diff --git a/auto_round_extension/ark/test/test_moe_prefill_perf.py b/auto_round_extension/ark/test/test_moe_prefill_perf.py index 073fb32ce..4ce79a1c0 100644 --- a/auto_round_extension/ark/test/test_moe_prefill_perf.py +++ b/auto_round_extension/ark/test/test_moe_prefill_perf.py @@ -214,23 +214,54 @@ def _compute_moe_flops(total_tokens, K, N, num_experts_active): ] -def _print_header(title: str) -> None: +def _print_header(title: str, *, with_dequant_baseline: bool = False) -> None: + """Print a benchmark header. + + When ``with_dequant_baseline`` is True, an extra ``base+deq(ms)`` column is + printed alongside the matmul-only baseline and ``speedup`` is reported + against the dequant-inclusive baseline (this is the apples-to-apples + comparison for the current Stage-1 ``moe_gemm_prefill`` which dequants + into a workspace before dispatching to the FP GEMM). + """ print() - print("=" * 110) + width = 130 if with_dequant_baseline else 110 + print("=" * width) print(title) - print( - f"{'shape':<14}{'E':>4}{'N':>7}{'K':>7}{'tokens':>8}" - f"{'baseline(ms)':>16}{'ark(ms)':>14}{'speedup':>12}{'TFLOPS':>10}" - ) - print("-" * 110) + if with_dequant_baseline: + print( + f"{'shape':<14}{'E':>4}{'N':>7}{'K':>7}{'tokens':>8}" + f"{'base mm(ms)':>14}{'base+deq(ms)':>16}{'ark(ms)':>12}" + f"{'speedup':>12}{'TFLOPS':>10}" + ) + else: + print( + f"{'shape':<14}{'E':>4}{'N':>7}{'K':>7}{'tokens':>8}" + f"{'baseline(ms)':>16}{'ark(ms)':>14}{'speedup':>12}{'TFLOPS':>10}" + ) + print("-" * width) -def _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops): - speedup = base_ms / ark_ms if ark_ms > 0 else float("nan") - print( - f"{label:<14}{E:>4}{N:>7}{K:>7}{total_tokens:>8}" - f"{base_ms:>16.4f}{ark_ms:>14.4f}{speedup:>11.2f}x{tflops:>9.1f}" - ) +def _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops, *, base_with_deq_ms=None): + """Print a benchmark row. + + If ``base_with_deq_ms`` is provided, the dequant-inclusive baseline is + shown and ``speedup`` is computed against it (apples-to-apples with the + Stage-1 quantized path). Otherwise the row reverts to the original + matmul-only layout. + """ + if base_with_deq_ms is None: + speedup = base_ms / ark_ms if ark_ms > 0 else float("nan") + print( + f"{label:<14}{E:>4}{N:>7}{K:>7}{total_tokens:>8}" + f"{base_ms:>16.4f}{ark_ms:>14.4f}{speedup:>11.2f}x{tflops:>9.1f}" + ) + else: + speedup = base_with_deq_ms / ark_ms if ark_ms > 0 else float("nan") + print( + f"{label:<14}{E:>4}{N:>7}{K:>7}{total_tokens:>8}" + f"{base_ms:>14.4f}{base_with_deq_ms:>16.4f}{ark_ms:>12.4f}" + f"{speedup:>11.2f}x{tflops:>9.1f}" + ) # --------------------------------------------------------------------------- @@ -277,7 +308,8 @@ def test_perf_int4(self, dtype, asym): kind = "asym" if asym else "sym" _print_header( f"INT4 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " - f"-- ark.moe_gemm_prefill (prefill) vs dequant + per-expert A @ W.T" + f"-- ark.moe_gemm_prefill (prefill) vs dequant + per-expert A @ W.T", + with_dequant_baseline=True, ) for label, E, tpe, N, K in PREFILL_SHAPES: if K % group_size != 0: @@ -291,14 +323,24 @@ def test_perf_int4(self, dtype, asym): zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") packed = _pack_int4_asym(w_float, scales, zeros, group_size) dequant = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) + + def _baseline_with_dequant(): + d = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) + return _default_moe_prefill(activations, d, ntpe) else: zeros = None packed = _pack_int4_sym(w_float, scales, group_size) dequant = _dequant_int4_sym(packed, scales, group_size).to(dtype) + + def _baseline_with_dequant(): + d = _dequant_int4_sym(packed, scales, group_size).to(dtype) + return _default_moe_prefill(activations, d, ntpe) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") # ``dequant`` is already [E, N, K] -- matches the baseline contract. base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant, ntpe)) + base_deq_ms = _xpu_time_ms(_baseline_with_dequant) ark_ms = _xpu_time_ms( lambda: ark.moe_gemm_prefill( activations, @@ -315,7 +357,7 @@ def test_perf_int4(self, dtype, asym): flops = _compute_moe_flops(total_tokens, K, N, E) tflops = flops / (ark_ms * 1e-3) / 1e12 - _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops) + _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops, base_with_deq_ms=base_deq_ms) @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -325,7 +367,8 @@ def test_perf_int8(self, dtype, asym): kind = "asym" if asym else "sym" _print_header( f"INT8 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " - f"-- ark.moe_gemm_prefill (prefill) vs dequant + per-expert A @ W.T" + f"-- ark.moe_gemm_prefill (prefill) vs dequant + per-expert A @ W.T", + with_dequant_baseline=True, ) for label, E, tpe, N, K in PREFILL_SHAPES: if K % group_size != 0: @@ -338,13 +381,23 @@ def test_perf_int8(self, dtype, asym): zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") packed = _pack_int8_asym(w_float, scales, zeros, group_size) dequant = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) + + def _baseline_with_dequant(): + d = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) + return _default_moe_prefill(activations, d, ntpe) else: zeros = None packed = _pack_int8_sym(w_float, scales, group_size) dequant = _dequant_int8_sym(packed, scales, group_size).to(dtype) + + def _baseline_with_dequant(): + d = _dequant_int8_sym(packed, scales, group_size).to(dtype) + return _default_moe_prefill(activations, d, ntpe) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant, ntpe)) + base_deq_ms = _xpu_time_ms(_baseline_with_dequant) ark_ms = _xpu_time_ms( lambda: ark.moe_gemm_prefill( activations, @@ -361,7 +414,7 @@ def test_perf_int8(self, dtype, asym): flops = _compute_moe_flops(total_tokens, K, N, E) tflops = flops / (ark_ms * 1e-3) / 1e12 - _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops) + _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops, base_with_deq_ms=base_deq_ms) @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -371,7 +424,8 @@ def test_perf_int2(self, dtype, asym): kind = "asym" if asym else "sym" _print_header( f"INT2 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " - f"-- ark.moe_gemm_prefill (prefill) vs dequant + per-expert A @ W.T" + f"-- ark.moe_gemm_prefill (prefill) vs dequant + per-expert A @ W.T", + with_dequant_baseline=True, ) for label, E, tpe, N, K in PREFILL_SHAPES: if K % group_size != 0 or K % 4 != 0: @@ -384,13 +438,23 @@ def test_perf_int2(self, dtype, asym): zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") packed = _pack_int2_asym(w_float, scales, zeros, group_size) dequant = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) + + def _baseline_with_dequant(): + d = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) + return _default_moe_prefill(activations, d, ntpe) else: zeros = None packed = _pack_int2_sym(w_float, scales, group_size) dequant = _dequant_int2_sym(packed, scales, group_size).to(dtype) + + def _baseline_with_dequant(): + d = _dequant_int2_sym(packed, scales, group_size).to(dtype) + return _default_moe_prefill(activations, d, ntpe) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant, ntpe)) + base_deq_ms = _xpu_time_ms(_baseline_with_dequant) ark_ms = _xpu_time_ms( lambda: ark.moe_gemm_prefill( activations, @@ -407,7 +471,7 @@ def test_perf_int2(self, dtype, asym): flops = _compute_moe_flops(total_tokens, K, N, E) tflops = flops / (ark_ms * 1e-3) / 1e12 - _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops) + _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops, base_with_deq_ms=base_deq_ms) @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -416,7 +480,8 @@ def test_perf_fp8(self, dtype, fp8_dtype): group_size = 128 _print_header( f"FP8 {str(fp8_dtype).split('.')[-1]} (group_size={group_size}, " - f"act={str(dtype).split('.')[-1]}) -- ark.moe_gemm_prefill (prefill) vs dequant + per-expert A @ W.T" + f"act={str(dtype).split('.')[-1]}) -- ark.moe_gemm_prefill (prefill) vs dequant + per-expert A @ W.T", + with_dequant_baseline=True, ) for label, E, tpe, N, K in PREFILL_SHAPES: if K % group_size != 0: @@ -429,7 +494,12 @@ def test_perf_fp8(self, dtype, fp8_dtype): dequant = _dequant_fp8(packed, scales, group_size, dtype) ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + def _baseline_with_dequant(): + d = _dequant_fp8(packed, scales, group_size, dtype) + return _default_moe_prefill(activations, d, ntpe) + base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant, ntpe)) + base_deq_ms = _xpu_time_ms(_baseline_with_dequant) ark_ms = _xpu_time_ms( lambda: ark.moe_gemm_prefill( activations, @@ -444,7 +514,7 @@ def test_perf_fp8(self, dtype, fp8_dtype): flops = _compute_moe_flops(total_tokens, K, N, E) tflops = flops / (ark_ms * 1e-3) / 1e12 - _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops) + _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops, base_with_deq_ms=base_deq_ms) if __name__ == "__main__": From 834d03ce929f9bb992480b447b491aa75a4f645c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Jun 2026 06:55:17 +0000 Subject: [PATCH 22/28] perf(moe-prefill): cache the [E, K, N] dequant workspace across calls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For the quantized paths the Python `moe_gemm_prefill` wrapper previously allocated a fresh `E*K*N*sizeof(act_dtype)` scratch tensor on every call. For real MoE prefill workloads the same shape repeats every step, so the allocator overhead is pure waste — and dominates the small-shape numbers. Move the workspace to a module-level cache keyed by `(device, dtype, E, K, N)`. The unquantized fast path is unchanged: it still uses a per-call transposed copy of `weights`. Added `clear_moe_prefill_workspace_cache()` for callers that need to drop the cached buffers. --- .../ark/auto_round_kernel/__init__.py | 64 ++++++++++++++++++- 1 file changed, 61 insertions(+), 3 deletions(-) diff --git a/auto_round_extension/ark/auto_round_kernel/__init__.py b/auto_round_extension/ark/auto_round_kernel/__init__.py index b4bbc4cc0..0e6e65fa4 100644 --- a/auto_round_extension/ark/auto_round_kernel/__init__.py +++ b/auto_round_extension/ark/auto_round_kernel/__init__.py @@ -1455,7 +1455,16 @@ def moe_gemm_prefill( dequant_workspace = weights.transpose(1, 2).contiguous() weights_ptr = dequant_workspace.data_ptr() else: - dequant_workspace = torch.empty((num_experts, K, N), device=activations.device, dtype=activations.dtype) + # Reuse a persistent `[E, K, N]` workspace across calls with the same + # (device, dtype, E, K, N). For real MoE prefill workloads the same + # shape is dispatched on every iteration; allocating a fresh + # `E*K*N*sizeof(act)` tensor each call adds non-trivial caching- + # allocator overhead (and, on the small shapes, dominates the + # quantized GEMM cost). The workspace is kept alive by the cache so + # we hand the data_ptr() to the kernel without taking a new ref. + dequant_workspace = _get_moe_prefill_workspace( + activations.device, activations.dtype, num_experts, K, N + ) weights_ptr = weights.data_ptr() scales_ptr = scales.data_ptr() if scales is not None else 0 @@ -1482,11 +1491,60 @@ def moe_gemm_prefill( # The inner CUTLASS-SYCL `moe_gemm` calls `event.wait()` before returning # (see `moe_detail::moe_gemm_launcher` in `sycl_tla_moe.hpp`), so by the # time `lib.moe_gemm_prefill` returns the device has already consumed the - # workspace and it is safe to drop our reference here. - del dequant_workspace + # workspace. For the unquantized fast path the workspace is a per-call + # transposed copy of `weights` -- drop it now. For the quantized paths + # the workspace lives in the module-level cache (`_get_moe_prefill_workspace`) + # and is intentionally retained for reuse on the next call. + if is_unquantized: + del dequant_workspace return outputs +# --------------------------------------------------------------------------- +# `moe_gemm_prefill` dequant-workspace cache. +# +# The Stage-1 quantized prefill kernel dequantises weights into an +# `[E, K, N]` act-dtype scratch buffer before dispatching to the existing +# CUTLASS-SYCL grouped GEMM. In real model usage the same `(E, K, N, dtype)` +# tuple is hit on every prefill step, so allocating a fresh +# `E * K * N * sizeof(act_dtype)` tensor per call adds caching-allocator +# overhead that is significant on the small/medium shapes. +# +# We cache one tensor per `(device, dtype, E, K, N)` key. The cache holds +# references that keep the tensors alive across calls; callers can clear it +# explicitly via `clear_moe_prefill_workspace_cache()` if they need to +# release the memory (e.g., before allocating large buffers for a different +# subsystem). +# --------------------------------------------------------------------------- + +_MOE_PREFILL_WORKSPACE_CACHE: "dict[tuple, torch.Tensor]" = {} + + +def _get_moe_prefill_workspace(device: torch.device, dtype: torch.dtype, E: int, K: int, N: int) -> torch.Tensor: + """Return a persistent `[E, K, N]` workspace tensor for the prefill kernel. + + The tensor is allocated lazily on first use and retained in a module-level + cache so subsequent calls with the same `(device, dtype, E, K, N)` reuse + the same memory. Returned tensors are contiguous and uninitialised; the + kernel writes every element before reading. + """ + # `device` may be a `torch.device` or a string; normalise so the cache key + # is hashable and identifies the exact device (including ordinal). + if not isinstance(device, torch.device): + device = torch.device(device) + key = (device.type, device.index, dtype, int(E), int(K), int(N)) + ws = _MOE_PREFILL_WORKSPACE_CACHE.get(key) + if ws is None: + ws = torch.empty((E, K, N), device=device, dtype=dtype) + _MOE_PREFILL_WORKSPACE_CACHE[key] = ws + return ws + + +def clear_moe_prefill_workspace_cache() -> None: + """Release all cached `moe_gemm_prefill` dequant-workspace tensors.""" + _MOE_PREFILL_WORKSPACE_CACHE.clear() + + def patch_torch_sdpa(*args, **kwargs): from .torch_sdpa_patch import patch_torch_sdpa_with_ark From 2caa0663c138eed05f412bef1a6552beee2ab516 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Jun 2026 06:57:03 +0000 Subject: [PATCH 23/28] perf(moe_prefill): skip dequant for experts with zero tokens --- .../wrapper/include/sycl_tla_moe_mixed.hpp | 56 ++++++++++++------- 1 file changed, 37 insertions(+), 19 deletions(-) diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp index 6e24534f9..7740da3b6 100644 --- a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp @@ -83,7 +83,8 @@ constexpr int WG_N = 16; // as the quantized variants. // ---------------------------------------------------------------------------- template -void launch_dequant_fp(sycl::queue* q, const ScalarT* weights_NK, ScalarT* weights_KN, int E, int N, int K) { +void launch_dequant_fp(sycl::queue* q, const ScalarT* weights_NK, ScalarT* weights_KN, int E, int N, int K, + const int* num_tokens_per_expert = nullptr) { if (E == 0 || N == 0 || K == 0) return; sycl::range<3> global{static_cast(E), static_cast(K), @@ -93,6 +94,9 @@ void launch_dequant_fp(sycl::queue* q, const ScalarT* weights_NK, ScalarT* weigh q->parallel_for>( sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { const int e = static_cast(it.get_global_id(0)); + // Skip experts that receive no tokens in this prefill batch; the + // grouped GEMM will not read their rows of `weights_KN` either. + if (num_tokens_per_expert != nullptr && num_tokens_per_expert[e] == 0) return; const int k = static_cast(it.get_global_id(1)); const int n = static_cast(it.get_global_id(2)); if (n >= N) return; @@ -109,7 +113,8 @@ void launch_dequant_fp(sycl::queue* q, const ScalarT* weights_NK, ScalarT* weigh // ---------------------------------------------------------------------------- template void launch_dequant_int8(sycl::queue* q, const uint8_t* weights_NK, const ScalarT* scales, const ScalarT* zeros, - ScalarT* weights_KN, int E, int N, int K, int group_size) { + ScalarT* weights_KN, int E, int N, int K, int group_size, + const int* num_tokens_per_expert = nullptr) { if (E == 0 || N == 0 || K == 0) return; if (Asym && zeros == nullptr) { throw std::invalid_argument("moe_gemm_prefill(int8): zeros pointer required when asym=true"); @@ -123,6 +128,7 @@ void launch_dequant_int8(sycl::queue* q, const uint8_t* weights_NK, const Scalar q->parallel_for>( sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { const int e = static_cast(it.get_global_id(0)); + if (num_tokens_per_expert != nullptr && num_tokens_per_expert[e] == 0) return; const int k = static_cast(it.get_global_id(1)); const int n = static_cast(it.get_global_id(2)); if (n >= N) return; @@ -152,7 +158,8 @@ void launch_dequant_int8(sycl::queue* q, const uint8_t* weights_NK, const Scalar // ---------------------------------------------------------------------------- template void launch_dequant_int4(sycl::queue* q, const uint8_t* weights_NKp, const ScalarT* scales, const ScalarT* zeros, - ScalarT* weights_KN, int E, int N, int K, int group_size) { + ScalarT* weights_KN, int E, int N, int K, int group_size, + const int* num_tokens_per_expert = nullptr) { if (E == 0 || N == 0 || K == 0) return; if (Asym && zeros == nullptr) { throw std::invalid_argument("moe_gemm_prefill(int4): zeros pointer required when asym=true"); @@ -170,6 +177,7 @@ void launch_dequant_int4(sycl::queue* q, const uint8_t* weights_NKp, const Scala q->parallel_for>( sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { const int e = static_cast(it.get_global_id(0)); + if (num_tokens_per_expert != nullptr && num_tokens_per_expert[e] == 0) return; const int k = static_cast(it.get_global_id(1)); const int n = static_cast(it.get_global_id(2)); if (n >= N) return; @@ -206,7 +214,8 @@ void launch_dequant_int4(sycl::queue* q, const uint8_t* weights_NKp, const Scala // ---------------------------------------------------------------------------- template void launch_dequant_int2(sycl::queue* q, const uint8_t* weights_NKp, const ScalarT* scales, const ScalarT* zeros, - ScalarT* weights_KN, int E, int N, int K, int group_size) { + ScalarT* weights_KN, int E, int N, int K, int group_size, + const int* num_tokens_per_expert = nullptr) { if (E == 0 || N == 0 || K == 0) return; if (Asym && zeros == nullptr) { throw std::invalid_argument("moe_gemm_prefill(int2): zeros pointer required when asym=true"); @@ -224,6 +233,7 @@ void launch_dequant_int2(sycl::queue* q, const uint8_t* weights_NKp, const Scala q->parallel_for>( sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { const int e = static_cast(it.get_global_id(0)); + if (num_tokens_per_expert != nullptr && num_tokens_per_expert[e] == 0) return; const int k = static_cast(it.get_global_id(1)); const int n = static_cast(it.get_global_id(2)); if (n >= N) return; @@ -260,7 +270,7 @@ void launch_dequant_int2(sycl::queue* q, const uint8_t* weights_NKp, const Scala // ---------------------------------------------------------------------------- template void launch_dequant_fp8(sycl::queue* q, const uint8_t* weights_NK, const ScalarT* scales, ScalarT* weights_KN, int E, - int N, int K, int group_size) { + int N, int K, int group_size, const int* num_tokens_per_expert = nullptr) { if (E == 0 || N == 0 || K == 0) return; const int num_groups_k = K / group_size; @@ -271,6 +281,7 @@ void launch_dequant_fp8(sycl::queue* q, const uint8_t* weights_NK, const ScalarT q->parallel_for>( sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { const int e = static_cast(it.get_global_id(0)); + if (num_tokens_per_expert != nullptr && num_tokens_per_expert[e] == 0) return; const int k = static_cast(it.get_global_id(1)); const int n = static_cast(it.get_global_id(2)); if (n >= N) return; @@ -292,38 +303,45 @@ void launch_dequant_fp8(sycl::queue* q, const uint8_t* weights_NK, const ScalarT // ---------------------------------------------------------------------------- template void dequant_to_KN(sycl::queue* q, const void* weights, const void* scales, const void* zeros, ScalarT* weights_KN, - BTLA_DTYPE weight_dtype, int E, int N, int K, int group_size, bool asym) { + BTLA_DTYPE weight_dtype, int E, int N, int K, int group_size, bool asym, + const int* num_tokens_per_expert = nullptr) { if (weight_dtype == BTLA_DTYPE::F16 || weight_dtype == BTLA_DTYPE::BF16) { - launch_dequant_fp(q, static_cast(weights), weights_KN, E, N, K); + launch_dequant_fp(q, static_cast(weights), weights_KN, E, N, K, num_tokens_per_expert); return; } if (weight_dtype == BTLA_DTYPE::S8) { if (asym) { launch_dequant_int8(q, static_cast(weights), static_cast(scales), - static_cast(zeros), weights_KN, E, N, K, group_size); + static_cast(zeros), weights_KN, E, N, K, group_size, + num_tokens_per_expert); } else { launch_dequant_int8(q, static_cast(weights), static_cast(scales), - static_cast(zeros), weights_KN, E, N, K, group_size); + static_cast(zeros), weights_KN, E, N, K, group_size, + num_tokens_per_expert); } return; } if (weight_dtype == BTLA_DTYPE::S4_CLIP) { if (asym) { launch_dequant_int4(q, static_cast(weights), static_cast(scales), - static_cast(zeros), weights_KN, E, N, K, group_size); + static_cast(zeros), weights_KN, E, N, K, group_size, + num_tokens_per_expert); } else { launch_dequant_int4(q, static_cast(weights), static_cast(scales), - static_cast(zeros), weights_KN, E, N, K, group_size); + static_cast(zeros), weights_KN, E, N, K, group_size, + num_tokens_per_expert); } return; } if (weight_dtype == BTLA_DTYPE::S2_CLIP) { if (asym) { launch_dequant_int2(q, static_cast(weights), static_cast(scales), - static_cast(zeros), weights_KN, E, N, K, group_size); + static_cast(zeros), weights_KN, E, N, K, group_size, + num_tokens_per_expert); } else { launch_dequant_int2(q, static_cast(weights), static_cast(scales), - static_cast(zeros), weights_KN, E, N, K, group_size); + static_cast(zeros), weights_KN, E, N, K, group_size, + num_tokens_per_expert); } return; } @@ -337,21 +355,21 @@ void dequant_to_KN(sycl::queue* q, const void* weights, const void* scales, cons if (use_lut) { launch_dequant_fp8(q, static_cast(weights), static_cast(scales), weights_KN, E, N, K, - group_size); + group_size, num_tokens_per_expert); } else { launch_dequant_fp8(q, static_cast(weights), static_cast(scales), weights_KN, E, N, K, - group_size); + group_size, num_tokens_per_expert); } } else { if (use_lut) { launch_dequant_fp8(q, static_cast(weights), static_cast(scales), weights_KN, E, N, K, - group_size); + group_size, num_tokens_per_expert); } else { launch_dequant_fp8(q, static_cast(weights), static_cast(scales), weights_KN, E, N, K, - group_size); + group_size, num_tokens_per_expert); } } return; @@ -408,13 +426,13 @@ inline void moe_gemm_prefill(sycl::queue* q, void* activations, void* weights, v if (act_dtype == BTLA_DTYPE::F16) { auto* w_kn = static_cast(dequant_workspace); moe_mixed_detail::dequant_to_KN(q, weights, scales, zeros, w_kn, weight_dtype, num_experts, N, K, - group_size, asym); + group_size, asym, num_tokens_per_expert); moe_gemm(q, activations, w_kn, /*scales=*/nullptr, outputs, act_dtype, N, K, num_tokens_per_expert, num_experts); } else if (act_dtype == BTLA_DTYPE::BF16) { using BF = sycl::ext::oneapi::bfloat16; auto* w_kn = static_cast(dequant_workspace); moe_mixed_detail::dequant_to_KN(q, weights, scales, zeros, w_kn, weight_dtype, num_experts, N, K, group_size, - asym); + asym, num_tokens_per_expert); moe_gemm(q, activations, w_kn, /*scales=*/nullptr, outputs, act_dtype, N, K, num_tokens_per_expert, num_experts); } else { throw std::invalid_argument("moe_gemm_prefill: act_dtype must be F16 or BF16"); From 3e5192c098a0a2c274b43de058472dc0c34e1a8b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Jun 2026 09:39:12 +0000 Subject: [PATCH 24/28] perf(moe_prefill): pack PACK_K K-outputs per dequant work-item --- .../wrapper/include/sycl_tla_moe_mixed.hpp | 205 ++++++++++++------ 1 file changed, 136 insertions(+), 69 deletions(-) diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp index 7740da3b6..ccbe77498 100644 --- a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp @@ -70,11 +70,33 @@ class MoEDequantKernelInt2; template class MoEDequantKernelFP8; -// Tile sizes for the dequant kernel: each work-item dequantises one weight -// element, with the work-group covering a 1xWG_N tile along (k, n). N is -// chosen as the inner dimension to keep stores to the [E, K, N] output -// coalesced across the sub-group. -constexpr int WG_N = 16; +// Tile sizes for the dequant kernels. +// +// Each work-group covers a (PACK_K x WG_N) tile in (k, n) and writes PACK_K +// consecutive K rows for the same column band. N is the inner dimension so +// stores to the `[E, K, N]` workspace stay coalesced across the sub-group. +// +// PACK_K is chosen per weight encoding so a single work-item dequantises an +// entire packed byte (INT4: 2 outputs, INT2: 4 outputs) or a small run of +// elements sharing one scale/zero load (INT8 / FP8 / FP transpose: 4 +// outputs). This removes the redundant packed-byte and scale loads that +// the previous "one element per work-item" launch incurred: +// - INT4: every byte was read twice (one item per nibble) -> now once. +// - INT2: every byte was read four times -> now once. +// - All quantized paths: every scale (and zero) was reloaded by every K +// element in the group -> now once per PACK_K elements. +// Group-wise scale sharing is safe because PACK_K <= group_size in every +// supported configuration (group_size >= 32 in practice). +// +// WG_N is the sub-group store width along N. 32 yields a single 64-byte +// coalesced burst per row for FP16/BF16 writes, which matches the L1 +// cache-line size on the target XPUs. +constexpr int WG_N = 32; +constexpr int PACK_K_FP = 4; +constexpr int PACK_K_INT8 = 4; +constexpr int PACK_K_INT4 = 2; +constexpr int PACK_K_INT2 = 4; +constexpr int PACK_K_FP8 = 4; // ---------------------------------------------------------------------------- // FP16 / BF16 weight reshape: in-place transpose [E, N, K] -> [E, K, N]. @@ -87,7 +109,12 @@ void launch_dequant_fp(sycl::queue* q, const ScalarT* weights_NK, ScalarT* weigh const int* num_tokens_per_expert = nullptr) { if (E == 0 || N == 0 || K == 0) return; - sycl::range<3> global{static_cast(E), static_cast(K), + // Each work-item copies PACK_K_FP consecutive K elements for a single + // (e, n). The launch grid covers ceil(K / PACK_K_FP) along the middle dim; + // an inner bounds check guards the K tail when K is not a multiple of + // PACK_K_FP (e.g. K < PACK_K_FP in tiny unit tests). + const int k_tiles = (K + PACK_K_FP - 1) / PACK_K_FP; + sycl::range<3> global{static_cast(E), static_cast(k_tiles), static_cast((N + WG_N - 1) / WG_N) * WG_N}; sycl::range<3> local{1, 1, static_cast(WG_N)}; @@ -97,12 +124,18 @@ void launch_dequant_fp(sycl::queue* q, const ScalarT* weights_NK, ScalarT* weigh // Skip experts that receive no tokens in this prefill batch; the // grouped GEMM will not read their rows of `weights_KN` either. if (num_tokens_per_expert != nullptr && num_tokens_per_expert[e] == 0) return; - const int k = static_cast(it.get_global_id(1)); + const int k_base = static_cast(it.get_global_id(1)) * PACK_K_FP; const int n = static_cast(it.get_global_id(2)); if (n >= N) return; - const ScalarT v = - weights_NK[(static_cast(e) * N + static_cast(n)) * K + static_cast(k)]; - weights_KN[(static_cast(e) * K + static_cast(k)) * N + static_cast(n)] = v; + const size_t w_row = (static_cast(e) * N + static_cast(n)) * K; + const size_t out_base = static_cast(e) * K * N + static_cast(n); +#pragma unroll + for (int j = 0; j < PACK_K_FP; ++j) { + const int k = k_base + j; + if (k >= K) break; + const ScalarT v = weights_NK[w_row + static_cast(k)]; + weights_KN[out_base + static_cast(k) * N] = v; + } }); } @@ -121,7 +154,12 @@ void launch_dequant_int8(sycl::queue* q, const uint8_t* weights_NK, const Scalar } const int num_groups_k = K / group_size; - sycl::range<3> global{static_cast(E), static_cast(K), + // Each work-item dequantises PACK_K_INT8 consecutive K outputs for a + // single (e, n), sharing one scale/zero load. PACK_K_INT8 (=4) is always + // <= group_size (>=32 in practice), so the cached scale/zero is valid + // for every element in the run. + const int k_tiles = (K + PACK_K_INT8 - 1) / PACK_K_INT8; + sycl::range<3> global{static_cast(E), static_cast(k_tiles), static_cast((N + WG_N - 1) / WG_N) * WG_N}; sycl::range<3> local{1, 1, static_cast(WG_N)}; @@ -129,24 +167,32 @@ void launch_dequant_int8(sycl::queue* q, const uint8_t* weights_NK, const Scalar sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { const int e = static_cast(it.get_global_id(0)); if (num_tokens_per_expert != nullptr && num_tokens_per_expert[e] == 0) return; - const int k = static_cast(it.get_global_id(1)); + const int k_base = static_cast(it.get_global_id(1)) * PACK_K_INT8; const int n = static_cast(it.get_global_id(2)); if (n >= N) return; - const int g = k / group_size; + const size_t w_row = (static_cast(e) * N + static_cast(n)) * K; + const size_t out_base = static_cast(e) * K * N + static_cast(n); + // Hoist scale/zero loads: PACK_K_INT8 K values share the same group + // because PACK_K_INT8 <= group_size and groups are aligned at K + // multiples of group_size (PACK_K_INT8 divides group_size). + const int g = k_base / group_size; const size_t s_idx = (static_cast(e) * N + static_cast(n)) * num_groups_k + static_cast(g); const float scale = static_cast(scales[s_idx]); - const uint8_t raw = - weights_NK[(static_cast(e) * N + static_cast(n)) * K + static_cast(k)]; - float w; - if constexpr (Asym) { - const float zero = static_cast(zeros[s_idx]); - w = (static_cast(raw) - zero) * scale; - } else { - w = static_cast(static_cast(raw)) * scale; + const float zero = Asym ? static_cast(zeros[s_idx]) : 0.0f; +#pragma unroll + for (int j = 0; j < PACK_K_INT8; ++j) { + const int k = k_base + j; + if (k >= K) break; + const uint8_t raw = weights_NK[w_row + static_cast(k)]; + float w; + if constexpr (Asym) { + w = (static_cast(raw) - zero) * scale; + } else { + w = static_cast(static_cast(raw)) * scale; + } + weights_KN[out_base + static_cast(k) * N] = static_cast(w); } - weights_KN[(static_cast(e) * K + static_cast(k)) * N + static_cast(n)] = - static_cast(w); }); } @@ -170,7 +216,11 @@ void launch_dequant_int4(sycl::queue* q, const uint8_t* weights_NKp, const Scala const int num_groups_k = K / group_size; const int k_packed = K / 2; - sycl::range<3> global{static_cast(E), static_cast(K), + // Each work-item now dequantises one full packed byte = PACK_K_INT4 (=2) + // consecutive K outputs for a single (e, n). The middle launch dim is + // therefore k_packed instead of K, halving the work-item count and + // eliminating the previous "two items reading the same byte" pattern. + sycl::range<3> global{static_cast(E), static_cast(k_packed), static_cast((N + WG_N - 1) / WG_N) * WG_N}; sycl::range<3> local{1, 1, static_cast(WG_N)}; @@ -178,31 +228,34 @@ void launch_dequant_int4(sycl::queue* q, const uint8_t* weights_NKp, const Scala sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { const int e = static_cast(it.get_global_id(0)); if (num_tokens_per_expert != nullptr && num_tokens_per_expert[e] == 0) return; - const int k = static_cast(it.get_global_id(1)); + const int kp = static_cast(it.get_global_id(1)); const int n = static_cast(it.get_global_id(2)); if (n >= N) return; - const int g = k / group_size; + const int k_base = kp * PACK_K_INT4; + const int g = k_base / group_size; const size_t s_idx = (static_cast(e) * N + static_cast(n)) * num_groups_k + static_cast(g); const float scale = static_cast(scales[s_idx]); - const uint8_t packed = - weights_NKp[(static_cast(e) * N + static_cast(n)) * k_packed + - static_cast(k / 2)]; - const bool is_high = (k & 1) != 0; - float w; + const float zero = Asym ? static_cast(zeros[s_idx]) : 0.0f; + const uint8_t packed = weights_NKp[(static_cast(e) * N + static_cast(n)) * k_packed + + static_cast(kp)]; + const size_t out_base = static_cast(e) * K * N + static_cast(n); + float w0, w1; if constexpr (Asym) { - const float zero = static_cast(zeros[s_idx]); - const int q = static_cast(is_high ? ((packed >> 4) & 0x0F) : (packed & 0x0F)); - w = (static_cast(q) - zero) * scale; + const int q_lo = static_cast(packed & 0x0F); + const int q_hi = static_cast((packed >> 4) & 0x0F); + w0 = (static_cast(q_lo) - zero) * scale; + w1 = (static_cast(q_hi) - zero) * scale; } else { - // Sign-extend 4-bit -> 8-bit by shifting into the top nibble. - const int q = is_high - ? static_cast(static_cast(packed & 0xF0) >> 4) - : static_cast(static_cast(packed << 4) >> 4); - w = static_cast(q) * scale; + // Sign-extend each nibble: shift into the top of an int8 then + // arithmetic-shift right by 4 to fill the sign bits. + const int q_lo = static_cast(static_cast(packed << 4) >> 4); + const int q_hi = static_cast(static_cast(packed & 0xF0) >> 4); + w0 = static_cast(q_lo) * scale; + w1 = static_cast(q_hi) * scale; } - weights_KN[(static_cast(e) * K + static_cast(k)) * N + static_cast(n)] = - static_cast(w); + weights_KN[out_base + static_cast(k_base) * N] = static_cast(w0); + weights_KN[out_base + static_cast(k_base + 1) * N] = static_cast(w1); }); } @@ -226,7 +279,11 @@ void launch_dequant_int2(sycl::queue* q, const uint8_t* weights_NKp, const Scala const int num_groups_k = K / group_size; const int k_packed = K / 4; - sycl::range<3> global{static_cast(E), static_cast(K), + // One work-item handles one full packed byte = PACK_K_INT2 (=4) + // consecutive K outputs for a single (e, n). This removes the previous + // 4x duplicated byte loads (one item per 2-bit field) and 4x scale + // reloads. + sycl::range<3> global{static_cast(E), static_cast(k_packed), static_cast((N + WG_N - 1) / WG_N) * WG_N}; sycl::range<3> local{1, 1, static_cast(WG_N)}; @@ -234,31 +291,33 @@ void launch_dequant_int2(sycl::queue* q, const uint8_t* weights_NKp, const Scala sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { const int e = static_cast(it.get_global_id(0)); if (num_tokens_per_expert != nullptr && num_tokens_per_expert[e] == 0) return; - const int k = static_cast(it.get_global_id(1)); + const int kp = static_cast(it.get_global_id(1)); const int n = static_cast(it.get_global_id(2)); if (n >= N) return; - const int g = k / group_size; + const int k_base = kp * PACK_K_INT2; + const int g = k_base / group_size; const size_t s_idx = (static_cast(e) * N + static_cast(n)) * num_groups_k + static_cast(g); const float scale = static_cast(scales[s_idx]); - const uint8_t packed = - weights_NKp[(static_cast(e) * N + static_cast(n)) * k_packed + - static_cast(k / 4)]; - const int field = k & 3; - float w; - if constexpr (Asym) { - const float zero = static_cast(zeros[s_idx]); - const int q = static_cast((packed >> (2 * field)) & 0x3); - w = (static_cast(q) - zero) * scale; - } else { - // Sign-extend 2-bit by shifting the field into the top bits of an int8. - const int shift = 6 - 2 * field; // 6, 4, 2, 0 for fields 0..3 - const int8_t s8 = static_cast((packed << shift) & 0xC0); - const int q = static_cast(s8 >> 6); - w = static_cast(q) * scale; + const float zero = Asym ? static_cast(zeros[s_idx]) : 0.0f; + const uint8_t packed = weights_NKp[(static_cast(e) * N + static_cast(n)) * k_packed + + static_cast(kp)]; + const size_t out_base = static_cast(e) * K * N + static_cast(n); +#pragma unroll + for (int j = 0; j < PACK_K_INT2; ++j) { + float w; + if constexpr (Asym) { + const int q = static_cast((packed >> (2 * j)) & 0x3); + w = (static_cast(q) - zero) * scale; + } else { + // Sign-extend 2-bit by shifting the field into the top bits of an int8. + const int shift = 6 - 2 * j; // 6, 4, 2, 0 for fields 0..3 + const int8_t s8 = static_cast((packed << shift) & 0xC0); + const int q = static_cast(s8 >> 6); + w = static_cast(q) * scale; + } + weights_KN[out_base + static_cast(k_base + j) * N] = static_cast(w); } - weights_KN[(static_cast(e) * K + static_cast(k)) * N + static_cast(n)] = - static_cast(w); }); } @@ -274,7 +333,10 @@ void launch_dequant_fp8(sycl::queue* q, const uint8_t* weights_NK, const ScalarT if (E == 0 || N == 0 || K == 0) return; const int num_groups_k = K / group_size; - sycl::range<3> global{static_cast(E), static_cast(K), + // PACK_K_FP8 (=4) consecutive K outputs per work-item, sharing one scale + // load. PACK_K_FP8 <= group_size in all supported configurations. + const int k_tiles = (K + PACK_K_FP8 - 1) / PACK_K_FP8; + sycl::range<3> global{static_cast(E), static_cast(k_tiles), static_cast((N + WG_N - 1) / WG_N) * WG_N}; sycl::range<3> local{1, 1, static_cast(WG_N)}; @@ -282,18 +344,23 @@ void launch_dequant_fp8(sycl::queue* q, const uint8_t* weights_NK, const ScalarT sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { const int e = static_cast(it.get_global_id(0)); if (num_tokens_per_expert != nullptr && num_tokens_per_expert[e] == 0) return; - const int k = static_cast(it.get_global_id(1)); + const int k_base = static_cast(it.get_global_id(1)) * PACK_K_FP8; const int n = static_cast(it.get_global_id(2)); if (n >= N) return; - const int g = k / group_size; + const size_t w_row = (static_cast(e) * N + static_cast(n)) * K; + const size_t out_base = static_cast(e) * K * N + static_cast(n); + const int g = k_base / group_size; const size_t s_idx = (static_cast(e) * N + static_cast(n)) * num_groups_k + static_cast(g); const float scale = static_cast(scales[s_idx]); - const uint8_t raw = - weights_NK[(static_cast(e) * N + static_cast(n)) * K + static_cast(k)]; - const float w = moe_dequant::decode_fp8(raw) * scale; - weights_KN[(static_cast(e) * K + static_cast(k)) * N + static_cast(n)] = - static_cast(w); +#pragma unroll + for (int j = 0; j < PACK_K_FP8; ++j) { + const int k = k_base + j; + if (k >= K) break; + const uint8_t raw = weights_NK[w_row + static_cast(k)]; + const float w = moe_dequant::decode_fp8(raw) * scale; + weights_KN[out_base + static_cast(k) * N] = static_cast(w); + } }); } From c5f7c2188c2c081cb86567d6d6ac30c8daf5e721 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Jun 2026 09:40:22 +0000 Subject: [PATCH 25/28] docs(moe_prefill): clarify PACK_K must divide group_size for hoist --- .../wrapper/include/sycl_tla_moe_mixed.hpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp index ccbe77498..94f91085e 100644 --- a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp @@ -85,8 +85,12 @@ class MoEDequantKernelFP8; // - INT2: every byte was read four times -> now once. // - All quantized paths: every scale (and zero) was reloaded by every K // element in the group -> now once per PACK_K elements. -// Group-wise scale sharing is safe because PACK_K <= group_size in every -// supported configuration (group_size >= 32 in practice). +// Group-wise scale sharing is safe because PACK_K *divides* group_size in +// every supported configuration: PACK_K is 2 (INT4) or 4 (INT2/INT8/FP8), +// and group_size is always a power of two >= 32 (typically 32, 64, or 128). +// As a result `k_base / group_size` yields the same group index for every +// K element in the PACK_K run, and each kernel can hoist a single scale +// (and, for asym, a single zero) load to amortise across the run. // // WG_N is the sub-group store width along N. 32 yields a single 64-byte // coalesced burst per row for FP16/BF16 writes, which matches the L1 @@ -173,8 +177,8 @@ void launch_dequant_int8(sycl::queue* q, const uint8_t* weights_NK, const Scalar const size_t w_row = (static_cast(e) * N + static_cast(n)) * K; const size_t out_base = static_cast(e) * K * N + static_cast(n); // Hoist scale/zero loads: PACK_K_INT8 K values share the same group - // because PACK_K_INT8 <= group_size and groups are aligned at K - // multiples of group_size (PACK_K_INT8 divides group_size). + // because PACK_K_INT8 divides group_size (see PACK_K constants + // above), so `k_base / group_size` is constant across the run. const int g = k_base / group_size; const size_t s_idx = (static_cast(e) * N + static_cast(n)) * num_groups_k + static_cast(g); @@ -232,6 +236,8 @@ void launch_dequant_int4(sycl::queue* q, const uint8_t* weights_NKp, const Scala const int n = static_cast(it.get_global_id(2)); if (n >= N) return; const int k_base = kp * PACK_K_INT4; + // PACK_K_INT4 (=2) divides group_size, so all PACK_K_INT4 K values + // in this run share the same scale/zero (one hoisted load each). const int g = k_base / group_size; const size_t s_idx = (static_cast(e) * N + static_cast(n)) * num_groups_k + static_cast(g); @@ -295,6 +301,8 @@ void launch_dequant_int2(sycl::queue* q, const uint8_t* weights_NKp, const Scala const int n = static_cast(it.get_global_id(2)); if (n >= N) return; const int k_base = kp * PACK_K_INT2; + // PACK_K_INT2 (=4) divides group_size, so all PACK_K_INT2 K values + // in this run share the same scale/zero (one hoisted load each). const int g = k_base / group_size; const size_t s_idx = (static_cast(e) * N + static_cast(n)) * num_groups_k + static_cast(g); From 81712c84a861fa361cdd8a05ff903e4c2acd33a5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Jun 2026 10:31:23 +0000 Subject: [PATCH 26/28] test: add accuracy UT for MoE prefill (ark.moe_gemm / moe_gemm_prefill) --- .../ark/test/test_moe_prefill_accuracy.py | 366 ++++++++++++++++++ 1 file changed, 366 insertions(+) create mode 100644 auto_round_extension/ark/test/test_moe_prefill_accuracy.py diff --git a/auto_round_extension/ark/test/test_moe_prefill_accuracy.py b/auto_round_extension/ark/test/test_moe_prefill_accuracy.py new file mode 100644 index 000000000..8be30e19c --- /dev/null +++ b/auto_round_extension/ark/test/test_moe_prefill_accuracy.py @@ -0,0 +1,366 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Accuracy (parity) tests for ``ark.moe_gemm`` / ``ark.moe_gemm_prefill``. + +This complements ``test_moe_prefill_perf.py`` (which measures throughput) and +``test_moe.py`` (which exercises small toy shapes). The shape matrix here +mirrors the *production-scale* shapes used by the perf benchmark — large +hidden sizes (up to 14336), high expert counts (up to 64), and uneven +token-per-expert distributions typical of Mixtral/DeepSeek-style MoE models +during prefill. + +For each (dtype, quant-scheme) combination the test: + + 1. Packs / quantizes random weights. + 2. Dequantizes them back to the activation dtype. + 3. Runs the ark kernel on the *packed* weights. + 4. Runs a per-expert ``A @ W.T`` reference on the *dequantized* weights. + 5. Compares with ``torch.testing.assert_close`` at quant-appropriate + tolerances. + +This isolates kernel correctness (matmul + on-the-fly dequant) from +quantization noise: the reference shares the same dequantized weights as the +kernel, so tolerances reflect only accumulator/order-of-operations +differences, not quant error. + +How to run:: + + pytest -v -s auto_round_extension/ark/test/test_moe_prefill_accuracy.py +""" + +import auto_round_kernel +import pytest +import torch + +# Reuse pack/dequant helpers from the correctness tests. +from test_moe import ( # noqa: E402 + _dequant_fp8, + _dequant_int2_asym, + _dequant_int2_sym, + _dequant_int4_asym, + _dequant_int4_sym, + _dequant_int8_asym, + _dequant_int8_sym, + _pack_fp8, + _pack_int2_asym, + _pack_int2_sym, + _pack_int4_asym, + _pack_int4_sym, + _pack_int8_asym, + _pack_int8_sym, +) + +ark = auto_round_kernel + + +# --------------------------------------------------------------------------- +# Skip reasons (mirror test_moe_prefill_perf.py) +# --------------------------------------------------------------------------- + + +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def _xpu_skip_reason() -> str: + if not hasattr(torch, "xpu"): + return "torch has no xpu submodule (need an Intel XPU build of torch)" + if not torch.xpu.is_available(): + return "torch.xpu.is_available() == False (no XPU device or driver visible)" + return "" + + +def _prefill_skip_reason() -> str: + reason = _xpu_skip_reason() + if reason: + return reason + if ark.xpu_lib is None: + return ( + "ark.xpu_lib is None -- the XPU extension module " + "(auto_round_kernel_xpu) failed to import; check that auto_round_kernel " + "was installed for THIS Python env with XPU support enabled" + ) + if not hasattr(ark.xpu_lib, "moe_gemm"): + return ( + "ark.xpu_lib loaded but has no moe_gemm symbol -- " + "rebuild with ARK_SYCL_TLA=ON to compile the MoE GEMM kernel" + ) + return "" + + +def _quantized_prefill_skip_reason() -> str: + reason = _prefill_skip_reason() + if reason: + return reason + if not hasattr(ark.xpu_lib, "moe_gemm_prefill"): + return ( + "ark.xpu_lib loaded but has no moe_gemm_prefill symbol -- " + "rebuild with ARK_SYCL_TLA=ON to compile the quantized MoE prefill kernel" + ) + return "" + + +_PREFILL_SKIP = _prefill_skip_reason() +_QUANT_PREFILL_SKIP = _quantized_prefill_skip_reason() + + +# --------------------------------------------------------------------------- +# Shape matrix +# +# Subset of the perf benchmark's PREFILL_SHAPES. We keep the production-scale +# hidden sizes (up to N/K = 14336) and high expert counts (up to E = 64) so +# the accuracy check exercises the same code paths as the perf benchmark, +# but skip duplicate up-proj/down-proj rows to keep wall-clock reasonable. +# --------------------------------------------------------------------------- + +PREFILL_SHAPES = [ + # (label, num_experts, tokens_per_expert_list, N, K) + ("small E=8 ", 8, [32, 28, 30, 35, 33, 31, 29, 34], 4096, 4096), + ("medium E=8 ", 8, [64, 60, 68, 72, 65, 63, 70, 66], 4096, 14336), + ("large E=16", 16, [16] * 16, 2048, 2048), + ("large E=32", 32, [8] * 32, 2048, 2048), + ("large E=64", 64, [4] * 64, 2048, 2048), + ("uneven E=8 ", 8, [100, 50, 75, 80, 60, 90, 70, 85], 4096, 4096), +] + + +# --------------------------------------------------------------------------- +# Reference implementation +# --------------------------------------------------------------------------- + + +def _reference_moe_prefill(activations, dequant_weights_NK, num_tokens_per_expert): + """Per-expert ``A @ W.T`` over ``[E, N, K]`` dequantized weights. + + This is the same reference used by ``TestMoEGemmPrefill`` in ``test_moe.py`` + and matches what a model would compute when no fused kernel is available. + """ + total_tokens, _ = activations.shape + E, N, _ = dequant_weights_NK.shape + out = torch.empty(total_tokens, N, dtype=activations.dtype, device=activations.device) + offset = 0 + for e in range(E): + n_tokens = int(num_tokens_per_expert[e].item()) + if n_tokens == 0: + continue + a = activations[offset : offset + n_tokens] + out[offset : offset + n_tokens] = a @ dequant_weights_NK[e].T + offset += n_tokens + return out + + +# --------------------------------------------------------------------------- +# Tolerances per quant scheme +# +# These mirror the tolerances used in ``test_moe.py`` for the small-shape +# parity tests. We loosen slightly for the large-shape cases because longer +# K-reduction accumulates more rounding noise (FP16/BF16 accumulators). +# --------------------------------------------------------------------------- + +_TOL_FP = dict(rtol=3e-2, atol=3e-2) +_TOL_INT8 = dict(rtol=7e-2, atol=7e-2) +_TOL_INT4 = dict(rtol=7e-2, atol=7e-2) +_TOL_INT2 = dict(rtol=1.5e-1, atol=1.5e-1) +_TOL_FP8 = dict(rtol=7e-2, atol=7e-2) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(bool(_PREFILL_SKIP), reason=_PREFILL_SKIP or "ok") +class TestMoEGemmPrefillAccuracy: + """Parity tests for ``moe_gemm`` / ``moe_gemm_prefill`` at production shapes. + + Each test iterates the production-scale shape matrix and asserts the ark + kernel output matches a per-expert dequant + ``A @ W.T`` reference. + """ + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_accuracy_fp(self, dtype): + for label, E, tpe, N, K in PREFILL_SHAPES: + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + # ark.moe_gemm wants weights in [E, K, N] layout. + weights_KN = (torch.randn(E, K, N, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm(activations, weights_KN, ntpe) + + # Reference uses [E, N, K] layout. + weights_NK = weights_KN.transpose(1, 2).contiguous() + ref = _reference_moe_prefill(activations, weights_NK, ntpe) + + assert out.shape == (total_tokens, N), f"{label}: bad shape {out.shape}" + assert out.dtype == dtype, f"{label}: bad dtype {out.dtype}" + torch.testing.assert_close( + out, ref, msg=lambda m, lbl=label: f"[{lbl}] {m}", **_TOL_FP + ) + + @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_accuracy_int4(self, dtype, asym): + group_size = 128 + for label, E, tpe, N, K in PREFILL_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int4_sym(w_float, scales, group_size) + dequant = _dequant_int4_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=4, + group_size=group_size, + asym=asym, + ) + + ref = _reference_moe_prefill(activations, dequant, ntpe) + assert out.shape == (total_tokens, N), f"{label}: bad shape {out.shape}" + torch.testing.assert_close( + out, ref, msg=lambda m, lbl=label: f"[{lbl}] {m}", **_TOL_INT4 + ) + + @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_accuracy_int8(self, dtype, asym): + group_size = 128 + for label, E, tpe, N, K in PREFILL_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int8_sym(w_float, scales, group_size) + dequant = _dequant_int8_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=8, + group_size=group_size, + asym=asym, + ) + + ref = _reference_moe_prefill(activations, dequant, ntpe) + assert out.shape == (total_tokens, N), f"{label}: bad shape {out.shape}" + torch.testing.assert_close( + out, ref, msg=lambda m, lbl=label: f"[{lbl}] {m}", **_TOL_INT8 + ) + + @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_accuracy_int2(self, dtype, asym): + group_size = 128 + for label, E, tpe, N, K in PREFILL_SHAPES: + if K % group_size != 0 or K % 4 != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int2_sym(w_float, scales, group_size) + dequant = _dequant_int2_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=2, + group_size=group_size, + asym=asym, + ) + + ref = _reference_moe_prefill(activations, dequant, ntpe) + assert out.shape == (total_tokens, N), f"{label}: bad shape {out.shape}" + torch.testing.assert_close( + out, ref, msg=lambda m, lbl=label: f"[{lbl}] {m}", **_TOL_INT2 + ) + + @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + def test_accuracy_fp8(self, dtype, fp8_dtype): + group_size = 128 + for label, E, tpe, N, K in PREFILL_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) + dequant = _dequant_fp8(packed, scales, group_size, dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, + packed, + ntpe, + scales=scales, + group_size=group_size, + asym=False, + ) + + ref = _reference_moe_prefill(activations, dequant, ntpe) + assert out.shape == (total_tokens, N), f"{label}: bad shape {out.shape}" + torch.testing.assert_close( + out, ref, msg=lambda m, lbl=label: f"[{lbl}] {m}", **_TOL_FP8 + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) From 26aa210f33abda46dae041667a8381b977bfc39b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Jun 2026 10:58:02 +0000 Subject: [PATCH 27/28] feat(ark): unified `moe` API + model-level perf test - Add `ark.moe(...)` dispatcher that auto-selects between `moe_gemm_decode` (GEMV-tuned, small tokens/expert) and `moe_gemm_prefill` (GEMM-tuned, many tokens/expert). Single call site for model code; `phase="decode"` or `phase="prefill"` skips the auto-dispatch host-device sync. - Add `test_moe_unified.py`: bit-parity tests vs the underlying kernels across fp/int8/int4/int2/fp8 + dispatch correctness + error path. - Add `test_moe_model_perf.py`: model-level forward (1 prefill + N decode steps over L MoE layers) comparing always_prefill, always_decode, manual_branch, unified_auto, unified_hinted strategies. - `moe_gemm_decode` / `moe_gemm_prefill` are kept for backward compatibility. --- .../ark/auto_round_kernel/__init__.py | 115 ++++++ .../ark/test/test_moe_model_perf.py | 364 +++++++++++++++++ .../ark/test/test_moe_unified.py | 376 ++++++++++++++++++ 3 files changed, 855 insertions(+) create mode 100644 auto_round_extension/ark/test/test_moe_model_perf.py create mode 100644 auto_round_extension/ark/test/test_moe_unified.py diff --git a/auto_round_extension/ark/auto_round_kernel/__init__.py b/auto_round_extension/ark/auto_round_kernel/__init__.py index 0e6e65fa4..cca7f7950 100644 --- a/auto_round_extension/ark/auto_round_kernel/__init__.py +++ b/auto_round_extension/ark/auto_round_kernel/__init__.py @@ -1545,6 +1545,121 @@ def clear_moe_prefill_workspace_cache() -> None: _MOE_PREFILL_WORKSPACE_CACHE.clear() +# --------------------------------------------------------------------------- +# Unified MoE entry point +# +# `moe_gemm_decode` and `moe_gemm_prefill` accept identical argument shapes +# and dtypes -- the only difference is which underlying SYCL kernel is +# launched (a GEMV variant tuned for 1-2 tokens/expert vs. a Grouped GEMM +# variant tuned for many tokens/expert). Model code that runs through both +# regimes (prefill of a prompt, then autoregressive decode) traditionally +# has to keep two call sites and branch on phase. `moe(...)` collapses that +# into a single API and auto-selects the right kernel from the token +# distribution. +# +# Callers that already know the phase (e.g., a model's generation loop knows +# whether it's in prefill or decode) should pass it via the `phase` argument +# to avoid the small host-device sync that `phase="auto"` needs to inspect +# `num_tokens_per_expert.max()`. +# --------------------------------------------------------------------------- + +# Default tokens-per-expert threshold used by `phase="auto"`. The decode +# GEMV kernel is faster when every expert sees only a handful of tokens +# (TopK >= 1 with batch size 1-4); above that the GEMM-tuned prefill kernel +# wins. The crossover is hardware-dependent but `4` is a conservative default +# that matches the regime `moe_gemm_decode`'s docstring describes +# ("typically only 1-2 tokens", up to top-k * small batch). +_MOE_AUTO_DECODE_MAX_TOKENS_PER_EXPERT = 4 + +_MOE_VALID_PHASES = ("auto", "decode", "prefill") + + +def moe( + activations: torch.Tensor, + weights: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + *, + scales: Optional[torch.Tensor] = None, + zeros: Optional[torch.Tensor] = None, + weight_bits: int = 4, + group_size: int = 128, + asym: bool = False, + phase: str = "auto", + decode_threshold: int = _MOE_AUTO_DECODE_MAX_TOKENS_PER_EXPERT, +) -> torch.Tensor: + """Unified MoE GEMM entry point that dispatches to decode or prefill. + + This is a thin Python-side dispatcher over :func:`moe_gemm_decode` and + :func:`moe_gemm_prefill`. The two underlying kernels accept the same + argument shapes/dtypes (see :func:`moe_gemm_decode` for the full layout + contract); ``moe`` simply picks the one that is faster for the current + token distribution so model code can have a single call site for both + prefill and decode phases. + + Args: + activations: ``[total_tokens, K]`` in fp16 or bf16. + weights: ``[E, N, K_packed]`` -- see :func:`moe_gemm_decode` for the + quant-specific layout/dtype contract. + num_tokens_per_expert: ``[E]`` int32. Sum must equal + ``activations.shape[0]``. + scales, zeros, weight_bits, group_size, asym: forwarded to the + underlying kernel; see :func:`moe_gemm_decode`. + phase: dispatch mode. + + * ``"auto"`` (default): inspect ``num_tokens_per_expert.max()`` + and pick decode if every expert sees ``<= decode_threshold`` + tokens, otherwise prefill. This incurs one small host-device + sync per call. + * ``"decode"``: always dispatch to :func:`moe_gemm_decode`. Use + when the model's generation loop already knows it is in the + decode phase; avoids the sync. + * ``"prefill"``: always dispatch to :func:`moe_gemm_prefill`. + Use when the model knows it is in the prefill phase. + decode_threshold: ``"auto"`` mode dispatches to decode when + ``num_tokens_per_expert.max() <= decode_threshold``. Defaults to + 4 (the regime the decode GEMV kernel is tuned for). + + Returns: + ``[total_tokens, N]`` in the activations dtype. Bit-identical to the + underlying kernel that was dispatched. + """ + if phase not in _MOE_VALID_PHASES: + raise ValueError(f"phase must be one of {_MOE_VALID_PHASES}, got {phase!r}") + + if phase == "auto": + # `.max().item()` triggers a host-device sync; callers in tight + # decode loops should pass `phase="decode"` explicitly to skip this. + # We tolerate a non-int32 / non-contiguous tensor here because the + # downstream kernel wrappers will normalise it anyway. + if num_tokens_per_expert.numel() == 0: + raise ValueError("num_tokens_per_expert must be non-empty") + max_tpe = int(num_tokens_per_expert.max().item()) + phase = "decode" if max_tpe <= int(decode_threshold) else "prefill" + + if phase == "decode": + return moe_gemm_decode( + activations, + weights, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=weight_bits, + group_size=group_size, + asym=asym, + ) + # phase == "prefill" + return moe_gemm_prefill( + activations, + weights, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=weight_bits, + group_size=group_size, + asym=asym, + ) + + def patch_torch_sdpa(*args, **kwargs): from .torch_sdpa_patch import patch_torch_sdpa_with_ark diff --git a/auto_round_extension/ark/test/test_moe_model_perf.py b/auto_round_extension/ark/test/test_moe_model_perf.py new file mode 100644 index 000000000..eb5bc4e69 --- /dev/null +++ b/auto_round_extension/ark/test/test_moe_model_perf.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model-level performance test for the unified ``ark.moe`` dispatcher. + +This benchmark simulates a realistic MoE LLM generation trace: + + 1. A **prefill** step that processes the whole prompt (many tokens/expert). + 2. A handful of **decode** steps that produce one token each (1-2 + tokens/expert after TopK routing). + +Each step runs through ``L`` MoE layers; each layer consists of an +``up_proj`` (``K -> N_inter``) and a ``down_proj`` (``N_inter -> K``) MoE +GEMM. This is the standard shape for Mixtral-style models. + +We then compare four call strategies for the *same* trace: + + * ``always_prefill``: model code always calls ``moe_gemm_prefill``. + Represents the simple-but-suboptimal "use the prefill kernel + everywhere" approach. + * ``always_decode``: model code always calls ``moe_gemm_decode``. + The opposite extreme -- decode kernel for both phases. + * ``manual_branch``: model code branches on the known phase (``if + is_prefill: moe_gemm_prefill else moe_gemm_decode``). Optimal but + requires two call sites and a phase flag in the model. + * ``unified_auto``: single call site: ``ark.moe(..., phase="auto")``. + The dispatcher picks the right kernel from ``num_tokens_per_expert``; + pays one tiny host-device sync per call. + * ``unified_hinted``: single call site: ``ark.moe(..., phase=)``. + Skips the sync when the caller already knows the phase. + +The reported speedup is ``always_prefill_time / strategy_time`` -- i.e., +how much faster the unified API gets you relative to the naive +single-kernel approach a typical first-pass integration would use. + +How to run:: + + pytest -v -s auto_round_extension/ark/test/test_moe_model_perf.py +""" + +from dataclasses import dataclass +from typing import Callable + +import auto_round_kernel +import pytest +import torch + +from test_moe import ( # noqa: E402 + _pack_int4_sym, +) + +ark = auto_round_kernel + + +# --------------------------------------------------------------------------- +# Skip reasons (mirror the other perf tests) +# --------------------------------------------------------------------------- + + +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def _skip_reason() -> str: + if not _xpu_available(): + return "XPU not available" + if ark.xpu_lib is None: + return "ark.xpu_lib is None (XPU extension failed to import)" + for sym in ("moe_gemm_decode", "moe_gemm_prefill"): + if not hasattr(ark.xpu_lib, sym): + return f"ark.xpu_lib missing {sym} (need ARK_SYCL_TLA=ON)" + if not hasattr(ark, "moe"): + return "ark.moe (unified entry point) not exported by auto_round_kernel" + return "" + + +_SKIP = _skip_reason() + +print( + "[moe-model-perf] xpu_available=%s xpu_lib=%s has_moe=%s" + % ( + _xpu_available(), + "loaded" if ark.xpu_lib is not None else "None", + hasattr(ark, "moe"), + ) +) +if _SKIP: + print("[moe-model-perf] suite will SKIP. reason: %s" % _SKIP) + + +# --------------------------------------------------------------------------- +# Timing utility +# --------------------------------------------------------------------------- + +WARMUP = 3 +ITERS = 10 + + +def _xpu_time_ms(fn, warmup: int = WARMUP, iters: int = ITERS) -> float: + for _ in range(warmup): + fn() + torch.xpu.synchronize() + timings = [] + for _ in range(iters): + start = torch.xpu.Event(enable_timing=True) + end = torch.xpu.Event(enable_timing=True) + start.record() + fn() + end.record() + end.synchronize() + timings.append(start.elapsed_time(end)) + timings.sort() + return timings[len(timings) // 2] + + +# --------------------------------------------------------------------------- +# Model + trace definitions +# --------------------------------------------------------------------------- + + +@dataclass +class MoELayerWeights: + """Quantized weights for one MoE layer (up_proj + down_proj).""" + + up_packed: torch.Tensor # [E, N_inter, K // 2] (INT4) + up_scales: torch.Tensor # [E, N_inter, K // group_size] + down_packed: torch.Tensor # [E, K, N_inter // 2] + down_scales: torch.Tensor # [E, K, N_inter // group_size] + + +def _build_layers(num_layers, E, K, N_inter, group_size, dtype): + layers = [] + for _ in range(num_layers): + # up_proj: [E, N_inter, K] + up_float = (torch.randn(E, N_inter, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + up_scales = torch.empty(E, N_inter, K // group_size, dtype=dtype, device="xpu") + up_packed = _pack_int4_sym(up_float, up_scales, group_size) + + # down_proj: [E, K, N_inter] + down_float = (torch.randn(E, K, N_inter, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + down_scales = torch.empty(E, K, N_inter // group_size, dtype=dtype, device="xpu") + down_packed = _pack_int4_sym(down_float, down_scales, group_size) + + layers.append(MoELayerWeights(up_packed, up_scales, down_packed, down_scales)) + return layers + + +@dataclass +class Step: + """One trace step: activations + per-expert token counts + phase flag.""" + + activations: torch.Tensor # [total_tokens, K] + ntpe: torch.Tensor # [E] int32 + is_prefill: bool + + +def _build_trace(E, K, prompt_tokens, decode_steps, dtype, seed=0): + """Mixed prefill+decode trace mimicking an LLM generation request.""" + g = torch.Generator(device="xpu") + g.manual_seed(seed) + + steps = [] + + # ---- Prefill step: prompt_tokens routed across E experts (roughly even). ---- + base = prompt_tokens // E + rem = prompt_tokens - base * E + tpe = [base + (1 if i < rem else 0) for i in range(E)] + activations = torch.randn(prompt_tokens, K, dtype=dtype, device="xpu", generator=g) + steps.append(Step(activations, torch.tensor(tpe, dtype=torch.int32, device="xpu"), is_prefill=True)) + + # ---- Decode steps: 1 token, TopK=2 routing -> 2 experts see 1 token each. ---- + # We simulate batch=1 + top_k=2: total_tokens=2, two experts each get 1 token. + for s in range(decode_steps): + tpe = [0] * E + chosen = [(s * 2) % E, (s * 2 + 1) % E] + if chosen[0] == chosen[1]: + chosen[1] = (chosen[1] + 1) % E + tpe[chosen[0]] = 1 + tpe[chosen[1]] = 1 + activations = torch.randn(2, K, dtype=dtype, device="xpu", generator=g) + steps.append(Step(activations, torch.tensor(tpe, dtype=torch.int32, device="xpu"), is_prefill=False)) + + return steps + + +# --------------------------------------------------------------------------- +# Call strategies (the four "model-side" integration patterns) +# --------------------------------------------------------------------------- + + +_CallStrategy = Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, bool, int], torch.Tensor] + + +def _strat_always_prefill(activations, packed, ntpe, scales, _zeros, _is_prefill, group_size): + return ark.moe_gemm_prefill( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + ) + + +def _strat_always_decode(activations, packed, ntpe, scales, _zeros, _is_prefill, group_size): + return ark.moe_gemm_decode( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + ) + + +def _strat_manual_branch(activations, packed, ntpe, scales, _zeros, is_prefill, group_size): + if is_prefill: + return ark.moe_gemm_prefill( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + ) + return ark.moe_gemm_decode( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + ) + + +def _strat_unified_auto(activations, packed, ntpe, scales, _zeros, _is_prefill, group_size): + return ark.moe( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + phase="auto", + ) + + +def _strat_unified_hinted(activations, packed, ntpe, scales, _zeros, is_prefill, group_size): + return ark.moe( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + phase="prefill" if is_prefill else "decode", + ) + + +_STRATEGIES = [ + ("always_prefill ", _strat_always_prefill), + ("always_decode ", _strat_always_decode), + ("manual_branch ", _strat_manual_branch), + ("unified_auto ", _strat_unified_auto), + ("unified_hinted ", _strat_unified_hinted), +] + + +# --------------------------------------------------------------------------- +# End-to-end "model forward over the trace" runner +# --------------------------------------------------------------------------- + + +def _forward_full_trace(strategy, layers, trace, group_size): + """Run the strategy across every step and every layer of the trace. + + Each layer is up_proj followed by down_proj. The down_proj input is just + the up_proj output reshaped to ``[total_tokens, N_inter]`` (we skip the + SiLU/element-wise ops; we only want to measure the MoE-kernel cost, + which is what the dispatcher changes). + """ + for step in trace: + x = step.activations + for layer in layers: + up_out = strategy( + x, layer.up_packed, step.ntpe, layer.up_scales, None, + step.is_prefill, group_size, + ) + # down_proj takes the up output (same total_tokens, dim = N_inter). + x_down = strategy( + up_out, layer.down_packed, step.ntpe, layer.down_scales, None, + step.is_prefill, group_size, + ) + # Feed down_proj output into the next layer's up_proj. + x = x_down + + +# --------------------------------------------------------------------------- +# The benchmark +# --------------------------------------------------------------------------- + + +# Three model presets mimicking common MoE configurations. Each runs the +# whole trace (1 prefill + N decode steps) through L layers. +# +# Shapes are intentionally smaller than the perf-bench shapes used in +# `test_moe_prefill_perf.py` so the full multi-layer forward stays tractable +# (we run the whole trace ITERS times for the median timing). +_MODEL_PRESETS = [ + # (label, num_layers, E, K, N_inter, group_size, prompt_tokens, decode_steps) + ("mixtral-tiny L=2 E=8 ", 2, 8, 1024, 2048, 128, 64, 8), + ("mixtral-small L=4 E=8 ", 4, 8, 2048, 4096, 128, 128, 8), + ("deepseek-tiny L=2 E=16", 2, 16, 1024, 2048, 128, 96, 8), +] + + +def _print_header(): + print() + print("=" * 110) + print( + f"Model-level MoE perf: full forward over (prefill + N x decode) for L layers; " + f"INT4 sym weights" + ) + print("-" * 110) + print( + f"{'preset':<24}{'strategy':<18}{'forward(ms)':>14}{'speedup vs prefill-only':>28}" + ) + print("-" * 110) + + +def _print_row(preset, strat_name, ms, speedup): + print(f"{preset:<24}{strat_name:<18}{ms:>14.4f}{speedup:>26.2f}x") + + +@pytest.mark.skipif(bool(_SKIP), reason=_SKIP or "ok") +class TestMoEModelPerf: + """Model-level perf comparison for the unified MoE dispatcher. + + The benchmark validates two things in one go: + 1. *Correctness of the dispatcher*: the perf-comparison loop also + confirms every strategy produces a runnable forward (any kernel + error during timing causes a hard failure). + 2. *Speedup vs naive single-kernel integration*: the printed table + shows how much the unified API saves a model author who would + otherwise reach for ``always_prefill``. + """ + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_model_forward(self, dtype): + _print_header() + for preset_label, L, E, K, N_inter, group_size, prompt_tokens, decode_steps in _MODEL_PRESETS: + layers = _build_layers(L, E, K, N_inter, group_size, dtype) + trace = _build_trace(E, K, prompt_tokens, decode_steps, dtype) + + # Time each strategy on the same (layers, trace) pair. + baseline_ms = None + for strat_name, strat_fn in _STRATEGIES: + def _run(strat_fn=strat_fn): + _forward_full_trace(strat_fn, layers, trace, group_size) + + ms = _xpu_time_ms(_run) + if baseline_ms is None: + baseline_ms = ms # first row is always always_prefill + speedup = baseline_ms / ms if ms > 0 else float("nan") + _print_row(f"{preset_label} {dtype}".strip(), strat_name, ms, speedup) + + print("-" * 110) + + # Free the per-preset workspace so the next preset starts clean. + ark.clear_moe_prefill_workspace_cache() + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/auto_round_extension/ark/test/test_moe_unified.py b/auto_round_extension/ark/test/test_moe_unified.py new file mode 100644 index 000000000..308dcf994 --- /dev/null +++ b/auto_round_extension/ark/test/test_moe_unified.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parity tests for the unified ``ark.moe`` dispatcher. + +``ark.moe`` is a thin Python-side dispatcher that picks between +``moe_gemm_decode`` (GEMV-tuned) and ``moe_gemm_prefill`` (GEMM-tuned) based +on the token distribution. The contract is that dispatching should never +change the numerical result: ``moe(...)`` must be **bit-identical** to the +underlying kernel that was dispatched. + +This file checks: + + * Dispatch correctness: ``phase="auto"`` picks decode when every expert + sees few tokens and prefill otherwise. + * Bit-parity: ``moe(phase="auto")`` matches the kernel it dispatched to. + * Explicit-phase parity: ``moe(phase="decode")`` matches + ``moe_gemm_decode``, ``moe(phase="prefill")`` matches + ``moe_gemm_prefill`` (for the same inputs). + * Coverage across all supported quant schemes (fp / int8 / int4 / int2 / + fp8) and both activation dtypes. + * Argument validation (bad ``phase`` raises ``ValueError``). +""" + +import auto_round_kernel +import pytest +import torch + +# Reuse pack/dequant helpers from the correctness tests. +from test_moe import ( # noqa: E402 + _dequant_fp8, + _dequant_int2_asym, + _dequant_int2_sym, + _dequant_int4_asym, + _dequant_int4_sym, + _dequant_int8_asym, + _dequant_int8_sym, + _pack_fp8, + _pack_int2_asym, + _pack_int2_sym, + _pack_int4_asym, + _pack_int4_sym, + _pack_int8_asym, + _pack_int8_sym, +) + +ark = auto_round_kernel + + +# --------------------------------------------------------------------------- +# Skip reasons +# --------------------------------------------------------------------------- + + +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def _unified_skip_reason() -> str: + if not _xpu_available(): + return "XPU not available" + if ark.xpu_lib is None: + return "ark.xpu_lib is None (XPU extension failed to import)" + if not hasattr(ark.xpu_lib, "moe_gemm_decode"): + return "ark.xpu_lib missing moe_gemm_decode (need ARK_SYCL_TLA=ON)" + if not hasattr(ark.xpu_lib, "moe_gemm_prefill"): + return "ark.xpu_lib missing moe_gemm_prefill (need ARK_SYCL_TLA=ON)" + if not hasattr(ark, "moe"): + return "ark.moe (unified entry point) not exported by auto_round_kernel" + return "" + + +_UNIFIED_SKIP = _unified_skip_reason() + + +# --------------------------------------------------------------------------- +# Small shapes (one decode-shaped, one prefill-shaped) -- keep wall-clock low. +# --------------------------------------------------------------------------- + +_DECODE_SHAPE = dict(num_experts=4, tokens_per_expert=[1, 2, 0, 2], N=128, K=256) +_PREFILL_SHAPE = dict(num_experts=4, tokens_per_expert=[16, 8, 0, 20], N=128, K=256) + + +def _make_int4_sym(E, N, K, group_size, dtype, total_tokens): + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_sym(w_float, scales, group_size) + return activations, packed, scales, None + + +def _make_int4_asym(E, N, K, group_size, dtype, total_tokens): + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_asym(w_float, scales, zeros, group_size) + return activations, packed, scales, zeros + + +def _make_int8_sym(E, N, K, group_size, dtype, total_tokens): + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_sym(w_float, scales, group_size) + return activations, packed, scales, None + + +def _make_int8_asym(E, N, K, group_size, dtype, total_tokens): + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_asym(w_float, scales, zeros, group_size) + return activations, packed, scales, zeros + + +def _make_int2_sym(E, N, K, group_size, dtype, total_tokens): + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_sym(w_float, scales, group_size) + return activations, packed, scales, None + + +def _make_int2_asym(E, N, K, group_size, dtype, total_tokens): + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_asym(w_float, scales, zeros, group_size) + return activations, packed, scales, zeros + + +def _make_fp8(E, N, K, group_size, dtype, total_tokens, fp8_dtype): + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) + return activations, packed, scales, None + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(bool(_UNIFIED_SKIP), reason=_UNIFIED_SKIP or "ok") +class TestMoeUnifiedDispatch: + """Tests for the auto-dispatch logic itself.""" + + def test_auto_picks_decode_for_small_tokens_per_expert(self): + shape = _DECODE_SHAPE + total_tokens = sum(shape["tokens_per_expert"]) + E, N, K = shape["num_experts"], shape["N"], shape["K"] + group_size = 128 + dtype = torch.float16 + + activations, packed, scales, _ = _make_int4_sym(E, N, K, group_size, dtype, total_tokens) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + out_auto = ark.moe( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + phase="auto", + ) + out_decode = ark.moe_gemm_decode( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + ) + # max tokens/expert = 2 (<= default threshold 4) -> dispatched to decode + # -> output must be bit-identical to moe_gemm_decode. + torch.testing.assert_close(out_auto, out_decode, rtol=0, atol=0) + + def test_auto_picks_prefill_for_large_tokens_per_expert(self): + shape = _PREFILL_SHAPE + total_tokens = sum(shape["tokens_per_expert"]) + E, N, K = shape["num_experts"], shape["N"], shape["K"] + group_size = 128 + dtype = torch.float16 + + activations, packed, scales, _ = _make_int4_sym(E, N, K, group_size, dtype, total_tokens) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + out_auto = ark.moe( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + phase="auto", + ) + out_prefill = ark.moe_gemm_prefill( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + ) + torch.testing.assert_close(out_auto, out_prefill, rtol=0, atol=0) + + def test_decode_threshold_override(self): + # Same prefill-shaped input but bump the threshold above the max + # tokens/expert -> auto must now pick decode. + shape = _PREFILL_SHAPE + total_tokens = sum(shape["tokens_per_expert"]) + E, N, K = shape["num_experts"], shape["N"], shape["K"] + group_size = 128 + dtype = torch.float16 + + activations, packed, scales, _ = _make_int4_sym(E, N, K, group_size, dtype, total_tokens) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + max_tpe = max(shape["tokens_per_expert"]) + out_auto = ark.moe( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + phase="auto", decode_threshold=max_tpe + 1, + ) + out_decode = ark.moe_gemm_decode( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + ) + torch.testing.assert_close(out_auto, out_decode, rtol=0, atol=0) + + def test_invalid_phase_raises(self): + shape = _DECODE_SHAPE + total_tokens = sum(shape["tokens_per_expert"]) + E, N, K = shape["num_experts"], shape["N"], shape["K"] + group_size = 128 + dtype = torch.float16 + + activations, packed, scales, _ = _make_int4_sym(E, N, K, group_size, dtype, total_tokens) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + with pytest.raises(ValueError, match="phase"): + ark.moe( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + phase="not_a_phase", + ) + + +@pytest.mark.skipif(bool(_UNIFIED_SKIP), reason=_UNIFIED_SKIP or "ok") +class TestMoeUnifiedBitParity: + """``moe(phase=X)`` must be bit-identical to the dispatched kernel. + + Parametrised across all supported quant schemes; for each, we check both + the decode-shaped and the prefill-shaped input so that both code paths + are exercised regardless of which one ``"auto"`` would have picked. + """ + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("shape_name,shape", [ + ("decode-shape", _DECODE_SHAPE), + ("prefill-shape", _PREFILL_SHAPE), + ]) + def test_fp_unquantized(self, dtype, shape_name, shape): + E, N, K = shape["num_experts"], shape["N"], shape["K"] + total_tokens = sum(shape["tokens_per_expert"]) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + weights_NK = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + # phase="decode" -> moe_gemm_decode + out_unified = ark.moe(activations, weights_NK, ntpe, weight_bits=16, phase="decode") + out_kernel = ark.moe_gemm_decode(activations, weights_NK, ntpe, weight_bits=16) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + # phase="prefill" -> moe_gemm_prefill + out_unified = ark.moe(activations, weights_NK, ntpe, weight_bits=16, phase="prefill") + out_kernel = ark.moe_gemm_prefill(activations, weights_NK, ntpe, weight_bits=16) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + @pytest.mark.parametrize("shape_name,shape", [ + ("decode-shape", _DECODE_SHAPE), + ("prefill-shape", _PREFILL_SHAPE), + ]) + def test_int4(self, dtype, asym, shape_name, shape): + E, N, K = shape["num_experts"], shape["N"], shape["K"] + total_tokens = sum(shape["tokens_per_expert"]) + group_size = 128 + if asym: + activations, packed, scales, zeros = _make_int4_asym(E, N, K, group_size, dtype, total_tokens) + else: + activations, packed, scales, zeros = _make_int4_sym(E, N, K, group_size, dtype, total_tokens) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + kwargs = dict(scales=scales, zeros=zeros, weight_bits=4, group_size=group_size, asym=asym) + + out_unified = ark.moe(activations, packed, ntpe, phase="decode", **kwargs) + out_kernel = ark.moe_gemm_decode(activations, packed, ntpe, **kwargs) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + out_unified = ark.moe(activations, packed, ntpe, phase="prefill", **kwargs) + out_kernel = ark.moe_gemm_prefill(activations, packed, ntpe, **kwargs) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_int8(self, dtype, asym): + # Single shape -- the quant path is the same on both shapes, so + # iterating both would just slow the test suite down. + shape = _PREFILL_SHAPE + E, N, K = shape["num_experts"], shape["N"], shape["K"] + total_tokens = sum(shape["tokens_per_expert"]) + group_size = 128 + if asym: + activations, packed, scales, zeros = _make_int8_asym(E, N, K, group_size, dtype, total_tokens) + else: + activations, packed, scales, zeros = _make_int8_sym(E, N, K, group_size, dtype, total_tokens) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + kwargs = dict(scales=scales, zeros=zeros, weight_bits=8, group_size=group_size, asym=asym) + out_unified = ark.moe(activations, packed, ntpe, phase="prefill", **kwargs) + out_kernel = ark.moe_gemm_prefill(activations, packed, ntpe, **kwargs) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + out_unified = ark.moe(activations, packed, ntpe, phase="decode", **kwargs) + out_kernel = ark.moe_gemm_decode(activations, packed, ntpe, **kwargs) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_int2(self, dtype, asym): + shape = _PREFILL_SHAPE + E, N, K = shape["num_experts"], shape["N"], shape["K"] + total_tokens = sum(shape["tokens_per_expert"]) + group_size = 128 + if asym: + activations, packed, scales, zeros = _make_int2_asym(E, N, K, group_size, dtype, total_tokens) + else: + activations, packed, scales, zeros = _make_int2_sym(E, N, K, group_size, dtype, total_tokens) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + kwargs = dict(scales=scales, zeros=zeros, weight_bits=2, group_size=group_size, asym=asym) + out_unified = ark.moe(activations, packed, ntpe, phase="prefill", **kwargs) + out_kernel = ark.moe_gemm_prefill(activations, packed, ntpe, **kwargs) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + out_unified = ark.moe(activations, packed, ntpe, phase="decode", **kwargs) + out_kernel = ark.moe_gemm_decode(activations, packed, ntpe, **kwargs) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + def test_fp8(self, dtype, fp8_dtype): + shape = _PREFILL_SHAPE + E, N, K = shape["num_experts"], shape["N"], shape["K"] + total_tokens = sum(shape["tokens_per_expert"]) + group_size = 128 + activations, packed, scales, _ = _make_fp8(E, N, K, group_size, dtype, total_tokens, fp8_dtype) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + kwargs = dict(scales=scales, group_size=group_size, asym=False) + out_unified = ark.moe(activations, packed, ntpe, phase="prefill", **kwargs) + out_kernel = ark.moe_gemm_prefill(activations, packed, ntpe, **kwargs) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + out_unified = ark.moe(activations, packed, ntpe, phase="decode", **kwargs) + out_kernel = ark.moe_gemm_decode(activations, packed, ntpe, **kwargs) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) From 55026244a82fddee92921b2bba5213e7e51d271c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 18 Jun 2026 01:48:28 +0000 Subject: [PATCH 28/28] test(ark): add model-level MoE perf benchmark on XPU - Add test/test_ark/test_moe_model_perf.py: real MoE LLM perf scan (tiny by default, AR_MOE_PERF_FULL=1 for full checkpoints) over prefill + per-token decode latency. Compares FP reference vs ARK backend (and optional GPTQModel). Asserts ARK decode <= 2x FP to guard against silent ark.moe dispatcher regressions. - Remove obsolete auto_round_extension/ark/test/test_moe_model_perf.py (synthetic kernel-trace bench; replaced by the model-level one). - Register the `perf` pytest marker in pyproject.toml. --- .../ark/test/test_moe_model_perf.py | 364 ----------------- pyproject.toml | 5 + test/test_ark/test_moe_model_perf.py | 385 ++++++++++++++++++ 3 files changed, 390 insertions(+), 364 deletions(-) delete mode 100644 auto_round_extension/ark/test/test_moe_model_perf.py create mode 100644 test/test_ark/test_moe_model_perf.py diff --git a/auto_round_extension/ark/test/test_moe_model_perf.py b/auto_round_extension/ark/test/test_moe_model_perf.py deleted file mode 100644 index eb5bc4e69..000000000 --- a/auto_round_extension/ark/test/test_moe_model_perf.py +++ /dev/null @@ -1,364 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright (c) 2026 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Model-level performance test for the unified ``ark.moe`` dispatcher. - -This benchmark simulates a realistic MoE LLM generation trace: - - 1. A **prefill** step that processes the whole prompt (many tokens/expert). - 2. A handful of **decode** steps that produce one token each (1-2 - tokens/expert after TopK routing). - -Each step runs through ``L`` MoE layers; each layer consists of an -``up_proj`` (``K -> N_inter``) and a ``down_proj`` (``N_inter -> K``) MoE -GEMM. This is the standard shape for Mixtral-style models. - -We then compare four call strategies for the *same* trace: - - * ``always_prefill``: model code always calls ``moe_gemm_prefill``. - Represents the simple-but-suboptimal "use the prefill kernel - everywhere" approach. - * ``always_decode``: model code always calls ``moe_gemm_decode``. - The opposite extreme -- decode kernel for both phases. - * ``manual_branch``: model code branches on the known phase (``if - is_prefill: moe_gemm_prefill else moe_gemm_decode``). Optimal but - requires two call sites and a phase flag in the model. - * ``unified_auto``: single call site: ``ark.moe(..., phase="auto")``. - The dispatcher picks the right kernel from ``num_tokens_per_expert``; - pays one tiny host-device sync per call. - * ``unified_hinted``: single call site: ``ark.moe(..., phase=)``. - Skips the sync when the caller already knows the phase. - -The reported speedup is ``always_prefill_time / strategy_time`` -- i.e., -how much faster the unified API gets you relative to the naive -single-kernel approach a typical first-pass integration would use. - -How to run:: - - pytest -v -s auto_round_extension/ark/test/test_moe_model_perf.py -""" - -from dataclasses import dataclass -from typing import Callable - -import auto_round_kernel -import pytest -import torch - -from test_moe import ( # noqa: E402 - _pack_int4_sym, -) - -ark = auto_round_kernel - - -# --------------------------------------------------------------------------- -# Skip reasons (mirror the other perf tests) -# --------------------------------------------------------------------------- - - -def _xpu_available() -> bool: - return hasattr(torch, "xpu") and torch.xpu.is_available() - - -def _skip_reason() -> str: - if not _xpu_available(): - return "XPU not available" - if ark.xpu_lib is None: - return "ark.xpu_lib is None (XPU extension failed to import)" - for sym in ("moe_gemm_decode", "moe_gemm_prefill"): - if not hasattr(ark.xpu_lib, sym): - return f"ark.xpu_lib missing {sym} (need ARK_SYCL_TLA=ON)" - if not hasattr(ark, "moe"): - return "ark.moe (unified entry point) not exported by auto_round_kernel" - return "" - - -_SKIP = _skip_reason() - -print( - "[moe-model-perf] xpu_available=%s xpu_lib=%s has_moe=%s" - % ( - _xpu_available(), - "loaded" if ark.xpu_lib is not None else "None", - hasattr(ark, "moe"), - ) -) -if _SKIP: - print("[moe-model-perf] suite will SKIP. reason: %s" % _SKIP) - - -# --------------------------------------------------------------------------- -# Timing utility -# --------------------------------------------------------------------------- - -WARMUP = 3 -ITERS = 10 - - -def _xpu_time_ms(fn, warmup: int = WARMUP, iters: int = ITERS) -> float: - for _ in range(warmup): - fn() - torch.xpu.synchronize() - timings = [] - for _ in range(iters): - start = torch.xpu.Event(enable_timing=True) - end = torch.xpu.Event(enable_timing=True) - start.record() - fn() - end.record() - end.synchronize() - timings.append(start.elapsed_time(end)) - timings.sort() - return timings[len(timings) // 2] - - -# --------------------------------------------------------------------------- -# Model + trace definitions -# --------------------------------------------------------------------------- - - -@dataclass -class MoELayerWeights: - """Quantized weights for one MoE layer (up_proj + down_proj).""" - - up_packed: torch.Tensor # [E, N_inter, K // 2] (INT4) - up_scales: torch.Tensor # [E, N_inter, K // group_size] - down_packed: torch.Tensor # [E, K, N_inter // 2] - down_scales: torch.Tensor # [E, K, N_inter // group_size] - - -def _build_layers(num_layers, E, K, N_inter, group_size, dtype): - layers = [] - for _ in range(num_layers): - # up_proj: [E, N_inter, K] - up_float = (torch.randn(E, N_inter, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) - up_scales = torch.empty(E, N_inter, K // group_size, dtype=dtype, device="xpu") - up_packed = _pack_int4_sym(up_float, up_scales, group_size) - - # down_proj: [E, K, N_inter] - down_float = (torch.randn(E, K, N_inter, dtype=torch.float32, device="xpu") * 0.1).to(dtype) - down_scales = torch.empty(E, K, N_inter // group_size, dtype=dtype, device="xpu") - down_packed = _pack_int4_sym(down_float, down_scales, group_size) - - layers.append(MoELayerWeights(up_packed, up_scales, down_packed, down_scales)) - return layers - - -@dataclass -class Step: - """One trace step: activations + per-expert token counts + phase flag.""" - - activations: torch.Tensor # [total_tokens, K] - ntpe: torch.Tensor # [E] int32 - is_prefill: bool - - -def _build_trace(E, K, prompt_tokens, decode_steps, dtype, seed=0): - """Mixed prefill+decode trace mimicking an LLM generation request.""" - g = torch.Generator(device="xpu") - g.manual_seed(seed) - - steps = [] - - # ---- Prefill step: prompt_tokens routed across E experts (roughly even). ---- - base = prompt_tokens // E - rem = prompt_tokens - base * E - tpe = [base + (1 if i < rem else 0) for i in range(E)] - activations = torch.randn(prompt_tokens, K, dtype=dtype, device="xpu", generator=g) - steps.append(Step(activations, torch.tensor(tpe, dtype=torch.int32, device="xpu"), is_prefill=True)) - - # ---- Decode steps: 1 token, TopK=2 routing -> 2 experts see 1 token each. ---- - # We simulate batch=1 + top_k=2: total_tokens=2, two experts each get 1 token. - for s in range(decode_steps): - tpe = [0] * E - chosen = [(s * 2) % E, (s * 2 + 1) % E] - if chosen[0] == chosen[1]: - chosen[1] = (chosen[1] + 1) % E - tpe[chosen[0]] = 1 - tpe[chosen[1]] = 1 - activations = torch.randn(2, K, dtype=dtype, device="xpu", generator=g) - steps.append(Step(activations, torch.tensor(tpe, dtype=torch.int32, device="xpu"), is_prefill=False)) - - return steps - - -# --------------------------------------------------------------------------- -# Call strategies (the four "model-side" integration patterns) -# --------------------------------------------------------------------------- - - -_CallStrategy = Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, bool, int], torch.Tensor] - - -def _strat_always_prefill(activations, packed, ntpe, scales, _zeros, _is_prefill, group_size): - return ark.moe_gemm_prefill( - activations, packed, ntpe, - scales=scales, weight_bits=4, group_size=group_size, asym=False, - ) - - -def _strat_always_decode(activations, packed, ntpe, scales, _zeros, _is_prefill, group_size): - return ark.moe_gemm_decode( - activations, packed, ntpe, - scales=scales, weight_bits=4, group_size=group_size, asym=False, - ) - - -def _strat_manual_branch(activations, packed, ntpe, scales, _zeros, is_prefill, group_size): - if is_prefill: - return ark.moe_gemm_prefill( - activations, packed, ntpe, - scales=scales, weight_bits=4, group_size=group_size, asym=False, - ) - return ark.moe_gemm_decode( - activations, packed, ntpe, - scales=scales, weight_bits=4, group_size=group_size, asym=False, - ) - - -def _strat_unified_auto(activations, packed, ntpe, scales, _zeros, _is_prefill, group_size): - return ark.moe( - activations, packed, ntpe, - scales=scales, weight_bits=4, group_size=group_size, asym=False, - phase="auto", - ) - - -def _strat_unified_hinted(activations, packed, ntpe, scales, _zeros, is_prefill, group_size): - return ark.moe( - activations, packed, ntpe, - scales=scales, weight_bits=4, group_size=group_size, asym=False, - phase="prefill" if is_prefill else "decode", - ) - - -_STRATEGIES = [ - ("always_prefill ", _strat_always_prefill), - ("always_decode ", _strat_always_decode), - ("manual_branch ", _strat_manual_branch), - ("unified_auto ", _strat_unified_auto), - ("unified_hinted ", _strat_unified_hinted), -] - - -# --------------------------------------------------------------------------- -# End-to-end "model forward over the trace" runner -# --------------------------------------------------------------------------- - - -def _forward_full_trace(strategy, layers, trace, group_size): - """Run the strategy across every step and every layer of the trace. - - Each layer is up_proj followed by down_proj. The down_proj input is just - the up_proj output reshaped to ``[total_tokens, N_inter]`` (we skip the - SiLU/element-wise ops; we only want to measure the MoE-kernel cost, - which is what the dispatcher changes). - """ - for step in trace: - x = step.activations - for layer in layers: - up_out = strategy( - x, layer.up_packed, step.ntpe, layer.up_scales, None, - step.is_prefill, group_size, - ) - # down_proj takes the up output (same total_tokens, dim = N_inter). - x_down = strategy( - up_out, layer.down_packed, step.ntpe, layer.down_scales, None, - step.is_prefill, group_size, - ) - # Feed down_proj output into the next layer's up_proj. - x = x_down - - -# --------------------------------------------------------------------------- -# The benchmark -# --------------------------------------------------------------------------- - - -# Three model presets mimicking common MoE configurations. Each runs the -# whole trace (1 prefill + N decode steps) through L layers. -# -# Shapes are intentionally smaller than the perf-bench shapes used in -# `test_moe_prefill_perf.py` so the full multi-layer forward stays tractable -# (we run the whole trace ITERS times for the median timing). -_MODEL_PRESETS = [ - # (label, num_layers, E, K, N_inter, group_size, prompt_tokens, decode_steps) - ("mixtral-tiny L=2 E=8 ", 2, 8, 1024, 2048, 128, 64, 8), - ("mixtral-small L=4 E=8 ", 4, 8, 2048, 4096, 128, 128, 8), - ("deepseek-tiny L=2 E=16", 2, 16, 1024, 2048, 128, 96, 8), -] - - -def _print_header(): - print() - print("=" * 110) - print( - f"Model-level MoE perf: full forward over (prefill + N x decode) for L layers; " - f"INT4 sym weights" - ) - print("-" * 110) - print( - f"{'preset':<24}{'strategy':<18}{'forward(ms)':>14}{'speedup vs prefill-only':>28}" - ) - print("-" * 110) - - -def _print_row(preset, strat_name, ms, speedup): - print(f"{preset:<24}{strat_name:<18}{ms:>14.4f}{speedup:>26.2f}x") - - -@pytest.mark.skipif(bool(_SKIP), reason=_SKIP or "ok") -class TestMoEModelPerf: - """Model-level perf comparison for the unified MoE dispatcher. - - The benchmark validates two things in one go: - 1. *Correctness of the dispatcher*: the perf-comparison loop also - confirms every strategy produces a runnable forward (any kernel - error during timing causes a hard failure). - 2. *Speedup vs naive single-kernel integration*: the printed table - shows how much the unified API saves a model author who would - otherwise reach for ``always_prefill``. - """ - - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) - def test_model_forward(self, dtype): - _print_header() - for preset_label, L, E, K, N_inter, group_size, prompt_tokens, decode_steps in _MODEL_PRESETS: - layers = _build_layers(L, E, K, N_inter, group_size, dtype) - trace = _build_trace(E, K, prompt_tokens, decode_steps, dtype) - - # Time each strategy on the same (layers, trace) pair. - baseline_ms = None - for strat_name, strat_fn in _STRATEGIES: - def _run(strat_fn=strat_fn): - _forward_full_trace(strat_fn, layers, trace, group_size) - - ms = _xpu_time_ms(_run) - if baseline_ms is None: - baseline_ms = ms # first row is always always_prefill - speedup = baseline_ms / ms if ms > 0 else float("nan") - _print_row(f"{preset_label} {dtype}".strip(), strat_name, ms, speedup) - - print("-" * 110) - - # Free the per-preset workspace so the next preset starts clean. - ark.clear_moe_prefill_workspace_cache() - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) diff --git a/pyproject.toml b/pyproject.toml index 6efe49e8d..e1978e006 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,6 +2,11 @@ requires = ["setuptools>=64", "wheel"] build-backend = "setuptools.build_meta" +[tool.pytest.ini_options] +markers = [ + "perf: performance benchmark tests (excluded from default CI runs; select with `-m perf`)", +] + [tool.codespell] skip = 'pyproject.toml,.azure-pipelines/scripts/codeScan/codespell/autoround_dict.txt,auto_round_extension/ark/*' ignore-words = ".azure-pipelines/scripts/codeScan/codespell/autoround_dict.txt" diff --git a/test/test_ark/test_moe_model_perf.py b/test/test_ark/test_moe_model_perf.py new file mode 100644 index 000000000..7410b8ef2 --- /dev/null +++ b/test/test_ark/test_moe_model_perf.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model-level perf benchmark for MoE LLMs on the ARK (XPU) backend. + +This file mirrors the structure of ``test/test_ark/test_model.py`` but +operates on real MoE checkpoints (Qwen1.5-MoE, DeepSeek-V2-Lite) and adds +prefill / decode latency measurement on top. + +For each ``(model, bits, dtype)`` combination we: + + 1. Load a *tiny* slice of the MoE model with random weights + (``num_layers=2``, ``num_experts=4``) via ``helpers.get_tiny_model`` so + the suite is tractable in CI. Set ``AR_MOE_PERF_FULL=1`` to load the + full checkpoint instead. + 2. Quantize with ``AutoRound(iters=0, nsamples=1, disable_opt_rtn=True)`` + and export with ``format="auto_round"``. + 3. Reload twice on XPU: + a. unquantized FP reference (no ``quantization_config``); + b. ARK backend (``AutoRoundConfig(backend="ark")``); + c. *optional* GPTQModel backend, skipped if not installed. + 4. Smoke-test correctness via ``helpers.model_infer`` (asserts non-empty + output -> catches any wiring break in the backend). + 5. Measure **prefill** latency (single forward over a 128-token prompt) + and **per-token decode** latency (``model.generate(max_new_tokens=32)`` + after a 4-token warmup) using ``torch.xpu.Event``. Report the median + of 3 runs. + 6. Assert that ARK decode latency is within ``ARK_DECODE_REGRESSION_FACTOR`` + of the FP reference -- defends against silent perf regressions of the + unified ``ark.moe`` dispatcher. + +How to run:: + + pytest -v -s test/test_ark/test_moe_model_perf.py + # full-size checkpoints (needs ~30GB free + checkpoint cached locally): + AR_MOE_PERF_FULL=1 pytest -v -s test/test_ark/test_moe_model_perf.py + # exclude from default CI runs: + pytest -m 'not perf' ... +""" + +import os +import shutil +from statistics import median + +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoRoundConfig, AutoTokenizer + +from auto_round import AutoRound + +from ..helpers import ( + deepseek_v2_name_or_path, + get_model_path, + get_tiny_model, + model_infer, + qwen_moe_name_or_path, +) + + +# --------------------------------------------------------------------------- +# Knobs +# --------------------------------------------------------------------------- + +# Use the full pretrained checkpoint when set; otherwise build a tiny random +# slice via ``get_tiny_model`` so the test stays CI-friendly. +_USE_FULL_MODEL = os.environ.get("AR_MOE_PERF_FULL", "0") == "1" + +# Tiny-model slice geometry (only used when AR_MOE_PERF_FULL=0). +_TINY_NUM_LAYERS = 2 +_TINY_NUM_EXPERTS = 4 + +# Timing harness. +_PREFILL_PROMPT_TOKENS = 128 +_DECODE_NEW_TOKENS = 32 +_DECODE_WARMUP_TOKENS = 4 +_TIMING_REPEATS = 3 + +# Perf-regression guard for the ARK decode path vs the unquantized FP +# baseline. Loose enough to absorb run-to-run jitter (we already take the +# median of N runs) but tight enough to catch the "phase=auto sync" +# regression class described in the upstream perf analysis (which costs +# ~25% per kernel call). +ARK_DECODE_REGRESSION_FACTOR = 2.0 + + +# --------------------------------------------------------------------------- +# ARK availability gate (mirrors auto_round_extension/ark/test/test_moe_unified.py) +# --------------------------------------------------------------------------- + + +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def _ark_skip_reason() -> str: + if not _xpu_available(): + return "XPU not available" + try: + import auto_round_kernel as ark + except ImportError as exc: + return f"auto_round_kernel not importable: {exc}" + if getattr(ark, "xpu_lib", None) is None: + return "ark.xpu_lib is None (XPU extension failed to import)" + for sym in ("moe_gemm_decode", "moe_gemm_prefill"): + if not hasattr(ark.xpu_lib, sym): + return f"ark.xpu_lib missing {sym} (need ARK_SYCL_TLA=ON)" + if not hasattr(ark, "moe"): + return "ark.moe (unified entry point) not exported by auto_round_kernel" + return "" + + +_ARK_SKIP = _ark_skip_reason() + + +# --------------------------------------------------------------------------- +# Timing utilities +# --------------------------------------------------------------------------- + + +def _xpu_sync(): + if _xpu_available(): + torch.xpu.synchronize() + + +def _xpu_time_ms(fn, repeats: int = _TIMING_REPEATS) -> float: + """Time ``fn`` on XPU using ``torch.xpu.Event``; return median ms. + + The function is invoked ``repeats`` times after one warmup call. The + median is returned to absorb run-to-run jitter from the XPU runtime. + """ + fn() # warmup + _xpu_sync() + timings = [] + for _ in range(repeats): + start = torch.xpu.Event(enable_timing=True) + end = torch.xpu.Event(enable_timing=True) + start.record() + fn() + end.record() + end.synchronize() + timings.append(start.elapsed_time(end)) + return median(timings) + + +def _measure_prefill_ms(model, input_ids, attention_mask) -> float: + def _run(): + with torch.inference_mode(): + model(input_ids=input_ids, attention_mask=attention_mask) + + return _xpu_time_ms(_run) + + +def _measure_decode_ms_per_tok(model, tokenizer, prompt_ids, attention_mask) -> float: + """Median per-token decode latency from ``generate(max_new_tokens=N)``.""" + # Warmup generate (compiles caches, allocates KV). + with torch.inference_mode(): + model.generate( + input_ids=prompt_ids, + attention_mask=attention_mask, + do_sample=False, + max_new_tokens=_DECODE_WARMUP_TOKENS, + pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, + ) + + def _run(): + with torch.inference_mode(): + model.generate( + input_ids=prompt_ids, + attention_mask=attention_mask, + do_sample=False, + max_new_tokens=_DECODE_NEW_TOKENS, + pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, + ) + + total_ms = _xpu_time_ms(_run) + return total_ms / _DECODE_NEW_TOKENS + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_prefill_inputs(tokenizer, device, dtype, num_tokens=_PREFILL_PROMPT_TOKENS): + """Build a fixed-length (num_tokens) input tensor on ``device``.""" + # Use a deterministic synthetic prompt to keep the perf number stable + # across machines that may not have the same vocab tokenization for a + # natural-language prompt of a given length. + vocab_size = getattr(tokenizer, "vocab_size", 32000) + input_ids = torch.arange(num_tokens, dtype=torch.long).unsqueeze(0) % max(1, vocab_size) + attention_mask = torch.ones_like(input_ids) + return input_ids.to(device), attention_mask.to(device) + + +def _load_tiny_or_full(model_name_or_path, dtype): + """Return a model instance + tokenizer for the given MoE checkpoint. + + Tiny path uses ``get_tiny_model`` (random weights, 2 layers, 4 experts) + so the suite is tractable in CI. Full path loads the real checkpoint + when ``AR_MOE_PERF_FULL=1``. + """ + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) + if _USE_FULL_MODEL: + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, dtype=dtype, trust_remote_code=True) + else: + model = get_tiny_model( + model_name_or_path, + num_layers=_TINY_NUM_LAYERS, + num_experts=_TINY_NUM_EXPERTS, + from_config=True, + trust_remote_code=True, + ) + model = model.to(dtype) + return model, tokenizer + + +def _format_row(model_label, backend, prefill_ms, decode_ms_per_tok, baseline_decode_ms): + tps = 1000.0 / decode_ms_per_tok if decode_ms_per_tok > 0 else float("nan") + if baseline_decode_ms is None or baseline_decode_ms <= 0: + speedup_str = " --" + else: + speedup_str = f"{baseline_decode_ms / decode_ms_per_tok:6.2f}x" + return ( + f"{model_label:<22}{backend:<14}{prefill_ms:>12.3f}{decode_ms_per_tok:>16.3f}" + f"{tps:>14.2f}{speedup_str:>14}" + ) + + +def _print_header(dtype): + print() + print("=" * 96) + print(f"Model-level MoE perf -- dtype={dtype} AR_MOE_PERF_FULL={_USE_FULL_MODEL}") + print("-" * 96) + print( + f"{'model':<22}{'backend':<14}{'prefill(ms)':>12}{'decode(ms/tok)':>16}" + f"{'tokens/s':>14}{'vs FP':>14}" + ) + print("-" * 96) + + +# --------------------------------------------------------------------------- +# Test class +# --------------------------------------------------------------------------- + + +pytestmark = [pytest.mark.perf] + + +@pytest.mark.skipif(not _xpu_available(), reason="XPU not available") +class TestMoEModelPerf: + """Model-level perf table for MoE LLMs on the ARK XPU backend.""" + + @classmethod + def teardown_class(cls): + shutil.rmtree("runs", ignore_errors=True) + + @pytest.fixture(autouse=True) + def _save_dir(self, tmp_path): + self.save_folder = str(tmp_path / "saved") + yield + shutil.rmtree(self.save_folder, ignore_errors=True) + + # -- helpers ------------------------------------------------------------ + + def _quantize_and_save(self, model, tokenizer, bits, group_size, sym): + autoround = AutoRound( + model, + tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=0, + nsamples=1, + disable_opt_rtn=True, + ) + _, saved_folder = autoround.quantize_and_save(output_dir=self.save_folder, format="auto_round") + return saved_folder + + def _reload(self, saved_folder, dtype, *, backend): + kwargs = dict(dtype=dtype, device_map="xpu", trust_remote_code=True) + if backend is not None: + kwargs["quantization_config"] = AutoRoundConfig(backend=backend) + return AutoModelForCausalLM.from_pretrained(saved_folder, **kwargs) + + def _bench_one(self, model, tokenizer, dtype): + prompt_ids, attention_mask = _make_prefill_inputs(tokenizer, model.device, dtype) + prefill_ms = _measure_prefill_ms(model, prompt_ids, attention_mask) + decode_ms = _measure_decode_ms_per_tok(model, tokenizer, prompt_ids, attention_mask) + return prefill_ms, decode_ms + + def _smoke_test(self, model, tokenizer): + out = model_infer(model, tokenizer) + assert out is not None and len(out) > 0, "ARK backend produced empty output" + + # -- parametrized perf scan -------------------------------------------- + + @pytest.mark.parametrize( + "model_label, model_path", + [ + ("qwen-moe", qwen_moe_name_or_path), + ("deepseek-v2-lite", deepseek_v2_name_or_path), + ], + ) + @pytest.mark.parametrize("bits, group_size, sym", [(4, 128, True), (8, 128, True)]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_moe_forward_perf(self, model_label, model_path, bits, group_size, sym, dtype): + # 1. Resolve checkpoint -- skip if neither local mirror nor HF cache + # has it (CI without internet must not fail this test). + resolved = get_model_path(model_path) + try: + tokenizer_probe = AutoTokenizer.from_pretrained(resolved, trust_remote_code=True) + except (OSError, ValueError) as exc: + pytest.skip(f"checkpoint {model_path!r} not available locally: {exc}") + del tokenizer_probe + + # 2. Load tiny (or full) FP MoE model + tokenizer. + fp_model, tokenizer = _load_tiny_or_full(resolved, dtype) + + # 3. Quantize + save. + saved_folder = self._quantize_and_save(fp_model, tokenizer, bits, group_size, sym) + # Free the calibration-time model before reloading on XPU. + del fp_model + torch.xpu.empty_cache() + + _print_header(dtype) + label = f"{model_label} INT{bits}" + + # 4a. FP reference on XPU (no quantization_config). + fp_decode_ms = None + try: + fp_model_xpu = self._reload(resolved, dtype, backend=None) + fp_prefill_ms, fp_decode_ms = self._bench_one(fp_model_xpu, tokenizer, dtype) + print(_format_row(label, "fp(ref)", fp_prefill_ms, fp_decode_ms, fp_decode_ms)) + del fp_model_xpu + torch.xpu.empty_cache() + except Exception as exc: # noqa: BLE001 -- FP baseline is optional + print(f"[moe-model-perf] fp(ref) row skipped for {label}: {exc}") + + # 4b. ARK backend (the thing under test). + if _ARK_SKIP: + pytest.skip(f"ARK backend unavailable: {_ARK_SKIP}") + ark_model = self._reload(saved_folder, dtype, backend="ark") + self._smoke_test(ark_model, tokenizer) + ark_prefill_ms, ark_decode_ms = self._bench_one(ark_model, tokenizer, dtype) + print(_format_row(label, "ark", ark_prefill_ms, ark_decode_ms, fp_decode_ms)) + del ark_model + torch.xpu.empty_cache() + + # 4c. Optional GPTQModel cross-reference (skip silently if missing). + try: + gptq_model = self._reload(saved_folder, dtype, backend="gptqmodel") + gptq_prefill_ms, gptq_decode_ms = self._bench_one(gptq_model, tokenizer, dtype) + print(_format_row(label, "gptqmodel", gptq_prefill_ms, gptq_decode_ms, fp_decode_ms)) + del gptq_model + torch.xpu.empty_cache() + except Exception as exc: # noqa: BLE001 -- backend is optional + print(f"[moe-model-perf] gptqmodel row skipped for {label}: {exc}") + + print("-" * 96) + + # 5. Perf-regression assertion (only when we have an FP baseline). + if fp_decode_ms is not None and fp_decode_ms > 0: + assert ark_decode_ms <= fp_decode_ms * ARK_DECODE_REGRESSION_FACTOR, ( + f"ARK decode latency {ark_decode_ms:.3f} ms/tok exceeds " + f"{ARK_DECODE_REGRESSION_FACTOR}x FP baseline {fp_decode_ms:.3f} ms/tok " + f"for {label} ({dtype}). Likely regression in the ark.moe dispatcher." + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"])