Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions examples/models/gemma4_31b/cuda_source_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from executorch.examples.models.gemma4.text_decoder import apply_rotary_emb
from executorch.extension.llm.modules.turboquant import TurboQuantKVCache
Expand Down Expand Up @@ -110,13 +111,117 @@ def _turboquant_attention_forward(
return self.o_proj(y)


def _fused_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
"""Drop-in ``Gemma4MLP.forward`` over a fused gate|up projection.

Identical math to ``down(gelu(gate(x)) * up(x))``: the single
``gate_up_proj`` emits ``[gate | up]`` concatenated on the last dim,
which is then split. One W4A8 matmul (and one activation-quant of ``x``)
instead of two.
"""
h = self.gate_up_proj(x)
gate = h[..., : self.intermediate_size]
up = h[..., self.intermediate_size :]
return self.down_proj(F.gelu(gate, approximate="tanh") * up)


def _concat_coalesced_int4_along_n(a, b):
"""Concatenate two ``CudaCoalescedInt4Tensor`` along the output (N) dim.

qdata is ``[N, K/2]`` and scale/zero_point are ``[N, n_groups]`` in the
coalesced layout, so a per-output-row concat on dim 0 is exact: the W4A8
dp4a matvec reads each output row's qdata/scale/zero independently, so
out[:N_a] reproduces ``a`` and out[N_a:] reproduces ``b`` bit-for-bit.
"""
from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor

return CudaCoalescedInt4Tensor(
torch.cat([a.qdata, b.qdata], dim=0),
torch.cat([a.scale, b.scale], dim=0),
torch.cat([a.zero_point, b.zero_point], dim=0),
a.block_size,
torch.Size([a.shape[0] + b.shape[0], a.shape[1]]),
None,
a.activation_dtype,
)


def _is_fuseable_int4_pair(gate_w, up_w) -> bool:
"""True iff gate/up are both coalesced-int4 with matching K + block_size.

Q4_K MLP weights become ``CudaCoalescedInt4Tensor`` (fuseable); a Q6_K
weight becomes ``CudaDp4aPlanarInt6Tensor`` (left alone). ``act_pre_scale``
is unused on this path but we require it absent so the concat stays exact.
"""
from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor

return (
isinstance(gate_w, CudaCoalescedInt4Tensor)
and isinstance(up_w, CudaCoalescedInt4Tensor)
and list(gate_w.block_size) == list(up_w.block_size)
and gate_w.shape[1] == up_w.shape[1]
and gate_w.act_pre_scale is None
and up_w.act_pre_scale is None
)


def _fuse_gate_up_proj(model: nn.Module) -> None:
"""Fuse each MLP's ``gate_proj | up_proj`` into one ``gate_up_proj``.

gate and up share the same input, so the unfused path quantizes ``x`` to
int8 twice and launches two W4A8 matvecs per layer. Fusing the weights
into one ``[2*inter, hidden]`` tensor halves both. Weight bytes read are
unchanged, so the win is launch + activation-quant overhead (decode is
launch-bound). Only Q4_K (coalesced-int4) layers are fused; any layer
with a non-int4 weight is left as two matmuls (still correct).

Must run AFTER weights are packed to ``CudaCoalescedInt4Tensor`` (i.e.
inside ``_export_cuda``), and is independent of TurboQuant.
"""
n_fused = 0
n_skipped = 0
for layer in model.layers:
mlp = getattr(layer, "mlp", None)
if mlp is None or not (hasattr(mlp, "gate_proj") and hasattr(mlp, "up_proj")):
continue
gate_w = mlp.gate_proj.weight
up_w = mlp.up_proj.weight
if not _is_fuseable_int4_pair(gate_w, up_w):
n_skipped += 1
continue
inter = up_w.shape[0]
hidden = up_w.shape[1]
fused_w = _concat_coalesced_int4_along_n(gate_w, up_w)

# Container built on meta to avoid materializing a dense
# [2*inter, hidden] weight before we overwrite it with fused_w.
gate_up = nn.Linear(hidden, 2 * inter, bias=False, device="meta")
gate_up.weight = nn.Parameter(fused_w, requires_grad=False)
mlp.gate_up_proj = gate_up
mlp.intermediate_size = inter
del mlp.gate_proj
del mlp.up_proj
mlp.forward = types.MethodType(_fused_mlp_forward, mlp)
n_fused += 1

msg = f"[gemma4_31b cuda] Fused gate+up on {n_fused} MLP layers"
if n_skipped:
msg += f" ({n_skipped} skipped: non-int4 weights)"
print(msg)


def cuda_source_transformations(
model: nn.Module,
*,
use_turboquant: bool = False,
) -> None:
"""Apply CUDA source transformations to a Gemma 4 31B model in place.

Always fuses each MLP's ``gate_proj|up_proj`` into a single matmul (one
activation-quant + one W4A8 matvec per layer instead of two; Q4_K
coalesced-int4 layers only — other quant types are left untouched).
Optionally also swaps full-attention KV caches for TurboQuant TQ4.

Args:
model: ``Gemma4_31B`` instance to transform.
use_turboquant: When True, swap full-attention layers' KV caches
Expand All @@ -125,6 +230,8 @@ def cuda_source_transformations(
``torch.ops.triton.tq4_sdpa``. Sliding-window layers are
unaffected.
"""
_fuse_gate_up_proj(model)

if not use_turboquant:
return

Expand Down
9 changes: 4 additions & 5 deletions examples/models/gemma4_31b/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,11 @@ def _export_cuda(

materialize_runtime_buffers(model, dtype=torch.bfloat16)

if use_turboquant:
from executorch.examples.models.gemma4_31b.cuda_source_transformations import (
cuda_source_transformations,
)
from executorch.examples.models.gemma4_31b.cuda_source_transformations import (
cuda_source_transformations,
)

cuda_source_transformations(model, use_turboquant=True)
cuda_source_transformations(model, use_turboquant=use_turboquant)

# Int4Tensor weights are used directly — no format conversion.
# F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim).
Expand Down
Loading