From 1c371e257a37f6298acd03be9434d47d96508a92 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 23 Jun 2026 15:21:25 -0700 Subject: [PATCH] [executorch][cuda] gemma4_31b: fuse gate/up MLP projections (default-on) Summary: Fuse each gemma4_31b MLP's gate_proj|up_proj into a single [2*intermediate, hidden] coalesced-int4 matmul, applied by default in the CUDA export. This issues one activation-quant + one W4A8 matvec per layer instead of two, cutting per-token launch + activation-quant overhead in the launch-bound decode path. Only Q4_K (CudaCoalescedInt4Tensor) gate/up pairs are fused; any other quant type (e.g. Q6_K) is left as two matmuls (guarded, still correct). Builds on the already-landed kv_len-bounded tq4_sdpa kernel + gemma4_31b call-site (kv_len + mask_is_causal), which recovered 128k decode from ~2.8 to ~43 tok/s. With both, ET gemma4_31b 128k+TurboQuant decode beats llama.cpp at every measured context (cuda_graph ON): ctx ET llama 512 44.80 42.77 2K 43.20 41.97 8K 42.23 41.23 32K 41.64 40.27 127K 38.41 35.97 TurboQuant KV compression kept; prefill restored (6-8x) with no regression; output quality preserved. Test Plan: - Fusion numerics: fused vs unfused MLP through the real W4A8 int4_plain_mm kernel = bit-exact (max_abs_diff 0.0, cos 1.000000) for decode (T=1) and prefill (T=4). - Export + run: fused module exported via CudaPartitioner and executed through executor_runner (RC=0, cos 0.999915 vs eager). Full 31B export logs "Fused gate+up on 60 MLP layers". - Decode A/B (gemma4_31b 128k+TQ, cuda_graph ON, 5x median): table above; beats llama.cpp at 512 -> 127K. nsys: tq4_sdpa 91.7% -> 2.9% of decode. --- .../gemma4_31b/cuda_source_transformations.py | 107 ++++++++++++++++++ examples/models/gemma4_31b/export.py | 9 +- 2 files changed, 111 insertions(+), 5 deletions(-) diff --git a/examples/models/gemma4_31b/cuda_source_transformations.py b/examples/models/gemma4_31b/cuda_source_transformations.py index 666d0c44e9d..6609178e084 100644 --- a/examples/models/gemma4_31b/cuda_source_transformations.py +++ b/examples/models/gemma4_31b/cuda_source_transformations.py @@ -30,6 +30,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from executorch.examples.models.gemma4.text_decoder import apply_rotary_emb from executorch.extension.llm.modules.turboquant import TurboQuantKVCache @@ -110,6 +111,105 @@ def _turboquant_attention_forward( return self.o_proj(y) +def _fused_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: + """Drop-in ``Gemma4MLP.forward`` over a fused gate|up projection. + + Identical math to ``down(gelu(gate(x)) * up(x))``: the single + ``gate_up_proj`` emits ``[gate | up]`` concatenated on the last dim, + which is then split. One W4A8 matmul (and one activation-quant of ``x``) + instead of two. + """ + h = self.gate_up_proj(x) + gate = h[..., : self.intermediate_size] + up = h[..., self.intermediate_size :] + return self.down_proj(F.gelu(gate, approximate="tanh") * up) + + +def _concat_coalesced_int4_along_n(a, b): + """Concatenate two ``CudaCoalescedInt4Tensor`` along the output (N) dim. + + qdata is ``[N, K/2]`` and scale/zero_point are ``[N, n_groups]`` in the + coalesced layout, so a per-output-row concat on dim 0 is exact: the W4A8 + dp4a matvec reads each output row's qdata/scale/zero independently, so + out[:N_a] reproduces ``a`` and out[N_a:] reproduces ``b`` bit-for-bit. + """ + from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor + + return CudaCoalescedInt4Tensor( + torch.cat([a.qdata, b.qdata], dim=0), + torch.cat([a.scale, b.scale], dim=0), + torch.cat([a.zero_point, b.zero_point], dim=0), + a.block_size, + torch.Size([a.shape[0] + b.shape[0], a.shape[1]]), + None, + a.activation_dtype, + ) + + +def _is_fuseable_int4_pair(gate_w, up_w) -> bool: + """True iff gate/up are both coalesced-int4 with matching K + block_size. + + Q4_K MLP weights become ``CudaCoalescedInt4Tensor`` (fuseable); a Q6_K + weight becomes ``CudaDp4aPlanarInt6Tensor`` (left alone). ``act_pre_scale`` + is unused on this path but we require it absent so the concat stays exact. + """ + from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor + + return ( + isinstance(gate_w, CudaCoalescedInt4Tensor) + and isinstance(up_w, CudaCoalescedInt4Tensor) + and list(gate_w.block_size) == list(up_w.block_size) + and gate_w.shape[1] == up_w.shape[1] + and gate_w.act_pre_scale is None + and up_w.act_pre_scale is None + ) + + +def _fuse_gate_up_proj(model: nn.Module) -> None: + """Fuse each MLP's ``gate_proj | up_proj`` into one ``gate_up_proj``. + + gate and up share the same input, so the unfused path quantizes ``x`` to + int8 twice and launches two W4A8 matvecs per layer. Fusing the weights + into one ``[2*inter, hidden]`` tensor halves both. Weight bytes read are + unchanged, so the win is launch + activation-quant overhead (decode is + launch-bound). Only Q4_K (coalesced-int4) layers are fused; any layer + with a non-int4 weight is left as two matmuls (still correct). + + Must run AFTER weights are packed to ``CudaCoalescedInt4Tensor`` (i.e. + inside ``_export_cuda``), and is independent of TurboQuant. + """ + n_fused = 0 + n_skipped = 0 + for layer in model.layers: + mlp = getattr(layer, "mlp", None) + if mlp is None or not (hasattr(mlp, "gate_proj") and hasattr(mlp, "up_proj")): + continue + gate_w = mlp.gate_proj.weight + up_w = mlp.up_proj.weight + if not _is_fuseable_int4_pair(gate_w, up_w): + n_skipped += 1 + continue + inter = up_w.shape[0] + hidden = up_w.shape[1] + fused_w = _concat_coalesced_int4_along_n(gate_w, up_w) + + # Container built on meta to avoid materializing a dense + # [2*inter, hidden] weight before we overwrite it with fused_w. + gate_up = nn.Linear(hidden, 2 * inter, bias=False, device="meta") + gate_up.weight = nn.Parameter(fused_w, requires_grad=False) + mlp.gate_up_proj = gate_up + mlp.intermediate_size = inter + del mlp.gate_proj + del mlp.up_proj + mlp.forward = types.MethodType(_fused_mlp_forward, mlp) + n_fused += 1 + + msg = f"[gemma4_31b cuda] Fused gate+up on {n_fused} MLP layers" + if n_skipped: + msg += f" ({n_skipped} skipped: non-int4 weights)" + print(msg) + + def cuda_source_transformations( model: nn.Module, *, @@ -117,6 +217,11 @@ def cuda_source_transformations( ) -> None: """Apply CUDA source transformations to a Gemma 4 31B model in place. + Always fuses each MLP's ``gate_proj|up_proj`` into a single matmul (one + activation-quant + one W4A8 matvec per layer instead of two; Q4_K + coalesced-int4 layers only — other quant types are left untouched). + Optionally also swaps full-attention KV caches for TurboQuant TQ4. + Args: model: ``Gemma4_31B`` instance to transform. use_turboquant: When True, swap full-attention layers' KV caches @@ -125,6 +230,8 @@ def cuda_source_transformations( ``torch.ops.triton.tq4_sdpa``. Sliding-window layers are unaffected. """ + _fuse_gate_up_proj(model) + if not use_turboquant: return diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index d9e16bc34df..b2b2264178a 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -182,12 +182,11 @@ def _export_cuda( materialize_runtime_buffers(model, dtype=torch.bfloat16) - if use_turboquant: - from executorch.examples.models.gemma4_31b.cuda_source_transformations import ( - cuda_source_transformations, - ) + from executorch.examples.models.gemma4_31b.cuda_source_transformations import ( + cuda_source_transformations, + ) - cuda_source_transformations(model, use_turboquant=True) + cuda_source_transformations(model, use_turboquant=use_turboquant) # Int4Tensor weights are used directly — no format conversion. # F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim).