Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions backends/aoti/aoti_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,21 @@ def codesign_so(cls, so_path: str, compile_specs: List[CompileSpec]) -> None:
"""
return

@classmethod
def move_program_to_device(
cls,
edge_program: ExportedProgram,
device: str,
compile_specs: List[CompileSpec],
) -> ExportedProgram:
"""Move the exported program to the target device for compilation.

Default implementation moves everything (params, buffers, constants) via
``move_to_device_pass``. Concrete backends may override to keep large
non-parameter tensors off the device during a low-memory export.
"""
return move_to_device_pass(edge_program, device)

@classmethod
def release_moved_tensors(
cls,
Expand Down Expand Up @@ -196,9 +211,13 @@ def preprocess(
decomposition_table = cls.get_decomposition_table()
options = cls.get_aoti_compile_options(compile_specs)

# Move the edge_program to the target device
device_edge_program = move_to_device_pass(
edge_program, device_name if device_name != "metal" else "mps"
# Move the edge_program to the target device. Routed through a hook so
# backends can keep large non-parameter tensors (e.g. KV-cache buffers)
# off the device during a low-memory export.
device_edge_program = cls.move_program_to_device(
edge_program,
device_name if device_name != "metal" else "mps",
compile_specs,
)

# Replace view_copy with view
Expand Down
190 changes: 184 additions & 6 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,29 +61,81 @@ def _is_cpu_clone_active() -> bool:
return getattr(_CPU_CLONE_GUARD, "active", False)


def _full_zeros_preserving_strides(x: torch.Tensor, device) -> torch.Tensor:
"""Allocate a zero-filled tensor matching ``x``'s size/stride/dtype on ``device``.

Used to re-synthesize KV-cache buffers whose storage was freed (``resize_(0)``)
during the low-memory device move. KV content is all zeros, so this exactly
reproduces the buffer for both the lifted graph value and serialization.
"""
needed = 1
for size, stride in zip(x.size(), x.stride()):
needed += (size - 1) * stride
buf = torch.zeros(int(needed), dtype=x.dtype, device=device)
return torch.as_strided(buf, x.size(), x.stride())


def _is_emptied(x) -> bool:
return (
isinstance(x, torch.Tensor)
and x.numel() > 0
and x.untyped_storage().nbytes() == 0
)


@contextlib.contextmanager
def _compile_time_cpu_clones(target_device: torch.device):
"""Force AOTI's mutated-buffer clones onto CPU while preserving the
serialized constants' target device."""
from torch._inductor import compile_fx as _cfx
from torch._inductor import compile_fx as _cfx, graph as _graph
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu as _Cpp
from torch._inductor.graph import GraphLowering as _GL

orig_clone = _cfx.clone_preserve_strides
orig_codegen_device = _Cpp.codegen_device
orig_get_const = _GL.get_original_value_of_constant
orig_is_same = _graph.is_same_tensor

def _is_same_skip_emptied(data, value):
# KV buffers freed via resize_(0) all have data_ptr 0, so the stock
# is_same_tensor would treat every same-shape KV constant as a duplicate
# and collapse the 60 layers' caches into one — the runtime needs each
# FQN's own buffer, so the collapsed ones load uninitialized garbage.
# Never dedup an emptied tensor.
if _is_emptied(data) or _is_emptied(value):
return False
return orig_is_same(data, value)

def _cpu_clone_preserve_strides(x: torch.Tensor) -> torch.Tensor:
# `clone_preserve_strides` is shared by `_unlift_graph` (clones
# lifted buffers — can be safely kept on CPU) and by autotuning code
# in `triton_heuristics.py` (clones for benchmark — must stay on
# GPU for Triton). Discriminate by caller frame so we only force
# CPU clones for the buffer-lifting path.
# `clone_preserve_strides` is shared by `_unlift_graph` (clones lifted
# buffers — can be safely kept on CPU) and by autotuning code in
# `triton_heuristics.py` (clones for benchmark — must stay on GPU for
# Triton). Discriminate by caller frame so we only force CPU clones for
# the buffer-lifting path.
import sys

