From 56323030309fa189c349ad971e0899ab05c0eea4 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Thu, 2 Apr 2026 15:08:05 +0000 Subject: [PATCH 01/14] Rebase onto dev --- tests/pytorch/test_numerics.py | 8 +- transformer_engine/common/CMakeLists.txt | 4 +- .../common/gemm/ck_grouped_gemm.cpp | 338 ------------------ .../gemm/ck_grouped_gemm/ck_grouped_gemm.cpp | 138 +++++++ .../{ => ck_grouped_gemm}/ck_grouped_gemm.h | 2 +- .../ck_grouped_gemm/ck_grouped_gemm_common.h | 108 ++++++ .../ck_grouped_gemm/ck_grouped_gemm_fp16.cpp | 83 +++++ .../ck_grouped_gemm_fp16_impl.h | 214 +++++++++++ .../ck_grouped_gemm/ck_grouped_gemm_fp8.cpp | 134 +++++++ .../ck_grouped_gemm_fp8_impl.h | 260 ++++++++++++++ .../common/gemm/cublaslt_gemm.cu | 12 +- .../transformer_engine/transformer_engine.h | 10 + 12 files changed, 965 insertions(+), 346 deletions(-) delete mode 100644 transformer_engine/common/gemm/ck_grouped_gemm.cpp create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp rename transformer_engine/common/gemm/{ => ck_grouped_gemm}/ck_grouped_gemm.h (93%) create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8_impl.h diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 82655a0ce..bb2ecdca6 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2171,6 +2171,8 @@ def test_grouped_linear_accuracy( @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) +@pytest.mark.parametrize("fp8_model_params", all_boolean) +@pytest.mark.parametrize("recipe", fp8_recipes + [None]) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("delay_wgrad_compute", all_boolean) def test_grouped_linear_accuracy_cutlass( @@ -2178,6 +2180,8 @@ def test_grouped_linear_accuracy_cutlass( num_gemms, bs, model, + recipe, + fp8_model_params, fuse_wgrad_accumulation, delay_wgrad_compute, ): @@ -2187,8 +2191,8 @@ def test_grouped_linear_accuracy_cutlass( num_gemms, bs, model, - None, - False, + recipe, + fp8_model_params, fuse_wgrad_accumulation, False, delay_wgrad_compute, diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 42e5c0449..104481cde 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -249,7 +249,9 @@ else() comm_gemm_overlap/rocm_comm_gemm_overlap.cpp fused_attn_rocm/fused_attn.cpp gemm/rocm_gemm.cu - gemm/ck_grouped_gemm.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp amd_detail/system.cpp) list(APPEND transformer_engine_cuda_sources fused_attn_rocm/fused_attn_aotriton.cpp diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp deleted file mode 100644 index def454f86..000000000 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ /dev/null @@ -1,338 +0,0 @@ -/************************************************************************* - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. - * - * License for AMD contributions = MIT. See LICENSE for more information - ************************************************************************/ - -#include - -#include -#include "../common.h" - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" - -namespace transformer_engine { -namespace grouped_gemm { - -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; - -template struct TETypeToCKType; -template <> struct TETypeToCKType { using type = ck_tile::half_t; }; -template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; - -// Treat TE tensors as generalized 2D matrices by flattening: -// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. -static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, - int64_t& d0, int64_t& d1) { - // Require at least a matrix (rank >= 2). Higher ranks are flattened. - if (t.shape().size() < 2) - return false; - d0 = static_cast(t.flat_first_dim()); - d1 = static_cast(t.flat_last_dim()); - return true; -} - -// Selects epilogue traits based on whether we are accumulating (D += A*B) or not (D = A*B). -// For accumulate=true, the existing D buffer is passed as a MultiD input tensor and combined -// via element_wise::Add. For accumulate=false, no extra input is needed and PassThrough is used. -template -struct EpilogueTraits { - using DsDataType = ck_tile::tuple<>; - using DsLayout = ck_tile::tuple<>; - using ElemOp = ck_tile::element_wise::PassThrough; -}; -template -struct EpilogueTraits { - using DsDataType = ck_tile::tuple; - using DsLayout = ck_tile::tuple; - using ElemOp = ck_tile::element_wise::Add; -}; - -static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { - return t.data; // rowwise data view -} - -// Primus-Turbo-like FP16/BF16 tile configs -// Selection rule: -// if (N % 256 == 0) use 256x256x64 -// else if (N % 128 == 0) use 256x128x64 -// else use 256x128x64 with N padding enabled -struct TileCfg_256x256x64 { - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool DoubleSmemBuffer = false; - - static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; - static constexpr ck_tile::index_t TilePartitionerM01 = 4; -}; - -struct TileCfg_256x128x64 : TileCfg_256x256x64 { - static constexpr ck_tile::index_t N_Tile = 128; -}; - -struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { - static constexpr bool kPadN = true; -}; - -// This class instantiates CK_Tile's grouped GEMM pipeline. -// See e.g. https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/03_gemm/universal_gemm_invoker.hpp for reference. -template -struct Runner{ - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile::sequence>; - - using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< - GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; - - using UniversalTraits = ck_tile::PersistentTileGemmUniversalTraits< - TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, - TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; - - static constexpr ck_tile::GemmPipelineScheduler Scheduler = - ck_tile::GemmPipelineScheduler::Intrawave; - - using Problem = ck_tile::UniversalGemmPipelineProblem< - AType, BType, AccType, GemmShape, UniversalTraits, Scheduler>; - - using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; - - using ET = EpilogueTraits; - - using Epilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem< - AType, BType, typename ET::DsDataType, AccType, - CType, typename ET::DsLayout, CLayout, - typename ET::ElemOp, - Partitioner::MPerBlock, Partitioner::NPerBlock, - TileCfg::M_Warp, TileCfg::N_Warp, - TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, - Problem::TransposeC>>; - - using Kernel = ck_tile::GroupedGemmKernel; -}; - -template -static bool run_grouped_impl(const NVTETensor* A_use, - const NVTETensor* B_use, - NVTETensor* D, - int group_num, - bool transA_use, - bool transB_use, - void* workspace, - size_t workspace_bytes, - hipStream_t stream) -{ - using Kernel = typename Runner::Kernel; - - const size_t needed = Kernel::GetWorkSpaceSize(group_num); - if (!workspace || workspace_bytes < needed) { - NVTE_WARN("ck_tile_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, - ", available bytes=", workspace_bytes, ". Falling back."); - return false; - } - - // GroupedGemmHostArgs<1> for the MultiD accumulate path, <0> for the overwrite path. - using HostArgs = std::conditional_t, - ck_tile::GroupedGemmHostArgs<0>>; - - thread_local std::vector descs; - descs.clear(); - descs.reserve(group_num); - - for (int i = 0; i < group_num; ++i) { - const transformer_engine::Tensor* const A_te = - transformer_engine::convertNVTETensorCheck(A_use[i]); - const transformer_engine::Tensor* const B_te = - transformer_engine::convertNVTETensorCheck(B_use[i]); - transformer_engine::Tensor* D_te = - transformer_engine::convertNVTETensorCheck(D[i]); - - const auto& a = data_view(*A_te); - const auto& b = data_view(*B_te); - const auto& d = data_view(*D_te); - - int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || - !get_flat_2d_dims(*B_te, Bd0, Bd1) || - !get_flat_2d_dims(*D_te, Dd0, Dd1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2 (2D or higher)."); - return false; - } - - const int64_t M = transA_use ? Ad1 : Ad0; - const int64_t K = transA_use ? Ad0 : Ad1; - const int64_t N = transB_use ? Bd0 : Bd1; - const int64_t Kb = transB_use ? Bd1 : Bd0; - - if (Kb != K) { - NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); - return false; - } - - if (Dd0 != M || Dd1 != N) { - NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); - return false; - } - - // Leading dimensions under the flattened-contiguous interpretation - const ck_tile::index_t stride_A = Ad1; - const ck_tile::index_t stride_B = Bd1; - const ck_tile::index_t stride_E = Dd1; - - if constexpr (Accumulate) { - // MultiD: E = Add(A@B, D1). D1 and E point to the same buffer for in-place accumulation. - descs.emplace_back( - a.dptr, b.dptr, - std::array{d.dptr}, // D1 = existing D contents (read) - d.dptr, // E = same buffer (write) - 1, M, N, K, - stride_A, stride_B, - std::array{stride_E}, - stride_E); - } else { - descs.emplace_back( - a.dptr, b.dptr, - std::array{}, - d.dptr, - 1, M, N, K, - stride_A, stride_B, - std::array{}, - stride_E); - } - } - - const dim3 grids = Kernel::GridSize(descs); - auto kargs = Kernel::MakeKargs(descs); - if (!Kernel::IsSupportedArgument(kargs)) { - NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. " - "Falling back."); - return false; - } - - HIP_CHECK_ERROR(hipMemcpyAsync(workspace, - kargs.data(), - kargs.size() * sizeof(typename decltype(kargs)::value_type), - hipMemcpyHostToDevice, - stream)); - - const ck_tile::stream_config s{stream}; - const dim3 blocks = Kernel::BlockSize(); - - ck_tile::launch_kernel( - s, - ck_tile::make_kernel<1>( - Kernel{}, grids, blocks, 0, - ck_tile::cast_pointer_to_constant_address_space(workspace), - group_num)); - return true; -} - -} // namespace grouped_gemm -} // namespace transformer_engine - -bool ck_tile_grouped_gemm(const NVTETensor* A, - const NVTETensor* B, - NVTETensor* D, - int group_num, - bool transA, - bool transB, - NVTETensor* workspace, - bool accumulate, - hipStream_t stream) -{ - if (group_num <= 0) - return true; - - using namespace transformer_engine; - using namespace transformer_engine::grouped_gemm; - - // Workspace pointer + bytes - void* ws_ptr = nullptr; - size_t ws_bytes = 0; - if (workspace) { - auto* ws_te = convertNVTETensorCheck(*workspace); - ws_ptr = ws_te->data.dptr; - ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); - } - - // Normalize similar to upstream - // See https://github.com/NVIDIA/TransformerEngine/blob/59f6f3876767d07045152bfae07b5dd4c54e1725/transformer_engine/common/gemm/cutlass_grouped_gemm.cu#L54-L68 - // I.e., swap A and B, as well as transa and transb. - const NVTETensor* A_use = B; - const NVTETensor* B_use = A; - const bool transA_use = transB; - const bool transB_use = transA; - - const auto a_dtype = convertNVTETensorCheck(A_use[0])->dtype(); - - // Get N from D[0] (assume uniform N across groups) - int64_t ref_d0 = 0, ref_d1 = 0; - Tensor* D0_te = convertNVTETensorCheck(D[0]); - if (!get_flat_2d_dims(*D0_te, ref_d0, ref_d1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]"); - return false; - } - const ck_tile::index_t N = static_cast(ref_d1); - - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, te_type, { - using T = typename TETypeToCKType::type; - - auto run_with_tilecfg = [&](auto tile_tag) -> bool { - using TileCfgSel = decltype(tile_tag); - - TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, { - using ALayout = std::conditional_t; - - TRANSFORMER_ENGINE_SWITCH_CONDITION(transB_use, kTransB, { - using BLayout = std::conditional_t; - - if (accumulate) { - return run_grouped_impl( - A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); - } else { - return run_grouped_impl( - A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); - } - }); - }); - }; - - // Select tile config like Primus-Turbo for FP16/BF16: - // N%256 -> 256x256x64 - // N%128 -> 256x128x64 - // else -> 256x128x64 padding - // NOTE: We assume N is uniform across groups. - if ((N % 256) == 0) { - return run_with_tilecfg(TileCfg_256x256x64{}); - } else if ((N % 128) == 0) { - return run_with_tilecfg(TileCfg_256x128x64{}); - } else { - return run_with_tilecfg(TileCfg_256x128x64_padding{}); - } - }); -} diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp new file mode 100644 index 000000000..ccd71fa43 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp @@ -0,0 +1,138 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include "ck_grouped_gemm_common.h" + +bool ck_tile_grouped_gemm(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + bool accumulate, + cudaStream_t stream) { + if (group_num <= 0) { + return true; + } + + using namespace transformer_engine; + using namespace transformer_engine::grouped_gemm; + + void* ws_ptr = nullptr; + size_t ws_bytes = 0; + if (workspace) { + auto* ws_te = convertNVTETensorCheck(*workspace); + ws_ptr = ws_te->data.dptr; + ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); + } + + // Normalize similar to upstream + // See https://github.com/NVIDIA/TransformerEngine/blob/59f6f3876767d07045152bfae07b5dd4c54e1725/transformer_engine/common/gemm/cutlass_grouped_gemm.cu#L54-L68 + // I.e., swap A and B, as well as transa and transb. + const NVTETensor* A_use = B; + const NVTETensor* B_use = A; + bool transA_use = transB; + bool transB_use = transA; + bool use_b_columnwise_data = false; + + const auto caller_a_dtype = convertNVTETensorCheck(A[0])->dtype(); + const bool is_fp8 = is_fp8_dtype(caller_a_dtype); + const bool is_fp16 = is_fp16_dtype(caller_a_dtype); + + // Currently the accumulate path is only supported on fp16 + if (accumulate && is_fp8) + return false; + + // Handle pathological NN case during fp8 dX GEMM by reading W columnwise and re-formulating as NT + if (!transA_use && !transB_use && is_fp8) { + auto* B0_te = convertNVTETensorCheck(B_use[0]); + if (B0_te->has_columnwise_data()) { + use_b_columnwise_data = true; + transB_use = true; + } + } + + const auto a_dtype = convertNVTETensorCheck(A_use[0])->dtype(); + const auto b_dtype = convertNVTETensorCheck(B_use[0])->dtype(); + + Tensor* D0_te = convertNVTETensorCheck(D[0]); + const auto d_dtype = D0_te->dtype(); + + Tensor* A0_te = convertNVTETensorCheck(A_use[0]); + Tensor* B0_te = convertNVTETensorCheck(B_use[0]); + + int64_t a0 = 0, a1 = 0; + int64_t b0 = 0, b1 = 0; + int64_t d0 = 0, d1 = 0; + + if (!get_flat_2d_dims(*A0_te, a0, a1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized A_use[0]"); + return false; + } + + if (use_b_columnwise_data) { + if (B0_te->columnwise_data.shape.size() < 2) { + NVTE_ERROR("ck_tile_grouped_gemm: expected columnwise_data rank>=2 for B_use[0]"); + return false; + } + b0 = static_cast(B0_te->columnwise_data.shape[B0_te->columnwise_data.shape.size() - 2]); + b1 = static_cast(B0_te->columnwise_data.shape[B0_te->columnwise_data.shape.size() - 1]); + } else { + if (!get_flat_2d_dims(*B0_te, b0, b1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized B_use[0]"); + return false; + } + } + + if (!get_flat_2d_dims(*D0_te, d0, d1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]"); + return false; + } + + const int64_t m = transA_use ? a1 : a0; + const int64_t kA = transA_use ? a0 : a1; + + const int64_t kB = transB_use ? b1 : b0; + const int64_t n = transB_use ? b0 : b1; + + if (kA != kB) { + NVTE_ERROR("ck_tile_grouped_gemm: normalized GEMM K mismatch: op(A_use) is ", + m, "x", kA, ", op(B_use) is ", kB, "x", n); + return false; + } + + if (d0 != m || d1 != n) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch for normalized GEMM. " + "D is ", d0, "x", d1, " but expected ", m, "x", n); + return false; + } + + GroupedGemmRunContext ctx = { + A_use, + B_use, + D, + static_cast(n), + group_num, + transA_use, + transB_use, + ws_ptr, + ws_bytes, + stream, + use_b_columnwise_data, + accumulate}; + + + if (is_fp16) { + return ck_tile_grouped_gemm_fp16_dispatch(a_dtype, b_dtype, d_dtype, ctx); + } + + else if (is_fp8) { + return ck_tile_grouped_gemm_fp8_dispatch(a_dtype, b_dtype, d_dtype, ctx); + } + + return false; +} diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.h similarity index 93% rename from transformer_engine/common/gemm/ck_grouped_gemm.h rename to transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.h index 97b4cfd88..2e0c71983 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.h @@ -12,4 +12,4 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, bool transB, NVTETensor* workspace, bool accumulate, - hipStream_t stream); + cudaStream_t stream); diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h new file mode 100644 index 000000000..556dc1eef --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -0,0 +1,108 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include +#include "../../common.h" + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; + +template struct TETypeToCKType; +template <> struct TETypeToCKType { using type = float; }; +template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::half_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; + +// Selects epilogue traits based on whether we are accumulating (D += A*B) or not (D = A*B). +// For accumulate=true, the existing D buffer is passed as a MultiD input tensor and combined +// via element_wise::Add. For accumulate=false, no extra input is needed and PassThrough is used. +template +struct EpilogueTraits { + using DsDataType = ck_tile::tuple<>; + using DsLayout = ck_tile::tuple<>; + using ElemOp = ck_tile::element_wise::PassThrough; +}; +template +struct EpilogueTraits { + using DsDataType = ck_tile::tuple; + using DsLayout = ck_tile::tuple; + using ElemOp = ck_tile::element_wise::Add; +}; + +static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { + return t.data; +} + +static inline const transformer_engine::SimpleTensor& scale_inv_view(const transformer_engine::Tensor& t) { + return t.scale_inv; +} + +struct GroupedGemmRunContext { + const NVTETensor* A = nullptr; + const NVTETensor* B = nullptr; + NVTETensor* D = nullptr; + int64_t N = 0; + + int group_num = 0; + bool transA = false; + bool transB = false; + + void* workspace = nullptr; + size_t workspace_bytes = 0; + cudaStream_t stream = nullptr; + + bool use_b_columnwise_data = false; + bool accumulate = false; +}; + +// Treat TE tensors as generalized 2D matrices by flattening: +// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. +static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, + int64_t& d0, int64_t& d1) { + if (t.shape().size() < 2) { + return false; + } + d0 = static_cast(t.flat_first_dim()); + d1 = static_cast(t.flat_last_dim()); + return true; +} + +bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx); + +bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx); + +class RunnerInterface { +public: + virtual ~RunnerInterface() = default; + virtual bool run(const ck_tile::stream_config& stream_cfg, + const GroupedGemmRunContext& ctx) = 0; +}; + +} // namespace grouped_gemm +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp new file mode 100644 index 000000000..ea33a7ace --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp @@ -0,0 +1,83 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include "ck_grouped_gemm_common.h" +#include "ck_grouped_gemm_fp16_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +#define MAKE_RUNNER(TileCfg_) \ + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.accumulate, accum_option, {\ + using Runner = GroupedGemmRunner< \ + AType, BType, CType, \ + ALayout, BLayout, CLayout, \ + TileCfg_, accum_option>; \ + runner = std::make_unique(); \ + }) + +template +static std::unique_ptr make_fp16_runner_typed(DType d_dtype, const GroupedGemmRunContext& ctx) { + std::unique_ptr runner = nullptr; + using AType = typename TETypeToCKType::type; + using BType = typename TETypeToCKType::type; + using CLayout = RowMajor; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + + if (ctx.N % 256 == 0) { + MAKE_RUNNER(TileCfg_256x256x64); + } else if (ctx.N % 128 == 0) { + MAKE_RUNNER(TileCfg_256x128x64); + } else { + MAKE_RUNNER(TileCfg_256x128x64_padding); + } + }); + return runner; +} + +#undef MAKE_RUNNER + +static std::unique_ptr +make_fp16_runner(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { + using ALayout = std::conditional_t; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { + using BLayout = std::conditional_t; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, a_type, { + return make_fp16_runner_typed(d_dtype, ctx); + }); + }); + }); + + return nullptr; +} + +bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + const ck_tile::stream_config s{ctx.stream}; + + auto runner = make_fp16_runner( + a_dtype, b_dtype, d_dtype, ctx); + + if (!runner) { + return false; + } + + return runner->run(s, ctx); +} + +} // namespace grouped_gemm +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h new file mode 100644 index 000000000..e4138a132 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h @@ -0,0 +1,214 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#pragma once +#include "ck_grouped_gemm_common.h" + +#include +#include +#include +#include + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +// ------------------------- +// Tile configs: FP16/BF16 +// ------------------------- + +struct TileCfg_256x256x64 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +struct TileCfg_256x128x64 : TileCfg_256x256x64 { + static constexpr ck_tile::index_t N_Tile = 128; +}; + +struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { + static constexpr bool kPadN = true; +}; + +template +class GroupedGemmRunner : public RunnerInterface { +public: + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; + + using UniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; + + static constexpr ck_tile::GemmPipelineScheduler Scheduler = + ck_tile::GemmPipelineScheduler::Intrawave; + + using Problem = ck_tile::UniversalGemmPipelineProblem< + AType, BType, AccType, + GemmShape, UniversalTraits, Scheduler>; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using ET = EpilogueTraits; + + using Epilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem< + AType, BType, typename ET::DsDataType, AccType, + CType, typename ET::DsLayout, CLayout, + typename ET::ElemOp, + Partitioner::MPerBlock, Partitioner::NPerBlock, + TileCfg::M_Warp, TileCfg::N_Warp, + TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, + Problem::TransposeC>>; + + using Kernel = ck_tile::GroupedGemmKernel; + + // GroupedGemmHostArgs<1> for the MultiD accumulate path, <0> for the overwrite path. + using HostArgs = std::conditional_t, + ck_tile::GroupedGemmHostArgs<0>>; + +public: + static std::vector build_descs(const GroupedGemmRunContext& ctx) { + const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); + if (!ctx.workspace || ctx.workspace_bytes < needed) { + NVTE_WARN("ck_tile_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, + ", available bytes=", ctx.workspace_bytes, ". Falling back."); + return {}; + } + + thread_local std::vector descs; + descs.clear(); + descs.reserve(ctx.group_num); + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(ctx.D[i]); + + const auto& a = data_view(*A_te); + const auto& b = data_view(*B_te); + const auto& d = data_view(*D_te); + + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || + !get_flat_2d_dims(*B_te, Bd0, Bd1) || + !get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2."); + } + + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + const int64_t Kb = ctx.transB ? Bd1 : Bd0; + + if (Kb != K) { + NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); + } + + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); + } + + const ck_tile::index_t stride_A = Ad1; + const ck_tile::index_t stride_B = Bd1; + const ck_tile::index_t stride_E = Dd1; + + if constexpr(Accumulate) { + descs.emplace_back( + a.dptr, b.dptr, + std::array{d.dptr}, // D1 = existing D contents (read) + d.dptr, // E = same buffer (write) + 1, M, N, K, + stride_A, stride_B, + std::array{stride_E}, + stride_E); + } else { + descs.emplace_back( + a.dptr, b.dptr, + std::array{}, + d.dptr, + 1, M, N, K, + stride_A, stride_B, + std::array{}, + stride_E); + } + } + + return descs; + }; + + + bool run(const ck_tile::stream_config& stream_cfg, + const GroupedGemmRunContext& ctx) override { + auto descs = build_descs(ctx); + + constexpr int kBlockPerCu = 1; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(descs); + auto kargs = Kernel::MakeKargs(descs); + if (!Kernel::IsSupportedArgument(kargs)) { + NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. " + "Falling back."); + return false; + } + + NVTE_CHECK_CUDA(cudaMemcpyAsync(ctx.workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + cudaMemcpyHostToDevice, + ctx.stream)); + + ck_tile::launch_kernel( + stream_cfg, ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), + ctx.group_num)); + return true; + }; +}; + +} // namespace grouped_gemm +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp new file mode 100644 index 000000000..6419e3188 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -0,0 +1,134 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include "ck_grouped_gemm_common.h" +#include "ck_grouped_gemm_fp8_impl.h" +#include "common/util/cuda_runtime.h" + +namespace transformer_engine { +namespace grouped_gemm { + +enum class GPUArch { + GFX942, + GFX950, + UNKNOWN +}; + +static inline GPUArch detect_gpu_arch() { + int arch = cuda::sm_arch(0); + + if (arch == 94) { + return GPUArch::GFX942; + } + if (arch == 95) { + return GPUArch::GFX950; + } + return GPUArch::UNKNOWN; +} + +template +struct FP8TileCfg; + +template <> +struct FP8TileCfg { + using type = TileCfg_128x128x128_32x32x16_2x2x1; +}; + +template <> +struct FP8TileCfg { + using type = TileCfg_128x128x128_16x16x128_2x2x1; +}; + +template +static std::unique_ptr make_fp8_runner_typed(DType d_dtype, + const GroupedGemmRunContext& ctx) { + std::unique_ptr runner = nullptr; + + using AType = typename TETypeToCKType::type; + using BType = typename TETypeToCKType::type; + using CTypeLayout = RowMajor; + using TileCfg = typename FP8TileCfg::type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + using Runner = QuantGroupedGemmRunner< + AType, BType, CType, + ALayout, BLayout, CTypeLayout, + TileCfg, ck_tile::memory_operation_enum::set>; + runner = std::make_unique(); + }); + + return runner; +} + +template +static std::unique_ptr make_fp8_runner_impl(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { + using ALayout = std::conditional_t; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { + using BLayout = std::conditional_t; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_type, { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_type, { + return make_fp8_runner_typed( + d_dtype, ctx); + }); + }); + }); + }); + + return nullptr; +} + +static inline std::unique_ptr make_fp8_runner_gfx942(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + return make_fp8_runner_impl(a_dtype, b_dtype, d_dtype, ctx); +} + +static inline std::unique_ptr make_fp8_runner_gfx950(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + return make_fp8_runner_impl(a_dtype, b_dtype, d_dtype, ctx); +} + +static std::unique_ptr +make_fp8_runner(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + switch (detect_gpu_arch()) { + case GPUArch::GFX942: + return make_fp8_runner_gfx942(a_dtype, b_dtype, d_dtype, ctx); + case GPUArch::GFX950: + return make_fp8_runner_gfx950(a_dtype, b_dtype, d_dtype, ctx); + default: + NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950}"); + return nullptr; + } +} + +bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + const ck_tile::stream_config s{ctx.stream}; + auto runner = make_fp8_runner(a_dtype, b_dtype, d_dtype, ctx); + if (!runner) { + return false; + } + return runner->run(s, ctx); +} + +} // namespace grouped_gemm +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8_impl.h new file mode 100644 index 000000000..bbc994fb4 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8_impl.h @@ -0,0 +1,260 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#pragma once +#include "ck_grouped_gemm_common.h" + +#include +#include +#include +#include + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +struct TileCfg_128x128x128_16x16x128_2x2x1 { + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 16; + static constexpr ck_tile::index_t TilePartitionerM01 = 8; +}; + +// gfx950 device compilation cannot instantiate the literal 32x32x16 FP8 tile +// configuration due to an unsupported warp GEMM dispatcher configuration. +// See: ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp for supported variants. +// +// To preserve the existing type name in shared template code, this struct +// inherits from the gfx950-safe 16x16x128 configuration in the gfx950 device +// compilation path, effectively reusing those parameters without redefining them. +// +// In all other compilation paths, the struct overrides the relevant fields to +// provide the intended 32x32x16 configuration. +#if defined(__gfx950__) +struct TileCfg_128x128x128_32x32x16_2x2x1 + : TileCfg_128x128x128_16x16x128_2x2x1 { +}; +#else +struct TileCfg_128x128x128_32x32x16_2x2x1 + : TileCfg_128x128x128_16x16x128_2x2x1 { + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; +#endif + +template +class QuantGroupedGemmRunner : public RunnerInterface { +public: + static constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::TensorQuant; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; + + using AQLayout = RowMajor; + using BQLayout = RowMajor; + + using UniversalTraits = + ck_tile::TileGemmQuantTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + false, false, false, ALayout, BLayout, CLayout, + QuantMode, AQLayout, BQLayout, + false, TileCfg::DoubleSmemBuffer, false>; + + using Problem = ck_tile::GemmRowColTensorQuantPipelineProblem< + AType, BType, AccType, + AccType, GemmShape, UniversalTraits, + false, AccType>; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using Epilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem< + AType, BType, ck_tile::tuple<>, AccType, + CType, ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, + Partitioner::MPerBlock, Partitioner::NPerBlock, + TileCfg::M_Warp, TileCfg::N_Warp, + TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, + Problem::TransposeC>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + using HostArgs = ck_tile::QuantGroupedGemmHostArgs; + +public: + static std::vector build_descs(const GroupedGemmRunContext& ctx) { + const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); + if (!ctx.workspace || ctx.workspace_bytes < needed) { + NVTE_WARN("ck_tile_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, + ", available bytes=", ctx.workspace_bytes, ". Falling back."); + return {}; + } + + std::vector descs; + descs.reserve(ctx.group_num); + + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(ctx.D[i]); + + const auto& a = data_view(*A_te); + const auto& d = data_view(*D_te); + + const transformer_engine::SimpleTensor* b_src = nullptr; + if (ctx.use_b_columnwise_data) { + if (!B_te->has_columnwise_data()) { + NVTE_ERROR("ck_tile_grouped_gemm: ctx.use_b_columnwise_data=true but columnwise_data is absent."); + } + b_src = &B_te->columnwise_data; + } else { + b_src = &B_te->data; + } + + const auto& b = *b_src; + + int64_t Ad0 = 0, Ad1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || + !get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected A and D to be rank>=2."); + } + + if (b.shape.size() < 2) { + NVTE_ERROR("ck_tile_grouped_gemm: expected chosen B source to be rank>=2."); + } + + int64_t Bd0 = static_cast(b.shape[b.shape.size() - 2]); + int64_t Bd1 = static_cast(b.shape[b.shape.size() - 1]); + + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + const int64_t Kb = ctx.transB ? Bd1 : Bd0; + + if (Kb != K) { + NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i, + ". op(A)=", M, "x", K, + " op(B)=", Kb, "x", N, + " raw A=", Ad0, "x", Ad1, + " raw B=", Bd0, "x", Bd1, + " use_b_columnwise_data=", static_cast(ctx.use_b_columnwise_data), + " transA=", static_cast(ctx.transA), + " transB=", static_cast(ctx.transB)); + } + + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i, + ". D=", Dd0, "x", Dd1, + ", expected=", M, "x", N); + } + + const ck_tile::index_t stride_A = static_cast(Ad1); + const ck_tile::index_t stride_B = static_cast(Bd1); + const ck_tile::index_t stride_E = static_cast(Dd1); + + ck_tile::index_t AQK = 1; + ck_tile::index_t BQK = 1; + ck_tile::index_t stride_AQ = 1; + ck_tile::index_t stride_BQ = 1; + + const auto& aq = scale_inv_view(*A_te); + const auto& bq = scale_inv_view(*B_te); + + descs.emplace_back( + a.dptr, + b.dptr, + d.dptr, + aq.dptr, + bq.dptr, + 1, + M, + N, + K, + AQK, + BQK, + stride_A, + stride_B, + stride_E, + stride_AQ, + stride_BQ); + } + + return descs; + } + + bool run(const ck_tile::stream_config& stream_cfg, + const GroupedGemmRunContext& ctx) override { + auto descs = build_descs(ctx); + + constexpr int kBlockPerCu = 1; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(descs); + auto kargs = Kernel::MakeKargs(descs); + + if (!Kernel::IsSupportedArgument(kargs)) { + NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. " + "Falling back."); + return false; + } + + NVTE_CHECK_CUDA(cudaMemcpyAsync(ctx.workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + cudaMemcpyHostToDevice, + ctx.stream)); + + ck_tile::launch_kernel( + stream_cfg, ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), + ctx.group_num)); + return true; + } +}; + +} // namespace grouped_gemm +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 35da3075c..620275d0b 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -32,7 +32,7 @@ #ifndef __HIP_PLATFORM_AMD__ #include "./cutlass_grouped_gemm.cuh" #else -#include "ck_grouped_gemm.h" +#include "ck_grouped_gemm/ck_grouped_gemm.h" #endif #ifndef __HIP_PLATFORM_AMD__ @@ -1140,9 +1140,13 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor auto A_dt = inputA->data.dtype; auto B_dt = inputB->data.dtype; auto D_dt = OutputD->data.dtype; - return (A_dt == B_dt) && (A_dt == D_dt) && - (A_dt == transformer_engine::DType::kFloat16 || - A_dt == transformer_engine::DType::kBFloat16); + return ( + (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) + ) || + ( + (A_dt == B_dt) && (A_dt == D_dt) && + (is_fp16_dtype(A_dt)) + ); #else auto A_type = get_cuda_dtype(inputA->data.dtype); auto B_type = get_cuda_dtype(inputB->data.dtype); diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index bd91212e9..056dc14ac 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -423,6 +423,16 @@ enum class DType { kNumTypes }; +/*! \brief Check if TE datatype is FP16 + * + * Return true if TE datatype is FP16 + * \param[in] DType TE Datatype of interest + */ +inline bool is_fp16_dtype(const DType t) { + return t == DType::kFloat16 || t == DType::kBFloat16; +} + + /*! \brief Check if TE datatype is FP8 * * Return true if TE datatype is FP8 From 1b00bc04d8f52bdf8fdf542868fe5f01bd5464be Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Fri, 3 Apr 2026 14:44:21 +0000 Subject: [PATCH 02/14] code refactoring to address pr comments --- tests/pytorch/test_numerics.py | 10 +- .../gemm/ck_grouped_gemm/ck_grouped_gemm.cpp | 26 +- .../gemm/ck_grouped_gemm/ck_grouped_gemm.h | 2 +- .../ck_grouped_gemm/ck_grouped_gemm_common.h | 88 ++++- .../ck_grouped_gemm/ck_grouped_gemm_fp16.cpp | 282 +++++++++++--- .../ck_grouped_gemm/ck_grouped_gemm_fp16.h | 18 + .../ck_grouped_gemm_fp16_impl.h | 214 ----------- .../ck_grouped_gemm/ck_grouped_gemm_fp8.cpp | 360 +++++++++++++++--- .../ck_grouped_gemm/ck_grouped_gemm_fp8.h | 18 + .../ck_grouped_gemm_fp8_impl.h | 260 ------------- .../transformer_engine/transformer_engine.h | 3 +- 11 files changed, 686 insertions(+), 595 deletions(-) create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.h delete mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.h delete mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8_impl.h diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index bb2ecdca6..a1e2e1e98 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2171,8 +2171,14 @@ def test_grouped_linear_accuracy( @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("fp8_model_params", all_boolean) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize( + "fp8_model_params", + all_boolean if IS_HIP_EXTENSION else [False], +) +@pytest.mark.parametrize( + "recipe", + (fp8_recipes + [None]) if IS_HIP_EXTENSION else [None], +) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("delay_wgrad_compute", all_boolean) def test_grouped_linear_accuracy_cutlass( diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp index ccd71fa43..347c10399 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp @@ -5,6 +5,8 @@ ************************************************************************/ #include "ck_grouped_gemm_common.h" +#include "ck_grouped_gemm_fp16.h" +#include "ck_grouped_gemm_fp8.h" bool ck_tile_grouped_gemm(const NVTETensor* A, const NVTETensor* B, @@ -14,7 +16,7 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, bool transB, NVTETensor* workspace, bool accumulate, - cudaStream_t stream) { + hipStream_t stream) { if (group_num <= 0) { return true; } @@ -40,15 +42,15 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, bool use_b_columnwise_data = false; const auto caller_a_dtype = convertNVTETensorCheck(A[0])->dtype(); - const bool is_fp8 = is_fp8_dtype(caller_a_dtype); - const bool is_fp16 = is_fp16_dtype(caller_a_dtype); + const bool is_8bit_float = is_fp8_dtype(caller_a_dtype); + const bool is_16bit_float = is_fp16_dtype(caller_a_dtype); // Currently the accumulate path is only supported on fp16 - if (accumulate && is_fp8) + if (accumulate && is_8bit_float) return false; // Handle pathological NN case during fp8 dX GEMM by reading W columnwise and re-formulating as NT - if (!transA_use && !transB_use && is_fp8) { + if (!transA_use && !transB_use && is_8bit_float) { auto* B0_te = convertNVTETensorCheck(B_use[0]); if (B0_te->has_columnwise_data()) { use_b_columnwise_data = true; @@ -75,12 +77,10 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, } if (use_b_columnwise_data) { - if (B0_te->columnwise_data.shape.size() < 2) { - NVTE_ERROR("ck_tile_grouped_gemm: expected columnwise_data rank>=2 for B_use[0]"); + if (!get_columnwise_storage_2d_dims(B0_te->columnwise_data, b0, b1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected 2D columnwise_data for B_use[0]"); return false; } - b0 = static_cast(B0_te->columnwise_data.shape[B0_te->columnwise_data.shape.size() - 2]); - b1 = static_cast(B0_te->columnwise_data.shape[B0_te->columnwise_data.shape.size() - 1]); } else { if (!get_flat_2d_dims(*B0_te, b0, b1)) { NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized B_use[0]"); @@ -125,14 +125,12 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, use_b_columnwise_data, accumulate}; - - if (is_fp16) { + if (is_16bit_float) { return ck_tile_grouped_gemm_fp16_dispatch(a_dtype, b_dtype, d_dtype, ctx); - } - - else if (is_fp8) { + } else if (is_8bit_float) { return ck_tile_grouped_gemm_fp8_dispatch(a_dtype, b_dtype, d_dtype, ctx); } + NVTE_WARN("ck_tile_grouped_gemm: input dtype is neither fp16 nor fp8."); return false; } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.h index 2e0c71983..97b4cfd88 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.h @@ -12,4 +12,4 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, bool transB, NVTETensor* workspace, bool accumulate, - cudaStream_t stream); + hipStream_t stream); diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h index 556dc1eef..3d82654ad 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -6,8 +6,7 @@ #pragma once -#include -#include +#include #include #include @@ -18,6 +17,7 @@ #include "../../common.h" #include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" namespace transformer_engine { @@ -26,6 +26,18 @@ namespace grouped_gemm { using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +template +using GroupedGemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + +template +using GroupedGemmPartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GroupedGemmShape, + TileCfg::TilePartitionerGroupNum, + TileCfg::TilePartitionerM01>; + template struct TETypeToCKType; template <> struct TETypeToCKType { using type = float; }; template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; @@ -42,6 +54,7 @@ struct EpilogueTraits { using DsLayout = ck_tile::tuple<>; using ElemOp = ck_tile::element_wise::PassThrough; }; + template struct EpilogueTraits { using DsDataType = ck_tile::tuple; @@ -69,7 +82,7 @@ struct GroupedGemmRunContext { void* workspace = nullptr; size_t workspace_bytes = 0; - cudaStream_t stream = nullptr; + hipStream_t stream = nullptr; bool use_b_columnwise_data = false; bool accumulate = false; @@ -87,22 +100,69 @@ static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, return true; } -bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, - DType b_dtype, - DType d_dtype, - const GroupedGemmRunContext& ctx); +// Extract GEMM dims from columnwise storage. +// This path expects columnwise_data to already be normalized to a 2D layout. +static inline bool get_columnwise_storage_2d_dims( + const transformer_engine::SimpleTensor& t, + int64_t& d0, + int64_t& d1) { + + if (t.shape.size() != 2) { + return false; + } + + d0 = static_cast(t.shape[0]); + d1 = static_cast(t.shape[1]); + return true; +} + +template +static inline bool has_sufficient_workspace(const GroupedGemmRunContext& ctx) { + const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); + if (!ctx.workspace || ctx.workspace_bytes < needed) { + NVTE_WARN("ck_tile_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, + ", available bytes=", ctx.workspace_bytes, ". Falling back."); + return false; + } + return true; +} + +template +static inline bool launch_grouped_gemm_kernel(const DescContainer& descs, + const GroupedGemmRunContext& ctx, + const ck_tile::stream_config& stream_cfg) { + constexpr int kBlockPerCu = 1; -bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, - DType b_dtype, - DType d_dtype, - const GroupedGemmRunContext& ctx); + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(descs); + auto kargs = Kernel::MakeKargs(descs); + + if (!Kernel::IsSupportedArgument(kargs)) { + NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. " + "Falling back."); + return false; + } + + NVTE_CHECK_CUDA(hipMemcpyAsync(ctx.workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + ctx.stream)); + + ck_tile::launch_kernel( + stream_cfg, ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), + ctx.group_num)); + return true; +} class RunnerInterface { public: virtual ~RunnerInterface() = default; virtual bool run(const ck_tile::stream_config& stream_cfg, - const GroupedGemmRunContext& ctx) = 0; + const GroupedGemmRunContext& ctx) = 0; }; - + } // namespace grouped_gemm -} // namespace transformer_engine +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp index ea33a7ace..da464d977 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp @@ -5,79 +5,269 @@ ************************************************************************/ #include "ck_grouped_gemm_common.h" -#include "ck_grouped_gemm_fp16_impl.h" +#include "ck_grouped_gemm_fp16.h" namespace transformer_engine { namespace grouped_gemm { -#define MAKE_RUNNER(TileCfg_) \ - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.accumulate, accum_option, {\ - using Runner = GroupedGemmRunner< \ - AType, BType, CType, \ - ALayout, BLayout, CLayout, \ - TileCfg_, accum_option>; \ - runner = std::make_unique(); \ - }) +// ------------------------- +// Tile configs: FP16/BF16 +// ------------------------- + +struct TileCfg_256x256x64 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +struct TileCfg_256x128x64 : TileCfg_256x256x64 { + static constexpr ck_tile::index_t N_Tile = 128; +}; + +struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { + static constexpr bool kPadN = true; +}; + +template +class GroupedGemmRunner : public RunnerInterface { + public: + using GemmShape = GroupedGemmShape; + using Partitioner = GroupedGemmPartitioner; + + using UniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits; + + static constexpr ck_tile::GemmPipelineScheduler Scheduler = + ck_tile::GemmPipelineScheduler::Intrawave; + + using Problem = + ck_tile::UniversalGemmPipelineProblem; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using ET = EpilogueTraits; + + using Epilogue = + ck_tile::CShuffleEpilogue>; + + using Kernel = ck_tile::GroupedGemmKernel; + + // GroupedGemmHostArgs<1> for the MultiD accumulate path, <0> for the overwrite path. + using HostArgs = std::conditional_t, + ck_tile::GroupedGemmHostArgs<0>>; + + public: + static std::vector build_descs(const GroupedGemmRunContext& ctx) { + if (!has_sufficient_workspace(ctx)) { + return {}; + } + + std::vector descs; + descs.reserve(ctx.group_num); + + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(ctx.D[i]); + + const auto& a = data_view(*A_te); + const auto& b = data_view(*B_te); + const auto& d = data_view(*D_te); + + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || + !get_flat_2d_dims(*B_te, Bd0, Bd1) || + !get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2."); + } + + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + const int64_t Kb = ctx.transB ? Bd1 : Bd0; + + if (Kb != K) { + NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); + } + + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); + } + + const ck_tile::index_t stride_A = static_cast(Ad1); + const ck_tile::index_t stride_B = static_cast(Bd1); + const ck_tile::index_t stride_E = static_cast(Dd1); + + if constexpr (Accumulate) { + descs.emplace_back(a.dptr, + b.dptr, + std::array{d.dptr}, + d.dptr, + 1, + M, + N, + K, + stride_A, + stride_B, + std::array{stride_E}, + stride_E); + } else { + descs.emplace_back(a.dptr, + b.dptr, + std::array{}, + d.dptr, + 1, + M, + N, + K, + stride_A, + stride_B, + std::array{}, + stride_E); + } + } + + return descs; + } + + bool run(const ck_tile::stream_config& stream_cfg, + const GroupedGemmRunContext& ctx) override { + auto descs = build_descs(ctx); + if (descs.empty()) { + return false; + } + return launch_grouped_gemm_kernel(descs, ctx, stream_cfg); + } +}; + +#define MAKE_RUNNER(TileCfg_) \ + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.accumulate, accum_option, { \ + using Runner = GroupedGemmRunner; \ + runner = std::make_unique(); \ + }) template -static std::unique_ptr make_fp16_runner_typed(DType d_dtype, const GroupedGemmRunContext& ctx) { +static std::unique_ptr make_fp16_runner_typed( + DType d_dtype, + const GroupedGemmRunContext& ctx) { std::unique_ptr runner = nullptr; using AType = typename TETypeToCKType::type; using BType = typename TETypeToCKType::type; using CLayout = RowMajor; TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { - using CType = typename TETypeToCKType::type; - - if (ctx.N % 256 == 0) { - MAKE_RUNNER(TileCfg_256x256x64); - } else if (ctx.N % 128 == 0) { - MAKE_RUNNER(TileCfg_256x128x64); - } else { - MAKE_RUNNER(TileCfg_256x128x64_padding); - } + using CType = typename TETypeToCKType::type; + + if (ctx.N % 256 == 0) { + MAKE_RUNNER(TileCfg_256x256x64); + } else if (ctx.N % 128 == 0) { + MAKE_RUNNER(TileCfg_256x128x64); + } else { + MAKE_RUNNER(TileCfg_256x128x64_padding); + } }); + return runner; } #undef MAKE_RUNNER -static std::unique_ptr -make_fp16_runner(DType a_dtype, - DType b_dtype, - DType d_dtype, - const GroupedGemmRunContext& ctx) { - - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { - using ALayout = std::conditional_t; +static std::unique_ptr make_fp16_runner( + DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { + using ALayout = std::conditional_t; - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { - using BLayout = std::conditional_t; + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { + using BLayout = std::conditional_t; - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, a_type, { - return make_fp16_runner_typed(d_dtype, ctx); - }); - }); + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, a_type, { + return make_fp16_runner_typed(d_dtype, ctx); + }); }); + }); - return nullptr; + return nullptr; } bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, - DType b_dtype, - DType d_dtype, - const GroupedGemmRunContext& ctx) { - const ck_tile::stream_config s{ctx.stream}; + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + const ck_tile::stream_config s{ctx.stream}; - auto runner = make_fp16_runner( - a_dtype, b_dtype, d_dtype, ctx); - - if (!runner) { - return false; - } + auto runner = make_fp16_runner(a_dtype, b_dtype, d_dtype, ctx); + if (!runner) { + return false; + } - return runner->run(s, ctx); + return runner->run(s, ctx); } } // namespace grouped_gemm -} // namespace transformer_engine +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.h new file mode 100644 index 000000000..02368c87b --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.h @@ -0,0 +1,18 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#pragma once + +namespace transformer_engine { +namespace grouped_gemm { + +bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx); + +} // namespace grouped_gemm +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h deleted file mode 100644 index e4138a132..000000000 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h +++ /dev/null @@ -1,214 +0,0 @@ -/************************************************************************* - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. - * - * License for AMD contributions = MIT. See LICENSE for more information - ************************************************************************/ - -#pragma once -#include "ck_grouped_gemm_common.h" - -#include -#include -#include -#include - -#include - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" -#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" - -namespace transformer_engine { -namespace grouped_gemm { - -// ------------------------- -// Tile configs: FP16/BF16 -// ------------------------- - -struct TileCfg_256x256x64 { - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool DoubleSmemBuffer = false; - - static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; - static constexpr ck_tile::index_t TilePartitionerM01 = 4; -}; - -struct TileCfg_256x128x64 : TileCfg_256x256x64 { - static constexpr ck_tile::index_t N_Tile = 128; -}; - -struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { - static constexpr bool kPadN = true; -}; - -template -class GroupedGemmRunner : public RunnerInterface { -public: - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile::sequence>; - - using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< - GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; - - using UniversalTraits = - ck_tile::PersistentTileGemmUniversalTraits< - TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, - TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; - - static constexpr ck_tile::GemmPipelineScheduler Scheduler = - ck_tile::GemmPipelineScheduler::Intrawave; - - using Problem = ck_tile::UniversalGemmPipelineProblem< - AType, BType, AccType, - GemmShape, UniversalTraits, Scheduler>; - - using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; - - using ET = EpilogueTraits; - - using Epilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem< - AType, BType, typename ET::DsDataType, AccType, - CType, typename ET::DsLayout, CLayout, - typename ET::ElemOp, - Partitioner::MPerBlock, Partitioner::NPerBlock, - TileCfg::M_Warp, TileCfg::N_Warp, - TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, - Problem::TransposeC>>; - - using Kernel = ck_tile::GroupedGemmKernel; - - // GroupedGemmHostArgs<1> for the MultiD accumulate path, <0> for the overwrite path. - using HostArgs = std::conditional_t, - ck_tile::GroupedGemmHostArgs<0>>; - -public: - static std::vector build_descs(const GroupedGemmRunContext& ctx) { - const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); - if (!ctx.workspace || ctx.workspace_bytes < needed) { - NVTE_WARN("ck_tile_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, - ", available bytes=", ctx.workspace_bytes, ". Falling back."); - return {}; - } - - thread_local std::vector descs; - descs.clear(); - descs.reserve(ctx.group_num); - for (int i = 0; i < ctx.group_num; ++i) { - const transformer_engine::Tensor* const A_te = - transformer_engine::convertNVTETensorCheck(ctx.A[i]); - const transformer_engine::Tensor* const B_te = - transformer_engine::convertNVTETensorCheck(ctx.B[i]); - transformer_engine::Tensor* D_te = - transformer_engine::convertNVTETensorCheck(ctx.D[i]); - - const auto& a = data_view(*A_te); - const auto& b = data_view(*B_te); - const auto& d = data_view(*D_te); - - int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || - !get_flat_2d_dims(*B_te, Bd0, Bd1) || - !get_flat_2d_dims(*D_te, Dd0, Dd1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2."); - } - - const int64_t M = ctx.transA ? Ad1 : Ad0; - const int64_t K = ctx.transA ? Ad0 : Ad1; - const int64_t N = ctx.transB ? Bd0 : Bd1; - const int64_t Kb = ctx.transB ? Bd1 : Bd0; - - if (Kb != K) { - NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); - } - - if (Dd0 != M || Dd1 != N) { - NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); - } - - const ck_tile::index_t stride_A = Ad1; - const ck_tile::index_t stride_B = Bd1; - const ck_tile::index_t stride_E = Dd1; - - if constexpr(Accumulate) { - descs.emplace_back( - a.dptr, b.dptr, - std::array{d.dptr}, // D1 = existing D contents (read) - d.dptr, // E = same buffer (write) - 1, M, N, K, - stride_A, stride_B, - std::array{stride_E}, - stride_E); - } else { - descs.emplace_back( - a.dptr, b.dptr, - std::array{}, - d.dptr, - 1, M, N, K, - stride_A, stride_B, - std::array{}, - stride_E); - } - } - - return descs; - }; - - - bool run(const ck_tile::stream_config& stream_cfg, - const GroupedGemmRunContext& ctx) override { - auto descs = build_descs(ctx); - - constexpr int kBlockPerCu = 1; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(descs); - auto kargs = Kernel::MakeKargs(descs); - if (!Kernel::IsSupportedArgument(kargs)) { - NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. " - "Falling back."); - return false; - } - - NVTE_CHECK_CUDA(cudaMemcpyAsync(ctx.workspace, - kargs.data(), - kargs.size() * sizeof(typename decltype(kargs)::value_type), - cudaMemcpyHostToDevice, - ctx.stream)); - - ck_tile::launch_kernel( - stream_cfg, ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, - ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), - ctx.group_num)); - return true; - }; -}; - -} // namespace grouped_gemm -} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index 6419e3188..47ba29e52 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -5,8 +5,14 @@ ************************************************************************/ #include "ck_grouped_gemm_common.h" -#include "ck_grouped_gemm_fp8_impl.h" -#include "common/util/cuda_runtime.h" +#include "ck_grouped_gemm_fp8.h" + +#include + +#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" namespace transformer_engine { namespace grouped_gemm { @@ -17,13 +23,272 @@ enum class GPUArch { UNKNOWN }; +struct TileCfg_128x128x128_16x16x128_2x2x1 { + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 16; + static constexpr ck_tile::index_t TilePartitionerM01 = 8; +}; + +// gfx950 device compilation cannot instantiate the literal 32x32x16 FP8 tile +// configuration due to an unsupported warp GEMM dispatcher configuration. +// See: ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp for supported variants. +// +// To preserve the existing type name in shared template code, this struct +// inherits from the gfx950-safe 16x16x128 configuration in the gfx950 device +// compilation path, effectively reusing those parameters without redefining them. +// +// In all other compilation paths, the struct overrides the relevant fields to +// provide the intended 32x32x16 configuration. +#if defined(__gfx950__) +struct TileCfg_128x128x128_32x32x16_2x2x1 + : TileCfg_128x128x128_16x16x128_2x2x1 {}; +#else +struct TileCfg_128x128x128_32x32x16_2x2x1 + : TileCfg_128x128x128_16x16x128_2x2x1 { + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; +#endif + +// FP8 currently supports overwrite only. +// Preserve MemOp here for a future accumulate path. +template +class QuantGroupedGemmRunner : public RunnerInterface { + public: + static constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::TensorQuant; + + using GemmShape = GroupedGemmShape; + using Partitioner = GroupedGemmPartitioner; + + using AQLayout = RowMajor; + using BQLayout = RowMajor; + + using UniversalTraits = + ck_tile::TileGemmQuantTraits; + + using Problem = + ck_tile::GemmRowColTensorQuantPipelineProblem; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using Epilogue = + ck_tile::CShuffleEpilogue, + AccType, + CType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + Partitioner::MPerBlock, + Partitioner::NPerBlock, + TileCfg::M_Warp, + TileCfg::N_Warp, + TileCfg::M_Warp_Tile, + TileCfg::N_Warp_Tile, + TileCfg::K_Warp_Tile, + Problem::TransposeC>>; + + using Kernel = + ck_tile::QuantGroupedGemmKernel; + using HostArgs = ck_tile::QuantGroupedGemmHostArgs; + + public: + static std::vector build_descs(const GroupedGemmRunContext& ctx) { + if (!has_sufficient_workspace(ctx)) { + return {}; + } + + std::vector descs; + descs.reserve(ctx.group_num); + + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(ctx.D[i]); + + const auto& a = data_view(*A_te); + const auto& d = data_view(*D_te); + + const transformer_engine::SimpleTensor* b_src = nullptr; + if (ctx.use_b_columnwise_data) { + if (!B_te->has_columnwise_data()) { + NVTE_ERROR("ck_tile_grouped_gemm: ctx.use_b_columnwise_data=true but " + "columnwise_data is absent."); + } + b_src = &B_te->columnwise_data; + } else { + b_src = &B_te->data; + } + + const auto& b = *b_src; + + int64_t Ad0 = 0, Ad1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || + !get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected A and D to be rank>=2."); + } + + if (b.shape.size() < 2) { + NVTE_ERROR("ck_tile_grouped_gemm: expected chosen B source to be rank>=2."); + } + + int64_t Bd0 = static_cast(b.shape[b.shape.size() - 2]); + int64_t Bd1 = static_cast(b.shape[b.shape.size() - 1]); + + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + const int64_t Kb = ctx.transB ? Bd1 : Bd0; + + if (Kb != K) { + NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", + i, + ". op(A)=", + M, + "x", + K, + " op(B)=", + Kb, + "x", + N, + " raw A=", + Ad0, + "x", + Ad1, + " raw B=", + Bd0, + "x", + Bd1, + " use_b_columnwise_data=", + static_cast(ctx.use_b_columnwise_data), + " transA=", + static_cast(ctx.transA), + " transB=", + static_cast(ctx.transB)); + } + + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", + i, + ". D=", + Dd0, + "x", + Dd1, + ", expected=", + M, + "x", + N); + } + + const ck_tile::index_t stride_A = static_cast(Ad1); + const ck_tile::index_t stride_B = static_cast(Bd1); + const ck_tile::index_t stride_E = static_cast(Dd1); + + ck_tile::index_t AQK = 1; + ck_tile::index_t BQK = 1; + ck_tile::index_t stride_AQ = 1; + ck_tile::index_t stride_BQ = 1; + + const auto& aq = scale_inv_view(*A_te); + const auto& bq = scale_inv_view(*B_te); + + descs.emplace_back(a.dptr, + b.dptr, + d.dptr, + aq.dptr, + bq.dptr, + 1, + M, + N, + K, + AQK, + BQK, + stride_A, + stride_B, + stride_E, + stride_AQ, + stride_BQ); + } + + return descs; + } + + bool run(const ck_tile::stream_config& stream_cfg, + const GroupedGemmRunContext& ctx) override { + auto descs = build_descs(ctx); + if (descs.empty()) { + return false; + } + return launch_grouped_gemm_kernel(descs, ctx, stream_cfg); + } +}; + static inline GPUArch detect_gpu_arch() { - int arch = cuda::sm_arch(0); + int device = 0; + HIP_CHECK_ERROR(hipGetDevice(&device)); + + hipDeviceProp_t props{}; + HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device)); - if (arch == 94) { + if (props.major == 9 && props.minor == 4) { return GPUArch::GFX942; } - if (arch == 95) { + if (props.major == 9 && props.minor == 5) { return GPUArch::GFX950; } return GPUArch::UNKNOWN; @@ -43,8 +308,9 @@ struct FP8TileCfg { }; template -static std::unique_ptr make_fp8_runner_typed(DType d_dtype, - const GroupedGemmRunContext& ctx) { +static std::unique_ptr make_fp8_runner_typed( + DType d_dtype, + const GroupedGemmRunContext& ctx) { std::unique_ptr runner = nullptr; using AType = typename TETypeToCKType::type; @@ -54,10 +320,14 @@ static std::unique_ptr make_fp8_runner_typed(DType d_dtype, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { using CType = typename TETypeToCKType::type; - using Runner = QuantGroupedGemmRunner< - AType, BType, CType, - ALayout, BLayout, CTypeLayout, - TileCfg, ck_tile::memory_operation_enum::set>; + using Runner = QuantGroupedGemmRunner; runner = std::make_unique(); }); @@ -65,48 +335,50 @@ static std::unique_ptr make_fp8_runner_typed(DType d_dtype, } template -static std::unique_ptr make_fp8_runner_impl(DType a_dtype, - DType b_dtype, - DType d_dtype, - const GroupedGemmRunContext& ctx) { - - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { - using ALayout = std::conditional_t; - - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { - using BLayout = std::conditional_t; - - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_type, { - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_type, { - return make_fp8_runner_typed( - d_dtype, ctx); - }); - }); +static std::unique_ptr make_fp8_runner_impl( + DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { + using ALayout = std::conditional_t; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { + using BLayout = std::conditional_t; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_type, { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_type, { + return make_fp8_runner_typed( + d_dtype, ctx); }); + }); }); + }); - return nullptr; + return nullptr; } -static inline std::unique_ptr make_fp8_runner_gfx942(DType a_dtype, - DType b_dtype, - DType d_dtype, - const GroupedGemmRunContext& ctx) { +static inline std::unique_ptr make_fp8_runner_gfx942( + DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { return make_fp8_runner_impl(a_dtype, b_dtype, d_dtype, ctx); } -static inline std::unique_ptr make_fp8_runner_gfx950(DType a_dtype, - DType b_dtype, - DType d_dtype, - const GroupedGemmRunContext& ctx) { +static inline std::unique_ptr make_fp8_runner_gfx950( + DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { return make_fp8_runner_impl(a_dtype, b_dtype, d_dtype, ctx); } -static std::unique_ptr -make_fp8_runner(DType a_dtype, - DType b_dtype, - DType d_dtype, - const GroupedGemmRunContext& ctx) { +static std::unique_ptr make_fp8_runner( + DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { switch (detect_gpu_arch()) { case GPUArch::GFX942: return make_fp8_runner_gfx942(a_dtype, b_dtype, d_dtype, ctx); @@ -123,12 +395,14 @@ bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, DType d_dtype, const GroupedGemmRunContext& ctx) { const ck_tile::stream_config s{ctx.stream}; + auto runner = make_fp8_runner(a_dtype, b_dtype, d_dtype, ctx); if (!runner) { return false; } + return runner->run(s, ctx); } } // namespace grouped_gemm -} // namespace transformer_engine +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.h new file mode 100644 index 000000000..4109eef85 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.h @@ -0,0 +1,18 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#pragma once + +namespace transformer_engine { +namespace grouped_gemm { + +bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx); + +} // namespace grouped_gemm +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8_impl.h deleted file mode 100644 index bbc994fb4..000000000 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8_impl.h +++ /dev/null @@ -1,260 +0,0 @@ -/************************************************************************* - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. - * - * License for AMD contributions = MIT. See LICENSE for more information - ************************************************************************/ - -#pragma once -#include "ck_grouped_gemm_common.h" - -#include -#include -#include -#include - -#include - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" -#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" - -namespace transformer_engine { -namespace grouped_gemm { - -struct TileCfg_128x128x128_16x16x128_2x2x1 { - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 128; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool DoubleSmemBuffer = false; - - static constexpr ck_tile::index_t TilePartitionerGroupNum = 16; - static constexpr ck_tile::index_t TilePartitionerM01 = 8; -}; - -// gfx950 device compilation cannot instantiate the literal 32x32x16 FP8 tile -// configuration due to an unsupported warp GEMM dispatcher configuration. -// See: ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp for supported variants. -// -// To preserve the existing type name in shared template code, this struct -// inherits from the gfx950-safe 16x16x128 configuration in the gfx950 device -// compilation path, effectively reusing those parameters without redefining them. -// -// In all other compilation paths, the struct overrides the relevant fields to -// provide the intended 32x32x16 configuration. -#if defined(__gfx950__) -struct TileCfg_128x128x128_32x32x16_2x2x1 - : TileCfg_128x128x128_16x16x128_2x2x1 { -}; -#else -struct TileCfg_128x128x128_32x32x16_2x2x1 - : TileCfg_128x128x128_16x16x128_2x2x1 { - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; - static constexpr ck_tile::index_t TilePartitionerM01 = 4; -}; -#endif - -template -class QuantGroupedGemmRunner : public RunnerInterface { -public: - static constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::TensorQuant; - - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile::sequence>; - - using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< - GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; - - using AQLayout = RowMajor; - using BQLayout = RowMajor; - - using UniversalTraits = - ck_tile::TileGemmQuantTraits< - TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, - false, false, false, ALayout, BLayout, CLayout, - QuantMode, AQLayout, BQLayout, - false, TileCfg::DoubleSmemBuffer, false>; - - using Problem = ck_tile::GemmRowColTensorQuantPipelineProblem< - AType, BType, AccType, - AccType, GemmShape, UniversalTraits, - false, AccType>; - - using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; - - using Epilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem< - AType, BType, ck_tile::tuple<>, AccType, - CType, ck_tile::tuple<>, CLayout, - ck_tile::element_wise::PassThrough, - Partitioner::MPerBlock, Partitioner::NPerBlock, - TileCfg::M_Warp, TileCfg::N_Warp, - TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, - Problem::TransposeC>>; - - using Kernel = ck_tile::QuantGroupedGemmKernel; - using HostArgs = ck_tile::QuantGroupedGemmHostArgs; - -public: - static std::vector build_descs(const GroupedGemmRunContext& ctx) { - const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); - if (!ctx.workspace || ctx.workspace_bytes < needed) { - NVTE_WARN("ck_tile_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, - ", available bytes=", ctx.workspace_bytes, ". Falling back."); - return {}; - } - - std::vector descs; - descs.reserve(ctx.group_num); - - for (int i = 0; i < ctx.group_num; ++i) { - const transformer_engine::Tensor* const A_te = - transformer_engine::convertNVTETensorCheck(ctx.A[i]); - const transformer_engine::Tensor* const B_te = - transformer_engine::convertNVTETensorCheck(ctx.B[i]); - transformer_engine::Tensor* D_te = - transformer_engine::convertNVTETensorCheck(ctx.D[i]); - - const auto& a = data_view(*A_te); - const auto& d = data_view(*D_te); - - const transformer_engine::SimpleTensor* b_src = nullptr; - if (ctx.use_b_columnwise_data) { - if (!B_te->has_columnwise_data()) { - NVTE_ERROR("ck_tile_grouped_gemm: ctx.use_b_columnwise_data=true but columnwise_data is absent."); - } - b_src = &B_te->columnwise_data; - } else { - b_src = &B_te->data; - } - - const auto& b = *b_src; - - int64_t Ad0 = 0, Ad1 = 0, Dd0 = 0, Dd1 = 0; - if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || - !get_flat_2d_dims(*D_te, Dd0, Dd1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected A and D to be rank>=2."); - } - - if (b.shape.size() < 2) { - NVTE_ERROR("ck_tile_grouped_gemm: expected chosen B source to be rank>=2."); - } - - int64_t Bd0 = static_cast(b.shape[b.shape.size() - 2]); - int64_t Bd1 = static_cast(b.shape[b.shape.size() - 1]); - - const int64_t M = ctx.transA ? Ad1 : Ad0; - const int64_t K = ctx.transA ? Ad0 : Ad1; - const int64_t N = ctx.transB ? Bd0 : Bd1; - const int64_t Kb = ctx.transB ? Bd1 : Bd0; - - if (Kb != K) { - NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i, - ". op(A)=", M, "x", K, - " op(B)=", Kb, "x", N, - " raw A=", Ad0, "x", Ad1, - " raw B=", Bd0, "x", Bd1, - " use_b_columnwise_data=", static_cast(ctx.use_b_columnwise_data), - " transA=", static_cast(ctx.transA), - " transB=", static_cast(ctx.transB)); - } - - if (Dd0 != M || Dd1 != N) { - NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i, - ". D=", Dd0, "x", Dd1, - ", expected=", M, "x", N); - } - - const ck_tile::index_t stride_A = static_cast(Ad1); - const ck_tile::index_t stride_B = static_cast(Bd1); - const ck_tile::index_t stride_E = static_cast(Dd1); - - ck_tile::index_t AQK = 1; - ck_tile::index_t BQK = 1; - ck_tile::index_t stride_AQ = 1; - ck_tile::index_t stride_BQ = 1; - - const auto& aq = scale_inv_view(*A_te); - const auto& bq = scale_inv_view(*B_te); - - descs.emplace_back( - a.dptr, - b.dptr, - d.dptr, - aq.dptr, - bq.dptr, - 1, - M, - N, - K, - AQK, - BQK, - stride_A, - stride_B, - stride_E, - stride_AQ, - stride_BQ); - } - - return descs; - } - - bool run(const ck_tile::stream_config& stream_cfg, - const GroupedGemmRunContext& ctx) override { - auto descs = build_descs(ctx); - - constexpr int kBlockPerCu = 1; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(descs); - auto kargs = Kernel::MakeKargs(descs); - - if (!Kernel::IsSupportedArgument(kargs)) { - NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. " - "Falling back."); - return false; - } - - NVTE_CHECK_CUDA(cudaMemcpyAsync(ctx.workspace, - kargs.data(), - kargs.size() * sizeof(typename decltype(kargs)::value_type), - cudaMemcpyHostToDevice, - ctx.stream)); - - ck_tile::launch_kernel( - stream_cfg, ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, - ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), - ctx.group_num)); - return true; - } -}; - -} // namespace grouped_gemm -} // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 056dc14ac..db3deab3b 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -423,6 +423,7 @@ enum class DType { kNumTypes }; +#ifdef USE_ROCM /*! \brief Check if TE datatype is FP16 * * Return true if TE datatype is FP16 @@ -431,7 +432,7 @@ enum class DType { inline bool is_fp16_dtype(const DType t) { return t == DType::kFloat16 || t == DType::kBFloat16; } - +#endif /*! \brief Check if TE datatype is FP8 * From 68fb511143889f016772d434b798c3c62a8b1da6 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Fri, 3 Apr 2026 15:02:38 +0000 Subject: [PATCH 03/14] add EOL --- .../common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp | 2 +- .../common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp index da464d977..2995cffb5 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp @@ -270,4 +270,4 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, } } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index 47ba29e52..65f6f1acf 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -405,4 +405,4 @@ bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, } } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine From c77dff33a05952831b4a590f89b0ea6c3f1c5c42 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Fri, 3 Apr 2026 15:03:24 +0000 Subject: [PATCH 04/14] add EOL --- .../common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h index 3d82654ad..f744e6eef 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -165,4 +165,4 @@ class RunnerInterface { }; } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine From f8293a0ae9b19c2373cbf2ed7eaf7ec105a91f46 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Fri, 3 Apr 2026 15:15:30 +0000 Subject: [PATCH 05/14] fix alignment --- .../common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp index 347c10399..4e252efe8 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp @@ -43,7 +43,7 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, const auto caller_a_dtype = convertNVTETensorCheck(A[0])->dtype(); const bool is_8bit_float = is_fp8_dtype(caller_a_dtype); - const bool is_16bit_float = is_fp16_dtype(caller_a_dtype); + const bool is_16bit_float = is_fp16_dtype(caller_a_dtype); // Currently the accumulate path is only supported on fp16 if (accumulate && is_8bit_float) From 1ef53986c2c6e58dbf232e709c8b37a08585e544 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Fri, 3 Apr 2026 15:33:16 +0000 Subject: [PATCH 06/14] use faster impl for detect gpu arch --- .../gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index 65f6f1acf..9caae2d14 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -6,8 +6,7 @@ #include "ck_grouped_gemm_common.h" #include "ck_grouped_gemm_fp8.h" - -#include +#include "common/util/cuda_runtime.h" #include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" @@ -279,16 +278,12 @@ class QuantGroupedGemmRunner : public RunnerInterface { }; static inline GPUArch detect_gpu_arch() { - int device = 0; - HIP_CHECK_ERROR(hipGetDevice(&device)); - - hipDeviceProp_t props{}; - HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device)); + int arch = cuda::sm_arch(0); - if (props.major == 9 && props.minor == 4) { + if (arch == 94) { return GPUArch::GFX942; } - if (props.major == 9 && props.minor == 5) { + if (arch == 95) { return GPUArch::GFX950; } return GPUArch::UNKNOWN; From 391f22acb77257977e0d5d318cd84d0b9d7d35b6 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Fri, 3 Apr 2026 16:45:55 +0000 Subject: [PATCH 07/14] fuse small helpers into dispatch funcs --- .../ck_grouped_gemm/ck_grouped_gemm_fp16.cpp | 86 ++++----- .../ck_grouped_gemm/ck_grouped_gemm_fp8.cpp | 168 ++++++------------ 2 files changed, 88 insertions(+), 166 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp index 2995cffb5..660dbefb8 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp @@ -196,72 +196,52 @@ class GroupedGemmRunner : public RunnerInterface { } }; -#define MAKE_RUNNER(TileCfg_) \ - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.accumulate, accum_option, { \ - using Runner = GroupedGemmRunner; \ - runner = std::make_unique(); \ +#define MAKE_RUNNER(TileCfg_) \ + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.accumulate, accum_option, { \ + using Runner = GroupedGemmRunner; \ + runner = std::make_unique(); \ }) -template -static std::unique_ptr make_fp16_runner_typed( - DType d_dtype, - const GroupedGemmRunContext& ctx) { +bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + const ck_tile::stream_config s{ctx.stream}; std::unique_ptr runner = nullptr; - using AType = typename TETypeToCKType::type; - using BType = typename TETypeToCKType::type; - using CLayout = RowMajor; - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { - using CType = typename TETypeToCKType::type; - - if (ctx.N % 256 == 0) { - MAKE_RUNNER(TileCfg_256x256x64); - } else if (ctx.N % 128 == 0) { - MAKE_RUNNER(TileCfg_256x128x64); - } else { - MAKE_RUNNER(TileCfg_256x128x64_padding); - } - }); - - return runner; -} - -#undef MAKE_RUNNER -static std::unique_ptr make_fp16_runner( - DType a_dtype, - DType b_dtype, - DType d_dtype, - const GroupedGemmRunContext& ctx) { TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { using ALayout = std::conditional_t; TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { using BLayout = std::conditional_t; - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, a_type, { - return make_fp16_runner_typed(d_dtype, ctx); + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, a_te_type, { + using AType = typename TETypeToCKType::type; + using BType = typename TETypeToCKType::type; + using CLayout = RowMajor; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + + if (ctx.N % 256 == 0) { + MAKE_RUNNER(TileCfg_256x256x64); + } else if (ctx.N % 128 == 0) { + MAKE_RUNNER(TileCfg_256x128x64); + } else { + MAKE_RUNNER(TileCfg_256x128x64_padding); + } + }); }); }); }); - return nullptr; -} - -bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, - DType b_dtype, - DType d_dtype, - const GroupedGemmRunContext& ctx) { - const ck_tile::stream_config s{ctx.stream}; - - auto runner = make_fp16_runner(a_dtype, b_dtype, d_dtype, ctx); if (!runner) { return false; } @@ -269,5 +249,7 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, return runner->run(s, ctx); } +#undef MAKE_RUNNER + } // namespace grouped_gemm } // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index 9caae2d14..567fec9ef 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -166,8 +166,7 @@ class QuantGroupedGemmRunner : public RunnerInterface { const transformer_engine::SimpleTensor* b_src = nullptr; if (ctx.use_b_columnwise_data) { if (!B_te->has_columnwise_data()) { - NVTE_ERROR("ck_tile_grouped_gemm: ctx.use_b_columnwise_data=true but " - "columnwise_data is absent."); + NVTE_ERROR("ck_tile_grouped_gemm: ctx.use_b_columnwise_data=true but columnwise_data is absent."); } b_src = &B_te->columnwise_data; } else { @@ -176,18 +175,24 @@ class QuantGroupedGemmRunner : public RunnerInterface { const auto& b = *b_src; - int64_t Ad0 = 0, Ad1 = 0, Dd0 = 0, Dd1 = 0; - if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || - !get_flat_2d_dims(*D_te, Dd0, Dd1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected A and D to be rank>=2."); + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized A in group ", i); } - if (b.shape.size() < 2) { - NVTE_ERROR("ck_tile_grouped_gemm: expected chosen B source to be rank>=2."); + if (ctx.use_b_columnwise_data) { + if (!get_columnwise_storage_2d_dims(B_te->columnwise_data, Bd0, Bd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected 2D columnwise_data for B in group ", i); + } + } else { + if (!get_flat_2d_dims(*B_te, Bd0, Bd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized B in group ", i); + } } - int64_t Bd0 = static_cast(b.shape[b.shape.size() - 2]); - int64_t Bd1 = static_cast(b.shape[b.shape.size() - 1]); + if (!get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized D in group ", i); + } const int64_t M = ctx.transA ? Ad1 : Ad0; const int64_t K = ctx.transA ? Ad0 : Ad1; @@ -195,43 +200,13 @@ class QuantGroupedGemmRunner : public RunnerInterface { const int64_t Kb = ctx.transB ? Bd1 : Bd0; if (Kb != K) { - NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", - i, - ". op(A)=", - M, - "x", - K, - " op(B)=", - Kb, - "x", - N, - " raw A=", - Ad0, - "x", - Ad1, - " raw B=", - Bd0, - "x", - Bd1, - " use_b_columnwise_data=", - static_cast(ctx.use_b_columnwise_data), - " transA=", - static_cast(ctx.transA), - " transB=", - static_cast(ctx.transB)); + NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i, + ". op(A)=", M, "x", K, ", op(B)=", Kb, "x", N); } if (Dd0 != M || Dd1 != N) { - NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", - i, - ". D=", - Dd0, - "x", - Dd1, - ", expected=", - M, - "x", - N); + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i, + ". D=", Dd0, "x", Dd1, ", expected=", M, "x", N); } const ck_tile::index_t stride_A = static_cast(Ad1); @@ -302,101 +277,66 @@ struct FP8TileCfg { using type = TileCfg_128x128x128_16x16x128_2x2x1; }; -template -static std::unique_ptr make_fp8_runner_typed( - DType d_dtype, - const GroupedGemmRunContext& ctx) { +template +static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + const ck_tile::stream_config s{ctx.stream}; std::unique_ptr runner = nullptr; - using AType = typename TETypeToCKType::type; - using BType = typename TETypeToCKType::type; using CTypeLayout = RowMajor; using TileCfg = typename FP8TileCfg::type; - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { - using CType = typename TETypeToCKType::type; - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - }); - - return runner; -} - -template -static std::unique_ptr make_fp8_runner_impl( - DType a_dtype, - DType b_dtype, - DType d_dtype, - const GroupedGemmRunContext& ctx) { TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { using ALayout = std::conditional_t; TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { using BLayout = std::conditional_t; - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_type, { - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_type, { - return make_fp8_runner_typed( - d_dtype, ctx); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, { + using AType = typename TETypeToCKType::type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, { + using BType = typename TETypeToCKType::type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + }); }); }); }); }); - return nullptr; -} - -static inline std::unique_ptr make_fp8_runner_gfx942( - DType a_dtype, - DType b_dtype, - DType d_dtype, - const GroupedGemmRunContext& ctx) { - return make_fp8_runner_impl(a_dtype, b_dtype, d_dtype, ctx); -} + if (!runner) { + return false; + } -static inline std::unique_ptr make_fp8_runner_gfx950( - DType a_dtype, - DType b_dtype, - DType d_dtype, - const GroupedGemmRunContext& ctx) { - return make_fp8_runner_impl(a_dtype, b_dtype, d_dtype, ctx); + return runner->run(s, ctx); } -static std::unique_ptr make_fp8_runner( - DType a_dtype, - DType b_dtype, - DType d_dtype, - const GroupedGemmRunContext& ctx) { +bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { switch (detect_gpu_arch()) { case GPUArch::GFX942: - return make_fp8_runner_gfx942(a_dtype, b_dtype, d_dtype, ctx); + return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); case GPUArch::GFX950: - return make_fp8_runner_gfx950(a_dtype, b_dtype, d_dtype, ctx); + return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); default: NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950}"); - return nullptr; - } -} - -bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, - DType b_dtype, - DType d_dtype, - const GroupedGemmRunContext& ctx) { - const ck_tile::stream_config s{ctx.stream}; - - auto runner = make_fp8_runner(a_dtype, b_dtype, d_dtype, ctx); - if (!runner) { - return false; + return false; } - - return runner->run(s, ctx); } } // namespace grouped_gemm From d266d33fe8d78703eb7535fe275b471279c8d4fb Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 6 Apr 2026 20:38:16 +0000 Subject: [PATCH 08/14] remove file --- .../gemm/ck_grouped_gemm/ck_grouped_gemm.cpp | 136 ------------------ 1 file changed, 136 deletions(-) delete mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp deleted file mode 100644 index 4e252efe8..000000000 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp +++ /dev/null @@ -1,136 +0,0 @@ -/************************************************************************* - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. - * - * License for AMD contributions = MIT. See LICENSE for more information - ************************************************************************/ - -#include "ck_grouped_gemm_common.h" -#include "ck_grouped_gemm_fp16.h" -#include "ck_grouped_gemm_fp8.h" - -bool ck_tile_grouped_gemm(const NVTETensor* A, - const NVTETensor* B, - NVTETensor* D, - int group_num, - bool transA, - bool transB, - NVTETensor* workspace, - bool accumulate, - hipStream_t stream) { - if (group_num <= 0) { - return true; - } - - using namespace transformer_engine; - using namespace transformer_engine::grouped_gemm; - - void* ws_ptr = nullptr; - size_t ws_bytes = 0; - if (workspace) { - auto* ws_te = convertNVTETensorCheck(*workspace); - ws_ptr = ws_te->data.dptr; - ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); - } - - // Normalize similar to upstream - // See https://github.com/NVIDIA/TransformerEngine/blob/59f6f3876767d07045152bfae07b5dd4c54e1725/transformer_engine/common/gemm/cutlass_grouped_gemm.cu#L54-L68 - // I.e., swap A and B, as well as transa and transb. - const NVTETensor* A_use = B; - const NVTETensor* B_use = A; - bool transA_use = transB; - bool transB_use = transA; - bool use_b_columnwise_data = false; - - const auto caller_a_dtype = convertNVTETensorCheck(A[0])->dtype(); - const bool is_8bit_float = is_fp8_dtype(caller_a_dtype); - const bool is_16bit_float = is_fp16_dtype(caller_a_dtype); - - // Currently the accumulate path is only supported on fp16 - if (accumulate && is_8bit_float) - return false; - - // Handle pathological NN case during fp8 dX GEMM by reading W columnwise and re-formulating as NT - if (!transA_use && !transB_use && is_8bit_float) { - auto* B0_te = convertNVTETensorCheck(B_use[0]); - if (B0_te->has_columnwise_data()) { - use_b_columnwise_data = true; - transB_use = true; - } - } - - const auto a_dtype = convertNVTETensorCheck(A_use[0])->dtype(); - const auto b_dtype = convertNVTETensorCheck(B_use[0])->dtype(); - - Tensor* D0_te = convertNVTETensorCheck(D[0]); - const auto d_dtype = D0_te->dtype(); - - Tensor* A0_te = convertNVTETensorCheck(A_use[0]); - Tensor* B0_te = convertNVTETensorCheck(B_use[0]); - - int64_t a0 = 0, a1 = 0; - int64_t b0 = 0, b1 = 0; - int64_t d0 = 0, d1 = 0; - - if (!get_flat_2d_dims(*A0_te, a0, a1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized A_use[0]"); - return false; - } - - if (use_b_columnwise_data) { - if (!get_columnwise_storage_2d_dims(B0_te->columnwise_data, b0, b1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected 2D columnwise_data for B_use[0]"); - return false; - } - } else { - if (!get_flat_2d_dims(*B0_te, b0, b1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized B_use[0]"); - return false; - } - } - - if (!get_flat_2d_dims(*D0_te, d0, d1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]"); - return false; - } - - const int64_t m = transA_use ? a1 : a0; - const int64_t kA = transA_use ? a0 : a1; - - const int64_t kB = transB_use ? b1 : b0; - const int64_t n = transB_use ? b0 : b1; - - if (kA != kB) { - NVTE_ERROR("ck_tile_grouped_gemm: normalized GEMM K mismatch: op(A_use) is ", - m, "x", kA, ", op(B_use) is ", kB, "x", n); - return false; - } - - if (d0 != m || d1 != n) { - NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch for normalized GEMM. " - "D is ", d0, "x", d1, " but expected ", m, "x", n); - return false; - } - - GroupedGemmRunContext ctx = { - A_use, - B_use, - D, - static_cast(n), - group_num, - transA_use, - transB_use, - ws_ptr, - ws_bytes, - stream, - use_b_columnwise_data, - accumulate}; - - if (is_16bit_float) { - return ck_tile_grouped_gemm_fp16_dispatch(a_dtype, b_dtype, d_dtype, ctx); - } else if (is_8bit_float) { - return ck_tile_grouped_gemm_fp8_dispatch(a_dtype, b_dtype, d_dtype, ctx); - } - - NVTE_WARN("ck_tile_grouped_gemm: input dtype is neither fp16 nor fp8."); - return false; -} From 7976ab4e64d9157081af7b0fb20f1deac0fee0a7 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 6 Apr 2026 20:39:47 +0000 Subject: [PATCH 09/14] add file --- .../common/gemm/ck_grouped_gemm.cpp | 338 ++++++++++++++++++ 1 file changed, 338 insertions(+) create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm.cpp diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp new file mode 100644 index 000000000..def454f86 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -0,0 +1,338 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include + +#include +#include "../common.h" + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; + +template struct TETypeToCKType; +template <> struct TETypeToCKType { using type = ck_tile::half_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; + +// Treat TE tensors as generalized 2D matrices by flattening: +// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. +static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, + int64_t& d0, int64_t& d1) { + // Require at least a matrix (rank >= 2). Higher ranks are flattened. + if (t.shape().size() < 2) + return false; + d0 = static_cast(t.flat_first_dim()); + d1 = static_cast(t.flat_last_dim()); + return true; +} + +// Selects epilogue traits based on whether we are accumulating (D += A*B) or not (D = A*B). +// For accumulate=true, the existing D buffer is passed as a MultiD input tensor and combined +// via element_wise::Add. For accumulate=false, no extra input is needed and PassThrough is used. +template +struct EpilogueTraits { + using DsDataType = ck_tile::tuple<>; + using DsLayout = ck_tile::tuple<>; + using ElemOp = ck_tile::element_wise::PassThrough; +}; +template +struct EpilogueTraits { + using DsDataType = ck_tile::tuple; + using DsLayout = ck_tile::tuple; + using ElemOp = ck_tile::element_wise::Add; +}; + +static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { + return t.data; // rowwise data view +} + +// Primus-Turbo-like FP16/BF16 tile configs +// Selection rule: +// if (N % 256 == 0) use 256x256x64 +// else if (N % 128 == 0) use 256x128x64 +// else use 256x128x64 with N padding enabled +struct TileCfg_256x256x64 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +struct TileCfg_256x128x64 : TileCfg_256x256x64 { + static constexpr ck_tile::index_t N_Tile = 128; +}; + +struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { + static constexpr bool kPadN = true; +}; + +// This class instantiates CK_Tile's grouped GEMM pipeline. +// See e.g. https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/03_gemm/universal_gemm_invoker.hpp for reference. +template +struct Runner{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; + + using UniversalTraits = ck_tile::PersistentTileGemmUniversalTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; + + static constexpr ck_tile::GemmPipelineScheduler Scheduler = + ck_tile::GemmPipelineScheduler::Intrawave; + + using Problem = ck_tile::UniversalGemmPipelineProblem< + AType, BType, AccType, GemmShape, UniversalTraits, Scheduler>; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using ET = EpilogueTraits; + + using Epilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem< + AType, BType, typename ET::DsDataType, AccType, + CType, typename ET::DsLayout, CLayout, + typename ET::ElemOp, + Partitioner::MPerBlock, Partitioner::NPerBlock, + TileCfg::M_Warp, TileCfg::N_Warp, + TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, + Problem::TransposeC>>; + + using Kernel = ck_tile::GroupedGemmKernel; +}; + +template +static bool run_grouped_impl(const NVTETensor* A_use, + const NVTETensor* B_use, + NVTETensor* D, + int group_num, + bool transA_use, + bool transB_use, + void* workspace, + size_t workspace_bytes, + hipStream_t stream) +{ + using Kernel = typename Runner::Kernel; + + const size_t needed = Kernel::GetWorkSpaceSize(group_num); + if (!workspace || workspace_bytes < needed) { + NVTE_WARN("ck_tile_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, + ", available bytes=", workspace_bytes, ". Falling back."); + return false; + } + + // GroupedGemmHostArgs<1> for the MultiD accumulate path, <0> for the overwrite path. + using HostArgs = std::conditional_t, + ck_tile::GroupedGemmHostArgs<0>>; + + thread_local std::vector descs; + descs.clear(); + descs.reserve(group_num); + + for (int i = 0; i < group_num; ++i) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(A_use[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(B_use[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(D[i]); + + const auto& a = data_view(*A_te); + const auto& b = data_view(*B_te); + const auto& d = data_view(*D_te); + + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || + !get_flat_2d_dims(*B_te, Bd0, Bd1) || + !get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2 (2D or higher)."); + return false; + } + + const int64_t M = transA_use ? Ad1 : Ad0; + const int64_t K = transA_use ? Ad0 : Ad1; + const int64_t N = transB_use ? Bd0 : Bd1; + const int64_t Kb = transB_use ? Bd1 : Bd0; + + if (Kb != K) { + NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); + return false; + } + + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); + return false; + } + + // Leading dimensions under the flattened-contiguous interpretation + const ck_tile::index_t stride_A = Ad1; + const ck_tile::index_t stride_B = Bd1; + const ck_tile::index_t stride_E = Dd1; + + if constexpr (Accumulate) { + // MultiD: E = Add(A@B, D1). D1 and E point to the same buffer for in-place accumulation. + descs.emplace_back( + a.dptr, b.dptr, + std::array{d.dptr}, // D1 = existing D contents (read) + d.dptr, // E = same buffer (write) + 1, M, N, K, + stride_A, stride_B, + std::array{stride_E}, + stride_E); + } else { + descs.emplace_back( + a.dptr, b.dptr, + std::array{}, + d.dptr, + 1, M, N, K, + stride_A, stride_B, + std::array{}, + stride_E); + } + } + + const dim3 grids = Kernel::GridSize(descs); + auto kargs = Kernel::MakeKargs(descs); + if (!Kernel::IsSupportedArgument(kargs)) { + NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. " + "Falling back."); + return false; + } + + HIP_CHECK_ERROR(hipMemcpyAsync(workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + stream)); + + const ck_tile::stream_config s{stream}; + const dim3 blocks = Kernel::BlockSize(); + + ck_tile::launch_kernel( + s, + ck_tile::make_kernel<1>( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(workspace), + group_num)); + return true; +} + +} // namespace grouped_gemm +} // namespace transformer_engine + +bool ck_tile_grouped_gemm(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + bool accumulate, + hipStream_t stream) +{ + if (group_num <= 0) + return true; + + using namespace transformer_engine; + using namespace transformer_engine::grouped_gemm; + + // Workspace pointer + bytes + void* ws_ptr = nullptr; + size_t ws_bytes = 0; + if (workspace) { + auto* ws_te = convertNVTETensorCheck(*workspace); + ws_ptr = ws_te->data.dptr; + ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); + } + + // Normalize similar to upstream + // See https://github.com/NVIDIA/TransformerEngine/blob/59f6f3876767d07045152bfae07b5dd4c54e1725/transformer_engine/common/gemm/cutlass_grouped_gemm.cu#L54-L68 + // I.e., swap A and B, as well as transa and transb. + const NVTETensor* A_use = B; + const NVTETensor* B_use = A; + const bool transA_use = transB; + const bool transB_use = transA; + + const auto a_dtype = convertNVTETensorCheck(A_use[0])->dtype(); + + // Get N from D[0] (assume uniform N across groups) + int64_t ref_d0 = 0, ref_d1 = 0; + Tensor* D0_te = convertNVTETensorCheck(D[0]); + if (!get_flat_2d_dims(*D0_te, ref_d0, ref_d1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]"); + return false; + } + const ck_tile::index_t N = static_cast(ref_d1); + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, te_type, { + using T = typename TETypeToCKType::type; + + auto run_with_tilecfg = [&](auto tile_tag) -> bool { + using TileCfgSel = decltype(tile_tag); + + TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, { + using ALayout = std::conditional_t; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(transB_use, kTransB, { + using BLayout = std::conditional_t; + + if (accumulate) { + return run_grouped_impl( + A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); + } else { + return run_grouped_impl( + A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); + } + }); + }); + }; + + // Select tile config like Primus-Turbo for FP16/BF16: + // N%256 -> 256x256x64 + // N%128 -> 256x128x64 + // else -> 256x128x64 padding + // NOTE: We assume N is uniform across groups. + if ((N % 256) == 0) { + return run_with_tilecfg(TileCfg_256x256x64{}); + } else if ((N % 128) == 0) { + return run_with_tilecfg(TileCfg_256x128x64{}); + } else { + return run_with_tilecfg(TileCfg_256x128x64_padding{}); + } + }); +} From 4dd1cfbabaced3f0d5103396f38e17fcbc666281 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 6 Apr 2026 20:43:35 +0000 Subject: [PATCH 10/14] move file --- .../common/gemm/{ => ck_grouped_gemm}/ck_grouped_gemm.cpp | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename transformer_engine/common/gemm/{ => ck_grouped_gemm}/ck_grouped_gemm.cpp (100%) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp similarity index 100% rename from transformer_engine/common/gemm/ck_grouped_gemm.cpp rename to transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp From 3608c1afecd17d01a30731ce12de09ffd4a89ef6 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 6 Apr 2026 20:44:18 +0000 Subject: [PATCH 11/14] Update file --- .../gemm/ck_grouped_gemm/ck_grouped_gemm.cpp | 386 +++++------------- 1 file changed, 92 insertions(+), 294 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp index def454f86..2aaf96df6 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp @@ -4,254 +4,9 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ -#include - -#include -#include "../common.h" - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" - -namespace transformer_engine { -namespace grouped_gemm { - -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; - -template struct TETypeToCKType; -template <> struct TETypeToCKType { using type = ck_tile::half_t; }; -template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; - -// Treat TE tensors as generalized 2D matrices by flattening: -// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. -static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, - int64_t& d0, int64_t& d1) { - // Require at least a matrix (rank >= 2). Higher ranks are flattened. - if (t.shape().size() < 2) - return false; - d0 = static_cast(t.flat_first_dim()); - d1 = static_cast(t.flat_last_dim()); - return true; -} - -// Selects epilogue traits based on whether we are accumulating (D += A*B) or not (D = A*B). -// For accumulate=true, the existing D buffer is passed as a MultiD input tensor and combined -// via element_wise::Add. For accumulate=false, no extra input is needed and PassThrough is used. -template -struct EpilogueTraits { - using DsDataType = ck_tile::tuple<>; - using DsLayout = ck_tile::tuple<>; - using ElemOp = ck_tile::element_wise::PassThrough; -}; -template -struct EpilogueTraits { - using DsDataType = ck_tile::tuple; - using DsLayout = ck_tile::tuple; - using ElemOp = ck_tile::element_wise::Add; -}; - -static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { - return t.data; // rowwise data view -} - -// Primus-Turbo-like FP16/BF16 tile configs -// Selection rule: -// if (N % 256 == 0) use 256x256x64 -// else if (N % 128 == 0) use 256x128x64 -// else use 256x128x64 with N padding enabled -struct TileCfg_256x256x64 { - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool DoubleSmemBuffer = false; - - static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; - static constexpr ck_tile::index_t TilePartitionerM01 = 4; -}; - -struct TileCfg_256x128x64 : TileCfg_256x256x64 { - static constexpr ck_tile::index_t N_Tile = 128; -}; - -struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { - static constexpr bool kPadN = true; -}; - -// This class instantiates CK_Tile's grouped GEMM pipeline. -// See e.g. https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/03_gemm/universal_gemm_invoker.hpp for reference. -template -struct Runner{ - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile::sequence>; - - using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< - GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; - - using UniversalTraits = ck_tile::PersistentTileGemmUniversalTraits< - TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, - TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; - - static constexpr ck_tile::GemmPipelineScheduler Scheduler = - ck_tile::GemmPipelineScheduler::Intrawave; - - using Problem = ck_tile::UniversalGemmPipelineProblem< - AType, BType, AccType, GemmShape, UniversalTraits, Scheduler>; - - using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; - - using ET = EpilogueTraits; - - using Epilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem< - AType, BType, typename ET::DsDataType, AccType, - CType, typename ET::DsLayout, CLayout, - typename ET::ElemOp, - Partitioner::MPerBlock, Partitioner::NPerBlock, - TileCfg::M_Warp, TileCfg::N_Warp, - TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, - Problem::TransposeC>>; - - using Kernel = ck_tile::GroupedGemmKernel; -}; - -template -static bool run_grouped_impl(const NVTETensor* A_use, - const NVTETensor* B_use, - NVTETensor* D, - int group_num, - bool transA_use, - bool transB_use, - void* workspace, - size_t workspace_bytes, - hipStream_t stream) -{ - using Kernel = typename Runner::Kernel; - - const size_t needed = Kernel::GetWorkSpaceSize(group_num); - if (!workspace || workspace_bytes < needed) { - NVTE_WARN("ck_tile_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, - ", available bytes=", workspace_bytes, ". Falling back."); - return false; - } - - // GroupedGemmHostArgs<1> for the MultiD accumulate path, <0> for the overwrite path. - using HostArgs = std::conditional_t, - ck_tile::GroupedGemmHostArgs<0>>; - - thread_local std::vector descs; - descs.clear(); - descs.reserve(group_num); - - for (int i = 0; i < group_num; ++i) { - const transformer_engine::Tensor* const A_te = - transformer_engine::convertNVTETensorCheck(A_use[i]); - const transformer_engine::Tensor* const B_te = - transformer_engine::convertNVTETensorCheck(B_use[i]); - transformer_engine::Tensor* D_te = - transformer_engine::convertNVTETensorCheck(D[i]); - - const auto& a = data_view(*A_te); - const auto& b = data_view(*B_te); - const auto& d = data_view(*D_te); - - int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || - !get_flat_2d_dims(*B_te, Bd0, Bd1) || - !get_flat_2d_dims(*D_te, Dd0, Dd1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2 (2D or higher)."); - return false; - } - - const int64_t M = transA_use ? Ad1 : Ad0; - const int64_t K = transA_use ? Ad0 : Ad1; - const int64_t N = transB_use ? Bd0 : Bd1; - const int64_t Kb = transB_use ? Bd1 : Bd0; - - if (Kb != K) { - NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); - return false; - } - - if (Dd0 != M || Dd1 != N) { - NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); - return false; - } - - // Leading dimensions under the flattened-contiguous interpretation - const ck_tile::index_t stride_A = Ad1; - const ck_tile::index_t stride_B = Bd1; - const ck_tile::index_t stride_E = Dd1; - - if constexpr (Accumulate) { - // MultiD: E = Add(A@B, D1). D1 and E point to the same buffer for in-place accumulation. - descs.emplace_back( - a.dptr, b.dptr, - std::array{d.dptr}, // D1 = existing D contents (read) - d.dptr, // E = same buffer (write) - 1, M, N, K, - stride_A, stride_B, - std::array{stride_E}, - stride_E); - } else { - descs.emplace_back( - a.dptr, b.dptr, - std::array{}, - d.dptr, - 1, M, N, K, - stride_A, stride_B, - std::array{}, - stride_E); - } - } - - const dim3 grids = Kernel::GridSize(descs); - auto kargs = Kernel::MakeKargs(descs); - if (!Kernel::IsSupportedArgument(kargs)) { - NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. " - "Falling back."); - return false; - } - - HIP_CHECK_ERROR(hipMemcpyAsync(workspace, - kargs.data(), - kargs.size() * sizeof(typename decltype(kargs)::value_type), - hipMemcpyHostToDevice, - stream)); - - const ck_tile::stream_config s{stream}; - const dim3 blocks = Kernel::BlockSize(); - - ck_tile::launch_kernel( - s, - ck_tile::make_kernel<1>( - Kernel{}, grids, blocks, 0, - ck_tile::cast_pointer_to_constant_address_space(workspace), - group_num)); - return true; -} - -} // namespace grouped_gemm -} // namespace transformer_engine +#include "ck_grouped_gemm_common.h" +#include "ck_grouped_gemm_fp16.h" +#include "ck_grouped_gemm_fp8.h" bool ck_tile_grouped_gemm(const NVTETensor* A, const NVTETensor* B, @@ -261,20 +16,19 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, bool transB, NVTETensor* workspace, bool accumulate, - hipStream_t stream) -{ - if (group_num <= 0) + hipStream_t stream) { + if (group_num <= 0) { return true; + } using namespace transformer_engine; using namespace transformer_engine::grouped_gemm; - // Workspace pointer + bytes - void* ws_ptr = nullptr; + void* ws_ptr = nullptr; size_t ws_bytes = 0; if (workspace) { auto* ws_te = convertNVTETensorCheck(*workspace); - ws_ptr = ws_te->data.dptr; + ws_ptr = ws_te->data.dptr; ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); } @@ -283,56 +37,100 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, // I.e., swap A and B, as well as transa and transb. const NVTETensor* A_use = B; const NVTETensor* B_use = A; - const bool transA_use = transB; - const bool transB_use = transA; + bool transA_use = transB; + bool transB_use = transA; + bool use_b_columnwise_data = false; + + const auto caller_a_dtype = convertNVTETensorCheck(A[0])->dtype(); + const bool is_8bit_float = is_fp8_dtype(caller_a_dtype); + const bool is_16bit_float = is_fp16_dtype(caller_a_dtype); + + // Currently the accumulate path is only supported on fp16 + if (accumulate && is_8bit_float) + return false; + + // Handle pathological NN case during fp8 dX GEMM by reading W columnwise and re-formulating as NT + if (!transA_use && !transB_use && is_8bit_float) { + auto* B0_te = convertNVTETensorCheck(B_use[0]); + if (B0_te->has_columnwise_data()) { + use_b_columnwise_data = true; + transB_use = true; + } + } const auto a_dtype = convertNVTETensorCheck(A_use[0])->dtype(); + const auto b_dtype = convertNVTETensorCheck(B_use[0])->dtype(); - // Get N from D[0] (assume uniform N across groups) - int64_t ref_d0 = 0, ref_d1 = 0; Tensor* D0_te = convertNVTETensorCheck(D[0]); - if (!get_flat_2d_dims(*D0_te, ref_d0, ref_d1)) { + const auto d_dtype = D0_te->dtype(); + + Tensor* A0_te = convertNVTETensorCheck(A_use[0]); + Tensor* B0_te = convertNVTETensorCheck(B_use[0]); + + int64_t a0 = 0, a1 = 0; + int64_t b0 = 0, b1 = 0; + int64_t d0 = 0, d1 = 0; + + if (!get_flat_2d_dims(*A0_te, a0, a1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized A_use[0]"); + return false; + } + + if (use_b_columnwise_data) { + if (!get_columnwise_storage_2d_dims(B0_te->columnwise_data, b0, b1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected 2D columnwise_data for B_use[0]"); + return false; + } + } else { + if (!get_flat_2d_dims(*B0_te, b0, b1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized B_use[0]"); + return false; + } + } + + if (!get_flat_2d_dims(*D0_te, d0, d1)) { NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]"); return false; } - const ck_tile::index_t N = static_cast(ref_d1); - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, te_type, { - using T = typename TETypeToCKType::type; + const int64_t m = transA_use ? a1 : a0; + const int64_t kA = transA_use ? a0 : a1; - auto run_with_tilecfg = [&](auto tile_tag) -> bool { - using TileCfgSel = decltype(tile_tag); + const int64_t kB = transB_use ? b1 : b0; + const int64_t n = transB_use ? b0 : b1; - TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, { - using ALayout = std::conditional_t; + if (kA != kB) { + NVTE_ERROR("ck_tile_grouped_gemm: normalized GEMM K mismatch: op(A_use) is ", + m, "x", kA, ", op(B_use) is ", kB, "x", n); + return false; + } - TRANSFORMER_ENGINE_SWITCH_CONDITION(transB_use, kTransB, { - using BLayout = std::conditional_t; + if (d0 != m || d1 != n) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch for normalized GEMM. " + "D is ", d0, "x", d1, " but expected ", m, "x", n); + return false; + } - if (accumulate) { - return run_grouped_impl( - A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); - } else { - return run_grouped_impl( - A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); - } - }); - }); - }; + GroupedGemmRunContext ctx = { + A_use, + B_use, + D, + static_cast(n), + group_num, + transA_use, + transB_use, + ws_ptr, + ws_bytes, + stream, + use_b_columnwise_data, + accumulate}; + + if (is_16bit_float) { + return ck_tile_grouped_gemm_fp16_dispatch(a_dtype, b_dtype, d_dtype, ctx); + } else if (is_8bit_float) { + return ck_tile_grouped_gemm_fp8_dispatch(a_dtype, b_dtype, d_dtype, ctx); + } - // Select tile config like Primus-Turbo for FP16/BF16: - // N%256 -> 256x256x64 - // N%128 -> 256x128x64 - // else -> 256x128x64 padding - // NOTE: We assume N is uniform across groups. - if ((N % 256) == 0) { - return run_with_tilecfg(TileCfg_256x256x64{}); - } else if ((N % 128) == 0) { - return run_with_tilecfg(TileCfg_256x128x64{}); - } else { - return run_with_tilecfg(TileCfg_256x128x64_padding{}); - } - }); -} + NVTE_WARN("ck_tile_grouped_gemm: input dtype is neither fp16 nor fp8."); + return false; +} \ No newline at end of file From 4b4f29df8e852cecbc14a11b5ab4a310942cf1fa Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 6 Apr 2026 21:44:24 +0000 Subject: [PATCH 12/14] Merge changes from dev. --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 0258951d4..be6c079be 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 0258951d4d512f4714eb1574496f4d57669b1b93 +Subproject commit be6c079be8aaffa0fc079fcf039887e637c289c7 From 9edd059174d0e19806cdf02e09ffdbcaaf8bf32f Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 6 Apr 2026 22:14:22 +0000 Subject: [PATCH 13/14] fix EOL --- .../common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp index 2aaf96df6..4e252efe8 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp @@ -133,4 +133,4 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, NVTE_WARN("ck_tile_grouped_gemm: input dtype is neither fp16 nor fp8."); return false; -} \ No newline at end of file +} From b35a023f12377665c109e5ee5f10c9fce8e2df1f Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 6 Apr 2026 22:28:44 +0000 Subject: [PATCH 14/14] Sync cudnn-frontend submodule pointer with dev --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index be6c079be..0258951d4 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit be6c079be8aaffa0fc079fcf039887e637c289c7 +Subproject commit 0258951d4d512f4714eb1574496f4d57669b1b93