From 88cf1b27c74afbc7f247ba763715246797d4db4a Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 21 Jan 2026 16:50:51 +0000 Subject: [PATCH 01/12] Rebased to main Signed-off-by: Oleg Goncharov --- tests/cpp/operator/CMakeLists.txt | 1 + tests/cpp/operator/test_cast_mxfp8_grouped.cu | 660 +++++++++++++ transformer_engine/common/cast/cast.cu | 10 + .../common/cast/dispatch/quantize_grouped.cuh | 122 +++ .../cast/mxfp8/quantize_grouped_mxfp8.cuh | 933 ++++++++++++++++++ .../common/include/transformer_engine/cast.h | 11 + 6 files changed, 1737 insertions(+) create mode 100644 tests/cpp/operator/test_cast_mxfp8_grouped.cu create mode 100644 transformer_engine/common/cast/dispatch/quantize_grouped.cuh create mode 100644 transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 26efb37962..6c615d63e0 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -11,6 +11,7 @@ add_executable(test_operator test_cast_mxfp8_gated_swiglu.cu test_qdq.cu test_cast_mxfp8.cu + test_cast_mxfp8_grouped.cu test_cast_nvfp4_transpose.cu test_cast_float8blockwise.cu test_dequantize_mxfp8.cu diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu new file mode 100644 index 0000000000..9ce7b34bdd --- /dev/null +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -0,0 +1,660 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum ProcessingMethod { + CAST_ONLY, + CAST_DBIAS, + CAST_DBIAS_DACT, + CAST_DACT, + CAST_ACT +}; + +enum ActivationKind { + Identity, + GeLU, + SiLU, + ReLU, + QGeLU, + SReLU +}; + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + +template +void compute_ref(const ProcessingMethod processing_method, + float (*OP)(const float), + const bool rowwise, + const bool colwise, + const InputType* input, + const InputType* grad, + OutputType* output_rowwise, + OutputType* output_colwise, + fp8e8m0* output_scales_rowwise, + fp8e8m0* output_scales_colwise, + InputType* output_dbias, + const size_t rows, + const size_t cols, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) +{ + const size_t tile_size_Y = 32; + const size_t tile_size_X = 32; + const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; + const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; + + std::vector output_dbias_fp32(cols, 0); + #pragma omp parallel proc_bind(spread) + { + // Buffers to cache intermediate computations + std::vector cache_buffer(tile_size_Y * tile_size_X); + + std::vector thread_dbias(cols, 0); + #pragma omp for schedule(static) + for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { + const size_t tile_Y = t / tiles_num_X; + const size_t tile_X = t % tiles_num_X; + const size_t tile_offset_Y = tile_Y * tile_size_Y; + const size_t tile_offset_X = tile_X * tile_size_X; + + const size_t i_min = tile_offset_Y; + const size_t i_max = std::min(i_min + tile_size_Y, rows); + + const size_t j_min = tile_offset_X; + const size_t j_max = std::min(j_min + tile_size_X, cols); + + // Cache computations + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + + float elt = static_cast(input[idx]); + // if (processing_method == ProcessingMethod::CAST_DBIAS) { + // // grad is the input + // elt = static_cast(grad[idx]); + // } + if (processing_method != ProcessingMethod::CAST_ONLY + && processing_method != ProcessingMethod::CAST_DBIAS) { + elt = OP(elt); + } + // if (processing_method == ProcessingMethod::CAST_DACT || + // processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + // elt *= static_cast(grad[idx]); + // } + thread_dbias[j] += elt; + + // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32 + elt = static_cast(static_cast(elt)); + + cache_buffer[cache_idx] = elt; + if (isinf(elt) || isnan(elt)) { + continue; + } + } + } + + if (rowwise) { + for (size_t i = i_min; i < i_max; ++i) { + float block_amax = 0.0f; + + for (size_t j = j_min; j < j_max; ++j) { + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const size_t scale_idx = i * scales_stride_rowwise + tile_X; + output_scales_rowwise[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_rowwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } + } + } + if (colwise) { + for (size_t j = j_min; j < j_max; ++j) { + float block_amax = 0.0f; + + for (size_t i = i_min; i < i_max; ++i) { + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const size_t scale_idx = tile_Y * scales_stride_colwise + j; + output_scales_colwise[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t i = i_min; i < i_max; ++i) { + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_colwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } + } + } + } + #pragma omp critical + { + for (size_t j = 0; j < cols; ++j) { + output_dbias_fp32[j] += thread_dbias[j]; + } + } + } + // for (size_t j = 0; j < cols; ++j) { + // output_dbias[j] = static_cast(output_dbias_fp32[j]); + // } +} + +template +void compare_scaled_elts(const std::string &name, + const T* ref_data, + const T* test_data, + const size_t rows, + const size_t cols, + const bool rowwise, + const size_t tolerable_mismatches_limit = 0, + const double atol = 1e-5, + const double rtol = 1e-8) { + size_t mismatches_num = 0; + int first_mismatch_idx = -1; + + for (size_t i = 0; i < rows * cols; ++i) { + double t = static_cast(test_data[i]); + double r = static_cast(ref_data[i]); + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = false; + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + std::string direction = rowwise ? "rowwise" : "columnwise"; + if (assertion) { + mismatches_num++; + if (first_mismatch_idx == -1) { + first_mismatch_idx = i; + } + } + if (mismatches_num > tolerable_mismatches_limit) { + const double first_mismatch_t = static_cast(test_data[first_mismatch_idx]); + const double first_mismatch_r = static_cast(ref_data[first_mismatch_idx]); + + GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "." << std::endl + << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "First mismatch at place " << first_mismatch_idx + << " (" << std::to_string(first_mismatch_idx) << "): " + << first_mismatch_t << " vs " << first_mismatch_r; + } + } +} + +/** + * Scaling along single dimension (either rows or columns) + * Produces one set of output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * OR + * 2) Scaled columns + column-wise scaling factors + */ +template +void performTest(const ProcessingMethod processing_method, + float (*OP)(const float), + const ShapeRepresentation shape_rep, + const size_t num_tensors, + const std::vector& logical_shape_vec, + const std::vector& first_dims_h, + const std::vector& last_dims_h, + const std::vector& offsets_h, + const bool rowwise, + const bool colwise) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = logical_shape_vec[0]; + const size_t cols = logical_shape_vec[1]; + + const size_t elts_num = rows * cols; + const size_t sfs_num = (rows * cols) / 32; + + std::vector scales_shape = {sfs_num}; + + std::mt19937 gen; + std::uniform_real_distribution<> dis(-2.0, 1.0); + + std::vector in_data(elts_num); + + std::vector out_data_rowwise_h(rowwise ? elts_num : 0); + std::vector out_data_colwise_h(colwise ? elts_num : 0); + std::vector out_scales_rowwise_h(rowwise ? sfs_num : 0); + std::vector out_scales_colwise_h(colwise ? sfs_num : 0); + + std::vector out_data_rowwise_ref(rowwise ? elts_num : 0); + std::vector out_data_colwise_ref(colwise ? elts_num : 0); + std::vector out_scales_rowwise_ref(rowwise ? sfs_num : 0); + std::vector out_scales_colwise_ref(colwise ? sfs_num : 0); + + for (size_t i = 0; i < elts_num; ++i) { + const float val = dis(gen); + in_data[i] = static_cast(val); + } + + const OutputType zero_elt = static_cast(0.0f); + const fp8e8m0 zero_SF = static_cast(0.0f); + if (rowwise) { + std::fill(out_data_rowwise_h.begin(), out_data_rowwise_h.end(), zero_elt); + std::fill(out_data_rowwise_ref.begin(), out_data_rowwise_ref.end(), zero_elt); + std::fill(out_scales_rowwise_h.begin(), out_scales_rowwise_h.end(), zero_SF); + std::fill(out_scales_rowwise_ref.begin(), out_scales_rowwise_ref.end(), zero_SF); + } + if (colwise) { + std::fill(out_data_colwise_h.begin(), out_data_colwise_h.end(), zero_elt); + std::fill(out_data_colwise_ref.begin(), out_data_colwise_ref.end(), zero_elt); + std::fill(out_scales_colwise_h.begin(), out_scales_colwise_h.end(), zero_SF); + std::fill(out_scales_colwise_ref.begin(), out_scales_colwise_ref.end(), zero_SF); + } + + const size_t in_data_size = elts_num * sizeof(InputType); + const size_t out_data_size = elts_num * sizeof(OutputType); + const size_t out_scales_size = sfs_num * sizeof(fp8e8m0); + + const size_t first_dims_size = num_tensors * sizeof(size_t); + const size_t last_dims_size = num_tensors * sizeof(size_t); + const size_t offsets_size = (num_tensors + 1) * sizeof(size_t); + + InputType* in_data_d; + OutputType* out_data_rowwise_d; + OutputType* out_data_colwise_d; + fp8e8m0* out_scales_rowwise_d; + fp8e8m0* out_scales_colwise_d; + size_t* first_dims_d; + size_t* last_dims_d; + size_t* offsets_d; + + cudaMalloc((void**)&in_data_d, in_data_size); + cudaMalloc((void**)&first_dims_d, first_dims_size); + cudaMalloc((void**)&last_dims_d, last_dims_size); + cudaMalloc((void**)&offsets_d, offsets_size); + + cudaMemcpy(in_data_d, in_data.data(), in_data_size, cudaMemcpyHostToDevice); + cudaMemcpy(first_dims_d, first_dims_h.data(), first_dims_size, cudaMemcpyHostToDevice); + cudaMemcpy(last_dims_d, last_dims_h.data(), last_dims_size, cudaMemcpyHostToDevice); + cudaMemcpy(offsets_d, offsets_h.data(), offsets_size, cudaMemcpyHostToDevice); + + NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size()); + + NVTEShape first_dims_shape_; + NVTEShape last_dims_shape_; + NVTEShape offsets_shape_; + + first_dims_shape_.ndim = 1; + last_dims_shape_.ndim = 1; + offsets_shape_.ndim = 1; + + first_dims_shape_.data[0] = num_tensors; + last_dims_shape_.data[0] = num_tensors; + offsets_shape_.data[0] = num_tensors + 1; + + NVTEGroupedTensor in_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); + NVTEGroupedTensor out_group_tensor = nvte_create_grouped_tensor(NVTE_MXFP8_1D_SCALING, num_tensors, logical_shape_); + + NVTEBasicTensor in_data_tensor = {in_data_d, static_cast(itype), logical_shape_}; + nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor); + + if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { + NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape_}; + nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor); + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor); + } + + if ((shape_rep == VARYING_LAST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { + NVTEBasicTensor last_dims_tensor = {last_dims_d, kNVTEInt64, last_dims_shape_}; + nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor); + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor); + } + + if (shape_rep != SAME_BOTH_DIMS) { + NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64, offsets_shape_}; + nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor); + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor); + } + + if (rowwise) { + cudaMalloc((void**)&out_data_rowwise_d, out_data_size); + cudaMalloc((void**)&out_scales_rowwise_d, out_scales_size); + cudaMemset(out_data_rowwise_d, 0, out_data_size); + cudaMemset(out_scales_rowwise_d, 0, out_scales_size); + NVTEBasicTensor out_data_rowwise_tensor = {out_data_rowwise_d, static_cast(otype), logical_shape_}; + NVTEShape scales_rowwise_shape_ = nvte_make_shape(scales_shape.data(), scales_shape.size()); + NVTEBasicTensor out_scales_rowwise_tensor = {out_scales_rowwise_d, NVTEDType::kNVTEFloat8E8M0, scales_rowwise_shape_}; + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &out_data_rowwise_tensor); + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv, &out_scales_rowwise_tensor); + } + + if (colwise) { + cudaMalloc((void**)&out_data_colwise_d, out_data_size); + cudaMalloc((void**)&out_scales_colwise_d, out_scales_size); + cudaMemset(out_data_colwise_d, 0, out_data_size); + cudaMemset(out_scales_colwise_d, 0, out_scales_size); + NVTEBasicTensor out_data_colwise_tensor = {out_data_colwise_d, static_cast(otype), logical_shape_}; + NVTEShape scales_colwise_shape_ = nvte_make_shape(scales_shape.data(), scales_shape.size()); + NVTEBasicTensor out_scales_colwise_tensor = {out_scales_colwise_d, NVTEDType::kNVTEFloat8E8M0, scales_colwise_shape_}; + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData, &out_data_colwise_tensor); + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, &out_scales_colwise_tensor); + } + + // Reference (CPU) + for (size_t t = 0; t < num_tensors; ++t) { + const size_t M = first_dims_h[t]; + const size_t K = last_dims_h[t]; + + const size_t scales_stride_rowwise = K / 32; + const size_t scales_stride_colwise = K; + const size_t data_offset = offsets_h[t]; + const size_t sfs_offset = data_offset / 32; + + const InputType* const in_ptr = in_data.data() + data_offset; + OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; + OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; + fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + sfs_offset; + fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + sfs_offset; + + compute_ref( + processing_method, OP, rowwise, colwise, in_ptr, /*grad=*/ nullptr, + out_data_rowwise_ptr, out_data_colwise_ptr, + out_scales_rowwise_ptr, out_scales_colwise_ptr, + /*output_dbias=*/ nullptr, M, K, + scales_stride_rowwise, + scales_stride_colwise); + } + + // GPU + nvte_quantize_grouped(in_group_tensor, out_group_tensor, 0); + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + auto [atol, rtol] = getTolerances(otype); + const size_t scale_diff_abs_tolerance = 0; + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; + + if (rowwise) { + cudaMemcpy(out_data_rowwise_h.data(), out_data_rowwise_d, out_data_size, cudaMemcpyDeviceToHost); + cudaMemcpy(out_scales_rowwise_h.data(), out_scales_rowwise_d, out_scales_size, cudaMemcpyDeviceToHost); + + size_t mismatches_scales = 0; + compare_scaling_factors("rowwise_scales", out_scales_rowwise_h.data(), out_scales_rowwise_ref.data(), + 1, sfs_num, sfs_num, mismatches_scales, scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); + + const size_t mismatches_elts = 32 * mismatches_scales; + + compare_scaled_elts("rowwise_output", out_data_rowwise_ref.data(), + out_data_rowwise_h.data(), rows, cols, true, mismatches_elts); + } + + if (colwise) { + cudaMemcpy(out_data_colwise_h.data(), out_data_colwise_d, out_data_size, cudaMemcpyDeviceToHost); + cudaMemcpy(out_scales_colwise_h.data(), out_scales_colwise_d, out_scales_size, cudaMemcpyDeviceToHost); + + size_t mismatches_scales = 0; + compare_scaling_factors("colwise_scales", out_scales_colwise_h.data(), out_scales_colwise_ref.data(), + 1, sfs_num, sfs_num, mismatches_scales, scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); + + const size_t mismatches_elts = 32 * mismatches_scales; + + compare_scaled_elts("colwise_output", out_data_colwise_ref.data(), + out_data_colwise_h.data(), rows, cols, false, mismatches_elts); + } + + cudaFree(in_data_d); + cudaFree(first_dims_d); + cudaFree(last_dims_d); + cudaFree(offsets_d); + if (rowwise) { + cudaFree(out_data_rowwise_d); + cudaFree(out_scales_rowwise_d); + } + if (colwise) { + cudaFree(out_data_colwise_d); + cudaFree(out_scales_colwise_d); + } +} + +std::vector processing_methods = { + ProcessingMethod::CAST_ONLY, + // ProcessingMethod::CAST_DBIAS, + // ProcessingMethod::CAST_DBIAS_DACT, + // ProcessingMethod::CAST_DACT, + // ProcessingMethod::CAST_ACT, +}; + +// Only GeLU activation tests are supported +std::vector activation_kinds = { + ActivationKind::Identity, + // ActivationKind::GeLU, + // ActivationKind::SiLU, + // ActivationKind::ReLU, + // ActivationKind::QGeLU, + // ActivationKind::SReLU, +}; + +enum ScalingDirection { + ROWWISE = 0, + COLWISE = 1, + BOTH = 2 +}; + +std::vector scaling_directions = { + ScalingDirection::ROWWISE, + ScalingDirection::COLWISE, + ScalingDirection::BOTH, +}; + +// {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} +std::vector> input_config = { + {SAME_BOTH_DIMS, 1, 128,128}, + {SAME_BOTH_DIMS, 2, 256,128}, + {VARYING_FIRST_DIM, 2, 512,128, 128,384}, + {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, + {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, + {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, + {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, +}; + +} // namespace + +class GroupedFusedCastMXFP8TestSuite : public ::testing::TestWithParam + , // Config + transformer_engine::DType, // InputType + transformer_engine::DType // OutputType + >> {}; + +TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ProcessingMethod processing_method = std::get<0>(GetParam()); + const ActivationKind activation = std::get<1>(GetParam()); + const ScalingDirection scaling_direction = std::get<2>(GetParam()); + const std::vector input_config = std::get<3>(GetParam()); + const DType input_type = std::get<4>(GetParam()); + const DType output_type = std::get<5>(GetParam()); + + const ShapeRepresentation shape_rep = static_cast(input_config[0]); + const size_t num_tensors = input_config[1]; + const std::vector logical_shape = {input_config[2], input_config[3]}; + std::vector first_dims(num_tensors); + std::vector last_dims(num_tensors); + std::vector offsets(num_tensors + 1, 0); + for (size_t t = 0; t < num_tensors; ++t) { + switch (shape_rep) { + case SAME_BOTH_DIMS: { + first_dims[t] = logical_shape[0] / num_tensors; + last_dims[t] = logical_shape[1]; + break; + } + case VARYING_FIRST_DIM: { + first_dims[t] = input_config[t + 4]; + last_dims[t] = logical_shape[1]; + break; + } + case VARYING_LAST_DIM: { + first_dims[t] = logical_shape[0]; + last_dims[t] = input_config[t + 4]; + break; + } + case VARYING_BOTH_DIMS: { + first_dims[t] = input_config[t + 4]; + last_dims[t] = input_config[t + (4 + num_tensors)]; + break; + } + } + offsets[t+1] = offsets[t] + first_dims[t] * last_dims[t]; + // Skips tests if tensor dims are not multiples of 128 + if ((first_dims[t] % 128 != 0) || (last_dims[t] % 128 != 0)) { + GTEST_SKIP(); + } + } + + // Skips non Act tests if the Activation type is not an identity + if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) + && activation != ActivationKind::Identity) { + GTEST_SKIP(); + } + // Skips Act tests if the Activation is an identity + if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT + || processing_method == ProcessingMethod::CAST_DACT + || processing_method == ProcessingMethod::CAST_ACT) && (activation == ActivationKind::Identity)) { + GTEST_SKIP(); + } + + bool rowwise = false; + bool colwise = false; + switch (scaling_direction) { + case ScalingDirection::ROWWISE: rowwise = true; break; + case ScalingDirection::COLWISE: colwise = true; break; + case ScalingDirection::BOTH: rowwise = true; colwise = true; break; + } + + auto OP = &identity; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + performTest(processing_method, OP, shape_rep, num_tensors, + logical_shape, first_dims, last_dims, offsets, + rowwise, colwise); + ); + ); +} + +std::string to_string(const ProcessingMethod method) { + switch (method) { + case ProcessingMethod::CAST_ONLY: return "CAST_ONLY"; + case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS"; + case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT"; + case ProcessingMethod::CAST_DACT: return "CAST_DACT"; + case ProcessingMethod::CAST_ACT: return "CAST_ACT"; + default: return ""; + } +} + +std::string to_string(const ActivationKind activation) { + switch (activation) { + case ActivationKind::Identity: return "Identity"; + case ActivationKind::GeLU: return "GeLU"; + case ActivationKind::SiLU: return "SiLU"; + case ActivationKind::ReLU: return "ReLU"; + case ActivationKind::QGeLU: return "QGeLU"; + case ActivationKind::SReLU: return "SReLU"; + default: return ""; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + GroupedFusedCastMXFP8TestSuite, + ::testing::Combine( + ::testing::ValuesIn(processing_methods), + ::testing::ValuesIn(activation_kinds), + ::testing::ValuesIn(scaling_directions), + ::testing::ValuesIn(input_config), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), + [](const testing::TestParamInfo& info) { + const ProcessingMethod method = std::get<0>(info.param); + std::string name = to_string(method); + if (method != ProcessingMethod::CAST_ONLY && method != ProcessingMethod::CAST_DBIAS) { + name += "X" + to_string(std::get<1>(info.param)); + } + + switch (std::get<2>(info.param)) { + case ScalingDirection::ROWWISE: name += "_ROWWISE"; break; + case ScalingDirection::COLWISE: name += "_COLWISE"; break; + case ScalingDirection::BOTH: name += "_BOTH"; break; + } + + const std::vector input = std::get<3>(info.param); + name += "_Shape_"; + switch(static_cast(input[0])) { + case ShapeRepresentation::SAME_BOTH_DIMS: name += "SAME_BOTH_DIMS"; break; + case ShapeRepresentation::VARYING_FIRST_DIM: name += "VARYING_FIRST_DIM"; break; + case ShapeRepresentation::VARYING_LAST_DIM: name += "VARYING_LAST_DIM"; break; + case ShapeRepresentation::VARYING_BOTH_DIMS: name += "VARYING_BOTH_DIMS"; break; + }; + + name += "_N_" + std::to_string(input[1]); + + name += "_Shape_" + + std::to_string(input[2]) + + "X" + std::to_string(input[3]); + + name += "_" + test::typeName(std::get<4>(info.param)) + + "_" + test::typeName(std::get<5>(info.param)); + return name; + }); \ No newline at end of file diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index de1a8864da..6c30d9e95c 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -16,6 +16,7 @@ #include "../utils.cuh" #include "dispatch/dequantize.cuh" #include "dispatch/quantize.cuh" +#include "dispatch/quantize_grouped.cuh" #include "transformer_engine/transpose.h" void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { @@ -26,6 +27,15 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea dispatch::quantize_fwd_helper(input, output, nullptr, stream); } +void nvte_quantize_grouped(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_grouped); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + dispatch::quantize_grouped_fwd_helper(input, output, nullptr, stream); +} + void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_noop); diff --git a/transformer_engine/common/cast/dispatch/quantize_grouped.cuh b/transformer_engine/common/cast/dispatch/quantize_grouped.cuh new file mode 100644 index 0000000000..8220c45f15 --- /dev/null +++ b/transformer_engine/common/cast/dispatch/quantize_grouped.cuh @@ -0,0 +1,122 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_grouped.cuh + * \brief Quantize Grouped Tensor dispatcher. + */ + +#ifndef TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_GROUPED_CUH_ +#define TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_GROUPED_CUH_ + +#include + +#include "../../common.h" +#include "../../transpose/cast_transpose.h" +#include "../../util/vectorized_pointwise.h" +#include "../core/common.cuh" +#include "../mxfp8/quantize_grouped_mxfp8.cuh" + +namespace transformer_engine { +namespace dispatch { + +template +void quantize_grouped_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + // "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + const NVTEGroupedTensor activation = nullptr; + NVTETensor dbias = nullptr; + NVTETensor workspace = nullptr; + + const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); + GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation); + Tensor *dbias_tensor = convertNVTETensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + mxfp8::quantize_grouped( + input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } +} + +// template +// void quantize_grouped_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTEGroupedTensor output, +// NVTEGroupedTensor dbias, NVTEGroupedTensor workspace, +// const NVTEQuantizationConfig quant_config, cudaStream_t stream) { +// using namespace detail; + +// const Tensor *grad_tensor = convertNVTETensorCheck(grad); +// const Tensor *input_tensor = convertNVTETensor(input); + +// Tensor *output_tensor = convertNVTETensorCheck(output); +// Tensor *dbias_tensor = convertNVTETensor(dbias); +// Tensor *workspace_tensor = convertNVTETensor(workspace); + +// // Quantization config +// QuantizationConfig quant_config_cpp; +// if (quant_config != nullptr) { +// quant_config_cpp = *reinterpret_cast(quant_config); +// } + +// // Noop flag +// Tensor dummy_tensor; +// Tensor *noop_tensor = &dummy_tensor; +// if (quant_config_cpp.noop_tensor != nullptr) { +// noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); +// } + +// // Check for unsupported options +// if (quant_config_cpp.stochastic_rounding) { +// NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, +// "Stochastic rounding is only supported for NVFP4 quantization."); +// } + +// NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), +// "Either rowwise or columnwise output data need to be allocated."); + +// // Dispatch to quantization kernel depending on data format +// switch (output_tensor->scaling_mode) { +// case NVTE_MXFP8_1D_SCALING: { +// mxfp8::quantize( +// *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, +// stream); +// break; +// } +// default: +// NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); +// } +// } + +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_GROUPED_CUH_ \ No newline at end of file diff --git a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh new file mode 100644 index 0000000000..e3110a9261 --- /dev/null +++ b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh @@ -0,0 +1,933 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_grouped_mxfp8.cuh + * \brief CUDA kernels to quantize grouped tensors to MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_GROUPED_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_GROUPED_MXFP8_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "../core/common.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace mxfp8 { +namespace quantize_grouped_kernel { + +constexpr int MAX_SUPPORTED_TENSOR_DESCRIPTORS = 64; +__device__ alignas(128) CUtensorMap g_tensor_maps_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 32; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t PACK_SIZE = 4; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 128; + +constexpr size_t ELTS_PER_CHUNK = CHUNK_DIM_Y * CHUNK_DIM_X; + +constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; +constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + +constexpr size_t BUFF_DIM_Y = THREADS_Y; +constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; +constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; +static_assert(BUFF_DIM_Y == 32); + +constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; +static_assert(STAGES >= 1); + +// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 + +__device__ __forceinline__ size_t get_current_tensor_id( + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr) { + if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { + const size_t current_row = current_offset / last_logical_dim; + const size_t rows_per_tensor = first_logical_dim / num_tensors; + return current_row / rows_per_tensor; + } else { + size_t low = 0; + size_t hi = num_tensors; // [low, hi] + + while (low < hi) { + const size_t mid = low + (hi - low) / 2; + const size_t mid_offset = static_cast(offsets_ptr[mid]); + + if (mid_offset <= current_offset) { + low = mid + 1; + } else { + hi = mid; + } + } + return low - 1; + } +} + +__device__ __forceinline__ size_t get_tensor_rows_num( + const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t first_logical_dim, + const int64_t *const __restrict__ first_dims_ptr, const size_t num_tensors) { + size_t rows_num = 0; + switch (shape_rep) { + case ShapeRepresentation::SAME_BOTH_DIMS: // rows_num = first_logical_dim / num_tensors; break; + case ShapeRepresentation::VARYING_LAST_DIM: + rows_num = first_logical_dim; + break; + case ShapeRepresentation::VARYING_FIRST_DIM: + case ShapeRepresentation::VARYING_BOTH_DIMS: + rows_num = static_cast(first_dims_ptr[tensor_id]); + break; + } + return rows_num; +} + +__device__ __forceinline__ size_t get_tensor_cols_num( + const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t last_logical_dim, + const int64_t *const __restrict__ last_dims_ptr) { + size_t cols_num = 0; + switch (shape_rep) { + case ShapeRepresentation::SAME_BOTH_DIMS: + case ShapeRepresentation::VARYING_FIRST_DIM: + cols_num = last_logical_dim; + break; + case ShapeRepresentation::VARYING_LAST_DIM: + case ShapeRepresentation::VARYING_BOTH_DIMS: + cols_num = static_cast(last_dims_ptr[tensor_id]); + break; + } + return cols_num; +} + +// Copies the base tensor map to shmem, modifies the copy, stores the modified tensor map at index +template +__device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_tensor_map, + CUtensorMap *global_tensor_map, + const uintptr_t global_data_ptr, + const size_t global_dim_Y, + const size_t global_dim_X) { + const size_t global_stride_bytes = global_dim_X * sizeof(T); + + __shared__ CUtensorMap shared_tensor_map; + shared_tensor_map = base_tensor_map; // Copy the base tensor map into shmem + + asm volatile( + "{\n\t" + ".reg.b64 tensor_map_ptr; \n\t" + "mov.b64 tensor_map_ptr, %0; \n\t" + "tensormap.replace.tile.global_address.b1024.b64 [tensor_map_ptr], %1; \n\t" + "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 1, %2; \n\t" // DIM Y + "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 0, %3; \n\t" // DIM X + "tensormap.replace.tile.global_stride.b1024.b64 [tensor_map_ptr], 0, %4; \n" + "}\n" ::"l"(reinterpret_cast(&shared_tensor_map)), + "l"(global_data_ptr), "r"(static_cast(global_dim_Y)), + "r"(static_cast(global_dim_X)), "l"(static_cast(global_stride_bytes)) + : "memory"); + *global_tensor_map = shared_tensor_map; +} + +template +__global__ void update_tma_descriptors( + const __grid_constant__ CUtensorMap base_tensor_map_input, + const __grid_constant__ CUtensorMap base_tensor_map_act_input, + const __grid_constant__ CUtensorMap base_tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap base_tensor_map_output_colwise, + const IType *const __restrict__ input_data_ptr, + const IType *const __restrict__ act_input_data_ptr, + const OType *const __restrict__ output_rowwise_data_ptr, + const OType *const __restrict__ output_colwise_data_ptr, const ShapeRepresentation shape_rep, + const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, const bool rowwise, const bool colwise, + const bool compute_activations) { + const bool leading_thread = (threadIdx.x == 0); + const size_t tensor_id = blockIdx.x; + + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + + const size_t offset_elts = offsets_ptr[tensor_id]; + + if (leading_thread && (tensor_id < num_tensors)) { + { + const uintptr_t global_data_ptr = reinterpret_cast(input_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id], + global_data_ptr, rows, cols); + } + if (compute_activations) { + const uintptr_t global_data_ptr = + reinterpret_cast(act_input_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_act_input, &g_tensor_maps_act_input[tensor_id], + global_data_ptr, rows, cols); + } + if (rowwise) { + const uintptr_t global_data_ptr = + reinterpret_cast(output_rowwise_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_output_rowwise, + &g_tensor_maps_output_rowwise[tensor_id], global_data_ptr, rows, + cols); + } + if (colwise) { + const uintptr_t global_data_ptr = + reinterpret_cast(output_colwise_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_output_colwise, + &g_tensor_maps_output_colwise[tensor_id], global_data_ptr, rows, + cols); + } + } +} + +__device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tensor_map) { + asm volatile("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" ::"l"(tensor_map)); +} + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) quantize_grouped_mxfp8_kernel( + const __grid_constant__ CUtensorMap tensor_map_input_static, + const __grid_constant__ CUtensorMap tensor_map_act_input_static, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, + const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t first_logical_dim, + const size_t last_logical_dim, const int64_t *const __restrict__ offsets_ptr, + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, + e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, + float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; + + if constexpr (NO_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const size_t block_global_offset = blockIdx.x * ELTS_PER_CHUNK; + + const size_t tensor_id = get_current_tensor_id(shape_rep, num_tensors, block_global_offset, + first_logical_dim, last_logical_dim, offsets_ptr); + + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + const size_t scale_stride_rowwise = cols / SCALE_DIM_X; + const size_t scale_stride_colwise = cols; + + const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + + // grouped tensor can be treated as continuous tensor for MXFP8 + const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); + const size_t offset_within_tensor = block_global_offset - tensor_base; + + const CUtensorMap &tensor_map_input = + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; + const CUtensorMap &tensor_map_act_input = + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; + const CUtensorMap &tensor_map_output_rowwise = + is_single_tensor ? tensor_map_output_rowwise_static : g_tensor_maps_output_rowwise[tensor_id]; + const CUtensorMap &tensor_map_output_colwise = + is_single_tensor ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id]; + + const bool leading_thread = (threadIdx.x == 0); + + if (leading_thread && (!is_single_tensor)) { + fence_acquire_tensormap(&tensor_map_input); + if constexpr (COMPUTE_ACTIVATIONS) { + fence_acquire_tensormap(&tensor_map_act_input); + } + if constexpr (ROWWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_rowwise); + } + if constexpr (COLWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_colwise); + } + } + + const size_t blocks_X_num_in_current_tensor = cols / CHUNK_DIM_X; + const size_t block_id_in_current_tensor = + is_single_tensor ? blockIdx.x : (blockIdx.x - tensor_base / ELTS_PER_CHUNK); + + const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; + const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; + + const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; + const size_t block_offset_X = block_id_X * CHUNK_DIM_X; + + e8m0_t *const scales_rowwise = + scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); + e8m0_t *const scales_colwise = + scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); + + const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X; + const size_t tid_Y_colwise = 0; + const size_t tid_X_colwise = threadIdx.x; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const size_t thread_offset_Y_colwise = tid_Y_colwise; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); + + OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; + if constexpr (IS_DBIAS) { +#pragma unroll + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; + } + } + + float block_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, leading_thread); + + int parity = 0; + + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], + &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], leading_thread); + } else { + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], leading_thread); + } + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_DIM; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, + global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], + leading_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], leading_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], parity); + + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType thread_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); + } + thread_amax = static_cast(thread_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_colwise[i] = elt; + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + const size_t scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } + + if constexpr (ROWWISE_SCALING) { + const size_t shmem_offset_base_rowwise = + buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[e]); + elt *= OP(act_in_elt, {}); + } + + // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (leading_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + + parity ^= 1; + + if constexpr (IS_DBIAS) { + if (is_single_tensor) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] + // HEIGHT = THREADS_Y + // WIDTH = THREADS_X * (SCALE_DIM_X + 1) + // Added extra 1-element padding per thread_X to reduce bank conflicts + float *partial_dbias_rowwise = reinterpret_cast(dshmem); + + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + + const int shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const int shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + } + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < THREADS_Y; ++i) { + // Add extra element offset per MXFP8 scaling block [1x32] + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + } + } + const int dbias_stride = cols; + const int dbias_offset_Y = block_id_Y; + const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; + } + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + block_amax = reduce_max(block_amax, warp_id); + } + + if (leading_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } + + destroy_barriers(mbar, leading_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_grouped_kernel + +template +void quantize_grouped(const GroupedTensor *input, const GroupedTensor *activations, + const Tensor *noop, GroupedTensor *output, Tensor *dbias, Tensor *workspace, + cudaStream_t stream) { + using namespace quantize_grouped_kernel; + + checkCuDriverContext(stream); + + const bool use_rowwise_scaling = output->has_data(); + const bool use_colwise_scaling = output->has_columnwise_data(); + + ScalingType scaling_type; + if (use_rowwise_scaling && (!use_colwise_scaling)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!use_rowwise_scaling) && use_colwise_scaling) { + scaling_type = ScalingType::COLWISE; + } else if (use_rowwise_scaling && use_colwise_scaling) { + scaling_type = ScalingType::BIDIMENSIONAL; + } + + ShapeRepresentation shape_rep; + if (output->all_same_shape()) { + shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; + } else if (output->all_same_first_dim()) { + shape_rep = ShapeRepresentation::VARYING_LAST_DIM; + } else if (output->all_same_last_dim()) { + shape_rep = ShapeRepresentation::VARYING_FIRST_DIM; + } else if (output->varying_both_dims()) { + shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS; + } + + // Treat a grouped tensor with const last dims as a single tensor + const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + + NVTE_CHECK(input->num_tensors == output->num_tensors, + "Number of input and output tensors must be same."); + NVTE_CHECK(input->has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + + const size_t num_tensors = input->num_tensors; + NVTE_CHECK( + num_tensors < MAX_SUPPORTED_TENSOR_DESCRIPTORS, + "Number of tensors in a group is larger than the MAX number of supported descriptors (64)."); + + const size_t first_logical_dim = input->logical_shape.data[0]; + const size_t last_logical_dim = input->logical_shape.data[1]; + const size_t elts_total = first_logical_dim * last_logical_dim; + + // Logical shape of a tensor with varying all dims is [1, M*K] + if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) { + NVTE_CHECK(first_logical_dim % 128 == 0, + "First dimension of a grouped tensor should be divisible by 128."); + } + NVTE_CHECK(last_logical_dim % 128 == 0, + "Last dimension of a grouped tensor should be divisible by 128."); + + e8m0_t *const scales_rowwise_ptr = reinterpret_cast(output->scale_inv.dptr); + e8m0_t *const scales_colwise_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + + if (use_rowwise_scaling) { + NVTE_CHECK(scales_rowwise_ptr != nullptr, "Scaling tensor must be allocated"); + } + if (use_colwise_scaling) { + NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); + } + + const int64_t *const offsets_ptr = reinterpret_cast(input->tensor_offsets.dptr); + const int64_t *const first_dims_ptr = reinterpret_cast(input->first_dims.dptr); + const int64_t *const last_dims_ptr = reinterpret_cast(input->last_dims.dptr); + + CheckNoopTensor(*noop, "cast_noop"); + + const size_t blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); + const dim3 grid(blocks); + const size_t block_size = THREADS_PER_CHUNK; + + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + const size_t dbias_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); + const size_t dbias_cols = last_logical_dim; + if constexpr (IS_DBIAS) { + NVTE_CHECK(is_single_tensor, + "DBias is only supported for tensors with the const last dimension."); + NVTE_CHECK(dbias->data.dtype == input->dtype(), + "DBias must have the same type as input_tensor."); + NVTE_CHECK(dbias->data.shape == std::vector{last_logical_dim}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input->dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, last_logical_dim, + BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, input_type_bit_size); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + input_type_bit_size); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + output_type_bit_size); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, + first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, + last_logical_dim, 0, output_type_bit_size); + } constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + auto kernel = quantize_grouped_mxfp8_kernel; + switch (scaling_type) { + case ScalingType::ROWWISE: { + kernel = quantize_grouped_mxfp8_kernel; + break; + } + case ScalingType::COLWISE: { + kernel = quantize_grouped_mxfp8_kernel; + break; + } + case ScalingType::BIDIMENSIONAL: { + kernel = quantize_grouped_mxfp8_kernel; + break; + } + } + + // Update tensor descriptors before launching the kernel + if (!is_single_tensor) { + const IType *const input_dptr = reinterpret_cast(input->data.dptr); + + const IType *const act_input_dptr = + (IS_DACT || IS_ACT) ? reinterpret_cast(activations->data.dptr) + : nullptr; + + OType *const output_rowwise_dptr = + use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; + + OType *const output_colwise_dptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) + : nullptr; + update_tma_descriptors<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, + output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling, + use_colwise_scaling, IS_ACT); + } + + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim, + last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, + scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); + + if constexpr (IS_DBIAS) { + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) +} + +} // namespace mxfp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_GROUPED_MXFP8_CUH_ \ No newline at end of file diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 576494a4de..f6081cae7c 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -89,6 +89,17 @@ extern "C" { */ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Casts input grouped tensor to MXFP8. + * The type of quantized tensor in the output depends on the scaling mode of the output + * tensor. See file level comments. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in,out] output Output grouped MXFP8 tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_grouped(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream); + /*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel * based on the value of the 'noop' tensor. * The type of quantized tensor in the output depends on the scaling mode of the output From ac23f0607a85954e3a66ff6a7d4c6d9e0e33f37e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Jan 2026 17:00:44 +0000 Subject: [PATCH 02/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 2 +- transformer_engine/common/cast/dispatch/quantize_grouped.cuh | 2 +- transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 9ce7b34bdd..0ba8ad3b4b 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -657,4 +657,4 @@ INSTANTIATE_TEST_SUITE_P( name += "_" + test::typeName(std::get<4>(info.param)) + "_" + test::typeName(std::get<5>(info.param)); return name; - }); \ No newline at end of file + }); diff --git a/transformer_engine/common/cast/dispatch/quantize_grouped.cuh b/transformer_engine/common/cast/dispatch/quantize_grouped.cuh index 8220c45f15..bcfc35a843 100644 --- a/transformer_engine/common/cast/dispatch/quantize_grouped.cuh +++ b/transformer_engine/common/cast/dispatch/quantize_grouped.cuh @@ -119,4 +119,4 @@ void quantize_grouped_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTenso } // namespace dispatch } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_GROUPED_CUH_ \ No newline at end of file +#endif // TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_GROUPED_CUH_ diff --git a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh index e3110a9261..a279b56ac0 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh @@ -930,4 +930,4 @@ void quantize_grouped(const GroupedTensor *input, const GroupedTensor *activatio } // namespace dispatch } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_QUANTIZE_GROUPED_MXFP8_CUH_ \ No newline at end of file +#endif // TRANSFORMER_ENGINE_QUANTIZE_GROUPED_MXFP8_CUH_ From 99f1f63b75867a31c9a0eb3d6cb43c3584645503 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 21 Jan 2026 18:41:20 +0000 Subject: [PATCH 03/12] Fixed the year to 2026 Signed-off-by: Oleg Goncharov --- transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh index a279b56ac0..d441152d43 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ From 74151381c7c0080a445f9a5e0f21f336f824d1d0 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 21 Jan 2026 19:30:12 +0000 Subject: [PATCH 04/12] Added compilation guards Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/quantize_grouped_mxfp8.cuh | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh index d441152d43..151d8fc51d 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh @@ -141,20 +141,26 @@ __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_te __shared__ CUtensorMap shared_tensor_map; shared_tensor_map = base_tensor_map; // Copy the base tensor map into shmem - - asm volatile( - "{\n\t" - ".reg.b64 tensor_map_ptr; \n\t" - "mov.b64 tensor_map_ptr, %0; \n\t" - "tensormap.replace.tile.global_address.b1024.b64 [tensor_map_ptr], %1; \n\t" - "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 1, %2; \n\t" // DIM Y - "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 0, %3; \n\t" // DIM X - "tensormap.replace.tile.global_stride.b1024.b64 [tensor_map_ptr], 0, %4; \n" - "}\n" ::"l"(reinterpret_cast(&shared_tensor_map)), - "l"(global_data_ptr), "r"(static_cast(global_dim_Y)), - "r"(static_cast(global_dim_X)), "l"(static_cast(global_stride_bytes)) - : "memory"); - *global_tensor_map = shared_tensor_map; + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + asm volatile( + "{\n\t" + ".reg.b64 tensor_map_ptr; \n\t" + "mov.b64 tensor_map_ptr, %0; \n\t" + "tensormap.replace.tile.global_address.b1024.b64 [tensor_map_ptr], %1; \n\t" + "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 1, %2; \n\t" // DIM Y + "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 0, %3; \n\t" // DIM X + "tensormap.replace.tile.global_stride.b1024.b64 [tensor_map_ptr], 0, %4; \n" + "}\n" ::"l"(reinterpret_cast(&shared_tensor_map)), + "l"(global_data_ptr), "r"(static_cast(global_dim_Y)), + "r"(static_cast(global_dim_X)), "l"(static_cast(global_stride_bytes)) + : "memory"); + *global_tensor_map = shared_tensor_map; + } else { + NVTE_DEVICE_ERROR( + "tensormap.replace is architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } } template @@ -210,7 +216,11 @@ __global__ void update_tma_descriptors( } __device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tensor_map) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" ::"l"(tensor_map)); +#else + NVTE_DEVICE_ERROR("fence_acquire_tensormap is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } template Date: Wed, 21 Jan 2026 19:31:44 +0000 Subject: [PATCH 05/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/cast/mxfp8/quantize_grouped_mxfp8.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh index 151d8fc51d..02a833b274 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh @@ -157,9 +157,9 @@ __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_te : "memory"); *global_tensor_map = shared_tensor_map; } else { - NVTE_DEVICE_ERROR( - "tensormap.replace is architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); + NVTE_DEVICE_ERROR( + "tensormap.replace is architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); } } From 39bb24f6d546b989515417fa0c50c97d44c8366c Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Thu, 22 Jan 2026 18:07:13 +0000 Subject: [PATCH 06/12] Added BWD pass Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 32 ++++++++- transformer_engine/common/cast/cast.cu | 13 ++++ .../common/cast/dispatch/quantize_grouped.cuh | 65 +++++++++++++++---- .../cast/mxfp8/quantize_grouped_mxfp8.cuh | 7 +- .../common/include/transformer_engine/cast.h | 20 ++++++ 5 files changed, 122 insertions(+), 15 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 0ba8ad3b4b..3c44564e86 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -258,6 +258,7 @@ void performTest(const ProcessingMethod processing_method, std::uniform_real_distribution<> dis(-2.0, 1.0); std::vector in_data(elts_num); + std::vector grad_data(elts_num); std::vector out_data_rowwise_h(rowwise ? elts_num : 0); std::vector out_data_colwise_h(colwise ? elts_num : 0); @@ -271,6 +272,7 @@ void performTest(const ProcessingMethod processing_method, for (size_t i = 0; i < elts_num; ++i) { const float val = dis(gen); + grad_data[i] = static_cast(val); in_data[i] = static_cast(val); } @@ -297,6 +299,7 @@ void performTest(const ProcessingMethod processing_method, const size_t last_dims_size = num_tensors * sizeof(size_t); const size_t offsets_size = (num_tensors + 1) * sizeof(size_t); + InputType* grad_data_d; InputType* in_data_d; OutputType* out_data_rowwise_d; OutputType* out_data_colwise_d; @@ -306,11 +309,13 @@ void performTest(const ProcessingMethod processing_method, size_t* last_dims_d; size_t* offsets_d; + cudaMalloc((void**)&grad_data_d, in_data_size); cudaMalloc((void**)&in_data_d, in_data_size); cudaMalloc((void**)&first_dims_d, first_dims_size); cudaMalloc((void**)&last_dims_d, last_dims_size); cudaMalloc((void**)&offsets_d, offsets_size); + cudaMemcpy(grad_data_d, grad_data.data(), in_data_size, cudaMemcpyHostToDevice); cudaMemcpy(in_data_d, in_data.data(), in_data_size, cudaMemcpyHostToDevice); cudaMemcpy(first_dims_d, first_dims_h.data(), first_dims_size, cudaMemcpyHostToDevice); cudaMemcpy(last_dims_d, last_dims_h.data(), last_dims_size, cudaMemcpyHostToDevice); @@ -330,26 +335,32 @@ void performTest(const ProcessingMethod processing_method, last_dims_shape_.data[0] = num_tensors; offsets_shape_.data[0] = num_tensors + 1; + NVTEGroupedTensor grad_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); NVTEGroupedTensor in_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); NVTEGroupedTensor out_group_tensor = nvte_create_grouped_tensor(NVTE_MXFP8_1D_SCALING, num_tensors, logical_shape_); + NVTEBasicTensor grad_data_tensor = {grad_data_d, static_cast(itype), logical_shape_}; NVTEBasicTensor in_data_tensor = {in_data_d, static_cast(itype), logical_shape_}; nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor); + nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor); if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape_}; + nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor); nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor); nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor); } if ((shape_rep == VARYING_LAST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { NVTEBasicTensor last_dims_tensor = {last_dims_d, kNVTEInt64, last_dims_shape_}; + nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor); nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor); nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor); } if (shape_rep != SAME_BOTH_DIMS) { NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64, offsets_shape_}; + nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor); nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor); nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor); } @@ -378,6 +389,8 @@ void performTest(const ProcessingMethod processing_method, nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, &out_scales_colwise_tensor); } + Tensor output_dbias("output_dbias", std::vector{ cols }, itype); + // Reference (CPU) for (size_t t = 0; t < num_tensors; ++t) { const size_t M = first_dims_h[t]; @@ -388,6 +401,7 @@ void performTest(const ProcessingMethod processing_method, const size_t data_offset = offsets_h[t]; const size_t sfs_offset = data_offset / 32; + const InputType* const grad_ptr = grad_data.data() + data_offset; const InputType* const in_ptr = in_data.data() + data_offset; OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; @@ -395,7 +409,7 @@ void performTest(const ProcessingMethod processing_method, fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + sfs_offset; compute_ref( - processing_method, OP, rowwise, colwise, in_ptr, /*grad=*/ nullptr, + processing_method, OP, rowwise, colwise, in_ptr, grad_ptr, out_data_rowwise_ptr, out_data_colwise_ptr, out_scales_rowwise_ptr, out_scales_colwise_ptr, /*output_dbias=*/ nullptr, M, K, @@ -404,7 +418,20 @@ void performTest(const ProcessingMethod processing_method, } // GPU - nvte_quantize_grouped(in_group_tensor, out_group_tensor, 0); + Tensor workspace; + + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_quantize_grouped(in_group_tensor, out_group_tensor, 0); + break; + } + case ProcessingMethod::CAST_DBIAS: { + nvte_quantize_dbias_grouped(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + nvte_quantize_dbias_grouped(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); + break; + } + } cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); @@ -444,6 +471,7 @@ void performTest(const ProcessingMethod processing_method, out_data_colwise_h.data(), rows, cols, false, mismatches_elts); } + cudaFree(grad_data_d); cudaFree(in_data_d); cudaFree(first_dims_d); cudaFree(last_dims_d); diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 6c30d9e95c..333696bd60 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -70,6 +70,19 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_quantize_dbias_grouped(const NVTEGroupedTensor input, NVTEGroupedTensor output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_grouped); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = false; + constexpr const NVTEGroupedTensor activation_input = nullptr; + + dispatch::quantize_grouped_bwd_helper( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dequantize); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/dispatch/quantize_grouped.cuh b/transformer_engine/common/cast/dispatch/quantize_grouped.cuh index bcfc35a843..723b883b35 100644 --- a/transformer_engine/common/cast/dispatch/quantize_grouped.cuh +++ b/transformer_engine/common/cast/dispatch/quantize_grouped.cuh @@ -29,6 +29,16 @@ void quantize_grouped_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTenso NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); + const NVTEGroupedTensor activation = nullptr; + NVTETensor dbias = nullptr; + NVTETensor workspace = nullptr; + + const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); + GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation); + Tensor *dbias_tensor = convertNVTETensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + // Quantization config QuantizationConfig quant_config_cpp; if (quant_config != nullptr) { @@ -42,22 +52,12 @@ void quantize_grouped_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTenso noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); } - // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + // NVTE_CHECK(output->has_data() || output->has_columnwise_data(), // "Either rowwise or columnwise output data need to be allocated."); // Dispatch to quantization kernel depending on data format switch (scaling_mode) { case NVTE_MXFP8_1D_SCALING: { - const NVTEGroupedTensor activation = nullptr; - NVTETensor dbias = nullptr; - NVTETensor workspace = nullptr; - - const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); - GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); - const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation); - Tensor *dbias_tensor = convertNVTETensor(dbias); - Tensor *workspace_tensor = convertNVTETensor(workspace); - mxfp8::quantize_grouped( input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); @@ -68,6 +68,49 @@ void quantize_grouped_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTenso } } +template +void quantize_grouped_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); + + const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); + const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); + GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + Tensor *dbias_tensor = convertNVTETensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + // "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + mxfp8::quantize_grouped( + grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } +} + // template // void quantize_grouped_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTEGroupedTensor output, // NVTEGroupedTensor dbias, NVTEGroupedTensor workspace, diff --git a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh index 02a833b274..520b50726b 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh @@ -932,8 +932,11 @@ void quantize_grouped(const GroupedTensor *input, const GroupedTensor *activatio if constexpr (IS_DBIAS) { common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) + } + + NVTE_CHECK_CUDA(cudaGetLastError()); + ); // NOLINT(*) + ); // NOLINT(*) } } // namespace mxfp8 diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index f6081cae7c..5a6c08c724 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -143,6 +143,26 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output, void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Casts input grouped tensor to MXFP8. Additionally, reduces the input along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_grouped(const NVTEGroupedTensor input, NVTEGroupedTensor output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + /*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the GeLU backward along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, From 452651a9bc0eb862949167af391aae073c7b46bc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Jan 2026 18:13:36 +0000 Subject: [PATCH 07/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/cast/dispatch/quantize_grouped.cuh | 4 ++-- .../common/cast/mxfp8/quantize_grouped_mxfp8.cuh | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/cast/dispatch/quantize_grouped.cuh b/transformer_engine/common/cast/dispatch/quantize_grouped.cuh index 723b883b35..2841cdeb49 100644 --- a/transformer_engine/common/cast/dispatch/quantize_grouped.cuh +++ b/transformer_engine/common/cast/dispatch/quantize_grouped.cuh @@ -102,8 +102,8 @@ void quantize_grouped_bwd_helper(const NVTEGroupedTensor grad, const NVTEGrouped switch (scaling_mode) { case NVTE_MXFP8_1D_SCALING: { mxfp8::quantize_grouped( - grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); + grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); break; } default: diff --git a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh index 520b50726b..4d7be34c05 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh @@ -933,10 +933,9 @@ void quantize_grouped(const GroupedTensor *input, const GroupedTensor *activatio if constexpr (IS_DBIAS) { common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); } - - NVTE_CHECK_CUDA(cudaGetLastError()); - ); // NOLINT(*) - ); // NOLINT(*) + + NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) } } // namespace mxfp8 From e8beb1e0393aff24b412ff313329b7ba9ac8abc4 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 23 Jan 2026 17:47:43 +0000 Subject: [PATCH 08/12] Added dbias and dact tests. Refactoring. Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 161 ++++++++++++----- transformer_engine/common/activation/gelu.cu | 13 ++ transformer_engine/common/cast/cast.cu | 13 +- .../common/cast/dispatch/quantize.cuh | 90 ++++++++++ .../common/cast/dispatch/quantize_grouped.cuh | 165 ------------------ ...ped_mxfp8.cuh => group_quantize_mxfp8.cuh} | 38 ++-- .../common/include/transformer_engine/cast.h | 31 +++- 7 files changed, 275 insertions(+), 236 deletions(-) delete mode 100644 transformer_engine/common/cast/dispatch/quantize_grouped.cuh rename transformer_engine/common/cast/mxfp8/{quantize_grouped_mxfp8.cuh => group_quantize_mxfp8.cuh} (96%) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 3c44564e86..8d52f24fe5 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -58,7 +58,8 @@ void compute_ref(const ProcessingMethod processing_method, const size_t rows, const size_t cols, const size_t scales_stride_rowwise, - const size_t scales_stride_colwise) + const size_t scales_stride_colwise, + const bool is_single_tensor) { const size_t tile_size_Y = 32; const size_t tile_size_X = 32; @@ -93,18 +94,18 @@ void compute_ref(const ProcessingMethod processing_method, const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); float elt = static_cast(input[idx]); - // if (processing_method == ProcessingMethod::CAST_DBIAS) { - // // grad is the input - // elt = static_cast(grad[idx]); - // } + if (processing_method == ProcessingMethod::CAST_DBIAS) { + // grad is the input + elt = static_cast(grad[idx]); + } if (processing_method != ProcessingMethod::CAST_ONLY && processing_method != ProcessingMethod::CAST_DBIAS) { elt = OP(elt); } - // if (processing_method == ProcessingMethod::CAST_DACT || - // processing_method == ProcessingMethod::CAST_DBIAS_DACT) { - // elt *= static_cast(grad[idx]); - // } + if (processing_method == ProcessingMethod::CAST_DACT || + processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + elt *= static_cast(grad[idx]); + } thread_dbias[j] += elt; // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32 @@ -167,9 +168,12 @@ void compute_ref(const ProcessingMethod processing_method, } } } - // for (size_t j = 0; j < cols; ++j) { - // output_dbias[j] = static_cast(output_dbias_fp32[j]); - // } + + if (is_single_tensor) { + for (size_t j = 0; j < cols; ++j) { + output_dbias[j] = static_cast(output_dbias_fp32[j]); + } + } } template @@ -252,6 +256,7 @@ void performTest(const ProcessingMethod processing_method, const size_t elts_num = rows * cols; const size_t sfs_num = (rows * cols) / 32; + const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM); std::vector scales_shape = {sfs_num}; std::mt19937 gen; @@ -270,6 +275,8 @@ void performTest(const ProcessingMethod processing_method, std::vector out_scales_rowwise_ref(rowwise ? sfs_num : 0); std::vector out_scales_colwise_ref(colwise ? sfs_num : 0); + std::vector ref_output_dbias(is_single_tensor ? cols : 0); + for (size_t i = 0; i < elts_num; ++i) { const float val = dis(gen); grad_data[i] = static_cast(val); @@ -342,7 +349,7 @@ void performTest(const ProcessingMethod processing_method, NVTEBasicTensor grad_data_tensor = {grad_data_d, static_cast(itype), logical_shape_}; NVTEBasicTensor in_data_tensor = {in_data_d, static_cast(itype), logical_shape_}; nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor); - nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor); + nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &grad_data_tensor); if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape_}; @@ -392,43 +399,71 @@ void performTest(const ProcessingMethod processing_method, Tensor output_dbias("output_dbias", std::vector{ cols }, itype); // Reference (CPU) - for (size_t t = 0; t < num_tensors; ++t) { - const size_t M = first_dims_h[t]; - const size_t K = last_dims_h[t]; - - const size_t scales_stride_rowwise = K / 32; - const size_t scales_stride_colwise = K; - const size_t data_offset = offsets_h[t]; - const size_t sfs_offset = data_offset / 32; - - const InputType* const grad_ptr = grad_data.data() + data_offset; - const InputType* const in_ptr = in_data.data() + data_offset; - OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; - OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; - fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + sfs_offset; - fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + sfs_offset; + if (is_single_tensor) { + const size_t scales_stride_rowwise = cols / 32; + const size_t scales_stride_colwise = cols; compute_ref( - processing_method, OP, rowwise, colwise, in_ptr, grad_ptr, - out_data_rowwise_ptr, out_data_colwise_ptr, - out_scales_rowwise_ptr, out_scales_colwise_ptr, - /*output_dbias=*/ nullptr, M, K, + processing_method, OP, rowwise, colwise, in_data.data(), grad_data.data(), + out_data_rowwise_ref.data(), out_data_colwise_ref.data(), + out_scales_rowwise_ref.data(), out_scales_colwise_ref.data(), + ref_output_dbias.data(), rows, cols, scales_stride_rowwise, - scales_stride_colwise); + scales_stride_colwise, + is_single_tensor); + } else { + for (size_t t = 0; t < num_tensors; ++t) { + const size_t M = first_dims_h[t]; + const size_t K = last_dims_h[t]; + + const size_t scales_stride_rowwise = K / 32; + const size_t scales_stride_colwise = K; + const size_t data_offset = offsets_h[t]; + const size_t sfs_offset = data_offset / 32; + + const InputType* const grad_ptr = grad_data.data() + data_offset; + const InputType* const in_ptr = in_data.data() + data_offset; + OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; + OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; + fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + sfs_offset; + fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + sfs_offset; + + compute_ref( + processing_method, OP, rowwise, colwise, in_ptr, grad_ptr, + out_data_rowwise_ptr, out_data_colwise_ptr, + out_scales_rowwise_ptr, out_scales_colwise_ptr, + ref_output_dbias.data(), M, K, + scales_stride_rowwise, + scales_stride_colwise, + is_single_tensor); + } } // GPU Tensor workspace; - switch (processing_method) { case ProcessingMethod::CAST_ONLY: { - nvte_quantize_grouped(in_group_tensor, out_group_tensor, 0); + nvte_group_quantize(in_group_tensor, out_group_tensor, 0); break; } case ProcessingMethod::CAST_DBIAS: { - nvte_quantize_dbias_grouped(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); + nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); + break; + } + case ProcessingMethod::CAST_DBIAS_DACT: { + auto nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dgelu; + // if (OP == &dsilu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsilu; } + // else if (OP == &drelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_drelu; } + // else if (OP == &dqgelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dqgelu; } + // else if (OP == &dsrelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsrelu; } + + nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor, + output_dbias.data(), workspace.data(), 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_quantize_dbias_grouped(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); + nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor, + output_dbias.data(), workspace.data(), 0); break; } } @@ -471,6 +506,19 @@ void performTest(const ProcessingMethod processing_method, out_data_colwise_h.data(), rows, cols, false, mismatches_elts); } + if (processing_method == ProcessingMethod::CAST_DBIAS + || processing_method == ProcessingMethod::CAST_DBIAS_DACT) + { + auto [atol_dbias, rtol_dbias] = getTolerances(itype); + if (itype == DType::kFloat32) { + atol_dbias = 1e-4; + rtol_dbias *= sqrt(static_cast(rows)) ; + } else { + rtol_dbias *= 4; + } + compareResults("output_dbias", output_dbias, ref_output_dbias.data(), true, atol_dbias, rtol_dbias); + } + cudaFree(grad_data_d); cudaFree(in_data_d); cudaFree(first_dims_d); @@ -488,16 +536,15 @@ void performTest(const ProcessingMethod processing_method, std::vector processing_methods = { ProcessingMethod::CAST_ONLY, - // ProcessingMethod::CAST_DBIAS, - // ProcessingMethod::CAST_DBIAS_DACT, + ProcessingMethod::CAST_DBIAS, + ProcessingMethod::CAST_DBIAS_DACT, // ProcessingMethod::CAST_DACT, // ProcessingMethod::CAST_ACT, }; -// Only GeLU activation tests are supported std::vector activation_kinds = { ActivationKind::Identity, - // ActivationKind::GeLU, + ActivationKind::GeLU, // ActivationKind::SiLU, // ActivationKind::ReLU, // ActivationKind::QGeLU, @@ -553,8 +600,10 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { const std::vector input_config = std::get<3>(GetParam()); const DType input_type = std::get<4>(GetParam()); const DType output_type = std::get<5>(GetParam()); - + const ShapeRepresentation shape_rep = static_cast(input_config[0]); + const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM); + const size_t num_tensors = input_config[1]; const std::vector logical_shape = {input_config[2], input_config[3]}; std::vector first_dims(num_tensors); @@ -589,6 +638,11 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { GTEST_SKIP(); } } + // Skips DBias tests if last dimension of tensors variates + if ((processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) + && !is_single_tensor) { + GTEST_SKIP(); + } // Skips non Act tests if the Activation type is not an identity if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) @@ -612,6 +666,25 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { auto OP = &identity; + if (processing_method == ProcessingMethod::CAST_ACT) { + switch (activation) { + case ActivationKind::GeLU: OP = &gelu; break; + // case ActivationKind::SiLU: OP = &silu; break; + // case ActivationKind::ReLU: OP = &relu; break; + // case ActivationKind::QGeLU: OP = &qgelu; break; + // case ActivationKind::SReLU: OP = &srelu; break; + } + } else if (processing_method == ProcessingMethod::CAST_DACT + || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + switch (activation) { + case ActivationKind::GeLU: OP = &dgelu; break; + // case ActivationKind::SiLU: OP = &dsilu; break; + // case ActivationKind::ReLU: OP = &drelu; break; + // case ActivationKind::QGeLU: OP = &dqgelu; break; + // case ActivationKind::SReLU: OP = &dsrelu; break; + } + } + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, performTest(processing_method, OP, shape_rep, num_tensors, @@ -657,9 +730,7 @@ INSTANTIATE_TEST_SUITE_P( [](const testing::TestParamInfo& info) { const ProcessingMethod method = std::get<0>(info.param); std::string name = to_string(method); - if (method != ProcessingMethod::CAST_ONLY && method != ProcessingMethod::CAST_DBIAS) { - name += "X" + to_string(std::get<1>(info.param)); - } + name += "X" + to_string(std::get<1>(info.param)); switch (std::get<2>(info.param)) { case ScalingDirection::ROWWISE: name += "_ROWWISE"; break; diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index 675341f7db..5e1d0d4fe1 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -33,6 +33,19 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 333696bd60..c24725495a 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -16,7 +16,6 @@ #include "../utils.cuh" #include "dispatch/dequantize.cuh" #include "dispatch/quantize.cuh" -#include "dispatch/quantize_grouped.cuh" #include "transformer_engine/transpose.h" void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { @@ -27,13 +26,13 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea dispatch::quantize_fwd_helper(input, output, nullptr, stream); } -void nvte_quantize_grouped(const NVTEGroupedTensor input, NVTEGroupedTensor output, +void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_grouped); + NVTE_API_CALL(nvte_group_quantize); using namespace transformer_engine; constexpr bool IS_ACT = false; - dispatch::quantize_grouped_fwd_helper(input, output, nullptr, stream); + dispatch::group_quantize_fwd_helper(input, output, nullptr, stream); } void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, @@ -70,16 +69,16 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d input, activation_input, output, dbias, workspace, nullptr, stream); } -void nvte_quantize_dbias_grouped(const NVTEGroupedTensor input, NVTEGroupedTensor output, +void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_grouped); + NVTE_API_CALL(nvte_group_quantize_dbias); using namespace transformer_engine; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = false; constexpr const NVTEGroupedTensor activation_input = nullptr; - dispatch::quantize_grouped_bwd_helper( + dispatch::group_quantize_bwd_helper( input, activation_input, output, dbias, workspace, nullptr, stream); } diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index a02e7f4f07..87318a18a5 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -19,6 +19,7 @@ #include "../core/common.cuh" #include "../fp8/quantize_fp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" +#include "../mxfp8/group_quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" #include "../nvfp4/quantize_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" @@ -371,6 +372,95 @@ void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs, } } +template +void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); + + const NVTEGroupedTensor activation = nullptr; + NVTETensor dbias = nullptr; + NVTETensor workspace = nullptr; + + const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); + GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation); + Tensor *dbias_tensor = convertNVTETensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + mxfp8::group_quantize( + input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } +} + +template +void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); + + const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); + const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); + GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + Tensor *dbias_tensor = convertNVTETensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + mxfp8::group_quantize( + grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } +} + } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/dispatch/quantize_grouped.cuh b/transformer_engine/common/cast/dispatch/quantize_grouped.cuh deleted file mode 100644 index 2841cdeb49..0000000000 --- a/transformer_engine/common/cast/dispatch/quantize_grouped.cuh +++ /dev/null @@ -1,165 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/*! \file quantize_grouped.cuh - * \brief Quantize Grouped Tensor dispatcher. - */ - -#ifndef TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_GROUPED_CUH_ -#define TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_GROUPED_CUH_ - -#include - -#include "../../common.h" -#include "../../transpose/cast_transpose.h" -#include "../../util/vectorized_pointwise.h" -#include "../core/common.cuh" -#include "../mxfp8/quantize_grouped_mxfp8.cuh" - -namespace transformer_engine { -namespace dispatch { - -template -void quantize_grouped_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor output, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - using namespace detail; - - NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); - - const NVTEGroupedTensor activation = nullptr; - NVTETensor dbias = nullptr; - NVTETensor workspace = nullptr; - - const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); - GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); - const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation); - Tensor *dbias_tensor = convertNVTETensor(dbias); - Tensor *workspace_tensor = convertNVTETensor(workspace); - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - // "Either rowwise or columnwise output data need to be allocated."); - - // Dispatch to quantization kernel depending on data format - switch (scaling_mode) { - case NVTE_MXFP8_1D_SCALING: { - mxfp8::quantize_grouped( - input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); - } -} - -template -void quantize_grouped_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, - NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - using namespace detail; - - NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); - - const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); - const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); - GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); - Tensor *dbias_tensor = convertNVTETensor(dbias); - Tensor *workspace_tensor = convertNVTETensor(workspace); - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), - // "Either rowwise or columnwise output data need to be allocated."); - - // Dispatch to quantization kernel depending on data format - switch (scaling_mode) { - case NVTE_MXFP8_1D_SCALING: { - mxfp8::quantize_grouped( - grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); - } -} - -// template -// void quantize_grouped_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTEGroupedTensor output, -// NVTEGroupedTensor dbias, NVTEGroupedTensor workspace, -// const NVTEQuantizationConfig quant_config, cudaStream_t stream) { -// using namespace detail; - -// const Tensor *grad_tensor = convertNVTETensorCheck(grad); -// const Tensor *input_tensor = convertNVTETensor(input); - -// Tensor *output_tensor = convertNVTETensorCheck(output); -// Tensor *dbias_tensor = convertNVTETensor(dbias); -// Tensor *workspace_tensor = convertNVTETensor(workspace); - -// // Quantization config -// QuantizationConfig quant_config_cpp; -// if (quant_config != nullptr) { -// quant_config_cpp = *reinterpret_cast(quant_config); -// } - -// // Noop flag -// Tensor dummy_tensor; -// Tensor *noop_tensor = &dummy_tensor; -// if (quant_config_cpp.noop_tensor != nullptr) { -// noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); -// } - -// // Check for unsupported options -// if (quant_config_cpp.stochastic_rounding) { -// NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, -// "Stochastic rounding is only supported for NVFP4 quantization."); -// } - -// NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), -// "Either rowwise or columnwise output data need to be allocated."); - -// // Dispatch to quantization kernel depending on data format -// switch (output_tensor->scaling_mode) { -// case NVTE_MXFP8_1D_SCALING: { -// mxfp8::quantize( -// *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, -// stream); -// break; -// } -// default: -// NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); -// } -// } - -} // namespace dispatch -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_GROUPED_CUH_ diff --git a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh similarity index 96% rename from transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh rename to transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 4d7be34c05..a75c7607f0 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_grouped_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -4,12 +4,12 @@ * See LICENSE for license information. ************************************************************************/ -/*! \file quantize_grouped_mxfp8.cuh +/*! \file group_quantize_mxfp8.cuh * \brief CUDA kernels to quantize grouped tensors to MXFP8. */ -#ifndef TRANSFORMER_ENGINE_QUANTIZE_GROUPED_MXFP8_CUH_ -#define TRANSFORMER_ENGINE_QUANTIZE_GROUPED_MXFP8_CUH_ +#ifndef TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_ #include #include @@ -25,7 +25,7 @@ namespace transformer_engine { namespace dispatch { namespace mxfp8 { -namespace quantize_grouped_kernel { +namespace group_quantize_kernel { constexpr int MAX_SUPPORTED_TENSOR_DESCRIPTORS = 64; __device__ alignas(128) CUtensorMap g_tensor_maps_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; @@ -226,7 +226,7 @@ __device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tenso template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) quantize_grouped_mxfp8_kernel( +__global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( const __grid_constant__ CUtensorMap tensor_map_input_static, const __grid_constant__ CUtensorMap tensor_map_act_input_static, const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, @@ -724,14 +724,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) quantize_grouped_mxfp8_kern destroy_barriers(mbar, leading_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -} // namespace quantize_grouped_kernel +} // namespace group_quantize_kernel template -void quantize_grouped(const GroupedTensor *input, const GroupedTensor *activations, - const Tensor *noop, GroupedTensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - using namespace quantize_grouped_kernel; +void group_quantize(const GroupedTensor *input, const GroupedTensor *activations, + const Tensor *noop, GroupedTensor *output, Tensor *dbias, Tensor *workspace, + cudaStream_t stream) { + using namespace group_quantize_kernel; checkCuDriverContext(stream); @@ -766,6 +766,14 @@ void quantize_grouped(const GroupedTensor *input, const GroupedTensor *activatio NVTE_CHECK(input->has_data(), "Cannot quantize tensor without rowwise data."); NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + if (IS_DACT) { + NVTE_CHECK(activations->has_data(), "Activations tensor must have data."); + NVTE_CHECK(input->num_tensors == activations->num_tensors, + "Number of grad and activations tensors must be same."); + NVTE_CHECK(input->dtype() == activations->dtype(), + "Grad and activations tensors must have the same type."); + } + const size_t num_tensors = input->num_tensors; NVTE_CHECK( num_tensors < MAX_SUPPORTED_TENSOR_DESCRIPTORS, @@ -879,21 +887,21 @@ void quantize_grouped(const GroupedTensor *input, const GroupedTensor *activatio const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - auto kernel = quantize_grouped_mxfp8_kernel; switch (scaling_type) { case ScalingType::ROWWISE: { - kernel = quantize_grouped_mxfp8_kernel; break; } case ScalingType::COLWISE: { - kernel = quantize_grouped_mxfp8_kernel; break; } case ScalingType::BIDIMENSIONAL: { - kernel = quantize_grouped_mxfp8_kernel; break; } @@ -942,4 +950,4 @@ void quantize_grouped(const GroupedTensor *input, const GroupedTensor *activatio } // namespace dispatch } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_QUANTIZE_GROUPED_MXFP8_CUH_ +#endif // TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_ diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 5a6c08c724..3aa05c1464 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -97,8 +97,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea * \param[in,out] output Output grouped MXFP8 tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_quantize_grouped(const NVTEGroupedTensor input, NVTEGroupedTensor output, - cudaStream_t stream); +void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream); /*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel * based on the value of the 'noop' tensor. @@ -160,8 +160,8 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d * \param[out] workspace Workspace tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_quantize_dbias_grouped(const NVTEGroupedTensor input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the GeLU backward along columns. @@ -186,6 +186,29 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Computes backward of GeLU operation on the gropued input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the GeLU backward along columns. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in] act_input Activation input grouped tensor. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, + NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); + /*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the SiLU backward along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, From b3f846809248f2c529b475448f72110784b57593 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 23 Jan 2026 17:50:14 +0000 Subject: [PATCH 09/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 12 ++++++------ transformer_engine/common/activation/gelu.cu | 7 ++++--- transformer_engine/common/cast/cast.cu | 4 ++-- transformer_engine/common/cast/dispatch/quantize.cuh | 2 +- .../common/cast/mxfp8/group_quantize_mxfp8.cuh | 10 +++++----- .../common/include/transformer_engine/cast.h | 6 +++--- 6 files changed, 21 insertions(+), 20 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 8d52f24fe5..266db0a68f 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -409,32 +409,32 @@ void performTest(const ProcessingMethod processing_method, out_scales_rowwise_ref.data(), out_scales_colwise_ref.data(), ref_output_dbias.data(), rows, cols, scales_stride_rowwise, - scales_stride_colwise, + scales_stride_colwise, is_single_tensor); } else { for (size_t t = 0; t < num_tensors; ++t) { const size_t M = first_dims_h[t]; const size_t K = last_dims_h[t]; - + const size_t scales_stride_rowwise = K / 32; const size_t scales_stride_colwise = K; const size_t data_offset = offsets_h[t]; const size_t sfs_offset = data_offset / 32; - + const InputType* const grad_ptr = grad_data.data() + data_offset; const InputType* const in_ptr = in_data.data() + data_offset; OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + sfs_offset; fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + sfs_offset; - + compute_ref( processing_method, OP, rowwise, colwise, in_ptr, grad_ptr, out_data_rowwise_ptr, out_data_colwise_ptr, out_scales_rowwise_ptr, out_scales_colwise_ptr, ref_output_dbias.data(), M, K, scales_stride_rowwise, - scales_stride_colwise, + scales_stride_colwise, is_single_tensor); } } @@ -600,7 +600,7 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { const std::vector input_config = std::get<3>(GetParam()); const DType input_type = std::get<4>(GetParam()); const DType output_type = std::get<5>(GetParam()); - + const ShapeRepresentation shape_rep = static_cast(input_config[0]); const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM); diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index 5e1d0d4fe1..201ca52507 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -33,9 +33,10 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati input, activation_input, output, dbias, workspace, nullptr, stream); } -void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { +void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_dbias_dgelu); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index c24725495a..582172a88e 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -27,7 +27,7 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea } void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output, - cudaStream_t stream) { + cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize); using namespace transformer_engine; @@ -70,7 +70,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d } void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 87318a18a5..d42c967486 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -18,8 +18,8 @@ #include "../../util/vectorized_pointwise.h" #include "../core/common.cuh" #include "../fp8/quantize_fp8.cuh" -#include "../mxfp8/quantize_mxfp8.cuh" #include "../mxfp8/group_quantize_mxfp8.cuh" +#include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" #include "../nvfp4/quantize_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index a75c7607f0..b87896dadf 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -769,7 +769,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations if (IS_DACT) { NVTE_CHECK(activations->has_data(), "Activations tensor must have data."); NVTE_CHECK(input->num_tensors == activations->num_tensors, - "Number of grad and activations tensors must be same."); + "Number of grad and activations tensors must be same."); NVTE_CHECK(input->dtype() == activations->dtype(), "Grad and activations tensors must have the same type."); } @@ -888,21 +888,21 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; auto kernel = group_quantize_mxfp8_kernel; + OType, true, true>; switch (scaling_type) { case ScalingType::ROWWISE: { kernel = group_quantize_mxfp8_kernel; + OType, true, false>; break; } case ScalingType::COLWISE: { kernel = group_quantize_mxfp8_kernel; + OType, false, true>; break; } case ScalingType::BIDIMENSIONAL: { kernel = group_quantize_mxfp8_kernel; + OType, true, true>; break; } } diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 3aa05c1464..e740b90ef2 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -205,9 +205,9 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu * \param[out] workspace Workspace tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, - NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream); +void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor act_input, NVTEGroupedTensor output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the SiLU backward along columns. From 12351675fc0577fd8329efc8293c60d2ceac279c Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Sat, 24 Jan 2026 00:52:22 +0000 Subject: [PATCH 10/12] Added grouped MXFP8 DACT and ACT API and tests Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 56 ++++++--- transformer_engine/common/activation/gelu.cu | 60 +++++++++- transformer_engine/common/activation/relu.cu | 72 ++++++++++++ .../common/activation/swiglu.cu | 36 ++++++ .../cast/mxfp8/group_quantize_mxfp8.cuh | 9 +- .../include/transformer_engine/activation.h | 110 ++++++++++++++++++ .../common/include/transformer_engine/cast.h | 92 +++++++++++++++ 7 files changed, 410 insertions(+), 25 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 266db0a68f..5a3cf40828 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -454,10 +454,10 @@ void performTest(const ProcessingMethod processing_method, } case ProcessingMethod::CAST_DBIAS_DACT: { auto nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dgelu; - // if (OP == &dsilu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsilu; } - // else if (OP == &drelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_drelu; } - // else if (OP == &dqgelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dqgelu; } - // else if (OP == &dsrelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsrelu; } + if (OP == &dsilu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsilu; } + else if (OP == &drelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_drelu; } + else if (OP == &dqgelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dqgelu; } + else if (OP == &dsrelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsrelu; } nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); @@ -466,6 +466,24 @@ void performTest(const ProcessingMethod processing_method, output_dbias.data(), workspace.data(), 0); break; } + case ProcessingMethod::CAST_ACT: { + auto nvte_group_act = &nvte_group_gelu; + if (OP == &silu) { nvte_group_act = &nvte_group_silu; } + else if (OP == &relu) { nvte_group_act = &nvte_group_relu; } + else if (OP == &qgelu) { nvte_group_act = &nvte_group_qgelu; } + else if (OP == &srelu) { nvte_group_act = &nvte_group_srelu; } + nvte_group_act(in_group_tensor, out_group_tensor, 0); + break; + } + case ProcessingMethod::CAST_DACT: { + auto nvte_group_dact = &nvte_group_dgelu; + if (OP == &dsilu) { nvte_group_dact = &nvte_group_dsilu; } + else if (OP == &drelu) { nvte_group_dact = &nvte_group_drelu; } + else if (OP == &dqgelu) { nvte_group_dact = &nvte_group_dqgelu; } + else if (OP == &dsrelu) { nvte_group_dact = &nvte_group_dsrelu; } + nvte_group_dact(grad_group_tensor, in_group_tensor, out_group_tensor, 0); + break; + } } cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -538,8 +556,8 @@ std::vector processing_methods = { ProcessingMethod::CAST_ONLY, ProcessingMethod::CAST_DBIAS, ProcessingMethod::CAST_DBIAS_DACT, - // ProcessingMethod::CAST_DACT, - // ProcessingMethod::CAST_ACT, + ProcessingMethod::CAST_DACT, + ProcessingMethod::CAST_ACT, }; std::vector activation_kinds = { @@ -669,19 +687,19 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { if (processing_method == ProcessingMethod::CAST_ACT) { switch (activation) { case ActivationKind::GeLU: OP = &gelu; break; - // case ActivationKind::SiLU: OP = &silu; break; - // case ActivationKind::ReLU: OP = &relu; break; - // case ActivationKind::QGeLU: OP = &qgelu; break; - // case ActivationKind::SReLU: OP = &srelu; break; + case ActivationKind::SiLU: OP = &silu; break; + case ActivationKind::ReLU: OP = &relu; break; + case ActivationKind::QGeLU: OP = &qgelu; break; + case ActivationKind::SReLU: OP = &srelu; break; } } else if (processing_method == ProcessingMethod::CAST_DACT || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { switch (activation) { case ActivationKind::GeLU: OP = &dgelu; break; - // case ActivationKind::SiLU: OP = &dsilu; break; - // case ActivationKind::ReLU: OP = &drelu; break; - // case ActivationKind::QGeLU: OP = &dqgelu; break; - // case ActivationKind::SReLU: OP = &dsrelu; break; + case ActivationKind::SiLU: OP = &dsilu; break; + case ActivationKind::ReLU: OP = &drelu; break; + case ActivationKind::QGeLU: OP = &dqgelu; break; + case ActivationKind::SReLU: OP = &dsrelu; break; } } @@ -733,13 +751,13 @@ INSTANTIATE_TEST_SUITE_P( name += "X" + to_string(std::get<1>(info.param)); switch (std::get<2>(info.param)) { - case ScalingDirection::ROWWISE: name += "_ROWWISE"; break; - case ScalingDirection::COLWISE: name += "_COLWISE"; break; - case ScalingDirection::BOTH: name += "_BOTH"; break; + case ScalingDirection::ROWWISE: name += "_ROWWISE_"; break; + case ScalingDirection::COLWISE: name += "_COLWISE_"; break; + case ScalingDirection::BOTH: name += "_BIDIMENSIONAL_"; break; } const std::vector input = std::get<3>(info.param); - name += "_Shape_"; + switch(static_cast(input[0])) { case ShapeRepresentation::SAME_BOTH_DIMS: name += "SAME_BOTH_DIMS"; break; case ShapeRepresentation::VARYING_FIRST_DIM: name += "VARYING_FIRST_DIM"; break; @@ -749,7 +767,7 @@ INSTANTIATE_TEST_SUITE_P( name += "_N_" + std::to_string(input[1]); - name += "_Shape_" + + name += "_SHAPE_" + std::to_string(input[2]) + "X" + std::to_string(input[3]); diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index 201ca52507..c65a7b4de4 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -13,6 +13,14 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { act_fn>(input, output, stream); } +void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_gelu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>( + input, output, nullptr, stream); +} + void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgelu); @@ -20,6 +28,20 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dgelu); + using namespace transformer_engine; + NVTETensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { @@ -37,7 +59,7 @@ void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_dgelu); + NVTE_API_CALL(nvte_group_quantize_dbias_dgelu); using namespace transformer_engine; constexpr bool IS_DBIAS = true; @@ -68,6 +90,14 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) act_fn>(input, output, stream); } +void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_qgelu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>( + input, output, nullptr, stream); +} + void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgelu); @@ -75,6 +105,20 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu dact_fn>(grad, input, output, stream); } +void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dqgelu); + using namespace transformer_engine; + NVTETensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { @@ -88,6 +132,20 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index fd70e38c1a..042c7cb4d0 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -13,6 +13,14 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { act_fn>(input, output, stream); } +void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_relu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>( + input, output, nullptr, stream); +} + void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_drelu); @@ -20,6 +28,20 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_drelu); + using namespace transformer_engine; + NVTETensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { @@ -33,6 +55,20 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_drelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; @@ -54,6 +90,14 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) act_fn>(input, output, stream); } +void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_srelu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>( + input, output, nullptr, stream); +} + void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsrelu); @@ -61,6 +105,20 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu dact_fn>(grad, input, output, stream); } +void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dsrelu); + using namespace transformer_engine; + NVTETensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { @@ -74,6 +132,20 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_dsrelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_sreglu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index cc812a17fa..6c6c3fb8db 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -13,6 +13,14 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { act_fn>(input, output, stream); } +void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_silu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>( + input, output, nullptr, stream); +} + void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsilu); @@ -20,6 +28,20 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dsilu); + using namespace transformer_engine; + NVTETensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { @@ -33,6 +55,20 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_dsilu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swiglu); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index b87896dadf..b658b56cfd 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -176,7 +176,7 @@ __global__ void update_tma_descriptors( const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, const int64_t *const __restrict__ last_dims_ptr, const bool rowwise, const bool colwise, - const bool compute_activations) { + const bool compute_dactivations) { const bool leading_thread = (threadIdx.x == 0); const size_t tensor_id = blockIdx.x; @@ -192,7 +192,7 @@ __global__ void update_tma_descriptors( modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id], global_data_ptr, rows, cols); } - if (compute_activations) { + if (compute_dactivations) { const uintptr_t global_data_ptr = reinterpret_cast(act_input_data_ptr + offset_elts); modify_base_tensor_map(base_tensor_map_act_input, &g_tensor_maps_act_input[tensor_id], @@ -912,8 +912,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations const IType *const input_dptr = reinterpret_cast(input->data.dptr); const IType *const act_input_dptr = - (IS_DACT || IS_ACT) ? reinterpret_cast(activations->data.dptr) - : nullptr; + IS_DACT ? reinterpret_cast(activations->data.dptr) : nullptr; OType *const output_rowwise_dptr = use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; @@ -926,7 +925,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling, - use_colwise_scaling, IS_ACT); + use_colwise_scaling, IS_DACT); } NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 55cd44d9de..3376dc522d 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -52,6 +52,16 @@ enum class NVTE_Activation_Type { */ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the GeLU activation of the grouped input. + * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); + /*! \brief Computes the SiLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -62,6 +72,16 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the SiLU activation of the grouped input. + * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); + /*! \brief Computes the ReLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -72,6 +92,16 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the ReLU activation of the grouped input. + * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); + /*! \brief Computes the Quick GeLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -82,6 +112,16 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Quick GeLU activation of the grouped input. + * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); + /*! \brief Computes the Squared ReLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -92,6 +132,16 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Squared ReLU activation of the grouped input. + * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); + /*! \brief Computes the GeLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -104,6 +154,18 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the GeLU activation gradient of the grouped input. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming grouped gradient. + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTETensor output, + cudaStream_t stream); + /*! \brief Computes the SiLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -116,6 +178,18 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the SiLU activation gradient of the grouped input. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming grouped gradient. + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTETensor output, + cudaStream_t stream); + /*! \brief Computes the ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -128,6 +202,18 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the ReLU activation gradient of the grouped input. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming grouped gradient. + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTETensor output, + cudaStream_t stream); + /*! \brief Computes the Quick GeLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -140,6 +226,18 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Quick GeLU activation gradient of the grouped input. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming grouped gradient. + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTETensor output, + cudaStream_t stream); + /*! \brief Computes the Squared ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -152,6 +250,18 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Squared ReLU activation gradient of the grouped input. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming grouped gradient. + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTETensor output, + cudaStream_t stream); + /*! \brief Computes the gated GeLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index e740b90ef2..c786335f80 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -232,6 +232,29 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Computes backward of SiLU operation on the gropued input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the SiLU backward along columns. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in] act_input Activation input grouped tensor. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, + const NVTEGroupedTensor act_input, NVTEGroupedTensor output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + /*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the ReLU backward along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -255,6 +278,29 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Computes backward of ReLU operation on the gropued input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the ReLU backward along columns. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in] act_input Activation input grouped tensor. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor act_input, NVTEGroupedTensor output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + /*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Quick GeLU backward along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -278,6 +324,29 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Computes backward of Quick GeLU operation on the gropued input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the Quick GeLU backward along columns. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in] act_input Activation input grouped tensor. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor act_input, NVTEGroupedTensor output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + /*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Squared ReLU backward along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -301,6 +370,29 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Computes backward of Squared ReLU operation on the gropued input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the Squared ReLU backward along columns. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in] act_input Activation input grouped tensor. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor act_input, NVTEGroupedTensor output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + /*! \brief Casts input tensor from reduced to higher precision. * If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING, * the block dequantization (MXFP8) of the specified shape of the block will be used. From 34b9dfdb3938777ce9f9cdd8f6e0163df6435606 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 24 Jan 2026 00:53:17 +0000 Subject: [PATCH 11/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/activation/gelu.cu | 13 ++++++------ transformer_engine/common/activation/relu.cu | 13 ++++++------ .../common/activation/swiglu.cu | 4 ++-- .../include/transformer_engine/activation.h | 20 +++++++++---------- 4 files changed, 26 insertions(+), 24 deletions(-) diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index c65a7b4de4..d209ea8d47 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -17,8 +17,8 @@ void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cu NVTE_API_CALL(nvte_group_gelu); using namespace transformer_engine; constexpr bool IS_ACT = true; - dispatch::group_quantize_fwd_helper>( - input, output, nullptr, stream); + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); } void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, @@ -90,12 +90,13 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) act_fn>(input, output, stream); } -void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { +void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { NVTE_API_CALL(nvte_group_qgelu); using namespace transformer_engine; constexpr bool IS_ACT = true; - dispatch::group_quantize_fwd_helper>( - input, output, nullptr, stream); + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); } void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, @@ -106,7 +107,7 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu } void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, - NVTEGroupedTensor output, cudaStream_t stream) { + NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dqgelu); using namespace transformer_engine; NVTETensor dbias = nullptr; diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index 042c7cb4d0..b6f758caf6 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -17,8 +17,8 @@ void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cu NVTE_API_CALL(nvte_group_relu); using namespace transformer_engine; constexpr bool IS_ACT = true; - dispatch::group_quantize_fwd_helper>( - input, output, nullptr, stream); + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); } void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, @@ -90,12 +90,13 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) act_fn>(input, output, stream); } -void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { +void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { NVTE_API_CALL(nvte_group_srelu); using namespace transformer_engine; constexpr bool IS_ACT = true; - dispatch::group_quantize_fwd_helper>( - input, output, nullptr, stream); + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); } void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, @@ -106,7 +107,7 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu } void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, - NVTEGroupedTensor output, cudaStream_t stream) { + NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dsrelu); using namespace transformer_engine; NVTETensor dbias = nullptr; diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 6c6c3fb8db..77d5b6867f 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -17,8 +17,8 @@ void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cu NVTE_API_CALL(nvte_group_silu); using namespace transformer_engine; constexpr bool IS_ACT = true; - dispatch::group_quantize_fwd_helper>( - input, output, nullptr, stream); + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); } void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 3376dc522d..4c9eed3365 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -163,8 +163,8 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output * \param[in,out] output Output grouped tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTETensor output, - cudaStream_t stream); +void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTETensor output, cudaStream_t stream); /*! \brief Computes the SiLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -187,8 +187,8 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output * \param[in,out] output Output grouped tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTETensor output, - cudaStream_t stream); +void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTETensor output, cudaStream_t stream); /*! \brief Computes the ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -211,8 +211,8 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output * \param[in,out] output Output grouped tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTETensor output, - cudaStream_t stream); +void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTETensor output, cudaStream_t stream); /*! \brief Computes the Quick GeLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -235,8 +235,8 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu * \param[in,out] output Output grouped tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTETensor output, - cudaStream_t stream); +void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTETensor output, cudaStream_t stream); /*! \brief Computes the Squared ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -259,8 +259,8 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu * \param[in,out] output Output grouped tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTETensor output, - cudaStream_t stream); +void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTETensor output, cudaStream_t stream); /*! \brief Computes the gated GeLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, From 6dd381405309c47484bcab65b111a19de4838581 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Sat, 24 Jan 2026 01:07:32 +0000 Subject: [PATCH 12/12] Fixed a typo Signed-off-by: Oleg Goncharov --- .../common/include/transformer_engine/cast.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index c786335f80..04712d3003 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -186,7 +186,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); -/*! \brief Computes backward of GeLU operation on the gropued input, then casts to FP8/MXFP8. +/*! \brief Computes backward of GeLU operation on the grouped input, then casts to FP8/MXFP8. * Additionally, reduces the result of the GeLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -232,7 +232,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); -/*! \brief Computes backward of SiLU operation on the gropued input, then casts to FP8/MXFP8. +/*! \brief Computes backward of SiLU operation on the grouped input, then casts to FP8/MXFP8. * Additionally, reduces the result of the SiLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -278,7 +278,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); -/*! \brief Computes backward of ReLU operation on the gropued input, then casts to FP8/MXFP8. +/*! \brief Computes backward of ReLU operation on the grouped input, then casts to FP8/MXFP8. * Additionally, reduces the result of the ReLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -324,7 +324,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); -/*! \brief Computes backward of Quick GeLU operation on the gropued input, then casts to FP8/MXFP8. +/*! \brief Computes backward of Quick GeLU operation on the grouped input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Quick GeLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -370,7 +370,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); -/*! \brief Computes backward of Squared ReLU operation on the gropued input, then casts to FP8/MXFP8. +/*! \brief Computes backward of Squared ReLU operation on the grouped input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Squared ReLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used.