caller = sys._getframe(1).f_code.co_name
if caller == "_unlift_graph":
# KV-cache buffers are emptied (storage resize_(0)) by the low-memory
# device move so they never occupy GPU memory during compile. Their
# content is all zeros, so re-synthesize zeros (on CPU, strides
# preserved) instead of cloning the now-empty storage.
if _is_emptied(x):
return _full_zeros_preserving_strides(x, "cpu")
return orig_clone(x).cpu()
return orig_clone(x)

def _get_const_synthesize_zeros(self, name):
# AOTI serializes each constant via get_original_value_of_constant ->
# _to_bytes. For KV buffers we freed with resize_(0) this would otherwise
# fall back to the empty-storage constant and write 0 bytes, producing a
# .ptd with an uninitialized cache. Re-synthesize the zeros so the blob
# holds a correctly-zeroed KV cache.
value = orig_get_const(self, name)
if _is_emptied(value):
return _full_zeros_preserving_strides(value, "cpu")
return value

def _codegen_device_target_aware(self, device):
# Translate accidental CPU device strings back to the model target
# device only when a constant we forced to CPU is being serialized.
Expand All @@ -99,6 +151,8 @@ def _codegen_device_target_aware(self, device):

_cfx.clone_preserve_strides = _cpu_clone_preserve_strides
_Cpp.codegen_device = _codegen_device_target_aware
_GL.get_original_value_of_constant = _get_const_synthesize_zeros
_graph.is_same_tensor = _is_same_skip_emptied
prev_active = getattr(_CPU_CLONE_GUARD, "active", False)
_CPU_CLONE_GUARD.active = True
try:
Expand All @@ -107,6 +161,107 @@ def _codegen_device_target_aware(self, device):
_CPU_CLONE_GUARD.active = prev_active
_cfx.clone_preserve_strides = orig_clone
_Cpp.codegen_device = orig_codegen_device
_GL.get_original_value_of_constant = orig_get_const
_graph.is_same_tensor = orig_is_same


def _is_kv_buffer(name, v) -> bool:
"""True only for an actual KV-cache *content* buffer that is safe to free.

The low-memory path (``_move_to_device_resize_kv``) frees every buffer this
matches and re-synthesizes it as ZEROS in both the lifted graph and the
serialized ``.ptd`` (see ``_full_zeros_preserving_strides`` /
``_get_const_synthesize_zeros``). That is only valid for genuine KV *content*,
which is all-zeros at export time (caches start empty).

It must NOT match the non-zero constants that some KV-cache modules register
alongside the cache — e.g. TurboQuant registers its codebook/rotation
(``centroids``/``boundaries``/``rotation``/``rotation_T``) as buffers on the
``kv_cache`` module, so their FQNs also contain ``kv_cache``. Freeing+zeroing
those silently corrupts the serialized model (TQ4 dequant -> 0 -> garbage).
Gate on the buffer actually being all-zeros so only empty KV content is freed;
this is robust to any future constant name (a non-zero buffer is never freed).
"""
if not isinstance(v, torch.Tensor) or isinstance(v, torch.nn.Parameter):
return False
if "kv_cache" not in name or v.numel() == 0 or v.is_meta:
return False
# Only the genuinely all-zero KV content may be freed + re-zeroed; non-zero
# constants (TurboQuant centroids/rotation/...) must be preserved as-is.
return bool(torch.count_nonzero(v) == 0)


def _empty_strided_on_device(v, location):
"""A device tensor with v's shape/stride/dtype but zero (freed) storage."""
t = torch.empty_strided(v.shape, v.stride(), dtype=v.dtype, device=location)
t.untyped_storage().resize_(0) # free bytes, keep device + shape/stride
return t


def _move_graph_nodes_to_device(graph_module, location):
"""Point node device kwargs / aten.to.device targets / meta vals at location."""
import torch.utils._pytree as pytree

def _to_loc(v):
return v.to(location) if isinstance(v, torch.Tensor) else v

for m in graph_module.modules():
if not isinstance(m, torch.fx.GraphModule):
continue
for node in m.graph.nodes:
if "device" in node.kwargs:
node.kwargs = {**node.kwargs, "device": location}
if node.op == "call_function" and node.target is torch.ops.aten.to.device:
args = list(node.args)
args[1] = location
node.args = tuple(args)
node.meta["val"] = pytree.tree_map(_to_loc, node.meta.get("val"))


def _move_to_device_resize_kv(ep, location):
"""``move_to_device_pass`` variant that frees KV-cache storage on-device.

Mirrors ``torch.export.passes.move_to_device_pass`` exactly, except KV-cache
buffers (FQN contains ``kv_cache``) are placed on ``location`` but with their
storage immediately freed via ``resize_(0)``. This keeps ``device ==
location`` — so the fake-tensor device check on the ``index_copy`` cache
update passes (``self`` and ``values`` both on cuda) — while no real KV bytes
occupy the device during the AOTI compile. KV content is all zeros, so the
emptied tensors are re-synthesized as zeros at the ``_unlift_graph`` clone
(see ``_compile_time_cpu_clones``), which is reused as both the lifted initial
value and the serialized ``.ptd`` constant. The empty/free is interleaved per
tensor so the transient device peak is a single KV buffer, not the whole cache.
Only ``kv_cache`` tensors are emptied (they are the lone large zero-buffers);
every other tensor is moved normally so non-zero content is never lost.
"""
import torch.utils._pytree as pytree

for k, v in ep.state_dict.items():
if isinstance(v, torch.nn.Parameter):
ep._state_dict[k] = torch.nn.Parameter(v.to(location), v.requires_grad)
elif _is_kv_buffer(k, v):
ep._state_dict[k] = _empty_strided_on_device(v, location)
else:
ep._state_dict[k] = v.to(location)

for k, v in ep.constants.items():
if isinstance(v, torch.Tensor):
ep._constants[k] = (
_empty_strided_on_device(v, location)
if _is_kv_buffer(k, v)
else v.to(location)
)

if ep.example_inputs is not None:
args, kwargs = ep.example_inputs
ep._example_inputs = (
pytree.tree_map_only(torch.Tensor, lambda t: t.to(location), args),
pytree.tree_map_only(torch.Tensor, lambda t: t.to(location), kwargs),
)

_move_graph_nodes_to_device(ep.graph_module, location)
ep.validate()
return ep


@final
Expand Down Expand Up @@ -424,6 +579,29 @@ def _is_low_memory_mode(compile_specs: List[CompileSpec]) -> bool:
return spec.value.decode("utf-8").upper() == "ON"
return False

@classmethod
def move_program_to_device(
cls,
edge_program,
device: str,
compile_specs: List[CompileSpec],
):
"""Move the program to ``device`` for AOTI compile.

On a low-memory export (``low_memory_mode="ON"``) the KV-cache buffers —
which can be 10+ GiB at long context — are placed on-device but with their
storage freed (``resize_(0)``), so they never occupy device memory during
the autotune / cpp_wrapper compile while still satisfying the device-match
check on the cache update. They are re-synthesized as zeros for the lifted
graph and the serialized blob. This activates automatically with low-memory
mode. Other (non-low-memory) exports use the stock pass.
"""
from torch.export.passes import move_to_device_pass

if not cls._is_low_memory_mode(compile_specs):
return move_to_device_pass(edge_program, device)
return _move_to_device_resize_kv(edge_program, device)

@classmethod
def release_moved_tensors(
cls,
Expand Down
46 changes: 36 additions & 10 deletions backends/cuda/quantize_op_dispatch/int4_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,28 +60,54 @@ def _cuda(self, qdata, scale, zero, group_size):
return _dequant_matmul(self, qdata, scale, zero, group_size)


# Chunked dequant for the export GPU budget. The lm_head dequant (N = vocab_size,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish there is a better way to do this i.e. why does this logic needs to be aware of export issues?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well it is not a export issue, but it impacts the memory consumption during exporation which is reasonable.

# e.g. 262144) runs through the int4_plain_mm custom op (M=1); AOTI executes that
# op's CUDA impl during autotune / cpp_wrapper codegen, where it transiently holds

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just a crude way of doing tile level dequant?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes indeed. it is tile-level dequant

# ~5 full-size bf16 temporaries (low/high/data/data-z/w_deq) — ~10 GiB for a
# 262144-row weight even though the final w_deq is only ~2.6 GiB. Chunking along N
# caps that at ~chunk rows. It is numerically identical (F.linear output rows are
# independent), and because only the lm_head (custom-op) path crosses the N
# threshold — never the M>4 prefill inline path — it never enters the runtime
# graph: ZERO runtime / accuracy impact. Applied unconditionally to any weight
# whose row count exceeds the threshold.
_DEQUANT_N_THRESHOLD = 65536
_DEQUANT_N_CHUNK = 32768
Comment on lines +73 to +74

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these kind of device specific?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no these are just the parameters for controling the peak memory we used for dequant;



def _dequant_matmul(x, qdata, scale, zero, group_size):
"""Dequant INT4 weights to input dtype and call F.linear.
scale/zero are in the coalesced [N, n_groups] layout (baked into the
weight constant at pack time), aligned row-for-row with qdata's [N, *].
Large weights (N > threshold, i.e. the lm_head) are chunked along N to bound
the dequant intermediate (see note above); smaller weights take the original
single-shot dequant.
"""
N, K_half = qdata.shape
K = K_half * 2
n_groups = K // group_size
gs_half = group_size // 2
dtype = x.dtype

p = qdata.to(torch.uint8).reshape(N, n_groups, gs_half)
low = (p & 0x0F).to(dtype)
high = ((p >> 4) & 0x0F).to(dtype)
data = torch.stack([low, high], dim=-1).reshape(N, n_groups, group_size)

s = scale.to(dtype).unsqueeze(-1)
z = zero.to(dtype).unsqueeze(-1)
w_deq = ((data - z) * s).reshape(N, K)

return F.linear(x, w_deq)
def _dq(qd, sc, ze, rows):
p = qd.to(torch.uint8).reshape(rows, n_groups, gs_half)
low = (p & 0x0F).to(dtype)
high = ((p >> 4) & 0x0F).to(dtype)
data = torch.stack([low, high], dim=-1).reshape(rows, n_groups, group_size)
s = sc.to(dtype).unsqueeze(-1)
z = ze.to(dtype).unsqueeze(-1)
w_deq = ((data - z) * s).reshape(rows, K)
return F.linear(x, w_deq)

if N <= _DEQUANT_N_THRESHOLD:
return _dq(qdata, scale, zero, N)

outs = []
for i in range(0, N, _DEQUANT_N_CHUNK):
j = min(i + _DEQUANT_N_CHUNK, N)
outs.append(_dq(qdata[i:j], scale[i:j], zero[i:j], j - i))
return torch.cat(outs, dim=-1)


# ---------------------------------------------------------------------------
Expand Down
5 changes: 5 additions & 0 deletions backends/cuda/triton/kernels/tq4_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@ def _tq4_sdpa_fwd_kernel_body(
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=3),
],
key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"],
)
Expand Down Expand Up @@ -410,6 +414,7 @@ def _tq4_sdpa_fwd_kernel_m64(
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 16}, num_warps=4, num_stages=3),
],
key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"],
)
Expand Down
Loading