Skip to content

[CI] Gate expensive ngram tests behind env vars#7170

Open
cloudforge1 wants to merge 25 commits intoPaddlePaddle:developfrom
CloudForge-Solutions:task/049-envgate-ngram-tests
Open

[CI] Gate expensive ngram tests behind env vars#7170
cloudforge1 wants to merge 25 commits intoPaddlePaddle:developfrom
CloudForge-Solutions:task/049-envgate-ngram-tests

Conversation

@cloudforge1
Copy link
Copy Markdown
Contributor

Motivation

Gate expensive ngram tests behind environment variables so they don't run in CI by default. These tests either require extreme GPU memory (>40 GB for test_large_batch_long_seq) or are benchmarks rather than correctness tests (test_latency).

Addresses Copilot AI review comments from PR #6960.

Depends on: #6960 (must merge first — this PR is stacked on top).

Modifications

Added @unittest.skipUnless decorators to 3 test methods in test_ngram_gpu_kernel.py:

Method Class Env Var Reason
test_large_batch_long_seq TestNgramMatchKernel RUN_LARGE_NGRAM_TESTS=1 >40 GB GPU memory
test_latency TestNgramMatchKernel RUN_NGRAM_BENCHMARKS=1 Benchmark, not correctness
test_large_batch_long_seq TestHybridMtpNgramKernel RUN_LARGE_NGRAM_TESTS=1 >40 GB GPU memory

Usage or Command

# Run all correctness tests (default CI behavior — unchanged)
pytest tests/spec_decode/test_ngram_gpu_kernel.py

# Also run large-scale tests (needs >40 GB GPU)  
RUN_LARGE_NGRAM_TESTS=1 pytest tests/spec_decode/test_ngram_gpu_kernel.py

# Also run benchmarks
RUN_NGRAM_BENCHMARKS=1 pytest tests/spec_decode/test_ngram_gpu_kernel.py

Accuracy Tests

No accuracy impact — only adds skip conditions. All existing correctness tests continue to run by default.

Checklist

  • Env-gated tests still pass when enabled: RUN_LARGE_NGRAM_TESTS=1 RUN_NGRAM_BENCHMARKS=1 pytest ...
  • Default CI run skips expensive tests (no OOM risk)
  • pre-commit checks pass

Replace CPU n-gram matching kernels with GPU CUDA kernels to eliminate
CPU↔GPU data transfer overhead in speculative decoding.

Key changes:
- ngram_match.cc → ngram_match.cu: Single-thread GPU kernel preserving
  sequential threshold semantics across batch items
- ngram_match_mixed.cu: Replace CPU function with __global__ kernel
- ngram.py: Remove ~10 .cpu() tensor copies, pass GPU tensors directly
- mtp.py: Remove .cpu()/.cuda() round-trips and CUDAPinnedPlace copies

Design: <<<1,1>>> single-thread kernels (same approach as TensorRT-LLM).
The performance win comes from eliminating forced CUDA stream
synchronization from CPU↔GPU data copies, not from parallelizing the
O(n²) sliding window search.
Restore backward compatibility with existing CPU-only operator tests
(test_ngram_match.py, test_hybrid_mtp_ngram.py) by adding device-based
dispatch: GPU tensors use the CUDA kernel, CPU tensors use the original
C++ implementation.
Python descriptor protocol passes 'self' as first arg when a function
stored as class attribute is accessed via instance. Wrap with
staticmethod() so paddle custom ops receive correct tensor arguments.
Reverts line 39 to match develop (keeps .cpu()) so diff-cover
no longer flags it as an uncovered changed line. The tensor is
moved to GPU via .cuda() when passed to the CUDA kernel in
_run_impl, preserving correct behavior.
…n.cuh)

Per upstream requirement: '两个Kernel逻辑有较为相似部分,Kernel
形式为提取共用的匹配逻辑,外加业务逻辑'

The core ngram sliding-window search + token copy logic is now defined
once in ngram_match_common.cuh as two __device__ __forceinline__
functions:
  - ngram_search_and_copy: single-haystack sliding window match
  - ngram_search_batch_item: two-phase search (input_ids then pre_ids)

