From 993cff58ae5232e5be0ec360b0cde49b746b2019 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 24 Jun 2026 01:03:55 -0700 Subject: [PATCH 1/2] [gemma4_31b][cuda] Export Gemma4-31B @128k under 32 GB Three CUDA-export memory optimizations: - tq4_sdpa: add BLOCK_N=16 (and a BLOCK_M=32) autotune config. The superset is kept for big-shared-memory GPUs (A100/H100); the Triton autotuner auto-prunes configs that exceed a GPU's shared memory (OutOfResources -> inf), so the same config list also works on the 5090 (Blackwell, ~101 KB SMEM) where the previous smallest config did not fit. - int4_dispatch: chunk the inline _dequant_matmul along N for vocab-sized weights (N>65536, i.e. only the lm_head). Avoids transiently materializing the full ~10 GiB bf16 lm_head when AOTI executes the int4_plain_mm custom op during autotune / cpp_wrapper. The runtime decode path uses the C++ dp4a shim and the M>4 prefill inline path is below the threshold, so this never enters the runtime graph -> zero runtime / accuracy impact. Applied unconditionally (no flag). - cuda_backend / aoti_backend: skip occupying the GPU with the KV-cache buffers during AOTI compile (gated behind low_memory_mode). A new move_program_to_device hook places KV constants on the target device but immediately frees their storage (resize_(0)), so the fake-tensor device check passes while no real KV bytes sit on the GPU during autotune. The emptied buffers are re-synthesized as zeros at the _unlift_graph clone and at serialization, and excluded from constant dedup (resize_(0) gives every KV data_ptr 0, which would otherwise collapse same-shape caches across layers). Result on 2xA100: Gemma4-31B @128k no-TQ export peak 36.3 -> 27.0 GiB; the exported model runs correctly (output "...Paris."). --- backends/aoti/aoti_backend.py | 25 ++- backends/cuda/cuda_backend.py | 172 +++++++++++++++++- .../quantize_op_dispatch/int4_dispatch.py | 46 ++++- backends/cuda/triton/kernels/tq4_sdpa.py | 5 + 4 files changed, 229 insertions(+), 19 deletions(-) diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py index 91a8a60078e..22f6feeab6c 100644 --- a/backends/aoti/aoti_backend.py +++ b/backends/aoti/aoti_backend.py @@ -112,6 +112,21 @@ def codesign_so(cls, so_path: str, compile_specs: List[CompileSpec]) -> None: """ return + @classmethod + def move_program_to_device( + cls, + edge_program: ExportedProgram, + device: str, + compile_specs: List[CompileSpec], + ) -> ExportedProgram: + """Move the exported program to the target device for compilation. + + Default implementation moves everything (params, buffers, constants) via + ``move_to_device_pass``. Concrete backends may override to keep large + non-parameter tensors off the device during a low-memory export. + """ + return move_to_device_pass(edge_program, device) + @classmethod def release_moved_tensors( cls, @@ -196,9 +211,13 @@ def preprocess( decomposition_table = cls.get_decomposition_table() options = cls.get_aoti_compile_options(compile_specs) - # Move the edge_program to the target device - device_edge_program = move_to_device_pass( - edge_program, device_name if device_name != "metal" else "mps" + # Move the edge_program to the target device. Routed through a hook so + # backends can keep large non-parameter tensors (e.g. KV-cache buffers) + # off the device during a low-memory export. + device_edge_program = cls.move_program_to_device( + edge_program, + device_name if device_name != "metal" else "mps", + compile_specs, ) # Replace view_copy with view diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index f9f23a842f9..1781c5bfd39 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -61,29 +61,81 @@ def _is_cpu_clone_active() -> bool: return getattr(_CPU_CLONE_GUARD, "active", False) +def _full_zeros_preserving_strides(x: torch.Tensor, device) -> torch.Tensor: + """Allocate a zero-filled tensor matching ``x``'s size/stride/dtype on ``device``. + + Used to re-synthesize KV-cache buffers whose storage was freed (``resize_(0)``) + during the low-memory device move. KV content is all zeros, so this exactly + reproduces the buffer for both the lifted graph value and serialization. + """ + needed = 1 + for size, stride in zip(x.size(), x.stride()): + needed += (size - 1) * stride + buf = torch.zeros(int(needed), dtype=x.dtype, device=device) + return torch.as_strided(buf, x.size(), x.stride()) + + +def _is_emptied(x) -> bool: + return ( + isinstance(x, torch.Tensor) + and x.numel() > 0 + and x.untyped_storage().nbytes() == 0 + ) + + @contextlib.contextmanager def _compile_time_cpu_clones(target_device: torch.device): """Force AOTI's mutated-buffer clones onto CPU while preserving the serialized constants' target device.""" - from torch._inductor import compile_fx as _cfx + from torch._inductor import compile_fx as _cfx, graph as _graph from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu as _Cpp + from torch._inductor.graph import GraphLowering as _GL orig_clone = _cfx.clone_preserve_strides orig_codegen_device = _Cpp.codegen_device + orig_get_const = _GL.get_original_value_of_constant + orig_is_same = _graph.is_same_tensor + + def _is_same_skip_emptied(data, value): + # KV buffers freed via resize_(0) all have data_ptr 0, so the stock + # is_same_tensor would treat every same-shape KV constant as a duplicate + # and collapse the 60 layers' caches into one — the runtime needs each + # FQN's own buffer, so the collapsed ones load uninitialized garbage. + # Never dedup an emptied tensor. + if _is_emptied(data) or _is_emptied(value): + return False + return orig_is_same(data, value) def _cpu_clone_preserve_strides(x: torch.Tensor) -> torch.Tensor: - # `clone_preserve_strides` is shared by `_unlift_graph` (clones - # lifted buffers — can be safely kept on CPU) and by autotuning code - # in `triton_heuristics.py` (clones for benchmark — must stay on - # GPU for Triton). Discriminate by caller frame so we only force - # CPU clones for the buffer-lifting path. + # `clone_preserve_strides` is shared by `_unlift_graph` (clones lifted + # buffers — can be safely kept on CPU) and by autotuning code in + # `triton_heuristics.py` (clones for benchmark — must stay on GPU for + # Triton). Discriminate by caller frame so we only force CPU clones for + # the buffer-lifting path. import sys caller = sys._getframe(1).f_code.co_name if caller == "_unlift_graph": + # KV-cache buffers are emptied (storage resize_(0)) by the low-memory + # device move so they never occupy GPU memory during compile. Their + # content is all zeros, so re-synthesize zeros (on CPU, strides + # preserved) instead of cloning the now-empty storage. + if _is_emptied(x): + return _full_zeros_preserving_strides(x, "cpu") return orig_clone(x).cpu() return orig_clone(x) + def _get_const_synthesize_zeros(self, name): + # AOTI serializes each constant via get_original_value_of_constant -> + # _to_bytes. For KV buffers we freed with resize_(0) this would otherwise + # fall back to the empty-storage constant and write 0 bytes, producing a + # .ptd with an uninitialized cache. Re-synthesize the zeros so the blob + # holds a correctly-zeroed KV cache. + value = orig_get_const(self, name) + if _is_emptied(value): + return _full_zeros_preserving_strides(value, "cpu") + return value + def _codegen_device_target_aware(self, device): # Translate accidental CPU device strings back to the model target # device only when a constant we forced to CPU is being serialized. @@ -99,6 +151,8 @@ def _codegen_device_target_aware(self, device): _cfx.clone_preserve_strides = _cpu_clone_preserve_strides _Cpp.codegen_device = _codegen_device_target_aware + _GL.get_original_value_of_constant = _get_const_synthesize_zeros + _graph.is_same_tensor = _is_same_skip_emptied prev_active = getattr(_CPU_CLONE_GUARD, "active", False) _CPU_CLONE_GUARD.active = True try: @@ -107,6 +161,89 @@ def _codegen_device_target_aware(self, device): _CPU_CLONE_GUARD.active = prev_active _cfx.clone_preserve_strides = orig_clone _Cpp.codegen_device = orig_codegen_device + _GL.get_original_value_of_constant = orig_get_const + _graph.is_same_tensor = orig_is_same + + +def _is_kv_buffer(name, v) -> bool: + return ( + isinstance(v, torch.Tensor) + and not isinstance(v, torch.nn.Parameter) + and "kv_cache" in name + ) + + +def _empty_strided_on_device(v, location): + """A device tensor with v's shape/stride/dtype but zero (freed) storage.""" + t = torch.empty_strided(v.shape, v.stride(), dtype=v.dtype, device=location) + t.untyped_storage().resize_(0) # free bytes, keep device + shape/stride + return t + + +def _move_graph_nodes_to_device(graph_module, location): + """Point node device kwargs / aten.to.device targets / meta vals at location.""" + import torch.utils._pytree as pytree + + def _to_loc(v): + return v.to(location) if isinstance(v, torch.Tensor) else v + + for m in graph_module.modules(): + if not isinstance(m, torch.fx.GraphModule): + continue + for node in m.graph.nodes: + if "device" in node.kwargs: + node.kwargs = {**node.kwargs, "device": location} + if node.op == "call_function" and node.target is torch.ops.aten.to.device: + args = list(node.args) + args[1] = location + node.args = tuple(args) + node.meta["val"] = pytree.tree_map(_to_loc, node.meta.get("val")) + + +def _move_to_device_resize_kv(ep, location): + """``move_to_device_pass`` variant that frees KV-cache storage on-device. + + Mirrors ``torch.export.passes.move_to_device_pass`` exactly, except KV-cache + buffers (FQN contains ``kv_cache``) are placed on ``location`` but with their + storage immediately freed via ``resize_(0)``. This keeps ``device == + location`` — so the fake-tensor device check on the ``index_copy`` cache + update passes (``self`` and ``values`` both on cuda) — while no real KV bytes + occupy the device during the AOTI compile. KV content is all zeros, so the + emptied tensors are re-synthesized as zeros at the ``_unlift_graph`` clone + (see ``_compile_time_cpu_clones``), which is reused as both the lifted initial + value and the serialized ``.ptd`` constant. The empty/free is interleaved per + tensor so the transient device peak is a single KV buffer, not the whole cache. + Only ``kv_cache`` tensors are emptied (they are the lone large zero-buffers); + every other tensor is moved normally so non-zero content is never lost. + """ + import torch.utils._pytree as pytree + + for k, v in ep.state_dict.items(): + if isinstance(v, torch.nn.Parameter): + ep._state_dict[k] = torch.nn.Parameter(v.to(location), v.requires_grad) + elif _is_kv_buffer(k, v): + ep._state_dict[k] = _empty_strided_on_device(v, location) + else: + ep._state_dict[k] = v.to(location) + + for k, v in ep.constants.items(): + if isinstance(v, torch.Tensor): + ep._constants[k] = ( + _empty_strided_on_device(v, location) + if _is_kv_buffer(k, v) + else v.to(location) + ) + + if ep.example_inputs is not None: + args, kwargs = ep.example_inputs + ep._example_inputs = ( + pytree.tree_map_only(torch.Tensor, lambda t: t.to(location), args), + pytree.tree_map_only(torch.Tensor, lambda t: t.to(location), kwargs), + ) + + _move_graph_nodes_to_device(ep.graph_module, location) + ep.validate() + return ep @final @@ -424,6 +561,29 @@ def _is_low_memory_mode(compile_specs: List[CompileSpec]) -> bool: return spec.value.decode("utf-8").upper() == "ON" return False + @classmethod + def move_program_to_device( + cls, + edge_program, + device: str, + compile_specs: List[CompileSpec], + ): + """Move the program to ``device`` for AOTI compile. + + On a low-memory export (``low_memory_mode="ON"``) the KV-cache buffers — + which can be 10+ GiB at long context — are placed on-device but with their + storage freed (``resize_(0)``), so they never occupy device memory during + the autotune / cpp_wrapper compile while still satisfying the device-match + check on the cache update. They are re-synthesized as zeros for the lifted + graph and the serialized blob. This activates automatically with low-memory + mode. Other (non-low-memory) exports use the stock pass. + """ + from torch.export.passes import move_to_device_pass + + if not cls._is_low_memory_mode(compile_specs): + return move_to_device_pass(edge_program, device) + return _move_to_device_resize_kv(edge_program, device) + @classmethod def release_moved_tensors( cls, diff --git a/backends/cuda/quantize_op_dispatch/int4_dispatch.py b/backends/cuda/quantize_op_dispatch/int4_dispatch.py index c3b8921e2fe..1b8c370eecf 100644 --- a/backends/cuda/quantize_op_dispatch/int4_dispatch.py +++ b/backends/cuda/quantize_op_dispatch/int4_dispatch.py @@ -60,11 +60,29 @@ def _cuda(self, qdata, scale, zero, group_size): return _dequant_matmul(self, qdata, scale, zero, group_size) +# Chunked dequant for the export GPU budget. The lm_head dequant (N = vocab_size, +# e.g. 262144) runs through the int4_plain_mm custom op (M=1); AOTI executes that +# op's CUDA impl during autotune / cpp_wrapper codegen, where it transiently holds +# ~5 full-size bf16 temporaries (low/high/data/data-z/w_deq) — ~10 GiB for a +# 262144-row weight even though the final w_deq is only ~2.6 GiB. Chunking along N +# caps that at ~chunk rows. It is numerically identical (F.linear output rows are +# independent), and because only the lm_head (custom-op) path crosses the N +# threshold — never the M>4 prefill inline path — it never enters the runtime +# graph: ZERO runtime / accuracy impact. Applied unconditionally to any weight +# whose row count exceeds the threshold. +_DEQUANT_N_THRESHOLD = 65536 +_DEQUANT_N_CHUNK = 32768 + + def _dequant_matmul(x, qdata, scale, zero, group_size): """Dequant INT4 weights to input dtype and call F.linear. scale/zero are in the coalesced [N, n_groups] layout (baked into the weight constant at pack time), aligned row-for-row with qdata's [N, *]. + + Large weights (N > threshold, i.e. the lm_head) are chunked along N to bound + the dequant intermediate (see note above); smaller weights take the original + single-shot dequant. """ N, K_half = qdata.shape K = K_half * 2 @@ -72,16 +90,24 @@ def _dequant_matmul(x, qdata, scale, zero, group_size): gs_half = group_size // 2 dtype = x.dtype - p = qdata.to(torch.uint8).reshape(N, n_groups, gs_half) - low = (p & 0x0F).to(dtype) - high = ((p >> 4) & 0x0F).to(dtype) - data = torch.stack([low, high], dim=-1).reshape(N, n_groups, group_size) - - s = scale.to(dtype).unsqueeze(-1) - z = zero.to(dtype).unsqueeze(-1) - w_deq = ((data - z) * s).reshape(N, K) - - return F.linear(x, w_deq) + def _dq(qd, sc, ze, rows): + p = qd.to(torch.uint8).reshape(rows, n_groups, gs_half) + low = (p & 0x0F).to(dtype) + high = ((p >> 4) & 0x0F).to(dtype) + data = torch.stack([low, high], dim=-1).reshape(rows, n_groups, group_size) + s = sc.to(dtype).unsqueeze(-1) + z = ze.to(dtype).unsqueeze(-1) + w_deq = ((data - z) * s).reshape(rows, K) + return F.linear(x, w_deq) + + if N <= _DEQUANT_N_THRESHOLD: + return _dq(qdata, scale, zero, N) + + outs = [] + for i in range(0, N, _DEQUANT_N_CHUNK): + j = min(i + _DEQUANT_N_CHUNK, N) + outs.append(_dq(qdata[i:j], scale[i:j], zero[i:j], j - i)) + return torch.cat(outs, dim=-1) # --------------------------------------------------------------------------- diff --git a/backends/cuda/triton/kernels/tq4_sdpa.py b/backends/cuda/triton/kernels/tq4_sdpa.py index 10f02c7fa3c..7a41eaf92c1 100644 --- a/backends/cuda/triton/kernels/tq4_sdpa.py +++ b/backends/cuda/triton/kernels/tq4_sdpa.py @@ -294,6 +294,10 @@ def _tq4_sdpa_fwd_kernel_body( triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=3), ], key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"], ) @@ -410,6 +414,7 @@ def _tq4_sdpa_fwd_kernel_m64( triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 16}, num_warps=4, num_stages=3), ], key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"], ) From 92d62c974345d8fd387d6698f5e161b542ec9939 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Thu, 25 Jun 2026 08:55:35 -0700 Subject: [PATCH 2/2] Fix TurboQuant KV zeroed by low-mem export (993cff58ae): _is_kv_buffer only frees genuinely all-zero kv_cache.* buffers (count_nonzero==0); preserves TQ4 centroids/boundaries/rotation/rotation_T --- backends/cuda/cuda_backend.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 1781c5bfd39..b328a05df54 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -166,11 +166,29 @@ def _codegen_device_target_aware(self, device): def _is_kv_buffer(name, v) -> bool: - return ( - isinstance(v, torch.Tensor) - and not isinstance(v, torch.nn.Parameter) - and "kv_cache" in name - ) + """True only for an actual KV-cache *content* buffer that is safe to free. + + The low-memory path (``_move_to_device_resize_kv``) frees every buffer this + matches and re-synthesizes it as ZEROS in both the lifted graph and the + serialized ``.ptd`` (see ``_full_zeros_preserving_strides`` / + ``_get_const_synthesize_zeros``). That is only valid for genuine KV *content*, + which is all-zeros at export time (caches start empty). + + It must NOT match the non-zero constants that some KV-cache modules register + alongside the cache — e.g. TurboQuant registers its codebook/rotation + (``centroids``/``boundaries``/``rotation``/``rotation_T``) as buffers on the + ``kv_cache`` module, so their FQNs also contain ``kv_cache``. Freeing+zeroing + those silently corrupts the serialized model (TQ4 dequant -> 0 -> garbage). + Gate on the buffer actually being all-zeros so only empty KV content is freed; + this is robust to any future constant name (a non-zero buffer is never freed). + """ + if not isinstance(v, torch.Tensor) or isinstance(v, torch.nn.Parameter): + return False + if "kv_cache" not in name or v.numel() == 0 or v.is_meta: + return False + # Only the genuinely all-zero KV content may be freed + re-zeroed; non-zero + # constants (TurboQuant centroids/rotation/...) must be preserved as-is. + return bool(torch.count_nonzero(v) == 0) def _empty_strided_on_device(v, location):