diff --git a/benchmarks/linear/benchmark_graph_safe_grouped_linear.py b/benchmarks/linear/benchmark_graph_safe_grouped_linear.py new file mode 100644 index 0000000000..d8230c38fe --- /dev/null +++ b/benchmarks/linear/benchmark_graph_safe_grouped_linear.py @@ -0,0 +1,380 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Benchmark MXFP8 graph-safe grouped MLP. + +This mirrors ``benchmark_grouped_linear.py`` but targets the graph-safe TE ops +path used by grouped MLP: + + GroupedLinear -> ScaledSwiGLU -> GroupedLinear + +The benchmark intentionally uses CUDA-device ``m_splits`` and MXFP8 only. + +Example: + + python benchmarks/linear/benchmark_graph_safe_grouped_linear.py + +Forward-only: + + python benchmarks/linear/benchmark_graph_safe_grouped_linear.py --fwd-only + +Nsight Systems: + + (optionally: unset DEBUGINFOD_URLS) + + nsys profile \ + --output=./benchmarks/linear/graph_safe_grouped_linear_mxfp8 \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_graph_safe_grouped_linear.py --profile +""" + +# Match the Qwen MXFP8 SFT launch toggles before importing TE. +import os + +os.environ.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "1") +os.environ.setdefault("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1") +os.environ.setdefault("NVTE_CUTEDSL_FUSED_GROUPED_MLP", "1") +os.environ.setdefault("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") + +import argparse +from contextlib import nullcontext + +import pandas as pd +import torch +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.common.recipe import MXFP8BlockScaling +from transformer_engine.pytorch.quantization import FP8GlobalStateManager + + +MXFP8_AVAILABLE, REASON_FOR_NO_MXFP8 = FP8GlobalStateManager.is_mxfp8_available() + + +def parse_int_list(value: str) -> list[int]: + """Parse comma-separated integers.""" + return [int(x) for x in value.split(",") if x] + + +def make_uniform_splits(total_tokens: int, num_groups: int) -> list[int]: + """Split tokens uniformly across groups.""" + if total_tokens % num_groups != 0: + raise ValueError( + "Uniform split requires total_tokens divisible by num_groups, " + f"got total_tokens={total_tokens}, num_groups={num_groups}" + ) + return [total_tokens // num_groups] * num_groups + + +def build_grouped_mlp( + *, + num_groups: int, + hidden_dim: int, + ffn_hidden_dim: int, + dtype: torch.dtype, + single_grouped_weight: bool, + accumulate_into_main_grad: bool, + glu_interleave_size: int, +) -> te_ops.Sequential: + """Build graph-safe grouped MLP ops sequence.""" + recipe = MXFP8BlockScaling() + with te.quantized_model_init(enabled=True, recipe=recipe): + fc1 = te_ops.GroupedLinear( + num_groups, + hidden_dim, + 2 * ffn_hidden_dim, + bias=False, + device="cuda", + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + fc2 = te_ops.GroupedLinear( + num_groups, + ffn_hidden_dim, + hidden_dim, + bias=False, + device="cuda", + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + return te_ops.Sequential( + fc1, + te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size), + fc2, + ) + + +def init_main_grads(module: torch.nn.Module, value: float = 0.0) -> None: + """Initialize Megatron-style main_grad buffers for accumulate_into_main_grad.""" + with torch.no_grad(): + for param in module.parameters(): + if getattr(param, "main_grad", None) is None: + param.main_grad = torch.empty( + param.size(), device=param.device, dtype=torch.float32 + ) + param.main_grad.fill_(value) + + +def zero_grads(module: torch.nn.Module, x: torch.Tensor, scales: torch.Tensor) -> None: + """Reset gradients without changing allocated main_grad buffers.""" + module.zero_grad(set_to_none=True) + x.grad = None + scales.grad = None + + +def run_grouped_mlp_steps( + module: torch.nn.Module, + x: torch.Tensor, + split_sizes: torch.Tensor, + scales: torch.Tensor, + grad_output: torch.Tensor, + *, + recipe: MXFP8BlockScaling, + fwd_only: bool, + num_steps: int, + accumulate_into_main_grad: bool, +) -> torch.Tensor: + """Run eager grouped MLP for a number of synthetic microbatches.""" + quantization_context = te.autocast(enabled=True, recipe=recipe) + + if fwd_only: + with torch.no_grad(), quantization_context: + for _ in range(num_steps): + out = module(x, split_sizes, scales, split_sizes) + return out + + zero_grads(module, x, scales) + if accumulate_into_main_grad: + init_main_grads(module) + + with quantization_context: + for step in range(num_steps): + torch.cuda.nvtx.range_push(f"step_{step}") + out = module(x, split_sizes, scales, split_sizes) + out.backward(grad_output) + torch.cuda.nvtx.range_pop() + return out + + +def benchmark_case( + *, + total_tokens: int, + hidden_dim: int, + ffn_hidden_dim: int, + num_groups: int, + dtype: torch.dtype, + fwd_only: bool, + single_grouped_weight: bool, + accumulate_into_main_grad: bool, + glu_interleave_size: int, + num_microbatches: int, + min_run_time: float, + profile: bool, +) -> float: + """Benchmark one grouped MLP shape.""" + split_sizes_list = make_uniform_splits(total_tokens, num_groups) + split_sizes = torch.tensor(split_sizes_list, dtype=torch.int64, device="cuda") + x = torch.randn( + (total_tokens, hidden_dim), + dtype=dtype, + device="cuda", + requires_grad=not fwd_only, + ) + scales = torch.ones( + (total_tokens,), + dtype=dtype, + device="cuda", + requires_grad=not fwd_only, + ) + grad_output = torch.ones((total_tokens, hidden_dim), dtype=dtype, device="cuda") + + module = build_grouped_mlp( + num_groups=num_groups, + hidden_dim=hidden_dim, + ffn_hidden_dim=ffn_hidden_dim, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + glu_interleave_size=glu_interleave_size, + ) + recipe = MXFP8BlockScaling() + + print( + "case:", + f"tokens={total_tokens}", + f"hidden={hidden_dim}", + f"ffn_hidden={ffn_hidden_dim}", + f"num_groups={num_groups}", + f"fwd_only={fwd_only}", + f"single_grouped_weight={single_grouped_weight}", + f"accumulate_into_main_grad={accumulate_into_main_grad}", + f"glu_interleave_size={glu_interleave_size}", + ) + print(f"m_splits: {split_sizes_list}") + + # Warmup also forces the op-fuser to materialize the expected fused ops. + run_grouped_mlp_steps( + module, + x, + split_sizes, + scales, + grad_output, + recipe=recipe, + fwd_only=fwd_only, + num_steps=128, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + torch.cuda.synchronize() + + forward_ops = module._module_groups[0]._forward_ops + print("forward fused op:", type(forward_ops[0][0]).__name__ if forward_ops else "none") + if not fwd_only: + backward_ops = module._module_groups[0]._backward_ops + print("backward fused op:", type(backward_ops[0][0]).__name__ if backward_ops else "none") + + label = "graph_safe_grouped_mlp_mxfp8_swiglu" + timing_context = ( + torch.autograd.profiler.emit_nvtx(record_shapes=True) if profile else nullcontext() + ) + with timing_context: + torch.cuda.nvtx.range_push(label) + timing = benchmark.Timer( + stmt=( + "run_grouped_mlp_steps(" + "module, x, split_sizes, scales, grad_output, " + "recipe=recipe, fwd_only=fwd_only, num_steps=num_microbatches, " + "accumulate_into_main_grad=accumulate_into_main_grad)" + ), + globals={ + "run_grouped_mlp_steps": run_grouped_mlp_steps, + "module": module, + "x": x, + "split_sizes": split_sizes, + "scales": scales, + "grad_output": grad_output, + "recipe": recipe, + "fwd_only": fwd_only, + "num_microbatches": num_microbatches, + "accumulate_into_main_grad": accumulate_into_main_grad, + }, + num_threads=1, + ).blocked_autorange(min_run_time=min_run_time) + torch.cuda.nvtx.range_pop() + + print(f"mxfp8_swiglu: {timing}\n") + return timing.median * 1000 / num_microbatches + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--profile", action="store_true", help="Enable NVTX profiling annotations") + parser.add_argument( + "--fwd-only", + action="store_true", + default=False, + help="Benchmark forward only. Default benchmarks forward + backward.", + ) + parser.add_argument( + "--num-groups", + type=str, + default="8", + help="Comma-separated local grouped GEMM/expert counts.", + ) + parser.add_argument( + "--token-dims", + type=str, + default="65536", + help="Comma-separated total token counts to benchmark.", + ) + parser.add_argument("--hidden-dim", type=int, default=7168) + parser.add_argument("--ffn-hidden-dim", type=int, default=2048) + parser.add_argument("--num-microbatches", type=int, default=32) + parser.add_argument("--min-run-time", type=float, default=10.0) + parser.add_argument("--glu-interleave-size", type=int, default=32) + parser.add_argument( + "--single-grouped-weight", + action="store_true", + default=False, + help="Use one GroupedTensor parameter for each grouped linear.", + ) + args = parser.parse_args() + + if not MXFP8_AVAILABLE: + raise RuntimeError(f"MXFP8 is not available: {REASON_FOR_NO_MXFP8}") + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this benchmark.") + + dtype = torch.bfloat16 + accumulate_into_main_grad = True + token_dims = parse_int_list(args.token_dims) + num_groups_list = parse_int_list(args.num_groups) + + print("Environment toggles:") + for name in ( + "CUDA_DEVICE_MAX_CONNECTIONS", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO", + "NVTE_CUTEDSL_FUSED_GROUPED_MLP", + "CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", + ): + print(f" {name}={os.environ.get(name)}") + print("Recipe: MXFP8BlockScaling") + print("Activation: ScaledSwiGLU") + print(f"Default GLU interleave size: {args.glu_interleave_size}") + print() + + data = [] + for num_groups in num_groups_list: + for total_tokens in token_dims: + timing_ms = benchmark_case( + total_tokens=total_tokens, + hidden_dim=args.hidden_dim, + ffn_hidden_dim=args.ffn_hidden_dim, + num_groups=num_groups, + dtype=dtype, + fwd_only=args.fwd_only, + single_grouped_weight=args.single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + glu_interleave_size=args.glu_interleave_size, + num_microbatches=args.num_microbatches, + min_run_time=args.min_run_time, + profile=args.profile, + ) + data.append( + [ + total_tokens, + args.hidden_dim, + args.ffn_hidden_dim, + num_groups, + args.glu_interleave_size, + args.single_grouped_weight, + accumulate_into_main_grad, + "fwd" if args.fwd_only else "fwd_bwd", + timing_ms, + ] + ) + + timing_col = "time_per_microbatch_ms" + df = pd.DataFrame( + data=data, + columns=[ + "tokens", + "hidden_dim", + "ffn_hidden_dim", + "num_groups", + "glu_interleave_size", + "single_grouped_weight", + "accumulate_into_main_grad", + "mode", + timing_col, + ], + ) + print(df) + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index c54c9758ff..4a871dc6a3 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -163,6 +163,91 @@ def test_basic_construction_varying_first_dim(self) -> None: shape[0][1], ) # sum of first dims + @pytest.mark.parametrize( + "split_sizes_list,logical_last_dim", + [ + pytest.param([3, 4, 5, 2], 7, id="all_nonzero"), + pytest.param([3, 0, 5, 2], 7, id="zero_middle"), + pytest.param([0, 3, 5, 0], 11, id="zero_edges"), + pytest.param([1], 17, id="single_group"), + pytest.param([1, 2, 3, 4, 5, 6, 7, 8], 13, id="many_groups"), + # MoE-style group counts. ``split_points`` (an int32[num_groups] + # tensor packed into a shared buffer alongside int64 outputs) used + # to land at an 8-byte-aligned offset for these counts, which + # tripped cuDNN's 16-byte alignment requirement in grouped GEMM. + pytest.param([8192] * 8, 2048, id="num_groups_8_uniform"), + pytest.param([4096] * 16, 4096, id="num_groups_16_uniform"), + pytest.param([2048] * 32, 7168, id="num_groups_32_uniform"), + pytest.param([1024] * 64, 7168, id="num_groups_64_uniform"), + pytest.param([512] * 128, 7168, id="num_groups_128_uniform"), + # Non-uniform with large totals to also exercise tensor_offsets > 2^31. + pytest.param( + [12345, 0, 8192, 1, 65536, 100, 131072, 7], + 7168, + id="non_uniform_large_totals", + ), + ], + ) + @pytest.mark.parametrize("input_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) + @pytest.mark.parametrize("input_device", ["cuda", "cpu"], ids=["cuda", "cpu"]) + def test_prepare_grouped_splits( + self, + input_device: str, + input_dtype: torch.dtype, + split_sizes_list: List[int], + logical_last_dim: int, + ) -> None: + """Test fused grouped split metadata preparation.""" + split_sizes = torch.tensor(split_sizes_list, dtype=input_dtype, device=input_device) + num_groups = split_sizes.numel() + + ( + split_sizes_i64, + base_offsets, + split_points, + tensor_offsets, + ) = tex.prepare_grouped_splits(split_sizes, num_groups, logical_last_dim) + + expected_split_sizes = split_sizes.to(device="cuda", dtype=torch.int64) + expected_base_offsets = torch.cat( + ( + torch.zeros(1, dtype=torch.int64, device="cuda"), + torch.cumsum(expected_split_sizes, dim=0), + ) + ) + expected_split_points = expected_base_offsets[1:].to(torch.int32) + expected_tensor_offsets = expected_base_offsets * logical_last_dim + + assert split_sizes_i64.dtype == torch.int64 + assert base_offsets.dtype == torch.int64 + # cuDNN grouped GEMM consumes int32 end offsets; TE GroupedTensor metadata stays int64. + assert split_points.dtype == torch.int32 + assert tensor_offsets.dtype == torch.int64 + assert split_sizes_i64.device.type == "cuda" + assert base_offsets.device.type == "cuda" + assert split_points.device.type == "cuda" + assert tensor_offsets.device.type == "cuda" + assert torch.equal(split_sizes_i64, expected_split_sizes) + assert torch.equal(base_offsets, expected_base_offsets) + assert torch.equal(split_points, expected_split_points) + assert torch.equal(tensor_offsets, expected_tensor_offsets) + + # cuDNN CuTe-DSL grouped GEMM kernels require 16-byte-aligned data + # pointers for every tensor argument. ``split_points`` used to land at + # an 8-byte-aligned offset inside the bulk buffer; pin the fix here so + # any regression in ``prepare_grouped_splits`` / ``bulk_allocate`` + # alignment is caught immediately instead of surfacing as a runtime + # "Misaligned Tensor data" error from cuDNN. + for name, tensor in ( + ("split_sizes_i64", split_sizes_i64), + ("base_offsets", base_offsets), + ("split_points", split_points), + ("tensor_offsets", tensor_offsets), + ): + assert ( + tensor.data_ptr() % 16 == 0 + ), f"{name} data_ptr is not 16-byte aligned: {tensor.data_ptr():#x}" + def test_split_into_quantized_tensors_no_quantization(self) -> None: """Test split_into_quantized_tensors for unquantized tensors""" num_tensors = 3 diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 1bdd80a369..2de42ab2dc 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -129,6 +129,60 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } } +template +__global__ void __launch_bounds__(kThreadsPerBlock) + prepare_grouped_splits_kernel(const FirstDimT *__restrict__ first_dims, + int64_t *__restrict__ first_dims_i64, + int64_t *__restrict__ base_offsets, + int32_t *__restrict__ split_points, + int64_t *__restrict__ tensor_offsets, int64_t logical_last_dim, + size_t num_tensors) { + __shared__ int64_t block_scan[kThreadsPerBlock]; + __shared__ int64_t chunk_prefix; + + const size_t tid = threadIdx.x; + if (tid == 0) { + base_offsets[0] = 0; + tensor_offsets[0] = 0; + chunk_prefix = 0; + } + __syncthreads(); + + for (size_t chunk_start = 0; chunk_start < num_tensors; chunk_start += kThreadsPerBlock) { + const size_t idx = chunk_start + tid; + + block_scan[tid] = 0; + if (idx < num_tensors) { + block_scan[tid] = static_cast(first_dims[idx]); + first_dims_i64[idx] = block_scan[tid]; + } + __syncthreads(); + + // Inclusive scan in shared memory. + for (size_t offset = 1; offset < kThreadsPerBlock; offset <<= 1) { + const int64_t addend = (tid >= offset) ? block_scan[tid - offset] : 0; + __syncthreads(); + block_scan[tid] += addend; + __syncthreads(); + } + + if (idx < num_tensors) { + const int64_t prefix = chunk_prefix + block_scan[tid]; + base_offsets[idx + 1] = prefix; + // cuDNN grouped GEMM expects padded split end offsets as int32. TE + // GroupedTensor metadata keeps the full int64 base_offsets/tensor_offsets. + split_points[idx] = static_cast(prefix); + tensor_offsets[idx + 1] = prefix * logical_last_dim; + } + __syncthreads(); + + if (tid == kThreadsPerBlock - 1) { + chunk_prefix += block_scan[tid]; + } + __syncthreads(); + } +} + } // namespace #define MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, vectorizedType, stream) \ @@ -171,6 +225,59 @@ void nvte_splits_to_offsets(const int64_t *first_dims, int64_t *output, size_t n logical_last_dim); NVTE_CHECK_CUDA(cudaGetLastError()); } + +void nvte_prepare_grouped_splits(const NVTETensor first_dims, NVTETensor first_dims_i64, + NVTETensor base_offsets, NVTETensor split_points, + NVTETensor tensor_offsets, int64_t logical_last_dim, + cudaStream_t stream) { + NVTE_API_CALL(nvte_prepare_grouped_splits); + + const auto *first_dims_tensor = convertNVTETensorCheck(first_dims); + const auto *first_dims_i64_tensor = convertNVTETensorCheck(first_dims_i64); + const auto *base_offsets_tensor = convertNVTETensorCheck(base_offsets); + const auto *split_points_tensor = convertNVTETensorCheck(split_points); + const auto *tensor_offsets_tensor = convertNVTETensorCheck(tensor_offsets); + const auto first_dims_dtype = first_dims_tensor->dtype(); + const auto num_tensors = first_dims_tensor->numel(); + const auto offsets_numel = num_tensors + 1; + const auto is_tensor = [](const Tensor *tensor, DType dtype, size_t numel) { + return tensor->dim() == 1 && tensor->dtype() == dtype && tensor->numel() == numel; + }; + + NVTE_CHECK(num_tensors > 0 && logical_last_dim >= 0 && first_dims_tensor->dim() == 1 && + (first_dims_dtype == DType::kInt32 || first_dims_dtype == DType::kInt64) && + is_tensor(first_dims_i64_tensor, DType::kInt64, num_tensors) && + is_tensor(base_offsets_tensor, DType::kInt64, offsets_numel) && + is_tensor(split_points_tensor, DType::kInt32, num_tensors) && + is_tensor(tensor_offsets_tensor, DType::kInt64, offsets_numel), + "Invalid grouped split metadata. Expected first_dims int32/int64[N], " + "first_dims_i64 int64[N], base_offsets int64[N+1], split_points int32[N], " + "tensor_offsets int64[N+1], and logical_last_dim >= 0."); + // split_points is the only int32 output by design: cuDNN grouped GEMM uses + // int32 padded split end offsets, while TE grouped tensor offsets are int64. + + switch (first_dims_dtype) { + case DType::kInt32: + prepare_grouped_splits_kernel<<<1, kThreadsPerBlock, 0, stream>>>( + static_cast(first_dims_tensor->data.dptr), + static_cast(first_dims_i64_tensor->data.dptr), + static_cast(base_offsets_tensor->data.dptr), + static_cast(split_points_tensor->data.dptr), + static_cast(tensor_offsets_tensor->data.dptr), logical_last_dim, num_tensors); + break; + case DType::kInt64: + prepare_grouped_splits_kernel<<<1, kThreadsPerBlock, 0, stream>>>( + static_cast(first_dims_tensor->data.dptr), + static_cast(first_dims_i64_tensor->data.dptr), + static_cast(base_offsets_tensor->data.dptr), + static_cast(split_points_tensor->data.dptr), + static_cast(tensor_offsets_tensor->data.dptr), logical_last_dim, num_tensors); + break; + default: + NVTE_ERROR("first_dims must have dtype int32 or int64."); + } + NVTE_CHECK_CUDA(cudaGetLastError()); +} } // extern "C" void checkCuDriverContext(CUstream stream) { diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 045ae88893..5a8850d330 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -454,6 +454,31 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream void nvte_splits_to_offsets(const int64_t *first_dims, int64_t *output, size_t num_tensors, int64_t logical_last_dim, cudaStream_t stream); +/*! \brief Prepare grouped split metadata. + * + * This is a fused variant of split metadata preparation for grouped kernels. + * It accepts either int32 or int64 first dimensions, writes int64 metadata + * for TE grouped tensors, writes int32 split points for cuDNN grouped GEMM, + * and writes scaled int64 tensor offsets. + * + * \param[in] first_dims Device int32 or int64 tensor of shape [num_tensors]. + * \param[out] first_dims_i64 Device int64 tensor of shape [num_tensors]. + * \param[out] base_offsets Device int64 tensor of shape [num_tensors + 1], + * containing [0, cumsum(first_dims)]. + * \param[out] split_points Device int32 tensor of shape [num_tensors], + * containing cumsum(first_dims) without the leading 0. This is int32 because + * it is consumed by cuDNN grouped GEMM padded offsets; TE grouped tensor + * offsets remain int64. + * \param[out] tensor_offsets Device int64 tensor of shape [num_tensors + 1], + * containing base_offsets * logical_last_dim. + * \param[in] logical_last_dim Scale factor for tensor_offsets. + * \param[in] stream CUDA stream to use for the operation. + */ +void nvte_prepare_grouped_splits(const NVTETensor first_dims, NVTETensor first_dims_i64, + NVTETensor base_offsets, NVTETensor split_points, + NVTETensor tensor_offsets, int64_t logical_last_dim, + cudaStream_t stream); + /*! \brief TE Grouped Tensor type * * NVTEGroupedTensor is a collection of tensors with potentially different shapes diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 9b10a9c5a4..f11437f935 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -488,6 +488,8 @@ std::tuple get_device_pointer_for_data_and_s std::vector data_tensors, std::vector scale_tensors, bool swizzle, bool rowwise, transformer_engine::DType data_dtype); at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_dim); +std::vector prepare_grouped_splits(const at::Tensor &split_sizes, int64_t num_groups, + int64_t logical_last_dim); /*************************************************************************************************** * Support THD format for Context Parallel diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index c5707fa53c..828746180a 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include "../extensions.h" +#include "pybind.h" namespace transformer_engine::pytorch { @@ -30,4 +31,67 @@ at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_ return output; } +std::vector prepare_grouped_splits(const at::Tensor &split_sizes, int64_t num_groups, + int64_t logical_last_dim) { + NVTE_CHECK(split_sizes.scalar_type() == at::kInt || split_sizes.scalar_type() == at::kLong, + "split_sizes must have dtype int32 or int64."); + NVTE_CHECK(split_sizes.dim() == 1, "split_sizes must be a 1D tensor."); + NVTE_CHECK(num_groups > 0, "num_groups must be greater than 0."); + NVTE_CHECK(split_sizes.numel() == num_groups, "split_sizes must have length ", num_groups, "."); + NVTE_CHECK(logical_last_dim >= 0, "logical_last_dim must be non-negative."); + const c10::Device device = c10::Device(c10::kCUDA, c10::cuda::current_device()); + + at::Tensor split_sizes_for_kernel; + if (split_sizes.is_cuda()) { + NVTE_CHECK(split_sizes.device() == device, "CUDA split_sizes must be on current CUDA device ", + device.index(), ", but got CUDA device ", split_sizes.device().index(), "."); + split_sizes_for_kernel = split_sizes; + } else { + // Preserve the legacy eager path: host m_splits are copied to the target + // CUDA device here, then all derived metadata is produced by one CUDA kernel. + split_sizes_for_kernel = + split_sizes.to(at::TensorOptions().dtype(split_sizes.scalar_type()).device(device), + /*non_blocking=*/true); + } + + const int64_t offsets_length = num_groups + 1; + + // Return order is part of the Python contract: + // 0. split_sizes_i64: int64[num_groups], canonical TE GroupedTensor first dims. + // 1. base_offsets: int64[num_groups + 1], [0, cumsum(split_sizes)]. + // 2. split_points: int32[num_groups], cumsum(split_sizes) without the leading 0 + // for cuDNN grouped GEMM padded offsets. This is intentionally int32 + // even though TE grouped tensor metadata uses int64 below. + // 3. tensor_offsets: int64[num_groups + 1], base_offsets * logical_last_dim. + // + // Force 16-byte alignment on every output so ``split_points`` (consumed by + // cuDNN CuTe-DSL grouped GEMM as ``padded_offsets``, which requires 16-byte + // alignment) lands on a 16-byte boundary inside the bulk buffer. + std::vector alignments = {16, 16, 16, 16}; + auto outputs = bulk_allocate({{static_cast(num_groups)}, + {static_cast(offsets_length)}, + {static_cast(num_groups)}, + {static_cast(offsets_length)}}, + {at::kLong, at::kLong, at::kInt, at::kLong}, device, alignments); + auto split_sizes_i64 = outputs[0]; + auto base_offsets = outputs[1]; + auto split_points = outputs[2]; + auto tensor_offsets = outputs[3]; + + auto split_sizes_nvte = makeTransformerEngineTensor(split_sizes_for_kernel); + auto split_sizes_i64_nvte = makeTransformerEngineTensor(split_sizes_i64); + auto base_offsets_nvte = makeTransformerEngineTensor(base_offsets); + auto split_points_nvte = makeTransformerEngineTensor(split_points); + auto tensor_offsets_nvte = makeTransformerEngineTensor(tensor_offsets); + + NVTE_SCOPED_GIL_RELEASE({ + nvte_prepare_grouped_splits(split_sizes_nvte.data(), split_sizes_i64_nvte.data(), + base_offsets_nvte.data(), split_points_nvte.data(), + tensor_offsets_nvte.data(), logical_last_dim, + at::cuda::getCurrentCUDAStream()); + }); + + return outputs; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index a813f3119d..2f60bdf8f9 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -497,6 +497,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("splits_to_offsets", &transformer_engine::pytorch::splits_to_offsets, "Compute grouped tensor offsets from split sizes", py::arg("first_dims"), py::arg("logical_last_dim"), py::call_guard()); + m.def("prepare_grouped_splits", &transformer_engine::pytorch::prepare_grouped_splits, + "Prepare grouped split metadata from CPU/CUDA int32 or int64 split sizes", + py::arg("split_sizes"), py::arg("num_groups"), py::arg("logical_last_dim")); m.def("get_num_cublas_streams", &nvte_get_num_compute_streams, "Get number of compute streams", py::call_guard()); diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 91db2ff9b7..0d384e7f6d 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -159,10 +159,19 @@ def fuser_forward( split_sizes = fc1_split_sizes if int(split_sizes.numel()) != num_groups: raise ValueError(f"Expected {num_groups} splits, but got {int(split_sizes.numel())}.") - split_sizes = split_sizes.to(dtype=torch.int64, device=device) - base_split_offsets = tex.splits_to_offsets(split_sizes, 1) - split_points = base_split_offsets[1:].to(dtype=torch.int) - fc2_x_tensor_offsets = base_split_offsets * fc2_weight_shape[1] + # Prepare all split metadata in one CUDA kernel. The returned split_sizes is the + # canonical TE representation: int64[num_groups]. Python uses it from here + # onward for grouped quantization and backward state. + # + # base_split_offsets: int64[num_groups + 1], [0, cumsum(split_sizes)] + # split_points: int32[num_groups], cumsum(split_sizes) without the leading 0 + # fc2_x_tensor_offsets: int64[num_groups + 1], base_split_offsets * fc2 K + ( + split_sizes, + base_split_offsets, + split_points, + fc2_x_tensor_offsets, + ) = tex.prepare_grouped_splits(split_sizes, num_groups, fc2_weight_shape[1]) # Extract post-scales from extra input scales = basic_op_extra_inputs[1][0]