Skip to content

fix(python/sglang): FlashInfer kv_indptr issue#2

Closed
hxkwan wants to merge 1 commit intoOpenBMB:minicpm_salafrom
hxkwan:minicpm_sala
Closed

fix(python/sglang): FlashInfer kv_indptr issue#2
hxkwan wants to merge 1 commit intoOpenBMB:minicpm_salafrom
hxkwan:minicpm_sala

Conversation

@hxkwan
Copy link

@hxkwan hxkwan commented Mar 8, 2026

FlashInfer kv_indptr 缓冲区污染导致 CUDA Graph Replay 崩溃

1. 错误现象

使用 --dense-as-sparse --attention-backend minicpm_flashinfer 部署模型,推理运行一段时间后,
Scheduler 在 CUDA Graph Replay 阶段抛出异常:

tvm.error.InternalError: Error in function 'PrefillSplitQOKVIndptr':
  kv_indptr[9]34891 - kv_indptr[8]36606 should be non-negative

FlashInfer 要求 kv_indptr 数组是单调非递减的(作为 CSR 格式的索引指针),
但实际传入的值在位置 8 处为 36606,位置 9 处为 34891,出现了递减(差值 -1715),违反了约束。


2. 错误根因

2.1 背景:共享缓冲区设计

init_cuda_graph_state() 中,系统预计算了一个 flashinfer_kv_indptr 缓冲区,
其值为 [0, K, 2K, 3K, ..., (max_bs*2)*K](K = num_sparse_topk_tokens),
存储在 self.decode_cuda_graph_metadata["flashinfer_kv_indptr"] 中。

CUDA Graph 捕获时,FlashInfer 的 BatchDecodeWithPagedKVCacheWrapper 通过
paged_kv_indptr_buffer=kv_indptr_view 参数将该缓冲区的切片作为内部缓冲区
wrapper._paged_kv_indptr_buf两者指向同一块 GPU 内存

2.2 污染链条

                       ┌─────────────────────────────────────┐
                       │  decode_cuda_graph_metadata          │
                       │  ["flashinfer_kv_indptr"]            │
                       │  预计算值: [0, K, 2K, 3K, ...]       │
                       └──────────┬──────────────────────────┘
                                  │ 同一块 GPU 内存
                                  ▼
                       ┌─────────────────────────────────────┐
                       │  wrapper._paged_kv_indptr_buf       │
                       │  (FlashInfer wrapper 的内部缓冲区)    │
                       └──────────┬──────────────────────────┘
                                  │
         ┌────────────────────────┼──────────────────────────┐
         │  CUDA Graph 前 (replay_prepare)                    │
         │  读取该缓冲区 → wrapper.begin_forward()             │
         │  期望值: [0, K, 2K, ...]                           │
         └────────────────────────┬──────────────────────────┘
                                  │
         ┌────────────────────────┼──────────────────────────┐
         │  CUDA Graph 中 (captured forward)                  │
         │  FlashInferKernel.forward() 调用                    │
         │  convert_sparse_page_table_to_flashinfer()         │
         │  → 原地覆写该缓冲区为动态的 per-request 值           │
         │  覆写后值: [0, 3201, 7103, ...] (不再是等差数列)     │
         └────────────────────────┬──────────────────────────┘
                                  │
         ┌────────────────────────┼──────────────────────────┐
         │  下一轮 replay_prepare                              │
         │  读取已被污染的缓冲区                                │
         │  → 传入 wrapper.begin_forward()                     │
         │  → FlashInfer 检测到 kv_indptr 非单调递增 → 报错     │
         └────────────────────────────────────────────────────┘

具体代码路径:

  1. replay_prepare (init_forward_metadata_replay_cuda_graph, line ~1774):
    调用 wrapper.begin_forward(kv_indptr_view, ...) 读取预计算缓冲区

  2. CUDA Graph replay (FlashInferKernel.forward, line ~344):

    kv_indptr = wrapper._paged_kv_indptr_buf  # 指向预计算缓冲区!
    convert_sparse_page_table_to_flashinfer(
        params.page_table, params.cache_seqlens,
        kv_indptr, kv_indices, kv_last_page_len  # 原地覆写 kv_indptr!
    )

    该函数根据每个请求实际的 sparse page table 重新计算 kv_indptr,
    覆写后的值取决于各请求的实际 token 分布,不再是等差 [0, K, 2K, ...]

  3. 下一轮 replay_prepare 再次读取同一缓冲区时,看到的是上一轮被动态覆写的值,
    这些值可能非单调递增,导致 FlashInfer 的 plan() 函数校验失败。

2.3 为什么"用了一会儿之后"才报错

  • 第一轮 decode 使用的是干净的预计算值,begin_forward() 通过校验
  • Graph replay 中 convert_sparse_page_table_to_flashinfer() 污染缓冲区
  • 第二轮 decode 的 begin_forward() 才看到被污染的值
  • 是否触发报错取决于被污染的值是否恰好非单调递增——
    只要不同请求的 topk 页面分布产生了非单调的 cumsum 就会触发

3. 修改方法

核心思路

保存一份预计算值的只读副本flashinfer_kv_indptr_original),
在每次 wrapper.begin_forward() 调用前,从副本恢复缓冲区的预计算值。

修改文件

python/sglang/srt/layers/attention/minicpm_backend.py(共 3 处修改)

修改详情

修改 1:init_cuda_graph_state() — 保存只读副本

  "flashinfer_kv_indptr": precomputed_kv_indptr,
+ "flashinfer_kv_indptr_original": precomputed_kv_indptr.clone(),

新增一个 .clone() 副本,该副本不会被任何 wrapper 引用,因此不会被 CUDA Graph
中的 convert_sparse_page_table_to_flashinfer() 覆写。

修改 2:init_forward_metadata_capture_cuda_graph() — 捕获前恢复

  kv_indptr_view = self.decode_cuda_graph_metadata["flashinfer_kv_indptr"][:sparse_bs + 1]
+ kv_indptr_view.copy_(
+     self.decode_cuda_graph_metadata["flashinfer_kv_indptr_original"][:sparse_bs + 1]
+ )

多个 batch size 的 CUDA Graph 按顺序捕获,前一次捕获的 forward 会污染缓冲区,
影响后续捕获。加上恢复可保证每次捕获都使用干净的预计算值。

修改 3:init_forward_metadata_replay_cuda_graph() — 每次 replay 前恢复

  kv_indptr_view = self.decode_cuda_graph_metadata["flashinfer_kv_indptr"][:sparse_bs + 1]
+ kv_indptr_view.copy_(
+     self.decode_cuda_graph_metadata["flashinfer_kv_indptr_original"][:sparse_bs + 1]
+ )
  ...
- kv_indptr_view[sparse_real_bs:].fill_(kv_indptr_view[-1])
+ kv_indptr_view[sparse_real_bs:].fill_(kv_indptr_view[sparse_real_bs])

恢复后,填充 padding 区域的值改为使用 kv_indptr_view[sparse_real_bs]
(即 sparse_real_bs * K,恢复后的正确值),而非 kv_indptr_view[-1]
(在未恢复的情况下可能是被污染的值)。

@kfeng123
Copy link

Thank you for reporting the bug! We have fixed the issue in #3. Your solution uses fixed values of kv_indptr_view for each siginfer plan, which may have chance to result in unknown issue.

@hxkwan hxkwan closed this Mar 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants