-
Notifications
You must be signed in to change notification settings - Fork 610
[Common] Tuned NVFP4 cast kernel #2412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Common] Tuned NVFP4 cast kernel #2412
Conversation
Greptile SummaryThis PR introduces a specialized CUDA kernel for NVFP4 quantization of BF16 tensors on Blackwell architecture (sm_100+), achieving significant performance improvements (6.4 TB/s for round-to-nearest, 4.5 TB/s for stochastic rounding). Key changes:
Critical issue:
Other observations:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Dispatcher as quantize_transpose<br/>(nvfp4/quantize_transpose_nvfp4.cuh)
participant TunedKernel as quantize_transpose_tuned_1D<br/>(specialized/...tuned_1D.cuh)
participant GPU as GPU Kernel<br/>(Blackwell-specific)
participant PTX as PTX Instructions<br/>(TMA, mbarrier, cluster)
User->>Dispatcher: quantize_transpose(input, output, config)
Dispatcher->>Dispatcher: Check: !use_2d_quantization &&<br/>input.dtype == BF16
alt BF16 + 1D quantization
Dispatcher->>TunedKernel: quantize_transpose_tuned_1D()
TunedKernel->>TunedKernel: Validate inputs & setup
TunedKernel->>TunedKernel: Create TMA tensor maps
TunedKernel->>TunedKernel: Launch kernel with config<br/>(USE_STOCHASTIC_ROUNDING,<br/>USE_FAST_MATH, RETURN_TRANSPOSE)
TunedKernel->>GPU: quantize_transpose_nvfp4_tuned_1D_kernel<<<grid, block>>>
GPU->>PTX: TMA async copy (global->shared)
GPU->>PTX: mbarrier init/arrive/wait
GPU->>PTX: Quantize BF16->NVFP4 with scaling
GPU->>PTX: Optional: cluster launch control
GPU->>PTX: TMA async copy (shared->global)
GPU-->>TunedKernel: Return
TunedKernel-->>Dispatcher: Return
else Other configurations
Dispatcher->>Dispatcher: Use generic kernel path
end
Dispatcher-->>User: Return quantized output
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
445c870 to
a7a0652
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/cast/dispatch/quantize.cuh, line 250 (link)logic: missing
is_supported_by_CC_100()check. forward pass (line 104) has this check, but backward pass doesn't. without it, the persistent kernel can be dispatched on non-Blackwell hardware and hitNVTE_DEVICE_ERRORat runtime.
5 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh, line 1171-1174 (link)logic: Missing runtime architecture check before dispatching to persistent kernel. The tuned 1D kernel uses Blackwell-specific PTX instructions (mbarrier variants, cluster launch control, etc.) and requires sm_100+. Should guard with
is_supported_by_CC_100()check similar to other TMA kernel dispatches in the codebase (seequantize_fp8.cuh:535,gated.cuh:48).
4 files reviewed, 1 comment
91312be to
a38eeff
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh, line 1171-1174 (link)logic: Missing runtime check for sm_100+ before dispatching to tuned kernel. The tuned kernel uses Blackwell-specific features (TMA, cluster launch control) that will fail on non-sm_100 devices.
5 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
tests/cpp/operator/test_cast_nvfp4_transpose.cu, line 125-126 (link)syntax: typo in comment: 'Numercial' should be 'Numerical'
5 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh, line 1171-1173 (link)logic: Add runtime architecture check before dispatching to tuned kernel:
The tuned kernel requires sm_100+ but this dispatch lacks the
is_supported_by_CC_100()check that other sm_100+ kernels use (seecast/dispatch/gated.cuh:48,cast/fp8/quantize_fp8.cuh:535). Without this check, the kernel will launch on non-Blackwell devices and hitNVTE_DEVICE_ERRORat runtime.
5 files reviewed, 1 comment
|
/te-ci |
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 1 comment
| if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { | ||
| quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); | ||
| return; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: missing runtime check for sm_100+ before dispatching to tuned kernel. kernel uses Blackwell-specific instructions (TMA, mbarrier, cluster launch control) that will cause runtime errors on older GPUs. add is_supported_by_CC_100() check like other architecture-specific kernels.
| if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { | |
| quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); | |
| return; | |
| } | |
| if (!use_2d_quantization && (input.dtype() == DType::kBFloat16) && is_supported_by_CC_100()) { | |
| quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); | |
| return; | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is guarded by macros FP4_TYPE_SUPPORTED which triggers the runtime NVTE error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're absolutely right! The check should be within the #if FP4_TYPE_SUPPORTED block to prevent the runtime error. The current code will fall through to the NVTE_ERROR when FP4_TYPE_SUPPORTED is false.
| if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { | |
| quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); | |
| return; | |
| } | |
| if (!use_2d_quantization && (input.dtype() == DType::kBFloat16) && is_supported_by_CC_100()) { | |
| quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); | |
| return; | |
| } |
However, I need to check if is_supported_by_CC_100() is available in this context. Could you show me where this function is defined or what the correct architecture check should be for sm_100+?
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
…(numerical truncation). Tightened the numerical tolerance Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
…formant when inputs are large Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
7362e47 to
bc26160
Compare
for more information, see https://pre-commit.ci
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
|
/te-ci |
| const double atol = 0.05; | ||
| const double rtol = 0.1; | ||
| const double atol = 1.0E-6; | ||
| const double rtol = 1.0E-6; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this changing? Looks like it got stricter so I think if it passes, it is better. But I am not understanding why it was lower before
| constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; | ||
| constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; | ||
|
|
||
| constexpr int THREADS_X_TRANSP = TILE_DIM_X / 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use 1 convention for macros for transpose? Like you named THREADS_X_T or BUFF_OUT_T_DIM_Y before, so "_T", or "_TRANSP", or "_TR" (like the BUFFS_NUM_OUT_TR one. But just choose 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I’m fine to standardize. Please consolidate any remaining feedback into a single final pass by EOD today so we can merge without another round of incremental comments. This PR has been open since late November and was originally requested as time-sensitive.
| inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, Tensor *output, | ||
| const QuantizationConfig *quant_config, | ||
| cudaStream_t stream) { | ||
| #if FP4_TYPE_SUPPORTED |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But shouldn't the TunableConfig struct is for that purpose? To make sure that these values are tuned separately from the common patrt of the 2 kernels, and we keep the common part in one place only. Then it should not be tight coupling
|
|
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
| if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { | ||
| quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); | ||
| return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: add runtime check for sm_100+ before dispatching to tuned kernel
| if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { | |
| quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); | |
| return; | |
| if (!use_2d_quantization && (input.dtype() == DType::kBFloat16) && is_supported_by_CC_100()) { | |
| quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); | |
| return; | |
| } |
The tuned kernel uses Blackwell-specific instructions (TMA, mbarrier, cluster launch control) that will cause runtime errors on older GPUs. FP4_TYPE_SUPPORTED is only a compile-time check for CUDA version, not device capability.
|
/te-ci |
Description
This PR introduces a specialized CUDA kernel optimized for NVFP4 quantization of BF16 inputs on Blackwell architecture (sm100f family). The implementation achieves performance improvements by leveraging architecture-specific features:
RN: round-to-nearest mode 6.4 TB/s (rowwise only 7.2 TB/s)
SR: stochastic rounding 4.5 TB/s (rowwise only 7.0 TB/s)
Rowwise + Colwise (transpose)
Rowwise only
a) round-to-nearest
b) stochastic rounding
Below are the performance measurements for quantizing tensors using dimensions representative of DSv3 [8192×8, 7168] on internal Cluster (B300).
Using
--fast-mathcan improve performance of the kernel with the stochastic rounding (RNG) by up to ~10%.Threads to data mapping (colwise case)
To reduce shared memory bank conflicts, the following mapping is use when reading from and writing to shmem buffers:
where
SCALE_DIM=16.The arrows in the figure below illustrate how thread indices increment, forming a zigzag pattern.
a) Reads from SHMEM Input Buffer
b) Writes to SHMEM Output Transpose Buffer
Type of change
Changes
Checklist: