From 5970c44a68a10ae5441b7abba9025481ec56a525 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 17 Jun 2026 12:02:35 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- examples/models/gemma4_31b/export.py | 8 ++- examples/models/gemma4_31b/main.cpp | 76 ++++++++++++++++++++++++++-- 2 files changed, 78 insertions(+), 6 deletions(-) diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index 64e55319490..5c2b21b2b98 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -190,8 +190,9 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - # Prefill (T>=2): shim does dequant+cuBLAS (optimal for large M). max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2) - seq_dim = Dim("seq_len", min=5, max=max_prefill) - print(f"Exporting prefill (T in [2, {max_prefill}])...") + min_prefill = 5 + seq_dim = Dim("seq_len", min=min_prefill, max=max_prefill) + print(f"Exporting prefill (T in [{min_prefill}, {max_prefill}])...") with torch.no_grad(): prefill_ep = export( model, @@ -250,6 +251,8 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - "get_vocab_size": config.vocab_size, "get_n_layers": config.num_hidden_layers, "get_max_prefill_chunk": max_prefill, + "get_min_prefill_chunk": min_prefill, + "get_sliding_window": config.sliding_window, "use_kv_cache": True, "use_sdpa_with_kv_cache": False, "enable_dynamic_shape": True, @@ -364,6 +367,7 @@ def _export_mlx( "get_vocab_size": config.vocab_size, "get_n_layers": config.num_hidden_layers, "get_max_prefill_chunk": max_prefill, + "get_sliding_window": config.sliding_window, "use_kv_cache": True, "use_sdpa_with_kv_cache": False, "enable_dynamic_shape": True, diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp index 6cf65cc8246..6be1653dde7 100644 --- a/examples/models/gemma4_31b/main.cpp +++ b/examples/models/gemma4_31b/main.cpp @@ -68,6 +68,11 @@ DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy)."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); DEFINE_int32(bos_id, 2, "BOS token id to prepend (Gemma convention: 2)."); DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1)."); +DEFINE_int32( + max_prefill_chunk, + 0, + "Override the prefill chunk size (<=0 uses metadata). Experiment: chunking " + "above sliding_window is inexact for sliding layers across boundaries."); DEFINE_bool( raw_prompt, false, @@ -168,13 +173,55 @@ int main(int argc, char** argv) { return 1; } - int64_t max_prefill_chunk = (*metadata_result)[llm::kMaxSeqLen] - 1; + int64_t exported_max_prefill = (*metadata_result)[llm::kMaxSeqLen] - 1; { auto get_result = module->get("get_max_prefill_chunk"); if (get_result.ok()) { - max_prefill_chunk = get_result->toScalar().to(); + exported_max_prefill = get_result->toScalar().to(); } } + // Cap prefill chunks at the sliding window: a chunk larger than the window + // overflows the 2*window ring cache across chunk boundaries, truncating + // sliding attention for the first ~(chunk-window) queries of each chunk (the + // global flat-cache layers stay exact). The export sets max_prefill = + // 2*sliding_window, so window = max_prefill/2 (prefer get_sliding_window + // metadata when present). + int64_t sliding_window = exported_max_prefill / 2; + { + auto sw = module->get("get_sliding_window"); + if (sw.ok()) { + sliding_window = sw->toScalar().to(); + } + } + int64_t max_prefill_chunk = std::min(sliding_window, exported_max_prefill); + if (FLAGS_max_prefill_chunk > 0) { + max_prefill_chunk = + std::min(FLAGS_max_prefill_chunk, exported_max_prefill); + } + // The exported prefill accepts T in [min_prefill, max_prefill]; a final chunk + // shorter than min_prefill (and > 1) is an out-of-range shape. Read the lower + // bound so chunking can avoid it (fallback 1 keeps older exports working: a + // length-1 tail already routes to decode). + int64_t min_prefill = 1; + { + auto r = module->get("get_min_prefill_chunk"); + if (r.ok()) { + min_prefill = r->toScalar().to(); + } + } + // A --max_prefill_chunk below the exported minimum has no valid prefill shape + // (and a cap of 1 would make the tail adjustment compute chunk_len == 0 and + // loop forever), so reject it rather than feed an out-of-range / zero chunk. + if (FLAGS_max_prefill_chunk > 0 && max_prefill_chunk < min_prefill) { + ET_LOG( + Error, + "--max_prefill_chunk (%d) is below the exported prefill minimum " + "(%" PRId64 "); use a value >= %" PRId64 " or omit it.", + FLAGS_max_prefill_chunk, + min_prefill, + min_prefill); + return 1; + } auto S = [](int64_t v) -> SizesType { return static_cast(v); }; @@ -280,6 +327,21 @@ int main(int argc, char** argv) { printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); stats.num_prompt_tokens = num_prompt_tokens; + // A prompt of 2..min_prefill-1 tokens has no valid prefill shape (the CUDA + // export specializes prefill to T >= min_prefill) and is too long for the + // single-token decode path, so reject it. A 1-token prompt is fine: it goes + // through decode below. + if (num_prompt_tokens > 1 && num_prompt_tokens < min_prefill) { + ET_LOG( + Error, + "Prompt (%" PRId64 + " tokens) is below the exported prefill minimum %" PRId64 + "; use a longer prompt.", + num_prompt_tokens, + min_prefill); + return 1; + } + stats.inference_start_ms = llm::time_in_ms(); // --------------------------------------------------------------- @@ -288,8 +350,14 @@ int main(int argc, char** argv) { uint64_t cur_token = 0; int64_t prefill_pos = 0; while (prefill_pos < num_prompt_tokens) { - int64_t chunk_len = - std::min(num_prompt_tokens - prefill_pos, max_prefill_chunk); + int64_t remaining = num_prompt_tokens - prefill_pos; + int64_t chunk_len = std::min(remaining, max_prefill_chunk); + // Shrink this chunk so the tail it leaves is never in (1, min_prefill): + // such a tail would be an out-of-range prefill shape. A length-1 tail is + // fine (routed to decode below); a >= min_prefill tail is fine too. + if (remaining - chunk_len > 1 && remaining - chunk_len < min_prefill) { + chunk_len = remaining - min_prefill; + } std::vector token_data( prompt_tokens.begin() + prefill_pos,