diff --git a/auto_round_extension/ark/auto_round_kernel/__init__.py b/auto_round_extension/ark/auto_round_kernel/__init__.py index b7b80dc6d..cca7f7950 100644 --- a/auto_round_extension/ark/auto_round_kernel/__init__.py +++ b/auto_round_extension/ark/auto_round_kernel/__init__.py @@ -1066,6 +1066,241 @@ def _ceil_div(a, b): return out +def moe_gemm_decode( + activations: torch.Tensor, + weights: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + *, + scales: Optional[torch.Tensor] = None, + zeros: Optional[torch.Tensor] = None, + weight_bits: int = 4, + group_size: int = 128, + asym: bool = False, +) -> torch.Tensor: + """MoE GEMV optimized for the decode phase. + + Each expert typically processes only 1-2 tokens (top-k routing with + small batch). Activations must already be gathered/sorted by expert + (same convention as ``moe_gemm``). + + Args: + activations: ``[total_tokens, K]`` in fp16 or bf16. + weights: 3-D tensor ``[E, N, K_packed]``. The accepted layouts are: + + * Unquantized (``weight_bits=16``): ``torch.float16`` / ``torch.bfloat16`` + matching the activations dtype, ``K_packed == K``. + * Int8 (``weight_bits=8``): ``torch.uint8``, ``K_packed == K``. + Sym (``asym=False``) reinterprets each byte as signed int8; + asym (``asym=True``) treats each byte as ``uint8`` with a + per-group zero-point. + * Int4 (``weight_bits=4``): ``torch.uint8`` packed, + ``K_packed == K // 2`` (two 4-bit values per byte; low nibble + at the lower K index). + * Int2 (``weight_bits=2``): ``torch.uint8`` packed, + ``K_packed == K // 4`` (four 2-bit values per byte; field j at + K index ``4*i + j`` occupies bits 2j and 2j+1 of byte i). + * FP8 (``torch.float8_e4m3fn`` / ``torch.float8_e5m2``): + ``K_packed == K``. ``weight_bits`` is ignored; ``asym`` must + be ``False`` (no zero-points for FP8). + num_tokens_per_expert: ``[E]`` int32. Sum must equal + ``activations.shape[0]``. + scales: ``[E, N, K // group_size]`` in activations dtype. Required + for all quantized paths (int8/int4/int2/fp8); must be ``None`` + for unquantized weights. + zeros: ``[E, N, K // group_size]`` in activations dtype. Required + when ``asym=True`` (int8/int4/int2 only); otherwise ``None``. + weight_bits: 2, 4, 8, or 16. Ignored when ``weights`` is an FP8 + tensor (the FP8 sub-format is taken from ``weights.dtype``). + group_size: group along K for quantized weights (default 128). + asym: if ``True``, weights use unsigned encoding and ``zeros`` must + be provided. Not supported for FP8. + + Returns: + outputs: ``[total_tokens, N]`` in the same dtype as activations. + """ + activations, weights, scales, zeros, num_tokens_per_expert, weight_dtype, total_tokens, N, K, num_experts = ( + _validate_moe_quant_args( + activations, + weights, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=weight_bits, + group_size=group_size, + asym=asym, + api_name="moe_gemm_decode", + ) + ) + + lib = get_lib(activations) + stream = get_stream(activations) + outputs = torch.empty((total_tokens, N), device=activations.device, dtype=activations.dtype) + # Scratch buffer mapping each token to its expert id; filled on-device + # inside the kernel wrapper so we avoid host-device sync. + expert_id_per_token = torch.empty((total_tokens,), device=activations.device, dtype=torch.int32) + + scales_ptr = scales.data_ptr() if scales is not None else 0 + zeros_ptr = zeros.data_ptr() if zeros is not None else 0 + + lib.moe_gemm_decode( + stream, + activations.data_ptr(), + weights.data_ptr(), + scales_ptr, + zeros_ptr, + outputs.data_ptr(), + expert_id_per_token.data_ptr(), + cvt_dtype(activations.dtype), + weight_dtype, + N, + K, + group_size, + num_tokens_per_expert.data_ptr(), + num_experts, + total_tokens, + bool(asym), + ) + return outputs + + +def _validate_moe_quant_args( + activations: torch.Tensor, + weights: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + *, + scales: Optional[torch.Tensor], + zeros: Optional[torch.Tensor], + weight_bits: int, + group_size: int, + asym: bool, + api_name: str, +): + """Shared validation/normalisation for quantized MoE entry points. + + Returns a tuple of normalised tensors and dtype/shape metadata used by the + kernel-call site: + ``(activations, weights, scales, zeros, num_tokens_per_expert, + weight_dtype, total_tokens, N, K, num_experts)``. + """ + if activations.device.type != "xpu": + raise NotImplementedError(f"{api_name} is only supported on XPU") + + if activations.dtype not in (torch.float16, torch.bfloat16): + raise ValueError(f"activations must be fp16/bf16, got {activations.dtype}") + + if activations.ndim != 2: + raise ValueError("activations must be 2D [total_tokens, K]") + if weights.ndim != 3: + raise ValueError("weights must be 3D [E, N, K_packed]") + + if not activations.is_contiguous(): + activations = activations.contiguous() + if not weights.is_contiguous(): + weights = weights.contiguous() + + if num_tokens_per_expert.dtype != torch.int32: + num_tokens_per_expert = num_tokens_per_expert.to(torch.int32) + if not num_tokens_per_expert.is_contiguous(): + num_tokens_per_expert = num_tokens_per_expert.contiguous() + + total_tokens, K = activations.shape + num_experts = weights.shape[0] + N = weights.shape[1] + + if num_tokens_per_expert.shape[0] != num_experts: + raise ValueError(f"num_tokens_per_expert length {num_tokens_per_expert.shape[0]} != num_experts {num_experts}") + + # Detect FP8 weight dtype first (overrides weight_bits). + is_fp8 = weights.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + + if is_fp8: + if asym: + raise ValueError("FP8 weights do not support asym=True") + if weights.shape[2] != K: + raise ValueError(f"FP8 weights K dim {weights.shape[2]} != activations K {K}") + if scales is None: + raise ValueError("scales is required for FP8 weights") + if scales.dtype != activations.dtype: + raise ValueError("scales dtype must match activations dtype") + if K % group_size != 0: + raise ValueError("K must be a multiple of group_size") + expected_scale_shape = (num_experts, N, K // group_size) + if tuple(scales.shape) != expected_scale_shape: + raise ValueError(f"scales shape {tuple(scales.shape)} != expected {expected_scale_shape}") + if zeros is not None: + raise ValueError("zeros must be None for FP8 weights") + weight_dtype = ARK_DT.float8_e4m3 if weights.dtype == torch.float8_e4m3fn else ARK_DT.float8_e5m2 + if not scales.is_contiguous(): + scales = scales.contiguous() + elif weight_bits == 16: + if weights.dtype != activations.dtype: + raise ValueError("Unquantized weights must match activations dtype") + if weights.shape[2] != K: + raise ValueError(f"Unquantized weights K dim {weights.shape[2]} != activations K {K}") + weight_dtype = cvt_dtype(activations.dtype) + if scales is not None or zeros is not None: + raise ValueError("scales/zeros must be None when weight_bits=16") + elif weight_bits in (8, 4, 2): + if weights.dtype != torch.uint8: + raise ValueError(f"Int{weight_bits} packed weights must be torch.uint8") + if weight_bits == 8: + k_packed_expected = K + k_div = 1 + elif weight_bits == 4: + k_packed_expected = K // 2 + k_div = 2 + else: # weight_bits == 2 + k_packed_expected = K // 4 + k_div = 4 + if K % k_div != 0: + raise ValueError(f"K must be a multiple of {k_div} for weight_bits={weight_bits}") + if weights.shape[2] != k_packed_expected: + raise ValueError( + f"Int{weight_bits} packed weights last dim {weights.shape[2]} must equal K/{k_div} " + f"({k_packed_expected})" + ) + if scales is None: + raise ValueError(f"scales is required for int{weight_bits} weights") + if scales.dtype != activations.dtype: + raise ValueError("scales dtype must match activations dtype") + if K % group_size != 0: + raise ValueError("K must be a multiple of group_size") + # Group_size constraints per dtype. + if weight_bits == 4 and (group_size & 1) != 0: + raise ValueError("group_size must be even for int4 weights") + if weight_bits == 2 and (group_size & 3) != 0: + raise ValueError("group_size must be a multiple of 4 for int2 weights") + expected_scale_shape = (num_experts, N, K // group_size) + if tuple(scales.shape) != expected_scale_shape: + raise ValueError(f"scales shape {tuple(scales.shape)} != expected {expected_scale_shape}") + if asym: + if zeros is None: + raise ValueError("zeros is required when asym=True") + if zeros.dtype != activations.dtype: + raise ValueError("zeros dtype must match activations dtype") + if tuple(zeros.shape) != expected_scale_shape: + raise ValueError(f"zeros shape {tuple(zeros.shape)} != expected {expected_scale_shape}") + else: + if zeros is not None: + raise ValueError("zeros must be None when asym=False") + weight_dtype = {8: ARK_DT.int8, 4: ARK_DT.int4, 2: ARK_DT.int2}[weight_bits] + if not scales.is_contiguous(): + scales = scales.contiguous() + if asym and not zeros.is_contiguous(): + zeros = zeros.contiguous() + else: + raise ValueError(f"Unsupported weight_bits={weight_bits} (supported: 2, 4, 8, 16)") + + if N % 16 != 0: + raise ValueError(f"N must be a multiple of 16 (got {N})") + + expected_total = int(num_tokens_per_expert.sum().item()) + if expected_total != total_tokens: + raise ValueError(f"Sum of num_tokens_per_expert ({expected_total}) != total_tokens ({total_tokens})") + + return (activations, weights, scales, zeros, num_tokens_per_expert, weight_dtype, total_tokens, N, K, num_experts) + + def moe_gemm( activations: torch.Tensor, weights: torch.Tensor, @@ -1142,6 +1377,289 @@ def moe_gemm( return outputs +def moe_gemm_prefill( + activations: torch.Tensor, + weights: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + *, + scales: Optional[torch.Tensor] = None, + zeros: Optional[torch.Tensor] = None, + weight_bits: int = 4, + group_size: int = 128, + asym: bool = False, +) -> torch.Tensor: + """MoE Grouped GEMM optimized for the prefill phase, supporting all weight + encodings of ``moe_gemm_decode`` (FP16/BF16, INT8 sym/asym, INT4 sym/asym, + INT2 sym/asym, FP8 E4M3/E5M2). + + The argument shapes/dtypes match :func:`moe_gemm_decode` exactly so the same + quantized weights/scales/zeros tensors can be re-used between prefill and + decode without reshaping. Internally, for the quantized paths the kernel + materialises a ``[E, K, N]`` ``act_dtype`` temporary via an on-device + dequantization kernel and then dispatches to the existing CUTLASS-SYCL + Grouped GEMM (``moe_gemm``). Numerical results are bit-identical to + ``moe_gemm`` applied to the same dequantized weights. + + Args: + activations: ``[total_tokens, K]`` in fp16 or bf16. + weights: 3-D tensor; same layout/dtype contract as + :func:`moe_gemm_decode`. Quantized layouts are ``[E, N, K_packed]``; + the unquantized fast path (``weight_bits=16``) accepts + ``[E, N, K]`` -- callers providing already-``[E, K, N]`` weights + (as ``moe_gemm`` requires) should call ``moe_gemm`` directly. + num_tokens_per_expert: ``[E]`` int32. Sum must equal + ``activations.shape[0]``. + scales: ``[E, N, K // group_size]`` in activations dtype. Required for + quantized paths; ignored (must be ``None``) for unquantized. + zeros: ``[E, N, K // group_size]`` in activations dtype, required when + ``asym=True`` (int8/int4/int2 only). + weight_bits: 2, 4, 8, or 16. Ignored for FP8 weights. + group_size: group along K for quantized weights (default 128). + asym: if ``True``, weights use unsigned encoding; ``zeros`` required. + + Returns: + outputs: ``[total_tokens, N]`` in the same dtype as activations. + """ + activations, weights, scales, zeros, num_tokens_per_expert, weight_dtype, total_tokens, N, K, num_experts = ( + _validate_moe_quant_args( + activations, + weights, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=weight_bits, + group_size=group_size, + asym=asym, + api_name="moe_gemm_prefill", + ) + ) + + lib = get_lib(activations) + stream = get_stream(activations) + outputs = torch.empty((total_tokens, N), device=activations.device, dtype=activations.dtype) + + # Quantized paths need an [E, K, N] act-dtype scratch buffer that the + # on-device dequant kernel fills before the inner Grouped GEMM consumes + # it. The unquantized fast path forwards directly through `moe_gemm` and + # doesn't need scratch (passing 0 is safe -- the C++ wrapper short-circuits + # before touching the workspace pointer in that case). We allocate the + # workspace from PyTorch's caching allocator so repeated calls reuse the + # same memory. + is_unquantized = (weight_bits == 16) and (weights.dtype == activations.dtype) + if is_unquantized: + # `moe_gemm` requires `[E, K, N]` row-major weights; the decode-style + # `[E, N, K]` weight shape coming through this validator can be + # transposed into a temporary contiguous `[E, K, N]` view. The + # workspace serves the same role as the dequant scratch so the + # on-device path stays uniform. + dequant_workspace = weights.transpose(1, 2).contiguous() + weights_ptr = dequant_workspace.data_ptr() + else: + # Reuse a persistent `[E, K, N]` workspace across calls with the same + # (device, dtype, E, K, N). For real MoE prefill workloads the same + # shape is dispatched on every iteration; allocating a fresh + # `E*K*N*sizeof(act)` tensor each call adds non-trivial caching- + # allocator overhead (and, on the small shapes, dominates the + # quantized GEMM cost). The workspace is kept alive by the cache so + # we hand the data_ptr() to the kernel without taking a new ref. + dequant_workspace = _get_moe_prefill_workspace( + activations.device, activations.dtype, num_experts, K, N + ) + weights_ptr = weights.data_ptr() + + scales_ptr = scales.data_ptr() if scales is not None else 0 + zeros_ptr = zeros.data_ptr() if zeros is not None else 0 + + lib.moe_gemm_prefill( + stream, + activations.data_ptr(), + weights_ptr, + scales_ptr, + zeros_ptr, + outputs.data_ptr(), + dequant_workspace.data_ptr(), + cvt_dtype(activations.dtype), + weight_dtype, + N, + K, + group_size, + num_tokens_per_expert.data_ptr(), + num_experts, + total_tokens, + bool(asym), + ) + # The inner CUTLASS-SYCL `moe_gemm` calls `event.wait()` before returning + # (see `moe_detail::moe_gemm_launcher` in `sycl_tla_moe.hpp`), so by the + # time `lib.moe_gemm_prefill` returns the device has already consumed the + # workspace. For the unquantized fast path the workspace is a per-call + # transposed copy of `weights` -- drop it now. For the quantized paths + # the workspace lives in the module-level cache (`_get_moe_prefill_workspace`) + # and is intentionally retained for reuse on the next call. + if is_unquantized: + del dequant_workspace + return outputs + + +# --------------------------------------------------------------------------- +# `moe_gemm_prefill` dequant-workspace cache. +# +# The Stage-1 quantized prefill kernel dequantises weights into an +# `[E, K, N]` act-dtype scratch buffer before dispatching to the existing +# CUTLASS-SYCL grouped GEMM. In real model usage the same `(E, K, N, dtype)` +# tuple is hit on every prefill step, so allocating a fresh +# `E * K * N * sizeof(act_dtype)` tensor per call adds caching-allocator +# overhead that is significant on the small/medium shapes. +# +# We cache one tensor per `(device, dtype, E, K, N)` key. The cache holds +# references that keep the tensors alive across calls; callers can clear it +# explicitly via `clear_moe_prefill_workspace_cache()` if they need to +# release the memory (e.g., before allocating large buffers for a different +# subsystem). +# --------------------------------------------------------------------------- + +_MOE_PREFILL_WORKSPACE_CACHE: "dict[tuple, torch.Tensor]" = {} + + +def _get_moe_prefill_workspace(device: torch.device, dtype: torch.dtype, E: int, K: int, N: int) -> torch.Tensor: + """Return a persistent `[E, K, N]` workspace tensor for the prefill kernel. + + The tensor is allocated lazily on first use and retained in a module-level + cache so subsequent calls with the same `(device, dtype, E, K, N)` reuse + the same memory. Returned tensors are contiguous and uninitialised; the + kernel writes every element before reading. + """ + # `device` may be a `torch.device` or a string; normalise so the cache key + # is hashable and identifies the exact device (including ordinal). + if not isinstance(device, torch.device): + device = torch.device(device) + key = (device.type, device.index, dtype, int(E), int(K), int(N)) + ws = _MOE_PREFILL_WORKSPACE_CACHE.get(key) + if ws is None: + ws = torch.empty((E, K, N), device=device, dtype=dtype) + _MOE_PREFILL_WORKSPACE_CACHE[key] = ws + return ws + + +def clear_moe_prefill_workspace_cache() -> None: + """Release all cached `moe_gemm_prefill` dequant-workspace tensors.""" + _MOE_PREFILL_WORKSPACE_CACHE.clear() + + +# --------------------------------------------------------------------------- +# Unified MoE entry point +# +# `moe_gemm_decode` and `moe_gemm_prefill` accept identical argument shapes +# and dtypes -- the only difference is which underlying SYCL kernel is +# launched (a GEMV variant tuned for 1-2 tokens/expert vs. a Grouped GEMM +# variant tuned for many tokens/expert). Model code that runs through both +# regimes (prefill of a prompt, then autoregressive decode) traditionally +# has to keep two call sites and branch on phase. `moe(...)` collapses that +# into a single API and auto-selects the right kernel from the token +# distribution. +# +# Callers that already know the phase (e.g., a model's generation loop knows +# whether it's in prefill or decode) should pass it via the `phase` argument +# to avoid the small host-device sync that `phase="auto"` needs to inspect +# `num_tokens_per_expert.max()`. +# --------------------------------------------------------------------------- + +# Default tokens-per-expert threshold used by `phase="auto"`. The decode +# GEMV kernel is faster when every expert sees only a handful of tokens +# (TopK >= 1 with batch size 1-4); above that the GEMM-tuned prefill kernel +# wins. The crossover is hardware-dependent but `4` is a conservative default +# that matches the regime `moe_gemm_decode`'s docstring describes +# ("typically only 1-2 tokens", up to top-k * small batch). +_MOE_AUTO_DECODE_MAX_TOKENS_PER_EXPERT = 4 + +_MOE_VALID_PHASES = ("auto", "decode", "prefill") + + +def moe( + activations: torch.Tensor, + weights: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + *, + scales: Optional[torch.Tensor] = None, + zeros: Optional[torch.Tensor] = None, + weight_bits: int = 4, + group_size: int = 128, + asym: bool = False, + phase: str = "auto", + decode_threshold: int = _MOE_AUTO_DECODE_MAX_TOKENS_PER_EXPERT, +) -> torch.Tensor: + """Unified MoE GEMM entry point that dispatches to decode or prefill. + + This is a thin Python-side dispatcher over :func:`moe_gemm_decode` and + :func:`moe_gemm_prefill`. The two underlying kernels accept the same + argument shapes/dtypes (see :func:`moe_gemm_decode` for the full layout + contract); ``moe`` simply picks the one that is faster for the current + token distribution so model code can have a single call site for both + prefill and decode phases. + + Args: + activations: ``[total_tokens, K]`` in fp16 or bf16. + weights: ``[E, N, K_packed]`` -- see :func:`moe_gemm_decode` for the + quant-specific layout/dtype contract. + num_tokens_per_expert: ``[E]`` int32. Sum must equal + ``activations.shape[0]``. + scales, zeros, weight_bits, group_size, asym: forwarded to the + underlying kernel; see :func:`moe_gemm_decode`. + phase: dispatch mode. + + * ``"auto"`` (default): inspect ``num_tokens_per_expert.max()`` + and pick decode if every expert sees ``<= decode_threshold`` + tokens, otherwise prefill. This incurs one small host-device + sync per call. + * ``"decode"``: always dispatch to :func:`moe_gemm_decode`. Use + when the model's generation loop already knows it is in the + decode phase; avoids the sync. + * ``"prefill"``: always dispatch to :func:`moe_gemm_prefill`. + Use when the model knows it is in the prefill phase. + decode_threshold: ``"auto"`` mode dispatches to decode when + ``num_tokens_per_expert.max() <= decode_threshold``. Defaults to + 4 (the regime the decode GEMV kernel is tuned for). + + Returns: + ``[total_tokens, N]`` in the activations dtype. Bit-identical to the + underlying kernel that was dispatched. + """ + if phase not in _MOE_VALID_PHASES: + raise ValueError(f"phase must be one of {_MOE_VALID_PHASES}, got {phase!r}") + + if phase == "auto": + # `.max().item()` triggers a host-device sync; callers in tight + # decode loops should pass `phase="decode"` explicitly to skip this. + # We tolerate a non-int32 / non-contiguous tensor here because the + # downstream kernel wrappers will normalise it anyway. + if num_tokens_per_expert.numel() == 0: + raise ValueError("num_tokens_per_expert must be non-empty") + max_tpe = int(num_tokens_per_expert.max().item()) + phase = "decode" if max_tpe <= int(decode_threshold) else "prefill" + + if phase == "decode": + return moe_gemm_decode( + activations, + weights, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=weight_bits, + group_size=group_size, + asym=asym, + ) + # phase == "prefill" + return moe_gemm_prefill( + activations, + weights, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=weight_bits, + group_size=group_size, + asym=asym, + ) + + def patch_torch_sdpa(*args, **kwargs): from .torch_sdpa_patch import patch_torch_sdpa_with_ark diff --git a/auto_round_extension/ark/auto_round_kernel/ark.cpp b/auto_round_extension/ark/auto_round_kernel/ark.cpp index dbf0eac16..6a5d390fe 100755 --- a/auto_round_extension/ark/auto_round_kernel/ark.cpp +++ b/auto_round_extension/ark/auto_round_kernel/ark.cpp @@ -25,6 +25,8 @@ typedef uintptr_t torch_ptr; // Only include declarations, implementations are in separate .cpp files #include "sycl_tla_common.hpp" #include "sycl_tla_moe.hpp" +#include "sycl_tla_moe_decode.hpp" +#include "sycl_tla_moe_mixed.hpp" #include "sycl_tla_sdpa.hpp" #endif #else @@ -222,6 +224,27 @@ static void moe_gemm_wrapper(torch_ptr stream, torch_ptr activations, torch_ptr (void*)outputs, (BTLA_DTYPE)(dtype), N, K, (int*)num_tokens_per_expert, num_experts); } +static void moe_gemm_decode_wrapper(torch_ptr stream, torch_ptr activations, torch_ptr weights, torch_ptr scales, + torch_ptr zeros, torch_ptr outputs, torch_ptr expert_id_per_token_buf, + int act_dtype, int weight_dtype, int N, int K, int group_size, + torch_ptr num_tokens_per_expert, int num_experts, int total_tokens, bool asym) { + ark::moe_gemm_decode((sycl::queue*)stream, (void*)activations, (void*)weights, scales ? (void*)scales : nullptr, + zeros ? (void*)zeros : nullptr, (void*)outputs, (int*)expert_id_per_token_buf, + (BTLA_DTYPE)(act_dtype), (BTLA_DTYPE)(weight_dtype), N, K, group_size, + (int*)num_tokens_per_expert, num_experts, total_tokens, asym); +} + +static void moe_gemm_prefill_wrapper(torch_ptr stream, torch_ptr activations, torch_ptr weights, torch_ptr scales, + torch_ptr zeros, torch_ptr outputs, torch_ptr dequant_workspace, int act_dtype, + int weight_dtype, int N, int K, int group_size, torch_ptr num_tokens_per_expert, + int num_experts, int total_tokens, bool asym) { + ark::moe_gemm_prefill((sycl::queue*)stream, (void*)activations, (void*)weights, scales ? (void*)scales : nullptr, + zeros ? (void*)zeros : nullptr, (void*)outputs, + dequant_workspace ? (void*)dequant_workspace : nullptr, (BTLA_DTYPE)(act_dtype), + (BTLA_DTYPE)(weight_dtype), N, K, group_size, (int*)num_tokens_per_expert, num_experts, + total_tokens, asym); +} + static void sage_dynamic_quant(torch_ptr stream, torch_ptr input, torch_ptr bias, torch_ptr output, torch_ptr scale_out, int num_rows, int head_dim, int block_size) { auto* q = (sycl::queue*)stream; @@ -439,5 +462,7 @@ PYBIND11_MODULE(PY_NAME, m) { m.def("sage_dynamic_quant_layout", &ark::sage_dynamic_quant_layout); m.def("sage_dynamic_quant_v_layout", &ark::sage_dynamic_quant_v_layout); m.def("moe_gemm", &ark::moe_gemm_wrapper); + m.def("moe_gemm_decode", &ark::moe_gemm_decode_wrapper); + m.def("moe_gemm_prefill", &ark::moe_gemm_prefill_wrapper); #endif } \ No newline at end of file diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp index 806dcb0ca..cc8cd8971 100644 --- a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp @@ -35,6 +35,67 @@ namespace ark { void moe_gemm(sycl::queue* q, void* activations, void* weights, void* scales, void* outputs, BTLA_DTYPE dtype, int N, int K, int* num_tokens_per_expert, int num_experts); +/** + * @brief MoE GEMV optimized for the decode phase (M per expert is typically + * 1-2 tokens). Supports unquantized FP16/BF16 weights and int4 (S4_CLIP) + * weights with group-wise scales and optional zero-points. + * + * Implementation is header-only in `sycl_tla_moe_decode.hpp`. + * + * @param q SYCL queue + * @param activations [total_tokens, K] in `act_dtype` + * @param weights Unquantized: [num_experts, N, K] in act_dtype + * Int4: packed [num_experts, N, K/2] uint8 + * @param scales [num_experts, N, K/group_size] (act_dtype), + * ignored when weight_dtype is FP16/BF16 + * @param zeros [num_experts, N, K/group_size] (act_dtype) or + * nullptr; required when asym==true + * @param outputs [total_tokens, N] in act_dtype + * @param expert_id_per_token_buf [total_tokens] int32 scratch buffer (device) + * @param act_dtype BTLA_DTYPE::F16 or BTLA_DTYPE::BF16 + * @param weight_dtype BTLA_DTYPE::F16/BF16/S4_CLIP + * @param N Output feature dim (must be multiple of 16) + * @param K Input feature dim + * @param group_size Quantization group along K (int4 only); must + * divide K and be even. Default 128. + * @param num_tokens_per_expert [num_experts] int32 + * @param num_experts Number of experts + * @param total_tokens Sum of num_tokens_per_expert (== rows of + * activations / outputs) + * @param asym Whether int4 weights are asymmetric + * (zeros required when true). + */ +void moe_gemm_decode(sycl::queue* q, void* activations, void* weights, void* scales, void* zeros, void* outputs, + int* expert_id_per_token_buf, BTLA_DTYPE act_dtype, BTLA_DTYPE weight_dtype, int N, int K, + int group_size, int* num_tokens_per_expert, int num_experts, int total_tokens, bool asym); + +/** + * @brief MoE Grouped GEMM optimized for the prefill phase, supporting the + * same set of weight encodings as `moe_gemm_decode` (FP16/BF16, INT8 sym/asym, + * INT4 sym/asym, INT2 sym/asym, FP8 E4M3/E5M2). + * + * Stage-1 implementation: dequantizes weights into a `[num_experts, K, N]` + * temporary buffer (must be supplied by the caller via `dequant_workspace`, + * sized `num_experts * K * N * sizeof(act_dtype)`) and then dispatches to the + * existing `moe_gemm` baseline. This guarantees numerical parity with the + * decode path. Mainloop fusion is the follow-up perf-tuning step. + * + * Implementation is header-only in `sycl_tla_moe_mixed.hpp`. + * + * Layout convention (matches `moe_gemm_decode`): + * - activations: [total_tokens, K] in act_dtype + * - weights (quantized): [num_experts, N, K_p] uint8 (decode-style packed) + * - weights (FP16/BF16): [num_experts, K, N] in act_dtype (matches `moe_gemm`) + * - scales: [num_experts, N, K/group_size] in act_dtype + * - zeros (asym only): [num_experts, N, K/group_size] in act_dtype + * - dequant_workspace: [num_experts, K, N] in act_dtype, may be null + * for the unquantized fast path + * - outputs: [total_tokens, N] in act_dtype + */ +void moe_gemm_prefill(sycl::queue* q, void* activations, void* weights, void* scales, void* zeros, void* outputs, + void* dequant_workspace, BTLA_DTYPE act_dtype, BTLA_DTYPE weight_dtype, int N, int K, + int group_size, int* num_tokens_per_expert, int num_experts, int total_tokens, bool asym); + // ======================================================================== // Public API // ======================================================================== diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp new file mode 100644 index 000000000..c78446690 --- /dev/null +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp @@ -0,0 +1,847 @@ +// SYCL MoE Decode Kernel +// +// GEMV-style MoE kernel optimized for the decode phase, where each expert +// typically processes only 1-2 tokens (top-k routing with batch size 1). +// +// Layout convention (caller already sorted activations per expert, +// identical to the prefill `moe_gemm` interface): +// - activations: [total_tokens, K] row-major +// - weights (fp/bf16): [num_experts, N, K] row-major +// - weights (int8): [num_experts, N, K] row-major, one +// int8 per byte (sym: signed -128..127; +// asym: unsigned 0..255 with zero-point) +// - weights (int4 packed): [num_experts, N, K/2] row-major, two +// 4-bit values per byte (low nibble at lower K) +// - weights (int2 packed): [num_experts, N, K/4] row-major, four +// 2-bit values per byte (field j at K index +// 4*i+j is bits [2j+1:2j]) +// - weights (fp8): [num_experts, N, K] row-major, one +// FP8 byte per weight (E4M3 / E5M2); scales +// applied per-group, no zero-points +// - scales: [num_experts, N, K/group_size] +// - zeros (asym only): [num_experts, N, K/group_size] +// - num_tokens_per_expert: [num_experts] int32 +// - outputs: [total_tokens, N] +// +// Target: Intel BMG (Xe2), sub_group_size = 16. One sub-group per (token, N-tile) +// with N_TILE == SG_SIZE: each lane independently computes one output element, +// so no cross-lane reduction is needed and activation reads are coalesced across +// the sub-group through the L1 cache. +// +// Copyright (C) 2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "bestla/bestla.h" +#include "sycl_tla_moe_dequant.hpp" + +#ifdef ARK_XPU +#include +#endif + +// ---------------------------------------------------------------------------- +// FP8 decode implementation switch (runtime) +// +// FP8 weight bytes can be dequantized either via inline bit manipulation or +// via the 128-entry magnitude LUT in `bestla/sycl/fp8_lut.h` (sign applied +// separately). Both paths are mathematically equivalent for finite values; +// pick whichever is faster on the target hardware. +// +// Selection is done at runtime through the environment variable +// `ARK_FP8_DECODE_USE_LUT`: +// - unset / "1" / "true" / "on" / "yes" (case-insensitive) -> LUT path (default) +// - "0" / "false" / "off" / "no" (case-insensitive) -> inline bit-manip +// +// The env var is read once on the host (cached) and passed as a template +// parameter into the SYCL kernel, so there is no per-element runtime branch. +// The actual primitives live in `sycl_tla_moe_dequant.hpp` (shared with the +// mixed-input prefill path); this file just re-exports them via `using`. +// ---------------------------------------------------------------------------- + +#if defined(ARK_XPU) && defined(ARK_SYCL_TLA) + +namespace ark { +namespace moe_decode_detail { + +constexpr int SG_SIZE = 16; +constexpr int N_TILE = SG_SIZE; // one output element per sub-group lane + +// ---------------------------------------------------------------------------- +// Kernel name tags (one per specialization, required for SYCL kernel naming) +// ---------------------------------------------------------------------------- +template +class MoEDecodeKernelFP; + +template +class MoEDecodeKernelInt4; + +template +class MoEDecodeKernelInt8; + +template +class MoEDecodeKernelInt2; + +template +class MoEDecodeKernelFP8; + +// ---------------------------------------------------------------------------- +// FP8 weight dequantization primitives + host-side env-var reader live in +// `sycl_tla_moe_dequant.hpp` so the prefill (mixed-input Grouped GEMM) and +// decode (GEMV) paths share one definition. The `using` declarations below +// keep the in-kernel call sites (`decode_fp8<...>(byte)`) and the host-side +// `fp8_decode_use_lut()` lookup inside `moe_decode_detail` working unchanged. +// ---------------------------------------------------------------------------- +using moe_dequant::decode_fp8; +using moe_dequant::decode_fp8_e4m3_bits; +using moe_dequant::decode_fp8_e4m3_lut; +using moe_dequant::decode_fp8_e5m2_bits; +using moe_dequant::decode_fp8_e5m2_lut; +using moe_dequant::fp8_decode_use_lut; + +// ---------------------------------------------------------------------------- +// Build a [total_tokens] -> expert_id mapping from num_tokens_per_expert. +// Runs on host (num_experts is small, total_tokens is small in decode). +// Caller-managed buffer (USM device allocation) keeps host noise out of the +// hot path; here we just fill it via a tiny SYCL kernel for simplicity. +// ---------------------------------------------------------------------------- +inline void fill_expert_id_per_token(sycl::queue* q, int* expert_id_per_token, + const int* num_tokens_per_expert, int num_experts, + int total_tokens) { + // Parallel fill: each work-item independently scans the small + // num_tokens_per_expert array (typ. <= 256) to find its expert id. This + // removes the single-task serialization point and avoids an explicit + // host-device sync; the in-order queue chains this with the GEMV launch. + if (total_tokens == 0) return; + q->parallel_for(sycl::range<1>(static_cast(total_tokens)), [=](sycl::id<1> idx) { + const int i = static_cast(idx[0]); + int offset = 0; + int expert = num_experts - 1; + for (int e = 0; e < num_experts; ++e) { + const int n = num_tokens_per_expert[e]; + if (i < offset + n) { + expert = e; + break; + } + offset += n; + } + expert_id_per_token[i] = expert; + }); +} + +// ---------------------------------------------------------------------------- +// FP16 / BF16 baseline GEMV (no quantization). +// ---------------------------------------------------------------------------- +template +void launch_fp(sycl::queue* q, const ScalarT* activations, const ScalarT* weights, ScalarT* outputs, + const int* expert_id_per_token, int total_tokens, int N, int K) { + if (N % N_TILE != 0) { + throw std::invalid_argument("moe_gemm_decode: N must be a multiple of 16"); + } + if (total_tokens == 0) return; + + const int n_tiles = N / N_TILE; + sycl::range<2> global{static_cast(total_tokens), static_cast(n_tiles * SG_SIZE)}; + sycl::range<2> local{1, static_cast(SG_SIZE)}; + + q->parallel_for>( + sycl::nd_range<2>(global, local), + [=](sycl::nd_item<2> it) [[intel::reqd_sub_group_size(SG_SIZE)]] { + const int token = static_cast(it.get_global_id(0)); + const int n_tile = static_cast(it.get_group(1)); + const int lane = static_cast(it.get_local_id(1)); + const int n_global = n_tile * N_TILE + lane; + + const int expert = expert_id_per_token[token]; + const ScalarT* act_row = activations + static_cast(token) * K; + const ScalarT* w_row = + weights + (static_cast(expert) * N + static_cast(n_global)) * K; + + float acc = 0.0f; + // Unroll by 8 with a 16-byte vector load for both activations and + // weights. Activations are sub-group-uniform so they coalesce via + // L1; each lane's weight load is an independent 16-byte transaction. + // We load through a uint16_t vector to stay portable across SYCL + // implementations that may not provide sycl::vec. + int k = 0; + constexpr int VEC = 8; + using LoadVec = sycl::vec; + static_assert(sizeof(ScalarT) == sizeof(uint16_t), + "ScalarT must be a 16-bit floating type"); + const int k_vec_end = (K / VEC) * VEC; + for (; k < k_vec_end; k += VEC) { + const LoadVec av = *reinterpret_cast(act_row + k); + const LoadVec wv = *reinterpret_cast(w_row + k); +#pragma unroll + for (int u = 0; u < VEC; ++u) { + const ScalarT a = sycl::bit_cast(static_cast(av[u])); + const ScalarT w = sycl::bit_cast(static_cast(wv[u])); + acc += static_cast(a) * static_cast(w); + } + } + for (; k < K; ++k) { + acc += static_cast(act_row[k]) * static_cast(w_row[k]); + } + + outputs[static_cast(token) * N + n_global] = static_cast(acc); + }); +} + +// ---------------------------------------------------------------------------- +// INT4 (S4_CLIP) GEMV with group-wise dequantization. +// +// Asym=false: signed nibble in [-8, 7], dequant = q * scale +// Asym=true : unsigned nibble in [0, 15], dequant = (q - zero) * scale +// +// Packing: two 4-bit values per byte; the value at k = 2*i is the LOW nibble +// of byte i, the value at k = 2*i+1 is the HIGH nibble. This matches the +// existing CPU/XPU `packq` layout for S4_CLIP weights. +// ---------------------------------------------------------------------------- +template +void launch_int4(sycl::queue* q, const ScalarT* activations, const uint8_t* weights, const ScalarT* scales, + const ScalarT* zeros, ScalarT* outputs, const int* expert_id_per_token, int total_tokens, int N, + int K, int group_size) { + if (N % N_TILE != 0) { + throw std::invalid_argument("moe_gemm_decode(int4): N must be a multiple of 16"); + } + if (K % group_size != 0 || (group_size & 1) != 0) { + throw std::invalid_argument("moe_gemm_decode(int4): K must be a multiple of group_size and group_size must be even"); + } + if (Asym && zeros == nullptr) { + throw std::invalid_argument("moe_gemm_decode(int4): zeros pointer required when asym=true"); + } + if (total_tokens == 0) return; + + const int n_tiles = N / N_TILE; + const int num_groups_k = K / group_size; + const int k_packed = K / 2; // bytes of packed weight per (expert, n) + + sycl::range<2> global{static_cast(total_tokens), static_cast(n_tiles * SG_SIZE)}; + sycl::range<2> local{1, static_cast(SG_SIZE)}; + + q->parallel_for>( + sycl::nd_range<2>(global, local), + [=](sycl::nd_item<2> it) [[intel::reqd_sub_group_size(SG_SIZE)]] { + const int token = static_cast(it.get_global_id(0)); + const int n_tile = static_cast(it.get_group(1)); + const int lane = static_cast(it.get_local_id(1)); + const int n_global = n_tile * N_TILE + lane; + + const int expert = expert_id_per_token[token]; + const ScalarT* act_row = activations + static_cast(token) * K; + + const uint8_t* w_row = + weights + (static_cast(expert) * N + static_cast(n_global)) * k_packed; + const ScalarT* s_row = + scales + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k; + const ScalarT* z_row = Asym + ? zeros + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k + : nullptr; + + float acc = 0.0f; + for (int g = 0; g < num_groups_k; ++g) { + const float scale = static_cast(s_row[g]); + float zero = 0.0f; + if constexpr (Asym) { + zero = static_cast(z_row[g]); + } + const int k_base = g * group_size; + // Vectorized path: process 16 K-elements at a time, which is + // 8 packed weight bytes and a vec activation block. + // group_size is a multiple of 16 in every supported config + // (group_size >= 32, even); a scalar tail loop covers leftovers. + constexpr int CHUNK = 16; + using ActVec = sycl::vec; + using PackVec = sycl::vec; + static_assert(sizeof(ScalarT) == sizeof(uint16_t), + "ScalarT must be a 16-bit floating type"); + const int chunk_end = (group_size / CHUNK) * CHUNK; + int kk = 0; + for (; kk < chunk_end; kk += CHUNK) { + const ActVec av = *reinterpret_cast(act_row + k_base + kk); + const PackVec pv = *reinterpret_cast(w_row + (k_base + kk) / 2); +#pragma unroll + for (int b = 0; b < CHUNK / 2; ++b) { + const uint8_t packed = pv[b]; + float w0, w1; + if constexpr (Asym) { + const int q0 = static_cast(packed & 0x0F); + const int q1 = static_cast((packed >> 4) & 0x0F); + w0 = (static_cast(q0) - zero) * scale; + w1 = (static_cast(q1) - zero) * scale; + } else { + const int q0 = static_cast(static_cast(packed << 4) >> 4); + const int q1 = static_cast(static_cast(packed & 0xF0) >> 4); + w0 = static_cast(q0) * scale; + w1 = static_cast(q1) * scale; + } + const ScalarT a0 = sycl::bit_cast(static_cast(av[2 * b])); + const ScalarT a1 = sycl::bit_cast(static_cast(av[2 * b + 1])); + acc += static_cast(a0) * w0; + acc += static_cast(a1) * w1; + } + } + // Scalar tail for group_size not divisible by CHUNK. + for (; kk < group_size; kk += 2) { + const uint8_t packed = w_row[(k_base + kk) / 2]; + float w0, w1; + if constexpr (Asym) { + const int q0 = static_cast(packed & 0x0F); + const int q1 = static_cast((packed >> 4) & 0x0F); + w0 = (static_cast(q0) - zero) * scale; + w1 = (static_cast(q1) - zero) * scale; + } else { + const int q0 = static_cast(static_cast(packed << 4) >> 4); + const int q1 = static_cast(static_cast(packed & 0xF0) >> 4); + w0 = static_cast(q0) * scale; + w1 = static_cast(q1) * scale; + } + acc += static_cast(act_row[k_base + kk]) * w0; + acc += static_cast(act_row[k_base + kk + 1]) * w1; + } + } + + outputs[static_cast(token) * N + n_global] = static_cast(acc); + }); +} + +// ---------------------------------------------------------------------------- +// INT8 (S8) GEMV with group-wise dequantization. +// +// Asym=false: signed byte in [-128, 127], dequant = q * scale +// Asym=true : unsigned byte in [0, 255], dequant = (q - zero) * scale +// +// Weights are stored as raw uint8 bytes (1 byte per weight). The same buffer +// type is used for sym and asym; the only difference is the sign interpretation +// performed at decode time. +// ---------------------------------------------------------------------------- +template +void launch_int8(sycl::queue* q, const ScalarT* activations, const uint8_t* weights, const ScalarT* scales, + const ScalarT* zeros, ScalarT* outputs, const int* expert_id_per_token, int total_tokens, int N, + int K, int group_size) { + if (N % N_TILE != 0) { + throw std::invalid_argument("moe_gemm_decode(int8): N must be a multiple of 16"); + } + if (K % group_size != 0) { + throw std::invalid_argument("moe_gemm_decode(int8): K must be a multiple of group_size"); + } + if (Asym && zeros == nullptr) { + throw std::invalid_argument("moe_gemm_decode(int8): zeros pointer required when asym=true"); + } + if (total_tokens == 0) return; + + const int n_tiles = N / N_TILE; + const int num_groups_k = K / group_size; + + sycl::range<2> global{static_cast(total_tokens), static_cast(n_tiles * SG_SIZE)}; + sycl::range<2> local{1, static_cast(SG_SIZE)}; + + q->parallel_for>( + sycl::nd_range<2>(global, local), + [=](sycl::nd_item<2> it) [[intel::reqd_sub_group_size(SG_SIZE)]] { + const int token = static_cast(it.get_global_id(0)); + const int n_tile = static_cast(it.get_group(1)); + const int lane = static_cast(it.get_local_id(1)); + const int n_global = n_tile * N_TILE + lane; + + const int expert = expert_id_per_token[token]; + const ScalarT* act_row = activations + static_cast(token) * K; + + const uint8_t* w_row = + weights + (static_cast(expert) * N + static_cast(n_global)) * K; + const ScalarT* s_row = + scales + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k; + const ScalarT* z_row = Asym + ? zeros + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k + : nullptr; + + float acc = 0.0f; + for (int g = 0; g < num_groups_k; ++g) { + const float scale = static_cast(s_row[g]); + float zero = 0.0f; + if constexpr (Asym) { + zero = static_cast(z_row[g]); + } + const int k_base = g * group_size; + // Vectorized path: 16 weights (16 bytes) + 16 activations per load. + // group_size is typically 128 (mult of 16); scalar tail handles + // anything that doesn't divide evenly. + constexpr int CHUNK = 16; + using ActVec = sycl::vec; + using ByteVec = sycl::vec; + static_assert(sizeof(ScalarT) == sizeof(uint16_t), + "ScalarT must be a 16-bit floating type"); + const int chunk_end = (group_size / CHUNK) * CHUNK; + int kk = 0; + for (; kk < chunk_end; kk += CHUNK) { + const ActVec av = *reinterpret_cast(act_row + k_base + kk); + const ByteVec wv = *reinterpret_cast(w_row + k_base + kk); +#pragma unroll + for (int u = 0; u < CHUNK; ++u) { + const uint8_t raw = wv[u]; + float w; + if constexpr (Asym) { + w = (static_cast(raw) - zero) * scale; + } else { + w = static_cast(static_cast(raw)) * scale; + } + const ScalarT a = sycl::bit_cast(static_cast(av[u])); + acc += static_cast(a) * w; + } + } + for (; kk < group_size; ++kk) { + const uint8_t raw = w_row[k_base + kk]; + float w; + if constexpr (Asym) { + w = (static_cast(raw) - zero) * scale; + } else { + w = static_cast(static_cast(raw)) * scale; + } + acc += static_cast(act_row[k_base + kk]) * w; + } + } + + outputs[static_cast(token) * N + n_global] = static_cast(acc); + }); +} + +// ---------------------------------------------------------------------------- +// INT2 (S2_CLIP) GEMV with group-wise dequantization. +// +// Packing: 4 values per byte. The value at K index 4*i + j is stored in +// bits [2j+1 : 2j] of byte i (i.e. byte = q0 | (q1<<2) | (q2<<4) | (q3<<6)). +// +// Asym=false: signed 2-bit value in [-2, 1]; dequant = q * scale +// Asym=true : unsigned 2-bit value in [0, 3]; dequant = (q - zero) * scale +// ---------------------------------------------------------------------------- +template +void launch_int2(sycl::queue* q, const ScalarT* activations, const uint8_t* weights, const ScalarT* scales, + const ScalarT* zeros, ScalarT* outputs, const int* expert_id_per_token, int total_tokens, int N, + int K, int group_size) { + if (N % N_TILE != 0) { + throw std::invalid_argument("moe_gemm_decode(int2): N must be a multiple of 16"); + } + if ((K & 0x3) != 0) { + throw std::invalid_argument("moe_gemm_decode(int2): K must be a multiple of 4"); + } + if (K % group_size != 0 || (group_size & 0x3) != 0) { + throw std::invalid_argument( + "moe_gemm_decode(int2): K must be a multiple of group_size and group_size must be a multiple of 4"); + } + if (Asym && zeros == nullptr) { + throw std::invalid_argument("moe_gemm_decode(int2): zeros pointer required when asym=true"); + } + if (total_tokens == 0) return; + + const int n_tiles = N / N_TILE; + const int num_groups_k = K / group_size; + const int k_packed = K / 4; // bytes of packed weight per (expert, n) + + sycl::range<2> global{static_cast(total_tokens), static_cast(n_tiles * SG_SIZE)}; + sycl::range<2> local{1, static_cast(SG_SIZE)}; + + q->parallel_for>( + sycl::nd_range<2>(global, local), + [=](sycl::nd_item<2> it) [[intel::reqd_sub_group_size(SG_SIZE)]] { + const int token = static_cast(it.get_global_id(0)); + const int n_tile = static_cast(it.get_group(1)); + const int lane = static_cast(it.get_local_id(1)); + const int n_global = n_tile * N_TILE + lane; + + const int expert = expert_id_per_token[token]; + const ScalarT* act_row = activations + static_cast(token) * K; + + const uint8_t* w_row = + weights + (static_cast(expert) * N + static_cast(n_global)) * k_packed; + const ScalarT* s_row = + scales + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k; + const ScalarT* z_row = Asym + ? zeros + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k + : nullptr; + + float acc = 0.0f; + for (int g = 0; g < num_groups_k; ++g) { + const float scale = static_cast(s_row[g]); + float zero = 0.0f; + if constexpr (Asym) { + zero = static_cast(z_row[g]); + } + const int k_base = g * group_size; + // Vectorized: 16 K-elements per chunk = 4 packed bytes (4 values + // each) plus a vec activation block. group_size is a + // multiple of 4 and typically 128 (mult of 16); scalar tail covers + // any leftover. We load activations via uint16_t to stay portable + // across SYCL implementations that may not provide + // sycl::vec. + constexpr int CHUNK = 16; + using ActVec = sycl::vec; + using PackVec = sycl::vec; + static_assert(sizeof(ScalarT) == sizeof(uint16_t), + "ScalarT must be a 16-bit floating type"); + const int chunk_end = (group_size / CHUNK) * CHUNK; + int kk = 0; + for (; kk < chunk_end; kk += CHUNK) { + const ActVec av = *reinterpret_cast(act_row + k_base + kk); + const PackVec pv = *reinterpret_cast(w_row + (k_base + kk) / 4); +#pragma unroll + for (int b = 0; b < CHUNK / 4; ++b) { + const uint8_t packed = pv[b]; + float w0, w1, w2, w3; + if constexpr (Asym) { + const int q0 = static_cast(packed & 0x3); + const int q1 = static_cast((packed >> 2) & 0x3); + const int q2 = static_cast((packed >> 4) & 0x3); + const int q3 = static_cast((packed >> 6) & 0x3); + w0 = (static_cast(q0) - zero) * scale; + w1 = (static_cast(q1) - zero) * scale; + w2 = (static_cast(q2) - zero) * scale; + w3 = (static_cast(q3) - zero) * scale; + } else { + const int q0 = static_cast(static_cast(packed << 6) >> 6); + const int q1 = static_cast(static_cast((packed << 4) & 0xC0) >> 6); + const int q2 = static_cast(static_cast((packed << 2) & 0xC0) >> 6); + const int q3 = static_cast(static_cast(packed & 0xC0) >> 6); + w0 = static_cast(q0) * scale; + w1 = static_cast(q1) * scale; + w2 = static_cast(q2) * scale; + w3 = static_cast(q3) * scale; + } + const ScalarT a0 = sycl::bit_cast(static_cast(av[4 * b + 0])); + const ScalarT a1 = sycl::bit_cast(static_cast(av[4 * b + 1])); + const ScalarT a2 = sycl::bit_cast(static_cast(av[4 * b + 2])); + const ScalarT a3 = sycl::bit_cast(static_cast(av[4 * b + 3])); + acc += static_cast(a0) * w0; + acc += static_cast(a1) * w1; + acc += static_cast(a2) * w2; + acc += static_cast(a3) * w3; + } + } + // Scalar tail (4 values per byte). + for (; kk < group_size; kk += 4) { + const uint8_t packed = w_row[(k_base + kk) / 4]; + float w[4]; + if constexpr (Asym) { + const int q0 = static_cast(packed & 0x3); + const int q1 = static_cast((packed >> 2) & 0x3); + const int q2 = static_cast((packed >> 4) & 0x3); + const int q3 = static_cast((packed >> 6) & 0x3); + w[0] = (static_cast(q0) - zero) * scale; + w[1] = (static_cast(q1) - zero) * scale; + w[2] = (static_cast(q2) - zero) * scale; + w[3] = (static_cast(q3) - zero) * scale; + } else { + const int q0 = static_cast(static_cast(packed << 6) >> 6); + const int q1 = static_cast(static_cast((packed << 4) & 0xC0) >> 6); + const int q2 = static_cast(static_cast((packed << 2) & 0xC0) >> 6); + const int q3 = static_cast(static_cast(packed & 0xC0) >> 6); + w[0] = static_cast(q0) * scale; + w[1] = static_cast(q1) * scale; + w[2] = static_cast(q2) * scale; + w[3] = static_cast(q3) * scale; + } + acc += static_cast(act_row[k_base + kk + 0]) * w[0]; + acc += static_cast(act_row[k_base + kk + 1]) * w[1]; + acc += static_cast(act_row[k_base + kk + 2]) * w[2]; + acc += static_cast(act_row[k_base + kk + 3]) * w[3]; + } + } + + outputs[static_cast(token) * N + n_global] = static_cast(acc); + }); +} + +// ---------------------------------------------------------------------------- +// FP8 (E4M3 / E5M2) GEMV with group-wise scale (no zero-point). +// +// Weights are 1 FP8 byte per element [E, N, K]. The byte is decoded via the +// `decode_fp8` helper, which selects between the LUT and the +// inline bit-manipulation path at compile time. The choice is driven at +// launch time by the env var `ARK_FP8_DECODE_USE_LUT` (default: ON). +// ---------------------------------------------------------------------------- +template +void launch_fp8(sycl::queue* q, const ScalarT* activations, const uint8_t* weights, const ScalarT* scales, + ScalarT* outputs, const int* expert_id_per_token, int total_tokens, int N, int K, int group_size) { + if (N % N_TILE != 0) { + throw std::invalid_argument("moe_gemm_decode(fp8): N must be a multiple of 16"); + } + if (K % group_size != 0) { + throw std::invalid_argument("moe_gemm_decode(fp8): K must be a multiple of group_size"); + } + if (total_tokens == 0) return; + + const int n_tiles = N / N_TILE; + const int num_groups_k = K / group_size; + + sycl::range<2> global{static_cast(total_tokens), static_cast(n_tiles * SG_SIZE)}; + sycl::range<2> local{1, static_cast(SG_SIZE)}; + + q->parallel_for>( + sycl::nd_range<2>(global, local), + [=](sycl::nd_item<2> it) [[intel::reqd_sub_group_size(SG_SIZE)]] { + const int token = static_cast(it.get_global_id(0)); + const int n_tile = static_cast(it.get_group(1)); + const int lane = static_cast(it.get_local_id(1)); + const int n_global = n_tile * N_TILE + lane; + + const int expert = expert_id_per_token[token]; + const ScalarT* act_row = activations + static_cast(token) * K; + + const uint8_t* w_row = + weights + (static_cast(expert) * N + static_cast(n_global)) * K; + const ScalarT* s_row = + scales + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k; + + float acc = 0.0f; + for (int g = 0; g < num_groups_k; ++g) { + const float scale = static_cast(s_row[g]); + const int k_base = g * group_size; + // Vectorized: 16 weights (16 bytes) + 16 activations per load. + // Decode each FP8 byte to float inline, then apply the per-group + // scale. group_size is typically 128 (mult of 16); scalar tail + // covers anything that doesn't divide evenly. + constexpr int CHUNK = 16; + using ActVec = sycl::vec; + using ByteVec = sycl::vec; + static_assert(sizeof(ScalarT) == sizeof(uint16_t), + "ScalarT must be a 16-bit floating type"); + const int chunk_end = (group_size / CHUNK) * CHUNK; + int kk = 0; + for (; kk < chunk_end; kk += CHUNK) { + const ActVec av = *reinterpret_cast(act_row + k_base + kk); + const ByteVec wv = *reinterpret_cast(w_row + k_base + kk); +#pragma unroll + for (int u = 0; u < CHUNK; ++u) { + const uint8_t raw = wv[u]; + const float w = decode_fp8(raw) * scale; + const ScalarT a = sycl::bit_cast(static_cast(av[u])); + acc += static_cast(a) * w; + } + } + for (; kk < group_size; ++kk) { + const uint8_t raw = w_row[k_base + kk]; + const float w = decode_fp8(raw) * scale; + acc += static_cast(act_row[k_base + kk]) * w; + } + } + + outputs[static_cast(token) * N + n_global] = static_cast(acc); + }); +} + +} // namespace moe_decode_detail + +// ---------------------------------------------------------------------------- +// Public API +// +// weight_dtype: +// BTLA_DTYPE::F16 / BF16 : weights stored as [E, N, K] in matching +// floating dtype, no scales/zeros needed +// BTLA_DTYPE::S8 : int8 weights [E, N, K] (uint8 buffer, +// interpreted as signed when asym==false, +// unsigned with zero-points when asym==true) +// BTLA_DTYPE::S4_CLIP : packed int4 weights [E, N, K/2] (uint8), +// scales [E, N, K/group_size] in act dtype, +// zeros optional (asym==true requires it) +// BTLA_DTYPE::S2_CLIP : packed int2 weights [E, N, K/4] (uint8), +// 4 values per byte, sym/asym like int4 +// BTLA_DTYPE::F8_E4M3 / F8_E5M2 : FP8 weights [E, N, K] (uint8 buffer), +// group-wise scales, no zero-points +// act_dtype: F16 or BF16 (must match scales/outputs dtype) +// ---------------------------------------------------------------------------- +inline void moe_gemm_decode(sycl::queue* q, void* activations, void* weights, void* scales, void* zeros, + void* outputs, int* expert_id_per_token_buf, BTLA_DTYPE act_dtype, + BTLA_DTYPE weight_dtype, int N, int K, int group_size, int* num_tokens_per_expert, + int num_experts, int total_tokens, bool asym) { + moe_decode_detail::fill_expert_id_per_token(q, expert_id_per_token_buf, num_tokens_per_expert, num_experts, + total_tokens); + + if (weight_dtype == BTLA_DTYPE::F16 || weight_dtype == BTLA_DTYPE::BF16) { + if (weight_dtype != act_dtype) { + throw std::invalid_argument("moe_gemm_decode: unquantized weight_dtype must match act_dtype"); + } + if (act_dtype == BTLA_DTYPE::F16) { + moe_decode_detail::launch_fp(q, static_cast(activations), + static_cast(weights), + static_cast(outputs), expert_id_per_token_buf, + total_tokens, N, K); + } else { + moe_decode_detail::launch_fp( + q, static_cast(activations), + static_cast(weights), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K); + } + return; + } + + if (weight_dtype == BTLA_DTYPE::S4_CLIP) { + if (act_dtype == BTLA_DTYPE::F16) { + if (asym) { + moe_decode_detail::launch_int4( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int4( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else if (act_dtype == BTLA_DTYPE::BF16) { + using BF = sycl::ext::oneapi::bfloat16; + if (asym) { + moe_decode_detail::launch_int4( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int4( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else { + throw std::invalid_argument("moe_gemm_decode(int4): act_dtype must be FP16 or BF16"); + } + return; + } + + if (weight_dtype == BTLA_DTYPE::S8) { + if (act_dtype == BTLA_DTYPE::F16) { + if (asym) { + moe_decode_detail::launch_int8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else if (act_dtype == BTLA_DTYPE::BF16) { + using BF = sycl::ext::oneapi::bfloat16; + if (asym) { + moe_decode_detail::launch_int8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else { + throw std::invalid_argument("moe_gemm_decode(int8): act_dtype must be FP16 or BF16"); + } + return; + } + + if (weight_dtype == BTLA_DTYPE::S2_CLIP) { + if (act_dtype == BTLA_DTYPE::F16) { + if (asym) { + moe_decode_detail::launch_int2( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int2( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else if (act_dtype == BTLA_DTYPE::BF16) { + using BF = sycl::ext::oneapi::bfloat16; + if (asym) { + moe_decode_detail::launch_int2( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int2( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else { + throw std::invalid_argument("moe_gemm_decode(int2): act_dtype must be FP16 or BF16"); + } + return; + } + + if (weight_dtype == BTLA_DTYPE::F8_E4M3 || weight_dtype == BTLA_DTYPE::F8_E5M2) { + if (asym) { + throw std::invalid_argument("moe_gemm_decode(fp8): asym mode is not supported"); + } + const bool is_e4m3 = (weight_dtype == BTLA_DTYPE::F8_E4M3); + const bool use_lut = moe_decode_detail::fp8_decode_use_lut(); + if (act_dtype == BTLA_DTYPE::F16) { + if (is_e4m3) { + if (use_lut) { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, + total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, + total_tokens, N, K, group_size); + } + } else { + if (use_lut) { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, + total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, + total_tokens, N, K, group_size); + } + } + } else if (act_dtype == BTLA_DTYPE::BF16) { + using BF = sycl::ext::oneapi::bfloat16; + if (is_e4m3) { + if (use_lut) { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, + group_size); + } else { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, + group_size); + } + } else { + if (use_lut) { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, + group_size); + } else { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, + group_size); + } + } + } else { + throw std::invalid_argument("moe_gemm_decode(fp8): act_dtype must be FP16 or BF16"); + } + return; + } + + throw std::invalid_argument( + "moe_gemm_decode: unsupported weight_dtype (supported: F16, BF16, S8, S4_CLIP, S2_CLIP, F8_E4M3, F8_E5M2)"); +} + +} // namespace ark + +#endif // ARK_XPU && ARK_SYCL_TLA diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_dequant.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_dequant.hpp new file mode 100644 index 000000000..191286717 --- /dev/null +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_dequant.hpp @@ -0,0 +1,150 @@ +// SYCL MoE Weight Dequantization Primitives +// +// Device-side dequantization helpers shared between the MoE *decode* (GEMV) +// kernel in `sycl_tla_moe_decode.hpp` and the MoE *prefill* (mixed-input +// Grouped GEMM) kernel in `sycl_tla_moe_mixed.hpp`. Keeping the primitives +// in one place guarantees that both paths produce bit-identical results for +// the same packed weight bytes, which is what the round-trip parity tests +// (decode vs prefill) rely on. +// +// Currently extracted (PR-A1): the FP8 byte->float decoders and the host- +// side `ARK_FP8_DECODE_USE_LUT` env-var reader. INT2/INT4/INT8 decoders are +// still inlined inside the decode kernel; they will be added here when the +// mixed-input prefill mainloop in PR-A2/PR-A3 starts consuming them. +// +// Copyright (C) 2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include +#include + +#include "bestla/sycl/fp8_lut.h" + +#ifdef ARK_XPU +#include +#endif + +#if defined(ARK_XPU) && defined(ARK_SYCL_TLA) + +namespace ark { +namespace moe_dequant { + +// ---------------------------------------------------------------------------- +// FP8 byte -> float decode. +// Matches IEEE-style layout used by torch.float8_e4m3fn / torch.float8_e5m2: +// E4M3 (finite-only): 1 sign, 4 exp (bias 7), 3 mantissa; 0x7F/0xFF = NaN. +// E5M2 (IEEE-like): 1 sign, 5 exp (bias 15), 2 mantissa; exp==31 -> Inf/NaN. +// +// Two equivalent (for finite values) implementations are provided: +// - `_lut`: read magnitude from the 128-entry constexpr table in +// `bestla/sycl/fp8_lut.h`, apply sign separately. +// - `_bits`: fully self-contained inline bit-manipulation, no LUT/SLM. +// +// Selection happens at kernel launch time via a `bool UseLut` template +// parameter sourced from the env var `ARK_FP8_DECODE_USE_LUT` (read once on +// the host by `fp8_decode_use_lut()` below). This keeps the per-element hot +// path branch-free. +// ---------------------------------------------------------------------------- + +inline float decode_fp8_e4m3_lut(uint8_t byte) { + const uint32_t mag = byte & 0x7Fu; + const float v = bestla::sycl_prologue_b::fp8_lut::lut_e4m3_128[mag]; + return (byte & 0x80u) ? -v : v; +} + +inline float decode_fp8_e5m2_lut(uint8_t byte) { + const uint32_t mag = byte & 0x7Fu; + const float v = bestla::sycl_prologue_b::fp8_lut::lut_e5m2_128[mag]; + return (byte & 0x80u) ? -v : v; +} + +inline float decode_fp8_e4m3_bits(uint8_t byte) { + const uint32_t mag = byte & 0x7Fu; + const uint32_t sign = byte >> 7; + float v; + if (mag == 0u) { + v = 0.0f; + } else if (mag == 0x7Fu) { + v = sycl::nan(0u); + } else { + const int exp = static_cast((mag >> 3) & 0xFu); + const int man = static_cast(mag & 0x7u); + if (exp == 0) { + // subnormal: value = man * 2^(1 - bias - mbits) = man / 512 + v = static_cast(man) * (1.0f / 512.0f); + } else { + // normal: (1 + man/8) * 2^(exp - bias), bias = 7 + v = (1.0f + static_cast(man) * 0.125f) * sycl::ldexp(1.0f, exp - 7); + } + } + return sign ? -v : v; +} + +inline float decode_fp8_e5m2_bits(uint8_t byte) { + const uint32_t mag = byte & 0x7Fu; + const uint32_t sign = byte >> 7; + const int exp = static_cast((mag >> 2) & 0x1Fu); + const int man = static_cast(mag & 0x3u); + float v; + if (exp == 0) { + // subnormal (incl. zero): value = man * 2^(1 - bias - mbits) = man / 65536 + v = static_cast(man) * (1.0f / 65536.0f); + } else if (exp == 31) { + v = (man == 0) ? std::numeric_limits::infinity() : sycl::nan(0u); + } else { + // normal: (1 + man/4) * 2^(exp - bias), bias = 15 + v = (1.0f + static_cast(man) * 0.25f) * sycl::ldexp(1.0f, exp - 15); + } + return sign ? -v : v; +} + +// Compile-time dispatch helper. Both branches are resolved via `if constexpr`, +// so there is no per-element runtime cost regardless of which path is chosen. +template +inline float decode_fp8(uint8_t byte) { + if constexpr (UseLut) { + if constexpr (IsE4M3) { + return decode_fp8_e4m3_lut(byte); + } else { + return decode_fp8_e5m2_lut(byte); + } + } else { + if constexpr (IsE4M3) { + return decode_fp8_e4m3_bits(byte); + } else { + return decode_fp8_e5m2_bits(byte); + } + } +} + +// ---------------------------------------------------------------------------- +// Host-side env-var reader: cached, defaults to LUT enabled. +// +// `ARK_FP8_DECODE_USE_LUT`: +// - unset / "1" / "true" / "on" / "yes" (case-insensitive) -> LUT path (default) +// - "0" / "false" / "off" / "no" (case-insensitive) -> inline bit-manip +// +// Read once on first call and cached in a function-local static, so it is +// safe (and free) to call this on every launch. +// ---------------------------------------------------------------------------- +inline bool fp8_decode_use_lut() { + static const bool value = []() { + const char* env = std::getenv("ARK_FP8_DECODE_USE_LUT"); + if (env == nullptr) return true; // default: LUT on + std::string s(env); + for (char& c : s) c = static_cast(std::tolower(static_cast(c))); + if (s == "0" || s == "false" || s == "off" || s == "no") return false; + return true; + }(); + return value; +} + +} // namespace moe_dequant +} // namespace ark + +#endif // ARK_XPU && ARK_SYCL_TLA diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp new file mode 100644 index 000000000..94f91085e --- /dev/null +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_mixed.hpp @@ -0,0 +1,519 @@ +// SYCL MoE Mixed-Input Prefill Kernel +// +// MoE prefill (Grouped GEMM) entry point that accepts the same set of +// quantized weight encodings as the decode kernel in +// `sycl_tla_moe_decode.hpp` (FP16/BF16 baseline, INT8 sym/asym, INT4 +// sym/asym, INT2 sym/asym, FP8 E4M3/E5M2 with group-wise scale). +// +// Stage-1 implementation ("function-first"): a single device-side +// dequantization kernel materialises the per-expert weights into a +// `[E, K, N]` fp16/bf16 temporary, after which the existing CUTLASS-SYCL +// grouped GEMM (`moe_gemm` in `sycl_tla_moe.hpp`) is invoked. This keeps +// the dispatch surface, packing convention and numerical behaviour +// bit-identical to the decode path so end-to-end models can be validated +// and profiled. Mainloop fusion (mixed-input grouped GEMM) is the +// follow-up perf-tuning step. +// +// Layout convention (matches `sycl_tla_moe_decode.hpp`): +// - activations: [total_tokens, K] row-major +// - weights (fp/bf16): [num_experts, N, K] row-major +// - weights (int8): [num_experts, N, K] row-major (uint8 buf) +// - weights (int4 packed): [num_experts, N, K/2] row-major (uint8) +// - weights (int2 packed): [num_experts, N, K/4] row-major (uint8) +// - weights (fp8): [num_experts, N, K] row-major (uint8 buf) +// - scales: [num_experts, N, K/group] in act dtype +// - zeros (asym only): [num_experts, N, K/group] in act dtype +// - num_tokens_per_expert: [num_experts] int32 +// - outputs: [total_tokens, N] row-major +// +// The dequantized weights are written transposed to `[E, K, N]` so the +// existing prefill grouped GEMM (which expects `[E, K, N]` row-major) can +// consume them directly without an additional transpose pass. +// +// Copyright (C) 2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "bestla/bestla.h" +#include "sycl_tla_moe.hpp" +#include "sycl_tla_moe_dequant.hpp" + +#ifdef ARK_XPU +#include +#endif + +#if defined(ARK_XPU) && defined(ARK_SYCL_TLA) + +namespace ark { +namespace moe_mixed_detail { + +// ---------------------------------------------------------------------------- +// Kernel name tags (one per specialization, required for SYCL kernel naming). +// ---------------------------------------------------------------------------- +template +class MoEDequantKernelFP; + +template +class MoEDequantKernelInt8; + +template +class MoEDequantKernelInt4; + +template +class MoEDequantKernelInt2; + +template +class MoEDequantKernelFP8; + +// Tile sizes for the dequant kernels. +// +// Each work-group covers a (PACK_K x WG_N) tile in (k, n) and writes PACK_K +// consecutive K rows for the same column band. N is the inner dimension so +// stores to the `[E, K, N]` workspace stay coalesced across the sub-group. +// +// PACK_K is chosen per weight encoding so a single work-item dequantises an +// entire packed byte (INT4: 2 outputs, INT2: 4 outputs) or a small run of +// elements sharing one scale/zero load (INT8 / FP8 / FP transpose: 4 +// outputs). This removes the redundant packed-byte and scale loads that +// the previous "one element per work-item" launch incurred: +// - INT4: every byte was read twice (one item per nibble) -> now once. +// - INT2: every byte was read four times -> now once. +// - All quantized paths: every scale (and zero) was reloaded by every K +// element in the group -> now once per PACK_K elements. +// Group-wise scale sharing is safe because PACK_K *divides* group_size in +// every supported configuration: PACK_K is 2 (INT4) or 4 (INT2/INT8/FP8), +// and group_size is always a power of two >= 32 (typically 32, 64, or 128). +// As a result `k_base / group_size` yields the same group index for every +// K element in the PACK_K run, and each kernel can hoist a single scale +// (and, for asym, a single zero) load to amortise across the run. +// +// WG_N is the sub-group store width along N. 32 yields a single 64-byte +// coalesced burst per row for FP16/BF16 writes, which matches the L1 +// cache-line size on the target XPUs. +constexpr int WG_N = 32; +constexpr int PACK_K_FP = 4; +constexpr int PACK_K_INT8 = 4; +constexpr int PACK_K_INT4 = 2; +constexpr int PACK_K_INT2 = 4; +constexpr int PACK_K_FP8 = 4; + +// ---------------------------------------------------------------------------- +// FP16 / BF16 weight reshape: in-place transpose [E, N, K] -> [E, K, N]. +// Implemented as a generic dequant pass with identity scale; used so the +// public entry point can dispatch FP16/BF16 through the same code path +// as the quantized variants. +// ---------------------------------------------------------------------------- +template +void launch_dequant_fp(sycl::queue* q, const ScalarT* weights_NK, ScalarT* weights_KN, int E, int N, int K, + const int* num_tokens_per_expert = nullptr) { + if (E == 0 || N == 0 || K == 0) return; + + // Each work-item copies PACK_K_FP consecutive K elements for a single + // (e, n). The launch grid covers ceil(K / PACK_K_FP) along the middle dim; + // an inner bounds check guards the K tail when K is not a multiple of + // PACK_K_FP (e.g. K < PACK_K_FP in tiny unit tests). + const int k_tiles = (K + PACK_K_FP - 1) / PACK_K_FP; + sycl::range<3> global{static_cast(E), static_cast(k_tiles), + static_cast((N + WG_N - 1) / WG_N) * WG_N}; + sycl::range<3> local{1, 1, static_cast(WG_N)}; + + q->parallel_for>( + sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { + const int e = static_cast(it.get_global_id(0)); + // Skip experts that receive no tokens in this prefill batch; the + // grouped GEMM will not read their rows of `weights_KN` either. + if (num_tokens_per_expert != nullptr && num_tokens_per_expert[e] == 0) return; + const int k_base = static_cast(it.get_global_id(1)) * PACK_K_FP; + const int n = static_cast(it.get_global_id(2)); + if (n >= N) return; + const size_t w_row = (static_cast(e) * N + static_cast(n)) * K; + const size_t out_base = static_cast(e) * K * N + static_cast(n); +#pragma unroll + for (int j = 0; j < PACK_K_FP; ++j) { + const int k = k_base + j; + if (k >= K) break; + const ScalarT v = weights_NK[w_row + static_cast(k)]; + weights_KN[out_base + static_cast(k) * N] = v; + } + }); +} + +// ---------------------------------------------------------------------------- +// INT8 (S8) dequant: [E, N, K] uint8 -> [E, K, N] ScalarT. +// Asym=false: signed int8 in [-128, 127], dequant = q * scale +// Asym=true : unsigned uint8 in [0, 255], dequant = (q - zero) * scale +// ---------------------------------------------------------------------------- +template +void launch_dequant_int8(sycl::queue* q, const uint8_t* weights_NK, const ScalarT* scales, const ScalarT* zeros, + ScalarT* weights_KN, int E, int N, int K, int group_size, + const int* num_tokens_per_expert = nullptr) { + if (E == 0 || N == 0 || K == 0) return; + if (Asym && zeros == nullptr) { + throw std::invalid_argument("moe_gemm_prefill(int8): zeros pointer required when asym=true"); + } + const int num_groups_k = K / group_size; + + // Each work-item dequantises PACK_K_INT8 consecutive K outputs for a + // single (e, n), sharing one scale/zero load. PACK_K_INT8 (=4) is always + // <= group_size (>=32 in practice), so the cached scale/zero is valid + // for every element in the run. + const int k_tiles = (K + PACK_K_INT8 - 1) / PACK_K_INT8; + sycl::range<3> global{static_cast(E), static_cast(k_tiles), + static_cast((N + WG_N - 1) / WG_N) * WG_N}; + sycl::range<3> local{1, 1, static_cast(WG_N)}; + + q->parallel_for>( + sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { + const int e = static_cast(it.get_global_id(0)); + if (num_tokens_per_expert != nullptr && num_tokens_per_expert[e] == 0) return; + const int k_base = static_cast(it.get_global_id(1)) * PACK_K_INT8; + const int n = static_cast(it.get_global_id(2)); + if (n >= N) return; + const size_t w_row = (static_cast(e) * N + static_cast(n)) * K; + const size_t out_base = static_cast(e) * K * N + static_cast(n); + // Hoist scale/zero loads: PACK_K_INT8 K values share the same group + // because PACK_K_INT8 divides group_size (see PACK_K constants + // above), so `k_base / group_size` is constant across the run. + const int g = k_base / group_size; + const size_t s_idx = (static_cast(e) * N + static_cast(n)) * num_groups_k + + static_cast(g); + const float scale = static_cast(scales[s_idx]); + const float zero = Asym ? static_cast(zeros[s_idx]) : 0.0f; +#pragma unroll + for (int j = 0; j < PACK_K_INT8; ++j) { + const int k = k_base + j; + if (k >= K) break; + const uint8_t raw = weights_NK[w_row + static_cast(k)]; + float w; + if constexpr (Asym) { + w = (static_cast(raw) - zero) * scale; + } else { + w = static_cast(static_cast(raw)) * scale; + } + weights_KN[out_base + static_cast(k) * N] = static_cast(w); + } + }); +} + +// ---------------------------------------------------------------------------- +// INT4 (S4_CLIP) dequant: [E, N, K/2] uint8 packed -> [E, K, N] ScalarT. +// Packing: low nibble at lower K (k = 2*i), high nibble at higher K (k = 2*i+1). +// Asym=false: signed nibble in [-8, 7], dequant = q * scale +// Asym=true : unsigned nibble in [0, 15], dequant = (q - zero) * scale +// ---------------------------------------------------------------------------- +template +void launch_dequant_int4(sycl::queue* q, const uint8_t* weights_NKp, const ScalarT* scales, const ScalarT* zeros, + ScalarT* weights_KN, int E, int N, int K, int group_size, + const int* num_tokens_per_expert = nullptr) { + if (E == 0 || N == 0 || K == 0) return; + if (Asym && zeros == nullptr) { + throw std::invalid_argument("moe_gemm_prefill(int4): zeros pointer required when asym=true"); + } + if ((K & 1) != 0) { + throw std::invalid_argument("moe_gemm_prefill(int4): K must be even"); + } + const int num_groups_k = K / group_size; + const int k_packed = K / 2; + + // Each work-item now dequantises one full packed byte = PACK_K_INT4 (=2) + // consecutive K outputs for a single (e, n). The middle launch dim is + // therefore k_packed instead of K, halving the work-item count and + // eliminating the previous "two items reading the same byte" pattern. + sycl::range<3> global{static_cast(E), static_cast(k_packed), + static_cast((N + WG_N - 1) / WG_N) * WG_N}; + sycl::range<3> local{1, 1, static_cast(WG_N)}; + + q->parallel_for>( + sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { + const int e = static_cast(it.get_global_id(0)); + if (num_tokens_per_expert != nullptr && num_tokens_per_expert[e] == 0) return; + const int kp = static_cast(it.get_global_id(1)); + const int n = static_cast(it.get_global_id(2)); + if (n >= N) return; + const int k_base = kp * PACK_K_INT4; + // PACK_K_INT4 (=2) divides group_size, so all PACK_K_INT4 K values + // in this run share the same scale/zero (one hoisted load each). + const int g = k_base / group_size; + const size_t s_idx = (static_cast(e) * N + static_cast(n)) * num_groups_k + + static_cast(g); + const float scale = static_cast(scales[s_idx]); + const float zero = Asym ? static_cast(zeros[s_idx]) : 0.0f; + const uint8_t packed = weights_NKp[(static_cast(e) * N + static_cast(n)) * k_packed + + static_cast(kp)]; + const size_t out_base = static_cast(e) * K * N + static_cast(n); + float w0, w1; + if constexpr (Asym) { + const int q_lo = static_cast(packed & 0x0F); + const int q_hi = static_cast((packed >> 4) & 0x0F); + w0 = (static_cast(q_lo) - zero) * scale; + w1 = (static_cast(q_hi) - zero) * scale; + } else { + // Sign-extend each nibble: shift into the top of an int8 then + // arithmetic-shift right by 4 to fill the sign bits. + const int q_lo = static_cast(static_cast(packed << 4) >> 4); + const int q_hi = static_cast(static_cast(packed & 0xF0) >> 4); + w0 = static_cast(q_lo) * scale; + w1 = static_cast(q_hi) * scale; + } + weights_KN[out_base + static_cast(k_base) * N] = static_cast(w0); + weights_KN[out_base + static_cast(k_base + 1) * N] = static_cast(w1); + }); +} + +// ---------------------------------------------------------------------------- +// INT2 (S2_CLIP) dequant: [E, N, K/4] uint8 packed -> [E, K, N] ScalarT. +// Packing: byte = q0 | (q1<<2) | (q2<<4) | (q3<<6); field j corresponds to +// K index 4*i + j. +// Asym=false: signed in [-2, 1]; Asym=true: unsigned in [0, 3]. +// ---------------------------------------------------------------------------- +template +void launch_dequant_int2(sycl::queue* q, const uint8_t* weights_NKp, const ScalarT* scales, const ScalarT* zeros, + ScalarT* weights_KN, int E, int N, int K, int group_size, + const int* num_tokens_per_expert = nullptr) { + if (E == 0 || N == 0 || K == 0) return; + if (Asym && zeros == nullptr) { + throw std::invalid_argument("moe_gemm_prefill(int2): zeros pointer required when asym=true"); + } + if ((K & 3) != 0) { + throw std::invalid_argument("moe_gemm_prefill(int2): K must be a multiple of 4"); + } + const int num_groups_k = K / group_size; + const int k_packed = K / 4; + + // One work-item handles one full packed byte = PACK_K_INT2 (=4) + // consecutive K outputs for a single (e, n). This removes the previous + // 4x duplicated byte loads (one item per 2-bit field) and 4x scale + // reloads. + sycl::range<3> global{static_cast(E), static_cast(k_packed), + static_cast((N + WG_N - 1) / WG_N) * WG_N}; + sycl::range<3> local{1, 1, static_cast(WG_N)}; + + q->parallel_for>( + sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { + const int e = static_cast(it.get_global_id(0)); + if (num_tokens_per_expert != nullptr && num_tokens_per_expert[e] == 0) return; + const int kp = static_cast(it.get_global_id(1)); + const int n = static_cast(it.get_global_id(2)); + if (n >= N) return; + const int k_base = kp * PACK_K_INT2; + // PACK_K_INT2 (=4) divides group_size, so all PACK_K_INT2 K values + // in this run share the same scale/zero (one hoisted load each). + const int g = k_base / group_size; + const size_t s_idx = (static_cast(e) * N + static_cast(n)) * num_groups_k + + static_cast(g); + const float scale = static_cast(scales[s_idx]); + const float zero = Asym ? static_cast(zeros[s_idx]) : 0.0f; + const uint8_t packed = weights_NKp[(static_cast(e) * N + static_cast(n)) * k_packed + + static_cast(kp)]; + const size_t out_base = static_cast(e) * K * N + static_cast(n); +#pragma unroll + for (int j = 0; j < PACK_K_INT2; ++j) { + float w; + if constexpr (Asym) { + const int q = static_cast((packed >> (2 * j)) & 0x3); + w = (static_cast(q) - zero) * scale; + } else { + // Sign-extend 2-bit by shifting the field into the top bits of an int8. + const int shift = 6 - 2 * j; // 6, 4, 2, 0 for fields 0..3 + const int8_t s8 = static_cast((packed << shift) & 0xC0); + const int q = static_cast(s8 >> 6); + w = static_cast(q) * scale; + } + weights_KN[out_base + static_cast(k_base + j) * N] = static_cast(w); + } + }); +} + +// ---------------------------------------------------------------------------- +// FP8 (E4M3 / E5M2) dequant: [E, N, K] uint8 -> [E, K, N] ScalarT. +// Per-group scale applied; no zero-points (caller must enforce sym). +// `UseLut` selects the LUT vs inline-bits decode at compile time +// (driven by `ARK_FP8_DECODE_USE_LUT`, same as the decode path). +// ---------------------------------------------------------------------------- +template +void launch_dequant_fp8(sycl::queue* q, const uint8_t* weights_NK, const ScalarT* scales, ScalarT* weights_KN, int E, + int N, int K, int group_size, const int* num_tokens_per_expert = nullptr) { + if (E == 0 || N == 0 || K == 0) return; + const int num_groups_k = K / group_size; + + // PACK_K_FP8 (=4) consecutive K outputs per work-item, sharing one scale + // load. PACK_K_FP8 <= group_size in all supported configurations. + const int k_tiles = (K + PACK_K_FP8 - 1) / PACK_K_FP8; + sycl::range<3> global{static_cast(E), static_cast(k_tiles), + static_cast((N + WG_N - 1) / WG_N) * WG_N}; + sycl::range<3> local{1, 1, static_cast(WG_N)}; + + q->parallel_for>( + sycl::nd_range<3>(global, local), [=](sycl::nd_item<3> it) { + const int e = static_cast(it.get_global_id(0)); + if (num_tokens_per_expert != nullptr && num_tokens_per_expert[e] == 0) return; + const int k_base = static_cast(it.get_global_id(1)) * PACK_K_FP8; + const int n = static_cast(it.get_global_id(2)); + if (n >= N) return; + const size_t w_row = (static_cast(e) * N + static_cast(n)) * K; + const size_t out_base = static_cast(e) * K * N + static_cast(n); + const int g = k_base / group_size; + const size_t s_idx = (static_cast(e) * N + static_cast(n)) * num_groups_k + + static_cast(g); + const float scale = static_cast(scales[s_idx]); +#pragma unroll + for (int j = 0; j < PACK_K_FP8; ++j) { + const int k = k_base + j; + if (k >= K) break; + const uint8_t raw = weights_NK[w_row + static_cast(k)]; + const float w = moe_dequant::decode_fp8(raw) * scale; + weights_KN[out_base + static_cast(k) * N] = static_cast(w); + } + }); +} + +// ---------------------------------------------------------------------------- +// Dispatch helper: dequant any supported weight encoding into `weights_KN` +// (already-allocated `[E, K, N]` ScalarT buffer) using ScalarT == act dtype. +// ---------------------------------------------------------------------------- +template +void dequant_to_KN(sycl::queue* q, const void* weights, const void* scales, const void* zeros, ScalarT* weights_KN, + BTLA_DTYPE weight_dtype, int E, int N, int K, int group_size, bool asym, + const int* num_tokens_per_expert = nullptr) { + if (weight_dtype == BTLA_DTYPE::F16 || weight_dtype == BTLA_DTYPE::BF16) { + launch_dequant_fp(q, static_cast(weights), weights_KN, E, N, K, num_tokens_per_expert); + return; + } + if (weight_dtype == BTLA_DTYPE::S8) { + if (asym) { + launch_dequant_int8(q, static_cast(weights), static_cast(scales), + static_cast(zeros), weights_KN, E, N, K, group_size, + num_tokens_per_expert); + } else { + launch_dequant_int8(q, static_cast(weights), static_cast(scales), + static_cast(zeros), weights_KN, E, N, K, group_size, + num_tokens_per_expert); + } + return; + } + if (weight_dtype == BTLA_DTYPE::S4_CLIP) { + if (asym) { + launch_dequant_int4(q, static_cast(weights), static_cast(scales), + static_cast(zeros), weights_KN, E, N, K, group_size, + num_tokens_per_expert); + } else { + launch_dequant_int4(q, static_cast(weights), static_cast(scales), + static_cast(zeros), weights_KN, E, N, K, group_size, + num_tokens_per_expert); + } + return; + } + if (weight_dtype == BTLA_DTYPE::S2_CLIP) { + if (asym) { + launch_dequant_int2(q, static_cast(weights), static_cast(scales), + static_cast(zeros), weights_KN, E, N, K, group_size, + num_tokens_per_expert); + } else { + launch_dequant_int2(q, static_cast(weights), static_cast(scales), + static_cast(zeros), weights_KN, E, N, K, group_size, + num_tokens_per_expert); + } + return; + } + if (weight_dtype == BTLA_DTYPE::F8_E4M3 || weight_dtype == BTLA_DTYPE::F8_E5M2) { + if (asym) { + throw std::invalid_argument("moe_gemm_prefill(fp8): asym mode is not supported"); + } + const bool is_e4m3 = (weight_dtype == BTLA_DTYPE::F8_E4M3); + const bool use_lut = moe_dequant::fp8_decode_use_lut(); + if (is_e4m3) { + if (use_lut) { + launch_dequant_fp8(q, static_cast(weights), + static_cast(scales), weights_KN, E, N, K, + group_size, num_tokens_per_expert); + } else { + launch_dequant_fp8(q, static_cast(weights), + static_cast(scales), weights_KN, E, N, K, + group_size, num_tokens_per_expert); + } + } else { + if (use_lut) { + launch_dequant_fp8(q, static_cast(weights), + static_cast(scales), weights_KN, E, N, K, + group_size, num_tokens_per_expert); + } else { + launch_dequant_fp8(q, static_cast(weights), + static_cast(scales), weights_KN, E, N, K, + group_size, num_tokens_per_expert); + } + } + return; + } + throw std::invalid_argument( + "moe_gemm_prefill: unsupported weight_dtype (supported: F16, BF16, S8, S4_CLIP, S2_CLIP, F8_E4M3, F8_E5M2)"); +} + +} // namespace moe_mixed_detail + +// ---------------------------------------------------------------------------- +// Public API +// +// MoE prefill (Grouped GEMM) supporting the same set of weight encodings as +// `moe_gemm_decode`. Activations/scales/zeros/outputs are in `act_dtype` +// (FP16 or BF16). For the unquantized fast path (weight_dtype matches +// act_dtype and is FP16/BF16), the call is forwarded directly to the +// existing `moe_gemm` (which already expects `[E, K, N]` row-major +// weights). For all other dtypes, weights are dequantised on-device into +// a temporary `[E, K, N]` buffer of `act_dtype` and then handed to the +// same grouped GEMM. The temporary buffer must be supplied by the caller +// (see the Python wrapper which allocates it sized +// `E * K * N * sizeof(act_dtype)`); this avoids per-call USM allocations +// and keeps memory ownership in PyTorch's caching allocator. +// +// IMPORTANT layout note: +// - Quantized `weights` are `[E, N, K_packed]` (decode-style). +// - The unquantized fast path (`act_dtype == weight_dtype` and FP/BF16) +// forwards directly to `moe_gemm`, which expects `[E, K, N]` weights. +// Callers must therefore pass already-`[E, K, N]` weights for the +// unquantized fast path (matching the existing `moe_gemm` contract). +// The Python wrapper handles this by exposing the unquantized path +// under the same shape contract as `moe_gemm` and the quantized paths +// under the same shape contract as `moe_gemm_decode`. +// ---------------------------------------------------------------------------- +inline void moe_gemm_prefill(sycl::queue* q, void* activations, void* weights, void* scales, void* zeros, + void* outputs, void* dequant_workspace, BTLA_DTYPE act_dtype, BTLA_DTYPE weight_dtype, + int N, int K, int group_size, int* num_tokens_per_expert, int num_experts, + int total_tokens, bool asym) { + if (total_tokens == 0) return; + + // Unquantized fast path: forward directly to the existing prefill GEMM. + // The caller is responsible for passing weights in `[E, K, N]` layout + // (matching the existing `moe_gemm` contract). + if (weight_dtype == act_dtype && (weight_dtype == BTLA_DTYPE::F16 || weight_dtype == BTLA_DTYPE::BF16)) { + moe_gemm(q, activations, weights, scales, outputs, act_dtype, N, K, num_tokens_per_expert, num_experts); + return; + } + + if (dequant_workspace == nullptr) { + throw std::invalid_argument("moe_gemm_prefill: dequant_workspace must be non-null for quantized paths"); + } + + if (act_dtype == BTLA_DTYPE::F16) { + auto* w_kn = static_cast(dequant_workspace); + moe_mixed_detail::dequant_to_KN(q, weights, scales, zeros, w_kn, weight_dtype, num_experts, N, K, + group_size, asym, num_tokens_per_expert); + moe_gemm(q, activations, w_kn, /*scales=*/nullptr, outputs, act_dtype, N, K, num_tokens_per_expert, num_experts); + } else if (act_dtype == BTLA_DTYPE::BF16) { + using BF = sycl::ext::oneapi::bfloat16; + auto* w_kn = static_cast(dequant_workspace); + moe_mixed_detail::dequant_to_KN(q, weights, scales, zeros, w_kn, weight_dtype, num_experts, N, K, group_size, + asym, num_tokens_per_expert); + moe_gemm(q, activations, w_kn, /*scales=*/nullptr, outputs, act_dtype, N, K, num_tokens_per_expert, num_experts); + } else { + throw std::invalid_argument("moe_gemm_prefill: act_dtype must be F16 or BF16"); + } +} + +} // namespace ark + +#endif // ARK_XPU && ARK_SYCL_TLA diff --git a/auto_round_extension/ark/test/README_MOE_PREFILL_PERF.md b/auto_round_extension/ark/test/README_MOE_PREFILL_PERF.md new file mode 100644 index 000000000..5bebc2af5 --- /dev/null +++ b/auto_round_extension/ark/test/README_MOE_PREFILL_PERF.md @@ -0,0 +1,151 @@ +# MoE Prefill Performance Test + +## Overview + +The `test_moe_prefill_perf.py` file provides comprehensive performance benchmarks for MoE (Mixture of Experts) prefill operations with TFLOPS (Tera Floating Point Operations Per Second) calculations. + +## What is MoE Prefill? + +**Prefill** is the phase during LLM inference where many tokens (e.g., the entire prompt or a batch of sequences) are processed simultaneously. In MoE models, tokens are routed to different experts, and each expert may receive multiple tokens. This is different from **decode** (token generation), where typically only one token per expert is processed at a time. + +## Features + +### 1. **Comprehensive Data Type Support** +- FP16 (float16) +- BF16 (bfloat16) +- INT8 (symmetric and asymmetric quantization) +- INT4 (symmetric and asymmetric quantization) +- INT2 (symmetric and asymmetric quantization) +- FP8 (float8_e4m3fn and float8_e5m2) + +### 2. **TFLOPS Calculation** +The test calculates TFLOPS for each configuration using the formula: +``` +FLOPs = total_tokens × K × N × 2 +TFLOPS = FLOPs / (time_in_seconds) / 1e12 +``` + +Where: +- `total_tokens`: Total number of tokens across all experts +- `K`: Input feature dimension +- `N`: Output feature dimension +- `×2`: Each multiply-add operation counts as 2 FLOPs + +### 3. **Various MoE Configurations** +The test covers multiple realistic MoE scenarios: +- **Small models** (8 experts, Mixtral-style): 4096×4096, 4096×14336, 14336×4096 +- **Medium models** (8 experts): Various token distributions +- **Large models** (16, 32, 64 experts, DeepSeek-style): 2048×2048 +- **Uneven distributions**: Simulates real-world routing patterns + +### 4. **Baseline Comparison** +Each test compares the ARK MoE kernel against a baseline PyTorch implementation: +- **Baseline**: Per-expert matrix multiplication using `torch.matmul` +- **ARK Kernel**: Optimized `ark.moe_gemm` with fused operations +- **Speedup**: Reports speedup ratio (baseline_time / ark_time) + +## How to Run + +### Run all tests: +```bash +cd /path/to/auto_round_extension/ark/test +pytest -v -s test_moe_prefill_perf.py +``` + +### Run specific data type: +```bash +# FP16 tests only +pytest -v -s test_moe_prefill_perf.py::TestMoEGemmPrefillPerf::test_perf_fp + +# INT4 tests only +pytest -v -s test_moe_prefill_perf.py::TestMoEGemmPrefillPerf::test_perf_int4 + +# INT8 symmetric quantization with bfloat16 activations +pytest -v -s test_moe_prefill_perf.py::TestMoEGemmPrefillPerf::test_perf_int8 -k "bfloat16 and not asym" +``` + +**Note**: The `-s` flag is required to see the printed timing tables and TFLOPS output. + +## Output Format + +The test prints formatted tables with the following columns: + +``` +shape E N K tokens baseline(ms) ark(ms) speedup TFLOPS +small E=8 8 4096 4096 252 12.3456 4.5678 2.70x 45.2 +medium E=8 8 4096 14336 528 23.4567 8.9012 2.64x 78.9 +... +``` + +Where: +- **shape**: Configuration label +- **E**: Number of experts +- **N**: Output feature dimension +- **K**: Input feature dimension +- **tokens**: Total tokens across all experts +- **baseline(ms)**: PyTorch baseline latency (milliseconds) +- **ark(ms)**: ARK kernel latency (milliseconds) +- **speedup**: Performance improvement ratio +- **TFLOPS**: Throughput in tera floating point operations per second + +## Requirements + +- Intel XPU (Arc GPU) with PyTorch XPU support +- `auto_round_kernel` built with `ARK_SYCL_TLA=ON` +- Test dependencies from `test_moe.py` (pack/dequant helpers) + +## Architecture + +``` +test_moe_prefill_perf.py +├── Timing utilities (_xpu_time_ms) +│ └── Uses XPU events for accurate GPU timing +├── FLOPS calculation (_compute_moe_flops) +│ └── Computes theoretical FLOPs for TFLOPS metric +├── Baseline implementation (_default_moe_prefill) +│ └── Per-expert PyTorch matmul for comparison +├── Test shapes (PREFILL_SHAPES) +│ └── Various realistic MoE configurations +└── Test cases (TestMoEGemmPrefillPerf) + ├── test_perf_fp (FP16/BF16) + ├── test_perf_int4 (INT4 sym/asym) + ├── test_perf_int8 (INT8 sym/asym) + ├── test_perf_int2 (INT2 sym/asym) + └── test_perf_fp8 (FP8 e4m3fn/e5m2) +``` + +## Example Output + +``` +================================================================== +FP weights (float16) -- ark.moe_gemm (prefill) vs per-expert A @ W.T +================================================================== +shape E N K tokens baseline(ms) ark(ms) speedup TFLOPS +------------------------------------------------------------------ +small E=8 8 4096 4096 252 12.3456 4.5678 2.70x 45.2 +medium E=8 8 4096 14336 528 23.4567 8.9012 2.64x 78.9 +medium E=8 8 14336 4096 528 25.6789 9.1234 2.82x 76.5 +large E=16 16 2048 2048 256 5.6789 2.3456 2.42x 91.2 +large E=32 32 2048 2048 256 5.7890 2.4567 2.36x 87.3 +large E=64 64 2048 2048 256 5.8901 2.5678 2.29x 83.5 +uneven E=8 8 4096 4096 610 28.9012 10.1234 2.86x 52.1 +``` + +## Key Metrics + +1. **TFLOPS**: Higher is better - indicates compute throughput +2. **Speedup**: Higher is better - shows performance gain over baseline +3. **Latency (ms)**: Lower is better - actual kernel execution time + +## Integration with CI/CD + +This test can be integrated into performance regression testing: +- Set minimum TFLOPS thresholds for each configuration +- Track speedup ratios over time +- Alert on performance degradation + +## Related Files + +- `test_moe.py`: Correctness tests for MoE GEMM +- `test_moe_decode_perf.py`: Performance tests for MoE decode (single token per expert) +- `test_bench_bmg.py`: SDPA performance benchmarks with TFLOPS diff --git a/auto_round_extension/ark/test/test_moe.py b/auto_round_extension/ark/test/test_moe.py index 360b603b6..8bad20cf9 100644 --- a/auto_round_extension/ark/test/test_moe.py +++ b/auto_round_extension/ark/test/test_moe.py @@ -39,6 +39,13 @@ def has_moe_gemm(): return hasattr(ark.xpu_lib, "moe_gemm") +def has_moe_gemm_decode(): + """Check if MoE decode GEMV kernel is available.""" + if ark.xpu_lib is None: + return False + return hasattr(ark.xpu_lib, "moe_gemm_decode") + + @pytest.mark.skipif(not is_xpu_available(), reason="XPU not available") @pytest.mark.skipif(not has_moe_gemm(), reason="MOE GEMM kernel not built (need ARK_SYCL_TLA=ON)") class TestMoEGemm: @@ -166,5 +173,762 @@ def test_moe_gemm_various_sizes(self, N, K): print(f"MOE GEMM test passed for N={N}, K={K}") +# --------------------------------------------------------------------------- +# Decode-path tests (M per expert is typically 1-2, mirrors top-k routing +# after the activations have been gathered/sorted by the upper layer). +# --------------------------------------------------------------------------- + + +def _pack_int4_sym(w_float, scales, group_size): + """Quantize a [E, N, K] fp tensor to symmetric int4 packed [E, N, K/2]. + + scales is filled in-place with [E, N, K/group_size] values. + """ + E, N, K = w_float.shape + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + s = (absmax / 7.0).squeeze(-1).to(scales.dtype) + scales.copy_(s) + q = torch.clamp(torch.round(w / (s.to(w.dtype).unsqueeze(-1))), -8, 7).to(torch.int8) + q = q.reshape(E, N, K) + # Pack two nibbles per byte: low nibble at lower K, high nibble at higher K. + q_low = q[..., 0::2] & 0x0F + q_high = q[..., 1::2] & 0x0F + packed = (q_low | (q_high << 4)).to(torch.uint8) + return packed + + +def _pack_int4_asym(w_float, scales, zeros, group_size): + """Quantize to asymmetric int4 (range [0, 15]); returns packed weights.""" + E, N, K = w_float.shape + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + wmin = w.amin(dim=-1, keepdim=True) + wmax = w.amax(dim=-1, keepdim=True) + s = ((wmax - wmin) / 15.0).clamp(min=1e-8) + z = torch.round(-wmin / s).clamp(0, 15) + scales.copy_(s.squeeze(-1).to(scales.dtype)) + zeros.copy_(z.squeeze(-1).to(zeros.dtype)) + q = torch.clamp(torch.round(w / s + z), 0, 15).to(torch.int32) + q = q.reshape(E, N, K) + q_low = q[..., 0::2] & 0x0F + q_high = q[..., 1::2] & 0x0F + packed = (q_low | (q_high << 4)).to(torch.uint8) + return packed + + +def _dequant_int4_sym(packed, scales, group_size): + """Inverse of _pack_int4_sym. Returns [E, N, K] in scales.dtype.""" + E, N, K_half = packed.shape + K = K_half * 2 + low = (packed & 0x0F).to(torch.int8) + high = ((packed >> 4) & 0x0F).to(torch.int8) + # Sign extend 4-bit -> 8-bit + low = torch.where(low >= 8, low - 16, low) + high = torch.where(high >= 8, high - 16, high) + q = torch.empty(E, N, K, dtype=torch.int8, device=packed.device) + q[..., 0::2] = low + q[..., 1::2] = high + q = q.reshape(E, N, K // group_size, group_size).to(scales.dtype) + return (q * scales.unsqueeze(-1)).reshape(E, N, K) + + +def _dequant_int4_asym(packed, scales, zeros, group_size): + E, N, K_half = packed.shape + K = K_half * 2 + low = (packed & 0x0F).to(torch.int32) + high = ((packed >> 4) & 0x0F).to(torch.int32) + q = torch.empty(E, N, K, dtype=torch.int32, device=packed.device) + q[..., 0::2] = low + q[..., 1::2] = high + q = q.reshape(E, N, K // group_size, group_size).to(scales.dtype) + deq = (q - zeros.to(scales.dtype).unsqueeze(-1)) * scales.unsqueeze(-1) + return deq.reshape(E, N, K) + + +# --------------------------------------------------------------------------- +# Int8 / Int2 / FP8 helpers (decode-path). +# --------------------------------------------------------------------------- + + +def _pack_int8_sym(w_float, scales, group_size): + """Quantize [E, N, K] fp -> int8 (signed, [-127, 127]); fills scales in place.""" + E, N, K = w_float.shape + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + s = (absmax / 127.0).squeeze(-1).to(scales.dtype) + scales.copy_(s) + q = torch.clamp(torch.round(w / s.to(w.dtype).unsqueeze(-1)), -127, 127).to(torch.int8) + # Reinterpret as uint8 with no value change. + return q.reshape(E, N, K).view(torch.uint8).contiguous() + + +def _pack_int8_asym(w_float, scales, zeros, group_size): + """Quantize [E, N, K] fp -> uint8 ([0, 255]); fills scales/zeros in place.""" + E, N, K = w_float.shape + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + wmin = w.amin(dim=-1, keepdim=True) + wmax = w.amax(dim=-1, keepdim=True) + s = ((wmax - wmin) / 255.0).clamp(min=1e-8) + z = torch.round(-wmin / s).clamp(0, 255) + scales.copy_(s.squeeze(-1).to(scales.dtype)) + zeros.copy_(z.squeeze(-1).to(zeros.dtype)) + q = torch.clamp(torch.round(w / s + z), 0, 255).to(torch.int32) + return q.reshape(E, N, K).to(torch.uint8).contiguous() + + +def _dequant_int8_sym(packed_u8, scales, group_size): + # Reinterpret uint8 bytes as signed int8. + q = packed_u8.view(torch.int8).to(scales.dtype) + E, N, K = q.shape + q = q.reshape(E, N, K // group_size, group_size) + return (q * scales.unsqueeze(-1)).reshape(E, N, K) + + +def _dequant_int8_asym(packed_u8, scales, zeros, group_size): + q = packed_u8.to(torch.int32).to(scales.dtype) + E, N, K = q.shape + q = q.reshape(E, N, K // group_size, group_size) + deq = (q - zeros.to(scales.dtype).unsqueeze(-1)) * scales.unsqueeze(-1) + return deq.reshape(E, N, K) + + +def _pack_int2_sym(w_float, scales, group_size): + """Quantize [E, N, K] fp -> packed int2 (signed [-2, 1]); shape [E, N, K/4]. + + Packing: byte = q0 | (q1<<2) | (q2<<4) | (q3<<6), where the j-th 2-bit + field corresponds to K index 4*i + j. + """ + E, N, K = w_float.shape + assert K % 4 == 0, "K must be a multiple of 4 for int2 packing" + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + # Symmetric int2 has signed range [-2, 1] (i.e. clip at 2 and -2 but -2 inclusive). + absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + s = (absmax / 2.0).squeeze(-1).to(scales.dtype) + scales.copy_(s) + q = torch.clamp(torch.round(w / s.to(w.dtype).unsqueeze(-1)), -2, 1).to(torch.int32) + q = q.reshape(E, N, K) + # Pack 4 values per byte. + q0 = q[..., 0::4] & 0x3 + q1 = q[..., 1::4] & 0x3 + q2 = q[..., 2::4] & 0x3 + q3 = q[..., 3::4] & 0x3 + packed = (q0 | (q1 << 2) | (q2 << 4) | (q3 << 6)).to(torch.uint8) + return packed + + +def _pack_int2_asym(w_float, scales, zeros, group_size): + """Quantize [E, N, K] fp -> packed int2 (unsigned [0, 3]); shape [E, N, K/4].""" + E, N, K = w_float.shape + assert K % 4 == 0 + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + wmin = w.amin(dim=-1, keepdim=True) + wmax = w.amax(dim=-1, keepdim=True) + s = ((wmax - wmin) / 3.0).clamp(min=1e-8) + z = torch.round(-wmin / s).clamp(0, 3) + scales.copy_(s.squeeze(-1).to(scales.dtype)) + zeros.copy_(z.squeeze(-1).to(zeros.dtype)) + q = torch.clamp(torch.round(w / s + z), 0, 3).to(torch.int32) + q = q.reshape(E, N, K) + q0 = q[..., 0::4] & 0x3 + q1 = q[..., 1::4] & 0x3 + q2 = q[..., 2::4] & 0x3 + q3 = q[..., 3::4] & 0x3 + packed = (q0 | (q1 << 2) | (q2 << 4) | (q3 << 6)).to(torch.uint8) + return packed + + +def _dequant_int2_sym(packed, scales, group_size): + E, N, K_q = packed.shape + K = K_q * 4 + p = packed.to(torch.int32) + fields = torch.empty(E, N, K, dtype=torch.int32, device=packed.device) + fields[..., 0::4] = p & 0x3 + fields[..., 1::4] = (p >> 2) & 0x3 + fields[..., 2::4] = (p >> 4) & 0x3 + fields[..., 3::4] = (p >> 6) & 0x3 + # Sign-extend 2-bit (>=2 means negative). + fields = torch.where(fields >= 2, fields - 4, fields).to(scales.dtype) + fields = fields.reshape(E, N, K // group_size, group_size) + return (fields * scales.unsqueeze(-1)).reshape(E, N, K) + + +def _dequant_int2_asym(packed, scales, zeros, group_size): + E, N, K_q = packed.shape + K = K_q * 4 + p = packed.to(torch.int32) + fields = torch.empty(E, N, K, dtype=torch.int32, device=packed.device) + fields[..., 0::4] = p & 0x3 + fields[..., 1::4] = (p >> 2) & 0x3 + fields[..., 2::4] = (p >> 4) & 0x3 + fields[..., 3::4] = (p >> 6) & 0x3 + fields = fields.to(scales.dtype) + fields = fields.reshape(E, N, K // group_size, group_size) + deq = (fields - zeros.to(scales.dtype).unsqueeze(-1)) * scales.unsqueeze(-1) + return deq.reshape(E, N, K) + + +def _pack_fp8(w_float, scales, group_size, fp8_dtype): + """Quantize [E, N, K] fp -> FP8 (e4m3fn/e5m2) with per-group scale. + + Scales are filled in-place; ``w_float`` is divided by the scale, then cast + to ``fp8_dtype`` (rounding handled by torch). Returns the FP8 tensor. + """ + assert fp8_dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + E, N, K = w_float.shape + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + # Pick fp8 max representable magnitude for the chosen format. + if fp8_dtype == torch.float8_e4m3fn: + fp8_max = 448.0 + else: # e5m2 + fp8_max = 57344.0 + s = (absmax / fp8_max).squeeze(-1).to(scales.dtype) + scales.copy_(s) + scaled = (w / s.to(w.dtype).unsqueeze(-1)).reshape(E, N, K) + # Clamp to fp8 representable range before cast to avoid Inf/NaN. + scaled = scaled.clamp(-fp8_max, fp8_max) + return scaled.to(fp8_dtype).contiguous() + + +def _dequant_fp8(packed_fp8, scales, group_size, out_dtype): + """Reference dequant: cast fp8 -> out_dtype and multiply per-group scale.""" + E, N, K = packed_fp8.shape + w = packed_fp8.to(out_dtype).reshape(E, N, K // group_size, group_size) + return (w * scales.unsqueeze(-1)).reshape(E, N, K) + + +def _moe_decode_reference(activations, dequant_weights, num_tokens_per_expert): + """Reference: each token is matmul'd against its routed expert's weights.""" + total_tokens, K = activations.shape + E, N, _ = dequant_weights.shape + out = torch.empty(total_tokens, N, dtype=activations.dtype, device=activations.device) + offset = 0 + for e in range(E): + n_tokens = int(num_tokens_per_expert[e].item()) + if n_tokens == 0: + continue + a = activations[offset : offset + n_tokens] # [n_tokens, K] + w = dequant_weights[e] # [N, K] + out[offset : offset + n_tokens] = a @ w.T + offset += n_tokens + return out + + +@pytest.mark.skipif(not is_xpu_available(), reason="XPU not available") +@pytest.mark.skipif(not has_moe_gemm_decode(), reason="MoE decode GEMV kernel not built (need ARK_SYCL_TLA=ON)") +class TestMoEGemmDecode: + """Unit tests for the MoE decode GEMV kernel. + + The activations layout follows the same convention as ``moe_gemm``: the + upper layer has already gathered/sorted tokens per expert, so the kernel + only needs ``num_tokens_per_expert`` (no top-k indices). + """ + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_decode_fp_basic(self, dtype): + num_experts = 4 + # One token per expert with one zero-token expert -> typical top-k=3 + # decode pattern after gather. + tokens_per_expert = [1, 0, 1, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 128 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + weights = torch.randn(num_experts, N, K, dtype=dtype, device="xpu") * 0.1 + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode(activations, weights, num_tokens_per_expert, weight_bits=16) + + ref = _moe_decode_reference(activations, weights, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + assert out.dtype == dtype + torch.testing.assert_close(out, ref, rtol=2e-2, atol=2e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_decode_int4_sym(self, dtype, group_size): + num_experts = 4 + tokens_per_expert = [1, 1, 0, 2] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_sym(w_float, scales, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + weight_bits=4, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_int4_sym(packed, scales, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_decode_int4_asym(self, dtype): + num_experts = 4 + group_size = 128 + tokens_per_expert = [0, 1, 2, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_asym(w_float, scales, zeros, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=4, + group_size=group_size, + asym=True, + ) + + dequant = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + def test_decode_validation_errors(self): + """Sanity-check that Python-side validation catches misuse.""" + num_experts = 2 + activations = torch.randn(2, 128, dtype=torch.float16, device="xpu") + num_tokens_per_expert = torch.tensor([1, 1], dtype=torch.int32, device="xpu") + + # N must be a multiple of 16 + bad_weights = torch.randn(num_experts, 17, 128, dtype=torch.float16, device="xpu") + with pytest.raises(ValueError): + ark.moe_gemm_decode(activations, bad_weights, num_tokens_per_expert, weight_bits=16) + + # weight_bits=4 requires uint8 packed weights + bad_packed = torch.randn(num_experts, 64, 64, dtype=torch.float16, device="xpu") + scales = torch.empty(num_experts, 64, 1, dtype=torch.float16, device="xpu") + with pytest.raises(ValueError): + ark.moe_gemm_decode( + activations, bad_packed, num_tokens_per_expert, scales=scales, weight_bits=4, group_size=128 + ) + + # asym=True without zeros must error + packed = torch.zeros(num_experts, 64, 64, dtype=torch.uint8, device="xpu") + with pytest.raises(ValueError): + ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + weight_bits=4, + group_size=128, + asym=True, + ) + + # FP8 + asym is rejected + fp8_w = torch.zeros(num_experts, 64, 128, dtype=torch.float8_e4m3fn, device="xpu") + zeros = torch.empty(num_experts, 64, 1, dtype=torch.float16, device="xpu") + with pytest.raises(ValueError): + ark.moe_gemm_decode( + activations, + fp8_w, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + group_size=128, + asym=True, + ) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_decode_int8_sym(self, dtype, group_size): + num_experts = 4 + tokens_per_expert = [1, 1, 0, 2] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_sym(w_float, scales, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + weight_bits=8, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_int8_sym(packed, scales, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_decode_int8_asym(self, dtype): + num_experts = 4 + group_size = 128 + tokens_per_expert = [0, 1, 2, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_asym(w_float, scales, zeros, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=8, + group_size=group_size, + asym=True, + ) + + dequant = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_decode_int2_sym(self, dtype, group_size): + num_experts = 4 + tokens_per_expert = [1, 1, 0, 2] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_sym(w_float, scales, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + weight_bits=2, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_int2_sym(packed, scales, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + # Int2 has much higher quant error; relax tolerance vs int4. + torch.testing.assert_close(out, ref, rtol=1e-1, atol=1e-1) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_decode_int2_asym(self, dtype): + num_experts = 4 + group_size = 128 + tokens_per_expert = [0, 1, 2, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_asym(w_float, scales, zeros, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=2, + group_size=group_size, + asym=True, + ) + + dequant = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=1e-1, atol=1e-1) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_decode_fp8(self, dtype, fp8_dtype, group_size): + num_experts = 4 + tokens_per_expert = [1, 0, 2, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_fp8(packed, scales, group_size, dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + # E5M2 has only 2 mantissa bits -> coarser; relax tolerance for both. + rtol = 1e-1 if fp8_dtype == torch.float8_e5m2 else 5e-2 + atol = 1e-1 if fp8_dtype == torch.float8_e5m2 else 5e-2 + torch.testing.assert_close(out, ref, rtol=rtol, atol=atol) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) + + +def has_moe_gemm_prefill(): + """Check if the quantized MoE prefill GEMM kernel is available.""" + if ark.xpu_lib is None: + return False + return hasattr(ark.xpu_lib, "moe_gemm_prefill") + + +@pytest.mark.skipif(not is_xpu_available(), reason="XPU not available") +@pytest.mark.skipif(not has_moe_gemm_prefill(), reason="MoE prefill GEMM kernel not built (need ARK_SYCL_TLA=ON)") +class TestMoEGemmPrefill: + """Parity tests for the quantized MoE prefill kernel. + + The prefill kernel accepts the same quantized weight encodings as + ``moe_gemm_decode`` (decode-style ``[E, N, K_packed]``); these tests check + that ``moe_gemm_prefill(q_weights)`` matches ``moe_gemm`` evaluated on the + same dequantized weights, for multi-token-per-expert workloads typical of + the prefill phase. + """ + + @staticmethod + def _run_prefill_reference(activations, weights_NK, num_tokens_per_expert): + """Reference: per-expert ``A @ W.T`` over ``[E, N, K]`` weights.""" + total_tokens, _ = activations.shape + E, N, _ = weights_NK.shape + out = torch.empty(total_tokens, N, dtype=activations.dtype, device=activations.device) + offset = 0 + for e in range(E): + n_tokens = int(num_tokens_per_expert[e].item()) + if n_tokens == 0: + continue + a = activations[offset : offset + n_tokens] # [n_tokens, K] + out[offset : offset + n_tokens] = a @ weights_NK[e].T + offset += n_tokens + return out + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_prefill_fp_basic(self, dtype): + # Many tokens per expert (prefill regime). + num_experts = 4 + tokens_per_expert = [16, 4, 0, 12] + total_tokens = sum(tokens_per_expert) + N, K = 128, 128 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + weights_NK = torch.randn(num_experts, N, K, dtype=dtype, device="xpu") * 0.1 + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill(activations, weights_NK, num_tokens_per_expert, weight_bits=16) + + ref = self._run_prefill_reference(activations, weights_NK, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + assert out.dtype == dtype + torch.testing.assert_close(out, ref, rtol=2e-2, atol=2e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_prefill_int4_sym(self, dtype, group_size): + num_experts = 4 + tokens_per_expert = [8, 8, 0, 16] + total_tokens = sum(tokens_per_expert) + N, K = 128, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_sym(w_float, scales, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, + packed, + num_tokens_per_expert, + scales=scales, + weight_bits=4, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_int4_sym(packed, scales, group_size).to(dtype) + ref = self._run_prefill_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_prefill_int4_asym(self, dtype): + num_experts = 4 + group_size = 128 + tokens_per_expert = [4, 8, 12, 8] + total_tokens = sum(tokens_per_expert) + N, K = 128, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_asym(w_float, scales, zeros, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=4, + group_size=group_size, + asym=True, + ) + + dequant = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) + ref = self._run_prefill_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_prefill_int8(self, dtype, asym): + num_experts = 4 + group_size = 128 + tokens_per_expert = [8, 8, 16, 0] + total_tokens = sum(tokens_per_expert) + N, K = 128, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int8_sym(w_float, scales, group_size) + dequant = _dequant_int8_sym(packed, scales, group_size).to(dtype) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=8, + group_size=group_size, + asym=asym, + ) + + ref = self._run_prefill_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_prefill_int2(self, dtype, asym): + num_experts = 4 + group_size = 128 + tokens_per_expert = [4, 8, 4, 16] + total_tokens = sum(tokens_per_expert) + N, K = 128, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int2_sym(w_float, scales, group_size) + dequant = _dequant_int2_sym(packed, scales, group_size).to(dtype) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=2, + group_size=group_size, + asym=asym, + ) + + ref = self._run_prefill_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + # Int2 quantization noise is high; relax tolerances. + torch.testing.assert_close(out, ref, rtol=1e-1, atol=1e-1) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + def test_prefill_fp8(self, dtype, fp8_dtype): + num_experts = 4 + group_size = 128 + tokens_per_expert = [8, 0, 12, 4] + total_tokens = sum(tokens_per_expert) + N, K = 128, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, + packed, + num_tokens_per_expert, + scales=scales, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_fp8(packed, scales, group_size, dtype) + ref = self._run_prefill_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) diff --git a/auto_round_extension/ark/test/test_moe_decode_perf.py b/auto_round_extension/ark/test/test_moe_decode_perf.py new file mode 100644 index 000000000..e2ba36a18 --- /dev/null +++ b/auto_round_extension/ark/test/test_moe_decode_perf.py @@ -0,0 +1,386 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Performance comparison: ``ark.moe_gemm_decode`` vs default XPU MoE. + +The "default XPU MoE implementation" used as the baseline is the standard +per-expert PyTorch matmul loop (the same approach ``_moe_decode_reference`` +uses in ``test_moe.py``). For quantized formats the weights are dequantized +once up-front (outside the timed region), so the baseline measures only the +matmul cost on XPU. This is what models fall back to when no fused decode +kernel is available. + +How to run:: + + pytest -v -s auto_round_extension/ark/test/test_moe_decode_perf.py + +The ``-s`` flag is required to see the printed timing tables. +""" + +import auto_round_kernel +import pytest +import torch + +# Reuse the existing pack/dequant helpers from the correctness tests so that +# the benchmarked path matches what the unit tests already validate. +from test_moe import ( # noqa: E402 + _dequant_fp8, + _dequant_int2_asym, + _dequant_int2_sym, + _dequant_int4_asym, + _dequant_int4_sym, + _dequant_int8_asym, + _dequant_int8_sym, + _pack_fp8, + _pack_int2_asym, + _pack_int2_sym, + _pack_int4_asym, + _pack_int4_sym, + _pack_int8_asym, + _pack_int8_sym, +) + +ark = auto_round_kernel + + +# --------------------------------------------------------------------------- +# Skip reasons. +# +# The original test_moe.py collapses several different failure modes into one +# generic "kernel not built" message which makes it impossible to tell whether +# the build is missing the kernel or whether XPU itself didn't come up. The +# helpers below distinguish those cases so a skipped run is actually +# actionable. +# --------------------------------------------------------------------------- + + +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def _xpu_skip_reason() -> str: + if not hasattr(torch, "xpu"): + return "torch has no xpu submodule (need an Intel XPU build of torch)" + if not torch.xpu.is_available(): + return "torch.xpu.is_available() == False (no XPU device or driver visible)" + return "" + + +def _decode_skip_reason() -> str: + """Return non-empty string if the decode kernel can't be exercised.""" + reason = _xpu_skip_reason() + if reason: + return reason + if ark.xpu_lib is None: + return ( + "ark.xpu_lib is None -- the XPU extension module " + "(auto_round_kernel_xpu) failed to import; check that auto_round_kernel " + "was installed for THIS Python env with XPU support enabled" + ) + if not hasattr(ark.xpu_lib, "moe_gemm_decode"): + return ( + "ark.xpu_lib loaded but has no moe_gemm_decode symbol -- " + "rebuild with ARK_SYCL_TLA=ON to compile the MoE decode GEMV kernel" + ) + return "" + + +_DECODE_SKIP = _decode_skip_reason() + +# Surface diagnostics on collection so the user always sees why the suite +# would skip, without having to add extra flags. +print( + "[moe-decode-perf] xpu_available=%s xpu_lib=%s has_moe_gemm_decode=%s" + % ( + _xpu_available(), + "loaded" if ark.xpu_lib is not None else "None", + hasattr(ark.xpu_lib, "moe_gemm_decode") if ark.xpu_lib is not None else False, + ) +) +if _DECODE_SKIP: + print("[moe-decode-perf] suite will SKIP. reason: %s" % _DECODE_SKIP) + + +# --------------------------------------------------------------------------- +# Timing utilities. +# --------------------------------------------------------------------------- + +# Warmup / iteration counts kept modest so the suite is still UT-shaped +# (finishes in seconds) but large enough for stable medians. +WARMUP = 5 +ITERS = 30 + + +def _xpu_time_ms(fn, warmup: int = WARMUP, iters: int = ITERS) -> float: + """Time ``fn`` on XPU using device events; returns median ms per call.""" + for _ in range(warmup): + fn() + torch.xpu.synchronize() + + timings = [] + for _ in range(iters): + start = torch.xpu.Event(enable_timing=True) + end = torch.xpu.Event(enable_timing=True) + start.record() + fn() + end.record() + end.synchronize() + timings.append(start.elapsed_time(end)) + timings.sort() + return timings[len(timings) // 2] + + +def _default_moe_decode(activations, dequant_weights, num_tokens_per_expert): + """Default XPU MoE decode baseline: per-expert torch matmul loop. + + This mirrors the path a model would take when no fused MoE decode kernel + is available: gather/sort tokens by expert (done by the caller), then + iterate over experts and do a plain ``A @ W.T`` on each slice. + """ + total_tokens, _ = activations.shape + E, N, _ = dequant_weights.shape + out = torch.empty(total_tokens, N, dtype=activations.dtype, device=activations.device) + offset = 0 + for e in range(E): + n_tokens = int(num_tokens_per_expert[e].item()) + if n_tokens == 0: + continue + a = activations[offset : offset + n_tokens] + out[offset : offset + n_tokens] = a @ dequant_weights[e].T + offset += n_tokens + return out + + +# --------------------------------------------------------------------------- +# Shape matrix. +# +# Picked to cover small / medium / large MoE expert FFNs (Mixtral-style +# 4096x14336 down-projection is the upper bound; smaller shapes catch +# launch-overhead-dominated cases). ``tokens_per_expert`` follows the +# expected decode-phase pattern (top-k routing with batch=1: each active +# expert sees one token). +# --------------------------------------------------------------------------- + +DECODE_SHAPES = [ + # (label, num_experts, tokens_per_expert, N, K) + ("small E=4 ", 4, [1, 0, 1, 1], 1024, 1024), + ("medium E=8 ", 8, [1, 1, 0, 1, 1, 0, 1, 1], 2048, 2048), + ("large E=8 ", 8, [1, 0, 1, 1, 0, 1, 1, 1], 4096, 4096), + ("ffn-up E=8 ", 8, [1, 1, 0, 1, 1, 1, 0, 1], 14336, 4096), + ("ffn-dn E=8 ", 8, [1, 1, 0, 1, 1, 1, 0, 1], 4096, 14336), +] + + +def _print_header(title: str) -> None: + print() + print("=" * 96) + print(title) + print(f"{'shape':<14}{'N':>7}{'K':>7}{'tokens':>8}" f"{'baseline(ms)':>16}{'ark(ms)':>14}{'speedup':>12}") + print("-" * 96) + + +def _print_row(label, N, K, total_tokens, base_ms, ark_ms): + speedup = base_ms / ark_ms if ark_ms > 0 else float("nan") + print(f"{label:<14}{N:>7}{K:>7}{total_tokens:>8}" f"{base_ms:>16.4f}{ark_ms:>14.4f}{speedup:>11.2f}x") + + +# --------------------------------------------------------------------------- +# Benchmark cases. +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(bool(_DECODE_SKIP), reason=_DECODE_SKIP or "ok") +class TestMoEGemmDecodePerf: + """Median XPU-event timings of ``moe_gemm_decode`` vs per-expert ``A @ W.T``. + + The baseline uses *already-dequantized* weights, so quantized cases only + pay the matmul cost in the timed region (no per-iteration dequant). This + is the most favorable apples-to-apples comparison for the baseline; the + fused decode kernel must beat that to be worth using. + """ + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_perf_fp(self, dtype): + _print_header(f"FP weights ({str(dtype).split('.')[-1]}) -- ark.moe_gemm_decode vs per-expert A @ W.T") + for label, E, tpe, N, K in DECODE_SHAPES: + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + weights = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, weights, ntpe)) + ark_ms = _xpu_time_ms(lambda: ark.moe_gemm_decode(activations, weights, ntpe, weight_bits=16)) + _print_row(label, N, K, total_tokens, base_ms, ark_ms) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_perf_int4(self, dtype, asym): + group_size = 128 + kind = "asym" if asym else "sym" + _print_header( + f"INT4 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " + f"-- ark.moe_gemm_decode vs dequant + per-expert A @ W.T" + ) + for label, E, tpe, N, K in DECODE_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int4_sym(w_float, scales, group_size) + dequant = _dequant_int4_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, dequant, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_decode( + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=4, + group_size=group_size, + asym=asym, + ) + ) + _print_row(label, N, K, total_tokens, base_ms, ark_ms) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_perf_int8(self, dtype, asym): + group_size = 128 + kind = "asym" if asym else "sym" + _print_header( + f"INT8 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " + f"-- ark.moe_gemm_decode vs dequant + per-expert A @ W.T" + ) + for label, E, tpe, N, K in DECODE_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int8_sym(w_float, scales, group_size) + dequant = _dequant_int8_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, dequant, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_decode( + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=8, + group_size=group_size, + asym=asym, + ) + ) + _print_row(label, N, K, total_tokens, base_ms, ark_ms) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_perf_int2(self, dtype, asym): + group_size = 128 + kind = "asym" if asym else "sym" + _print_header( + f"INT2 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " + f"-- ark.moe_gemm_decode vs dequant + per-expert A @ W.T" + ) + for label, E, tpe, N, K in DECODE_SHAPES: + if K % group_size != 0 or K % 4 != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int2_sym(w_float, scales, group_size) + dequant = _dequant_int2_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, dequant, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_decode( + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=2, + group_size=group_size, + asym=asym, + ) + ) + _print_row(label, N, K, total_tokens, base_ms, ark_ms) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + def test_perf_fp8(self, dtype, fp8_dtype): + group_size = 128 + _print_header( + f"FP8 {str(fp8_dtype).split('.')[-1]} (group_size={group_size}, " + f"act={str(dtype).split('.')[-1]}) -- ark.moe_gemm_decode vs dequant + per-expert A @ W.T" + ) + for label, E, tpe, N, K in DECODE_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) + dequant = _dequant_fp8(packed, scales, group_size, dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, dequant, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_decode( + activations, + packed, + ntpe, + scales=scales, + group_size=group_size, + asym=False, + ) + ) + _print_row(label, N, K, total_tokens, base_ms, ark_ms) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/auto_round_extension/ark/test/test_moe_prefill_accuracy.py b/auto_round_extension/ark/test/test_moe_prefill_accuracy.py new file mode 100644 index 000000000..8be30e19c --- /dev/null +++ b/auto_round_extension/ark/test/test_moe_prefill_accuracy.py @@ -0,0 +1,366 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Accuracy (parity) tests for ``ark.moe_gemm`` / ``ark.moe_gemm_prefill``. + +This complements ``test_moe_prefill_perf.py`` (which measures throughput) and +``test_moe.py`` (which exercises small toy shapes). The shape matrix here +mirrors the *production-scale* shapes used by the perf benchmark — large +hidden sizes (up to 14336), high expert counts (up to 64), and uneven +token-per-expert distributions typical of Mixtral/DeepSeek-style MoE models +during prefill. + +For each (dtype, quant-scheme) combination the test: + + 1. Packs / quantizes random weights. + 2. Dequantizes them back to the activation dtype. + 3. Runs the ark kernel on the *packed* weights. + 4. Runs a per-expert ``A @ W.T`` reference on the *dequantized* weights. + 5. Compares with ``torch.testing.assert_close`` at quant-appropriate + tolerances. + +This isolates kernel correctness (matmul + on-the-fly dequant) from +quantization noise: the reference shares the same dequantized weights as the +kernel, so tolerances reflect only accumulator/order-of-operations +differences, not quant error. + +How to run:: + + pytest -v -s auto_round_extension/ark/test/test_moe_prefill_accuracy.py +""" + +import auto_round_kernel +import pytest +import torch + +# Reuse pack/dequant helpers from the correctness tests. +from test_moe import ( # noqa: E402 + _dequant_fp8, + _dequant_int2_asym, + _dequant_int2_sym, + _dequant_int4_asym, + _dequant_int4_sym, + _dequant_int8_asym, + _dequant_int8_sym, + _pack_fp8, + _pack_int2_asym, + _pack_int2_sym, + _pack_int4_asym, + _pack_int4_sym, + _pack_int8_asym, + _pack_int8_sym, +) + +ark = auto_round_kernel + + +# --------------------------------------------------------------------------- +# Skip reasons (mirror test_moe_prefill_perf.py) +# --------------------------------------------------------------------------- + + +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def _xpu_skip_reason() -> str: + if not hasattr(torch, "xpu"): + return "torch has no xpu submodule (need an Intel XPU build of torch)" + if not torch.xpu.is_available(): + return "torch.xpu.is_available() == False (no XPU device or driver visible)" + return "" + + +def _prefill_skip_reason() -> str: + reason = _xpu_skip_reason() + if reason: + return reason + if ark.xpu_lib is None: + return ( + "ark.xpu_lib is None -- the XPU extension module " + "(auto_round_kernel_xpu) failed to import; check that auto_round_kernel " + "was installed for THIS Python env with XPU support enabled" + ) + if not hasattr(ark.xpu_lib, "moe_gemm"): + return ( + "ark.xpu_lib loaded but has no moe_gemm symbol -- " + "rebuild with ARK_SYCL_TLA=ON to compile the MoE GEMM kernel" + ) + return "" + + +def _quantized_prefill_skip_reason() -> str: + reason = _prefill_skip_reason() + if reason: + return reason + if not hasattr(ark.xpu_lib, "moe_gemm_prefill"): + return ( + "ark.xpu_lib loaded but has no moe_gemm_prefill symbol -- " + "rebuild with ARK_SYCL_TLA=ON to compile the quantized MoE prefill kernel" + ) + return "" + + +_PREFILL_SKIP = _prefill_skip_reason() +_QUANT_PREFILL_SKIP = _quantized_prefill_skip_reason() + + +# --------------------------------------------------------------------------- +# Shape matrix +# +# Subset of the perf benchmark's PREFILL_SHAPES. We keep the production-scale +# hidden sizes (up to N/K = 14336) and high expert counts (up to E = 64) so +# the accuracy check exercises the same code paths as the perf benchmark, +# but skip duplicate up-proj/down-proj rows to keep wall-clock reasonable. +# --------------------------------------------------------------------------- + +PREFILL_SHAPES = [ + # (label, num_experts, tokens_per_expert_list, N, K) + ("small E=8 ", 8, [32, 28, 30, 35, 33, 31, 29, 34], 4096, 4096), + ("medium E=8 ", 8, [64, 60, 68, 72, 65, 63, 70, 66], 4096, 14336), + ("large E=16", 16, [16] * 16, 2048, 2048), + ("large E=32", 32, [8] * 32, 2048, 2048), + ("large E=64", 64, [4] * 64, 2048, 2048), + ("uneven E=8 ", 8, [100, 50, 75, 80, 60, 90, 70, 85], 4096, 4096), +] + + +# --------------------------------------------------------------------------- +# Reference implementation +# --------------------------------------------------------------------------- + + +def _reference_moe_prefill(activations, dequant_weights_NK, num_tokens_per_expert): + """Per-expert ``A @ W.T`` over ``[E, N, K]`` dequantized weights. + + This is the same reference used by ``TestMoEGemmPrefill`` in ``test_moe.py`` + and matches what a model would compute when no fused kernel is available. + """ + total_tokens, _ = activations.shape + E, N, _ = dequant_weights_NK.shape + out = torch.empty(total_tokens, N, dtype=activations.dtype, device=activations.device) + offset = 0 + for e in range(E): + n_tokens = int(num_tokens_per_expert[e].item()) + if n_tokens == 0: + continue + a = activations[offset : offset + n_tokens] + out[offset : offset + n_tokens] = a @ dequant_weights_NK[e].T + offset += n_tokens + return out + + +# --------------------------------------------------------------------------- +# Tolerances per quant scheme +# +# These mirror the tolerances used in ``test_moe.py`` for the small-shape +# parity tests. We loosen slightly for the large-shape cases because longer +# K-reduction accumulates more rounding noise (FP16/BF16 accumulators). +# --------------------------------------------------------------------------- + +_TOL_FP = dict(rtol=3e-2, atol=3e-2) +_TOL_INT8 = dict(rtol=7e-2, atol=7e-2) +_TOL_INT4 = dict(rtol=7e-2, atol=7e-2) +_TOL_INT2 = dict(rtol=1.5e-1, atol=1.5e-1) +_TOL_FP8 = dict(rtol=7e-2, atol=7e-2) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(bool(_PREFILL_SKIP), reason=_PREFILL_SKIP or "ok") +class TestMoEGemmPrefillAccuracy: + """Parity tests for ``moe_gemm`` / ``moe_gemm_prefill`` at production shapes. + + Each test iterates the production-scale shape matrix and asserts the ark + kernel output matches a per-expert dequant + ``A @ W.T`` reference. + """ + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_accuracy_fp(self, dtype): + for label, E, tpe, N, K in PREFILL_SHAPES: + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + # ark.moe_gemm wants weights in [E, K, N] layout. + weights_KN = (torch.randn(E, K, N, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm(activations, weights_KN, ntpe) + + # Reference uses [E, N, K] layout. + weights_NK = weights_KN.transpose(1, 2).contiguous() + ref = _reference_moe_prefill(activations, weights_NK, ntpe) + + assert out.shape == (total_tokens, N), f"{label}: bad shape {out.shape}" + assert out.dtype == dtype, f"{label}: bad dtype {out.dtype}" + torch.testing.assert_close( + out, ref, msg=lambda m, lbl=label: f"[{lbl}] {m}", **_TOL_FP + ) + + @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_accuracy_int4(self, dtype, asym): + group_size = 128 + for label, E, tpe, N, K in PREFILL_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int4_sym(w_float, scales, group_size) + dequant = _dequant_int4_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=4, + group_size=group_size, + asym=asym, + ) + + ref = _reference_moe_prefill(activations, dequant, ntpe) + assert out.shape == (total_tokens, N), f"{label}: bad shape {out.shape}" + torch.testing.assert_close( + out, ref, msg=lambda m, lbl=label: f"[{lbl}] {m}", **_TOL_INT4 + ) + + @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_accuracy_int8(self, dtype, asym): + group_size = 128 + for label, E, tpe, N, K in PREFILL_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int8_sym(w_float, scales, group_size) + dequant = _dequant_int8_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=8, + group_size=group_size, + asym=asym, + ) + + ref = _reference_moe_prefill(activations, dequant, ntpe) + assert out.shape == (total_tokens, N), f"{label}: bad shape {out.shape}" + torch.testing.assert_close( + out, ref, msg=lambda m, lbl=label: f"[{lbl}] {m}", **_TOL_INT8 + ) + + @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_accuracy_int2(self, dtype, asym): + group_size = 128 + for label, E, tpe, N, K in PREFILL_SHAPES: + if K % group_size != 0 or K % 4 != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int2_sym(w_float, scales, group_size) + dequant = _dequant_int2_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=2, + group_size=group_size, + asym=asym, + ) + + ref = _reference_moe_prefill(activations, dequant, ntpe) + assert out.shape == (total_tokens, N), f"{label}: bad shape {out.shape}" + torch.testing.assert_close( + out, ref, msg=lambda m, lbl=label: f"[{lbl}] {m}", **_TOL_INT2 + ) + + @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + def test_accuracy_fp8(self, dtype, fp8_dtype): + group_size = 128 + for label, E, tpe, N, K in PREFILL_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) + dequant = _dequant_fp8(packed, scales, group_size, dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_prefill( + activations, + packed, + ntpe, + scales=scales, + group_size=group_size, + asym=False, + ) + + ref = _reference_moe_prefill(activations, dequant, ntpe) + assert out.shape == (total_tokens, N), f"{label}: bad shape {out.shape}" + torch.testing.assert_close( + out, ref, msg=lambda m, lbl=label: f"[{lbl}] {m}", **_TOL_FP8 + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/auto_round_extension/ark/test/test_moe_prefill_perf.py b/auto_round_extension/ark/test/test_moe_prefill_perf.py new file mode 100644 index 000000000..4ce79a1c0 --- /dev/null +++ b/auto_round_extension/ark/test/test_moe_prefill_perf.py @@ -0,0 +1,521 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Performance benchmark: ``ark.moe_gemm`` for MoE prefill workloads. + +MoE prefill is the matrix-matrix multiplication phase where many tokens (e.g., +the entire prompt or a batch of sequences) are routed to different experts. +Unlike decode (one token per expert), prefill has multiple tokens per expert, +making it a batched GEMM problem. + +This benchmark measures throughput (TFLOPS) and compares against a baseline +PyTorch implementation. The baseline uses per-expert matrix multiplication +with already-dequantized weights (for quantized tests), representing the +standard fallback path when no fused MoE kernel is available. + +How to run:: + + pytest -v -s auto_round_extension/ark/test/test_moe_prefill_perf.py + +The ``-s`` flag is required to see the printed timing tables and TFLOPS. +""" + +import auto_round_kernel +import pytest +import torch + +# Reuse pack/dequant helpers from the correctness tests +from test_moe import ( # noqa: E402 + _dequant_fp8, + _dequant_int2_asym, + _dequant_int2_sym, + _dequant_int4_asym, + _dequant_int4_sym, + _dequant_int8_asym, + _dequant_int8_sym, + _pack_fp8, + _pack_int2_asym, + _pack_int2_sym, + _pack_int4_asym, + _pack_int4_sym, + _pack_int8_asym, + _pack_int8_sym, +) + +ark = auto_round_kernel + + +# --------------------------------------------------------------------------- +# Skip reasons +# --------------------------------------------------------------------------- + + +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def _xpu_skip_reason() -> str: + if not hasattr(torch, "xpu"): + return "torch has no xpu submodule (need an Intel XPU build of torch)" + if not torch.xpu.is_available(): + return "torch.xpu.is_available() == False (no XPU device or driver visible)" + return "" + + +def _prefill_skip_reason() -> str: + """Return non-empty string if the MoE prefill kernel can't be exercised.""" + reason = _xpu_skip_reason() + if reason: + return reason + if ark.xpu_lib is None: + return ( + "ark.xpu_lib is None -- the XPU extension module " + "(auto_round_kernel_xpu) failed to import; check that auto_round_kernel " + "was installed for THIS Python env with XPU support enabled" + ) + if not hasattr(ark.xpu_lib, "moe_gemm"): + return ( + "ark.xpu_lib loaded but has no moe_gemm symbol -- " + "rebuild with ARK_SYCL_TLA=ON to compile the MoE GEMM kernel" + ) + return "" + + +def _quantized_prefill_skip_reason() -> str: + """Return non-empty string if the quantized MoE prefill kernel can't be exercised.""" + reason = _prefill_skip_reason() + if reason: + return reason + if not hasattr(ark.xpu_lib, "moe_gemm_prefill"): + return ( + "ark.xpu_lib loaded but has no moe_gemm_prefill symbol -- " + "rebuild with ARK_SYCL_TLA=ON to compile the quantized MoE prefill kernel" + ) + return "" + + +_PREFILL_SKIP = _prefill_skip_reason() +_QUANT_PREFILL_SKIP = _quantized_prefill_skip_reason() + +# Surface diagnostics on collection +print( + "[moe-prefill-perf] xpu_available=%s xpu_lib=%s has_moe_gemm=%s" + % ( + _xpu_available(), + "loaded" if ark.xpu_lib is not None else "None", + hasattr(ark.xpu_lib, "moe_gemm") if ark.xpu_lib is not None else False, + ) +) +if _PREFILL_SKIP: + print("[moe-prefill-perf] suite will SKIP. reason: %s" % _PREFILL_SKIP) + + +# --------------------------------------------------------------------------- +# Timing utilities +# --------------------------------------------------------------------------- + +WARMUP = 5 +ITERS = 30 + + +def _xpu_time_ms(fn, warmup: int = WARMUP, iters: int = ITERS) -> float: + """Time ``fn`` on XPU using device events; returns median ms per call.""" + for _ in range(warmup): + fn() + torch.xpu.synchronize() + + timings = [] + for _ in range(iters): + start = torch.xpu.Event(enable_timing=True) + end = torch.xpu.Event(enable_timing=True) + start.record() + fn() + end.record() + end.synchronize() + timings.append(start.elapsed_time(end)) + timings.sort() + return timings[len(timings) // 2] + + +def _default_moe_prefill(activations, dequant_weights, num_tokens_per_expert): + """Default XPU MoE prefill baseline: per-expert torch matmul. + + This mirrors the path a model would take when no fused MoE kernel is + available: iterate over experts and do ``A @ W.T`` on each token slice. + For prefill, each expert may have many tokens (unlike decode where each + expert typically has one token). + """ + total_tokens, K = activations.shape + E, N, _ = dequant_weights.shape + out = torch.empty(total_tokens, N, dtype=activations.dtype, device=activations.device) + offset = 0 + for e in range(E): + n_tokens = int(num_tokens_per_expert[e].item()) + if n_tokens == 0: + continue + a = activations[offset : offset + n_tokens] + # Weights are [N, K], activations are [n_tokens, K] + # Output is [n_tokens, N] + out[offset : offset + n_tokens] = a @ dequant_weights[e].T + offset += n_tokens + return out + + +# --------------------------------------------------------------------------- +# TFLOPS calculation +# --------------------------------------------------------------------------- + + +def _compute_moe_flops(total_tokens, K, N, num_experts_active): + """Compute FLOPs for MoE GEMM: sum over active experts of (tokens * K * N * 2). + + For a typical prefill, all experts may be active with varying token counts. + As a simplification, we use the total_tokens across all experts. + """ + # Each GEMM is [tokens, K] @ [K, N] -> [tokens, N] + # FLOPs = tokens * K * N * 2 (multiply-add counts as 2 ops) + return total_tokens * K * N * 2 + + +# --------------------------------------------------------------------------- +# Shape matrix for prefill +# +# Prefill has many tokens per expert (e.g., batch size, prompt length). +# We test small to large expert counts and token distributions typical of +# MoE models during prefill (Mixtral, DeepSeek, etc.). +# --------------------------------------------------------------------------- + +PREFILL_SHAPES = [ + # (label, num_experts, tokens_per_expert_list, N, K) + # Small models (e.g., Mixtral 8x7B style) + ("small E=8 ", 8, [32, 28, 30, 35, 33, 31, 29, 34], 4096, 4096), + ("medium E=8 ", 8, [64, 60, 68, 72, 65, 63, 70, 66], 4096, 14336), # up-proj + ("medium E=8 ", 8, [64, 60, 68, 72, 65, 63, 70, 66], 14336, 4096), # down-proj + # Larger models (e.g., DeepSeek style with more experts) + ("large E=16", 16, [16] * 16, 2048, 2048), + ("large E=32", 32, [8] * 32, 2048, 2048), + ("large E=64", 64, [4] * 64, 2048, 2048), + # Uneven distribution (some experts get more tokens) + ("uneven E=8 ", 8, [100, 50, 75, 80, 60, 90, 70, 85], 4096, 4096), +] + + +def _print_header(title: str, *, with_dequant_baseline: bool = False) -> None: + """Print a benchmark header. + + When ``with_dequant_baseline`` is True, an extra ``base+deq(ms)`` column is + printed alongside the matmul-only baseline and ``speedup`` is reported + against the dequant-inclusive baseline (this is the apples-to-apples + comparison for the current Stage-1 ``moe_gemm_prefill`` which dequants + into a workspace before dispatching to the FP GEMM). + """ + print() + width = 130 if with_dequant_baseline else 110 + print("=" * width) + print(title) + if with_dequant_baseline: + print( + f"{'shape':<14}{'E':>4}{'N':>7}{'K':>7}{'tokens':>8}" + f"{'base mm(ms)':>14}{'base+deq(ms)':>16}{'ark(ms)':>12}" + f"{'speedup':>12}{'TFLOPS':>10}" + ) + else: + print( + f"{'shape':<14}{'E':>4}{'N':>7}{'K':>7}{'tokens':>8}" + f"{'baseline(ms)':>16}{'ark(ms)':>14}{'speedup':>12}{'TFLOPS':>10}" + ) + print("-" * width) + + +def _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops, *, base_with_deq_ms=None): + """Print a benchmark row. + + If ``base_with_deq_ms`` is provided, the dequant-inclusive baseline is + shown and ``speedup`` is computed against it (apples-to-apples with the + Stage-1 quantized path). Otherwise the row reverts to the original + matmul-only layout. + """ + if base_with_deq_ms is None: + speedup = base_ms / ark_ms if ark_ms > 0 else float("nan") + print( + f"{label:<14}{E:>4}{N:>7}{K:>7}{total_tokens:>8}" + f"{base_ms:>16.4f}{ark_ms:>14.4f}{speedup:>11.2f}x{tflops:>9.1f}" + ) + else: + speedup = base_with_deq_ms / ark_ms if ark_ms > 0 else float("nan") + print( + f"{label:<14}{E:>4}{N:>7}{K:>7}{total_tokens:>8}" + f"{base_ms:>14.4f}{base_with_deq_ms:>16.4f}{ark_ms:>12.4f}" + f"{speedup:>11.2f}x{tflops:>9.1f}" + ) + + +# --------------------------------------------------------------------------- +# Benchmark cases +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(bool(_PREFILL_SKIP), reason=_PREFILL_SKIP or "ok") +class TestMoEGemmPrefillPerf: + """Median XPU-event timings of ``moe_gemm`` vs per-expert matrix multiply. + + The baseline uses *already-dequantized* weights for quantized tests, so + the timed region only measures matmul cost. This is the most favorable + apples-to-apples comparison for the baseline. + """ + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_perf_fp(self, dtype): + _print_header(f"FP weights ({str(dtype).split('.')[-1]}) -- ark.moe_gemm (prefill) vs per-expert A @ W.T") + for label, E, tpe, N, K in PREFILL_SHAPES: + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + # Weights layout: [E, K, N] for moe_gemm + weights = (torch.randn(E, K, N, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + # Baseline: per-expert matmul. Weights need to be [E, N, K] for the baseline. + weights_baseline = weights.transpose(1, 2) # [E, N, K] + + base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, weights_baseline, ntpe)) + ark_ms = _xpu_time_ms(lambda: ark.moe_gemm(activations, weights, ntpe)) + + # Compute TFLOPS + flops = _compute_moe_flops(total_tokens, K, N, E) + tflops = flops / (ark_ms * 1e-3) / 1e12 + + _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops) + + @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_perf_int4(self, dtype, asym): + group_size = 128 + kind = "asym" if asym else "sym" + _print_header( + f"INT4 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " + f"-- ark.moe_gemm_prefill (prefill) vs dequant + per-expert A @ W.T", + with_dequant_baseline=True, + ) + for label, E, tpe, N, K in PREFILL_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + # Pack helpers expect weights in [E, N, K] layout. + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) + + def _baseline_with_dequant(): + d = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) + return _default_moe_prefill(activations, d, ntpe) + else: + zeros = None + packed = _pack_int4_sym(w_float, scales, group_size) + dequant = _dequant_int4_sym(packed, scales, group_size).to(dtype) + + def _baseline_with_dequant(): + d = _dequant_int4_sym(packed, scales, group_size).to(dtype) + return _default_moe_prefill(activations, d, ntpe) + + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + # ``dequant`` is already [E, N, K] -- matches the baseline contract. + base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant, ntpe)) + base_deq_ms = _xpu_time_ms(_baseline_with_dequant) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_prefill( + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=4, + group_size=group_size, + asym=asym, + ) + ) + + flops = _compute_moe_flops(total_tokens, K, N, E) + tflops = flops / (ark_ms * 1e-3) / 1e12 + + _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops, base_with_deq_ms=base_deq_ms) + + @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_perf_int8(self, dtype, asym): + group_size = 128 + kind = "asym" if asym else "sym" + _print_header( + f"INT8 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " + f"-- ark.moe_gemm_prefill (prefill) vs dequant + per-expert A @ W.T", + with_dequant_baseline=True, + ) + for label, E, tpe, N, K in PREFILL_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) + + def _baseline_with_dequant(): + d = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) + return _default_moe_prefill(activations, d, ntpe) + else: + zeros = None + packed = _pack_int8_sym(w_float, scales, group_size) + dequant = _dequant_int8_sym(packed, scales, group_size).to(dtype) + + def _baseline_with_dequant(): + d = _dequant_int8_sym(packed, scales, group_size).to(dtype) + return _default_moe_prefill(activations, d, ntpe) + + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant, ntpe)) + base_deq_ms = _xpu_time_ms(_baseline_with_dequant) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_prefill( + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=8, + group_size=group_size, + asym=asym, + ) + ) + + flops = _compute_moe_flops(total_tokens, K, N, E) + tflops = flops / (ark_ms * 1e-3) / 1e12 + + _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops, base_with_deq_ms=base_deq_ms) + + @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_perf_int2(self, dtype, asym): + group_size = 128 + kind = "asym" if asym else "sym" + _print_header( + f"INT2 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " + f"-- ark.moe_gemm_prefill (prefill) vs dequant + per-expert A @ W.T", + with_dequant_baseline=True, + ) + for label, E, tpe, N, K in PREFILL_SHAPES: + if K % group_size != 0 or K % 4 != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) + + def _baseline_with_dequant(): + d = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) + return _default_moe_prefill(activations, d, ntpe) + else: + zeros = None + packed = _pack_int2_sym(w_float, scales, group_size) + dequant = _dequant_int2_sym(packed, scales, group_size).to(dtype) + + def _baseline_with_dequant(): + d = _dequant_int2_sym(packed, scales, group_size).to(dtype) + return _default_moe_prefill(activations, d, ntpe) + + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant, ntpe)) + base_deq_ms = _xpu_time_ms(_baseline_with_dequant) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_prefill( + activations, + packed, + ntpe, + scales=scales, + zeros=zeros, + weight_bits=2, + group_size=group_size, + asym=asym, + ) + ) + + flops = _compute_moe_flops(total_tokens, K, N, E) + tflops = flops / (ark_ms * 1e-3) / 1e12 + + _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops, base_with_deq_ms=base_deq_ms) + + @pytest.mark.skipif(bool(_QUANT_PREFILL_SKIP), reason=_QUANT_PREFILL_SKIP or "ok") + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + def test_perf_fp8(self, dtype, fp8_dtype): + group_size = 128 + _print_header( + f"FP8 {str(fp8_dtype).split('.')[-1]} (group_size={group_size}, " + f"act={str(dtype).split('.')[-1]}) -- ark.moe_gemm_prefill (prefill) vs dequant + per-expert A @ W.T", + with_dequant_baseline=True, + ) + for label, E, tpe, N, K in PREFILL_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) + dequant = _dequant_fp8(packed, scales, group_size, dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + def _baseline_with_dequant(): + d = _dequant_fp8(packed, scales, group_size, dtype) + return _default_moe_prefill(activations, d, ntpe) + + base_ms = _xpu_time_ms(lambda: _default_moe_prefill(activations, dequant, ntpe)) + base_deq_ms = _xpu_time_ms(_baseline_with_dequant) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_prefill( + activations, + packed, + ntpe, + scales=scales, + group_size=group_size, + asym=False, + ) + ) + + flops = _compute_moe_flops(total_tokens, K, N, E) + tflops = flops / (ark_ms * 1e-3) / 1e12 + + _print_row(label, E, N, K, total_tokens, base_ms, ark_ms, tflops, base_with_deq_ms=base_deq_ms) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/auto_round_extension/ark/test/test_moe_unified.py b/auto_round_extension/ark/test/test_moe_unified.py new file mode 100644 index 000000000..308dcf994 --- /dev/null +++ b/auto_round_extension/ark/test/test_moe_unified.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parity tests for the unified ``ark.moe`` dispatcher. + +``ark.moe`` is a thin Python-side dispatcher that picks between +``moe_gemm_decode`` (GEMV-tuned) and ``moe_gemm_prefill`` (GEMM-tuned) based +on the token distribution. The contract is that dispatching should never +change the numerical result: ``moe(...)`` must be **bit-identical** to the +underlying kernel that was dispatched. + +This file checks: + + * Dispatch correctness: ``phase="auto"`` picks decode when every expert + sees few tokens and prefill otherwise. + * Bit-parity: ``moe(phase="auto")`` matches the kernel it dispatched to. + * Explicit-phase parity: ``moe(phase="decode")`` matches + ``moe_gemm_decode``, ``moe(phase="prefill")`` matches + ``moe_gemm_prefill`` (for the same inputs). + * Coverage across all supported quant schemes (fp / int8 / int4 / int2 / + fp8) and both activation dtypes. + * Argument validation (bad ``phase`` raises ``ValueError``). +""" + +import auto_round_kernel +import pytest +import torch + +# Reuse pack/dequant helpers from the correctness tests. +from test_moe import ( # noqa: E402 + _dequant_fp8, + _dequant_int2_asym, + _dequant_int2_sym, + _dequant_int4_asym, + _dequant_int4_sym, + _dequant_int8_asym, + _dequant_int8_sym, + _pack_fp8, + _pack_int2_asym, + _pack_int2_sym, + _pack_int4_asym, + _pack_int4_sym, + _pack_int8_asym, + _pack_int8_sym, +) + +ark = auto_round_kernel + + +# --------------------------------------------------------------------------- +# Skip reasons +# --------------------------------------------------------------------------- + + +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def _unified_skip_reason() -> str: + if not _xpu_available(): + return "XPU not available" + if ark.xpu_lib is None: + return "ark.xpu_lib is None (XPU extension failed to import)" + if not hasattr(ark.xpu_lib, "moe_gemm_decode"): + return "ark.xpu_lib missing moe_gemm_decode (need ARK_SYCL_TLA=ON)" + if not hasattr(ark.xpu_lib, "moe_gemm_prefill"): + return "ark.xpu_lib missing moe_gemm_prefill (need ARK_SYCL_TLA=ON)" + if not hasattr(ark, "moe"): + return "ark.moe (unified entry point) not exported by auto_round_kernel" + return "" + + +_UNIFIED_SKIP = _unified_skip_reason() + + +# --------------------------------------------------------------------------- +# Small shapes (one decode-shaped, one prefill-shaped) -- keep wall-clock low. +# --------------------------------------------------------------------------- + +_DECODE_SHAPE = dict(num_experts=4, tokens_per_expert=[1, 2, 0, 2], N=128, K=256) +_PREFILL_SHAPE = dict(num_experts=4, tokens_per_expert=[16, 8, 0, 20], N=128, K=256) + + +def _make_int4_sym(E, N, K, group_size, dtype, total_tokens): + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_sym(w_float, scales, group_size) + return activations, packed, scales, None + + +def _make_int4_asym(E, N, K, group_size, dtype, total_tokens): + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_asym(w_float, scales, zeros, group_size) + return activations, packed, scales, zeros + + +def _make_int8_sym(E, N, K, group_size, dtype, total_tokens): + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_sym(w_float, scales, group_size) + return activations, packed, scales, None + + +def _make_int8_asym(E, N, K, group_size, dtype, total_tokens): + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_asym(w_float, scales, zeros, group_size) + return activations, packed, scales, zeros + + +def _make_int2_sym(E, N, K, group_size, dtype, total_tokens): + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_sym(w_float, scales, group_size) + return activations, packed, scales, None + + +def _make_int2_asym(E, N, K, group_size, dtype, total_tokens): + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_asym(w_float, scales, zeros, group_size) + return activations, packed, scales, zeros + + +def _make_fp8(E, N, K, group_size, dtype, total_tokens, fp8_dtype): + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) + return activations, packed, scales, None + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(bool(_UNIFIED_SKIP), reason=_UNIFIED_SKIP or "ok") +class TestMoeUnifiedDispatch: + """Tests for the auto-dispatch logic itself.""" + + def test_auto_picks_decode_for_small_tokens_per_expert(self): + shape = _DECODE_SHAPE + total_tokens = sum(shape["tokens_per_expert"]) + E, N, K = shape["num_experts"], shape["N"], shape["K"] + group_size = 128 + dtype = torch.float16 + + activations, packed, scales, _ = _make_int4_sym(E, N, K, group_size, dtype, total_tokens) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + out_auto = ark.moe( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + phase="auto", + ) + out_decode = ark.moe_gemm_decode( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + ) + # max tokens/expert = 2 (<= default threshold 4) -> dispatched to decode + # -> output must be bit-identical to moe_gemm_decode. + torch.testing.assert_close(out_auto, out_decode, rtol=0, atol=0) + + def test_auto_picks_prefill_for_large_tokens_per_expert(self): + shape = _PREFILL_SHAPE + total_tokens = sum(shape["tokens_per_expert"]) + E, N, K = shape["num_experts"], shape["N"], shape["K"] + group_size = 128 + dtype = torch.float16 + + activations, packed, scales, _ = _make_int4_sym(E, N, K, group_size, dtype, total_tokens) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + out_auto = ark.moe( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + phase="auto", + ) + out_prefill = ark.moe_gemm_prefill( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + ) + torch.testing.assert_close(out_auto, out_prefill, rtol=0, atol=0) + + def test_decode_threshold_override(self): + # Same prefill-shaped input but bump the threshold above the max + # tokens/expert -> auto must now pick decode. + shape = _PREFILL_SHAPE + total_tokens = sum(shape["tokens_per_expert"]) + E, N, K = shape["num_experts"], shape["N"], shape["K"] + group_size = 128 + dtype = torch.float16 + + activations, packed, scales, _ = _make_int4_sym(E, N, K, group_size, dtype, total_tokens) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + max_tpe = max(shape["tokens_per_expert"]) + out_auto = ark.moe( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + phase="auto", decode_threshold=max_tpe + 1, + ) + out_decode = ark.moe_gemm_decode( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + ) + torch.testing.assert_close(out_auto, out_decode, rtol=0, atol=0) + + def test_invalid_phase_raises(self): + shape = _DECODE_SHAPE + total_tokens = sum(shape["tokens_per_expert"]) + E, N, K = shape["num_experts"], shape["N"], shape["K"] + group_size = 128 + dtype = torch.float16 + + activations, packed, scales, _ = _make_int4_sym(E, N, K, group_size, dtype, total_tokens) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + with pytest.raises(ValueError, match="phase"): + ark.moe( + activations, packed, ntpe, + scales=scales, weight_bits=4, group_size=group_size, asym=False, + phase="not_a_phase", + ) + + +@pytest.mark.skipif(bool(_UNIFIED_SKIP), reason=_UNIFIED_SKIP or "ok") +class TestMoeUnifiedBitParity: + """``moe(phase=X)`` must be bit-identical to the dispatched kernel. + + Parametrised across all supported quant schemes; for each, we check both + the decode-shaped and the prefill-shaped input so that both code paths + are exercised regardless of which one ``"auto"`` would have picked. + """ + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("shape_name,shape", [ + ("decode-shape", _DECODE_SHAPE), + ("prefill-shape", _PREFILL_SHAPE), + ]) + def test_fp_unquantized(self, dtype, shape_name, shape): + E, N, K = shape["num_experts"], shape["N"], shape["K"] + total_tokens = sum(shape["tokens_per_expert"]) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + weights_NK = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + # phase="decode" -> moe_gemm_decode + out_unified = ark.moe(activations, weights_NK, ntpe, weight_bits=16, phase="decode") + out_kernel = ark.moe_gemm_decode(activations, weights_NK, ntpe, weight_bits=16) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + # phase="prefill" -> moe_gemm_prefill + out_unified = ark.moe(activations, weights_NK, ntpe, weight_bits=16, phase="prefill") + out_kernel = ark.moe_gemm_prefill(activations, weights_NK, ntpe, weight_bits=16) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + @pytest.mark.parametrize("shape_name,shape", [ + ("decode-shape", _DECODE_SHAPE), + ("prefill-shape", _PREFILL_SHAPE), + ]) + def test_int4(self, dtype, asym, shape_name, shape): + E, N, K = shape["num_experts"], shape["N"], shape["K"] + total_tokens = sum(shape["tokens_per_expert"]) + group_size = 128 + if asym: + activations, packed, scales, zeros = _make_int4_asym(E, N, K, group_size, dtype, total_tokens) + else: + activations, packed, scales, zeros = _make_int4_sym(E, N, K, group_size, dtype, total_tokens) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + kwargs = dict(scales=scales, zeros=zeros, weight_bits=4, group_size=group_size, asym=asym) + + out_unified = ark.moe(activations, packed, ntpe, phase="decode", **kwargs) + out_kernel = ark.moe_gemm_decode(activations, packed, ntpe, **kwargs) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + out_unified = ark.moe(activations, packed, ntpe, phase="prefill", **kwargs) + out_kernel = ark.moe_gemm_prefill(activations, packed, ntpe, **kwargs) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_int8(self, dtype, asym): + # Single shape -- the quant path is the same on both shapes, so + # iterating both would just slow the test suite down. + shape = _PREFILL_SHAPE + E, N, K = shape["num_experts"], shape["N"], shape["K"] + total_tokens = sum(shape["tokens_per_expert"]) + group_size = 128 + if asym: + activations, packed, scales, zeros = _make_int8_asym(E, N, K, group_size, dtype, total_tokens) + else: + activations, packed, scales, zeros = _make_int8_sym(E, N, K, group_size, dtype, total_tokens) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + kwargs = dict(scales=scales, zeros=zeros, weight_bits=8, group_size=group_size, asym=asym) + out_unified = ark.moe(activations, packed, ntpe, phase="prefill", **kwargs) + out_kernel = ark.moe_gemm_prefill(activations, packed, ntpe, **kwargs) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + out_unified = ark.moe(activations, packed, ntpe, phase="decode", **kwargs) + out_kernel = ark.moe_gemm_decode(activations, packed, ntpe, **kwargs) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_int2(self, dtype, asym): + shape = _PREFILL_SHAPE + E, N, K = shape["num_experts"], shape["N"], shape["K"] + total_tokens = sum(shape["tokens_per_expert"]) + group_size = 128 + if asym: + activations, packed, scales, zeros = _make_int2_asym(E, N, K, group_size, dtype, total_tokens) + else: + activations, packed, scales, zeros = _make_int2_sym(E, N, K, group_size, dtype, total_tokens) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + kwargs = dict(scales=scales, zeros=zeros, weight_bits=2, group_size=group_size, asym=asym) + out_unified = ark.moe(activations, packed, ntpe, phase="prefill", **kwargs) + out_kernel = ark.moe_gemm_prefill(activations, packed, ntpe, **kwargs) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + out_unified = ark.moe(activations, packed, ntpe, phase="decode", **kwargs) + out_kernel = ark.moe_gemm_decode(activations, packed, ntpe, **kwargs) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + def test_fp8(self, dtype, fp8_dtype): + shape = _PREFILL_SHAPE + E, N, K = shape["num_experts"], shape["N"], shape["K"] + total_tokens = sum(shape["tokens_per_expert"]) + group_size = 128 + activations, packed, scales, _ = _make_fp8(E, N, K, group_size, dtype, total_tokens, fp8_dtype) + ntpe = torch.tensor(shape["tokens_per_expert"], dtype=torch.int32, device="xpu") + + kwargs = dict(scales=scales, group_size=group_size, asym=False) + out_unified = ark.moe(activations, packed, ntpe, phase="prefill", **kwargs) + out_kernel = ark.moe_gemm_prefill(activations, packed, ntpe, **kwargs) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + out_unified = ark.moe(activations, packed, ntpe, phase="decode", **kwargs) + out_kernel = ark.moe_gemm_decode(activations, packed, ntpe, **kwargs) + torch.testing.assert_close(out_unified, out_kernel, rtol=0, atol=0) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/pyproject.toml b/pyproject.toml index 6efe49e8d..e1978e006 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,6 +2,11 @@ requires = ["setuptools>=64", "wheel"] build-backend = "setuptools.build_meta" +[tool.pytest.ini_options] +markers = [ + "perf: performance benchmark tests (excluded from default CI runs; select with `-m perf`)", +] + [tool.codespell] skip = 'pyproject.toml,.azure-pipelines/scripts/codeScan/codespell/autoround_dict.txt,auto_round_extension/ark/*' ignore-words = ".azure-pipelines/scripts/codeScan/codespell/autoround_dict.txt" diff --git a/test/test_ark/test_moe_model_perf.py b/test/test_ark/test_moe_model_perf.py new file mode 100644 index 000000000..7410b8ef2 --- /dev/null +++ b/test/test_ark/test_moe_model_perf.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model-level perf benchmark for MoE LLMs on the ARK (XPU) backend. + +This file mirrors the structure of ``test/test_ark/test_model.py`` but +operates on real MoE checkpoints (Qwen1.5-MoE, DeepSeek-V2-Lite) and adds +prefill / decode latency measurement on top. + +For each ``(model, bits, dtype)`` combination we: + + 1. Load a *tiny* slice of the MoE model with random weights + (``num_layers=2``, ``num_experts=4``) via ``helpers.get_tiny_model`` so + the suite is tractable in CI. Set ``AR_MOE_PERF_FULL=1`` to load the + full checkpoint instead. + 2. Quantize with ``AutoRound(iters=0, nsamples=1, disable_opt_rtn=True)`` + and export with ``format="auto_round"``. + 3. Reload twice on XPU: + a. unquantized FP reference (no ``quantization_config``); + b. ARK backend (``AutoRoundConfig(backend="ark")``); + c. *optional* GPTQModel backend, skipped if not installed. + 4. Smoke-test correctness via ``helpers.model_infer`` (asserts non-empty + output -> catches any wiring break in the backend). + 5. Measure **prefill** latency (single forward over a 128-token prompt) + and **per-token decode** latency (``model.generate(max_new_tokens=32)`` + after a 4-token warmup) using ``torch.xpu.Event``. Report the median + of 3 runs. + 6. Assert that ARK decode latency is within ``ARK_DECODE_REGRESSION_FACTOR`` + of the FP reference -- defends against silent perf regressions of the + unified ``ark.moe`` dispatcher. + +How to run:: + + pytest -v -s test/test_ark/test_moe_model_perf.py + # full-size checkpoints (needs ~30GB free + checkpoint cached locally): + AR_MOE_PERF_FULL=1 pytest -v -s test/test_ark/test_moe_model_perf.py + # exclude from default CI runs: + pytest -m 'not perf' ... +""" + +import os +import shutil +from statistics import median + +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoRoundConfig, AutoTokenizer + +from auto_round import AutoRound + +from ..helpers import ( + deepseek_v2_name_or_path, + get_model_path, + get_tiny_model, + model_infer, + qwen_moe_name_or_path, +) + + +# --------------------------------------------------------------------------- +# Knobs +# --------------------------------------------------------------------------- + +# Use the full pretrained checkpoint when set; otherwise build a tiny random +# slice via ``get_tiny_model`` so the test stays CI-friendly. +_USE_FULL_MODEL = os.environ.get("AR_MOE_PERF_FULL", "0") == "1" + +# Tiny-model slice geometry (only used when AR_MOE_PERF_FULL=0). +_TINY_NUM_LAYERS = 2 +_TINY_NUM_EXPERTS = 4 + +# Timing harness. +_PREFILL_PROMPT_TOKENS = 128 +_DECODE_NEW_TOKENS = 32 +_DECODE_WARMUP_TOKENS = 4 +_TIMING_REPEATS = 3 + +# Perf-regression guard for the ARK decode path vs the unquantized FP +# baseline. Loose enough to absorb run-to-run jitter (we already take the +# median of N runs) but tight enough to catch the "phase=auto sync" +# regression class described in the upstream perf analysis (which costs +# ~25% per kernel call). +ARK_DECODE_REGRESSION_FACTOR = 2.0 + + +# --------------------------------------------------------------------------- +# ARK availability gate (mirrors auto_round_extension/ark/test/test_moe_unified.py) +# --------------------------------------------------------------------------- + + +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def _ark_skip_reason() -> str: + if not _xpu_available(): + return "XPU not available" + try: + import auto_round_kernel as ark + except ImportError as exc: + return f"auto_round_kernel not importable: {exc}" + if getattr(ark, "xpu_lib", None) is None: + return "ark.xpu_lib is None (XPU extension failed to import)" + for sym in ("moe_gemm_decode", "moe_gemm_prefill"): + if not hasattr(ark.xpu_lib, sym): + return f"ark.xpu_lib missing {sym} (need ARK_SYCL_TLA=ON)" + if not hasattr(ark, "moe"): + return "ark.moe (unified entry point) not exported by auto_round_kernel" + return "" + + +_ARK_SKIP = _ark_skip_reason() + + +# --------------------------------------------------------------------------- +# Timing utilities +# --------------------------------------------------------------------------- + + +def _xpu_sync(): + if _xpu_available(): + torch.xpu.synchronize() + + +def _xpu_time_ms(fn, repeats: int = _TIMING_REPEATS) -> float: + """Time ``fn`` on XPU using ``torch.xpu.Event``; return median ms. + + The function is invoked ``repeats`` times after one warmup call. The + median is returned to absorb run-to-run jitter from the XPU runtime. + """ + fn() # warmup + _xpu_sync() + timings = [] + for _ in range(repeats): + start = torch.xpu.Event(enable_timing=True) + end = torch.xpu.Event(enable_timing=True) + start.record() + fn() + end.record() + end.synchronize() + timings.append(start.elapsed_time(end)) + return median(timings) + + +def _measure_prefill_ms(model, input_ids, attention_mask) -> float: + def _run(): + with torch.inference_mode(): + model(input_ids=input_ids, attention_mask=attention_mask) + + return _xpu_time_ms(_run) + + +def _measure_decode_ms_per_tok(model, tokenizer, prompt_ids, attention_mask) -> float: + """Median per-token decode latency from ``generate(max_new_tokens=N)``.""" + # Warmup generate (compiles caches, allocates KV). + with torch.inference_mode(): + model.generate( + input_ids=prompt_ids, + attention_mask=attention_mask, + do_sample=False, + max_new_tokens=_DECODE_WARMUP_TOKENS, + pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, + ) + + def _run(): + with torch.inference_mode(): + model.generate( + input_ids=prompt_ids, + attention_mask=attention_mask, + do_sample=False, + max_new_tokens=_DECODE_NEW_TOKENS, + pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, + ) + + total_ms = _xpu_time_ms(_run) + return total_ms / _DECODE_NEW_TOKENS + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_prefill_inputs(tokenizer, device, dtype, num_tokens=_PREFILL_PROMPT_TOKENS): + """Build a fixed-length (num_tokens) input tensor on ``device``.""" + # Use a deterministic synthetic prompt to keep the perf number stable + # across machines that may not have the same vocab tokenization for a + # natural-language prompt of a given length. + vocab_size = getattr(tokenizer, "vocab_size", 32000) + input_ids = torch.arange(num_tokens, dtype=torch.long).unsqueeze(0) % max(1, vocab_size) + attention_mask = torch.ones_like(input_ids) + return input_ids.to(device), attention_mask.to(device) + + +def _load_tiny_or_full(model_name_or_path, dtype): + """Return a model instance + tokenizer for the given MoE checkpoint. + + Tiny path uses ``get_tiny_model`` (random weights, 2 layers, 4 experts) + so the suite is tractable in CI. Full path loads the real checkpoint + when ``AR_MOE_PERF_FULL=1``. + """ + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) + if _USE_FULL_MODEL: + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, dtype=dtype, trust_remote_code=True) + else: + model = get_tiny_model( + model_name_or_path, + num_layers=_TINY_NUM_LAYERS, + num_experts=_TINY_NUM_EXPERTS, + from_config=True, + trust_remote_code=True, + ) + model = model.to(dtype) + return model, tokenizer + + +def _format_row(model_label, backend, prefill_ms, decode_ms_per_tok, baseline_decode_ms): + tps = 1000.0 / decode_ms_per_tok if decode_ms_per_tok > 0 else float("nan") + if baseline_decode_ms is None or baseline_decode_ms <= 0: + speedup_str = " --" + else: + speedup_str = f"{baseline_decode_ms / decode_ms_per_tok:6.2f}x" + return ( + f"{model_label:<22}{backend:<14}{prefill_ms:>12.3f}{decode_ms_per_tok:>16.3f}" + f"{tps:>14.2f}{speedup_str:>14}" + ) + + +def _print_header(dtype): + print() + print("=" * 96) + print(f"Model-level MoE perf -- dtype={dtype} AR_MOE_PERF_FULL={_USE_FULL_MODEL}") + print("-" * 96) + print( + f"{'model':<22}{'backend':<14}{'prefill(ms)':>12}{'decode(ms/tok)':>16}" + f"{'tokens/s':>14}{'vs FP':>14}" + ) + print("-" * 96) + + +# --------------------------------------------------------------------------- +# Test class +# --------------------------------------------------------------------------- + + +pytestmark = [pytest.mark.perf] + + +@pytest.mark.skipif(not _xpu_available(), reason="XPU not available") +class TestMoEModelPerf: + """Model-level perf table for MoE LLMs on the ARK XPU backend.""" + + @classmethod + def teardown_class(cls): + shutil.rmtree("runs", ignore_errors=True) + + @pytest.fixture(autouse=True) + def _save_dir(self, tmp_path): + self.save_folder = str(tmp_path / "saved") + yield + shutil.rmtree(self.save_folder, ignore_errors=True) + + # -- helpers ------------------------------------------------------------ + + def _quantize_and_save(self, model, tokenizer, bits, group_size, sym): + autoround = AutoRound( + model, + tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=0, + nsamples=1, + disable_opt_rtn=True, + ) + _, saved_folder = autoround.quantize_and_save(output_dir=self.save_folder, format="auto_round") + return saved_folder + + def _reload(self, saved_folder, dtype, *, backend): + kwargs = dict(dtype=dtype, device_map="xpu", trust_remote_code=True) + if backend is not None: + kwargs["quantization_config"] = AutoRoundConfig(backend=backend) + return AutoModelForCausalLM.from_pretrained(saved_folder, **kwargs) + + def _bench_one(self, model, tokenizer, dtype): + prompt_ids, attention_mask = _make_prefill_inputs(tokenizer, model.device, dtype) + prefill_ms = _measure_prefill_ms(model, prompt_ids, attention_mask) + decode_ms = _measure_decode_ms_per_tok(model, tokenizer, prompt_ids, attention_mask) + return prefill_ms, decode_ms + + def _smoke_test(self, model, tokenizer): + out = model_infer(model, tokenizer) + assert out is not None and len(out) > 0, "ARK backend produced empty output" + + # -- parametrized perf scan -------------------------------------------- + + @pytest.mark.parametrize( + "model_label, model_path", + [ + ("qwen-moe", qwen_moe_name_or_path), + ("deepseek-v2-lite", deepseek_v2_name_or_path), + ], + ) + @pytest.mark.parametrize("bits, group_size, sym", [(4, 128, True), (8, 128, True)]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_moe_forward_perf(self, model_label, model_path, bits, group_size, sym, dtype): + # 1. Resolve checkpoint -- skip if neither local mirror nor HF cache + # has it (CI without internet must not fail this test). + resolved = get_model_path(model_path) + try: + tokenizer_probe = AutoTokenizer.from_pretrained(resolved, trust_remote_code=True) + except (OSError, ValueError) as exc: + pytest.skip(f"checkpoint {model_path!r} not available locally: {exc}") + del tokenizer_probe + + # 2. Load tiny (or full) FP MoE model + tokenizer. + fp_model, tokenizer = _load_tiny_or_full(resolved, dtype) + + # 3. Quantize + save. + saved_folder = self._quantize_and_save(fp_model, tokenizer, bits, group_size, sym) + # Free the calibration-time model before reloading on XPU. + del fp_model + torch.xpu.empty_cache() + + _print_header(dtype) + label = f"{model_label} INT{bits}" + + # 4a. FP reference on XPU (no quantization_config). + fp_decode_ms = None + try: + fp_model_xpu = self._reload(resolved, dtype, backend=None) + fp_prefill_ms, fp_decode_ms = self._bench_one(fp_model_xpu, tokenizer, dtype) + print(_format_row(label, "fp(ref)", fp_prefill_ms, fp_decode_ms, fp_decode_ms)) + del fp_model_xpu + torch.xpu.empty_cache() + except Exception as exc: # noqa: BLE001 -- FP baseline is optional + print(f"[moe-model-perf] fp(ref) row skipped for {label}: {exc}") + + # 4b. ARK backend (the thing under test). + if _ARK_SKIP: + pytest.skip(f"ARK backend unavailable: {_ARK_SKIP}") + ark_model = self._reload(saved_folder, dtype, backend="ark") + self._smoke_test(ark_model, tokenizer) + ark_prefill_ms, ark_decode_ms = self._bench_one(ark_model, tokenizer, dtype) + print(_format_row(label, "ark", ark_prefill_ms, ark_decode_ms, fp_decode_ms)) + del ark_model + torch.xpu.empty_cache() + + # 4c. Optional GPTQModel cross-reference (skip silently if missing). + try: + gptq_model = self._reload(saved_folder, dtype, backend="gptqmodel") + gptq_prefill_ms, gptq_decode_ms = self._bench_one(gptq_model, tokenizer, dtype) + print(_format_row(label, "gptqmodel", gptq_prefill_ms, gptq_decode_ms, fp_decode_ms)) + del gptq_model + torch.xpu.empty_cache() + except Exception as exc: # noqa: BLE001 -- backend is optional + print(f"[moe-model-perf] gptqmodel row skipped for {label}: {exc}") + + print("-" * 96) + + # 5. Perf-regression assertion (only when we have an FP baseline). + if fp_decode_ms is not None and fp_decode_ms > 0: + assert ark_decode_ms <= fp_decode_ms * ARK_DECODE_REGRESSION_FACTOR, ( + f"ARK decode latency {ark_decode_ms:.3f} ms/tok exceeds " + f"{ARK_DECODE_REGRESSION_FACTOR}x FP baseline {fp_decode_ms:.3f} ms/tok " + f"for {label} ({dtype}). Likely regression in the ark.moe dispatcher." + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"])