Both kernels call ngram_search_batch_item with their business-specific
parameters:
  - ngram_match_kernel: write_offset=1, min_ngram_size=1
  - ngram_match_mixed_kernel: write_offset=ori_seq_len_this_time,
    min_ngram_size=configurable

No functional change. CPU fallback paths unchanged.
Two-phase parallel architecture addressing reviewer feedback:
- Phase 1: <<<bsz, 256>>> — parallel sliding-window ngram search
  using atomicMin64 CAS loop for leftmost-match semantics
- Phase 2: <<<1, 1>>> — serial threshold + token copy (inter-batch
  dependency via running sum of seq_lens_this_time)

Phase 1 is O(bsz × seq_len × ngram_size) distributed across bsz × 256
threads.  Phase 2 is O(bsz × max_draft_tokens) — negligible.

Shared code extracted into ngram_match_common.cuh:
  NgramMatchResult struct, atomicMin64, parallel_ngram_search,
  4 kernel functions (search+gather for both kernel types)

Tests: 6 new large-scale correctness tests with env-var threshold
override — bsz=256/seq_len=128k, bsz=1/seq_len=128k, bsz=256/seq_len=1k
for both ngram_match and hybrid_mtp_ngram.
…ultiple-def error)

Both ngram_match.cu and ngram_match_mixed.cu include ngram_match_common.cuh.
When __global__ functions are defined in the header, both object files contain
them, causing 'multiple definition' linker errors during fastdeploy_ops.so link.

Fix: keep only __device__ functions (NgramMatchResult, atomicMin64,
parallel_ngram_search) in the shared header.  Move __global__ kernel
definitions into each respective .cu file.

Net code change: +304/-304 (zero net lines).
Fix 7 type-mismatch compilation errors in ngram_match_mixed.cu:
- Search kernel: replace seq_lens_encoder/decoder with seq_lens_this_time
  (host function does not have seq_lens_encoder tensor)
- Gather kernel: remove seq_lens_encoder param, compute ori_seq_len_this_time
  per-batch from seq_lens_this_time (matches CPU path logic)
- Fix max_draft_tokens computation to match CPU path formula
- Fix skip condition to match CPU path: ori_seq_len_this_time==0 || max_draft_tokens<=0
…n mtp, remove unused kernel param, isolate benchmark timing
Groups: seq_len, batch_size, ngram hit pattern, threshold, threshold×batch.
Data creation outside timing loop. GPU kernel vs CPU-copy path.
- Renamed benchmark_ngram_kernel.py → test_benchmark_ngram_kernel.py
  so pytest discovers it (test_*.py pattern)
- Bumped NUM_ITERS 10→10000, WARMUP 2→5 for noise-free profiling
- Gated benchmark class with RUN_NGRAM_BENCHMARKS=1 (won't bloat CI)
- test_large_batch_long_seq (×2): RUN_LARGE_NGRAM_TESTS=1 (>40 GB GPU memory)
- test_latency: RUN_NGRAM_BENCHMARKS=1 (benchmark, not correctness)
Copilot AI review requested due to automatic review settings April 2, 2026 19:34
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 2, 2026

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Apr 2, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 的目标是通过环境变量开关,将耗时/高显存的 ngram 相关测试(含 benchmark)从默认 CI 流程中隔离出来,避免 CI OOM 或不必要的性能测试耗时;同时从 diff 看也包含了 ngram CUDA kernel 及其 Python 调用链的调整。

Changes:

  • 在 ngram GPU kernel 测试中为大规模用例与 latency benchmark 增加 env gate(RUN_LARGE_NGRAM_TESTS / RUN_NGRAM_BENCHMARKS)。
  • 新增多维度 benchmark 测试文件(同样通过 RUN_NGRAM_BENCHMARKS gate)。
  • 调整 spec_decode 侧对 ngram/hybrid_mtp_ngram 自定义算子的调用方式,并引入/替换对应 CUDA 实现与公共头文件。

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
tests/spec_decode/test_ngram_gpu_kernel.py 新增 ngram/hybrid_mtp_ngram 的正确性与延迟测试,并对昂贵用例加 env skip
tests/spec_decode/test_benchmark_ngram_kernel.py 新增多组 benchmark(通过 env gate 跳过)
fastdeploy/spec_decode/ngram.py 将 ngram proposer 调用切到 GPU 路径并新增 input_ids_len_gpu
fastdeploy/spec_decode/mtp.py 将 hybrid ngram 扩展 draft token 调用切到 GPU 路径
custom_ops/gpu_ops/speculate_decoding/ngram_match.cu 新增/替换 ngram_match 的 CUDA 两阶段实现并保留 CPU 兼容路径
custom_ops/gpu_ops/speculate_decoding/ngram_match.cc 删除旧的纯 CPU 实现文件
custom_ops/gpu_ops/speculate_decoding/ngram_match_common.cuh 抽取共享的 ngram 搜索与 atomicMin64 工具函数
custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu 为 hybrid_mtp_ngram 新增 CUDA 两阶段实现并保留 CPU 兼容路径

Comment on lines 39 to +56
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
self.input_ids_len_gpu = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64")

def update(self, bid: int, seq_len: int):
"""
update
"""
self.input_ids_len[bid] = seq_len
self.input_ids_len_gpu[bid] = seq_len

def _run_impl(self, share_inputs):
"""
run
"""
draft_tokens = share_inputs["draft_tokens"].cpu()
seq_lens_this_time = share_inputs["seq_lens_this_time"].cpu()
seq_lens_encoder = share_inputs["seq_lens_encoder"].cpu()
seq_lens_decoder = share_inputs["seq_lens_decoder"].cpu()

ngram_match(
share_inputs["input_ids_cpu"],
self.input_ids_len.cpu(),
share_inputs["token_ids_all"].cpu(),
share_inputs["prompt_lens"].cpu(),
share_inputs["step_idx"].cpu(),
share_inputs["actual_draft_token_num"].cpu(),
draft_tokens,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
share_inputs["max_dec_len"].cpu(),
share_inputs["input_ids_cpu"].cuda(),
self.input_ids_len_gpu,
share_inputs["token_ids_all"],
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 GPU 路径仍然依赖 share_inputs["input_ids_cpu"](CPU Tensor)再 .cuda() 拷到 GPU,同时 input_ids_len_gpu 只能通过本类的 update() 来更新;但当前代码库里没有找到任何对 NgramProposer.update() 的调用,导致 input_ids_len_gpu 可能一直是 0,进而让 ngram 搜索长度恒为 0(功能错误/永远匹配不到)。建议:直接复用 runner 侧已有且会被更新的 GPU Tensor(例如 share_inputs["input_ids"] 及其长度来源),或在 runner 的数据准备阶段显式调用 update() 并同步维护 input_ids_cpu

Copilot uses AI. Check for mistakes.
Comment on lines 1178 to +1183
def _extend_draft_token_with_ngram_match(self):
# TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency
device = paddle.CUDAPinnedPlace()

draft_tokens = self.target_model_inputs["draft_tokens"].cpu()
seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu()
seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu()
hybrid_mtp_ngram(
self.model_inputs["input_ids_cpu"],
self.model_inputs["input_ids_len"],
self.model_inputs["pre_ids"]._copy_to(device, True),
self.model_inputs["step_idx"].cpu(),
self.target_model_inputs["actual_draft_token_num"].cpu(),
draft_tokens,
seq_lens_this_time,
seq_lens_decoder,
self.model_inputs["max_dec_len"].cpu(),
self.model_inputs["input_ids_cpu"].cuda(),
self.model_inputs["input_ids_len"].cuda(),
self.model_inputs["pre_ids"],
self.model_inputs["step_idx"],
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里对 input_ids_cpu / input_ids_len 每次调用都执行 .cuda(),会引入 CPU→GPU 拷贝与潜在的额外分配/同步,和“CUDA kernel 零拷贝/降低同步”的目标相矛盾,也可能影响 CUDA Graph 捕获/回放的稳定性。建议改为直接传入已经在 GPU 上维护的张量(例如 self.model_inputs["input_ids"]、GPU 版 input_ids_len),或在输入准备阶段一次性把需要的长度/ID buffer 常驻到 GPU 并原地更新。

Copilot uses AI. Check for mistakes.
5. threshold × batch (batch=128)

Run:
cd FastDeploy && python tests/spec_decode/benchmark_ngram_kernel.py
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该文件的 docstring 里的运行命令路径与文件名不一致:当前文件名是 test_benchmark_ngram_kernel.py,但示例写的是 python tests/spec_decode/benchmark_ngram_kernel.py。建议修正为实际文件路径,避免使用者按文档执行时报错。

Suggested change
cd FastDeploy && python tests/spec_decode/benchmark_ngram_kernel.py
cd FastDeploy && python tests/spec_decode/test_benchmark_ngram_kernel.py

Copilot uses AI. Check for mistakes.
Comment on lines +416 to +420
@unittest.skipUnless(
os.environ.get("RUN_LARGE_NGRAM_TESTS", "0") == "1",
"Set RUN_LARGE_NGRAM_TESTS=1 to run (requires >40 GB GPU memory)",
)
def test_large_batch_long_seq(self):
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 描述中提到“仅给 3 个测试方法加 skipUnless”,但当前 diff 实际包含:新增基准测试文件、spec_decode Python 调用路径变更,以及 CUDA kernel/自定义 op 实现的大量改动。建议更新 PR 描述以准确覆盖这些变更范围(或拆分 PR),否则审阅者和 CI 维护者难以判断变更风险与回归面。

Copilot uses AI. Check for mistakes.
Comment on lines +347 to +351
int threshold = 128;
const char *env_var = getenv("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD");
if (env_var) {
threshold = std::stoi(env_var);
}
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里直接用 std::stoi(env_var) 解析环境变量;若用户/CI 把该环境变量设置为非整数(例如空串、带单位的值),会抛异常并可能导致进程直接终止(异常跨 C++/Python 边界不一定可控)。建议改用 strtol/std::from_chars 并在解析失败时回退到默认阈值,同时对负数/过大值做边界保护。

Copilot uses AI. Check for mistakes.
Comment on lines +341 to +342
if (env_var) {
threshold = std::stoi(env_var);
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里同样使用 std::stoi(env_var) 解析 SPEC_TOKENUM_THRESHOLD;一旦环境变量不是纯数字,会抛异常并可能导致进程异常退出。建议使用不抛异常的解析方式(如 std::from_chars/strtol)并在失败时回退默认值,同时做阈值范围校验。

Suggested change
if (env_var) {
threshold = std::stoi(env_var);
if (env_var && env_var[0] != '\0') {
char *end_ptr = nullptr;
long parsed_threshold = strtol(env_var, &end_ptr, 10);
if (end_ptr != env_var && *end_ptr == '\0' && parsed_threshold > 0 &&
parsed_threshold <= static_cast<long>(std::numeric_limits<int>::max())) {
threshold = static_cast<int>(parsed_threshold);
}

Copilot uses AI. Check for mistakes.
Comment on lines +161 to +173
# Warmup
for _ in range(WARMUP):
# Reset mutable outputs
gpu_data["draft_tokens"] = paddle.zeros([batch_size, MAX_DRAFT_TOKENS + 1], dtype="int64").cuda()
gpu_data["seq_lens_this_time"] = paddle.ones([batch_size], dtype="int32").cuda()
_run_gpu(ngram_match_fn, gpu_data)
paddle.device.synchronize()

paddle.device.synchronize()
t0 = time.perf_counter()
for _ in range(n_runs):
gpu_data["draft_tokens"] = paddle.zeros([batch_size, MAX_DRAFT_TOKENS + 1], dtype="int64").cuda()
gpu_data["seq_lens_this_time"] = paddle.ones([batch_size], dtype="int32").cuda()
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

基准计时代码 _time_gpu() 在每次迭代里都会重新 paddle.zeros(...).cuda() 分配 draft_tokens / seq_lens_this_time(以及随后 synchronize),这会把显著的显存分配与框架开销计入“kernel 时间”,导致 benchmark 结果失真且与“仅测 kernel 执行”目标不一致。建议:把输出 tensor 预先分配好,在循环内用原地清零/填充值(如 fill_/assign)重置,避免重复分配。

Suggested change
# Warmup
for _ in range(WARMUP):
# Reset mutable outputs
gpu_data["draft_tokens"] = paddle.zeros([batch_size, MAX_DRAFT_TOKENS + 1], dtype="int64").cuda()
gpu_data["seq_lens_this_time"] = paddle.ones([batch_size], dtype="int32").cuda()
_run_gpu(ngram_match_fn, gpu_data)
paddle.device.synchronize()
paddle.device.synchronize()
t0 = time.perf_counter()
for _ in range(n_runs):
gpu_data["draft_tokens"] = paddle.zeros([batch_size, MAX_DRAFT_TOKENS + 1], dtype="int64").cuda()
gpu_data["seq_lens_this_time"] = paddle.ones([batch_size], dtype="int32").cuda()
gpu_data["draft_tokens"] = paddle.zeros([batch_size, MAX_DRAFT_TOKENS + 1], dtype="int64").cuda()
gpu_data["seq_lens_this_time"] = paddle.ones([batch_size], dtype="int32").cuda()
# Warmup
for _ in range(WARMUP):
# Reset mutable outputs in place to avoid allocation overhead in benchmark.
gpu_data["draft_tokens"].fill_(0)
gpu_data["seq_lens_this_time"].fill_(1)
_run_gpu(ngram_match_fn, gpu_data)
paddle.device.synchronize()
paddle.device.synchronize()
t0 = time.perf_counter()
for _ in range(n_runs):
gpu_data["draft_tokens"].fill_(0)
gpu_data["seq_lens_this_time"].fill_(1)

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-03 11:20 CST

📋 Review 摘要

PR 概述:声称仅对 ngram 测试添加环境变量门控,但实际包含重大 GPU kernel 实现
变更范围custom_ops/gpu_ops/speculate_decoding/fastdeploy/spec_decode/tests/spec_decode/
影响面 Tag[OP] [Speculative Decoding]


📝 PR 规范检查

🔴 PR 标题和描述与实际变更严重不匹配

PR 标题声称 [CI] Gate expensive ngram tests behind env vars,描述中仅提到添加 @unittest.skipUnless 装饰器。

但实际变更包括:

变更类型 文件 描述
🆕 新增 GPU Kernel ngram_match.cu (438 行) 全新的两阶段并行 CUDA kernel 实现
🆕 新增共享头文件 ngram_match_common.cuh (99 行) 提取的 ngram 匹配共用逻辑
🔄 重构 ngram_match_mixed.cu 添加 GPU kernel 路径,保留 CPU fallback
❌ 删除 ngram_match.cc (227 行) 删除原 CPU 实现
🔄 API 变更 mtp.py, ngram.py 从 CPU→GPU 拷贝改为直接 GPU 调用

这是一个重大功能变更(将 ngram 匹配从 CPU 移至 GPU),应使用 [OP][Speculative Decoding] 标签,而非 [CI]

标题建议(可直接复制):

[Speculative Decoding][OP] Add GPU CUDA kernels for ngram matching

描述模板(可直接复制):

## Motivation

Optimize ngram matching performance by implementing GPU CUDA kernels, eliminating CPU-GPU data transfer overhead.

## Modifications

1. **New GPU Kernels** (`ngram_match.cu`, `ngram_match_common.cuh`):
   - Phase 1: Parallel sliding-window search (<<<bsz, 256>>>)
   - Phase 2: Serial threshold + token copy (<<<1, 1>>>)
   - Shared `parallel_ngram_search` helper with atomicMin64

2. **API Changes** (`mtp.py`, `ngram.py`):
   - Switch from CPU path (copy to CPU → process → copy back) to direct GPU kernel invocation

3. **Test Updates**:
   - Gate expensive tests behind env vars (RUN_LARGE_NGRAM_TESTS, RUN_NGRAM_BENCHMARKS)

## Depends on
#6960 (must merge first)

问题

级别 文件 概述
🟡 建议 全局 PR 描述需更新以反映实际的 GPU kernel 实现变更

总体评价

代码实现质量良好:

  • 两阶段 CUDA kernel 架构设计合理,Phase 1 并行搜索 + Phase 2 串行聚合
  • atomicMin64 通过 CAS 循环正确实现 int64 原子最小值
  • 保留了 CPU fallback 路径,兼容性好
  • 测试覆盖完整,包含正确性测试和基准测试

主要问题是 PR 元信息不准确,建议更新标题和描述后合并。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants