diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 26efb37962..6c615d63e0 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -11,6 +11,7 @@ add_executable(test_operator test_cast_mxfp8_gated_swiglu.cu test_qdq.cu test_cast_mxfp8.cu + test_cast_mxfp8_grouped.cu test_cast_nvfp4_transpose.cu test_cast_float8blockwise.cu test_dequantize_mxfp8.cu diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu new file mode 100644 index 0000000000..5a3cf40828 --- /dev/null +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -0,0 +1,777 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum ProcessingMethod { + CAST_ONLY, + CAST_DBIAS, + CAST_DBIAS_DACT, + CAST_DACT, + CAST_ACT +}; + +enum ActivationKind { + Identity, + GeLU, + SiLU, + ReLU, + QGeLU, + SReLU +}; + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + +template +void compute_ref(const ProcessingMethod processing_method, + float (*OP)(const float), + const bool rowwise, + const bool colwise, + const InputType* input, + const InputType* grad, + OutputType* output_rowwise, + OutputType* output_colwise, + fp8e8m0* output_scales_rowwise, + fp8e8m0* output_scales_colwise, + InputType* output_dbias, + const size_t rows, + const size_t cols, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise, + const bool is_single_tensor) +{ + const size_t tile_size_Y = 32; + const size_t tile_size_X = 32; + const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; + const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; + + std::vector output_dbias_fp32(cols, 0); + #pragma omp parallel proc_bind(spread) + { + // Buffers to cache intermediate computations + std::vector cache_buffer(tile_size_Y * tile_size_X); + + std::vector thread_dbias(cols, 0); + #pragma omp for schedule(static) + for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { + const size_t tile_Y = t / tiles_num_X; + const size_t tile_X = t % tiles_num_X; + const size_t tile_offset_Y = tile_Y * tile_size_Y; + const size_t tile_offset_X = tile_X * tile_size_X; + + const size_t i_min = tile_offset_Y; + const size_t i_max = std::min(i_min + tile_size_Y, rows); + + const size_t j_min = tile_offset_X; + const size_t j_max = std::min(j_min + tile_size_X, cols); + + // Cache computations + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + + float elt = static_cast(input[idx]); + if (processing_method == ProcessingMethod::CAST_DBIAS) { + // grad is the input + elt = static_cast(grad[idx]); + } + if (processing_method != ProcessingMethod::CAST_ONLY + && processing_method != ProcessingMethod::CAST_DBIAS) { + elt = OP(elt); + } + if (processing_method == ProcessingMethod::CAST_DACT || + processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + elt *= static_cast(grad[idx]); + } + thread_dbias[j] += elt; + + // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32 + elt = static_cast(static_cast(elt)); + + cache_buffer[cache_idx] = elt; + if (isinf(elt) || isnan(elt)) { + continue; + } + } + } + + if (rowwise) { + for (size_t i = i_min; i < i_max; ++i) { + float block_amax = 0.0f; + + for (size_t j = j_min; j < j_max; ++j) { + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const size_t scale_idx = i * scales_stride_rowwise + tile_X; + output_scales_rowwise[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_rowwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } + } + } + if (colwise) { + for (size_t j = j_min; j < j_max; ++j) { + float block_amax = 0.0f; + + for (size_t i = i_min; i < i_max; ++i) { + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const size_t scale_idx = tile_Y * scales_stride_colwise + j; + output_scales_colwise[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t i = i_min; i < i_max; ++i) { + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_colwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } + } + } + } + #pragma omp critical + { + for (size_t j = 0; j < cols; ++j) { + output_dbias_fp32[j] += thread_dbias[j]; + } + } + } + + if (is_single_tensor) { + for (size_t j = 0; j < cols; ++j) { + output_dbias[j] = static_cast(output_dbias_fp32[j]); + } + } +} + +template +void compare_scaled_elts(const std::string &name, + const T* ref_data, + const T* test_data, + const size_t rows, + const size_t cols, + const bool rowwise, + const size_t tolerable_mismatches_limit = 0, + const double atol = 1e-5, + const double rtol = 1e-8) { + size_t mismatches_num = 0; + int first_mismatch_idx = -1; + + for (size_t i = 0; i < rows * cols; ++i) { + double t = static_cast(test_data[i]); + double r = static_cast(ref_data[i]); + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = false; + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + std::string direction = rowwise ? "rowwise" : "columnwise"; + if (assertion) { + mismatches_num++; + if (first_mismatch_idx == -1) { + first_mismatch_idx = i; + } + } + if (mismatches_num > tolerable_mismatches_limit) { + const double first_mismatch_t = static_cast(test_data[first_mismatch_idx]); + const double first_mismatch_r = static_cast(ref_data[first_mismatch_idx]); + + GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "." << std::endl + << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "First mismatch at place " << first_mismatch_idx + << " (" << std::to_string(first_mismatch_idx) << "): " + << first_mismatch_t << " vs " << first_mismatch_r; + } + } +} + +/** + * Scaling along single dimension (either rows or columns) + * Produces one set of output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * OR + * 2) Scaled columns + column-wise scaling factors + */ +template +void performTest(const ProcessingMethod processing_method, + float (*OP)(const float), + const ShapeRepresentation shape_rep, + const size_t num_tensors, + const std::vector& logical_shape_vec, + const std::vector& first_dims_h, + const std::vector& last_dims_h, + const std::vector& offsets_h, + const bool rowwise, + const bool colwise) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = logical_shape_vec[0]; + const size_t cols = logical_shape_vec[1]; + + const size_t elts_num = rows * cols; + const size_t sfs_num = (rows * cols) / 32; + + const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM); + std::vector scales_shape = {sfs_num}; + + std::mt19937 gen; + std::uniform_real_distribution<> dis(-2.0, 1.0); + + std::vector in_data(elts_num); + std::vector grad_data(elts_num); + + std::vector out_data_rowwise_h(rowwise ? elts_num : 0); + std::vector out_data_colwise_h(colwise ? elts_num : 0); + std::vector out_scales_rowwise_h(rowwise ? sfs_num : 0); + std::vector out_scales_colwise_h(colwise ? sfs_num : 0); + + std::vector out_data_rowwise_ref(rowwise ? elts_num : 0); + std::vector out_data_colwise_ref(colwise ? elts_num : 0); + std::vector out_scales_rowwise_ref(rowwise ? sfs_num : 0); + std::vector out_scales_colwise_ref(colwise ? sfs_num : 0); + + std::vector ref_output_dbias(is_single_tensor ? cols : 0); + + for (size_t i = 0; i < elts_num; ++i) { + const float val = dis(gen); + grad_data[i] = static_cast(val); + in_data[i] = static_cast(val); + } + + const OutputType zero_elt = static_cast(0.0f); + const fp8e8m0 zero_SF = static_cast(0.0f); + if (rowwise) { + std::fill(out_data_rowwise_h.begin(), out_data_rowwise_h.end(), zero_elt); + std::fill(out_data_rowwise_ref.begin(), out_data_rowwise_ref.end(), zero_elt); + std::fill(out_scales_rowwise_h.begin(), out_scales_rowwise_h.end(), zero_SF); + std::fill(out_scales_rowwise_ref.begin(), out_scales_rowwise_ref.end(), zero_SF); + } + if (colwise) { + std::fill(out_data_colwise_h.begin(), out_data_colwise_h.end(), zero_elt); + std::fill(out_data_colwise_ref.begin(), out_data_colwise_ref.end(), zero_elt); + std::fill(out_scales_colwise_h.begin(), out_scales_colwise_h.end(), zero_SF); + std::fill(out_scales_colwise_ref.begin(), out_scales_colwise_ref.end(), zero_SF); + } + + const size_t in_data_size = elts_num * sizeof(InputType); + const size_t out_data_size = elts_num * sizeof(OutputType); + const size_t out_scales_size = sfs_num * sizeof(fp8e8m0); + + const size_t first_dims_size = num_tensors * sizeof(size_t); + const size_t last_dims_size = num_tensors * sizeof(size_t); + const size_t offsets_size = (num_tensors + 1) * sizeof(size_t); + + InputType* grad_data_d; + InputType* in_data_d; + OutputType* out_data_rowwise_d; + OutputType* out_data_colwise_d; + fp8e8m0* out_scales_rowwise_d; + fp8e8m0* out_scales_colwise_d; + size_t* first_dims_d; + size_t* last_dims_d; + size_t* offsets_d; + + cudaMalloc((void**)&grad_data_d, in_data_size); + cudaMalloc((void**)&in_data_d, in_data_size); + cudaMalloc((void**)&first_dims_d, first_dims_size); + cudaMalloc((void**)&last_dims_d, last_dims_size); + cudaMalloc((void**)&offsets_d, offsets_size); + + cudaMemcpy(grad_data_d, grad_data.data(), in_data_size, cudaMemcpyHostToDevice); + cudaMemcpy(in_data_d, in_data.data(), in_data_size, cudaMemcpyHostToDevice); + cudaMemcpy(first_dims_d, first_dims_h.data(), first_dims_size, cudaMemcpyHostToDevice); + cudaMemcpy(last_dims_d, last_dims_h.data(), last_dims_size, cudaMemcpyHostToDevice); + cudaMemcpy(offsets_d, offsets_h.data(), offsets_size, cudaMemcpyHostToDevice); + + NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size()); + + NVTEShape first_dims_shape_; + NVTEShape last_dims_shape_; + NVTEShape offsets_shape_; + + first_dims_shape_.ndim = 1; + last_dims_shape_.ndim = 1; + offsets_shape_.ndim = 1; + + first_dims_shape_.data[0] = num_tensors; + last_dims_shape_.data[0] = num_tensors; + offsets_shape_.data[0] = num_tensors + 1; + + NVTEGroupedTensor grad_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); + NVTEGroupedTensor in_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); + NVTEGroupedTensor out_group_tensor = nvte_create_grouped_tensor(NVTE_MXFP8_1D_SCALING, num_tensors, logical_shape_); + + NVTEBasicTensor grad_data_tensor = {grad_data_d, static_cast(itype), logical_shape_}; + NVTEBasicTensor in_data_tensor = {in_data_d, static_cast(itype), logical_shape_}; + nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor); + nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &grad_data_tensor); + + if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { + NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape_}; + nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor); + nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor); + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor); + } + + if ((shape_rep == VARYING_LAST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { + NVTEBasicTensor last_dims_tensor = {last_dims_d, kNVTEInt64, last_dims_shape_}; + nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor); + nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor); + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor); + } + + if (shape_rep != SAME_BOTH_DIMS) { + NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64, offsets_shape_}; + nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor); + nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor); + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor); + } + + if (rowwise) { + cudaMalloc((void**)&out_data_rowwise_d, out_data_size); + cudaMalloc((void**)&out_scales_rowwise_d, out_scales_size); + cudaMemset(out_data_rowwise_d, 0, out_data_size); + cudaMemset(out_scales_rowwise_d, 0, out_scales_size); + NVTEBasicTensor out_data_rowwise_tensor = {out_data_rowwise_d, static_cast(otype), logical_shape_}; + NVTEShape scales_rowwise_shape_ = nvte_make_shape(scales_shape.data(), scales_shape.size()); + NVTEBasicTensor out_scales_rowwise_tensor = {out_scales_rowwise_d, NVTEDType::kNVTEFloat8E8M0, scales_rowwise_shape_}; + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &out_data_rowwise_tensor); + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv, &out_scales_rowwise_tensor); + } + + if (colwise) { + cudaMalloc((void**)&out_data_colwise_d, out_data_size); + cudaMalloc((void**)&out_scales_colwise_d, out_scales_size); + cudaMemset(out_data_colwise_d, 0, out_data_size); + cudaMemset(out_scales_colwise_d, 0, out_scales_size); + NVTEBasicTensor out_data_colwise_tensor = {out_data_colwise_d, static_cast(otype), logical_shape_}; + NVTEShape scales_colwise_shape_ = nvte_make_shape(scales_shape.data(), scales_shape.size()); + NVTEBasicTensor out_scales_colwise_tensor = {out_scales_colwise_d, NVTEDType::kNVTEFloat8E8M0, scales_colwise_shape_}; + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData, &out_data_colwise_tensor); + nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, &out_scales_colwise_tensor); + } + + Tensor output_dbias("output_dbias", std::vector{ cols }, itype); + + // Reference (CPU) + if (is_single_tensor) { + const size_t scales_stride_rowwise = cols / 32; + const size_t scales_stride_colwise = cols; + + compute_ref( + processing_method, OP, rowwise, colwise, in_data.data(), grad_data.data(), + out_data_rowwise_ref.data(), out_data_colwise_ref.data(), + out_scales_rowwise_ref.data(), out_scales_colwise_ref.data(), + ref_output_dbias.data(), rows, cols, + scales_stride_rowwise, + scales_stride_colwise, + is_single_tensor); + } else { + for (size_t t = 0; t < num_tensors; ++t) { + const size_t M = first_dims_h[t]; + const size_t K = last_dims_h[t]; + + const size_t scales_stride_rowwise = K / 32; + const size_t scales_stride_colwise = K; + const size_t data_offset = offsets_h[t]; + const size_t sfs_offset = data_offset / 32; + + const InputType* const grad_ptr = grad_data.data() + data_offset; + const InputType* const in_ptr = in_data.data() + data_offset; + OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; + OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; + fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + sfs_offset; + fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + sfs_offset; + + compute_ref( + processing_method, OP, rowwise, colwise, in_ptr, grad_ptr, + out_data_rowwise_ptr, out_data_colwise_ptr, + out_scales_rowwise_ptr, out_scales_colwise_ptr, + ref_output_dbias.data(), M, K, + scales_stride_rowwise, + scales_stride_colwise, + is_single_tensor); + } + } + + // GPU + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_group_quantize(in_group_tensor, out_group_tensor, 0); + break; + } + case ProcessingMethod::CAST_DBIAS: { + nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); + break; + } + case ProcessingMethod::CAST_DBIAS_DACT: { + auto nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dgelu; + if (OP == &dsilu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsilu; } + else if (OP == &drelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_drelu; } + else if (OP == &dqgelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dqgelu; } + else if (OP == &dsrelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsrelu; } + + nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor, + output_dbias.data(), workspace.data(), 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor, + output_dbias.data(), workspace.data(), 0); + break; + } + case ProcessingMethod::CAST_ACT: { + auto nvte_group_act = &nvte_group_gelu; + if (OP == &silu) { nvte_group_act = &nvte_group_silu; } + else if (OP == &relu) { nvte_group_act = &nvte_group_relu; } + else if (OP == &qgelu) { nvte_group_act = &nvte_group_qgelu; } + else if (OP == &srelu) { nvte_group_act = &nvte_group_srelu; } + nvte_group_act(in_group_tensor, out_group_tensor, 0); + break; + } + case ProcessingMethod::CAST_DACT: { + auto nvte_group_dact = &nvte_group_dgelu; + if (OP == &dsilu) { nvte_group_dact = &nvte_group_dsilu; } + else if (OP == &drelu) { nvte_group_dact = &nvte_group_drelu; } + else if (OP == &dqgelu) { nvte_group_dact = &nvte_group_dqgelu; } + else if (OP == &dsrelu) { nvte_group_dact = &nvte_group_dsrelu; } + nvte_group_dact(grad_group_tensor, in_group_tensor, out_group_tensor, 0); + break; + } + } + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + auto [atol, rtol] = getTolerances(otype); + const size_t scale_diff_abs_tolerance = 0; + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; + + if (rowwise) { + cudaMemcpy(out_data_rowwise_h.data(), out_data_rowwise_d, out_data_size, cudaMemcpyDeviceToHost); + cudaMemcpy(out_scales_rowwise_h.data(), out_scales_rowwise_d, out_scales_size, cudaMemcpyDeviceToHost); + + size_t mismatches_scales = 0; + compare_scaling_factors("rowwise_scales", out_scales_rowwise_h.data(), out_scales_rowwise_ref.data(), + 1, sfs_num, sfs_num, mismatches_scales, scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); + + const size_t mismatches_elts = 32 * mismatches_scales; + + compare_scaled_elts("rowwise_output", out_data_rowwise_ref.data(), + out_data_rowwise_h.data(), rows, cols, true, mismatches_elts); + } + + if (colwise) { + cudaMemcpy(out_data_colwise_h.data(), out_data_colwise_d, out_data_size, cudaMemcpyDeviceToHost); + cudaMemcpy(out_scales_colwise_h.data(), out_scales_colwise_d, out_scales_size, cudaMemcpyDeviceToHost); + + size_t mismatches_scales = 0; + compare_scaling_factors("colwise_scales", out_scales_colwise_h.data(), out_scales_colwise_ref.data(), + 1, sfs_num, sfs_num, mismatches_scales, scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); + + const size_t mismatches_elts = 32 * mismatches_scales; + + compare_scaled_elts("colwise_output", out_data_colwise_ref.data(), + out_data_colwise_h.data(), rows, cols, false, mismatches_elts); + } + + if (processing_method == ProcessingMethod::CAST_DBIAS + || processing_method == ProcessingMethod::CAST_DBIAS_DACT) + { + auto [atol_dbias, rtol_dbias] = getTolerances(itype); + if (itype == DType::kFloat32) { + atol_dbias = 1e-4; + rtol_dbias *= sqrt(static_cast(rows)) ; + } else { + rtol_dbias *= 4; + } + compareResults("output_dbias", output_dbias, ref_output_dbias.data(), true, atol_dbias, rtol_dbias); + } + + cudaFree(grad_data_d); + cudaFree(in_data_d); + cudaFree(first_dims_d); + cudaFree(last_dims_d); + cudaFree(offsets_d); + if (rowwise) { + cudaFree(out_data_rowwise_d); + cudaFree(out_scales_rowwise_d); + } + if (colwise) { + cudaFree(out_data_colwise_d); + cudaFree(out_scales_colwise_d); + } +} + +std::vector processing_methods = { + ProcessingMethod::CAST_ONLY, + ProcessingMethod::CAST_DBIAS, + ProcessingMethod::CAST_DBIAS_DACT, + ProcessingMethod::CAST_DACT, + ProcessingMethod::CAST_ACT, +}; + +std::vector activation_kinds = { + ActivationKind::Identity, + ActivationKind::GeLU, + // ActivationKind::SiLU, + // ActivationKind::ReLU, + // ActivationKind::QGeLU, + // ActivationKind::SReLU, +}; + +enum ScalingDirection { + ROWWISE = 0, + COLWISE = 1, + BOTH = 2 +}; + +std::vector scaling_directions = { + ScalingDirection::ROWWISE, + ScalingDirection::COLWISE, + ScalingDirection::BOTH, +}; + +// {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} +std::vector> input_config = { + {SAME_BOTH_DIMS, 1, 128,128}, + {SAME_BOTH_DIMS, 2, 256,128}, + {VARYING_FIRST_DIM, 2, 512,128, 128,384}, + {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, + {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, + {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, + {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, +}; + +} // namespace + +class GroupedFusedCastMXFP8TestSuite : public ::testing::TestWithParam + , // Config + transformer_engine::DType, // InputType + transformer_engine::DType // OutputType + >> {}; + +TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ProcessingMethod processing_method = std::get<0>(GetParam()); + const ActivationKind activation = std::get<1>(GetParam()); + const ScalingDirection scaling_direction = std::get<2>(GetParam()); + const std::vector input_config = std::get<3>(GetParam()); + const DType input_type = std::get<4>(GetParam()); + const DType output_type = std::get<5>(GetParam()); + + const ShapeRepresentation shape_rep = static_cast(input_config[0]); + const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM); + + const size_t num_tensors = input_config[1]; + const std::vector logical_shape = {input_config[2], input_config[3]}; + std::vector first_dims(num_tensors); + std::vector last_dims(num_tensors); + std::vector offsets(num_tensors + 1, 0); + for (size_t t = 0; t < num_tensors; ++t) { + switch (shape_rep) { + case SAME_BOTH_DIMS: { + first_dims[t] = logical_shape[0] / num_tensors; + last_dims[t] = logical_shape[1]; + break; + } + case VARYING_FIRST_DIM: { + first_dims[t] = input_config[t + 4]; + last_dims[t] = logical_shape[1]; + break; + } + case VARYING_LAST_DIM: { + first_dims[t] = logical_shape[0]; + last_dims[t] = input_config[t + 4]; + break; + } + case VARYING_BOTH_DIMS: { + first_dims[t] = input_config[t + 4]; + last_dims[t] = input_config[t + (4 + num_tensors)]; + break; + } + } + offsets[t+1] = offsets[t] + first_dims[t] * last_dims[t]; + // Skips tests if tensor dims are not multiples of 128 + if ((first_dims[t] % 128 != 0) || (last_dims[t] % 128 != 0)) { + GTEST_SKIP(); + } + } + // Skips DBias tests if last dimension of tensors variates + if ((processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) + && !is_single_tensor) { + GTEST_SKIP(); + } + + // Skips non Act tests if the Activation type is not an identity + if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) + && activation != ActivationKind::Identity) { + GTEST_SKIP(); + } + // Skips Act tests if the Activation is an identity + if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT + || processing_method == ProcessingMethod::CAST_DACT + || processing_method == ProcessingMethod::CAST_ACT) && (activation == ActivationKind::Identity)) { + GTEST_SKIP(); + } + + bool rowwise = false; + bool colwise = false; + switch (scaling_direction) { + case ScalingDirection::ROWWISE: rowwise = true; break; + case ScalingDirection::COLWISE: colwise = true; break; + case ScalingDirection::BOTH: rowwise = true; colwise = true; break; + } + + auto OP = &identity; + + if (processing_method == ProcessingMethod::CAST_ACT) { + switch (activation) { + case ActivationKind::GeLU: OP = &gelu; break; + case ActivationKind::SiLU: OP = &silu; break; + case ActivationKind::ReLU: OP = &relu; break; + case ActivationKind::QGeLU: OP = &qgelu; break; + case ActivationKind::SReLU: OP = &srelu; break; + } + } else if (processing_method == ProcessingMethod::CAST_DACT + || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + switch (activation) { + case ActivationKind::GeLU: OP = &dgelu; break; + case ActivationKind::SiLU: OP = &dsilu; break; + case ActivationKind::ReLU: OP = &drelu; break; + case ActivationKind::QGeLU: OP = &dqgelu; break; + case ActivationKind::SReLU: OP = &dsrelu; break; + } + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + performTest(processing_method, OP, shape_rep, num_tensors, + logical_shape, first_dims, last_dims, offsets, + rowwise, colwise); + ); + ); +} + +std::string to_string(const ProcessingMethod method) { + switch (method) { + case ProcessingMethod::CAST_ONLY: return "CAST_ONLY"; + case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS"; + case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT"; + case ProcessingMethod::CAST_DACT: return "CAST_DACT"; + case ProcessingMethod::CAST_ACT: return "CAST_ACT"; + default: return ""; + } +} + +std::string to_string(const ActivationKind activation) { + switch (activation) { + case ActivationKind::Identity: return "Identity"; + case ActivationKind::GeLU: return "GeLU"; + case ActivationKind::SiLU: return "SiLU"; + case ActivationKind::ReLU: return "ReLU"; + case ActivationKind::QGeLU: return "QGeLU"; + case ActivationKind::SReLU: return "SReLU"; + default: return ""; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + GroupedFusedCastMXFP8TestSuite, + ::testing::Combine( + ::testing::ValuesIn(processing_methods), + ::testing::ValuesIn(activation_kinds), + ::testing::ValuesIn(scaling_directions), + ::testing::ValuesIn(input_config), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), + [](const testing::TestParamInfo& info) { + const ProcessingMethod method = std::get<0>(info.param); + std::string name = to_string(method); + name += "X" + to_string(std::get<1>(info.param)); + + switch (std::get<2>(info.param)) { + case ScalingDirection::ROWWISE: name += "_ROWWISE_"; break; + case ScalingDirection::COLWISE: name += "_COLWISE_"; break; + case ScalingDirection::BOTH: name += "_BIDIMENSIONAL_"; break; + } + + const std::vector input = std::get<3>(info.param); + + switch(static_cast(input[0])) { + case ShapeRepresentation::SAME_BOTH_DIMS: name += "SAME_BOTH_DIMS"; break; + case ShapeRepresentation::VARYING_FIRST_DIM: name += "VARYING_FIRST_DIM"; break; + case ShapeRepresentation::VARYING_LAST_DIM: name += "VARYING_LAST_DIM"; break; + case ShapeRepresentation::VARYING_BOTH_DIMS: name += "VARYING_BOTH_DIMS"; break; + }; + + name += "_N_" + std::to_string(input[1]); + + name += "_SHAPE_" + + std::to_string(input[2]) + + "X" + std::to_string(input[3]); + + name += "_" + test::typeName(std::get<4>(info.param)) + + "_" + test::typeName(std::get<5>(info.param)); + return name; + }); diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index 675341f7db..d209ea8d47 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -13,6 +13,14 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { act_fn>(input, output, stream); } +void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_gelu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); +} + void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgelu); @@ -20,6 +28,20 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dgelu); + using namespace transformer_engine; + NVTETensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { @@ -33,6 +55,20 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_dgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; @@ -54,6 +90,15 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) act_fn>(input, output, stream); } +void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_qgelu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); +} + void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgelu); @@ -61,6 +106,20 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu dact_fn>(grad, input, output, stream); } +void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dqgelu); + using namespace transformer_engine; + NVTETensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { @@ -74,6 +133,20 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index fd70e38c1a..b6f758caf6 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -13,6 +13,14 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { act_fn>(input, output, stream); } +void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_relu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); +} + void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_drelu); @@ -20,6 +28,20 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_drelu); + using namespace transformer_engine; + NVTETensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { @@ -33,6 +55,20 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_drelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; @@ -54,6 +90,15 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) act_fn>(input, output, stream); } +void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_srelu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); +} + void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsrelu); @@ -61,6 +106,20 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu dact_fn>(grad, input, output, stream); } +void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dsrelu); + using namespace transformer_engine; + NVTETensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { @@ -74,6 +133,20 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_dsrelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_sreglu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index cc812a17fa..77d5b6867f 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -13,6 +13,14 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { act_fn>(input, output, stream); } +void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_silu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); +} + void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsilu); @@ -20,6 +28,20 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dsilu); + using namespace transformer_engine; + NVTETensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { @@ -33,6 +55,20 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_dsilu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swiglu); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index de1a8864da..582172a88e 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -26,6 +26,15 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea dispatch::quantize_fwd_helper(input, output, nullptr, stream); } +void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + dispatch::group_quantize_fwd_helper(input, output, nullptr, stream); +} + void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_noop); @@ -60,6 +69,19 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = false; + constexpr const NVTEGroupedTensor activation_input = nullptr; + + dispatch::group_quantize_bwd_helper( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dequantize); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index a02e7f4f07..d42c967486 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -18,6 +18,7 @@ #include "../../util/vectorized_pointwise.h" #include "../core/common.cuh" #include "../fp8/quantize_fp8.cuh" +#include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" #include "../nvfp4/quantize_nvfp4.cuh" @@ -371,6 +372,95 @@ void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs, } } +template +void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); + + const NVTEGroupedTensor activation = nullptr; + NVTETensor dbias = nullptr; + NVTETensor workspace = nullptr; + + const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); + GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation); + Tensor *dbias_tensor = convertNVTETensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + mxfp8::group_quantize( + input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } +} + +template +void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); + + const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); + const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); + GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + Tensor *dbias_tensor = convertNVTETensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + mxfp8::group_quantize( + grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } +} + } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh new file mode 100644 index 0000000000..b658b56cfd --- /dev/null +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -0,0 +1,952 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file group_quantize_mxfp8.cuh + * \brief CUDA kernels to quantize grouped tensors to MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "../core/common.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace mxfp8 { +namespace group_quantize_kernel { + +constexpr int MAX_SUPPORTED_TENSOR_DESCRIPTORS = 64; +__device__ alignas(128) CUtensorMap g_tensor_maps_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 32; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t PACK_SIZE = 4; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 128; + +constexpr size_t ELTS_PER_CHUNK = CHUNK_DIM_Y * CHUNK_DIM_X; + +constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; +constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + +constexpr size_t BUFF_DIM_Y = THREADS_Y; +constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; +constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; +static_assert(BUFF_DIM_Y == 32); + +constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; +static_assert(STAGES >= 1); + +// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 + +__device__ __forceinline__ size_t get_current_tensor_id( + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr) { + if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { + const size_t current_row = current_offset / last_logical_dim; + const size_t rows_per_tensor = first_logical_dim / num_tensors; + return current_row / rows_per_tensor; + } else { + size_t low = 0; + size_t hi = num_tensors; // [low, hi] + + while (low < hi) { + const size_t mid = low + (hi - low) / 2; + const size_t mid_offset = static_cast(offsets_ptr[mid]); + + if (mid_offset <= current_offset) { + low = mid + 1; + } else { + hi = mid; + } + } + return low - 1; + } +} + +__device__ __forceinline__ size_t get_tensor_rows_num( + const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t first_logical_dim, + const int64_t *const __restrict__ first_dims_ptr, const size_t num_tensors) { + size_t rows_num = 0; + switch (shape_rep) { + case ShapeRepresentation::SAME_BOTH_DIMS: // rows_num = first_logical_dim / num_tensors; break; + case ShapeRepresentation::VARYING_LAST_DIM: + rows_num = first_logical_dim; + break; + case ShapeRepresentation::VARYING_FIRST_DIM: + case ShapeRepresentation::VARYING_BOTH_DIMS: + rows_num = static_cast(first_dims_ptr[tensor_id]); + break; + } + return rows_num; +} + +__device__ __forceinline__ size_t get_tensor_cols_num( + const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t last_logical_dim, + const int64_t *const __restrict__ last_dims_ptr) { + size_t cols_num = 0; + switch (shape_rep) { + case ShapeRepresentation::SAME_BOTH_DIMS: + case ShapeRepresentation::VARYING_FIRST_DIM: + cols_num = last_logical_dim; + break; + case ShapeRepresentation::VARYING_LAST_DIM: + case ShapeRepresentation::VARYING_BOTH_DIMS: + cols_num = static_cast(last_dims_ptr[tensor_id]); + break; + } + return cols_num; +} + +// Copies the base tensor map to shmem, modifies the copy, stores the modified tensor map at index +template +__device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_tensor_map, + CUtensorMap *global_tensor_map, + const uintptr_t global_data_ptr, + const size_t global_dim_Y, + const size_t global_dim_X) { + const size_t global_stride_bytes = global_dim_X * sizeof(T); + + __shared__ CUtensorMap shared_tensor_map; + shared_tensor_map = base_tensor_map; // Copy the base tensor map into shmem + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + asm volatile( + "{\n\t" + ".reg.b64 tensor_map_ptr; \n\t" + "mov.b64 tensor_map_ptr, %0; \n\t" + "tensormap.replace.tile.global_address.b1024.b64 [tensor_map_ptr], %1; \n\t" + "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 1, %2; \n\t" // DIM Y + "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 0, %3; \n\t" // DIM X + "tensormap.replace.tile.global_stride.b1024.b64 [tensor_map_ptr], 0, %4; \n" + "}\n" ::"l"(reinterpret_cast(&shared_tensor_map)), + "l"(global_data_ptr), "r"(static_cast(global_dim_Y)), + "r"(static_cast(global_dim_X)), "l"(static_cast(global_stride_bytes)) + : "memory"); + *global_tensor_map = shared_tensor_map; + } else { + NVTE_DEVICE_ERROR( + "tensormap.replace is architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } +} + +template +__global__ void update_tma_descriptors( + const __grid_constant__ CUtensorMap base_tensor_map_input, + const __grid_constant__ CUtensorMap base_tensor_map_act_input, + const __grid_constant__ CUtensorMap base_tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap base_tensor_map_output_colwise, + const IType *const __restrict__ input_data_ptr, + const IType *const __restrict__ act_input_data_ptr, + const OType *const __restrict__ output_rowwise_data_ptr, + const OType *const __restrict__ output_colwise_data_ptr, const ShapeRepresentation shape_rep, + const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, const bool rowwise, const bool colwise, + const bool compute_dactivations) { + const bool leading_thread = (threadIdx.x == 0); + const size_t tensor_id = blockIdx.x; + + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + + const size_t offset_elts = offsets_ptr[tensor_id]; + + if (leading_thread && (tensor_id < num_tensors)) { + { + const uintptr_t global_data_ptr = reinterpret_cast(input_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id], + global_data_ptr, rows, cols); + } + if (compute_dactivations) { + const uintptr_t global_data_ptr = + reinterpret_cast(act_input_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_act_input, &g_tensor_maps_act_input[tensor_id], + global_data_ptr, rows, cols); + } + if (rowwise) { + const uintptr_t global_data_ptr = + reinterpret_cast(output_rowwise_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_output_rowwise, + &g_tensor_maps_output_rowwise[tensor_id], global_data_ptr, rows, + cols); + } + if (colwise) { + const uintptr_t global_data_ptr = + reinterpret_cast(output_colwise_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_output_colwise, + &g_tensor_maps_output_colwise[tensor_id], global_data_ptr, rows, + cols); + } + } +} + +__device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tensor_map) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" ::"l"(tensor_map)); +#else + NVTE_DEVICE_ERROR("fence_acquire_tensormap is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( + const __grid_constant__ CUtensorMap tensor_map_input_static, + const __grid_constant__ CUtensorMap tensor_map_act_input_static, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, + const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t first_logical_dim, + const size_t last_logical_dim, const int64_t *const __restrict__ offsets_ptr, + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, + e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, + float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; + + if constexpr (NO_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const size_t block_global_offset = blockIdx.x * ELTS_PER_CHUNK; + + const size_t tensor_id = get_current_tensor_id(shape_rep, num_tensors, block_global_offset, + first_logical_dim, last_logical_dim, offsets_ptr); + + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + const size_t scale_stride_rowwise = cols / SCALE_DIM_X; + const size_t scale_stride_colwise = cols; + + const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + + // grouped tensor can be treated as continuous tensor for MXFP8 + const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); + const size_t offset_within_tensor = block_global_offset - tensor_base; + + const CUtensorMap &tensor_map_input = + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; + const CUtensorMap &tensor_map_act_input = + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; + const CUtensorMap &tensor_map_output_rowwise = + is_single_tensor ? tensor_map_output_rowwise_static : g_tensor_maps_output_rowwise[tensor_id]; + const CUtensorMap &tensor_map_output_colwise = + is_single_tensor ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id]; + + const bool leading_thread = (threadIdx.x == 0); + + if (leading_thread && (!is_single_tensor)) { + fence_acquire_tensormap(&tensor_map_input); + if constexpr (COMPUTE_ACTIVATIONS) { + fence_acquire_tensormap(&tensor_map_act_input); + } + if constexpr (ROWWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_rowwise); + } + if constexpr (COLWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_colwise); + } + } + + const size_t blocks_X_num_in_current_tensor = cols / CHUNK_DIM_X; + const size_t block_id_in_current_tensor = + is_single_tensor ? blockIdx.x : (blockIdx.x - tensor_base / ELTS_PER_CHUNK); + + const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; + const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; + + const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; + const size_t block_offset_X = block_id_X * CHUNK_DIM_X; + + e8m0_t *const scales_rowwise = + scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); + e8m0_t *const scales_colwise = + scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); + + const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X; + const size_t tid_Y_colwise = 0; + const size_t tid_X_colwise = threadIdx.x; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const size_t thread_offset_Y_colwise = tid_Y_colwise; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); + + OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; + if constexpr (IS_DBIAS) { +#pragma unroll + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; + } + } + + float block_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, leading_thread); + + int parity = 0; + + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], + &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], leading_thread); + } else { + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], leading_thread); + } + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_DIM; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, + global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], + leading_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], leading_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], parity); + + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType thread_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); + } + thread_amax = static_cast(thread_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_colwise[i] = elt; + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + const size_t scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } + + if constexpr (ROWWISE_SCALING) { + const size_t shmem_offset_base_rowwise = + buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[e]); + elt *= OP(act_in_elt, {}); + } + + // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (leading_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + + parity ^= 1; + + if constexpr (IS_DBIAS) { + if (is_single_tensor) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] + // HEIGHT = THREADS_Y + // WIDTH = THREADS_X * (SCALE_DIM_X + 1) + // Added extra 1-element padding per thread_X to reduce bank conflicts + float *partial_dbias_rowwise = reinterpret_cast(dshmem); + + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + + const int shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const int shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + } + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < THREADS_Y; ++i) { + // Add extra element offset per MXFP8 scaling block [1x32] + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + } + } + const int dbias_stride = cols; + const int dbias_offset_Y = block_id_Y; + const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; + } + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + block_amax = reduce_max(block_amax, warp_id); + } + + if (leading_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } + + destroy_barriers(mbar, leading_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace group_quantize_kernel + +template +void group_quantize(const GroupedTensor *input, const GroupedTensor *activations, + const Tensor *noop, GroupedTensor *output, Tensor *dbias, Tensor *workspace, + cudaStream_t stream) { + using namespace group_quantize_kernel; + + checkCuDriverContext(stream); + + const bool use_rowwise_scaling = output->has_data(); + const bool use_colwise_scaling = output->has_columnwise_data(); + + ScalingType scaling_type; + if (use_rowwise_scaling && (!use_colwise_scaling)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!use_rowwise_scaling) && use_colwise_scaling) { + scaling_type = ScalingType::COLWISE; + } else if (use_rowwise_scaling && use_colwise_scaling) { + scaling_type = ScalingType::BIDIMENSIONAL; + } + + ShapeRepresentation shape_rep; + if (output->all_same_shape()) { + shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; + } else if (output->all_same_first_dim()) { + shape_rep = ShapeRepresentation::VARYING_LAST_DIM; + } else if (output->all_same_last_dim()) { + shape_rep = ShapeRepresentation::VARYING_FIRST_DIM; + } else if (output->varying_both_dims()) { + shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS; + } + + // Treat a grouped tensor with const last dims as a single tensor + const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + + NVTE_CHECK(input->num_tensors == output->num_tensors, + "Number of input and output tensors must be same."); + NVTE_CHECK(input->has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + + if (IS_DACT) { + NVTE_CHECK(activations->has_data(), "Activations tensor must have data."); + NVTE_CHECK(input->num_tensors == activations->num_tensors, + "Number of grad and activations tensors must be same."); + NVTE_CHECK(input->dtype() == activations->dtype(), + "Grad and activations tensors must have the same type."); + } + + const size_t num_tensors = input->num_tensors; + NVTE_CHECK( + num_tensors < MAX_SUPPORTED_TENSOR_DESCRIPTORS, + "Number of tensors in a group is larger than the MAX number of supported descriptors (64)."); + + const size_t first_logical_dim = input->logical_shape.data[0]; + const size_t last_logical_dim = input->logical_shape.data[1]; + const size_t elts_total = first_logical_dim * last_logical_dim; + + // Logical shape of a tensor with varying all dims is [1, M*K] + if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) { + NVTE_CHECK(first_logical_dim % 128 == 0, + "First dimension of a grouped tensor should be divisible by 128."); + } + NVTE_CHECK(last_logical_dim % 128 == 0, + "Last dimension of a grouped tensor should be divisible by 128."); + + e8m0_t *const scales_rowwise_ptr = reinterpret_cast(output->scale_inv.dptr); + e8m0_t *const scales_colwise_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + + if (use_rowwise_scaling) { + NVTE_CHECK(scales_rowwise_ptr != nullptr, "Scaling tensor must be allocated"); + } + if (use_colwise_scaling) { + NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); + } + + const int64_t *const offsets_ptr = reinterpret_cast(input->tensor_offsets.dptr); + const int64_t *const first_dims_ptr = reinterpret_cast(input->first_dims.dptr); + const int64_t *const last_dims_ptr = reinterpret_cast(input->last_dims.dptr); + + CheckNoopTensor(*noop, "cast_noop"); + + const size_t blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); + const dim3 grid(blocks); + const size_t block_size = THREADS_PER_CHUNK; + + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + const size_t dbias_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); + const size_t dbias_cols = last_logical_dim; + if constexpr (IS_DBIAS) { + NVTE_CHECK(is_single_tensor, + "DBias is only supported for tensors with the const last dimension."); + NVTE_CHECK(dbias->data.dtype == input->dtype(), + "DBias must have the same type as input_tensor."); + NVTE_CHECK(dbias->data.shape == std::vector{last_logical_dim}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input->dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, last_logical_dim, + BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, input_type_bit_size); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + input_type_bit_size); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + output_type_bit_size); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, + first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, + last_logical_dim, 0, output_type_bit_size); + } constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + auto kernel = group_quantize_mxfp8_kernel; + switch (scaling_type) { + case ScalingType::ROWWISE: { + kernel = group_quantize_mxfp8_kernel; + break; + } + case ScalingType::COLWISE: { + kernel = group_quantize_mxfp8_kernel; + break; + } + case ScalingType::BIDIMENSIONAL: { + kernel = group_quantize_mxfp8_kernel; + break; + } + } + + // Update tensor descriptors before launching the kernel + if (!is_single_tensor) { + const IType *const input_dptr = reinterpret_cast(input->data.dptr); + + const IType *const act_input_dptr = + IS_DACT ? reinterpret_cast(activations->data.dptr) : nullptr; + + OType *const output_rowwise_dptr = + use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; + + OType *const output_colwise_dptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) + : nullptr; + update_tma_descriptors<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, + output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling, + use_colwise_scaling, IS_DACT); + } + + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim, + last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, + scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); + + if constexpr (IS_DBIAS) { + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + } + + NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) +} + +} // namespace mxfp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_ diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 55cd44d9de..4c9eed3365 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -52,6 +52,16 @@ enum class NVTE_Activation_Type { */ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the GeLU activation of the grouped input. + * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); + /*! \brief Computes the SiLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -62,6 +72,16 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the SiLU activation of the grouped input. + * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); + /*! \brief Computes the ReLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -72,6 +92,16 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the ReLU activation of the grouped input. + * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); + /*! \brief Computes the Quick GeLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -82,6 +112,16 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Quick GeLU activation of the grouped input. + * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); + /*! \brief Computes the Squared ReLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -92,6 +132,16 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Squared ReLU activation of the grouped input. + * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); + /*! \brief Computes the GeLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -104,6 +154,18 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the GeLU activation gradient of the grouped input. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming grouped gradient. + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTETensor output, cudaStream_t stream); + /*! \brief Computes the SiLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -116,6 +178,18 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the SiLU activation gradient of the grouped input. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming grouped gradient. + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTETensor output, cudaStream_t stream); + /*! \brief Computes the ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -128,6 +202,18 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the ReLU activation gradient of the grouped input. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming grouped gradient. + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTETensor output, cudaStream_t stream); + /*! \brief Computes the Quick GeLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -140,6 +226,18 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Quick GeLU activation gradient of the grouped input. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming grouped gradient. + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTETensor output, cudaStream_t stream); + /*! \brief Computes the Squared ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -152,6 +250,18 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Squared ReLU activation gradient of the grouped input. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming grouped gradient. + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTETensor output, cudaStream_t stream); + /*! \brief Computes the gated GeLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 576494a4de..04712d3003 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -89,6 +89,17 @@ extern "C" { */ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Casts input grouped tensor to MXFP8. + * The type of quantized tensor in the output depends on the scaling mode of the output + * tensor. See file level comments. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in,out] output Output grouped MXFP8 tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream); + /*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel * based on the value of the 'noop' tensor. * The type of quantized tensor in the output depends on the scaling mode of the output @@ -132,6 +143,26 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output, void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Casts input grouped tensor to MXFP8. Additionally, reduces the input along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + /*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the GeLU backward along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -155,6 +186,29 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Computes backward of GeLU operation on the grouped input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the GeLU backward along columns. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in] act_input Activation input grouped tensor. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor act_input, NVTEGroupedTensor output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + /*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the SiLU backward along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -178,6 +232,29 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Computes backward of SiLU operation on the grouped input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the SiLU backward along columns. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in] act_input Activation input grouped tensor. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, + const NVTEGroupedTensor act_input, NVTEGroupedTensor output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + /*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the ReLU backward along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -201,6 +278,29 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Computes backward of ReLU operation on the grouped input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the ReLU backward along columns. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in] act_input Activation input grouped tensor. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor act_input, NVTEGroupedTensor output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + /*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Quick GeLU backward along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -224,6 +324,29 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Computes backward of Quick GeLU operation on the grouped input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the Quick GeLU backward along columns. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in] act_input Activation input grouped tensor. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor act_input, NVTEGroupedTensor output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + /*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Squared ReLU backward along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -247,6 +370,29 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Computes backward of Squared ReLU operation on the grouped input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the Squared ReLU backward along columns. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in] act_input Activation input grouped tensor. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor act_input, NVTEGroupedTensor output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + /*! \brief Casts input tensor from reduced to higher precision. * If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING, * the block dequantization (MXFP8) of the specified shape of the block will be used.