From b6938ffd03749feffb427cb6b2729f1636a66cd4 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Wed, 13 May 2026 23:35:44 -0700 Subject: [PATCH 1/9] updated Signed-off-by: Zhongbo Zhu --- tests/pytorch/test_grouped_tensor.py | 48 ++++++++ transformer_engine/common/common.cu | 107 ++++++++++++++++++ .../transformer_engine/transformer_engine.h | 25 ++++ transformer_engine/pytorch/csrc/extensions.h | 2 + .../pytorch/csrc/extensions/misc.cpp | 48 ++++++++ .../pytorch/csrc/extensions/pybind.cpp | 3 + .../pytorch/ops/fused/forward_grouped_mlp.py | 17 ++- 7 files changed, 246 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index c54c9758ff..c8dd51603d 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -163,6 +163,54 @@ def test_basic_construction_varying_first_dim(self) -> None: shape[0][1], ) # sum of first dims + @pytest.mark.parametrize( + "split_sizes_list,logical_last_dim", + [ + pytest.param([3, 4, 5, 2], 7, id="all_nonzero"), + pytest.param([3, 0, 5, 2], 7, id="zero_middle"), + pytest.param([0, 3, 5, 0], 11, id="zero_edges"), + pytest.param([1], 17, id="single_group"), + pytest.param([1, 2, 3, 4, 5, 6, 7, 8], 13, id="many_groups"), + ], + ) + @pytest.mark.parametrize("input_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) + def test_prepare_grouped_splits( + self, + input_dtype: torch.dtype, + split_sizes_list: List[int], + logical_last_dim: int, + ) -> None: + """Test fused grouped split metadata preparation.""" + split_sizes = torch.tensor(split_sizes_list, dtype=input_dtype, device="cuda") + num_groups = split_sizes.numel() + + ( + split_sizes_i64, + base_offsets, + split_points, + tensor_offsets, + ) = tex.prepare_grouped_splits(split_sizes, num_groups, logical_last_dim) + + expected_split_sizes = split_sizes.to(torch.int64) + expected_base_offsets = torch.cat( + ( + torch.zeros(1, dtype=torch.int64, device="cuda"), + torch.cumsum(expected_split_sizes, dim=0), + ) + ) + expected_split_points = expected_base_offsets[1:].to(torch.int32) + expected_tensor_offsets = expected_base_offsets * logical_last_dim + + assert split_sizes_i64.dtype == torch.int64 + assert base_offsets.dtype == torch.int64 + # cuDNN grouped GEMM consumes int32 end offsets; TE GroupedTensor metadata stays int64. + assert split_points.dtype == torch.int32 + assert tensor_offsets.dtype == torch.int64 + assert torch.equal(split_sizes_i64, expected_split_sizes) + assert torch.equal(base_offsets, expected_base_offsets) + assert torch.equal(split_points, expected_split_points) + assert torch.equal(tensor_offsets, expected_tensor_offsets) + def test_split_into_quantized_tensors_no_quantization(self) -> None: """Test split_into_quantized_tensors for unquantized tensors""" num_tensors = 3 diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 1bdd80a369..81dbd76bc6 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -129,6 +129,59 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } } +template +__global__ void __launch_bounds__(kThreadsPerBlock) prepare_grouped_splits_kernel( + const FirstDimT *__restrict__ first_dims, int64_t *__restrict__ first_dims_i64, + int64_t *__restrict__ base_offsets, + int32_t *__restrict__ split_points, int64_t *__restrict__ tensor_offsets, + int64_t logical_last_dim, size_t num_tensors) { + + __shared__ int64_t block_scan[kThreadsPerBlock]; + __shared__ int64_t chunk_prefix; + + const size_t tid = threadIdx.x; + if (tid == 0) { + base_offsets[0] = 0; + tensor_offsets[0] = 0; + chunk_prefix = 0; + } + __syncthreads(); + + for (size_t chunk_start = 0; chunk_start < num_tensors; chunk_start += kThreadsPerBlock) { + const size_t idx = chunk_start + tid; + + block_scan[tid] = 0; + if (idx < num_tensors) { + block_scan[tid] = static_cast(first_dims[idx]); + first_dims_i64[idx] = block_scan[tid]; + } + __syncthreads(); + + // Inclusive scan in shared memory. + for (size_t offset = 1; offset < kThreadsPerBlock; offset <<= 1) { + const int64_t addend = (tid >= offset) ? block_scan[tid - offset] : 0; + __syncthreads(); + block_scan[tid] += addend; + __syncthreads(); + } + + if (idx < num_tensors) { + const int64_t prefix = chunk_prefix + block_scan[tid]; + base_offsets[idx + 1] = prefix; + // cuDNN grouped GEMM expects padded split end offsets as int32. TE + // GroupedTensor metadata keeps the full int64 base_offsets/tensor_offsets. + split_points[idx] = static_cast(prefix); + tensor_offsets[idx + 1] = prefix * logical_last_dim; + } + __syncthreads(); + + if (tid == kThreadsPerBlock - 1) { + chunk_prefix += block_scan[tid]; + } + __syncthreads(); + } +} + } // namespace #define MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, vectorizedType, stream) \ @@ -171,6 +224,60 @@ void nvte_splits_to_offsets(const int64_t *first_dims, int64_t *output, size_t n logical_last_dim); NVTE_CHECK_CUDA(cudaGetLastError()); } + +void nvte_prepare_grouped_splits(const NVTETensor first_dims, NVTETensor first_dims_i64, + NVTETensor base_offsets, NVTETensor split_points, + NVTETensor tensor_offsets, int64_t logical_last_dim, + cudaStream_t stream) { + NVTE_API_CALL(nvte_prepare_grouped_splits); + + const auto *first_dims_tensor = convertNVTETensorCheck(first_dims); + const auto *first_dims_i64_tensor = convertNVTETensorCheck(first_dims_i64); + const auto *base_offsets_tensor = convertNVTETensorCheck(base_offsets); + const auto *split_points_tensor = convertNVTETensorCheck(split_points); + const auto *tensor_offsets_tensor = convertNVTETensorCheck(tensor_offsets); + const auto first_dims_dtype = first_dims_tensor->dtype(); + const auto num_tensors = first_dims_tensor->numel(); + const auto offsets_numel = num_tensors + 1; + const auto is_tensor = [](const Tensor *tensor, DType dtype, size_t numel) { + return tensor->dim() == 1 && tensor->dtype() == dtype && tensor->numel() == numel; + }; + + NVTE_CHECK( + num_tensors > 0 && logical_last_dim >= 0 && first_dims_tensor->dim() == 1 && + (first_dims_dtype == DType::kInt32 || first_dims_dtype == DType::kInt64) && + is_tensor(first_dims_i64_tensor, DType::kInt64, num_tensors) && + is_tensor(base_offsets_tensor, DType::kInt64, offsets_numel) && + is_tensor(split_points_tensor, DType::kInt32, num_tensors) && + is_tensor(tensor_offsets_tensor, DType::kInt64, offsets_numel), + "Invalid grouped split metadata. Expected first_dims int32/int64[N], " + "first_dims_i64 int64[N], base_offsets int64[N+1], split_points int32[N], " + "tensor_offsets int64[N+1], and logical_last_dim >= 0."); + // split_points is the only int32 output by design: cuDNN grouped GEMM uses + // int32 padded split end offsets, while TE grouped tensor offsets are int64. + + switch (first_dims_dtype) { + case DType::kInt32: + prepare_grouped_splits_kernel<<<1, kThreadsPerBlock, 0, stream>>>( + static_cast(first_dims_tensor->data.dptr), + static_cast(first_dims_i64_tensor->data.dptr), + static_cast(base_offsets_tensor->data.dptr), + static_cast(split_points_tensor->data.dptr), + static_cast(tensor_offsets_tensor->data.dptr), logical_last_dim, num_tensors); + break; + case DType::kInt64: + prepare_grouped_splits_kernel<<<1, kThreadsPerBlock, 0, stream>>>( + static_cast(first_dims_tensor->data.dptr), + static_cast(first_dims_i64_tensor->data.dptr), + static_cast(base_offsets_tensor->data.dptr), + static_cast(split_points_tensor->data.dptr), + static_cast(tensor_offsets_tensor->data.dptr), logical_last_dim, num_tensors); + break; + default: + NVTE_ERROR("first_dims must have dtype int32 or int64."); + } + NVTE_CHECK_CUDA(cudaGetLastError()); +} } // extern "C" void checkCuDriverContext(CUstream stream) { diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 045ae88893..5a8850d330 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -454,6 +454,31 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream void nvte_splits_to_offsets(const int64_t *first_dims, int64_t *output, size_t num_tensors, int64_t logical_last_dim, cudaStream_t stream); +/*! \brief Prepare grouped split metadata. + * + * This is a fused variant of split metadata preparation for grouped kernels. + * It accepts either int32 or int64 first dimensions, writes int64 metadata + * for TE grouped tensors, writes int32 split points for cuDNN grouped GEMM, + * and writes scaled int64 tensor offsets. + * + * \param[in] first_dims Device int32 or int64 tensor of shape [num_tensors]. + * \param[out] first_dims_i64 Device int64 tensor of shape [num_tensors]. + * \param[out] base_offsets Device int64 tensor of shape [num_tensors + 1], + * containing [0, cumsum(first_dims)]. + * \param[out] split_points Device int32 tensor of shape [num_tensors], + * containing cumsum(first_dims) without the leading 0. This is int32 because + * it is consumed by cuDNN grouped GEMM padded offsets; TE grouped tensor + * offsets remain int64. + * \param[out] tensor_offsets Device int64 tensor of shape [num_tensors + 1], + * containing base_offsets * logical_last_dim. + * \param[in] logical_last_dim Scale factor for tensor_offsets. + * \param[in] stream CUDA stream to use for the operation. + */ +void nvte_prepare_grouped_splits(const NVTETensor first_dims, NVTETensor first_dims_i64, + NVTETensor base_offsets, NVTETensor split_points, + NVTETensor tensor_offsets, int64_t logical_last_dim, + cudaStream_t stream); + /*! \brief TE Grouped Tensor type * * NVTEGroupedTensor is a collection of tensors with potentially different shapes diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 9b10a9c5a4..f11437f935 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -488,6 +488,8 @@ std::tuple get_device_pointer_for_data_and_s std::vector data_tensors, std::vector scale_tensors, bool swizzle, bool rowwise, transformer_engine::DType data_dtype); at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_dim); +std::vector prepare_grouped_splits(const at::Tensor &split_sizes, int64_t num_groups, + int64_t logical_last_dim); /*************************************************************************************************** * Support THD format for Context Parallel diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index c5707fa53c..e1e554deda 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include "../extensions.h" +#include "pybind.h" namespace transformer_engine::pytorch { @@ -30,4 +31,51 @@ at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_ return output; } +std::vector prepare_grouped_splits(const at::Tensor &split_sizes, int64_t num_groups, + int64_t logical_last_dim) { + NVTE_CHECK(split_sizes.is_cuda(), "split_sizes must be on CUDA."); + NVTE_CHECK(split_sizes.scalar_type() == at::kInt || split_sizes.scalar_type() == at::kLong, + "split_sizes must have dtype int32 or int64."); + NVTE_CHECK(split_sizes.dim() == 1, "split_sizes must be a 1D tensor."); + NVTE_CHECK(split_sizes.is_contiguous(), "split_sizes must be contiguous."); + NVTE_CHECK(num_groups > 0, "num_groups must be greater than 0."); + NVTE_CHECK(split_sizes.numel() == num_groups, "split_sizes must have length ", num_groups, "."); + NVTE_CHECK(logical_last_dim >= 0, "logical_last_dim must be non-negative."); + + const int64_t offsets_length = num_groups + 1; + + // Return order is part of the Python contract: + // 0. split_sizes_i64: int64[num_groups], canonical TE GroupedTensor first dims. + // 1. base_offsets: int64[num_groups + 1], [0, cumsum(split_sizes)]. + // 2. split_points: int32[num_groups], cumsum(split_sizes) without the leading 0 + // for cuDNN grouped GEMM padded offsets. This is intentionally int32 + // even though TE grouped tensor metadata uses int64 below. + // 3. tensor_offsets: int64[num_groups + 1], base_offsets * logical_last_dim. + auto outputs = bulk_allocate( + {{static_cast(num_groups)}, + {static_cast(offsets_length)}, + {static_cast(num_groups)}, + {static_cast(offsets_length)}}, + {at::kLong, at::kLong, at::kInt, at::kLong}, split_sizes.device(), std::nullopt); + auto split_sizes_i64 = outputs[0]; + auto base_offsets = outputs[1]; + auto split_points = outputs[2]; + auto tensor_offsets = outputs[3]; + + auto split_sizes_nvte = makeTransformerEngineTensor(split_sizes); + auto split_sizes_i64_nvte = makeTransformerEngineTensor(split_sizes_i64); + auto base_offsets_nvte = makeTransformerEngineTensor(base_offsets); + auto split_points_nvte = makeTransformerEngineTensor(split_points); + auto tensor_offsets_nvte = makeTransformerEngineTensor(tensor_offsets); + + NVTE_SCOPED_GIL_RELEASE({ + nvte_prepare_grouped_splits(split_sizes_nvte.data(), split_sizes_i64_nvte.data(), + base_offsets_nvte.data(), split_points_nvte.data(), + tensor_offsets_nvte.data(), logical_last_dim, + at::cuda::getCurrentCUDAStream()); + }); + + return outputs; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index a813f3119d..ca1a72ddf5 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -497,6 +497,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("splits_to_offsets", &transformer_engine::pytorch::splits_to_offsets, "Compute grouped tensor offsets from split sizes", py::arg("first_dims"), py::arg("logical_last_dim"), py::call_guard()); + m.def("prepare_grouped_splits", &transformer_engine::pytorch::prepare_grouped_splits, + "Prepare grouped split metadata from int32 or int64 split sizes", py::arg("split_sizes"), + py::arg("num_groups"), py::arg("logical_last_dim")); m.def("get_num_cublas_streams", &nvte_get_num_compute_streams, "Get number of compute streams", py::call_guard()); diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 91db2ff9b7..0d384e7f6d 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -159,10 +159,19 @@ def fuser_forward( split_sizes = fc1_split_sizes if int(split_sizes.numel()) != num_groups: raise ValueError(f"Expected {num_groups} splits, but got {int(split_sizes.numel())}.") - split_sizes = split_sizes.to(dtype=torch.int64, device=device) - base_split_offsets = tex.splits_to_offsets(split_sizes, 1) - split_points = base_split_offsets[1:].to(dtype=torch.int) - fc2_x_tensor_offsets = base_split_offsets * fc2_weight_shape[1] + # Prepare all split metadata in one CUDA kernel. The returned split_sizes is the + # canonical TE representation: int64[num_groups]. Python uses it from here + # onward for grouped quantization and backward state. + # + # base_split_offsets: int64[num_groups + 1], [0, cumsum(split_sizes)] + # split_points: int32[num_groups], cumsum(split_sizes) without the leading 0 + # fc2_x_tensor_offsets: int64[num_groups + 1], base_split_offsets * fc2 K + ( + split_sizes, + base_split_offsets, + split_points, + fc2_x_tensor_offsets, + ) = tex.prepare_grouped_splits(split_sizes, num_groups, fc2_weight_shape[1]) # Extract post-scales from extra input scales = basic_op_extra_inputs[1][0] From 5643f46cd02fa3b6d1d313f703cae96bcfc58e30 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 May 2026 06:39:13 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/common.cu | 32 +++++++++---------- .../pytorch/csrc/extensions/misc.cpp | 12 +++---- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 81dbd76bc6..2de42ab2dc 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -130,12 +130,13 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } template -__global__ void __launch_bounds__(kThreadsPerBlock) prepare_grouped_splits_kernel( - const FirstDimT *__restrict__ first_dims, int64_t *__restrict__ first_dims_i64, - int64_t *__restrict__ base_offsets, - int32_t *__restrict__ split_points, int64_t *__restrict__ tensor_offsets, - int64_t logical_last_dim, size_t num_tensors) { - +__global__ void __launch_bounds__(kThreadsPerBlock) + prepare_grouped_splits_kernel(const FirstDimT *__restrict__ first_dims, + int64_t *__restrict__ first_dims_i64, + int64_t *__restrict__ base_offsets, + int32_t *__restrict__ split_points, + int64_t *__restrict__ tensor_offsets, int64_t logical_last_dim, + size_t num_tensors) { __shared__ int64_t block_scan[kThreadsPerBlock]; __shared__ int64_t chunk_prefix; @@ -243,16 +244,15 @@ void nvte_prepare_grouped_splits(const NVTETensor first_dims, NVTETensor first_d return tensor->dim() == 1 && tensor->dtype() == dtype && tensor->numel() == numel; }; - NVTE_CHECK( - num_tensors > 0 && logical_last_dim >= 0 && first_dims_tensor->dim() == 1 && - (first_dims_dtype == DType::kInt32 || first_dims_dtype == DType::kInt64) && - is_tensor(first_dims_i64_tensor, DType::kInt64, num_tensors) && - is_tensor(base_offsets_tensor, DType::kInt64, offsets_numel) && - is_tensor(split_points_tensor, DType::kInt32, num_tensors) && - is_tensor(tensor_offsets_tensor, DType::kInt64, offsets_numel), - "Invalid grouped split metadata. Expected first_dims int32/int64[N], " - "first_dims_i64 int64[N], base_offsets int64[N+1], split_points int32[N], " - "tensor_offsets int64[N+1], and logical_last_dim >= 0."); + NVTE_CHECK(num_tensors > 0 && logical_last_dim >= 0 && first_dims_tensor->dim() == 1 && + (first_dims_dtype == DType::kInt32 || first_dims_dtype == DType::kInt64) && + is_tensor(first_dims_i64_tensor, DType::kInt64, num_tensors) && + is_tensor(base_offsets_tensor, DType::kInt64, offsets_numel) && + is_tensor(split_points_tensor, DType::kInt32, num_tensors) && + is_tensor(tensor_offsets_tensor, DType::kInt64, offsets_numel), + "Invalid grouped split metadata. Expected first_dims int32/int64[N], " + "first_dims_i64 int64[N], base_offsets int64[N+1], split_points int32[N], " + "tensor_offsets int64[N+1], and logical_last_dim >= 0."); // split_points is the only int32 output by design: cuDNN grouped GEMM uses // int32 padded split end offsets, while TE grouped tensor offsets are int64. diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index e1e554deda..020cc2e051 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -51,12 +51,12 @@ std::vector prepare_grouped_splits(const at::Tensor &split_sizes, in // for cuDNN grouped GEMM padded offsets. This is intentionally int32 // even though TE grouped tensor metadata uses int64 below. // 3. tensor_offsets: int64[num_groups + 1], base_offsets * logical_last_dim. - auto outputs = bulk_allocate( - {{static_cast(num_groups)}, - {static_cast(offsets_length)}, - {static_cast(num_groups)}, - {static_cast(offsets_length)}}, - {at::kLong, at::kLong, at::kInt, at::kLong}, split_sizes.device(), std::nullopt); + auto outputs = bulk_allocate({{static_cast(num_groups)}, + {static_cast(offsets_length)}, + {static_cast(num_groups)}, + {static_cast(offsets_length)}}, + {at::kLong, at::kLong, at::kInt, at::kLong}, split_sizes.device(), + std::nullopt); auto split_sizes_i64 = outputs[0]; auto base_offsets = outputs[1]; auto split_points = outputs[2]; From 848c99e746ddcb2ffb99f0afa61dc33d01bb5ae6 Mon Sep 17 00:00:00 2001 From: zhongboz Date: Thu, 14 May 2026 19:18:55 -0700 Subject: [PATCH 3/9] benchmark Signed-off-by: zhongboz --- .../benchmark_graph_safe_grouped_linear.py | 376 ++++++++++++++++++ 1 file changed, 376 insertions(+) create mode 100644 benchmarks/linear/benchmark_graph_safe_grouped_linear.py diff --git a/benchmarks/linear/benchmark_graph_safe_grouped_linear.py b/benchmarks/linear/benchmark_graph_safe_grouped_linear.py new file mode 100644 index 0000000000..26fb232c0a --- /dev/null +++ b/benchmarks/linear/benchmark_graph_safe_grouped_linear.py @@ -0,0 +1,376 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Benchmark MXFP8 graph-safe grouped MLP. + +This mirrors ``benchmark_grouped_linear.py`` but targets the graph-safe TE ops +path used by grouped MLP: + + GroupedLinear -> ScaledSwiGLU -> GroupedLinear + +The benchmark intentionally uses CUDA-device ``m_splits`` and MXFP8 only. + +Example: + + python benchmarks/linear/benchmark_graph_safe_grouped_linear.py + +Forward-only: + + python benchmarks/linear/benchmark_graph_safe_grouped_linear.py --fwd-only + +Nsight Systems: + + (optionally: unset DEBUGINFOD_URLS) + + nsys profile \ + --output=./benchmarks/linear/graph_safe_grouped_linear_mxfp8 \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_graph_safe_grouped_linear.py --profile +""" + +# Match the Qwen MXFP8 SFT launch toggles before importing TE. +import os + +os.environ.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "1") +os.environ.setdefault("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1") +os.environ.setdefault("NVTE_CUTEDSL_FUSED_GROUPED_MLP", "1") +os.environ.setdefault("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") + +import argparse +from contextlib import nullcontext + +import pandas as pd +import torch +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.common.recipe import MXFP8BlockScaling +from transformer_engine.pytorch.quantization import FP8GlobalStateManager + + +MXFP8_AVAILABLE, REASON_FOR_NO_MXFP8 = FP8GlobalStateManager.is_mxfp8_available() + + +def parse_int_list(value: str) -> list[int]: + """Parse comma-separated integers.""" + return [int(x) for x in value.split(",") if x] + + +def make_uniform_splits(total_tokens: int, num_groups: int) -> list[int]: + """Split tokens uniformly across groups.""" + if total_tokens % num_groups != 0: + raise ValueError( + f"Uniform split requires total_tokens divisible by num_groups, " + f"got total_tokens={total_tokens}, num_groups={num_groups}" + ) + return [total_tokens // num_groups] * num_groups + + +def build_grouped_mlp( + *, + num_groups: int, + hidden_dim: int, + ffn_hidden_dim: int, + dtype: torch.dtype, + single_grouped_weight: bool, + accumulate_into_main_grad: bool, + glu_interleave_size: int, +) -> te_ops.Sequential: + """Build graph-safe grouped MLP ops sequence.""" + recipe = MXFP8BlockScaling() + with te.quantized_model_init(enabled=True, recipe=recipe): + fc1 = te_ops.GroupedLinear( + num_groups, + hidden_dim, + 2 * ffn_hidden_dim, + bias=False, + device="cuda", + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + fc2 = te_ops.GroupedLinear( + num_groups, + ffn_hidden_dim, + hidden_dim, + bias=False, + device="cuda", + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + return te_ops.Sequential( + fc1, + te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size), + fc2, + ) + + +def init_main_grads(module: torch.nn.Module, value: float = 0.0) -> None: + """Initialize Megatron-style main_grad buffers for accumulate_into_main_grad.""" + with torch.no_grad(): + for param in module.parameters(): + if getattr(param, "main_grad", None) is None: + param.main_grad = torch.empty(param.size(), device=param.device, dtype=torch.float32) + param.main_grad.fill_(value) + + +def zero_grads(module: torch.nn.Module, x: torch.Tensor, scales: torch.Tensor) -> None: + """Reset gradients without changing allocated main_grad buffers.""" + module.zero_grad(set_to_none=True) + x.grad = None + scales.grad = None + + +def run_grouped_mlp_steps( + module: torch.nn.Module, + x: torch.Tensor, + split_sizes: torch.Tensor, + scales: torch.Tensor, + grad_output: torch.Tensor, + *, + recipe: MXFP8BlockScaling, + fwd_only: bool, + num_steps: int, + accumulate_into_main_grad: bool, +) -> torch.Tensor: + """Run eager grouped MLP for a number of synthetic microbatches.""" + quantization_context = te.autocast(enabled=True, recipe=recipe) + + if fwd_only: + with torch.no_grad(), quantization_context: + for _ in range(num_steps): + out = module(x, split_sizes, scales, split_sizes) + return out + + zero_grads(module, x, scales) + if accumulate_into_main_grad: + init_main_grads(module) + + with quantization_context: + for step in range(num_steps): + torch.cuda.nvtx.range_push(f"step_{step}") + out = module(x, split_sizes, scales, split_sizes) + out.backward(grad_output) + torch.cuda.nvtx.range_pop() + return out + + +def benchmark_case( + *, + total_tokens: int, + hidden_dim: int, + ffn_hidden_dim: int, + num_groups: int, + dtype: torch.dtype, + fwd_only: bool, + single_grouped_weight: bool, + accumulate_into_main_grad: bool, + glu_interleave_size: int, + num_microbatches: int, + min_run_time: float, + profile: bool, +) -> float: + """Benchmark one grouped MLP shape.""" + split_sizes_list = make_uniform_splits(total_tokens, num_groups) + split_sizes = torch.tensor(split_sizes_list, dtype=torch.int64, device="cuda") + x = torch.randn( + (total_tokens, hidden_dim), + dtype=dtype, + device="cuda", + requires_grad=not fwd_only, + ) + scales = torch.ones( + (total_tokens,), + dtype=dtype, + device="cuda", + requires_grad=not fwd_only, + ) + grad_output = torch.ones((total_tokens, hidden_dim), dtype=dtype, device="cuda") + + module = build_grouped_mlp( + num_groups=num_groups, + hidden_dim=hidden_dim, + ffn_hidden_dim=ffn_hidden_dim, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + glu_interleave_size=glu_interleave_size, + ) + recipe = MXFP8BlockScaling() + + print( + "case:", + f"tokens={total_tokens}", + f"hidden={hidden_dim}", + f"ffn_hidden={ffn_hidden_dim}", + f"num_groups={num_groups}", + f"fwd_only={fwd_only}", + f"single_grouped_weight={single_grouped_weight}", + f"accumulate_into_main_grad={accumulate_into_main_grad}", + f"glu_interleave_size={glu_interleave_size}", + ) + print(f"m_splits: {split_sizes_list}") + + # Warmup also forces the op-fuser to materialize the expected fused ops. + run_grouped_mlp_steps( + module, + x, + split_sizes, + scales, + grad_output, + recipe=recipe, + fwd_only=fwd_only, + num_steps=128, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + torch.cuda.synchronize() + + forward_ops = module._module_groups[0]._forward_ops + print("forward fused op:", type(forward_ops[0][0]).__name__ if forward_ops else "none") + if not fwd_only: + backward_ops = module._module_groups[0]._backward_ops + print("backward fused op:", type(backward_ops[0][0]).__name__ if backward_ops else "none") + + label = "graph_safe_grouped_mlp_mxfp8_swiglu" + timing_context = torch.autograd.profiler.emit_nvtx(record_shapes=True) if profile else nullcontext() + with timing_context: + torch.cuda.nvtx.range_push(label) + timing = benchmark.Timer( + stmt=( + "run_grouped_mlp_steps(" + "module, x, split_sizes, scales, grad_output, " + "recipe=recipe, fwd_only=fwd_only, num_steps=num_microbatches, " + "accumulate_into_main_grad=accumulate_into_main_grad)" + ), + globals={ + "run_grouped_mlp_steps": run_grouped_mlp_steps, + "module": module, + "x": x, + "split_sizes": split_sizes, + "scales": scales, + "grad_output": grad_output, + "recipe": recipe, + "fwd_only": fwd_only, + "num_microbatches": num_microbatches, + "accumulate_into_main_grad": accumulate_into_main_grad, + }, + num_threads=1, + ).blocked_autorange(min_run_time=min_run_time) + torch.cuda.nvtx.range_pop() + + print(f"mxfp8_swiglu: {timing}\n") + return timing.median * 1000 / num_microbatches + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--profile", action="store_true", help="Enable NVTX profiling annotations") + parser.add_argument( + "--fwd-only", + action="store_true", + default=False, + help="Benchmark forward only. Default benchmarks forward + backward.", + ) + parser.add_argument( + "--num-groups", + type=str, + default="8", + help="Comma-separated local grouped GEMM/expert counts.", + ) + parser.add_argument( + "--token-dims", + type=str, + default="65536", + help="Comma-separated total token counts to benchmark.", + ) + parser.add_argument("--hidden-dim", type=int, default=7168) + parser.add_argument("--ffn-hidden-dim", type=int, default=2048) + parser.add_argument("--num-microbatches", type=int, default=32) + parser.add_argument("--min-run-time", type=float, default=10.0) + parser.add_argument("--glu-interleave-size", type=int, default=32) + parser.add_argument( + "--single-grouped-weight", + action="store_true", + default=False, + help="Use one GroupedTensor parameter for each grouped linear.", + ) + args = parser.parse_args() + + if not MXFP8_AVAILABLE: + raise RuntimeError(f"MXFP8 is not available: {REASON_FOR_NO_MXFP8}") + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this benchmark.") + + dtype = torch.bfloat16 + accumulate_into_main_grad = True + token_dims = parse_int_list(args.token_dims) + num_groups_list = parse_int_list(args.num_groups) + + print("Environment toggles:") + for name in ( + "CUDA_DEVICE_MAX_CONNECTIONS", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO", + "NVTE_CUTEDSL_FUSED_GROUPED_MLP", + "CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", + ): + print(f" {name}={os.environ.get(name)}") + print("Recipe: MXFP8BlockScaling") + print("Activation: ScaledSwiGLU") + print(f"Default GLU interleave size: {args.glu_interleave_size}") + print() + + data = [] + for num_groups in num_groups_list: + for total_tokens in token_dims: + timing_ms = benchmark_case( + total_tokens=total_tokens, + hidden_dim=args.hidden_dim, + ffn_hidden_dim=args.ffn_hidden_dim, + num_groups=num_groups, + dtype=dtype, + fwd_only=args.fwd_only, + single_grouped_weight=args.single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + glu_interleave_size=args.glu_interleave_size, + num_microbatches=args.num_microbatches, + min_run_time=args.min_run_time, + profile=args.profile, + ) + data.append( + [ + total_tokens, + args.hidden_dim, + args.ffn_hidden_dim, + num_groups, + args.glu_interleave_size, + args.single_grouped_weight, + accumulate_into_main_grad, + "fwd" if args.fwd_only else "fwd_bwd", + timing_ms, + ] + ) + + timing_col = "time_per_microbatch_ms" + df = pd.DataFrame( + data=data, + columns=[ + "tokens", + "hidden_dim", + "ffn_hidden_dim", + "num_groups", + "glu_interleave_size", + "single_grouped_weight", + "accumulate_into_main_grad", + "mode", + timing_col, + ], + ) + print(df) + + +if __name__ == "__main__": + main() From 1a32a1f74bb2d8ce7e0aaf708ca77e73877f931f Mon Sep 17 00:00:00 2001 From: zhongboz Date: Thu, 14 May 2026 19:19:11 -0700 Subject: [PATCH 4/9] update test Signed-off-by: zhongboz --- tests/pytorch/test_grouped_tensor.py | 31 ++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index c8dd51603d..4a82319d1b 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -171,6 +171,21 @@ def test_basic_construction_varying_first_dim(self) -> None: pytest.param([0, 3, 5, 0], 11, id="zero_edges"), pytest.param([1], 17, id="single_group"), pytest.param([1, 2, 3, 4, 5, 6, 7, 8], 13, id="many_groups"), + # MoE-style group counts. ``split_points`` (an int32[num_groups] + # tensor packed into a shared buffer alongside int64 outputs) used + # to land at an 8-byte-aligned offset for these counts, which + # tripped cuDNN's 16-byte alignment requirement in grouped GEMM. + pytest.param([8192] * 8, 2048, id="num_groups_8_uniform"), + pytest.param([4096] * 16, 4096, id="num_groups_16_uniform"), + pytest.param([2048] * 32, 7168, id="num_groups_32_uniform"), + pytest.param([1024] * 64, 7168, id="num_groups_64_uniform"), + pytest.param([512] * 128, 7168, id="num_groups_128_uniform"), + # Non-uniform with large totals to also exercise tensor_offsets > 2^31. + pytest.param( + [12345, 0, 8192, 1, 65536, 100, 131072, 7], + 7168, + id="non_uniform_large_totals", + ), ], ) @pytest.mark.parametrize("input_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) @@ -211,6 +226,22 @@ def test_prepare_grouped_splits( assert torch.equal(split_points, expected_split_points) assert torch.equal(tensor_offsets, expected_tensor_offsets) + # cuDNN CuTe-DSL grouped GEMM kernels require 16-byte-aligned data + # pointers for every tensor argument. ``split_points`` used to land at + # an 8-byte-aligned offset inside the bulk buffer; pin the fix here so + # any regression in ``prepare_grouped_splits`` / ``bulk_allocate`` + # alignment is caught immediately instead of surfacing as a runtime + # "Misaligned Tensor data" error from cuDNN. + for name, tensor in ( + ("split_sizes_i64", split_sizes_i64), + ("base_offsets", base_offsets), + ("split_points", split_points), + ("tensor_offsets", tensor_offsets), + ): + assert tensor.data_ptr() % 16 == 0, ( + f"{name} data_ptr is not 16-byte aligned: {tensor.data_ptr():#x}" + ) + def test_split_into_quantized_tensors_no_quantization(self) -> None: """Test split_into_quantized_tensors for unquantized tensors""" num_tensors = 3 From 07afb08a50509d7859f58df33e904310f5891502 Mon Sep 17 00:00:00 2001 From: zhongboz Date: Thu, 14 May 2026 19:21:39 -0700 Subject: [PATCH 5/9] alignment Signed-off-by: zhongboz --- transformer_engine/pytorch/csrc/extensions/misc.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index 020cc2e051..83b3e49484 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -51,12 +51,17 @@ std::vector prepare_grouped_splits(const at::Tensor &split_sizes, in // for cuDNN grouped GEMM padded offsets. This is intentionally int32 // even though TE grouped tensor metadata uses int64 below. // 3. tensor_offsets: int64[num_groups + 1], base_offsets * logical_last_dim. + // + // Force 16-byte alignment on every output so ``split_points`` (consumed by + // cuDNN CuTe-DSL grouped GEMM as ``padded_offsets``, which requires 16-byte + // alignment) lands on a 16-byte boundary inside the bulk buffer. + std::vector alignments = {16, 16, 16, 16}; auto outputs = bulk_allocate({{static_cast(num_groups)}, {static_cast(offsets_length)}, {static_cast(num_groups)}, {static_cast(offsets_length)}}, {at::kLong, at::kLong, at::kInt, at::kLong}, split_sizes.device(), - std::nullopt); + alignments); auto split_sizes_i64 = outputs[0]; auto base_offsets = outputs[1]; auto split_points = outputs[2]; From 56fd72abe83ce3692662b5ad6effbb58fb517b09 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 May 2026 02:23:28 +0000 Subject: [PATCH 6/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../linear/benchmark_graph_safe_grouped_linear.py | 14 +++++++++----- tests/pytorch/test_grouped_tensor.py | 6 +++--- .../pytorch/csrc/extensions/misc.cpp | 12 ++++++------ 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/benchmarks/linear/benchmark_graph_safe_grouped_linear.py b/benchmarks/linear/benchmark_graph_safe_grouped_linear.py index 26fb232c0a..d8230c38fe 100644 --- a/benchmarks/linear/benchmark_graph_safe_grouped_linear.py +++ b/benchmarks/linear/benchmark_graph_safe_grouped_linear.py @@ -19,10 +19,10 @@ python benchmarks/linear/benchmark_graph_safe_grouped_linear.py --fwd-only -Nsight Systems: +Nsight Systems: (optionally: unset DEBUGINFOD_URLS) - + nsys profile \ --output=./benchmarks/linear/graph_safe_grouped_linear_mxfp8 \ --force-overwrite true \ @@ -63,7 +63,7 @@ def make_uniform_splits(total_tokens: int, num_groups: int) -> list[int]: """Split tokens uniformly across groups.""" if total_tokens % num_groups != 0: raise ValueError( - f"Uniform split requires total_tokens divisible by num_groups, " + "Uniform split requires total_tokens divisible by num_groups, " f"got total_tokens={total_tokens}, num_groups={num_groups}" ) return [total_tokens // num_groups] * num_groups @@ -114,7 +114,9 @@ def init_main_grads(module: torch.nn.Module, value: float = 0.0) -> None: with torch.no_grad(): for param in module.parameters(): if getattr(param, "main_grad", None) is None: - param.main_grad = torch.empty(param.size(), device=param.device, dtype=torch.float32) + param.main_grad = torch.empty( + param.size(), device=param.device, dtype=torch.float32 + ) param.main_grad.fill_(value) @@ -236,7 +238,9 @@ def benchmark_case( print("backward fused op:", type(backward_ops[0][0]).__name__ if backward_ops else "none") label = "graph_safe_grouped_mlp_mxfp8_swiglu" - timing_context = torch.autograd.profiler.emit_nvtx(record_shapes=True) if profile else nullcontext() + timing_context = ( + torch.autograd.profiler.emit_nvtx(record_shapes=True) if profile else nullcontext() + ) with timing_context: torch.cuda.nvtx.range_push(label) timing = benchmark.Timer( diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index 4a82319d1b..1dd008926e 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -238,9 +238,9 @@ def test_prepare_grouped_splits( ("split_points", split_points), ("tensor_offsets", tensor_offsets), ): - assert tensor.data_ptr() % 16 == 0, ( - f"{name} data_ptr is not 16-byte aligned: {tensor.data_ptr():#x}" - ) + assert ( + tensor.data_ptr() % 16 == 0 + ), f"{name} data_ptr is not 16-byte aligned: {tensor.data_ptr():#x}" def test_split_into_quantized_tensors_no_quantization(self) -> None: """Test split_into_quantized_tensors for unquantized tensors""" diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index 83b3e49484..4181522895 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -56,12 +56,12 @@ std::vector prepare_grouped_splits(const at::Tensor &split_sizes, in // cuDNN CuTe-DSL grouped GEMM as ``padded_offsets``, which requires 16-byte // alignment) lands on a 16-byte boundary inside the bulk buffer. std::vector alignments = {16, 16, 16, 16}; - auto outputs = bulk_allocate({{static_cast(num_groups)}, - {static_cast(offsets_length)}, - {static_cast(num_groups)}, - {static_cast(offsets_length)}}, - {at::kLong, at::kLong, at::kInt, at::kLong}, split_sizes.device(), - alignments); + auto outputs = + bulk_allocate({{static_cast(num_groups)}, + {static_cast(offsets_length)}, + {static_cast(num_groups)}, + {static_cast(offsets_length)}}, + {at::kLong, at::kLong, at::kInt, at::kLong}, split_sizes.device(), alignments); auto split_sizes_i64 = outputs[0]; auto base_offsets = outputs[1]; auto split_points = outputs[2]; From bf5a8655aa1ecd1f4da3ba192f3744d0c289af45 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 15 May 2026 01:43:27 -0700 Subject: [PATCH 7/9] handle device of split_sizes Signed-off-by: Zhongbo Zhu --- tests/pytorch/test_grouped_tensor.py | 10 +++++++-- .../pytorch/csrc/extensions/misc.cpp | 21 +++++++++++++++---- .../pytorch/csrc/extensions/pybind.cpp | 4 ++-- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index 1dd008926e..4a871dc6a3 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -189,14 +189,16 @@ def test_basic_construction_varying_first_dim(self) -> None: ], ) @pytest.mark.parametrize("input_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) + @pytest.mark.parametrize("input_device", ["cuda", "cpu"], ids=["cuda", "cpu"]) def test_prepare_grouped_splits( self, + input_device: str, input_dtype: torch.dtype, split_sizes_list: List[int], logical_last_dim: int, ) -> None: """Test fused grouped split metadata preparation.""" - split_sizes = torch.tensor(split_sizes_list, dtype=input_dtype, device="cuda") + split_sizes = torch.tensor(split_sizes_list, dtype=input_dtype, device=input_device) num_groups = split_sizes.numel() ( @@ -206,7 +208,7 @@ def test_prepare_grouped_splits( tensor_offsets, ) = tex.prepare_grouped_splits(split_sizes, num_groups, logical_last_dim) - expected_split_sizes = split_sizes.to(torch.int64) + expected_split_sizes = split_sizes.to(device="cuda", dtype=torch.int64) expected_base_offsets = torch.cat( ( torch.zeros(1, dtype=torch.int64, device="cuda"), @@ -221,6 +223,10 @@ def test_prepare_grouped_splits( # cuDNN grouped GEMM consumes int32 end offsets; TE GroupedTensor metadata stays int64. assert split_points.dtype == torch.int32 assert tensor_offsets.dtype == torch.int64 + assert split_sizes_i64.device.type == "cuda" + assert base_offsets.device.type == "cuda" + assert split_points.device.type == "cuda" + assert tensor_offsets.device.type == "cuda" assert torch.equal(split_sizes_i64, expected_split_sizes) assert torch.equal(base_offsets, expected_base_offsets) assert torch.equal(split_points, expected_split_points) diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index 4181522895..9f55b27bb7 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -33,14 +33,27 @@ at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_ std::vector prepare_grouped_splits(const at::Tensor &split_sizes, int64_t num_groups, int64_t logical_last_dim) { - NVTE_CHECK(split_sizes.is_cuda(), "split_sizes must be on CUDA."); + NVTE_CHECK(split_sizes.scalar_type() == at::kInt || split_sizes.scalar_type() == at::kLong, "split_sizes must have dtype int32 or int64."); NVTE_CHECK(split_sizes.dim() == 1, "split_sizes must be a 1D tensor."); - NVTE_CHECK(split_sizes.is_contiguous(), "split_sizes must be contiguous."); NVTE_CHECK(num_groups > 0, "num_groups must be greater than 0."); NVTE_CHECK(split_sizes.numel() == num_groups, "split_sizes must have length ", num_groups, "."); NVTE_CHECK(logical_last_dim >= 0, "logical_last_dim must be non-negative."); + const c10::Device device = c10::Device(c10::kCUDA, c10::cuda::current_device()); + + at::Tensor split_sizes_for_kernel; + if (split_sizes.is_cuda()) { + NVTE_CHECK(split_sizes.device() == device, "CUDA split_sizes must be on current device ", device, + ", but got ", split_sizes.device(), "."); + split_sizes_for_kernel = split_sizes; + } else { + // Preserve the legacy eager path: host m_splits are copied to the target + // CUDA device here, then all derived metadata is produced by one CUDA kernel. + split_sizes_for_kernel = + split_sizes.to(at::TensorOptions().dtype(split_sizes.scalar_type()).device(device), + /*non_blocking=*/true); + } const int64_t offsets_length = num_groups + 1; @@ -61,13 +74,13 @@ std::vector prepare_grouped_splits(const at::Tensor &split_sizes, in {static_cast(offsets_length)}, {static_cast(num_groups)}, {static_cast(offsets_length)}}, - {at::kLong, at::kLong, at::kInt, at::kLong}, split_sizes.device(), alignments); + {at::kLong, at::kLong, at::kInt, at::kLong}, device, alignments); auto split_sizes_i64 = outputs[0]; auto base_offsets = outputs[1]; auto split_points = outputs[2]; auto tensor_offsets = outputs[3]; - auto split_sizes_nvte = makeTransformerEngineTensor(split_sizes); + auto split_sizes_nvte = makeTransformerEngineTensor(split_sizes_for_kernel); auto split_sizes_i64_nvte = makeTransformerEngineTensor(split_sizes_i64); auto base_offsets_nvte = makeTransformerEngineTensor(base_offsets); auto split_points_nvte = makeTransformerEngineTensor(split_points); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index ca1a72ddf5..2f60bdf8f9 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -498,8 +498,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Compute grouped tensor offsets from split sizes", py::arg("first_dims"), py::arg("logical_last_dim"), py::call_guard()); m.def("prepare_grouped_splits", &transformer_engine::pytorch::prepare_grouped_splits, - "Prepare grouped split metadata from int32 or int64 split sizes", py::arg("split_sizes"), - py::arg("num_groups"), py::arg("logical_last_dim")); + "Prepare grouped split metadata from CPU/CUDA int32 or int64 split sizes", + py::arg("split_sizes"), py::arg("num_groups"), py::arg("logical_last_dim")); m.def("get_num_cublas_streams", &nvte_get_num_compute_streams, "Get number of compute streams", py::call_guard()); From b78299179640968d6a9e926ba0086f295b8aa87e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 May 2026 08:44:22 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/csrc/extensions/misc.cpp | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index 9f55b27bb7..29d9c9da8b 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -33,7 +33,6 @@ at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_ std::vector prepare_grouped_splits(const at::Tensor &split_sizes, int64_t num_groups, int64_t logical_last_dim) { - NVTE_CHECK(split_sizes.scalar_type() == at::kInt || split_sizes.scalar_type() == at::kLong, "split_sizes must have dtype int32 or int64."); NVTE_CHECK(split_sizes.dim() == 1, "split_sizes must be a 1D tensor."); @@ -44,8 +43,8 @@ std::vector prepare_grouped_splits(const at::Tensor &split_sizes, in at::Tensor split_sizes_for_kernel; if (split_sizes.is_cuda()) { - NVTE_CHECK(split_sizes.device() == device, "CUDA split_sizes must be on current device ", device, - ", but got ", split_sizes.device(), "."); + NVTE_CHECK(split_sizes.device() == device, "CUDA split_sizes must be on current device ", + device, ", but got ", split_sizes.device(), "."); split_sizes_for_kernel = split_sizes; } else { // Preserve the legacy eager path: host m_splits are copied to the target @@ -69,12 +68,11 @@ std::vector prepare_grouped_splits(const at::Tensor &split_sizes, in // cuDNN CuTe-DSL grouped GEMM as ``padded_offsets``, which requires 16-byte // alignment) lands on a 16-byte boundary inside the bulk buffer. std::vector alignments = {16, 16, 16, 16}; - auto outputs = - bulk_allocate({{static_cast(num_groups)}, - {static_cast(offsets_length)}, - {static_cast(num_groups)}, - {static_cast(offsets_length)}}, - {at::kLong, at::kLong, at::kInt, at::kLong}, device, alignments); + auto outputs = bulk_allocate({{static_cast(num_groups)}, + {static_cast(offsets_length)}, + {static_cast(num_groups)}, + {static_cast(offsets_length)}}, + {at::kLong, at::kLong, at::kInt, at::kLong}, device, alignments); auto split_sizes_i64 = outputs[0]; auto base_offsets = outputs[1]; auto split_points = outputs[2]; From 1025557cf256c07916887a8455701ba70229f4dc Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 15 May 2026 02:13:32 -0700 Subject: [PATCH 9/9] fix Signed-off-by: Zhongbo Zhu --- transformer_engine/pytorch/csrc/extensions/misc.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index 29d9c9da8b..828746180a 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -43,8 +43,8 @@ std::vector prepare_grouped_splits(const at::Tensor &split_sizes, in at::Tensor split_sizes_for_kernel; if (split_sizes.is_cuda()) { - NVTE_CHECK(split_sizes.device() == device, "CUDA split_sizes must be on current device ", - device, ", but got ", split_sizes.device(), "."); + NVTE_CHECK(split_sizes.device() == device, "CUDA split_sizes must be on current CUDA device ", + device.index(), ", but got CUDA device ", split_sizes.device().index(), "."); split_sizes_for_kernel = split_sizes; } else { // Preserve the legacy eager path: host m_splits are copied to the target