From 1d0a70eb69ff84bdfd4f32f308e6b4873a1639bd Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Thu, 2 Apr 2026 15:12:12 +0000 Subject: [PATCH 1/9] Rebase onto dev --- tests/cpp/operator/CMakeLists.txt | 1 + .../cpp/operator/test_cast_nvfp4_transpose.cu | 7 - tests/cpp/operator/test_dequantize_nvfp4.cu | 410 ++++++++++++++++++ tests/cpp/test_common.h | 22 +- .../common/cast/dispatch/dequantize.cuh | 4 - ...quantize_transpose_vector_blockwise_fp4.cu | 15 +- 6 files changed, 446 insertions(+), 13 deletions(-) create mode 100644 tests/cpp/operator/test_dequantize_nvfp4.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index dfd8fba29..8a19e84f5 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -14,6 +14,7 @@ list(APPEND test_cuda_sources test_qdq.cu test_cast_mxfp8.cu test_dequantize_mxfp8.cu + test_dequantize_nvfp4.cu test_cast_nvfp4_transpose.cu test_transpose.cu test_cast_transpose.cu diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 4e42fad92..50f2d36fe 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -32,13 +32,6 @@ enum ActivationType { SReLU }; -#ifdef __HIP_PLATFORM_AMD__ -static constexpr float E2M1_LUT[16] = { - 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, - -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f, -}; -#endif - double2 cvt_fp4x2_to_double2(fp4e2m1x2 fp4_pair) { #ifdef __HIP_PLATFORM_AMD__ uint8_t raw = *reinterpret_cast(&fp4_pair); diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu new file mode 100644 index 000000000..1da70923a --- /dev/null +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -0,0 +1,410 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +constexpr size_t kFP4BlockSize1D = 16; +constexpr size_t kFP4BlockSize2DY = 16; +constexpr size_t kFP4BlockSize2DX = 16; + +size_t divide_round_up(size_t x, size_t y) { + return (x + y - 1) / y; +} + +// Generates random FP8 (E4M3) scale values by sampling raw 8-bit patterns +// and rejects non-finite values (i.e., NaN) before storing. +// Values are written using memcpy to preserve exact +// bit patterns rather than relying on numeric conversion. +void generate_1d_scales(fp8e4m3* scale_buffer, + const size_t mathematical_rows, + const size_t mathematical_blocks_per_row, + const size_t physical_row_stride, + std::mt19937& gen, + std::uniform_int_distribution& dis) { + const size_t total_elems = mathematical_rows * physical_row_stride; + std::memset(scale_buffer, 0, total_elems * sizeof(fp8e4m3)); + + for (size_t row = 0; row < mathematical_rows; ++row) { + for (size_t block = 0; block < mathematical_blocks_per_row; ++block) { + const size_t idx = row * physical_row_stride + block; + + while (true) { + const uint8_t bits = static_cast(dis(gen)); + + fp8e4m3 candidate; + std::memcpy(&candidate, &bits, sizeof(bits)); + + const float decoded = static_cast(candidate); + if (std::isfinite(decoded)) { + scale_buffer[idx] = candidate; + break; + } + } + } + } +} + +// Generate compact 2D scales over 16x16 tiles, then replicate them row-wise +// into the physical scale layout expected by the existing 1D dequant kernel. +// +// replicated[row][block_x] = compact_2d[row / 16][block_x] +void generate_2d_scales_with_replication(fp8e4m3* scale_buffer, + const size_t rows, + const size_t cols, + const size_t mathematical_scale_rows, + const size_t mathematical_scale_blocks_per_row, + const size_t physical_row_stride, + std::mt19937& gen, + std::uniform_int_distribution& dis) { + const size_t total_elems = mathematical_scale_rows * physical_row_stride; + std::memset(scale_buffer, 0, total_elems * sizeof(fp8e4m3)); + + const size_t blocks_y = divide_round_up(rows, kFP4BlockSize2DY); + const size_t blocks_x = divide_round_up(cols, kFP4BlockSize2DX); + + std::vector compact_2d(blocks_y * blocks_x); + + for (size_t by = 0; by < blocks_y; ++by) { + for (size_t bx = 0; bx < blocks_x; ++bx) { + while (true) { + const uint8_t bits = static_cast(dis(gen)); + + fp8e4m3 candidate; + std::memcpy(&candidate, &bits, sizeof(bits)); + + const float decoded = static_cast(candidate); + if (std::isfinite(decoded)) { + compact_2d[by * blocks_x + bx] = candidate; + break; + } + } + } + } + + for (size_t row = 0; row < mathematical_scale_rows; ++row) { + const size_t by = row / kFP4BlockSize2DY; + for (size_t bx = 0; bx < mathematical_scale_blocks_per_row; ++bx) { + const size_t dst_idx = row * physical_row_stride + bx; + scale_buffer[dst_idx] = compact_2d[by * blocks_x + bx]; + } + } +} + +// Write one mathematical FP4 E2M1 value, represented as a raw nibble [0, 15], +// into packed storage. Two mathematical FP4 values are packed per byte: +// even mathematical index -> low nibble, odd mathematical index -> high nibble. +void set_fp4_nibble(fp4e2m1* data, const size_t mathematical_idx, const uint8_t nibble) { + ASSERT_TRUE(nibble < 16); + auto* raw = reinterpret_cast(data); + const size_t byte_idx = mathematical_idx / 2; + const uint8_t val = nibble; + + if ((mathematical_idx % 2) == 0) { + // set low nibble + raw[byte_idx] = static_cast((raw[byte_idx] & 0xF0) | val); + } else { + // set high nibble + raw[byte_idx] = static_cast((raw[byte_idx] & 0x0F) | (val << 4)); + } +} + +// Populate FP4 (E2M1) tensor using packed 4-bit encoding, and simultaneously +// populate its mathematical transpose in packed storage. +// +// data has mathematical shape [rows, cols] +// data_t has mathematical shape [cols, rows] +void generate_data_and_transpose(fp4e2m1* data, + fp4e2m1* data_t, + const size_t rows, + const size_t cols, + std::mt19937& gen, + std::uniform_int_distribution& dis) { + const size_t packed_bytes = (rows * cols * BitsNumber::num_bits) / 8; + + std::memset(data, 0, packed_bytes); + std::memset(data_t, 0, packed_bytes); + + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const uint8_t nibble = static_cast(dis(gen)) & 0xF; + + const size_t idx = i * cols + j; + set_fp4_nibble(data, idx, nibble); + + const size_t idx_t = j * rows + i; + set_fp4_nibble(data_t, idx_t, nibble); + } + } +} + +// Decode a single FP4 (E2M1) value from packed storage. +float get_fp4_value(const fp4e2m1* data, const size_t mathematical_idx) { + const auto* raw = reinterpret_cast(data); + const size_t byte_idx = mathematical_idx / 2; + const uint8_t packed = raw[byte_idx]; + const uint8_t nibble = (mathematical_idx % 2 == 0) ? (packed & 0xF) : ((packed >> 4) & 0xF); + return E2M1_LUT[nibble]; +} + +// Reference implementation: dequantize packed FP4 (E2M1) input using per-block FP8_E4M3 scales. +// Each block of 1x16 elements shares one scale; values are decoded to float and scaled. +template +void compute_ref(const fp4e2m1* input, + OutputType* output, + const fp8e4m3* scales, + const float amax, + const size_t rows, + const size_t cols, + const size_t scale_stride) { + constexpr float factor_inv = 1.0f / (6.0f * 448.0f); + + const size_t blocks_per_row = cols / kFP4BlockSize1D; + + for (size_t i = 0; i < rows; ++i) { + for (size_t b = 0; b < blocks_per_row; ++b) { + const float scale = + static_cast(scales[i * scale_stride + b]) * amax * factor_inv; + + for (size_t k = 0; k < kFP4BlockSize1D; ++k) { + const size_t col = b * kFP4BlockSize1D + k; + const size_t idx = i * cols + col; + const float x = get_fp4_value(input, idx); + output[idx] = static_cast(x * scale); + } + } + } +} + +template +void run_single_case(const std::string& case_name, + const fp4e2m1* host_input, + const fp8e4m3* host_scales, + const size_t rows, + const size_t cols, + const size_t blocks_y, + const size_t blocks_x, + const size_t scale_stride, + const float amax, + DType otype) { + const DType itype = DType::kFloat4E2M1; + + Tensor input(case_name + "_input", std::vector{rows, cols}, itype, + true, false, NVTE_NVFP4_1D_SCALING); + Tensor output(case_name + "_output", std::vector{rows, cols}, otype, true, false); + + std::unique_ptr ref_output = + std::make_unique(rows * cols); + + const size_t data_bytes = (rows * cols * BitsNumber::num_bits) / 8; + const size_t scale_bytes = blocks_y * blocks_x * sizeof(fp8e4m3); + + auto err = cudaMemcpy(input.rowwise_dptr(), + host_input, + data_bytes, + cudaMemcpyHostToDevice); + ASSERT_EQ(err, cudaSuccess) << case_name << ": " << cudaGetErrorString(err); + + err = cudaMemcpy(input.rowwise_scale_inv_dptr(), + host_scales, + scale_bytes, + cudaMemcpyHostToDevice); + ASSERT_EQ(err, cudaSuccess) << case_name << ": " << cudaGetErrorString(err); + + input.set_tensor_amax(amax); + + nvte_dequantize(input.data(), output.data(), 0); + + cudaDeviceSynchronize(); + err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << case_name << ": " << cudaGetErrorString(err); + + output.to_cpu(); + + compute_ref(host_input, + ref_output.get(), + host_scales, + amax, + rows, + cols, + scale_stride); + + auto [atol, rtol] = getTolerances(otype); + compareResults(case_name, output, ref_output.get(), true, atol, rtol); +} + +// End-to-end test: generate random FP4 input and FP8 scales, then exercise +// 1) row-wise 1D dequant +// 2) col-wise 1D dequant (by running the same dequant kernel on transposed data) +// 3) 2D dequant semantics using row-wise replicated scales +template +void performTest(const size_t rows, const size_t cols, DType otype) { + const std::array scale_dims = get_scale_tensor_dims(rows, cols, 1, 16); + const std::array scale_dims_t = get_scale_tensor_dims(cols, rows, 1, 16); + + const size_t unpadded_blocks_Y = scale_dims[0]; + const size_t unpadded_blocks_X = scale_dims[1]; + const size_t blocks_Y = scale_dims[2]; + const size_t blocks_X = scale_dims[3]; + const size_t scales_stride = blocks_X; + + const size_t unpadded_blocks_Y_t = scale_dims_t[0]; + const size_t unpadded_blocks_X_t = scale_dims_t[1]; + const size_t blocks_Y_t = scale_dims_t[2]; + const size_t blocks_X_t = scale_dims_t[3]; + const size_t scales_stride_t = blocks_X_t; + + std::unique_ptr host_input = + std::make_unique(rows * cols); + + std::unique_ptr host_input_t = + std::make_unique(rows * cols); + + std::unique_ptr host_scales_rowwise_1d = + std::make_unique(blocks_Y * blocks_X); + + std::unique_ptr host_scales_colwise_1d = + std::make_unique(blocks_Y_t * blocks_X_t); + + std::unique_ptr host_scales_2d_replicated = + std::make_unique(blocks_Y * blocks_X); + + static std::mt19937 gen(42); + std::uniform_int_distribution fp4_dis(0, 15); + std::uniform_int_distribution fp8_dis(0, 255); + + generate_data_and_transpose(host_input.get(), + host_input_t.get(), + rows, + cols, + gen, + fp4_dis); + + // Row-wise 1D scales on [rows, cols] + generate_1d_scales(host_scales_rowwise_1d.get(), + unpadded_blocks_Y, + unpadded_blocks_X, + scales_stride, + gen, + fp8_dis); + + // Col-wise 1D scales on [cols, rows] + generate_1d_scales(host_scales_colwise_1d.get(), + unpadded_blocks_Y_t, + unpadded_blocks_X_t, + scales_stride_t, + gen, + fp8_dis); + + // 2D scales replicated row-wise + generate_2d_scales_with_replication(host_scales_2d_replicated.get(), + rows, + cols, + unpadded_blocks_Y, + unpadded_blocks_X, + scales_stride, + gen, + fp8_dis); + + const float amax = 1.0f; + + run_single_case("rowwise_1d_dequant", + host_input.get(), + host_scales_rowwise_1d.get(), + rows, + cols, + blocks_Y, + blocks_X, + scales_stride, + amax, + otype); + + run_single_case("colwise_1d_dequant", + host_input_t.get(), + host_scales_colwise_1d.get(), + cols, + rows, + blocks_Y_t, + blocks_X_t, + scales_stride_t, + amax, + otype); + + run_single_case("replicated_2d_dequant", + host_input.get(), + host_scales_2d_replicated.get(), + rows, + cols, + blocks_Y, + blocks_X, + scales_stride, + amax, + otype); +} + +std::vector> tensor_dims = { + {32, 32}, + {32, 64}, + {64, 32}, + {64, 96}, + {128, 128}, + {256, 256}, + {512, 512}, + {1024, 1024}, + {2048, 2048}, +}; + +} // namespace + +class DequantizeNVFP4TestSuite + : public ::testing::TestWithParam< + std::tuple, transformer_engine::DType>> {}; + +TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) { + const auto tensor_size = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + + const size_t rows = tensor_size.first; + const size_t cols = tensor_size.second; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY( + output_type, OutputType, + performTest(rows, cols, output_type);); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + DequantizeNVFP4TestSuite, + ::testing::Combine( + ::testing::ValuesIn(tensor_dims), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + [](const testing::TestParamInfo& info) { + std::string name = + std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "X" + + test::typeName(std::get<1>(info.param)); + return name; + }); diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index e2bfdfd57..9dfc2e1b5 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -301,6 +301,21 @@ class Tensor { tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); } + void set_tensor_amax(float amax) { + if (!amax_cpu_data_) { + amax_cpu_data_ = std::make_shared(amax); + } else { + *amax_cpu_data_ = amax; + } + + float *amax_gpu = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&amax_gpu, sizeof(float))); + NVTE_CHECK_CUDA(cudaMemcpy(amax_gpu, amax_cpu_data_.get(), + sizeof(float), cudaMemcpyHostToDevice)); + + tensor_.set_amax(amax_gpu, DType::kFloat32, tensor_.defaultShape); + } + void to_cpu() const; void from_cpu() const; void set_scale(float scale); @@ -356,6 +371,11 @@ constexpr size_t scale_tensor_alignment_Y_colwise = 4; constexpr size_t scale_tensor_alignment_X_colwise = 128; #endif +static constexpr float E2M1_LUT[16] = { + 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, + -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f, +}; + inline size_t divide_round_up(const size_t N, const size_t M) { return (N - 1 + M) / M; } @@ -519,7 +539,7 @@ template void compare_scaling_factors(const std::string &name, const T *test, const T *ref, const size_t row_blocks, const size_t col_blocks, const size_t stride, #ifdef USE_ROCM - std::vector& mismatch_indices, + std::vector& mismatch_indices, #endif //#ifdef USE_ROCM size_t& mismatches_num, const size_t scale_diff_abs_tolerance = 0, diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index 1bdd7e218..6b70c7582 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -18,9 +18,7 @@ #include "../../common.h" #include "../fp8/dequantize_fp8.cuh" #include "../mxfp8/dequantize_mxfp8.cuh" -#ifndef __HIP_PLATFORM_AMD__ #include "../nvfp4/dequantize_nvfp4.cuh" -#endif //#ifndef __HIP_PLATFORM_AMD__ namespace transformer_engine { namespace dispatch { @@ -49,12 +47,10 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t #endif //#ifndef __HIP_PLATFORM_AMD__ break; } -#ifndef __HIP_PLATFORM_AMD__ case NVTE_NVFP4_1D_SCALING: { nvfp4::dequantize(input, output, stream); break; } -#endif //#ifndef __HIP_PLATFORM_AMD__ default: NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); } diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 59742d1e7..bdcfd8d0b 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -157,12 +157,25 @@ static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kT static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); // for 2D block scaling, we need to reduce amax in warp +#ifdef __HIP_PLATFORM_AMD__ +static __device__ constexpr uint64_t WARP_REDUCE_AMAX_GROUP_MASKS[8] = { + 0x0101010101010101ULL, 0x0202020202020202ULL, + 0x0404040404040404ULL, 0x0808080808080808ULL, + 0x1010101010101010ULL, 0x2020202020202020ULL, + 0x4040404040404040ULL, 0x8080808080808080ULL}; +#else static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = { 0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080}; +#endif // max for every group_size elements in warp template -__device__ __forceinline__ float groupMax(float val, unsigned int groupMask) { +__device__ __forceinline__ float groupMax(float val, +#ifdef __HIP_PLATFORM_AMD__ + uint64_t groupMask) { +#else + unsigned int groupMask) { +#endif for (int offset = group_size / 2; offset > 0; offset /= 2) { #ifdef __HIP_PLATFORM_AMD__ (void)groupMask; // unused on AMD, __shfl_down does not take a mask From d6837f622518564c62a2b010f5aa2ba80be6ca9c Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 7 Apr 2026 15:42:18 +0000 Subject: [PATCH 2/9] address pr comments: remove while loop based method for generating e4m3 finite encodings, use set_scale() for amax, check for amax being null in kernel --- tests/cpp/operator/test_dequantize_nvfp4.cu | 114 +++++++----------- tests/cpp/test_common.h | 17 +-- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 2 +- 3 files changed, 48 insertions(+), 85 deletions(-) diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 1da70923a..96ded197a 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -31,39 +31,24 @@ constexpr size_t kFP4BlockSize1D = 16; constexpr size_t kFP4BlockSize2DY = 16; constexpr size_t kFP4BlockSize2DX = 16; -size_t divide_round_up(size_t x, size_t y) { - return (x + y - 1) / y; -} - -// Generates random FP8 (E4M3) scale values by sampling raw 8-bit patterns -// and rejects non-finite values (i.e., NaN) before storing. +// Generates random FP8 (E4M3) scale values by sampling raw 8-bit patterns. +// Only finite, non-negative scales are allowed. // Values are written using memcpy to preserve exact // bit patterns rather than relying on numeric conversion. void generate_1d_scales(fp8e4m3* scale_buffer, - const size_t mathematical_rows, - const size_t mathematical_blocks_per_row, - const size_t physical_row_stride, - std::mt19937& gen, - std::uniform_int_distribution& dis) { - const size_t total_elems = mathematical_rows * physical_row_stride; + const size_t unpadded_blocks_Y, + const size_t unpadded_blocks_X, + const size_t scales_stride, + std::mt19937& gen, + std::uniform_int_distribution& finite_nonneg_e4m3_dis) { + const size_t total_elems = unpadded_blocks_Y * scales_stride; std::memset(scale_buffer, 0, total_elems * sizeof(fp8e4m3)); - for (size_t row = 0; row < mathematical_rows; ++row) { - for (size_t block = 0; block < mathematical_blocks_per_row; ++block) { - const size_t idx = row * physical_row_stride + block; - - while (true) { - const uint8_t bits = static_cast(dis(gen)); - - fp8e4m3 candidate; - std::memcpy(&candidate, &bits, sizeof(bits)); - - const float decoded = static_cast(candidate); - if (std::isfinite(decoded)) { - scale_buffer[idx] = candidate; - break; - } - } + for (size_t row = 0; row < unpadded_blocks_Y; ++row) { + for (size_t block = 0; block < unpadded_blocks_X; ++block) { + const size_t scale_idx = row * scales_stride + block; + const uint8_t scale = static_cast(finite_nonneg_e4m3_dis(gen)); + std::memcpy(&scale_buffer[scale_idx], &scale, sizeof(scale)); } } } @@ -75,12 +60,12 @@ void generate_1d_scales(fp8e4m3* scale_buffer, void generate_2d_scales_with_replication(fp8e4m3* scale_buffer, const size_t rows, const size_t cols, - const size_t mathematical_scale_rows, - const size_t mathematical_scale_blocks_per_row, - const size_t physical_row_stride, + const size_t unpadded_blocks_Y, + const size_t unpadded_blocks_X, + const size_t scales_stride, std::mt19937& gen, - std::uniform_int_distribution& dis) { - const size_t total_elems = mathematical_scale_rows * physical_row_stride; + std::uniform_int_distribution& finite_nonneg_e4m3_dis) { + const size_t total_elems = unpadded_blocks_Y * scales_stride; std::memset(scale_buffer, 0, total_elems * sizeof(fp8e4m3)); const size_t blocks_y = divide_round_up(rows, kFP4BlockSize2DY); @@ -90,26 +75,17 @@ void generate_2d_scales_with_replication(fp8e4m3* scale_buffer, for (size_t by = 0; by < blocks_y; ++by) { for (size_t bx = 0; bx < blocks_x; ++bx) { - while (true) { - const uint8_t bits = static_cast(dis(gen)); - - fp8e4m3 candidate; - std::memcpy(&candidate, &bits, sizeof(bits)); - - const float decoded = static_cast(candidate); - if (std::isfinite(decoded)) { - compact_2d[by * blocks_x + bx] = candidate; - break; - } - } + const size_t compact_idx = by * blocks_x + bx; + const uint8_t scale = static_cast(finite_nonneg_e4m3_dis(gen)); + std::memcpy(&compact_2d[compact_idx], &scale, sizeof(scale)); } } - for (size_t row = 0; row < mathematical_scale_rows; ++row) { + for (size_t row = 0; row < unpadded_blocks_Y; ++row) { const size_t by = row / kFP4BlockSize2DY; - for (size_t bx = 0; bx < mathematical_scale_blocks_per_row; ++bx) { - const size_t dst_idx = row * physical_row_stride + bx; - scale_buffer[dst_idx] = compact_2d[by * blocks_x + bx]; + for (size_t bx = 0; bx < unpadded_blocks_X; ++bx) { + const size_t scale_idx = row * scales_stride + bx; + scale_buffer[scale_idx] = compact_2d[by * blocks_x + bx]; } } } @@ -142,7 +118,7 @@ void generate_data_and_transpose(fp4e2m1* data, const size_t rows, const size_t cols, std::mt19937& gen, - std::uniform_int_distribution& dis) { + std::uniform_int_distribution& e2m1_dis) { const size_t packed_bytes = (rows * cols * BitsNumber::num_bits) / 8; std::memset(data, 0, packed_bytes); @@ -150,7 +126,7 @@ void generate_data_and_transpose(fp4e2m1* data, for (size_t i = 0; i < rows; ++i) { for (size_t j = 0; j < cols; ++j) { - const uint8_t nibble = static_cast(dis(gen)) & 0xF; + const uint8_t nibble = static_cast(e2m1_dis(gen)) & 0xF; const size_t idx = i * cols + j; set_fp4_nibble(data, idx, nibble); @@ -162,11 +138,11 @@ void generate_data_and_transpose(fp4e2m1* data, } // Decode a single FP4 (E2M1) value from packed storage. -float get_fp4_value(const fp4e2m1* data, const size_t mathematical_idx) { +float get_fp4_value(const fp4e2m1* data, const size_t idx) { const auto* raw = reinterpret_cast(data); - const size_t byte_idx = mathematical_idx / 2; + const size_t byte_idx = idx / 2; const uint8_t packed = raw[byte_idx]; - const uint8_t nibble = (mathematical_idx % 2 == 0) ? (packed & 0xF) : ((packed >> 4) & 0xF); + const uint8_t nibble = (idx % 2 == 0) ? (packed & 0xF) : ((packed >> 4) & 0xF); return E2M1_LUT[nibble]; } @@ -234,7 +210,7 @@ void run_single_case(const std::string& case_name, cudaMemcpyHostToDevice); ASSERT_EQ(err, cudaSuccess) << case_name << ": " << cudaGetErrorString(err); - input.set_tensor_amax(amax); + input.set_scale(amax); nvte_dequantize(input.data(), output.data(), 0); @@ -293,31 +269,31 @@ void performTest(const size_t rows, const size_t cols, DType otype) { std::make_unique(blocks_Y * blocks_X); static std::mt19937 gen(42); - std::uniform_int_distribution fp4_dis(0, 15); - std::uniform_int_distribution fp8_dis(0, 255); + std::uniform_int_distribution e2m1_dis(0, 15); + std::uniform_int_distribution finite_nonneg_e4m3_dis(0, 126); generate_data_and_transpose(host_input.get(), host_input_t.get(), rows, cols, gen, - fp4_dis); + e2m1_dis); // Row-wise 1D scales on [rows, cols] generate_1d_scales(host_scales_rowwise_1d.get(), - unpadded_blocks_Y, - unpadded_blocks_X, - scales_stride, - gen, - fp8_dis); + unpadded_blocks_Y, + unpadded_blocks_X, + scales_stride, + gen, + finite_nonneg_e4m3_dis); // Col-wise 1D scales on [cols, rows] generate_1d_scales(host_scales_colwise_1d.get(), - unpadded_blocks_Y_t, - unpadded_blocks_X_t, - scales_stride_t, - gen, - fp8_dis); + unpadded_blocks_Y_t, + unpadded_blocks_X_t, + scales_stride_t, + gen, + finite_nonneg_e4m3_dis); // 2D scales replicated row-wise generate_2d_scales_with_replication(host_scales_2d_replicated.get(), @@ -327,7 +303,7 @@ void performTest(const size_t rows, const size_t cols, DType otype) { unpadded_blocks_X, scales_stride, gen, - fp8_dis); + finite_nonneg_e4m3_dis); const float amax = 1.0f; diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 9dfc2e1b5..ae96cfeac 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -301,21 +301,6 @@ class Tensor { tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); } - void set_tensor_amax(float amax) { - if (!amax_cpu_data_) { - amax_cpu_data_ = std::make_shared(amax); - } else { - *amax_cpu_data_ = amax; - } - - float *amax_gpu = nullptr; - NVTE_CHECK_CUDA(cudaMalloc(&amax_gpu, sizeof(float))); - NVTE_CHECK_CUDA(cudaMemcpy(amax_gpu, amax_cpu_data_.get(), - sizeof(float), cudaMemcpyHostToDevice)); - - tensor_.set_amax(amax_gpu, DType::kFloat32, tensor_.defaultShape); - } - void to_cpu() const; void from_cpu() const; void set_scale(float scale); @@ -371,10 +356,12 @@ constexpr size_t scale_tensor_alignment_Y_colwise = 4; constexpr size_t scale_tensor_alignment_X_colwise = 128; #endif +#ifdef __HIP_PLATFORM_AMD__ static constexpr float E2M1_LUT[16] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f, }; +#endif inline size_t divide_round_up(const size_t N, const size_t M) { return (N - 1 + M) / M; diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 38677a707..66348e082 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -57,7 +57,7 @@ __global__ void __launch_bounds__(512) fp4vec value; value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; - float amax = *tensor_amax; + float amax = (tensor_amax == nullptr) ? 1.0f : *tensor_amax; constexpr float factor_inv = 1.0 / (6.0 * 448.0); float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll From d3f8dd1b42d643909d1fd433836a8af74ace39c9 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 7 Apr 2026 15:52:43 +0000 Subject: [PATCH 3/9] fix prior merge error --- .../quantize_transpose_vector_blockwise_fp4.cu | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index bdcfd8d0b..59742d1e7 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -157,25 +157,12 @@ static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kT static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); // for 2D block scaling, we need to reduce amax in warp -#ifdef __HIP_PLATFORM_AMD__ -static __device__ constexpr uint64_t WARP_REDUCE_AMAX_GROUP_MASKS[8] = { - 0x0101010101010101ULL, 0x0202020202020202ULL, - 0x0404040404040404ULL, 0x0808080808080808ULL, - 0x1010101010101010ULL, 0x2020202020202020ULL, - 0x4040404040404040ULL, 0x8080808080808080ULL}; -#else static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = { 0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080}; -#endif // max for every group_size elements in warp template -__device__ __forceinline__ float groupMax(float val, -#ifdef __HIP_PLATFORM_AMD__ - uint64_t groupMask) { -#else - unsigned int groupMask) { -#endif +__device__ __forceinline__ float groupMax(float val, unsigned int groupMask) { for (int offset = group_size / 2; offset > 0; offset /= 2) { #ifdef __HIP_PLATFORM_AMD__ (void)groupMask; // unused on AMD, __shfl_down does not take a mask From 2cb08eb13d5b80d68dcf540ce5659b9d878fa055 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 7 Apr 2026 15:54:27 +0000 Subject: [PATCH 4/9] Sync cudnn-frontend submodule pointer with dev --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index be6c079be..0258951d4 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit be6c079be8aaffa0fc079fcf039887e637c289c7 +Subproject commit 0258951d4d512f4714eb1574496f4d57669b1b93 From 25d72fff60dc8703fb37757e2903cdd7f0ee205d Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 7 Apr 2026 18:25:06 +0000 Subject: [PATCH 5/9] add ifdef guard around rocm specific handling of amax --- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index f55ffb201..d7e782a28 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -57,7 +57,16 @@ __global__ void __launch_bounds__(512) fp4vec value; value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; - float amax = (tensor_amax == nullptr) ? 1.0f : *tensor_amax; + // NVFP4 may reach this path with scale present but no separate amax buffer. + // Use 1.0f as the neutral fallback when tensor_amax is not provided on HIP. + float amax = 1.0f; +#ifndef __HIP_PLATFORM_AMD__ + amax = *tensor_amax; +#else + if (tensor_amax != nullptr) { + amax = *tensor_amax; + } +#endif constexpr float factor_inv = 1.0 / (6.0 * 448.0); float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll From f941143a2860c506765ce02aa111b8a56cc91aa2 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 8 Apr 2026 05:02:34 +0000 Subject: [PATCH 6/9] Limit NVFP4 dequant test to null-amax path on ROCm --- tests/cpp/operator/test_dequantize_nvfp4.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 96ded197a..cc0f3fdf5 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -210,8 +210,6 @@ void run_single_case(const std::string& case_name, cudaMemcpyHostToDevice); ASSERT_EQ(err, cudaSuccess) << case_name << ": " << cudaGetErrorString(err); - input.set_scale(amax); - nvte_dequantize(input.data(), output.data(), 0); cudaDeviceSynchronize(); @@ -305,6 +303,9 @@ void performTest(const size_t rows, const size_t cols, DType otype) { gen, finite_nonneg_e4m3_dis); + // With the current test_common NVFP4 helper path on ROCm, there is no direct + // way to populate a separate global amax buffer for dequant, so this test + // explicitly covers the HIP nullptr -> 1.0f fallback path for now. const float amax = 1.0f; run_single_case("rowwise_1d_dequant", From 92de4bd9f932f111548ccb5a287c5233744d4a1d Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 8 Apr 2026 13:01:57 +0000 Subject: [PATCH 7/9] remove unnecessary cast in set_fp4_nibble --- tests/cpp/operator/test_dequantize_nvfp4.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index cc0f3fdf5..05241f8d4 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -101,10 +101,10 @@ void set_fp4_nibble(fp4e2m1* data, const size_t mathematical_idx, const uint8_t if ((mathematical_idx % 2) == 0) { // set low nibble - raw[byte_idx] = static_cast((raw[byte_idx] & 0xF0) | val); + raw[byte_idx] = (raw[byte_idx] & 0xF0) | val; } else { // set high nibble - raw[byte_idx] = static_cast((raw[byte_idx] & 0x0F) | (val << 4)); + raw[byte_idx] = (raw[byte_idx] & 0x0F) | (val << 4); } } From 458803f6568a444ef216e2a5ddec916bb718fcaa Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 8 Apr 2026 16:05:44 +0000 Subject: [PATCH 8/9] address pr comments: update copyright and ifdef guard --- transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index d7e782a28..48ba53274 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -59,10 +61,10 @@ __global__ void __launch_bounds__(512) fp8e4m3 scale = scales[my_scale_index]; // NVFP4 may reach this path with scale present but no separate amax buffer. // Use 1.0f as the neutral fallback when tensor_amax is not provided on HIP. - float amax = 1.0f; #ifndef __HIP_PLATFORM_AMD__ - amax = *tensor_amax; + float amax = *tensor_amax; #else + float amax = 1.0f; if (tensor_amax != nullptr) { amax = *tensor_amax; } From dfc00287e1c00071a051a36ede062c23bf1ccd0d Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Fri, 10 Apr 2026 19:49:27 +0000 Subject: [PATCH 9/9] address pr comment: initialize amax with ternary operator --- transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 48ba53274..94fc16b03 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -64,10 +64,7 @@ __global__ void __launch_bounds__(512) #ifndef __HIP_PLATFORM_AMD__ float amax = *tensor_amax; #else - float amax = 1.0f; - if (tensor_amax != nullptr) { - amax = *tensor_amax; - } + float amax = (tensor_amax != nullptr) ? *tensor_amax : 1.0f; #endif constexpr float factor_inv = 1.0 / (6.0 * 448.0); float final_scale = static_cast(scale) * amax * factor_inv;