From 57d69f5a84064489bb5bb21d8265b226a5a94394 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Mon, 9 Mar 2026 15:23:51 -0700 Subject: [PATCH 1/2] fallback Signed-off-by: Zhongbo Zhu --- tests/pytorch/nvfp4/test_nvfp4_group_quantize.py | 2 ++ transformer_engine/pytorch/csrc/extensions/cast.cpp | 9 ++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 5f35e9ad10..d4bf1fd3a1 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -130,6 +130,8 @@ def check_group_quantization_nvfp4_versus_reference( [ # edge case, zero tokens for all (0, 512), + # edge case, not 128 multiple hidden dimension + (1024, 320), # full tile cases (256, 1024), (1024, 256), diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index f8f793f036..99a6b5b8a1 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1355,9 +1355,16 @@ std::vector split_quantize(const at::Tensor &tensor, for (auto &quantizer : quantizer_cpp_list) { nvfp4_quantizers.push_back(static_cast(quantizer.get())); } - bool contiguous_data_and_scale; + bool contiguous_data_and_scale = false; std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); + if (!input_shape.empty() && input_shape.back() % 128 != 0) { + NVTE_WARN( + "Unfused NVFP4 quantization fallback is triggered because the input tensor inner " + "dimension is not a multiple of 128, disabling NVFP4 grouped kernel fusion. " + "NVFP4 might bring performance regressions for this input tensor shape."); + quantization_method = QuantizationMethod::UNFUSED; + } if (!contiguous_data_and_scale) { // Avoid fused quantize kernel if data is not contiguous quantization_method = QuantizationMethod::UNFUSED; From e145d1803aab1dd1c12246bb8390f7ebf48770c8 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Mon, 9 Mar 2026 15:39:00 -0700 Subject: [PATCH 2/2] warn once Signed-off-by: Zhongbo Zhu --- transformer_engine/pytorch/csrc/extensions/cast.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 99a6b5b8a1..89cd90f347 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -1359,10 +1360,13 @@ std::vector split_quantize(const at::Tensor &tensor, std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); if (!input_shape.empty() && input_shape.back() % 128 != 0) { - NVTE_WARN( - "Unfused NVFP4 quantization fallback is triggered because the input tensor inner " - "dimension is not a multiple of 128, disabling NVFP4 grouped kernel fusion. " - "NVFP4 might bring performance regressions for this input tensor shape."); + static std::once_flag once_unfused_nvfp4_fallback_warning; + std::call_once(once_unfused_nvfp4_fallback_warning, []() { + NVTE_WARN( + "Unfused NVFP4 quantization fallback is triggered because the input tensor inner " + "dimension is not a multiple of 128, disabling NVFP4 grouped kernel fusion. " + "NVFP4 might bring performance regressions for this input tensor shape."); + }); quantization_method = QuantizationMethod::UNFUSED; } if (!contiguous_data_and_scale) {