Skip to content

[gemma4_31b][cuda] Export Gemma4-31B @128k on 5090#20480

Open
Gasoonjia wants to merge 1 commit into
mainfrom
gemma4_31b_export_under_32gb
Open

[gemma4_31b][cuda] Export Gemma4-31B @128k on 5090#20480
Gasoonjia wants to merge 1 commit into
mainfrom
gemma4_31b_export_under_32gb

Conversation

@Gasoonjia

Copy link
Copy Markdown
Contributor

Current gemma4-31b can not be successfully exported on consumer gpu like 5090 with three reasons:

  1. During int4_dispatch we need to dequant whole matmul weight to bf16 for prefill in one step for lm_head, leading to weight duplcation;
  2. When lowering to AOTI-CUDA, we moved the whole model, including kv cache, onto gpu. With context length increased, the gpu memory consumption will also be increased dramatically.
  3. No autotune config for kernels like sdpa work for consumer gpu like 5090.

Three CUDA-export memory optimizations, all gated behind the existing low_memory_mode compile spec (no impact on other models or on runtime):

  • int4_dispatch: chunk the inline _dequant_matmul along N for vocab-sized weights, gated behind a low-memory flag with an N>65536 threshold so only the lm_head crosses it. Avoids transiently materializing the full ~10 GiB bf16 lm_head during AOTI autotune / cpp_wrapper. The prefill MLP path is untouched -> zero runtime impact.

  • cuda_backend / aoti_backend: skip occupying the GPU with the KV-cache buffers during AOTI compile. A new move_program_to_device hook places KV constants on the target device but immediately frees their storage (resize_(0)), so the fake-tensor device check passes while no real KV bytes sit on the GPU during autotune. The emptied buffers are re-synthesized as zeros at the unlift_graph clone and at serialization, and excluded from constant dedup (resize(0) gives every KV data_ptr 0, which would otherwise collapse same-shape caches across layers). All gated behind low_memory_mode.

  • tq4_sdpa: add BLOCK_N=16 (and a BLOCK_M=32) autotune config. The superset is kept for big-shared-memory GPUs (A100/H100); the Triton autotuner auto-prunes configs that exceed a GPU's shared memory (OutOfResources -> inf), so the same config list also works on the 5090 (Blackwell, ~101 KB SMEM) where the previous smallest config did not fit.

Full Gemma4-31B on 128k TQ export: peak 28.0 GiB, runtime output correct ("...Paris.").

@pytorch-bot

pytorch-bot Bot commented Jun 24, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20480

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 3 New Failures, 3 Unrelated Failures

As of commit 993cff5 with merge base 1b726b2 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 24, 2026
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Three CUDA-export memory optimizations:

- tq4_sdpa: add BLOCK_N=16 (and a BLOCK_M=32) autotune config. The superset
  is kept for big-shared-memory GPUs (A100/H100); the Triton autotuner
  auto-prunes configs that exceed a GPU's shared memory (OutOfResources ->
  inf), so the same config list also works on the 5090 (Blackwell, ~101 KB
  SMEM) where the previous smallest config did not fit.

- int4_dispatch: chunk the inline _dequant_matmul along N for vocab-sized
  weights (N>65536, i.e. only the lm_head). Avoids transiently materializing
  the full ~10 GiB bf16 lm_head when AOTI executes the int4_plain_mm custom
  op during autotune / cpp_wrapper. The runtime decode path uses the C++ dp4a
  shim and the M>4 prefill inline path is below the threshold, so this never
  enters the runtime graph -> zero runtime / accuracy impact. Applied
  unconditionally (no flag).

- cuda_backend / aoti_backend: skip occupying the GPU with the KV-cache
  buffers during AOTI compile (gated behind low_memory_mode). A new
  move_program_to_device hook places KV constants on the target device but
  immediately frees their storage (resize_(0)), so the fake-tensor device
  check passes while no real KV bytes sit on the GPU during autotune. The
  emptied buffers are re-synthesized as zeros at the _unlift_graph clone and
  at serialization, and excluded from constant dedup (resize_(0) gives every
  KV data_ptr 0, which would otherwise collapse same-shape caches across
  layers).

Result on 2xA100: Gemma4-31B @128k no-TQ export peak 36.3 -> 27.0 GiB; the
exported model runs correctly (output "...Paris.").
Comment on lines +73 to +74
_DEQUANT_N_THRESHOLD = 65536
_DEQUANT_N_CHUNK = 32768

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these kind of device specific?

return _dequant_matmul(self, qdata, scale, zero, group_size)


# Chunked dequant for the export GPU budget. The lm_head dequant (N = vocab_size,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish there is a better way to do this i.e. why does this logic needs to be aware of export issues?


# Chunked dequant for the export GPU budget. The lm_head dequant (N = vocab_size,
# e.g. 262144) runs through the int4_plain_mm custom op (M=1); AOTI executes that
# op's CUDA impl during autotune / cpp_wrapper codegen, where it transiently holds

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just a crude way of doing tile level dequant?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants