diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index bcacb2f801..12d6759f84 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -24,35 +25,38 @@ #include #include "../test_common.h" +#include "util/cuda_runtime.h" using namespace transformer_engine; using namespace test; namespace { -enum class InputCase { - kFP8Current, - kBF16, - kMXFP8, -}; +// std::nullopt means BF16 (no scaling); any other value is the scaling mode for FP8/NVFP4. +using InputRecipe = std::optional; + +inline const char* recipe_name(const InputRecipe& r) { + if (!r.has_value()) return "BF16"; + switch (*r) { + case NVTE_DELAYED_TENSOR_SCALING: return "FP8Current"; + case NVTE_MXFP8_1D_SCALING: return "MXFP8"; + case NVTE_NVFP4_1D_SCALING: return "NVFP4"; + case NVTE_BLOCK_SCALING_1D: return "FP8BlockScaling"; + default: return "Unknown"; + } +} +// Mul128 cases use dims that are multiples of 128 — full functionality across all recipes. +// kAllSameMul32 uses dims that are multiples of 32 but not 128, so each expert's scale_inv +// is padded. enum class ShapeCase { - kAllSame, - kSameFirst, - kSameLast, - kAllDifferent, + kAllSameMul128, + kSameFirstMul128, + kSameLastMul128, + kAllDifferentMul128, + kAllSameMul32, }; -size_t grouped_setup_workspace_size(const size_t num_tensors) { - const size_t ptr_bytes = num_tensors * sizeof(void*); - const size_t int_bytes = num_tensors * sizeof(int); - // Layout: 8 pointer arrays (A, B, C, D, alpha, beta, a_scale, b_scale) + 6 int arrays - size_t size = 8 * ptr_bytes + 6 * int_bytes; - const size_t alignment = 256; - size = ((size + alignment - 1) / alignment) * alignment; - return size; -} - Tensor make_fp8_operand(const std::string& name, const std::vector& shape) { Tensor input_fp32(name + "_fp32", shape, DType::kFloat32); @@ -88,69 +92,87 @@ Tensor make_bf16_operand(const std::string& name, const std::vector& sha return t; } -// Creates an MXFP8 operand with the correct data layout for GEMM. -// MXFP8 GEMM requirements (scales are along K dimension): -// A transposed -> needs rowwise data/scales -// A non-transposed -> needs columnwise data/scales -// B transposed -> needs columnwise data/scales -// B non-transposed -> needs rowwise data/scales +// Creates an MXFP8 operand with the given single direction (scales along K dimension). Tensor make_mxfp8_operand(const std::string& name, const std::vector& shape, - bool is_A, bool transposed) { - // Determine which data layout we need - bool use_rowwise, use_colwise; - if (is_A) { - // A: transposed -> rowwise, non-transposed -> columnwise - use_rowwise = transposed; - use_colwise = !transposed; - } else { - // B: transposed -> columnwise, non-transposed -> rowwise (opposite of A!) - use_rowwise = !transposed; - use_colwise = transposed; - } - - // Create BF16 input with random data + bool use_rowwise) { Tensor input_bf16(name + "_bf16", shape, DType::kBFloat16); fillUniform(&input_bf16); - // Create MXFP8 tensor with only the required data layout - Tensor mxfp8(name, shape, TypeInfo::dtype, use_rowwise, use_colwise, + Tensor mxfp8(name, shape, TypeInfo::dtype, use_rowwise, !use_rowwise, NVTE_MXFP8_1D_SCALING); - - // Quantize BF16 -> MXFP8 nvte_quantize(input_bf16.data(), mxfp8.data(), 0); - // Create output tensor for swizzled scales (same data shape, same layout) Tensor mxfp8_swizzled(name + "_swizzled", shape, TypeInfo::dtype, - use_rowwise, use_colwise, NVTE_MXFP8_1D_SCALING); + use_rowwise, !use_rowwise, NVTE_MXFP8_1D_SCALING); mxfp8_swizzled.set_with_gemm_swizzled_scales(true); // Must be set BEFORE swizzle call - // Copy quantized data from mxfp8 to mxfp8_swizzled - if (use_rowwise) { - size_t data_bytes = test::bytes(mxfp8.rowwise_shape(), mxfp8.dtype()); - NVTE_CHECK_CUDA(cudaMemcpy(mxfp8_swizzled.rowwise_dptr(), mxfp8.rowwise_dptr(), - data_bytes, cudaMemcpyDeviceToDevice)); - } - if (use_colwise) { - size_t data_bytes = test::bytes(mxfp8.columnwise_shape(), mxfp8.dtype()); - NVTE_CHECK_CUDA(cudaMemcpy(mxfp8_swizzled.columnwise_dptr(), mxfp8.columnwise_dptr(), - data_bytes, cudaMemcpyDeviceToDevice)); - } + const size_t data_bytes = test::bytes( + use_rowwise ? mxfp8.rowwise_shape() : mxfp8.columnwise_shape(), mxfp8.dtype()); + void* dst = use_rowwise ? mxfp8_swizzled.rowwise_dptr() : mxfp8_swizzled.columnwise_dptr(); + void* src = use_rowwise ? mxfp8.rowwise_dptr() : mxfp8.columnwise_dptr(); + NVTE_CHECK_CUDA(cudaMemcpy(dst, src, data_bytes, cudaMemcpyDeviceToDevice)); - // Swizzle scales for GEMM nvte_swizzle_scaling_factors(mxfp8.data(), mxfp8_swizzled.data(), 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + return mxfp8_swizzled; +} + +// Creates an NVFP4 operand with the given single direction, swizzled scales. +Tensor make_nvfp4_operand(const std::string& name, const std::vector& shape, + bool use_rowwise, bool nvfp4_2d) { + Tensor input_bf16(name + "_bf16", shape, DType::kBFloat16); + fillUniform(&input_bf16); - // Sync to ensure operations are complete + Tensor nvfp4(name, shape, DType::kFloat4E2M1, use_rowwise, !use_rowwise, + NVTE_NVFP4_1D_SCALING); + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_2d_quantization(nvfp4_2d); + nvte_quantize_v2(input_bf16.data(), nvfp4.data(), quant_config, 0); + + Tensor nvfp4_sw(name + "_sw", shape, DType::kFloat4E2M1, use_rowwise, !use_rowwise, + NVTE_NVFP4_1D_SCALING); + nvfp4_sw.set_with_gemm_swizzled_scales(true); + + // Copy quantized data + amax to swizzled tensor (swizzle only rewrites scale_inv). + const auto amax_kind = use_rowwise ? kNVTEAmax : kNVTEColumnwiseAmax; + const NVTEBasicTensor src_amax = nvte_get_tensor_param(nvfp4.data(), amax_kind); + const NVTEBasicTensor dst_amax = nvte_get_tensor_param(nvfp4_sw.data(), amax_kind); + NVTE_CHECK_CUDA(cudaMemcpy(dst_amax.data_ptr, src_amax.data_ptr, sizeof(float), + cudaMemcpyDeviceToDevice)); + const size_t data_bytes = test::bytes( + use_rowwise ? nvfp4.rowwise_shape() : nvfp4.columnwise_shape(), nvfp4.dtype()); + void* dst_data = use_rowwise ? nvfp4_sw.rowwise_dptr() : nvfp4_sw.columnwise_dptr(); + void* src_data = use_rowwise ? nvfp4.rowwise_dptr() : nvfp4.columnwise_dptr(); + NVTE_CHECK_CUDA(cudaMemcpy(dst_data, src_data, data_bytes, cudaMemcpyDeviceToDevice)); + + nvte_swizzle_scaling_factors(nvfp4.data(), nvfp4_sw.data(), 0); NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + return nvfp4_sw; +} - return mxfp8_swizzled; +// Creates an FP8 block-scaling operand with the given single direction (TN-only on Hopper). +Tensor make_fp8_block_scaling_operand(const std::string& name, const std::vector& shape, + bool use_rowwise) { + Tensor input_bf16(name + "_bf16", shape, DType::kBFloat16); + fillUniform(&input_bf16); + + Tensor fp8_bs(name, shape, TypeInfo::dtype, use_rowwise, !use_rowwise, + NVTE_BLOCK_SCALING_1D); + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(true); + nvte_quantize_v2(input_bf16.data(), fp8_bs.data(), quant_config, 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + return fp8_bs; } struct TestParams { - InputCase input_case; + InputRecipe recipe; // std::nullopt = BF16, otherwise the scaling mode. bool transa; bool transb; ShapeCase shape_case; bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0) + bool nvfp4_2d = false; // NVFP4-only: use 2D (16x16) amax instead of 1D (1x16). + DType output_dtype = DType::kBFloat16; // Implementation also accepts FP16 / FP32. }; // Returns a vector of (M, N, K) tuples for each GEMM in the group. @@ -159,113 +181,226 @@ struct TestParams { // K - reduction dimension shared between A and B std::vector> make_shapes(ShapeCase scase) { switch (scase) { - case ShapeCase::kAllSame: + case ShapeCase::kAllSameMul128: return {{128, 256, 384}, {128, 256, 384}, {128, 256, 384}}; - case ShapeCase::kSameFirst: + case ShapeCase::kSameFirstMul128: // Same M (first dim), varying N and K return {{128, 256, 384}, {128, 384, 512}, {128, 512, 640}}; - case ShapeCase::kSameLast: + case ShapeCase::kSameLastMul128: // Same N (last dim), varying M and K return {{128, 256, 384}, {256, 256, 512}, {384, 256, 640}}; - case ShapeCase::kAllDifferent: - default: + case ShapeCase::kAllDifferentMul128: return {{128, 256, 384}, {256, 384, 512}, {384, 512, 640}}; + case ShapeCase::kAllSameMul32: + default: + return {{160, 288, 416}, {160, 288, 416}, {160, 288, 416}}; } } -void run_grouped_gemm_case(const TestParams& params) { -#if CUBLAS_VERSION < 130300 - GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is " - << CUBLAS_VERSION << "."; -#else - if (getDeviceComputeCapability() < blackwellComputeCapability) { - GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer."; - } +constexpr size_t kCublasGroupedGemmVersion = 130300; // Blackwell-only grouped GEMM +constexpr size_t kCublasGroupedGemmHopperVersion = 130400; // adds Hopper support - const std::vector> shapes = make_shapes(params.shape_case); +inline std::string grouped_gemm_skip_reason(const TestParams& params) { + const size_t cublas_ver = transformer_engine::cuda::cublas_version(); + if (cublas_ver < kCublasGroupedGemmVersion) { + return "Grouped GEMM requires cuBLAS 13.3+, but run-time cuBLAS version is " + + std::to_string(cublas_ver) + "."; + } + const int32_t cc = getDeviceComputeCapability(); + const std::string cc_suffix = + "but device compute capability is " + std::to_string(cc) + "."; + if (cc < hopperComputeCapability) { + return "Grouped GEMM requires Hopper (SM90) or newer, " + cc_suffix; + } + if (cc < blackwellComputeCapability && cublas_ver < kCublasGroupedGemmHopperVersion) { + return "Grouped GEMM on Hopper (SM90) requires cuBLAS 13.4+, but run-time cuBLAS " + "version is " + std::to_string(cublas_ver) + "."; + } + if (params.recipe.has_value()) { + const bool is_blackwell_plus = cc >= blackwellComputeCapability; + if (!is_blackwell_plus && *params.recipe != NVTE_BLOCK_SCALING_1D) { + return std::string(recipe_name(params.recipe)) + + " grouped GEMM requires Blackwell (SM100) or newer, " + cc_suffix; + } + if (is_blackwell_plus && *params.recipe == NVTE_BLOCK_SCALING_1D) { + return "FP8 block scaling grouped GEMM is only supported on Hopper (SM90), " + cc_suffix; + } + if (*params.recipe == NVTE_NVFP4_1D_SCALING && params.output_dtype == DType::kFloat16) { + return "NVFP4 grouped GEMM does not support FP16 output."; + } + if (*params.recipe == NVTE_NVFP4_1D_SCALING && params.nvfp4_2d && + (!params.transa || params.transb)) { + return "NVFP4 2D quantization only supported in TN layout."; + } + } + return ""; +} - const size_t num_gemms = shapes.size(); +// Reference setup shared by the three run_* variants: builds A/B/D tensors per recipe, +// runs nvte_multi_tensor_gemm to fill D_multi with reference results, and keeps the +// workspaces alive (returned in the struct so callers don't have to track them). +// Output dtype comes from TestParams::output_dtype (BF16 / FP16 / FP32). +struct GroupedGemmRefSetup { + std::vector> shapes; + size_t num_gemms = 0; std::vector A_tensors; std::vector B_tensors; std::vector D_multi; + std::vector workspaces; + bool use_split_accum = false; +}; - A_tensors.reserve(num_gemms); - B_tensors.reserve(num_gemms); - D_multi.reserve(num_gemms); +inline GroupedGemmRefSetup make_grouped_gemm_ref(const TestParams& params) { + GroupedGemmRefSetup s; + s.shapes = make_shapes(params.shape_case); + s.num_gemms = s.shapes.size(); + s.A_tensors.reserve(s.num_gemms); + s.B_tensors.reserve(s.num_gemms); + s.D_multi.reserve(s.num_gemms); + + for (size_t i = 0; i < s.num_gemms; ++i) { + const auto [M, N, K] = s.shapes[i]; + const std::vector a_shape = + params.transa ? std::vector{N, K} : std::vector{K, N}; + const std::vector b_shape = + params.transb ? std::vector{K, M} : std::vector{M, K}; + if (!params.recipe.has_value()) { + s.A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape)); + s.B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape)); + } else { + const bool a_use_rowwise = params.transa; + const bool b_use_rowwise = !params.transb; + switch (*params.recipe) { + case NVTE_DELAYED_TENSOR_SCALING: + s.A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); + s.B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape)); + break; + case NVTE_MXFP8_1D_SCALING: + s.A_tensors.emplace_back(make_mxfp8_operand("A" + std::to_string(i), a_shape, + a_use_rowwise)); + s.B_tensors.emplace_back(make_mxfp8_operand("B" + std::to_string(i), b_shape, + b_use_rowwise)); + break; + case NVTE_NVFP4_1D_SCALING: + s.A_tensors.emplace_back(make_nvfp4_operand("A" + std::to_string(i), a_shape, + a_use_rowwise, params.nvfp4_2d)); + s.B_tensors.emplace_back(make_nvfp4_operand("B" + std::to_string(i), b_shape, + b_use_rowwise, params.nvfp4_2d)); + break; + case NVTE_BLOCK_SCALING_1D: + s.A_tensors.emplace_back(make_fp8_block_scaling_operand("A" + std::to_string(i), + a_shape, a_use_rowwise)); + s.B_tensors.emplace_back(make_fp8_block_scaling_operand("B" + std::to_string(i), + b_shape, b_use_rowwise)); + break; + default: + NVTE_ERROR("Unsupported scaling mode in grouped GEMM test: " + + std::string(recipe_name(params.recipe))); + } + } + s.D_multi.emplace_back(Tensor("D_multi" + std::to_string(i), + std::vector{M, N}, params.output_dtype)); + } - for (size_t i = 0; i < num_gemms; ++i) { - const auto [M, N, K] = shapes[i]; - const std::vector a_shape = params.transa ? std::vector{N, K} - : std::vector{K, N}; - const std::vector b_shape = params.transb ? std::vector{K, M} - : std::vector{M, K}; - switch (params.input_case) { - case InputCase::kFP8Current: { - A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); - B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape)); + // FP8 block scaling requires split accumulator (no fast accumulation). + s.use_split_accum = (params.recipe.has_value() && *params.recipe == NVTE_BLOCK_SCALING_1D); + + std::vector A_ptrs(s.num_gemms), B_ptrs(s.num_gemms), D_ptrs(s.num_gemms); + std::vector workspace_ptrs(s.num_gemms, nullptr); + std::vector bias_ptrs(s.num_gemms, nullptr), gelu_ptrs(s.num_gemms, nullptr); + constexpr size_t cublas_ws_bytes = 32ull * 1024 * 1024; + s.workspaces.reserve(s.num_gemms); + for (size_t i = 0; i < s.num_gemms; ++i) { + A_ptrs[i] = s.A_tensors[i].data(); + B_ptrs[i] = s.B_tensors[i].data(); + D_ptrs[i] = s.D_multi[i].data(); + s.workspaces.emplace_back(Tensor("workspace" + std::to_string(i), + std::vector{cublas_ws_bytes}, DType::kByte)); + workspace_ptrs[i] = s.workspaces.back().data(); + } + nvte_multi_tensor_gemm(A_ptrs.data(), B_ptrs.data(), D_ptrs.data(), bias_ptrs.data(), + gelu_ptrs.data(), static_cast(s.num_gemms), + params.transa, params.transb, false, workspace_ptrs.data(), + false, s.use_split_accum, 0, 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + return s; +} + +// Allocate and initialize alpha/beta tensors for grouped GEMM. +// Hopper requires a single shared scalar; Blackwell+ uses per-matrix scalars. +struct AlphaBetaTensors { + Tensor alpha; + Tensor beta; +}; + +inline AlphaBetaTensors make_alpha_beta(size_t num_gemms) { + const int32_t cc = getDeviceComputeCapability(); + const size_t n = cc < blackwellComputeCapability ? 1 : num_gemms; + AlphaBetaTensors ab{Tensor("alpha", std::vector{n}, DType::kFloat32), + Tensor("beta", std::vector{n}, DType::kFloat32)}; + std::vector a(n, 1.f); + std::vector b(n, 0.f); + NVTE_CHECK_CUDA(cudaMemcpy(ab.alpha.rowwise_dptr(), a.data(), n * sizeof(float), + cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(ab.beta.rowwise_dptr(), b.data(), n * sizeof(float), + cudaMemcpyHostToDevice)); + return ab; +} + +// Compare each tensor inside a grouped D buffer (with per-tensor offsets) against the +// reference D_multi[i] tensors. +inline void compare_grouped_d_to_multi( + const GroupedBuffers& grouped_D, + const std::vector>& shapes, + std::vector& D_multi, const char* tag) { + for (size_t i = 0; i < shapes.size(); ++i) { + Tensor grouped_split("grouped_D" + std::to_string(i), + std::vector{static_cast(std::get<0>(shapes[i])), + static_cast(std::get<1>(shapes[i]))}, + D_multi[i].dtype()); + const size_t offset_bytes = + static_cast(grouped_D.offsets_host[i]) * grouped_D.elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(), + static_cast(grouped_D.get_data()) + offset_bytes, + grouped_D.tensor_bytes[i], cudaMemcpyDeviceToDevice)); + grouped_split.to_cpu(); + D_multi[i].to_cpu(); + auto [atol, rtol] = getTolerances(D_multi[i].dtype()); + switch (D_multi[i].dtype()) { + case DType::kBFloat16: + compareResults(tag, grouped_split, D_multi[i].rowwise_cpu_dptr(), true, atol, rtol); break; - } - case InputCase::kBF16: { - A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape)); - B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape)); + case DType::kFloat16: + compareResults(tag, grouped_split, D_multi[i].rowwise_cpu_dptr(), true, atol, rtol); break; - } - case InputCase::kMXFP8: { - A_tensors.emplace_back(make_mxfp8_operand("A" + std::to_string(i), a_shape, - /*is_A=*/true, params.transa)); - B_tensors.emplace_back(make_mxfp8_operand("B" + std::to_string(i), b_shape, - /*is_A=*/false, params.transb)); + case DType::kFloat32: + compareResults(tag, grouped_split, D_multi[i].rowwise_cpu_dptr(), true, atol, rtol); break; - } + default: + NVTE_ERROR("Unsupported D dtype in test: " + + std::to_string(static_cast(D_multi[i].dtype()))); } - D_multi.emplace_back(Tensor("D_multi" + std::to_string(i), - std::vector{M, N}, - DType::kBFloat16)); } +} - std::vector A_ptrs(num_gemms); - std::vector B_ptrs(num_gemms); - std::vector D_ptrs(num_gemms); - std::vector workspaces(num_gemms); - std::vector workspace_ptrs(num_gemms, nullptr); - std::vector A_views; - std::vector B_views; +void run_grouped_gemm_case(const TestParams& params) { + if (auto reason = grouped_gemm_skip_reason(params); !reason.empty()) { + GTEST_SKIP() << reason; + } + auto ref = make_grouped_gemm_ref(params); + const auto& shapes = ref.shapes; + const size_t num_gemms = ref.num_gemms; + + std::vector A_views, B_views; A_views.reserve(num_gemms); B_views.reserve(num_gemms); - - // Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues) - std::vector bias_ptrs(num_gemms, nullptr); - std::vector gelu_ptrs(num_gemms, nullptr); - - const size_t cublas_ws_bytes = 32ull * 1024 * 1024; - for (size_t i = 0; i < num_gemms; ++i) { - A_ptrs[i] = A_tensors[i].data(); - B_ptrs[i] = B_tensors[i].data(); - D_ptrs[i] = D_multi[i].data(); - workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector{cublas_ws_bytes}, DType::kByte); - workspace_ptrs[i] = workspaces[i].data(); - A_views.push_back(&A_tensors[i]); - B_views.push_back(&B_tensors[i]); + A_views.push_back(&ref.A_tensors[i]); + B_views.push_back(&ref.B_tensors[i]); } - nvte_multi_tensor_gemm(A_ptrs.data(), - B_ptrs.data(), - D_ptrs.data(), - bias_ptrs.data(), - gelu_ptrs.data(), - static_cast(num_gemms), - params.transa, - params.transb, - false, // grad - workspace_ptrs.data(), - false, // accumulate - false, // use_split_accumulator - 0, // sm_count - 0); - - GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode()); - GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode()); + GroupedBuffers grouped_A = build_grouped_tensor(A_views, ref.A_tensors[0].scaling_mode()); + GroupedBuffers grouped_B = build_grouped_tensor(B_views, ref.B_tensors[0].scaling_mode()); std::vector C_tensors; std::vector D_group_tensors; @@ -277,11 +412,11 @@ void run_grouped_gemm_case(const TestParams& params) { if (!params.use_null_c) { C_tensors.emplace_back(Tensor("C" + std::to_string(i), std::vector{static_cast(M), static_cast(N)}, - DType::kBFloat16)); + params.output_dtype)); } D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i), std::vector{static_cast(M), static_cast(N)}, - DType::kBFloat16)); + params.output_dtype)); NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, bytes(D_group_tensors.back().rowwise_shape(), D_group_tensors.back().dtype()))); } @@ -299,152 +434,45 @@ void run_grouped_gemm_case(const TestParams& params) { } GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING); - // Per-matrix alpha/beta (all 1.0 and 0.0 respectively) - Tensor alpha_tensor("alpha", std::vector{num_gemms}, DType::kFloat32); - Tensor beta_tensor("beta", std::vector{num_gemms}, DType::kFloat32); - std::vector alpha_vals(num_gemms, 1.f); - std::vector beta_vals(num_gemms, 0.f); - NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(), - num_gemms * sizeof(float), cudaMemcpyHostToDevice)); - NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(), - num_gemms * sizeof(float), cudaMemcpyHostToDevice)); - - const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms); + AlphaBetaTensors ab = make_alpha_beta(num_gemms); + + constexpr size_t cublas_ws_bytes = 32ull * 1024 * 1024; + const size_t setup_ws_bytes = nvte_get_grouped_gemm_setup_workspace_size(num_gemms); Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); Tensor cublas_ws("cublas_ws", std::vector{cublas_ws_bytes}, DType::kByte); - nvte_grouped_gemm(grouped_A.get_handle(), - params.transa, - grouped_B.get_handle(), - params.transb, - params.use_null_c ? nullptr : grouped_C->get_handle(), - grouped_D.get_handle(), - alpha_tensor.data(), - beta_tensor.data(), - setup_ws.data(), - cublas_ws.data(), - nullptr, // config (use defaults) - 0); + GroupedMatmulConfigWrapper grouped_config; + if (ref.use_split_accum) { + grouped_config.set_use_split_accumulator(true); + } + + nvte_grouped_gemm(grouped_A.get_handle(), params.transa, grouped_B.get_handle(), params.transb, + params.use_null_c ? nullptr : grouped_C->get_handle(), grouped_D.get_handle(), + ab.alpha.data(), ab.beta.data(), setup_ws.data(), cublas_ws.data(), + grouped_config, 0); NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - // Compare results - for (size_t i = 0; i < num_gemms; ++i) { - Tensor grouped_split("grouped_D" + std::to_string(i), - std::vector{static_cast(std::get<0>(shapes[i])), - static_cast(std::get<1>(shapes[i]))}, - D_multi[i].dtype()); - const size_t offset_bytes = static_cast(grouped_D.offsets_host[i]) * grouped_D.elem_size; - NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(), - static_cast(grouped_D.get_data()) + offset_bytes, - grouped_D.tensor_bytes[i], - cudaMemcpyDeviceToDevice)); - grouped_split.to_cpu(); - D_multi[i].to_cpu(); - auto [atol, rtol] = getTolerances(D_multi[i].dtype()); - compareResults("grouped_vs_multi", - grouped_split, - D_multi[i].rowwise_cpu_dptr(), - true, - atol, - rtol); - } -#endif // CUBLAS_VERSION >= 130300 + compare_grouped_d_to_multi(grouped_D, shapes, ref.D_multi, "grouped_vs_multi"); } void run_grouped_gemm_discrete_out_case(const TestParams& params) { -#if CUBLAS_VERSION < 130300 - GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is " - << CUBLAS_VERSION << "."; -#else - if (getDeviceComputeCapability() < blackwellComputeCapability) { - GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer."; + if (auto reason = grouped_gemm_skip_reason(params); !reason.empty()) { + GTEST_SKIP() << reason; } + auto ref = make_grouped_gemm_ref(params); + const auto& shapes = ref.shapes; + const size_t num_gemms = ref.num_gemms; - const std::vector> shapes = make_shapes(params.shape_case); - - const size_t num_gemms = shapes.size(); - std::vector A_tensors; - std::vector B_tensors; - std::vector D_multi; - - A_tensors.reserve(num_gemms); - B_tensors.reserve(num_gemms); - D_multi.reserve(num_gemms); - - for (size_t i = 0; i < num_gemms; ++i) { - const auto [M, N, K] = shapes[i]; - const std::vector a_shape = params.transa ? std::vector{N, K} - : std::vector{K, N}; - const std::vector b_shape = params.transb ? std::vector{K, M} - : std::vector{M, K}; - switch (params.input_case) { - case InputCase::kFP8Current: { - A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); - B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape)); - break; - } - case InputCase::kBF16: { - A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape)); - B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape)); - break; - } - case InputCase::kMXFP8: { - A_tensors.emplace_back(make_mxfp8_operand("A" + std::to_string(i), a_shape, - /*is_A=*/true, params.transa)); - B_tensors.emplace_back(make_mxfp8_operand("B" + std::to_string(i), b_shape, - /*is_A=*/false, params.transb)); - break; - } - } - D_multi.emplace_back(Tensor("D_multi" + std::to_string(i), - std::vector{M, N}, - DType::kBFloat16)); - } - - std::vector A_ptrs(num_gemms); - std::vector B_ptrs(num_gemms); - std::vector D_ptrs(num_gemms); - std::vector workspaces(num_gemms); - std::vector workspace_ptrs(num_gemms, nullptr); - std::vector A_views; - std::vector B_views; + std::vector A_views, B_views; A_views.reserve(num_gemms); B_views.reserve(num_gemms); - - // Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues) - std::vector bias_ptrs(num_gemms, nullptr); - std::vector gelu_ptrs(num_gemms, nullptr); - - const size_t cublas_ws_bytes = 32ull * 1024 * 1024; - for (size_t i = 0; i < num_gemms; ++i) { - A_ptrs[i] = A_tensors[i].data(); - B_ptrs[i] = B_tensors[i].data(); - D_ptrs[i] = D_multi[i].data(); - workspaces[i] = - Tensor("workspace" + std::to_string(i), std::vector{cublas_ws_bytes}, DType::kByte); - workspace_ptrs[i] = workspaces[i].data(); - A_views.push_back(&A_tensors[i]); - B_views.push_back(&B_tensors[i]); + A_views.push_back(&ref.A_tensors[i]); + B_views.push_back(&ref.B_tensors[i]); } - nvte_multi_tensor_gemm(A_ptrs.data(), - B_ptrs.data(), - D_ptrs.data(), - bias_ptrs.data(), - gelu_ptrs.data(), - static_cast(num_gemms), - params.transa, - params.transb, - false, // grad - workspace_ptrs.data(), - false, // accumulate - false, // use_split_accumulator - 0, // sm_count - 0); - - GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode()); - GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode()); + GroupedBuffers grouped_A = build_grouped_tensor(A_views, ref.A_tensors[0].scaling_mode()); + GroupedBuffers grouped_B = build_grouped_tensor(B_views, ref.B_tensors[0].scaling_mode()); std::vector C_tensors; std::vector D_list_tensors; @@ -455,10 +483,10 @@ void run_grouped_gemm_discrete_out_case(const TestParams& params) { (void)K; if (!params.use_null_c) { C_tensors.emplace_back( - Tensor("C" + std::to_string(i), std::vector{M, N}, DType::kBFloat16)); + Tensor("C" + std::to_string(i), std::vector{M, N}, params.output_dtype)); } D_list_tensors.emplace_back( - Tensor("D_list" + std::to_string(i), std::vector{M, N}, DType::kBFloat16)); + Tensor("D_list" + std::to_string(i), std::vector{M, N}, params.output_dtype)); NVTE_CHECK_CUDA(cudaMemset(D_list_tensors.back().rowwise_dptr(), 0, bytes(D_list_tensors.back().rowwise_shape(), D_list_tensors.back().dtype()))); @@ -477,160 +505,75 @@ void run_grouped_gemm_discrete_out_case(const TestParams& params) { D_list_ptrs.push_back(D_list_tensors[i].data()); } - // Per-matrix alpha/beta (all 1.0 and 0.0 respectively) - Tensor alpha_tensor("alpha", std::vector{num_gemms}, DType::kFloat32); - Tensor beta_tensor("beta", std::vector{num_gemms}, DType::kFloat32); - std::vector alpha_vals(num_gemms, 1.f); - std::vector beta_vals(num_gemms, 0.f); - NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(), - num_gemms * sizeof(float), cudaMemcpyHostToDevice)); - NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(), - num_gemms * sizeof(float), cudaMemcpyHostToDevice)); - - const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms); + AlphaBetaTensors ab = make_alpha_beta(num_gemms); + + constexpr size_t cublas_ws_bytes = 32ull * 1024 * 1024; + const size_t setup_ws_bytes = nvte_get_grouped_gemm_setup_workspace_size(num_gemms); Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); Tensor cublas_ws("cublas_ws", std::vector{cublas_ws_bytes}, DType::kByte); - nvte_grouped_gemm_with_discrete_out(grouped_A.get_handle(), - params.transa, - grouped_B.get_handle(), - params.transb, - params.use_null_c ? nullptr : C_list_ptrs.data(), - params.use_null_c ? 0 : num_gemms, - D_list_ptrs.data(), - num_gemms, - alpha_tensor.data(), - beta_tensor.data(), - setup_ws.data(), - cublas_ws.data(), - nullptr, // config (use defaults) - 0); - NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - - // Compare results - for (size_t i = 0; i < num_gemms; ++i) { - D_list_tensors[i].to_cpu(); - D_multi[i].to_cpu(); - auto [atol, rtol] = getTolerances(D_multi[i].dtype()); - compareResults("grouped_list_vs_multi", - D_list_tensors[i], - D_multi[i].rowwise_cpu_dptr(), - true, - atol, - rtol); + GroupedMatmulConfigWrapper grouped_config; + if (ref.use_split_accum) { + grouped_config.set_use_split_accumulator(true); } -#endif // CUBLAS_VERSION >= 130300 -} - -void run_grouped_gemm_discrete_in_case(const TestParams& params) { -#if CUBLAS_VERSION < 130300 - GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is " - << CUBLAS_VERSION << "."; -#else - if (getDeviceComputeCapability() < blackwellComputeCapability) { - GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer."; - } - - const std::vector> shapes = make_shapes(params.shape_case); - - const size_t num_gemms = shapes.size(); - std::vector A_tensors; - std::vector B_tensors; - std::vector D_multi; - A_tensors.reserve(num_gemms); - B_tensors.reserve(num_gemms); - D_multi.reserve(num_gemms); + nvte_grouped_gemm_with_discrete_out( + grouped_A.get_handle(), params.transa, grouped_B.get_handle(), params.transb, + params.use_null_c ? nullptr : C_list_ptrs.data(), params.use_null_c ? 0 : num_gemms, + D_list_ptrs.data(), num_gemms, ab.alpha.data(), ab.beta.data(), setup_ws.data(), + cublas_ws.data(), grouped_config, 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); for (size_t i = 0; i < num_gemms; ++i) { - const auto [M, N, K] = shapes[i]; - const std::vector a_shape = params.transa ? std::vector{N, K} - : std::vector{K, N}; - const std::vector b_shape = params.transb ? std::vector{K, M} - : std::vector{M, K}; - switch (params.input_case) { - case InputCase::kFP8Current: { - A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); - B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape)); + D_list_tensors[i].to_cpu(); + ref.D_multi[i].to_cpu(); + auto [atol, rtol] = getTolerances(ref.D_multi[i].dtype()); + switch (ref.D_multi[i].dtype()) { + case DType::kBFloat16: + compareResults("grouped_list_vs_multi", D_list_tensors[i], + ref.D_multi[i].rowwise_cpu_dptr(), true, atol, rtol); break; - } - case InputCase::kBF16: { - A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape)); - B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape)); + case DType::kFloat16: + compareResults("grouped_list_vs_multi", D_list_tensors[i], + ref.D_multi[i].rowwise_cpu_dptr(), true, atol, rtol); break; - } - case InputCase::kMXFP8: { - A_tensors.emplace_back(make_mxfp8_operand("A" + std::to_string(i), a_shape, - /*is_A=*/true, params.transa)); - B_tensors.emplace_back(make_mxfp8_operand("B" + std::to_string(i), b_shape, - /*is_A=*/false, params.transb)); + case DType::kFloat32: + compareResults("grouped_list_vs_multi", D_list_tensors[i], + ref.D_multi[i].rowwise_cpu_dptr(), true, atol, rtol); break; - } + default: + NVTE_ERROR("Unsupported D dtype in test: " + + std::to_string(static_cast(ref.D_multi[i].dtype()))); } - D_multi.emplace_back(Tensor("D_multi" + std::to_string(i), - std::vector{M, N}, - DType::kBFloat16)); } +} + +void run_grouped_gemm_discrete_in_case(const TestParams& params) { + if (auto reason = grouped_gemm_skip_reason(params); !reason.empty()) { + GTEST_SKIP() << reason; + } + auto ref = make_grouped_gemm_ref(params); + const auto& shapes = ref.shapes; + const size_t num_gemms = ref.num_gemms; - std::vector A_ptrs(num_gemms); - std::vector B_ptrs(num_gemms); - std::vector D_ptrs(num_gemms); - std::vector workspaces(num_gemms); - std::vector workspace_ptrs(num_gemms, nullptr); - std::vector A_views; std::vector B_views; - A_views.reserve(num_gemms); B_views.reserve(num_gemms); + for (size_t i = 0; i < num_gemms; ++i) B_views.push_back(&ref.B_tensors[i]); - // Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues) - std::vector bias_ptrs(num_gemms, nullptr); - std::vector gelu_ptrs(num_gemms, nullptr); - - const size_t cublas_ws_bytes = 32ull * 1024 * 1024; - - for (size_t i = 0; i < num_gemms; ++i) { - A_ptrs[i] = A_tensors[i].data(); - B_ptrs[i] = B_tensors[i].data(); - D_ptrs[i] = D_multi[i].data(); - workspaces[i] = - Tensor("workspace" + std::to_string(i), std::vector{cublas_ws_bytes}, DType::kByte); - workspace_ptrs[i] = workspaces[i].data(); - A_views.push_back(&A_tensors[i]); - B_views.push_back(&B_tensors[i]); - } - - nvte_multi_tensor_gemm(A_ptrs.data(), - B_ptrs.data(), - D_ptrs.data(), - bias_ptrs.data(), - gelu_ptrs.data(), - static_cast(num_gemms), - params.transa, - params.transb, - false, // grad - workspace_ptrs.data(), - false, // accumulate - false, // use_split_accumulator - 0, // sm_count - 0); - - GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode()); + GroupedBuffers grouped_B = build_grouped_tensor(B_views, ref.B_tensors[0].scaling_mode()); - std::vector C_tensors; - std::vector D_group_tensors; + std::vector C_tensors, D_group_tensors; C_tensors.reserve(num_gemms); D_group_tensors.reserve(num_gemms); for (size_t i = 0; i < num_gemms; ++i) { const auto [M, N, K] = shapes[i]; (void)K; if (!params.use_null_c) { - C_tensors.emplace_back(Tensor("C" + std::to_string(i), - std::vector{M, N}, - DType::kBFloat16)); + C_tensors.emplace_back(Tensor("C" + std::to_string(i), std::vector{M, N}, + params.output_dtype)); } D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i), - std::vector{M, N}, - DType::kBFloat16)); + std::vector{M, N}, params.output_dtype)); NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, bytes(D_group_tensors.back().rowwise_shape(), D_group_tensors.back().dtype()))); @@ -638,9 +581,7 @@ void run_grouped_gemm_discrete_in_case(const TestParams& params) { std::vector C_views, D_views; for (size_t i = 0; i < num_gemms; ++i) { - if (!params.use_null_c) { - C_views.push_back(&C_tensors[i]); - } + if (!params.use_null_c) C_views.push_back(&C_tensors[i]); D_views.push_back(&D_group_tensors[i]); } @@ -650,63 +591,29 @@ void run_grouped_gemm_discrete_in_case(const TestParams& params) { } GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING); - // Per-matrix alpha/beta (all 1.0 and 0.0 respectively) - Tensor alpha_tensor("alpha", std::vector{num_gemms}, DType::kFloat32); - Tensor beta_tensor("beta", std::vector{num_gemms}, DType::kFloat32); - std::vector alpha_vals(num_gemms, 1.f); - std::vector beta_vals(num_gemms, 0.f); - NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(), - num_gemms * sizeof(float), cudaMemcpyHostToDevice)); - NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(), - num_gemms * sizeof(float), cudaMemcpyHostToDevice)); - - const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms); + AlphaBetaTensors ab = make_alpha_beta(num_gemms); + + constexpr size_t cublas_ws_bytes = 32ull * 1024 * 1024; + const size_t setup_ws_bytes = nvte_get_grouped_gemm_setup_workspace_size(num_gemms); Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); Tensor cublas_ws("cublas_ws", std::vector{cublas_ws_bytes}, DType::kByte); std::vector A_list_ptrs; A_list_ptrs.reserve(num_gemms); - for (size_t i = 0; i < num_gemms; ++i) { - A_list_ptrs.push_back(A_tensors[i].data()); + for (size_t i = 0; i < num_gemms; ++i) A_list_ptrs.push_back(ref.A_tensors[i].data()); + + GroupedMatmulConfigWrapper grouped_config; + if (ref.use_split_accum) { + grouped_config.set_use_split_accumulator(true); } - nvte_grouped_gemm_with_discrete_inputA(A_list_ptrs.data(), - num_gemms, - params.transa, - grouped_B.get_handle(), - params.transb, - params.use_null_c ? nullptr : grouped_C->get_handle(), - grouped_D.get_handle(), - alpha_tensor.data(), - beta_tensor.data(), - setup_ws.data(), - cublas_ws.data(), - nullptr, // config (use defaults) - 0); + nvte_grouped_gemm_with_discrete_inputA( + A_list_ptrs.data(), num_gemms, params.transa, grouped_B.get_handle(), params.transb, + params.use_null_c ? nullptr : grouped_C->get_handle(), grouped_D.get_handle(), + ab.alpha.data(), ab.beta.data(), setup_ws.data(), cublas_ws.data(), grouped_config, 0); NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - // Compare results - for (size_t i = 0; i < num_gemms; ++i) { - Tensor grouped_split("grouped_D" + std::to_string(i), - std::vector{static_cast(std::get<0>(shapes[i])), - static_cast(std::get<1>(shapes[i]))}, - D_multi[i].dtype()); - const size_t offset_bytes = static_cast(grouped_D.offsets_host[i]) * grouped_D.elem_size; - NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(), - static_cast(grouped_D.get_data()) + offset_bytes, - grouped_D.tensor_bytes[i], - cudaMemcpyDeviceToDevice)); - grouped_split.to_cpu(); - D_multi[i].to_cpu(); - auto [atol, rtol] = getTolerances(D_multi[i].dtype()); - compareResults("grouped_discrete_in_vs_multi", - grouped_split, - D_multi[i].rowwise_cpu_dptr(), - true, - atol, - rtol); - } -#endif // CUBLAS_VERSION >= 130300 + compare_grouped_d_to_multi(grouped_D, shapes, ref.D_multi, "grouped_discrete_in_vs_multi"); } class GroupedGemmTest : public ::testing::TestWithParam {}; @@ -724,38 +631,86 @@ TEST_P(GroupedGemmTest, CompareWithMultiTensorGemmDiscreteIn) { } std::string MakeGroupedGemmTestName(const testing::TestParamInfo& info) { - constexpr const char* kInputNames[] = {"FP8Current", "BF16", "MXFP8"}; - constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"}; + constexpr const char* kShapeNames[] = {"AllSameMul128", "SameMMul128", "SameNMul128", + "AllDiffMul128", "AllSameMul32"}; const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") + "tb" + (info.param.transb ? "T" : "N"); const std::string null_c = info.param.use_null_c ? "_NullC" : ""; - return std::string(kInputNames[static_cast(info.param.input_case)]) + "_" + - kShapeNames[static_cast(info.param.shape_case)] + "_" + layout + null_c; + const std::string nvfp4_2d = info.param.nvfp4_2d ? "_2D" : ""; + std::string out_suffix; + switch (info.param.output_dtype) { + case DType::kBFloat16: break; // default, no suffix + case DType::kFloat16: out_suffix = "_outFP16"; break; + case DType::kFloat32: out_suffix = "_outFP32"; break; + default: out_suffix = "_outUnknown"; break; + } + return std::string(recipe_name(info.param.recipe)) + nvfp4_2d + "_" + + kShapeNames[static_cast(info.param.shape_case)] + "_" + layout + null_c + out_suffix; } -// TestParams: {input_case, transa, transb, shape_case, use_null_c} +// TestParams: {recipe, transa, transb, shape_case, use_null_c} +// recipe == std::nullopt means BF16 (no scaling), otherwise the FP8/NVFP4 scaling mode. const std::vector kTestParams = { // FP8 tests (each tensor has random mean/stddev -> different scales) - {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false}, - {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false}, - {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false}, + {NVTE_DELAYED_TENSOR_SCALING, true, false, ShapeCase::kAllDifferentMul128, false}, + {NVTE_DELAYED_TENSOR_SCALING, false, true, ShapeCase::kAllDifferentMul128, false}, + {NVTE_DELAYED_TENSOR_SCALING, false, false, ShapeCase::kAllSameMul128, false}, // BF16 tests - {InputCase::kBF16, true, false, ShapeCase::kSameFirst, false}, - {InputCase::kBF16, false, true, ShapeCase::kSameLast, false}, - {InputCase::kBF16, false, false, ShapeCase::kAllSame, false}, - {InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false}, + {std::nullopt, true, false, ShapeCase::kSameFirstMul128, false}, + {std::nullopt, false, true, ShapeCase::kSameLastMul128, false}, + {std::nullopt, false, false, ShapeCase::kAllSameMul128, false}, + {std::nullopt, true, true, ShapeCase::kAllDifferentMul128, false}, // Test NULL C (valid when beta=0) - {InputCase::kBF16, false, false, ShapeCase::kAllSame, true}, + {std::nullopt, false, false, ShapeCase::kAllSameMul128, true}, // MXFP8 tests - {InputCase::kMXFP8, true, false, ShapeCase::kAllSame, false}, - {InputCase::kMXFP8, true, false, ShapeCase::kAllDifferent, false}, - {InputCase::kMXFP8, false, true, ShapeCase::kAllSame, false}, - {InputCase::kMXFP8, false, true, ShapeCase::kAllDifferent, false}, - {InputCase::kMXFP8, false, false, ShapeCase::kAllSame, false}, - {InputCase::kMXFP8, false, false, ShapeCase::kAllDifferent, false}, - {InputCase::kMXFP8, false, false, ShapeCase::kSameFirst, false}, + {NVTE_MXFP8_1D_SCALING, true, false, ShapeCase::kAllSameMul128, false}, + {NVTE_MXFP8_1D_SCALING, true, false, ShapeCase::kAllDifferentMul128, false}, + {NVTE_MXFP8_1D_SCALING, false, true, ShapeCase::kAllSameMul128, false}, + {NVTE_MXFP8_1D_SCALING, false, true, ShapeCase::kAllDifferentMul128, false}, + {NVTE_MXFP8_1D_SCALING, false, false, ShapeCase::kAllSameMul128, false}, + {NVTE_MXFP8_1D_SCALING, false, false, ShapeCase::kAllDifferentMul128, false}, + {NVTE_MXFP8_1D_SCALING, false, false, ShapeCase::kSameFirstMul128, false}, // MXFP8 with NULL C - {InputCase::kMXFP8, true, false, ShapeCase::kAllSame, true}, + {NVTE_MXFP8_1D_SCALING, true, false, ShapeCase::kAllSameMul128, true}, + // NVFP4 tests (all transpose combinations - GEMM internally forces TN) + {NVTE_NVFP4_1D_SCALING, true, false, ShapeCase::kAllSameMul128, false}, + {NVTE_NVFP4_1D_SCALING, true, false, ShapeCase::kAllDifferentMul128, false}, + {NVTE_NVFP4_1D_SCALING, true, false, ShapeCase::kSameFirstMul128, false}, + {NVTE_NVFP4_1D_SCALING, true, false, ShapeCase::kSameLastMul128, false}, + {NVTE_NVFP4_1D_SCALING, false, true, ShapeCase::kAllSameMul128, false}, + {NVTE_NVFP4_1D_SCALING, false, true, ShapeCase::kAllDifferentMul128, false}, + {NVTE_NVFP4_1D_SCALING, false, false, ShapeCase::kAllSameMul128, false}, + {NVTE_NVFP4_1D_SCALING, false, false, ShapeCase::kAllDifferentMul128, false}, + // NVFP4 with NULL C + {NVTE_NVFP4_1D_SCALING, true, false, ShapeCase::kAllSameMul128, true}, + // NVFP4 with 2D (16x16) quantization. + {NVTE_NVFP4_1D_SCALING, true, false, ShapeCase::kAllSameMul128, false, /*nvfp4_2d=*/true}, + {NVTE_NVFP4_1D_SCALING, false, true, ShapeCase::kAllDifferentMul128, false, /*nvfp4_2d=*/true}, + {NVTE_NVFP4_1D_SCALING, false, false, ShapeCase::kAllSameMul128, false, /*nvfp4_2d=*/true}, + // Non-default output dtypes (BF16 covered everywhere else). + {std::nullopt, false, false, ShapeCase::kAllSameMul128, false, + /*nvfp4_2d=*/false, /*output_dtype=*/DType::kFloat32}, + {NVTE_DELAYED_TENSOR_SCALING, true, false, ShapeCase::kAllSameMul128, false, + /*nvfp4_2d=*/false, /*output_dtype=*/DType::kFloat16}, + // FP8 Block Scaling tests (TN layout on Hopper, block size 128) + {NVTE_BLOCK_SCALING_1D, true, false, ShapeCase::kAllSameMul128, false}, + {NVTE_BLOCK_SCALING_1D, true, false, ShapeCase::kAllDifferentMul128, false}, + {NVTE_BLOCK_SCALING_1D, true, false, ShapeCase::kSameFirstMul128, false}, + {NVTE_BLOCK_SCALING_1D, true, false, ShapeCase::kSameLastMul128, false}, + {NVTE_BLOCK_SCALING_1D, false, true, ShapeCase::kAllSameMul128, false}, + {NVTE_BLOCK_SCALING_1D, false, false, ShapeCase::kAllSameMul128, false}, + // FP8 Block Scaling with NULL C + {NVTE_BLOCK_SCALING_1D, true, false, ShapeCase::kAllSameMul128, true}, + // Dims multiples of 32 but not 128 — exercises scale_inv padding offsets. + {NVTE_MXFP8_1D_SCALING, true, false, ShapeCase::kAllSameMul32, false}, + {NVTE_MXFP8_1D_SCALING, false, true, ShapeCase::kAllSameMul32, false}, + {NVTE_MXFP8_1D_SCALING, false, false, ShapeCase::kAllSameMul32, false}, + {NVTE_NVFP4_1D_SCALING, true, false, ShapeCase::kAllSameMul32, false}, + {NVTE_NVFP4_1D_SCALING, false, true, ShapeCase::kAllSameMul32, false}, + {NVTE_NVFP4_1D_SCALING, false, false, ShapeCase::kAllSameMul32, false}, + {NVTE_BLOCK_SCALING_1D, true, false, ShapeCase::kAllSameMul32, false}, + {NVTE_BLOCK_SCALING_1D, false, true, ShapeCase::kAllSameMul32, false}, + {NVTE_BLOCK_SCALING_1D, false, false, ShapeCase::kAllSameMul32, false}, }; INSTANTIATE_TEST_SUITE_P(OperatorTest, diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 8990ce8db1..a2a1b9445d 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -261,8 +261,10 @@ void performTestGroupedSwizzleMXFP8(const int num_tensors, const size_t M, const output_tensors.emplace_back(std::move(output)); } - GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING); - GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING); + GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING, + /*enforce_grouped_gemm_alignment=*/false); + GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING, + /*enforce_grouped_gemm_alignment=*/false); const uint8_t input_swizzled = 0; nvte_set_grouped_tensor_param(grouped_input.get_handle(), kNVTEGroupedWithGEMMSwizzledScales, @@ -369,8 +371,10 @@ void performTestGroupedUnswizzleMXFP8(const int num_tensors, const size_t M, con output_tensors.emplace_back(std::move(output)); } - GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING); - GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING); + GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING, + /*enforce_grouped_gemm_alignment=*/false); + GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING, + /*enforce_grouped_gemm_alignment=*/false); const uint8_t input_swizzled = 1; nvte_set_grouped_tensor_param(grouped_input.get_handle(), kNVTEGroupedWithGEMMSwizzledScales, @@ -459,9 +463,12 @@ void performTestGroupedSwizzleUnswizzleRoundtrip(const int num_tensors, const si final_tensors.emplace_back(std::move(fin)); } - GroupedBuffers grouped_orig = build_grouped_tensor(orig_ptrs, NVTE_MXFP8_1D_SCALING); - GroupedBuffers grouped_mid = build_grouped_tensor(mid_ptrs, NVTE_MXFP8_1D_SCALING); - GroupedBuffers grouped_fin = build_grouped_tensor(final_ptrs, NVTE_MXFP8_1D_SCALING); + GroupedBuffers grouped_orig = build_grouped_tensor(orig_ptrs, NVTE_MXFP8_1D_SCALING, + /*enforce_grouped_gemm_alignment=*/false); + GroupedBuffers grouped_mid = build_grouped_tensor(mid_ptrs, NVTE_MXFP8_1D_SCALING, + /*enforce_grouped_gemm_alignment=*/false); + GroupedBuffers grouped_fin = build_grouped_tensor(final_ptrs, NVTE_MXFP8_1D_SCALING, + /*enforce_grouped_gemm_alignment=*/false); const NVTEShape row_shape = orig_tensors[0]->rowwise_scale_inv_shape(); const NVTEShape col_shape = orig_tensors[0]->columnwise_scale_inv_shape(); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 4fd75bb927..c324ad2dc6 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -315,7 +315,8 @@ Tensor::Tensor(const std::string& name, switch (scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: case NVTE_BLOCK_SCALING_1D: - case NVTE_BLOCK_SCALING_2D: { + case NVTE_BLOCK_SCALING_2D: + case NVTE_NVFP4_1D_SCALING: { // Column-wise data shape is transposed if (shape.ndim > 0) { columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); @@ -325,8 +326,7 @@ Tensor::Tensor(const std::string& name, } break; } - case NVTE_MXFP8_1D_SCALING: - case NVTE_NVFP4_1D_SCALING: { + case NVTE_MXFP8_1D_SCALING: { // Column-wise data matches shape for (size_t i = 0; i < shape.ndim; ++i) { columnwise_shape_vec.emplace_back(shape.data[i]); @@ -1052,7 +1052,8 @@ std::array get_scale_tensor_dims(const size_t rows, } GroupedBuffers build_grouped_tensor(const std::vector& tensors, - const NVTEScalingMode scaling_mode) { + const NVTEScalingMode scaling_mode, + bool enforce_grouped_gemm_alignment) { NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build."); // Check which data layouts are available (all tensors must have the same) @@ -1064,9 +1065,16 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, : tensors[0]->columnwise_shape(); const DType dtype = tensors[0]->dtype(); const size_t num_tensors = tensors.size(); - const size_t elem_size = typeToNumBits(dtype) / 8; + const size_t bits_per_elem = typeToNumBits(dtype); + const bool is_sub_byte = (bits_per_elem < 8); + const size_t elem_size = is_sub_byte ? 0 : bits_per_elem / 8; GroupedBuffers grouped; - grouped.elem_size = elem_size; + grouped.elem_size = elem_size; // Only used for D output extraction (always >= 1 byte dtype) + + // Helper: convert element count to byte count (handles sub-byte types like FP4) + auto elems_to_bytes = [bits_per_elem](int64_t elems) -> size_t { + return static_cast((elems * static_cast(bits_per_elem)) / 8); + }; grouped.num_tensors = num_tensors; grouped.dtype = dtype; grouped.scaling_mode = scaling_mode; @@ -1095,9 +1103,14 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, // cuBLAS requires aligned pointers for vectorized loads static std::mt19937 gen(12345); std::uniform_int_distribution dist(0, 3); - // Calculate elements needed for 16-byte alignment in bytes, rounded up - const size_t align_elements = - std::max(1, (16 + elem_size - 1) / elem_size); // 16 bytes / element_size + // Calculate elements needed for 16-byte alignment + size_t align_elements; + if (is_sub_byte) { + // Sub-byte types (e.g. FP4): 16 bytes = 16*8/bits_per_elem elements + align_elements = (16 * 8) / bits_per_elem; + } else { + align_elements = std::max(1, (16 + elem_size - 1) / elem_size); + } return dist(gen) * static_cast(align_elements); }; @@ -1145,7 +1158,7 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, const int64_t total_elems = need_offsets ? (offsets[last_idx] + numel(last_idx)) : (logical_first * logical_last); - const size_t total_bytes = static_cast(total_elems) * elem_size; + const size_t total_bytes = elems_to_bytes(total_elems); NVTEGroupedTensor h = grouped.handle.get(); @@ -1155,8 +1168,8 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, if (has_rowwise) { grouped.data = cuda_alloc(total_bytes); for (size_t i = 0; i < num_tensors; ++i) { - const size_t offset_bytes = static_cast(offsets[i]) * elem_size; - NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data.get()) + offset_bytes, + const size_t offset_bytes_i = elems_to_bytes(offsets[i]); + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data.get()) + offset_bytes_i, tensors[i]->rowwise_dptr(), grouped.tensor_bytes[i], cudaMemcpyDeviceToDevice)); @@ -1169,8 +1182,8 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, if (has_columnwise) { grouped.columnwise_data = cuda_alloc(total_bytes); for (size_t i = 0; i < num_tensors; ++i) { - const size_t offset_bytes = static_cast(offsets[i]) * elem_size; - NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.columnwise_data.get()) + offset_bytes, + const size_t offset_bytes_i = elems_to_bytes(offsets[i]); + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.columnwise_data.get()) + offset_bytes_i, tensors[i]->columnwise_dptr(), grouped.tensor_bytes[i], cudaMemcpyDeviceToDevice)); @@ -1209,6 +1222,33 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, nvte_set_grouped_tensor_param(h, kNVTEGroupedTensorOffsets, &off_tensor, sizeof(off_tensor)); } + // Shared gather of per-tensor scale_inv buffers into a contiguous device buffer. + // Returns (device buffer, total element count). Used by all block-scaling recipes + // (MXFP8 / NVFP4 / FP8 block) — they only differ in element size and CPU getter. + auto gather_scale_inv = [&](size_t bytes_per_elem, auto get_shape_fn, + auto get_cpu_ptr_fn) -> std::pair, size_t> { + size_t total_elems = 0; + std::vector elem_offsets(num_tensors); + std::vector numels(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + elem_offsets[i] = total_elems; + const NVTEShape sshape = get_shape_fn(tensors[i]); + size_t numel = 1; + for (size_t d = 0; d < sshape.ndim; ++d) numel *= sshape.data[d]; + numels[i] = numel; + total_elems += numel; + } + CudaPtr<> buffer = cuda_alloc(total_elems * bytes_per_elem); + for (size_t i = 0; i < num_tensors; ++i) { + tensors[i]->to_cpu(); + NVTE_CHECK_CUDA(cudaGetLastError()); + void* dst = static_cast(buffer.get()) + elem_offsets[i] * bytes_per_elem; + NVTE_CHECK_CUDA(cudaMemcpy(dst, get_cpu_ptr_fn(tensors[i]), + numels[i] * bytes_per_elem, cudaMemcpyHostToDevice)); + } + return {std::move(buffer), total_elems}; + }; + if (isFp8Type(dtype) && scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { // FP8 tensor scaling: one float scale_inv per tensor // For delayed scaling, rowwise and columnwise share the same scale @@ -1231,67 +1271,145 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor, sizeof(scale_tensor)); } else if (scaling_mode == NVTE_MXFP8_1D_SCALING) { - // MXFP8: E8M0 scale_inv per block of 32 elements - // Helper to gather scale_inv from individual tensors into a contiguous buffer - auto gather_scales = [&]( - auto get_shape_fn, - auto get_cpu_ptr_fn) -> std::pair, size_t> { - // Compute total size and offsets - size_t total_bytes = 0; - std::vector scale_offsets(num_tensors); - std::vector numels(num_tensors); - - for (size_t i = 0; i < num_tensors; ++i) { - scale_offsets[i] = total_bytes; - const NVTEShape shape = get_shape_fn(tensors[i]); - size_t numel = 1; - for (size_t d = 0; d < shape.ndim; ++d) { - numel *= shape.data[d]; - } - numels[i] = numel; - total_bytes += numel; // E8M0 is 1 byte per element - } - - // Allocate and copy - CudaPtr<> buffer = cuda_alloc(total_bytes); + // The grouped GEMM setup kernel now computes per-tensor scale offsets via + // compute_grouped_scale_inv_offset + padded_mxfp8_scale_inv_bytes, which sums + // the padded (roundup(., 128) x roundup(./32, 4)) scale tile sizes — so dims + // only need to satisfy the MXFP8 block alignment of 32, not 128. (Previously + // this assertion enforced /128 alignment because the old setup kernel computed + // offsets as data_offset / 32, which silently mismatched for unaligned dims.) + if (enforce_grouped_gemm_alignment) { for (size_t i = 0; i < num_tensors; ++i) { - tensors[i]->to_cpu(); - NVTE_CHECK_CUDA(cudaGetLastError()); - void* dst = static_cast(buffer.get()) + scale_offsets[i]; - const void* src = get_cpu_ptr_fn(tensors[i]); - NVTE_CHECK_CUDA(cudaMemcpy(dst, src, numels[i], cudaMemcpyHostToDevice)); + NVTE_CHECK(first_dims[i] % 32 == 0, + "MXFP8 grouped GEMM test: first_dim must be divisible by 32, got ", + first_dims[i]); + NVTE_CHECK(last_dims[i] % 32 == 0, + "MXFP8 grouped GEMM test: last_dim must be divisible by 32, got ", + last_dims[i]); } - return {std::move(buffer), total_bytes}; - }; - - // Gather rowwise scale_inv if available + } + // MXFP8: E8M0 scale_inv per block of 32 elements (1 byte per scale element). if (has_rowwise) { - auto [row_buffer, row_total] = gather_scales( + auto [row_buffer, row_total] = gather_scale_inv( + /*bytes_per_elem=*/1, [](Tensor* t) { return t->rowwise_scale_inv_shape(); }, - [](Tensor* t) { return t->rowwise_cpu_scale_inv_ptr(); }); + [](Tensor* t) -> const void* { return t->rowwise_cpu_scale_inv_ptr(); }); grouped.scale_inv = std::move(row_buffer); - NVTEShape row_shape = nvte_make_shape(&row_total, 1); NVTEBasicTensor row_tensor{grouped.scale_inv.get(), kNVTEFloat8E8M0, row_shape}; nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &row_tensor, sizeof(row_tensor)); } - - // Gather columnwise scale_inv if available if (has_columnwise) { - auto [col_buffer, col_total] = gather_scales( + auto [col_buffer, col_total] = gather_scale_inv( + /*bytes_per_elem=*/1, [](Tensor* t) { return t->columnwise_scale_inv_shape(); }, - [](Tensor* t) { return t->columnwise_cpu_scale_inv_ptr(); }); + [](Tensor* t) -> const void* { return t->columnwise_cpu_scale_inv_ptr(); }); grouped.columnwise_scale_inv = std::move(col_buffer); - NVTEShape col_shape = nvte_make_shape(&col_total, 1); NVTEBasicTensor col_tensor{grouped.columnwise_scale_inv.get(), kNVTEFloat8E8M0, col_shape}; nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &col_tensor, sizeof(col_tensor)); } - - // Mark as having swizzled scales (required for GEMM) const uint8_t swizzled = 1; nvte_set_grouped_tensor_param(h, kNVTEGroupedWithGEMMSwizzledScales, &swizzled, sizeof(swizzled)); + } else if (scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { + // FP8 block scaling: float32 scale_inv per block of 128 elements. + // Unlike MXFP8/NVFP4, dims are not required to be multiples of the block size — the + // quantizer and padded_block_{1d,2d}_scale_inv_floats both use ceildiv, so the + // `enforce_grouped_gemm_alignment` flag intentionally has no effect here. + if (has_rowwise) { + auto [row_buffer, row_total] = gather_scale_inv( + /*bytes_per_elem=*/sizeof(float), + [](Tensor* t) { return t->rowwise_scale_inv_shape(); }, + [](Tensor* t) -> const void* { return t->rowwise_cpu_scale_inv_ptr(); }); + grouped.scale_inv = std::move(row_buffer); + NVTEShape row_shape = nvte_make_shape(&row_total, 1); + NVTEBasicTensor row_tensor{grouped.scale_inv.get(), kNVTEFloat32, row_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &row_tensor, sizeof(row_tensor)); + } + if (has_columnwise) { + auto [col_buffer, col_total] = gather_scale_inv( + /*bytes_per_elem=*/sizeof(float), + [](Tensor* t) { return t->columnwise_scale_inv_shape(); }, + [](Tensor* t) -> const void* { return t->columnwise_cpu_scale_inv_ptr(); }); + grouped.columnwise_scale_inv = std::move(col_buffer); + NVTEShape col_shape = nvte_make_shape(&col_total, 1); + NVTEBasicTensor col_tensor{grouped.columnwise_scale_inv.get(), kNVTEFloat32, col_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &col_tensor, sizeof(col_tensor)); + } + } else if (scaling_mode == NVTE_NVFP4_1D_SCALING) { + // NVFP4 quantize (optimized BF16 path) requires dims % 32 for TMA 16B alignment. + if (enforce_grouped_gemm_alignment) { + for (size_t i = 0; i < num_tensors; ++i) { + NVTE_CHECK(first_dims[i] % 32 == 0, + "NVFP4 grouped GEMM test: first_dim must be divisible by 32 " + "(NVFP4 quantize TMA alignment), got ", + first_dims[i]); + NVTE_CHECK(last_dims[i] % 32 == 0, + "NVFP4 grouped GEMM test: last_dim must be divisible by 32 " + "(NVFP4 quantize TMA alignment), got ", + last_dims[i]); + } + } + // NVFP4: E4M3 scale_inv per block of 16 elements (swizzled for GEMM, 1 byte per scale). + if (has_rowwise) { + auto [row_buffer, row_total] = gather_scale_inv( + /*bytes_per_elem=*/1, + [](Tensor* t) { return t->rowwise_scale_inv_shape(); }, + [](Tensor* t) -> const void* { return t->rowwise_cpu_scale_inv_ptr(); }); + grouped.scale_inv = std::move(row_buffer); + NVTEShape row_shape = nvte_make_shape(&row_total, 1); + NVTEBasicTensor row_tensor{grouped.scale_inv.get(), kNVTEFloat8E4M3, row_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &row_tensor, sizeof(row_tensor)); + } + if (has_columnwise) { + auto [col_buffer, col_total] = gather_scale_inv( + /*bytes_per_elem=*/1, + [](Tensor* t) { return t->columnwise_scale_inv_shape(); }, + [](Tensor* t) -> const void* { return t->columnwise_cpu_scale_inv_ptr(); }); + grouped.columnwise_scale_inv = std::move(col_buffer); + NVTEShape col_shape = nvte_make_shape(&col_total, 1); + NVTEBasicTensor col_tensor{grouped.columnwise_scale_inv.get(), kNVTEFloat8E4M3, col_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &col_tensor, sizeof(col_tensor)); + } + + // Mark as having swizzled scales (required for NVFP4 GEMM) + uint8_t swizzled = 1; + nvte_set_grouped_tensor_param(h, kNVTEGroupedWithGEMMSwizzledScales, &swizzled, sizeof(swizzled)); + + // Gather per-tensor amax values for NVFP4 global scale computation + auto gather_amax = [&](NVTETensorParam param) -> CudaPtr<> { + // Check if first tensor has this amax + NVTEBasicTensor first_amax = nvte_get_tensor_param(tensors[0]->data(), param); + if (first_amax.data_ptr == nullptr) return CudaPtr<>(); + + std::vector amax_cpu(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + NVTEBasicTensor amax_bt = nvte_get_tensor_param(tensors[i]->data(), param); + NVTE_CHECK(amax_bt.data_ptr != nullptr, "Tensor ", i, " is missing amax"); + float val; + NVTE_CHECK_CUDA(cudaMemcpy(&val, amax_bt.data_ptr, sizeof(float), cudaMemcpyDeviceToHost)); + amax_cpu[i] = val; + } + CudaPtr<> dev = cuda_alloc(sizeof(float) * num_tensors); + NVTE_CHECK_CUDA(cudaMemcpy(dev.get(), amax_cpu.data(), + sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); + return dev; + }; + + grouped.amax_dev = gather_amax(kNVTEAmax); + if (grouped.amax_dev.get()) { + NVTEShape amax_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor amax_tensor{grouped.amax_dev.get(), kNVTEFloat32, amax_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedAmax, &amax_tensor, sizeof(amax_tensor)); + } + + grouped.columnwise_amax_dev = gather_amax(kNVTEColumnwiseAmax); + if (grouped.columnwise_amax_dev.get()) { + NVTEShape amax_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor amax_tensor{grouped.columnwise_amax_dev.get(), kNVTEFloat32, amax_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseAmax, &amax_tensor, sizeof(amax_tensor)); + } + } return grouped; diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 17f36a99dd..43044ea011 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -593,6 +593,8 @@ struct GroupedBuffers { CudaPtr last_dims_dev; CudaPtr offsets_dev; CudaPtr<> columnwise_data; + CudaPtr<> amax_dev; // Per-tensor amax for NVFP4 grouped GEMM + CudaPtr<> columnwise_amax_dev; // Per-tensor columnwise amax for NVFP4 grouped GEMM NVTEShape logical_shape{}; std::vector offsets_host; std::vector tensor_bytes; @@ -614,7 +616,8 @@ struct GroupedBuffers { }; GroupedBuffers build_grouped_tensor(const std::vector& tensors, - const NVTEScalingMode scaling_mode); + const NVTEScalingMode scaling_mode, + bool enforce_grouped_gemm_alignment = true); } // namespace test diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a718ea2a8a..b301c0a77c 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2883,10 +2883,13 @@ def _apply_grouped_bias_ref( @pytest.mark.parametrize("accumulate", [False, True]) @pytest.mark.parametrize("use_bias_scale", [False, True]) def test_grouped_gemm_grouped_tensor(z, m, n, k, case, layout, accumulate, use_bias_scale) -> None: + if torch.cuda.get_device_capability() < (9, 0): + pytest.skip("Grouped GEMM requires Hopper (SM90) or newer.") + if torch.cuda.get_device_capability() < (10, 0): + if tex.get_cublasLt_version() < 130400: + pytest.skip("Grouped GEMM on Hopper requires cuBLAS 13.4+.") if tex.get_cublasLt_version() < 130300: pytest.skip("Grouped GEMM requires cuBLAS 13.3+.") - if torch.cuda.get_device_capability() < (10, 0): - pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") if not is_bf16_available(): pytest.skip("bfloat16 is required for grouped GEMM test.") diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 12479f2a9c..694a01c1e3 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -91,6 +91,10 @@ inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_M inline bool is_nvfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; } +inline bool is_fp8_block_scaling(const NVTEScalingMode &mode) { + return mode == NVTE_BLOCK_SCALING_1D || mode == NVTE_BLOCK_SCALING_2D; +} + inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", end, " in a vector with ", shape.size(), " entries"); diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 6a7af158e5..f0e62db0ac 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -33,6 +33,16 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) { // MXFP8 support for grouped GEMM requires cuBLAS 13.3+ #define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130300 + +// Hopper (SM90) support for grouped GEMM requires cuBLAS 13.4+ +#define CUBLAS_GROUPED_GEMM_HOPPER_VERSION 130400 + +// NVFP4 support for grouped GEMM requires cuBLAS 13.4+ +#define CUBLAS_NVFP4_GROUPED_GEMM_VERSION 130400 + +// FP8 block scaling support for grouped GEMM requires cuBLAS 13.4+ +#define CUBLAS_FP8_BLOCK_GROUPED_GEMM_VERSION 130400 + // BF16 support for grouped GEMM requires cuBLAS 13.3+ #define CUBLAS_GROUPED_GEMM_VERSION 130300 @@ -132,119 +142,117 @@ inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) static constexpr size_t kGroupedGemmAlignment = 256; static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB -// Workspace layout for grouped GEMM +// Workspace layout for grouped GEMM. +// Layout described once in `from_buffers`; `required_setup_size` runs the same walker +// with base=nullptr to derive the total byte count, so the two stay in sync by construction. struct GroupedGemmSetupWorkspace { - void **A_ptrs; - void **B_ptrs; - void **C_ptrs; - void **D_ptrs; - float **alpha_ptrs; - float **beta_ptrs; - void ** - a_scale_inv_ptrs; // Per-tensor FP8 scale pointers for A (float* for tensor scaling, E8M0* for MXFP8) - void ** - b_scale_inv_ptrs; // Per-tensor FP8 scale pointers for B (float* for tensor scaling, E8M0* for MXFP8) + void **A_ptrs = nullptr; + void **B_ptrs = nullptr; + void **C_ptrs = nullptr; + void **D_ptrs = nullptr; + float **alpha_ptrs = nullptr; + float **beta_ptrs = nullptr; + // Per-tensor scale_inv pointers (float* for tensor scaling, E8M0* for MXFP8, E4M3* for NVFP4) + void **a_scale_inv_ptrs = nullptr; + void **b_scale_inv_ptrs = nullptr; // Storage dimensions for cuBLAS matrix layouts - int *a_rows; - int *a_cols; - int *b_rows; - int *b_cols; - int *d_rows; // M (first dim) - also used for C - int *d_cols; // N (last dim) - also used for C - - // Initialize from workspace buffer - // Layout: all pointer arrays first (16-byte aligned for cuBLAS), then int arrays - static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) { + int *a_rows = nullptr; + int *a_cols = nullptr; + int *b_rows = nullptr; + int *b_cols = nullptr; + int *d_rows = nullptr; // M (first dim) - also used for C + int *d_cols = nullptr; // N (last dim) - also used for C + // NVFP4: per-group computed alpha values (alpha * amax_A * amax_B * factor_inv) + float *nvfp4_computed_alpha = nullptr; + // End-of-layout offset in bytes (unaligned). required_setup_size rounds this up. + size_t total_bytes = 0; + + // Walk the layout once. If `base` is non-null, fields are populated; otherwise + // only `total_bytes` is meaningful (used by required_setup_size). + static GroupedGemmSetupWorkspace from_buffers(char *base, size_t num_tensors) { GroupedGemmSetupWorkspace ws; - size_t offset = 0; + constexpr size_t kPtrAlignment = 16; // cuBLAS requires 16-byte alignment for pointer arrays const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); - constexpr size_t kPtrAlignment = 16; // cuBLAS requires 16-byte alignment for pointer arrays + const size_t float_size = num_tensors * sizeof(float); + size_t offset = 0; - // Helper to align offset to kPtrAlignment - auto align_offset = [&]() { + auto align_ptr = [&]() { offset = (offset + kPtrAlignment - 1) / kPtrAlignment * kPtrAlignment; }; + auto place = [&](auto *&field, size_t size_bytes) { + using Field = std::remove_reference_t; + if (base != nullptr) field = reinterpret_cast(base + offset); + offset += size_bytes; + }; - // Pointer arrays first (all 16-byte aligned for cuBLAS grouped GEMM) - align_offset(); - ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - align_offset(); - ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - align_offset(); - ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - align_offset(); - ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - align_offset(); - ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - align_offset(); - ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - align_offset(); - ws.a_scale_inv_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - align_offset(); - ws.b_scale_inv_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - - // Int arrays for storage dimensions (4-byte aligned is fine) - align_offset(); - ws.a_rows = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - ws.a_cols = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - ws.b_rows = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - ws.b_cols = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - ws.d_rows = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - ws.d_cols = reinterpret_cast(setup_ws_ptr + offset); - + // 8 pointer arrays (each 16-byte aligned), then 6 int arrays, then 1 float array. + align_ptr(); + place(ws.A_ptrs, ptr_size); + align_ptr(); + place(ws.B_ptrs, ptr_size); + align_ptr(); + place(ws.C_ptrs, ptr_size); + align_ptr(); + place(ws.D_ptrs, ptr_size); + align_ptr(); + place(ws.alpha_ptrs, ptr_size); + align_ptr(); + place(ws.beta_ptrs, ptr_size); + align_ptr(); + place(ws.a_scale_inv_ptrs, ptr_size); + align_ptr(); + place(ws.b_scale_inv_ptrs, ptr_size); + place(ws.a_rows, int_size); + place(ws.a_cols, int_size); + place(ws.b_rows, int_size); + place(ws.b_cols, int_size); + place(ws.d_rows, int_size); + place(ws.d_cols, int_size); + place(ws.nvfp4_computed_alpha, float_size); + + ws.total_bytes = offset; return ws; } - // Calculate required size for setup workspace static size_t required_setup_size(size_t num_tensors, size_t alignment) { - const size_t ptr_size = num_tensors * sizeof(void *); - const size_t int_size = num_tensors * sizeof(int); - constexpr size_t kPtrAlignment = 16; // Must match from_buffers - - // Layout: 8 ptr arrays (each 16-byte aligned), then 6 int arrays - // Each ptr array takes ptr_size bytes but needs to start at 16-byte boundary - auto aligned_ptr_size = ((ptr_size + kPtrAlignment - 1) / kPtrAlignment) * kPtrAlignment; - size_t size = 8 * aligned_ptr_size + 6 * int_size; - size = ((size + alignment - 1) / alignment) * alignment; - return size; + const size_t raw = from_buffers(nullptr, num_tensors).total_bytes; + return ((raw + alignment - 1) / alignment) * alignment; } }; +// ----------------------------------------------------------------------------- +// Helper routines to keep nvte_grouped_gemm readable +// ----------------------------------------------------------------------------- +inline bool grouped_gemm_supports_per_group_alpha_beta(int sm) { return sm >= 100; } + inline size_t validate_grouped_gemm_inputs( size_t num_tensors, std::initializer_list inputs, - const transformer_engine::Tensor *alpha_tensor, const transformer_engine::Tensor *beta_tensor) { + const transformer_engine::Tensor *alpha_tensor, const transformer_engine::Tensor *beta_tensor, + bool supports_per_group_alpha_beta) { NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: number of tensors must be at least 1"); for (const auto *tensor : inputs) { NVTE_CHECK(tensor->num_tensors == num_tensors, "Grouped GEMM: inputs must have the same number of tensors"); } + // Hopper currently requires a uniform alpha/beta scalar for the whole grouped GEMM, + // while Blackwell+ supports per-matrix alpha/beta. const size_t alpha_numel = alpha_tensor->data.numel(); const size_t beta_numel = beta_tensor->data.numel(); - NVTE_CHECK(alpha_numel == num_tensors, "Grouped GEMM: alpha must have num_tensors (", num_tensors, - ") elements, got ", alpha_numel); - NVTE_CHECK(beta_numel == num_tensors, "Grouped GEMM: beta must have num_tensors (", num_tensors, - ") elements, got ", beta_numel); + const size_t expected_alphabeta_numel = supports_per_group_alpha_beta ? num_tensors : 1; + const char *alphabeta_desc = supports_per_group_alpha_beta ? "num_tensors" : "1"; + NVTE_CHECK(alpha_numel == expected_alphabeta_numel, "Grouped GEMM: alpha must have ", + alphabeta_desc, " element(s), got ", alpha_numel); + NVTE_CHECK(beta_numel == expected_alphabeta_numel, "Grouped GEMM: beta must have ", + alphabeta_desc, " element(s), got ", beta_numel); auto is_supported_input_dtype = [](transformer_engine::DType dtype) { return dtype == transformer_engine::DType::kFloat8E4M3 || dtype == transformer_engine::DType::kFloat8E5M2 || dtype == transformer_engine::DType::kBFloat16 || - dtype == transformer_engine::DType::kFloat16; + dtype == transformer_engine::DType::kFloat16 || + dtype == transformer_engine::DType::kFloat4E2M1; }; for (const auto *tensor : inputs) { if (tensor->has_data() || tensor->has_columnwise_data()) { @@ -264,15 +272,18 @@ inline size_t validate_grouped_gemm_inputs( if (ref != nullptr) { const bool ref_is_fp8 = is_fp8_dtype(ref->dtype()); const bool ref_is_mxfp8 = transformer_engine::is_mxfp_scaling(ref->scaling_mode); + const bool ref_is_fp8_block = transformer_engine::is_fp8_block_scaling(ref->scaling_mode); for (const auto *tensor : inputs) { if (!(tensor->has_data() || tensor->has_columnwise_data())) continue; NVTE_CHECK(is_fp8_dtype(tensor->dtype()) == ref_is_fp8, "Grouped GEMM: A and B must both be FP8 or both be non-FP8."); NVTE_CHECK(transformer_engine::is_mxfp_scaling(tensor->scaling_mode) == ref_is_mxfp8, - "Grouped GEMM: A and B must both use MXFP8 scaling or both use tensor scaling."); - if (ref_is_mxfp8) { + "Grouped GEMM: A and B must both use MXFP8 scaling or both not."); + NVTE_CHECK(transformer_engine::is_fp8_block_scaling(tensor->scaling_mode) == ref_is_fp8_block, + "Grouped GEMM: A and B must both use FP8 block scaling or both not."); + if (ref_is_mxfp8 || transformer_engine::is_nvfp_scaling(tensor->scaling_mode)) { NVTE_CHECK(tensor->with_gemm_swizzled_scales, - "MXFP8 grouped GEMM: scales must be swizzled for GEMM."); + "Grouped GEMM: scales must be swizzled for GEMM (MXFP8/NVFP4)."); } } } @@ -303,11 +314,22 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { inline void check_grouped_gemm_requirements(const char *api_name) { const int current_device = transformer_engine::cuda::current_device(); - NVTE_CHECK(transformer_engine::cuda::sm_arch(current_device) >= 100, api_name, - " requires Blackwell (SM100) or newer architecture."); - NVTE_CHECK(transformer_engine::cuda::cublas_version() >= CUBLAS_GROUPED_GEMM_VERSION, api_name, - " requires cuBLAS 13.3+, but run-time cuBLAS version is ", - transformer_engine::cuda::cublas_version()); + const int sm = transformer_engine::cuda::sm_arch(current_device); + const int cublas_ver = transformer_engine::cuda::cublas_version(); +#if CUBLAS_VERSION >= CUBLAS_GROUPED_GEMM_HOPPER_VERSION + NVTE_CHECK(sm >= 90, api_name, " requires Hopper (SM90) or newer architecture."); + NVTE_CHECK(cublas_ver >= CUBLAS_GROUPED_GEMM_VERSION, api_name, + " requires cuBLAS 13.3+, but run-time cuBLAS version is ", cublas_ver); + if (sm < 100) { + NVTE_CHECK(cublas_ver >= CUBLAS_GROUPED_GEMM_HOPPER_VERSION, api_name, + " on Hopper (SM90) requires cuBLAS 13.4+, but run-time cuBLAS version is ", + cublas_ver); + } +#else + NVTE_CHECK(sm >= 100, api_name, " requires Blackwell (SM100) or newer architecture."); + NVTE_CHECK(cublas_ver >= CUBLAS_GROUPED_GEMM_VERSION, api_name, + " requires cuBLAS 13.3+, but run-time cuBLAS version is ", cublas_ver); +#endif } inline transformer_engine::GroupedMatmulConfig parse_grouped_gemm_config( @@ -327,11 +349,29 @@ struct GroupedOperandSelection { TensorShapeInfo shape; // Shape info with dims already swapped for columnwise if needed char *dptr = nullptr; void *scale_inv = nullptr; // Contiguous array of scales (input) + void *amax = nullptr; // Per-tensor amax values (NVFP4 only) transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; bool with_gemm_swizzled_scales = false; bool trans = false; bool rowwise = true; + // Whether sel.shape is pre-swapped relative to the original tensor shape (i.e., whether + // the columnwise data was set up as a transposed view). Together with `rowwise`, this + // determines the canonical orientation of the underlying scale_inv buffer when computing + // per-tensor scale offsets in setup_grouped_gemm_kernel for non-MXFP8 recipes. + bool swap_dims = false; +}; + +struct GroupedGemmConfig { + bool use_split_accumulator = false; + bool use_fp8 = false; + bool use_per_group_alpha_beta = false; + void *alpha_dptr = nullptr; + void *beta_dptr = nullptr; + int64_t avg_m = 0; + int64_t avg_n = 0; + int64_t avg_k = 0; + int sm_count = 0; }; constexpr int kMaxGroups = 64; @@ -346,6 +386,7 @@ struct MultiTensorGroupGemmOutputArgs { struct MultiTensorGroupGemmInputArgs { void *data_ptrs[kMaxGroups]; void *scale_inv_ptrs[kMaxGroups]; + void *amax_ptrs[kMaxGroups]; int rows[kMaxGroups]; int cols[kMaxGroups]; }; @@ -360,12 +401,16 @@ struct MultiTensorListInfo { struct OperandStorageChoice { bool use_rowwise = true; - bool swap_dims = true; + // Only meaningful when use_rowwise == false (columnwise storage). Indicates that + // the columnwise buffer is the logically-transposed tensor (e.g. NVFP4 colwise = + // transposed-rowwise), so sel.shape needs first/last swapped. + bool swap_dims = false; bool trans = false; }; inline OperandStorageChoice choose_grouped_operand_storage(bool trans, bool is_A, bool is_mxfp8, - bool is_fp8, bool non_tn_fp8_ok, + bool is_fp8, bool is_nvfp4, + bool is_fp8_block, bool non_tn_fp8_ok, bool has_row, bool has_col, const char *name) { NVTE_CHECK(has_row || has_col, "Grouped GEMM: ", name, @@ -374,7 +419,7 @@ inline OperandStorageChoice choose_grouped_operand_storage(bool trans, bool is_A if (is_A) { if (trans) { NVTE_CHECK(has_row, "Grouped GEMM: MXFP8 transposed ", name, " is missing row-wise data"); - return {true, true, trans}; + return {true, false, trans}; } NVTE_CHECK(has_col, "Grouped GEMM: MXFP8 non-transposed ", name, " is missing column-wise data"); @@ -385,19 +430,51 @@ inline OperandStorageChoice choose_grouped_operand_storage(bool trans, bool is_A return {false, false, trans}; } NVTE_CHECK(has_row, "Grouped GEMM: MXFP8 non-transposed ", name, " is missing row-wise data"); - return {true, true, trans}; + return {true, false, trans}; + } + + // FP8 block scaling on Hopper: force TN by using columnwise data with swap_dims=false. + // Unlike tensor scaling, block scaling has per-block scales whose offsets depend on physical + // dimensions. swap_dims=false keeps a_rows/b_rows = physical trailing dim = correct lda, + // matching how cublas_gemm sets lda = k for block scaling (cublaslt_gemm.cu). + if (is_fp8_block && !non_tn_fp8_ok) { + if (is_A && !trans) { + NVTE_CHECK(has_col, "Grouped GEMM: ", name, + " is missing column-wise data needed for TN layout"); + return {false, false, true}; + } + if (!is_A && trans) { + NVTE_CHECK(has_col, "Grouped GEMM: ", name, + " is missing column-wise data needed for TN layout"); + return {false, false, false}; + } + } + + // NVFP4: force TN by switching layout and flipping transpose. + // NVFP4 columnwise data is the transposed tensor quantized rowwise, so swap_dims=true. + if (is_nvfp4) { + if (is_A && !trans) { + NVTE_CHECK(has_col, "Grouped GEMM: ", name, + " is missing column-wise data needed for TN layout"); + return {false, true, true}; + } + if (!is_A && trans) { + NVTE_CHECK(has_col, "Grouped GEMM: ", name, + " is missing column-wise data needed for TN layout"); + return {false, true, false}; + } } - // Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed. + // Hopper-style TN-only FP8 (tensor scaling): force TN by switching layout and flipping transpose. if (is_fp8 && !non_tn_fp8_ok) { if (is_A && !trans) { NVTE_CHECK(has_col, "Grouped GEMM: ", name, - " is missing column-wise data needed for FP8 TN layout"); + " is missing column-wise data needed for TN layout"); return {false, true, true}; } if (!is_A && trans) { NVTE_CHECK(has_col, "Grouped GEMM: ", name, - " is missing column-wise data needed for FP8 TN layout"); + " is missing column-wise data needed for TN layout"); return {false, true, false}; } } @@ -410,7 +487,7 @@ inline OperandStorageChoice choose_grouped_operand_storage(bool trans, bool is_A } NVTE_CHECK(has_row, "Grouped GEMM: ", name, " is missing row-wise data"); - return {true, true, trans}; + return {true, false, trans}; } // Build Kernel Arguments detailing out addresses and other metadata for list of C/D tensors @@ -451,7 +528,8 @@ inline MultiTensorGroupGemmOutputArgs build_grouped_gemm_multi_out_args( // passed to the grouped GEMM kernel. Use-case: A --> List of Expert weights inline MultiTensorGroupGemmInputArgs build_grouped_gemm_multi_inputA_args( const NVTETensor *tensor_list, size_t list_size, bool use_rowwise, bool is_fp8, - int64_t *avg_first_dim, int64_t *avg_last_dim, const char *name) { + int64_t *avg_first_dim, int64_t *avg_last_dim, const char *name, bool needs_scale_inv = false, + bool swap_dims = false) { using namespace transformer_engine; MultiTensorGroupGemmInputArgs args{}; *avg_first_dim = 0; @@ -459,6 +537,7 @@ inline MultiTensorGroupGemmInputArgs build_grouped_gemm_multi_inputA_args( if (list_size == 0) { return args; } + const bool requires_scale = is_fp8 || needs_scale_inv; for (size_t i = 0; i < list_size; ++i) { const transformer_engine::Tensor *t = transformer_engine::convertNVTETensorCheck(tensor_list[i]); @@ -467,20 +546,27 @@ inline MultiTensorGroupGemmInputArgs build_grouped_gemm_multi_inputA_args( use_rowwise ? t->scale_inv : t->columnwise_scale_inv; NVTE_CHECK(data.has_data(), "Grouped GEMM: ", name, "_list tensor ", i, " is missing required data."); - NVTE_CHECK(data.shape.size() == 2, "Grouped GEMM: ", name, "_list tensor ", i, " must be 2D."); args.data_ptrs[i] = data.dptr; - args.rows[i] = static_cast(data.shape[1]); - args.cols[i] = static_cast(data.shape[0]); - *avg_first_dim += static_cast(data.shape[0]); - *avg_last_dim += static_cast(data.shape[1]); - - if (is_fp8) { + const auto logical_shape = t->shape(); + NVTE_CHECK(logical_shape.size() == 2, "Grouped GEMM: ", name, "_list tensor ", i, + " must be 2D."); + const size_t first_dim = swap_dims ? logical_shape[1] : logical_shape[0]; + const size_t last_dim = swap_dims ? logical_shape[0] : logical_shape[1]; + args.rows[i] = static_cast(last_dim); + args.cols[i] = static_cast(first_dim); + *avg_first_dim += static_cast(first_dim); + *avg_last_dim += static_cast(last_dim); + + if (requires_scale) { NVTE_CHECK(scale_inv.has_data(), "Grouped GEMM: ", name, "_list tensor ", i, - " requires scale_inv for FP8."); + " requires scale_inv."); args.scale_inv_ptrs[i] = scale_inv.dptr; } else { args.scale_inv_ptrs[i] = nullptr; } + + const transformer_engine::SimpleTensor &amax_src = use_rowwise ? t->amax : t->columnwise_amax; + args.amax_ptrs[i] = amax_src.has_data() ? amax_src.dptr : nullptr; } *avg_first_dim /= static_cast(list_size); *avg_last_dim /= static_cast(list_size); @@ -509,8 +595,11 @@ inline MultiTensorListInfo validate_grouped_gemm_multi_inputA_list(const NVTETen info.scaling_mode = t0->scaling_mode; info.with_gemm_swizzled_scales = t0->with_gemm_swizzled_scales; const bool mxfp8 = transformer_engine::is_mxfp_scaling(info.scaling_mode); - NVTE_CHECK(info.scaling_mode == NVTE_DELAYED_TENSOR_SCALING || mxfp8, - "Grouped GEMM: input list only supports tensor scaling or MXFP8."); + const bool nvfp4 = transformer_engine::is_nvfp_scaling(info.scaling_mode); + const bool fp8_block = transformer_engine::is_fp8_block_scaling(info.scaling_mode); + NVTE_CHECK(info.scaling_mode == NVTE_DELAYED_TENSOR_SCALING || mxfp8 || nvfp4 || fp8_block, + "Grouped GEMM: input list only supports tensor scaling, MXFP8, NVFP4, " + "or FP8 block scaling."); for (size_t i = 0; i < list_size; ++i) { const transformer_engine::Tensor *t = @@ -591,10 +680,13 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: const auto sm = t->scaling_mode; const bool mxfp8 = is_mxfp_scaling(sm); + const bool nvfp4 = is_nvfp_scaling(sm); + const bool fp8_block = is_fp8_block_scaling(sm); // Validate scaling mode - NVTE_CHECK(sm == NVTE_DELAYED_TENSOR_SCALING || mxfp8, - "Grouped GEMM is only supported with bf16, fp8 tensor scaling and MXFP8"); + NVTE_CHECK(sm == NVTE_DELAYED_TENSOR_SCALING || mxfp8 || nvfp4 || fp8_block, + "Grouped GEMM is only supported with bf16, fp8 tensor scaling, MXFP8, NVFP4, " + "and FP8 block scaling"); const DType row_dtype = t->data.dtype; const DType col_dtype = t->columnwise_data.dtype; @@ -605,7 +697,8 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: const DType rep_dtype = has_row ? row_dtype : col_dtype; const bool is_fp8 = is_fp8_dtype(rep_dtype); - const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); + // FP8 block scaling on Hopper requires TN layout (same as tensor scaling) + const bool non_tn_fp8_ok = fp8_block ? false : nvte_is_non_tn_fp8_gemm_supported(); // Helper to select columnwise storage. // swap_dims=true (default): swap first/last dims in shape info (used when columnwise == transposed). @@ -614,8 +707,10 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: auto use_columnwise = [&](bool swap_dims = true) { sel.dptr = static_cast(t->columnwise_data.dptr); sel.scale_inv = t->columnwise_scale_inv.dptr; + sel.amax = t->columnwise_amax.dptr; sel.dtype = col_dtype; sel.rowwise = false; + sel.swap_dims = swap_dims; sel.shape = create_shape_info(t, swap_dims); }; @@ -623,13 +718,15 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: auto use_rowwise = [&]() { sel.dptr = static_cast(t->data.dptr); sel.scale_inv = t->scale_inv.dptr; + sel.amax = t->amax.dptr; sel.dtype = row_dtype; sel.rowwise = true; sel.shape = create_shape_info(t, /*swap_dims=*/false); }; - const auto choice = choose_grouped_operand_storage(trans, is_A, mxfp8, is_fp8, non_tn_fp8_ok, - has_row, has_col, is_A ? "A" : "B"); + const auto choice = + choose_grouped_operand_storage(trans, is_A, mxfp8, is_fp8, nvfp4, fp8_block, non_tn_fp8_ok, + has_row, has_col, is_A ? "A" : "B"); sel.trans = choice.trans; if (choice.use_rowwise) { use_rowwise(); @@ -669,7 +766,8 @@ inline void init_matrix_layouts( } inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, - cublasOperation_t op_B, bool use_fp8, bool use_split_accumulator) { + cublasOperation_t op_B, bool use_fp8, bool use_split_accumulator, + bool use_per_group_alpha_beta) { NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, @@ -681,13 +779,15 @@ inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOpera NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); - int64_t alphabeta_batch_stride = 1; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, - CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, - &alphabeta_batch_stride, sizeof(int64_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, - CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, - &alphabeta_batch_stride, sizeof(int64_t))); + if (use_per_group_alpha_beta) { + int64_t alphabeta_batch_stride = 1; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); + } // Fast accumulation is only supported for FP8 (mirrors non-grouped GEMM logic). int8_t fastAccuMode = use_split_accumulator ? 0 : static_cast(use_fp8); @@ -720,6 +820,75 @@ inline void set_mxfp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, #endif // CUBLAS_VERSION >= CUBLAS_MXFP8_GROUPED_GEMM_VERSION } +// Configures cuBLAS for NVFP4 grouped GEMM: sets VEC16_UE4M3 scale mode and scale pointers +// for both A and B. Requires cuBLAS 12.8+. +inline void set_nvfp4_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, + void **a_scale_inv_ptrs, void **b_scale_inv_ptrs) { +#if CUBLAS_VERSION >= CUBLAS_NVFP4_GROUPED_GEMM_VERSION + NVTE_CHECK(transformer_engine::cuda::cublas_version() >= CUBLAS_NVFP4_GROUPED_GEMM_VERSION, + "NVFP4 grouped GEMM requires cuBLAS 13.4+, but run-time cuBLAS version is ", + transformer_engine::cuda::cublas_version()); + const cublasLtMatmulMatrixScale_t scale_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, + &scale_mode, sizeof(scale_mode))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, + &scale_mode, sizeof(scale_mode))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &a_scale_inv_ptrs, sizeof(a_scale_inv_ptrs))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &b_scale_inv_ptrs, sizeof(b_scale_inv_ptrs))); +#else + NVTE_CHECK(false, + "NVFP4 grouped GEMM requires cuBLAS 13.4+, but compile-time " + "cuBLAS version is ", + CUBLAS_VERSION); +#endif // CUBLAS_VERSION >= CUBLAS_NVFP4_GROUPED_GEMM_VERSION +} + +// Configures cuBLAS for FP8 block-scaling grouped GEMM: sets VEC128_32F or BLK128x128_32F +// scale mode and scale pointers for A and B. Requires cuBLAS 12.9+. +inline void set_fp8_block_scaling_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, + void **a_scale_inv_ptrs, void **b_scale_inv_ptrs, + NVTEScalingMode a_scaling_mode, + NVTEScalingMode b_scaling_mode) { +#if CUBLAS_VERSION >= CUBLAS_FP8_BLOCK_GROUPED_GEMM_VERSION + NVTE_CHECK( + transformer_engine::cuda::cublas_version() >= CUBLAS_FP8_BLOCK_GROUPED_GEMM_VERSION, + "FP8 block scaling grouped GEMM requires cuBLAS 13.4+, but run-time cuBLAS version is ", + transformer_engine::cuda::cublas_version()); + + // 2D by 2D is not supported + NVTE_CHECK(!(a_scaling_mode == NVTE_BLOCK_SCALING_2D && b_scaling_mode == NVTE_BLOCK_SCALING_2D), + "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling GEMM is supported, " + "but got 2D by 2D"); + + const cublasLtMatmulMatrixScale_t scale_mode_a = + a_scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; + const cublasLtMatmulMatrixScale_t scale_mode_b = + b_scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; + + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, + &scale_mode_a, sizeof(scale_mode_a))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, + &scale_mode_b, sizeof(scale_mode_b))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &a_scale_inv_ptrs, sizeof(a_scale_inv_ptrs))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &b_scale_inv_ptrs, sizeof(b_scale_inv_ptrs))); +#else + NVTE_CHECK(false, + "FP8 block scaling grouped GEMM requires cuBLAS 13.4+, but compile-time " + "cuBLAS version is ", + CUBLAS_VERSION); +#endif // CUBLAS_VERSION >= CUBLAS_FP8_BLOCK_GROUPED_GEMM_VERSION +} + // Configures cuBLAS for tensor-scaling FP8 grouped GEMM: sets PER_BATCH_SCALAR_32F scale mode // and scale pointers for A and B. Both operands are guaranteed FP8 by the caller. inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, void **a_scale_inv_ptrs, @@ -781,6 +950,10 @@ inline GroupedGemmWorkspace setup_grouped_gemm_workspace(transformer_engine::Ten "Grouped GEMM setup workspace"); void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, "Grouped GEMM cuBLAS workspace"); + constexpr uintptr_t kSetupBaseAlignment = 16; + NVTE_CHECK(reinterpret_cast(setup_workspace_ptr) % kSetupBaseAlignment == 0, + "Grouped GEMM setup workspace must be ", kSetupBaseAlignment, + "-byte aligned (cuBLAS requires this for pointer arrays)."); auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( static_cast(setup_workspace_ptr), num_tensors); return {std::move(setup_workspace), cublas_workspace_ptr, num_tensors}; @@ -790,9 +963,8 @@ inline void execute_grouped_gemm(const GroupedGemmSetupWorkspace &setup_workspac const GroupedOperandSelection &A_sel, const GroupedOperandSelection &B_sel, transformer_engine::DType d_dtype, size_t num_tensors, - bool use_split_accumulator, bool use_fp8, int64_t avg_m_val, - int64_t avg_n_val, int64_t avg_k_val, void *cublas_workspace_ptr, - cudaStream_t stream, int math_sm_count = 0) { + const GroupedGemmConfig &config, void *cublas_workspace_ptr, + cudaStream_t stream) { using cublasHandleManager = transformer_engine::detail::HandleManager; cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); @@ -805,26 +977,42 @@ inline void execute_grouped_gemm(const GroupedGemmSetupWorkspace &setup_workspac num_tensors); cublasLtMatmulDescOpaque_t matmulDesc; - init_matmul_desc(matmulDesc, op_A, op_B, use_fp8, use_split_accumulator); + init_matmul_desc(matmulDesc, op_A, op_B, config.use_fp8, config.use_split_accumulator, + config.use_per_group_alpha_beta); if (transformer_engine::is_mxfp_scaling(A_sel.scaling_mode)) { set_mxfp8_scale_pointers(matmulDesc, setup_workspace.a_scale_inv_ptrs, setup_workspace.b_scale_inv_ptrs); - } else if (use_fp8) { + } else if (transformer_engine::is_nvfp_scaling(A_sel.scaling_mode)) { + set_nvfp4_scale_pointers(matmulDesc, setup_workspace.a_scale_inv_ptrs, + setup_workspace.b_scale_inv_ptrs); + } else if (transformer_engine::is_fp8_block_scaling(A_sel.scaling_mode)) { + set_fp8_block_scaling_scale_pointers(matmulDesc, setup_workspace.a_scale_inv_ptrs, + setup_workspace.b_scale_inv_ptrs, A_sel.scaling_mode, + B_sel.scaling_mode); + } else if (config.use_fp8) { set_fp8_scale_pointers(matmulDesc, setup_workspace.a_scale_inv_ptrs, setup_workspace.b_scale_inv_ptrs); } - if (math_sm_count != 0) { - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count, sizeof(math_sm_count))); + if (config.sm_count != 0) { + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, + &config.sm_count, sizeof(config.sm_count))); } - cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, - descD, avg_m_val, avg_n_val, avg_k_val); - - NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs, - setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, - setup_workspace.beta_ptrs, setup_workspace.C_ptrs, &descC, - setup_workspace.D_ptrs, &descD, &algo, cublas_workspace_ptr, - kGroupedGemmCublasWorkspaceSize, stream)); + cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo( + handle, matmulDesc, descA, descB, descC, descD, config.avg_m, config.avg_n, config.avg_k); + + // Hopper uses a single scalar alpha/beta for the whole grouped GEMM; + // Blackwell+ uses per-matrix alpha/beta arrays. + void *alpha_arg = config.use_per_group_alpha_beta + ? static_cast(setup_workspace.alpha_ptrs) + : config.alpha_dptr; + void *beta_arg = config.use_per_group_alpha_beta ? static_cast(setup_workspace.beta_ptrs) + : config.beta_dptr; + + NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, alpha_arg, setup_workspace.A_ptrs, &descA, + setup_workspace.B_ptrs, &descB, beta_arg, setup_workspace.C_ptrs, + &descC, setup_workspace.D_ptrs, &descD, &algo, + cublas_workspace_ptr, kGroupedGemmCublasWorkspaceSize, stream)); } // Device helper: compute the element offset for tensor `idx` given shape metadata. @@ -871,22 +1059,59 @@ __forceinline__ __device__ int64_t padded_mxfp8_scale_inv_bytes(int64_t first, i return padded_scale_dim_y * padded_scale_dim_x; } -// Device helper: byte offset into a contiguous grouped MXFP8 scale_inv buffer for -// tensor `idx`. Each expert's scale_inv is expected to be padded -// to the 128x4 swizzled layout. -__forceinline__ __device__ int64_t compute_grouped_tensor_mxfp8_scale_inv_offset( - const TensorShapeInfo &meta, size_t idx, bool rowwise) { +// Generic prefix-sum of per-tensor padded scale_inv sizes — used to locate where +// tensor `idx`'s scales start in a contiguous grouped scale_inv buffer. +// `PaddedFn` is a callable (int64_t first, int64_t last) -> int64_t returning the +// recipe-specific padded size (bytes for MXFP8/NVFP4, floats for FP8 block scaling). +template +__forceinline__ __device__ int64_t compute_grouped_scale_inv_offset(const TensorShapeInfo &meta, + size_t idx, PaddedFn padded) { if (meta.first_dims != nullptr || meta.last_dims != nullptr) { int64_t cumsum = 0; for (size_t i = 0; i < idx; i++) { const int64_t f = meta.first_dims ? meta.first_dims[i] : meta.uniform_first; const int64_t l = meta.last_dims ? meta.last_dims[i] : meta.uniform_last; - cumsum += padded_mxfp8_scale_inv_bytes(f, l, rowwise); + cumsum += padded(f, l); } return cumsum; } - return static_cast(idx) * - padded_mxfp8_scale_inv_bytes(meta.uniform_first, meta.uniform_last, rowwise); + return static_cast(idx) * padded(meta.uniform_first, meta.uniform_last); +} + +__forceinline__ __device__ int64_t padded_nvfp4_scale_inv_bytes(int64_t first, int64_t last) { + namespace mxfp8_swizzle = transformer_engine::dispatch::mxfp8::swizzle; + constexpr int64_t kNvfp4BlockSize = 16; + const int64_t scale_tile_y = static_cast(mxfp8_swizzle::GEMM_SWIZZLED_SCALE_TILE_DIM_Y); + const int64_t scale_tile_x = static_cast(mxfp8_swizzle::GEMM_SWIZZLED_SCALE_TILE_DIM_X); + const int64_t padded_scale_dim_y = ((first + scale_tile_y - 1) / scale_tile_y) * scale_tile_y; + const int64_t scale_dim_x = (last + kNvfp4BlockSize - 1) / kNvfp4BlockSize; + const int64_t padded_scale_dim_x = + ((scale_dim_x + scale_tile_x - 1) / scale_tile_x) * scale_tile_x; + // E4M3 scales are 1 byte per element. + return padded_scale_dim_y * padded_scale_dim_x; +} + +__forceinline__ __device__ int64_t padded_block_1d_scale_inv_floats(int64_t first, int64_t last, + bool effective_rowwise) { + constexpr int64_t kBlockLen = 128; + constexpr int64_t kRowAlign = 4; + const int64_t y_dim = effective_rowwise ? last : first; + const int64_t x_dim = effective_rowwise ? first : last; + const int64_t y = (y_dim + kBlockLen - 1) / kBlockLen; + const int64_t x = ((x_dim + kRowAlign - 1) / kRowAlign) * kRowAlign; + return y * x; +} + +__forceinline__ __device__ int64_t padded_block_2d_scale_inv_floats(int64_t first, int64_t last, + bool effective_rowwise) { + constexpr int64_t kBlockLen = 128; + constexpr int64_t kRowAlign = 4; + const int64_t y_dim = effective_rowwise ? first : last; + const int64_t x_dim = effective_rowwise ? last : first; + const int64_t y = (y_dim + kBlockLen - 1) / kBlockLen; + const int64_t x_ceil = (x_dim + kBlockLen - 1) / kBlockLen; + const int64_t x = ((x_ceil + kRowAlign - 1) / kRowAlign) * kRowAlign; + return y * x; } // Linear scan to find which tensor contains the given row. @@ -1016,15 +1241,18 @@ __global__ void setup_grouped_gemm_kernel( void **a_scale_inv_ptrs, void **b_scale_inv_ptrs, // Inputs char *a_base, char *b_base, char *c_base, char *d_base, TensorShapeInfo A_meta, - TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, size_t a_elem_size, - size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, float *alpha_ptr, float *beta_ptr, + TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, size_t a_bits_per_elem, + size_t b_bits_per_elem, size_t c_elem_size, size_t d_elem_size, float *alpha_ptr, + float *beta_ptr, bool use_per_group_alpha_beta, // Scale inputs: for tensor scaling, pass float* and set mxfp8_base to nullptr // For MXFP8, pass nullptr for tensor_scale and set mxfp8_base float *a_scale_base, float *b_scale_base, bool a_rowwise, bool b_rowwise, NVTEScalingMode scaling_mode, size_t num_tensors, MultiTensorGroupGemmInputArgs a_multi_tensor_args, MultiTensorGroupGemmOutputArgs c_multi_tensor_args, - MultiTensorGroupGemmOutputArgs d_multi_tensor_args) { + MultiTensorGroupGemmOutputArgs d_multi_tensor_args, + // NVFP4: per-group amax values and output buffer for computed alpha + float *a_amax, float *b_amax, float *nvfp4_computed_alpha) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= num_tensors) return; @@ -1053,9 +1281,9 @@ __global__ void setup_grouped_gemm_kernel( int64_t d_offset = compute_grouped_tensor_offset(D_meta, idx); // Compute data pointers - A_ptrs[idx] = - has_a_multi_tensor ? a_multi_tensor_args.data_ptrs[idx] : (a_base + a_offset * a_elem_size); - B_ptrs[idx] = b_base + b_offset * b_elem_size; + A_ptrs[idx] = has_a_multi_tensor ? a_multi_tensor_args.data_ptrs[idx] + : (a_base + (a_offset * a_bits_per_elem) / 8); + B_ptrs[idx] = b_base + (b_offset * b_bits_per_elem) / 8; C_ptrs[idx] = has_c_multi_tensor ? c_multi_tensor_args.data_ptrs[idx] : (c_base + c_offset * c_elem_size); D_ptrs[idx] = @@ -1076,34 +1304,86 @@ __global__ void setup_grouped_gemm_kernel( d_cols[idx] = static_cast(d_first); } - // Fill alpha/beta pointers (per-matrix) - alpha_ptrs[idx] = alpha_ptr + idx; - beta_ptrs[idx] = beta_ptr + idx; + // Fill alpha/beta pointers. + // Hopper uses one shared alpha/beta scalar for all groups; Blackwell+ uses per-matrix scalars. + // For NVFP4 on Blackwell+: compute per-group alpha that includes global scale (amax). + // A's amax: grouped path indexes a_amax[idx]; discrete path reads amax_ptrs[idx]. + if (use_per_group_alpha_beta) { + float a_amax_val = 0.0f; + bool has_a_amax = false; + if (has_a_multi_tensor) { + auto *a_amax_p = static_cast(a_multi_tensor_args.amax_ptrs[idx]); + if (a_amax_p != nullptr) { + a_amax_val = *a_amax_p; + has_a_amax = true; + } + } else if (a_amax != nullptr) { + a_amax_val = a_amax[idx]; + has_a_amax = true; + } + if (has_a_amax && b_amax && nvfp4_computed_alpha) { + constexpr float factor_inv = 1.0f / (6.0f * 6.0f * 448.0f * 448.0f); + nvfp4_computed_alpha[idx] = alpha_ptr[idx] * a_amax_val * b_amax[idx] * factor_inv; + alpha_ptrs[idx] = &nvfp4_computed_alpha[idx]; + } else { + alpha_ptrs[idx] = alpha_ptr + idx; + } + beta_ptrs[idx] = beta_ptr + idx; + } else { + // Hopper: use single scalar for the whole grouped GEMM + alpha_ptrs[idx] = alpha_ptr; + beta_ptrs[idx] = beta_ptr; + } - // Fill scale pointers (per-matrix). - // The interpretation of the scale buffers depends on the shared scaling recipe: - // otherwise : one float per tensor, indexed by tensor index - if (a_scale_base) { - if (scaling_mode == NVTE_MXFP8_1D_SCALING) { - const int64_t a_scale_offset = - compute_grouped_tensor_mxfp8_scale_inv_offset(A_meta, idx, a_rowwise); - a_scale_inv_ptrs[idx] = reinterpret_cast( - static_cast(static_cast(a_scale_base)) + a_scale_offset); + // Fill scale pointers (per-matrix). For MXFP8/NVFP4 and FP8 block scaling, the per-expert + // scale_inv buffer is padded to a layout that depends on the recipe — offsets are computed + // from the same padded sizes that the quantizer uses at allocation, not from data_offset. + // NVTE_MXFP8_1D_SCALING : E8M0 byte stream; padded swizzled 128x4 tile, block_size=32. + // NVTE_NVFP4_1D_SCALING : E4M3 byte stream; padded swizzled 128x4 tile, block_size=16. + // NVTE_BLOCK_SCALING_1D : float32 array; ceildiv(./128) * roundup(./4) per tensor. + // NVTE_BLOCK_SCALING_2D : float32 array; ceildiv(./128) * roundup(ceildiv(./128), 4). + // otherwise (tensor) : one float per tensor, indexed by tensor index. + auto fill_scale_ptr = [&](void **ptrs, void *base, const TensorShapeInfo &meta, bool op_rowwise) { + int64_t byte_offset = -1; + int64_t float_offset = -1; + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: + byte_offset = compute_grouped_scale_inv_offset(meta, idx, [=](int64_t f, int64_t l) { + return padded_mxfp8_scale_inv_bytes(f, l, op_rowwise); + }); + break; + case NVTE_NVFP4_1D_SCALING: + byte_offset = compute_grouped_scale_inv_offset( + meta, idx, [](int64_t f, int64_t l) { return padded_nvfp4_scale_inv_bytes(f, l); }); + break; + case NVTE_BLOCK_SCALING_1D: + float_offset = compute_grouped_scale_inv_offset(meta, idx, [=](int64_t f, int64_t l) { + return padded_block_1d_scale_inv_floats(f, l, op_rowwise); + }); + break; + case NVTE_BLOCK_SCALING_2D: + float_offset = compute_grouped_scale_inv_offset(meta, idx, [=](int64_t f, int64_t l) { + return padded_block_2d_scale_inv_floats(f, l, op_rowwise); + }); + break; + default: + float_offset = static_cast(idx); + break; + } + if (byte_offset >= 0) { + ptrs[idx] = static_cast(base) + byte_offset; } else { - a_scale_inv_ptrs[idx] = static_cast(a_scale_base) + idx; + ptrs[idx] = static_cast(base) + float_offset; } + }; + + if (a_scale_base) { + fill_scale_ptr(a_scale_inv_ptrs, a_scale_base, A_meta, a_rowwise); } else { a_scale_inv_ptrs[idx] = a_multi_tensor_args.scale_inv_ptrs[idx]; } if (b_scale_base) { - if (scaling_mode == NVTE_MXFP8_1D_SCALING) { - const int64_t b_scale_offset = - compute_grouped_tensor_mxfp8_scale_inv_offset(B_meta, idx, b_rowwise); - b_scale_inv_ptrs[idx] = reinterpret_cast( - static_cast(static_cast(b_scale_base)) + b_scale_offset); - } else { - b_scale_inv_ptrs[idx] = static_cast(b_scale_base) + idx; - } + fill_scale_ptr(b_scale_inv_ptrs, b_scale_base, B_meta, b_rowwise); } } @@ -1112,7 +1392,8 @@ inline void launch_grouped_gemm_setup( const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, const GroupedOperandSelection &B_sel, const transformer_engine::GroupedTensor *C, const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, - const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream, + const transformer_engine::Tensor *beta_tensor, bool use_per_group_alpha_beta, + size_t num_tensors, cudaStream_t stream, const MultiTensorGroupGemmInputArgs &a_multi_tensor_args, const NVTETensor *C_list, const NVTETensor *D_list, char *a_base, transformer_engine::DType c_dtype, transformer_engine::DType d_dtype) { @@ -1153,8 +1434,8 @@ inline void launch_grouped_gemm_setup( d_base = static_cast(D->data.dptr); } - const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); - const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); + const size_t a_bits_per_elem = transformer_engine::typeToNumBits(A_sel.dtype); + const size_t b_bits_per_elem = transformer_engine::typeToNumBits(B_sel.dtype); const size_t c_elem_size = transformer_engine::typeToSize(c_dtype); const size_t d_elem_size = transformer_engine::typeToSize(d_dtype); @@ -1164,18 +1445,27 @@ inline void launch_grouped_gemm_setup( // A and B share the same scaling recipe (validated in validate_grouped_gemm_inputs). // Pass scale buffers as void* and let the kernel interpret them via scaling_mode. - // Scale rowwise flag for MXFP8/NVFP4: to calculate scale_inv padding based offsets - // within kernel. Ignored for tensor scaling. - const bool a_rowwise = A_sel.rowwise; - const bool b_rowwise = B_sel.rowwise; + // Scales rowwise of meta — only differs from sel.rowwise for NVFP4 colwise (swap_dims=true). + const bool a_rowwise = A_sel.rowwise || A_sel.swap_dims; + const bool b_rowwise = B_sel.rowwise || B_sel.swap_dims; + + // NVFP4 alpha needs A's amax from either A_sel.amax (grouped) or amax_ptrs (discrete). + const bool a_has_amax = (A_sel.amax != nullptr) || + (A_sel.dptr == nullptr && a_multi_tensor_args.amax_ptrs[0] != nullptr); + const bool needs_nvfp4_alpha = a_has_amax && (B_sel.amax != nullptr); + setup_grouped_gemm_kernel<<>>( ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.a_rows, ws.a_cols, ws.b_rows, ws.b_cols, ws.d_rows, ws.d_cols, ws.alpha_ptrs, ws.beta_ptrs, ws.a_scale_inv_ptrs, ws.b_scale_inv_ptrs, - A_sel.dptr, B_sel.dptr, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, - b_elem_size, c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), - static_cast(beta_tensor->data.dptr), reinterpret_cast(A_sel.scale_inv), - reinterpret_cast(B_sel.scale_inv), a_rowwise, b_rowwise, A_sel.scaling_mode, - num_tensors, a_multi_tensor_args, c_multi_tensor_args, d_multi_tensor_args); + A_sel.dptr, B_sel.dptr, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_bits_per_elem, + b_bits_per_elem, c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), + static_cast(beta_tensor->data.dptr), use_per_group_alpha_beta, + reinterpret_cast(A_sel.scale_inv), reinterpret_cast(B_sel.scale_inv), + a_rowwise, b_rowwise, A_sel.scaling_mode, num_tensors, a_multi_tensor_args, + c_multi_tensor_args, d_multi_tensor_args, + A_sel.amax ? static_cast(A_sel.amax) : nullptr, + B_sel.amax ? static_cast(B_sel.amax) : nullptr, + needs_nvfp4_alpha ? ws.nvfp4_computed_alpha : nullptr); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -1195,9 +1485,14 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT NVTE_API_CALL(nvte_grouped_gemm); using namespace transformer_engine; - // Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.3+ + // Grouped GEMM requires Blackwell (SM100) or newer with cuBLAS 13.3+, + // or Hopper (SM90) with cuBLAS 13.4+. check_grouped_gemm_requirements("nvte_grouped_gemm"); + const int current_device = transformer_engine::cuda::current_device(); + const int sm = transformer_engine::cuda::sm_arch(current_device); + const bool use_per_group_alpha_beta = grouped_gemm_supports_per_group_alpha_beta(sm); + // Convert to internal types const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); @@ -1212,10 +1507,18 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT GroupedMatmulConfig config_ = parse_grouped_gemm_config(config); // Validate inputs and outputs. - const size_t num_tensors = validate_grouped_gemm_inputs(inputA->num_tensors, {inputA, inputB}, - alpha_tensor, beta_tensor); + const size_t num_tensors = validate_grouped_gemm_inputs( + inputA->num_tensors, {inputA, inputB}, alpha_tensor, beta_tensor, use_per_group_alpha_beta); validate_grouped_gemm_outputs(num_tensors, {inputC_raw, outputD}); + // NVFP4-specific output dtype restrictions (matching non-grouped GEMM) + const bool use_fp4 = is_fp4_dtype(inputA->dtype()) || is_fp4_dtype(inputB->dtype()); + if (use_fp4) { + NVTE_CHECK(!is_fp4_dtype(outputD->dtype()), "FP4 GEMM output is not supported!"); + NVTE_CHECK(get_cuda_dtype(outputD->dtype()) != CUDA_R_16F, + "FP4 GEMM does not support FP16 output!"); + } + // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; // num_tensors validated above. @@ -1224,26 +1527,39 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); + // NVFP4 global-scale alpha requires per-tensor amax for both operands; without it + // the kernel silently drops the (amax_A * amax_B / factor) factor and produces + // numerically wrong output. + if (is_nvfp_scaling(A_sel.scaling_mode)) { + NVTE_CHECK(A_sel.amax != nullptr, "Grouped GEMM: NVFP4 A is missing amax."); + NVTE_CHECK(B_sel.amax != nullptr, "Grouped GEMM: NVFP4 B is missing amax."); + } + // Workspaces: setup (pointer arrays) and cuBLAS auto workspace = setup_grouped_gemm_workspace(wspace_setup, wspace_cublas, num_tensors); MultiTensorGroupGemmInputArgs a_multi_tensor_args{}; launch_grouped_gemm_setup(workspace.setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor, - beta_tensor, num_tensors, stream, a_multi_tensor_args, - /*C_list=*/nullptr, /*D_list=*/nullptr, A_sel.dptr, inputC->dtype(), - outputD->dtype()); + beta_tensor, use_per_group_alpha_beta, num_tensors, stream, + a_multi_tensor_args, /*C_list=*/nullptr, /*D_list=*/nullptr, A_sel.dptr, + inputC->dtype(), outputD->dtype()); // Compute average dimensions for heuristics // K dimension: if transa, K is A's first dim; if not, K is A's last dim // Use original inputA and transa for heuristics (not modified A_sel.trans) - int64_t avg_m_val = config_.avg_m.value_or(compute_avg_first_dim(outputD)); - int64_t avg_n_val = config_.avg_n.value_or(compute_avg_last_dim(outputD)); - int64_t avg_k_val = + GroupedGemmConfig gemm_config; + gemm_config.use_split_accumulator = config_.use_split_accumulator; + gemm_config.use_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); + gemm_config.use_per_group_alpha_beta = use_per_group_alpha_beta; + gemm_config.alpha_dptr = alpha_tensor->data.dptr; + gemm_config.beta_dptr = beta_tensor->data.dptr; + gemm_config.avg_m = config_.avg_m.value_or(compute_avg_first_dim(outputD)); + gemm_config.avg_n = config_.avg_n.value_or(compute_avg_last_dim(outputD)); + gemm_config.avg_k = config_.avg_k.value_or(transa ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA)); - const bool use_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); + gemm_config.sm_count = config_.sm_count; execute_grouped_gemm(workspace.setup_workspace, A_sel, B_sel, outputD->dtype(), num_tensors, - config_.use_split_accumulator, use_fp8, avg_m_val, avg_n_val, avg_k_val, - workspace.cublas_workspace_ptr, stream, config_.sm_count); + gemm_config, workspace.cublas_workspace_ptr, stream); } void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num_a_tensors, @@ -1255,9 +1571,14 @@ void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num NVTE_API_CALL(nvte_grouped_gemm_with_discrete_inputA); using namespace transformer_engine; - // Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.3+ + // Grouped GEMM requires Blackwell (SM100) or newer with cuBLAS 13.3+, + // or Hopper (SM90) with cuBLAS 13.4+. check_grouped_gemm_requirements("nvte_grouped_gemm_with_discrete_inputA"); + const int current_device = transformer_engine::cuda::current_device(); + const int sm = transformer_engine::cuda::sm_arch(current_device); + const bool use_per_group_alpha_beta = grouped_gemm_supports_per_group_alpha_beta(sm); + NVTE_CHECK(A_list != nullptr, "Grouped GEMM: A_list is null."); NVTE_CHECK(num_a_tensors > 0, "Grouped GEMM: num_a_tensors must be > 0."); @@ -1273,9 +1594,8 @@ void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num GroupedMatmulConfig config_ = parse_grouped_gemm_config(config); // Validate inputs and outputs. - const size_t num_tensors = - validate_grouped_gemm_inputs(num_a_tensors, {inputB}, alpha_tensor, beta_tensor); - + const size_t num_tensors = validate_grouped_gemm_inputs(num_a_tensors, {inputB}, alpha_tensor, + beta_tensor, use_per_group_alpha_beta); validate_grouped_gemm_outputs(num_tensors, {inputC_raw, outputD}); // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) @@ -1284,27 +1604,33 @@ void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num // Validate A list and selection auto A_list_info = validate_grouped_gemm_multi_inputA_list(A_list, num_a_tensors, num_tensors, "A"); - auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { + auto is_supported_dtype = [](transformer_engine::DType dtype) { return dtype == transformer_engine::DType::kFloat8E4M3 || dtype == transformer_engine::DType::kFloat8E5M2 || + dtype == transformer_engine::DType::kFloat4E2M1 || dtype == transformer_engine::DType::kBFloat16 || dtype == transformer_engine::DType::kFloat16; }; - NVTE_CHECK(is_fp8_or_16bit(A_list_info.all_row ? A_list_info.row_dtype : A_list_info.col_dtype), - "Grouped GEMM: A_list tensors must be FP8, BF16, or FP16."); + NVTE_CHECK( + is_supported_dtype(A_list_info.all_row ? A_list_info.row_dtype : A_list_info.col_dtype), + "Grouped GEMM: A_list tensors must be FP8, NVFP4, BF16, or FP16."); // Cross-operand consistency (mirrors validate_grouped_gemm_inputs). const DType a_rep_dtype = A_list_info.all_row ? A_list_info.row_dtype : A_list_info.col_dtype; - NVTE_CHECK(is_fp8_dtype(a_rep_dtype) == is_fp8_dtype(inputB->dtype()), - "Grouped GEMM: A and B must both be FP8 or both be non-FP8."); - NVTE_CHECK(transformer_engine::is_mxfp_scaling(A_list_info.scaling_mode) == - transformer_engine::is_mxfp_scaling(inputB->scaling_mode), - "Grouped GEMM: A and B must both use MXFP8 scaling or both use tensor scaling."); - if (transformer_engine::is_mxfp_scaling(A_list_info.scaling_mode)) { + const bool a_is_low_precision = + is_fp8_dtype(a_rep_dtype) || a_rep_dtype == transformer_engine::DType::kFloat4E2M1; + const bool b_is_low_precision = + is_fp8_dtype(inputB->dtype()) || inputB->dtype() == transformer_engine::DType::kFloat4E2M1; + NVTE_CHECK(a_is_low_precision == b_is_low_precision, + "Grouped GEMM: A and B must both be low-precision (FP8/NVFP4) or both be non-FP8."); + NVTE_CHECK(A_list_info.scaling_mode == inputB->scaling_mode, + "Grouped GEMM: A and B must use the same scaling mode."); + if (transformer_engine::is_mxfp_scaling(A_list_info.scaling_mode) || + transformer_engine::is_nvfp_scaling(A_list_info.scaling_mode)) { NVTE_CHECK(A_list_info.with_gemm_swizzled_scales, - "MXFP8 grouped GEMM: A scales must be swizzled for GEMM."); + "Grouped GEMM: A scales must be swizzled for GEMM (MXFP8/NVFP4)."); NVTE_CHECK(inputB->with_gemm_swizzled_scales, - "MXFP8 grouped GEMM: B scales must be swizzled for GEMM."); + "Grouped GEMM: B scales must be swizzled for GEMM (MXFP8/NVFP4)."); } // Select operand storage for B (row-wise vs column-wise) @@ -1317,53 +1643,79 @@ void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num const DType rep_dtype = A_list_info.all_row ? A_list_info.row_dtype : A_list_info.col_dtype; const bool is_fp8 = is_fp8_dtype(rep_dtype); - const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); const bool mxfp8 = transformer_engine::is_mxfp_scaling(A_list_info.scaling_mode); + const bool nvfp4 = transformer_engine::is_nvfp_scaling(A_list_info.scaling_mode); + const bool fp8_block = transformer_engine::is_fp8_block_scaling(A_list_info.scaling_mode); + // FP8 block scaling on Hopper requires TN layout (matches select_grouped_operand logic for B). + const bool non_tn_fp8_ok = fp8_block ? false : nvte_is_non_tn_fp8_gemm_supported(); int64_t avg_first_dim = 0; int64_t avg_last_dim = 0; MultiTensorGroupGemmInputArgs a_multi_tensor_args{}; - const auto choice = - choose_grouped_operand_storage(static_cast(transa), /*is_A=*/true, mxfp8, is_fp8, - non_tn_fp8_ok, A_list_info.all_row, A_list_info.all_col, "A"); + const auto choice = choose_grouped_operand_storage(static_cast(transa), /*is_A=*/true, + mxfp8, is_fp8, nvfp4, fp8_block, non_tn_fp8_ok, + A_list_info.all_row, A_list_info.all_col, "A"); A_sel.trans = choice.trans; A_sel.rowwise = choice.use_rowwise; + A_sel.swap_dims = choice.swap_dims; if (choice.use_rowwise) { NVTE_CHECK(A_list_info.all_row, "Grouped GEMM: A_list is missing row-wise data"); A_sel.dtype = A_list_info.row_dtype; a_multi_tensor_args = build_grouped_gemm_multi_inputA_args( - A_list, num_a_tensors, /*use_rowwise=*/true, is_fp8, &avg_first_dim, &avg_last_dim, "A"); + A_list, num_a_tensors, /*use_rowwise=*/true, is_fp8, &avg_first_dim, &avg_last_dim, "A", + /*needs_scale_inv=*/nvfp4 || fp8_block, + /*swap_dims=*/false); } else { NVTE_CHECK(A_list_info.all_col, "Grouped GEMM: A_list is missing column-wise data"); A_sel.dtype = A_list_info.col_dtype; + // NVFP4/MXFP8 columnwise data is physically transposed (logical shape == rowwise shape); + // pass swap_dims so rows/cols and avg_first/last match the physical layout cuBLAS sees. a_multi_tensor_args = build_grouped_gemm_multi_inputA_args( - A_list, num_a_tensors, /*use_rowwise=*/false, is_fp8, &avg_first_dim, &avg_last_dim, "A"); + A_list, num_a_tensors, /*use_rowwise=*/false, is_fp8, &avg_first_dim, &avg_last_dim, "A", + /*needs_scale_inv=*/nvfp4 || fp8_block, + /*swap_dims=*/choice.swap_dims); } - // For discrete A_list, scale pointers are per-tensor; use multi-tensor args. - // Base pointer is unused when providing per-tensor pointers. + // Discrete A_list: per-tensor pointers come from `a_multi_tensor_args` (data/scale/amax). A_sel.scale_inv = nullptr; A_sel.dptr = nullptr; + A_sel.amax = nullptr; + + if (nvfp4) { + const bool use_rowwise = choice.use_rowwise; + for (size_t i = 0; i < num_tensors; ++i) { + const transformer_engine::Tensor *ti = transformer_engine::convertNVTETensorCheck(A_list[i]); + const auto &amax_i = use_rowwise ? ti->amax : ti->columnwise_amax; + NVTE_CHECK(amax_i.has_data(), "Grouped GEMM: NVFP4 A_list tensor ", i, " is missing amax."); + } + NVTE_CHECK(B_sel.amax != nullptr, "Grouped GEMM: NVFP4 B is missing amax."); + } // Workspaces: setup (pointer arrays) and cuBLAS auto workspace = setup_grouped_gemm_workspace(wspace_setup, wspace_cublas, num_tensors); launch_grouped_gemm_setup(workspace.setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor, - beta_tensor, num_tensors, stream, a_multi_tensor_args, - /*C_list=*/nullptr, /*D_list=*/nullptr, nullptr, inputC->dtype(), - outputD->dtype()); - - // Compute average dimensions for heuristics - int64_t avg_m_val = config_.avg_m.value_or(compute_avg_first_dim(outputD)); - int64_t avg_n_val = + beta_tensor, use_per_group_alpha_beta, num_tensors, stream, + a_multi_tensor_args, /*C_list=*/nullptr, /*D_list=*/nullptr, nullptr, + inputC->dtype(), outputD->dtype()); + + GroupedGemmConfig gemm_config; + gemm_config.use_split_accumulator = config_.use_split_accumulator; + gemm_config.use_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); + gemm_config.use_per_group_alpha_beta = use_per_group_alpha_beta; + gemm_config.alpha_dptr = alpha_tensor->data.dptr; + gemm_config.beta_dptr = beta_tensor->data.dptr; + gemm_config.avg_m = config_.avg_m.value_or(compute_avg_first_dim(outputD)); + gemm_config.avg_n = config_.avg_n.value_or(transb ? compute_avg_first_dim(inputB) : compute_avg_last_dim(inputB)); - int64_t avg_k_val = - config_.avg_k.value_or(static_cast(transa) ? avg_last_dim : avg_first_dim); - const bool use_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); + // After choose_grouped_operand_storage / swap_dims, avg_first/last reflect the physical + // layout cuBLAS sees, and A_sel.trans is the post-flip transpose flag. Use those (not the + // raw `transa`) so K is selected from the correct dim. + gemm_config.avg_k = config_.avg_k.value_or(A_sel.trans ? avg_last_dim : avg_first_dim); + gemm_config.sm_count = config_.sm_count; execute_grouped_gemm(workspace.setup_workspace, A_sel, B_sel, outputD->dtype(), num_tensors, - config_.use_split_accumulator, use_fp8, avg_m_val, avg_n_val, avg_k_val, - workspace.cublas_workspace_ptr, stream, config_.sm_count); + gemm_config, workspace.cublas_workspace_ptr, stream); } void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, @@ -1376,9 +1728,14 @@ void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, NVTE_API_CALL(nvte_grouped_gemm_with_discrete_out); using namespace transformer_engine; - // Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.3+ + // Grouped GEMM requires Blackwell (SM100) or newer with cuBLAS 13.3+, + // or Hopper (SM90) with cuBLAS 13.4+. check_grouped_gemm_requirements("nvte_grouped_gemm_with_discrete_out"); + const int current_device = transformer_engine::cuda::current_device(); + const int sm = transformer_engine::cuda::sm_arch(current_device); + const bool use_per_group_alpha_beta = grouped_gemm_supports_per_group_alpha_beta(sm); + NVTE_CHECK(D_list != nullptr, "Grouped GEMM: D_list is null."); NVTE_CHECK(num_d_tensors > 0, "Grouped GEMM: num_d_tensors must be > 0."); if (num_c_tensors > 0) { @@ -1395,8 +1752,8 @@ void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, const Tensor *d0 = convertNVTETensorCheck(D_list[0]); const DType d_dtype = d0->dtype(); - const size_t num_tensors = validate_grouped_gemm_inputs(inputA->num_tensors, {inputA, inputB}, - alpha_tensor, beta_tensor); + const size_t num_tensors = validate_grouped_gemm_inputs( + inputA->num_tensors, {inputA, inputB}, alpha_tensor, beta_tensor, use_per_group_alpha_beta); NVTE_CHECK(num_d_tensors == num_tensors, "Grouped GEMM: D_list must have num_tensors (", num_tensors, ") entries, got ", num_d_tensors); if (num_c_tensors > 0) { @@ -1417,25 +1774,37 @@ void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, // mirror the non-grouped GEMM logic for FP8 layout constraints. auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); + + // NVFP4 global-scale alpha requires per-tensor amax for both operands. + if (is_nvfp_scaling(A_sel.scaling_mode)) { + NVTE_CHECK(A_sel.amax != nullptr, "Grouped GEMM: NVFP4 A is missing amax."); + NVTE_CHECK(B_sel.amax != nullptr, "Grouped GEMM: NVFP4 B is missing amax."); + } + // Workspaces: setup (pointer arrays) and cuBLAS auto workspace = setup_grouped_gemm_workspace(wspace_setup, wspace_cublas, num_tensors); MultiTensorGroupGemmInputArgs a_multi_tensor_args{}; launch_grouped_gemm_setup(workspace.setup_workspace, A_sel, B_sel, /*C=*/nullptr, /*D=*/nullptr, - alpha_tensor, beta_tensor, num_tensors, stream, a_multi_tensor_args, - C_list, D_list, A_sel.dptr, d_dtype, d_dtype); - - // Compute average dimensions for heuristics - int64_t avg_m_val = + alpha_tensor, beta_tensor, use_per_group_alpha_beta, num_tensors, + stream, a_multi_tensor_args, C_list, D_list, A_sel.dptr, d_dtype, + d_dtype); + + GroupedGemmConfig gemm_config; + gemm_config.use_split_accumulator = config_.use_split_accumulator; + gemm_config.use_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); + gemm_config.use_per_group_alpha_beta = use_per_group_alpha_beta; + gemm_config.alpha_dptr = alpha_tensor->data.dptr; + gemm_config.beta_dptr = beta_tensor->data.dptr; + gemm_config.avg_m = config_.avg_m.value_or(transa ? compute_avg_last_dim(inputA) : compute_avg_first_dim(inputA)); - int64_t avg_n_val = + gemm_config.avg_n = config_.avg_n.value_or(transb ? compute_avg_first_dim(inputB) : compute_avg_last_dim(inputB)); - int64_t avg_k_val = + gemm_config.avg_k = config_.avg_k.value_or(transa ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA)); - const bool use_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); - execute_grouped_gemm(workspace.setup_workspace, A_sel, B_sel, d_dtype, num_tensors, - config_.use_split_accumulator, use_fp8, avg_m_val, avg_n_val, avg_k_val, - workspace.cublas_workspace_ptr, stream, config_.sm_count); + gemm_config.sm_count = config_.sm_count; + execute_grouped_gemm(workspace.setup_workspace, A_sel, B_sel, d_dtype, num_tensors, gemm_config, + workspace.cublas_workspace_ptr, stream); } namespace { diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version index 706c237ccc..ccd18fc153 100644 --- a/transformer_engine/common/libtransformer_engine.version +++ b/transformer_engine/common/libtransformer_engine.version @@ -7,6 +7,7 @@ transformer_engine::cuda::supports_multicast*; transformer_engine::cuda::stream_priority_range*; transformer_engine::cuda::current_device*; + transformer_engine::cuda::cublas_version*; transformer_engine::cuda_driver::get_symbol*; transformer_engine::cuda_driver::ensure_context_exists*; transformer_engine::ubuf_built_with_mpi*; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1a52d76019..8c5c87a846 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -356,6 +356,8 @@ static void CheckGroupedScaleInv(const GroupedTensor &t, const std::string &name // Determine expected dtype based on data type and scaling mode if (is_fp8_dtype(t.dtype()) && is_tensor_scaling(t.scaling_mode)) { check_scales(DType::kFloat32); + } else if (is_fp8_block_scaling(t.scaling_mode)) { + check_scales(DType::kFloat32); } else if (is_mxfp8_scaling(t.scaling_mode)) { check_scales(DType::kFloat8E8M0); } else if (is_nvfp4_scaling(t.scaling_mode)) { diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index edf2c1e1c2..48751d8c93 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -5,7 +5,6 @@ """Python interface for GEMM extensions""" from typing import Iterable, Optional, Tuple, Union, List -import ctypes import os import functools import torch @@ -420,20 +419,8 @@ def general_grouped_gemm( @functools.lru_cache(maxsize=None) def get_grouped_gemm_setup_workspace_size(num_tensors: int) -> int: - """Return workspace size for grouped GEMM pointer setup. - Must match GroupedGemmSetupWorkspace::required_setup_size in cublaslt_grouped_gemm.cu. - """ - ptr_bytes = ctypes.sizeof(ctypes.c_void_p) - int_bytes = ctypes.sizeof(ctypes.c_int) - ptr_size = num_tensors * ptr_bytes - int_size = num_tensors * int_bytes - k_ptr_alignment = 16 - # Each pointer array is placed at a 16-byte-aligned offset (matching kPtrAlignment in C++). - # aligned_ptr_size = round_up(num_tensors * ptr_bytes, 16) - aligned_ptr_size = ((ptr_size + k_ptr_alignment - 1) // k_ptr_alignment) * k_ptr_alignment - size = 8 * aligned_ptr_size + 6 * int_size - alignment = 256 - return ((size + alignment - 1) // alignment) * alignment + """Return workspace size for grouped GEMM pointer setup.""" + return tex.get_grouped_gemm_setup_workspace_size(num_tensors) @functools.lru_cache(maxsize=None) @@ -510,13 +497,18 @@ def general_grouped_gemm_for_grouped_tensor( rowwise = B.rowwise_data device = rowwise.device if rowwise is not None else B.columnwise_data.device + # Hopper (SM90) uses a single shared alpha/beta scalar; + # Blackwell+ (SM100) supports per-group alpha/beta arrays. + per_group = torch.cuda.get_device_capability() >= (10, 0) + num_alphabeta = num_tensors if per_group else 1 + if alpha is None: - alpha = _get_fp32_ones_tensor(num_tensors, device) + alpha = _get_fp32_ones_tensor(num_alphabeta, device) if beta is None: if accumulate: - beta = _get_fp32_ones_tensor(num_tensors, device) + beta = _get_fp32_ones_tensor(num_alphabeta, device) else: - beta = _get_fp32_zeros_tensor(num_tensors, device) + beta = _get_fp32_zeros_tensor(num_alphabeta, device) if not alpha.is_cuda or not beta.is_cuda: raise ValueError("alpha and beta must be CUDA tensors.") diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 9cb1fb7f54..53ea76d83b 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -86,10 +86,13 @@ GroupedGemmConfig prepare_grouped_gemm_config(at::Tensor alpha, at::Tensor beta, at::Tensor workspace_setup, at::Tensor workspace_cublas, size_t num_tensors, int math_sm_count, bool use_split_accumulator) { - NVTE_CHECK(alpha.numel() == static_cast(num_tensors), - "Grouped GEMM expects alpha to have num_tensors elements."); - NVTE_CHECK(beta.numel() == static_cast(num_tensors), - "Grouped GEMM expects beta to have num_tensors elements."); + const bool per_group = (alpha.numel() == static_cast(num_tensors)); + const bool scalar = (alpha.numel() == 1); + NVTE_CHECK(per_group || scalar, "Grouped GEMM expects alpha to have 1 or num_tensors (", + num_tensors, ") elements, got ", alpha.numel()); + NVTE_CHECK(beta.numel() == alpha.numel(), + "Grouped GEMM expects beta to have the same number of elements as alpha (", + alpha.numel(), "), got ", beta.numel()); GroupedGemmConfig grouped_gemm_config{ makeTransformerEngineTensor(alpha), diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index eb7576d905..82cda07cac 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -278,6 +278,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("split_quantize", &transformer_engine::pytorch::split_quantize, "Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"), py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false); + m.def("get_grouped_gemm_setup_workspace_size", &nvte_get_grouped_gemm_setup_workspace_size, + "Required workspace size for grouped GEMM setup"); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, "Grouped GEMM"); m.def("te_general_grouped_gemm_for_grouped_tensor",