Skip to content

[BugFix] fix flashinfer-cutedsl moe nvfp4#7120

Merged
lizexu123 merged 22 commits intoPaddlePaddle:developfrom
lizexu123:fix_nvfpr
Apr 3, 2026
Merged

[BugFix] fix flashinfer-cutedsl moe nvfp4#7120
lizexu123 merged 22 commits intoPaddlePaddle:developfrom
lizexu123:fix_nvfpr

Conversation

@lizexu123
Copy link
Copy Markdown
Collaborator

@lizexu123 lizexu123 commented Apr 1, 2026

Motivation

修复nvfp4环境变量的问题,并把flashinfer的导入不使用lazy加载的方式,ci环境不是B卡,在代码里面检测如果不是B卡,则在nvfp4.py中不导入flashinfer,并且更新了使用flashinfer_cutedsl 跑nvfp4的文档

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 1, 2026

Thanks for your contribution!

zhoutianzi666
zhoutianzi666 previously approved these changes Apr 1, 2026
Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-01 19:37:08

📋 Review 摘要

PR 概述:修复 NVFP4 环境变量类型问题,并将 flashinfer 导入改为条件导入(仅在 Blackwell GPU 上导入)
变更范围model_executor/layers/quantization/model_executor/layers/moe/envs.py、文档
影响面 Tag[Quantization] [OP] [Docs]

问题

级别 文件 概述
🔴 Bug quant_base.py:26-32 is_nvfp4_supported() 函数在非 CUDA 平台返回 None 而非 False,导致条件判断异常
🔴 Bug nvfp4.py:70 非 B 卡环境下调用 get_sm_version() 可能在非 CUDA 环境出现问题
🟡 建议 nvfp4.py:593-597 load_up_proj_weight_first 属性缺少 else 返回值

总体评价

本次 PR 的核心目标是正确的,通过条件导入避免在非 Blackwell GPU 上导入 flashinfer 依赖。但 is_nvfp4_supported() 函数存在返回值缺失的问题,可能导致在非 CUDA 平台上出现运行时错误,需要修复后再合入。

from fastdeploy.model_executor.utils import get_sm_version

sm_version = get_sm_version()
return sm_version >= NVFP4_MIN_SM_VERSION
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug is_nvfp4_supported() 函数在非 CUDA 平台上隐式返回 None

current_platform.is_cuda() 返回 False 时,函数没有显式的 return 语句,会隐式返回 None。由于调用方使用 if is_nvfp4_supported(): 进行条件判断,None 会被视为 False,但这可能导致后续代码中的意外行为。

此外,函数内的 docstring 位置不正确(放在了 if 语句内部)。

建议修复:
```python
def is_nvfp4_supported() -> bool:
"""Check if current GPU supports NVFP4 (requires SM >= 100, Blackwell)."""
if not current_platform.is_cuda():
return False
from fastdeploy.model_executor.utils import get_sm_version
sm_version = get_sm_version()
return sm_version >= NVFP4_MIN_SM_VERSION
```

flashinfer_cutlass_fused_moe = None
logger.warning(
f"NVFP4 requires Blackwell GPU (SM >= 100), "
f"current GPU has SM {get_sm_version()}. Skipping flashinfer imports."
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 在 else 分支中直接调用 get_sm_version() 在非 CUDA 环境可能出问题

is_nvfp4_supported() 返回 False(例如非 CUDA 平台或 SM < 100)时,此处会尝试调用 get_sm_version() 来打印警告信息。虽然 get_sm_version() 内部有 paddle.cuda.is_available() 检查并返回 0,但结合上面 is_nvfp4_supported() 返回 None 的问题,可能导致意外行为。

建议保持一致的错误处理逻辑,并确保 is_nvfp4_supported() 正确返回布尔值。

def load_up_proj_weight_first(self) -> bool:
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
if self.backend == "flashinfer-cutlass":
return True
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 load_up_proj_weight_first 属性缺少 else 返回值

self.backend 不是 "flashinfer-cutlass" 时,此属性会隐式返回 None。建议显式返回 False

```python
@Property
def load_up_proj_weight_first(self) -> bool:
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
return self.backend == "flashinfer-cutlass"
```

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 1, 2026

Codecov Report

❌ Patch coverage is 52.72727% with 26 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@98f3fc9). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...deploy/model_executor/layers/quantization/nvfp4.py 62.85% 12 Missing and 1 partial ⚠️
...odel_executor/layers/moe/flashinfer_cutedsl_moe.py 0.00% 7 Missing ⚠️
...loy/model_executor/layers/quantization/__init__.py 0.00% 2 Missing and 2 partials ⚠️
...y/model_executor/layers/quantization/quant_base.py 77.77% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7120   +/-   ##
==========================================
  Coverage           ?   73.87%           
==========================================
  Files              ?      376           
  Lines              ?    52914           
  Branches           ?     8256           
==========================================
  Hits               ?    39089           
  Misses             ?    11096           
  Partials           ?     2729           
Flag Coverage Δ
GPU 73.87% <52.72%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-02 20:01 CST

📋 Review 摘要

PR 概述:修复 NVFP4 环境变量处理问题,将 flashinfer 导入改为基于 GPU 能力检测的条件导入,避免在非 Blackwell GPU 环境下导入失败。
变更范围:quantization 模块、MoE 层、文档
影响面 TagQuantization OP Docs

📝 PR 规范检查

PR 标题和 Tag 符合规范。建议补充 Modifications 部分的具体变更说明。

描述建议(可直接复制):

## Modifications

1. 新增 `is_nvfp4_supported()` 函数检测 GPU 是否支持 NVFP4(SM >= 100)
2. 将 flashinfer 模块导入改为条件导入,仅在 Blackwell GPU 上导入
3. 修复 `FD_NVFP4_LOAD_BLOCKSCALE_LEAVE` 环境变量类型处理(字符串转布尔值)
4. 更新 NVFP4 文档,添加 flashinfer-cutedsl 后端使用说明

问题

级别 文件 概述
🔴 误提交 benchmark.sh 包含硬编码本地路径的测试脚本,不应提交到仓库
🟡 建议 quant_base.py:28 docstring 位置不正确,应放在函数开头
🟡 建议 nvfp4.py:75 全局函数调用条件导入的模块,非 B 卡环境下错误消息不友好

总体评价

核心逻辑变更合理,通过 GPU 能力检测避免在非 Blackwell GPU 环境下导入 flashinfer 是正确的方向。但 benchmark.sh 文件包含开发者本地路径,需要移除后再合并。

benchmark.sh Outdated
@@ -0,0 +1,15 @@
export PYTHONPATH="/root/paddlejob/workspace/env_run/output/lizexu/FastDeploy":$PYTHONPATH
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 误提交 此文件包含硬编码的本地路径,不应提交到仓库。

文件中包含:

  • /root/paddlejob/workspace/env_run/output/lizexu/FastDeploy - 开发者本地路径
  • /raid0/dataset/dis_query_eb_0_32k_5k_converted.json - 本地数据集路径

建议:

  1. 将此文件从提交中移除:git rm benchmark.sh
  2. 如需保留测试脚本,可添加到 .gitignore 或使用相对路径/环境变量


def is_nvfp4_supported() -> bool:
if current_platform.is_cuda():
"""Check if current GPU supports NVFP4 (requires SM >= 100, Blackwell)."""
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 docstring 位置不正确

当前 docstring 放在了 if 语句内部,应该放在函数定义的第一行。

建议修改为:

def is_nvfp4_supported() -> bool:
    """Check if current GPU supports NVFP4 (requires SM >= 100, Blackwell)."""
    if current_platform.is_cuda():
        from fastdeploy.model_executor.utils import get_sm_version
        sm_version = get_sm_version()
        return sm_version >= NVFP4_MIN_SM_VERSION
    else:
        return False

# C++ op can detect the no-scale path via tensor.numel() == 0.
if scale is None:
scale = paddle.empty([0], dtype=paddle.float32)
def call_prefill_permute_to_masked_gemm(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 函数在非 B 卡环境下调用会产生不友好的错误消息

call_prefill_permute_to_masked_gemmcall_depermute_prefill_combine 现在定义在全局作用域,但它们内部调用的 prefill_permute_to_masked_gemm 在非 B 卡环境下是 None

虽然在非 B 卡环境下使用 NVFP4 本身就是配置错误,但建议添加更友好的错误提示:

def call_prefill_permute_to_masked_gemm(...):
    if prefill_permute_to_masked_gemm is None:
        raise RuntimeError(
            "NVFP4 requires Blackwell GPU (SM >= 100). "
            "Current GPU does not support this feature."
        )
    # ... 原有逻辑

@lizexu123 lizexu123 merged commit 5f612a3 into PaddlePaddle:develop Apr 3, 2026
77 of 85 checks passed
Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-03 15:45 CST

📋 Review 摘要

PR 概述:修复 NVFP4 环境变量问题,将 flashinfer 导入改为条件加载(仅 Blackwell GPU),并更新 flashinfer-cutedsl 相关文档
变更范围model_executor/layers/quantization/model_executor/layers/moe/、文档
影响面 TagQuantization OP Docs

问题

级别 文件 概述
🔴 Bug nvfp4.py:62 deep_ep 被错误设为 None,导致非 Blackwell GPU 上 EP 功能崩溃
🔴 Bug docs/quantization/nvfp4.md:37 文档中补丁代码示例有死代码,会误导用户
🟡 建议 quant_base.py:28 docstring 位置不正确

总体评价

PR 的核心意图(条件加载 flashinfer 以支持非 Blackwell GPU CI 环境)是合理的,但 deep_ep 不应该被设为 None——它是 Expert Parallel 通信库,与 GPU 架构无关,在 H 卡等环境下仍需可用。建议将 deep_ep 的导入移到条件语句外部。文档中的代码补丁示例也需要修正。

)
else:
# Not B卡, skip flashinfer imports
deep_ep = None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug deep_ep 被错误地设置为 None,导致在非 Blackwell GPU 上调用 EP 相关方法时崩溃。

deep_ep 是从 fastdeploy.model_executor.layers.moe.ep 导入的 DeepEP 通信库,用于 Expert Parallel(专家并行)通信,与 GPU 架构(SM 版本)无关。它应该在所有支持 CUDA 的平台上可用。

当前代码在非 Blackwell GPU 上会导致:

  • 第 673 行:deep_ep.Buffer.capture()AttributeError: 'NoneType' object has no attribute 'Buffer'
  • 第 755 行:同样的错误

建议修复:将 deep_ep 的导入移到条件语句外部,或者在 else 分支中也正确导入:

# 在文件顶部无条件导入
from fastdeploy.model_executor.layers.moe.ep import deep_ep

```bash
@functools.cache
def get_compute_capability(device: torch.device) -> Tuple[int, int]:
return torch.cuda.get_device_capability(device)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug(文档) 代码示例存在逻辑错误,第一行 return 后的代码是死代码,永远不会执行。

def get_compute_capability(device: torch.device) -> Tuple[int, int]:
    return torch.cuda.get_device_capability(device)  # ← 立即返回
    if device.type != "cuda":  # ← 永远不会执行
        raise ValueError("device must be a cuda device")
    return torch.cuda.get_device_capability(device.index)  # ← 永远不会执行

用户如果按此修改会导致功能异常。建议修正为正确的逻辑顺序:

@functools.cache
def get_compute_capability(device: torch.device) -> Tuple[int, int]:
    if device.type != "cuda":
        raise ValueError("device must be a cuda device")
    return torch.cuda.get_device_capability(device.index)


def is_nvfp4_supported() -> bool:
if current_platform.is_cuda():
"""Check if current GPU supports NVFP4 (requires SM >= 100, Blackwell)."""
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 docstring 位置不正确。当前 docstring 放在了 if 语句块内部,应该放在函数定义之后。

def is_nvfp4_supported() -> bool:
    """Check if current GPU supports NVFP4 (requires SM >= 100, Blackwell)."""  # ← 应该在这里
    if current_platform.is_cuda():
        # 删除这里的 docstring
        from fastdeploy.model_executor.utils import get_sm_version
        ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants