From 8de5bb5456d1f375ee80454b25ea5e08b666df2a Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 3 Dec 2025 12:29:30 -0800 Subject: [PATCH 01/61] init einsum Signed-off-by: Phuong Nguyen --- tests/jax/test_custom_call_compute.py | 53 +++ tests/jax/test_einsum.py | 219 +++++++++ transformer_engine/jax/cpp_extensions/amax.py | 36 ++ transformer_engine/jax/cpp_extensions/base.py | 89 +++- transformer_engine/jax/cpp_extensions/gemm.py | 53 +-- .../jax/cpp_extensions/quantization.py | 40 +- transformer_engine/jax/einsum.py | 424 ++++++++++++++++++ transformer_engine/jax/quantize/tensor.py | 119 +++-- 8 files changed, 930 insertions(+), 103 deletions(-) create mode 100644 tests/jax/test_einsum.py create mode 100644 transformer_engine/jax/einsum.py diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index c8bd9d47c3..897d9f683e 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1290,6 +1290,59 @@ def test_quantize_dact_dbias_mxfp8_scaling( ) +class TestQuantizeWithVmap: + """Test vmap support for quantization primitives.""" + + @pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16]) + @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("q_layout", [QuantizeLayout.ROWWISE]) + def test_vmap_quantize(self, in_dtype, scaling_mode, q_layout): + """Test that vmap works with tex.quantize using the general batcher.""" + # Determine q_dtype based on scaling mode + if scaling_mode.is_nvfp4_scaling: + q_dtype = jnp.float4_e2m1fn + else: + q_dtype = jnp.float8_e4m3fn + + # Create batched input (E, M, K) - E experts + E, M, K = 4, 64, 128 + key = jax.random.PRNGKey(0) + batched_input = jax.random.uniform(key, (E, M, K), in_dtype) + + # Create per-expert quantizers + quantizers = [ + QuantizerFactory.create( + q_dtype=q_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + ) + for _ in range(E) + ] + + # Stack quantizers for vmap + stacked_quantizers = jax.tree_util.tree_map(lambda *args: jnp.stack(args), *quantizers) + + # Vmap over expert dimension + def quantize_single(x, quantizer): + return tex.quantize(x, quantizer=quantizer, flatten_axis=-1) + + vmapped_quantize = jax.vmap(quantize_single, in_axes=(0, 0)) + result = vmapped_quantize(batched_input, stacked_quantizers) + + # Verify shapes + assert result.data.shape == (E, M, K) + assert result.scale_inv.shape[0] == E # Per-expert scales + + # Compare with calling quantize for each expert individually + individual_results = [] + for i in range(E): + res_i = tex.quantize(batched_input[i], quantizer=quantizers[i], flatten_axis=-1) + individual_results.append(res_i.data) + + expected = jnp.stack(individual_results, axis=0) + assert_allclose(result.data, expected, dtype=quantizers[0].q_dtype) + + valid_fp8_gemm_operand_types = [ (jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e5m2, jnp.float8_e4m3fn), diff --git a/tests/jax/test_einsum.py b/tests/jax/test_einsum.py new file mode 100644 index 0000000000..39dffa6787 --- /dev/null +++ b/tests/jax/test_einsum.py @@ -0,0 +1,219 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Tests for TE einsum operation with FP8 quantization.""" + +import jax +import jax.numpy as jnp +import pytest +from jax import value_and_grad + +from utils import assert_allclose, pytest_parametrize_wrapper +from transformer_engine.jax.einsum import einsum +from transformer_engine.jax.quantize import ( + QuantizerFactory, + QuantizeMeta, + QuantizeMetaSet, +) +from transformer_engine.jax.quantize import helper + + +# Test parameters +DTYPES = [jnp.bfloat16] +# (B, S, M, E, C, H) +# B: Batch size +# S: Sequence length (number of tokens) +# M: Model dimension (hidden size) +# E: Number of experts +# C: Capacity (max tokens per expert) +# H: Hidden dimension (MLP intermediate size) +MOE_CASES = [ + (2, 32, 128, 4, 32, 64), +] + +# Get supported recipes +supported_recipes = helper.get_supported_quantization_recipes() +supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes] + + +@pytest.fixture(autouse=True, scope="module") +def init(): + """WAR for CUDA uninitialize error""" + # Calling customcalls before jax may cause CUDA uninitialize error + _ = jnp.zeros(0) + yield + + +class TestMoEMLPWithRecipes: + """Test MoE MLP operations with different FP8 recipes and gradients.""" + + def _get_quantizer_sets(self, recipe, num_experts): + return QuantizerFactory.create_set( + n_quantizer_sets=num_experts, + fp8_recipe=recipe, + quantize_meta_set=QuantizeMetaSet( + x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() + ), + ) + + def _einsum(self, equation, *operands, quantizer_sets=None, quantizer_dim=None, fallback=False): + out = einsum( + equation, + *operands, + quantizer_sets=quantizer_sets, + quantizer_dim=quantizer_dim, + fallback=fallback, + ) + return jnp.mean(out) + + def _ref_einsum(self, equation, *operands): + out = jnp.einsum(equation, *operands) + return jnp.mean(out) + + @pytest_parametrize_wrapper("B,S,M,E,C,H", MOE_CASES) + @pytest_parametrize_wrapper("recipe", supported_recipes) + def test_mlp_up_grad(self, B, S, M, E, C, H, recipe): + """Test MLP up: EBCM,EMH->EBCH with gradients and different recipes.""" + # Create per-expert quantizers + quantizer_sets = self._get_quantizer_sets(recipe, E) + dispatched = jax.random.normal( + jax.random.PRNGKey(0), (E, B, C, M), dtype=jnp.bfloat16 + ) / jnp.sqrt(M) + weights = jax.random.normal(jax.random.PRNGKey(1), (E, M, H), dtype=jnp.bfloat16) + + # Compute with TE einsum with quantization + loss_te, grads_te = value_and_grad(self._einsum, argnums=(1, 2))( + "EBCM,EMH->EBCH", dispatched, weights, quantizer_sets=quantizer_sets, quantizer_dim="E" + ) + + # Compute reference (BF16) + loss_ref, grads_ref = value_and_grad(self._ref_einsum, argnums=(1, 2))( + "EBCM,EMH->EBCH", dispatched, weights + ) + + # Verify shapes and no NaNs + assert grads_te[0].shape == dispatched.shape + assert grads_te[1].shape == weights.shape + assert not jnp.isnan(loss_te) + assert jnp.all(jnp.isfinite(grads_te[0])) + assert jnp.all(jnp.isfinite(grads_te[1])) + + # Compare with reference (with FP8 tolerance) + assert_allclose(loss_te, loss_ref, dtype=quantizer_sets[0].x.q_dtype) + assert_allclose(grads_te[0], grads_ref[0], dtype=quantizer_sets[0].dgrad.q_dtype) + assert_allclose(grads_te[1], grads_ref[1], dtype=quantizer_sets[0].dgrad.q_dtype) + + @pytest_parametrize_wrapper("B,S,M,E,C,H", MOE_CASES) + @pytest_parametrize_wrapper("recipe", supported_recipes) + def test_mlp_down_grad(self, B, S, M, E, C, H, recipe): + """Test MLP down: EBCH,EHM->EBCM with gradients and different recipes.""" + # Create per-expert quantizers + quantizer_sets = self._get_quantizer_sets(recipe, E) + + hidden = jax.random.normal( + jax.random.PRNGKey(0), (E, B, C, H), dtype=jnp.bfloat16 + ) / jnp.sqrt(H) + weights = jax.random.normal(jax.random.PRNGKey(1), (E, H, M), dtype=jnp.bfloat16) + + # Compute with TE einsum with quantization + loss_te, grads_te = value_and_grad(self._einsum, argnums=(1, 2))( + "EBCH,EHM->EBCM", hidden, weights, quantizer_sets=quantizer_sets, quantizer_dim="E" + ) + + # Compute reference (BF16) + loss_ref, grads_ref = value_and_grad(self._ref_einsum, argnums=(1, 2))( + "EBCH,EHM->EBCM", hidden, weights + ) + + # Verify shapes and no NaNs + assert grads_te[0].shape == hidden.shape + assert grads_te[1].shape == weights.shape + assert not jnp.isnan(loss_te) + assert jnp.all(jnp.isfinite(grads_te[0])) + assert jnp.all(jnp.isfinite(grads_te[1])) + + # Compare with reference (with FP8 tolerance) + assert_allclose(loss_te, loss_ref, dtype=quantizer_sets[0].x.q_dtype) + assert_allclose(grads_te[0], grads_ref[0], dtype=quantizer_sets[0].dgrad.q_dtype) + assert_allclose(grads_te[1], grads_ref[1], dtype=quantizer_sets[0].dgrad.q_dtype) + + @pytest_parametrize_wrapper("B,S,M,E,C,H", MOE_CASES) + @pytest_parametrize_wrapper("recipe", supported_recipes) + def test_full_moe_grad(self, B, S, M, E, C, H, recipe): + """Test full MoE pipeline (all 4 einsums) with gradients and different recipes.""" + # Create per-expert quantizers for each einsum + mlp_up_quantizer_sets = self._get_quantizer_sets(recipe, E) + mlp_down_quantizer_sets = self._get_quantizer_sets(recipe, E) + + tokens = jax.random.normal(jax.random.PRNGKey(0), (B, S, M), dtype=jnp.bfloat16) / jnp.sqrt(M) + routing = jax.random.normal(jax.random.PRNGKey(1), (B, S, E, C), dtype=jnp.bfloat16) + routing = jax.nn.softmax(routing, axis=-1) # Normalize routing weights + up_weights = jax.random.normal( + jax.random.PRNGKey(2), (E, M, H), dtype=jnp.bfloat16 + ) / jnp.sqrt(H) + down_weights = jax.random.normal( + jax.random.PRNGKey(3), (E, H, M), dtype=jnp.bfloat16 + ) / jnp.sqrt(M) + + # TE implementation with quantization + def full_moe_te(tokens, routing, up_w, down_w): + """Complete MoE pipeline with TE einsum.""" + dispatched = einsum("BSM,BSEC->EBCM", tokens, routing, fallback=True) + hidden = einsum( + "EBCM,EMH->EBCH", + dispatched, + up_w, + quantizer_sets=mlp_up_quantizer_sets, + quantizer_dim="E", + ) + expert_out = einsum( + "EBCH,EHM->EBCM", + hidden, + down_w, + quantizer_sets=mlp_down_quantizer_sets, + quantizer_dim="E", + ) + output = einsum("EBCM,BSEC->BSM", expert_out, routing, fallback=True) + return jnp.sum(output) + + # Reference implementation with jnp.einsum + def full_moe_ref(tokens, routing, up_w, down_w): + """Complete MoE pipeline with jnp.einsum.""" + dispatched = jnp.einsum("BSM,BSEC->EBCM", tokens, routing) + hidden = jnp.einsum("EBCM,EMH->EBCH", dispatched, up_w) + expert_out = jnp.einsum("EBCH,EHM->EBCM", hidden, down_w) + output = jnp.einsum("EBCM,BSEC->BSM", expert_out, routing) + return jnp.sum(output) + + loss_te, grads_te = value_and_grad(full_moe_te, argnums=(0, 1, 2, 3))( + tokens, routing, up_weights, down_weights + ) + + loss_ref, grads_ref = value_and_grad(full_moe_ref, argnums=(0, 1, 2, 3))( + tokens, routing, up_weights, down_weights + ) + + # Verify all gradient shapes + assert grads_te[0].shape == tokens.shape, f"tokens grad shape mismatch" + assert grads_te[1].shape == routing.shape, f"routing grad shape mismatch" + assert grads_te[2].shape == up_weights.shape, f"up_weights grad shape mismatch" + assert grads_te[3].shape == down_weights.shape, f"down_weights grad shape mismatch" + + # Verify no NaNs or Infs + assert not jnp.isnan(loss_te), "Loss is NaN" + assert jnp.isfinite(loss_te), "Loss is Inf" + assert jnp.all(jnp.isfinite(grads_te[0])), "tokens grad has NaN/Inf" + assert jnp.all(jnp.isfinite(grads_te[1])), "routing grad has NaN/Inf" + assert jnp.all(jnp.isfinite(grads_te[2])), "up_weights grad has NaN/Inf" + assert jnp.all(jnp.isfinite(grads_te[3])), "down_weights grad has NaN/Inf" + + # Compare with reference (with FP8 tolerance) + assert_allclose(loss_te, loss_ref, dtype=mlp_up_quantizer_sets[0].x.q_dtype) + assert_allclose(grads_te[0], grads_ref[0], dtype=mlp_up_quantizer_sets[0].dgrad.q_dtype) + assert_allclose(grads_te[1], grads_ref[1], dtype=mlp_up_quantizer_sets[0].dgrad.q_dtype) + assert_allclose(grads_te[2], grads_ref[2], dtype=mlp_down_quantizer_sets[0].x.q_dtype) + assert_allclose(grads_te[3], grads_ref[3], dtype=mlp_down_quantizer_sets[0].dgrad.q_dtype) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/transformer_engine/jax/cpp_extensions/amax.py b/transformer_engine/jax/cpp_extensions/amax.py index 2f3bc402ec..19e229c1ee 100644 --- a/transformer_engine/jax/cpp_extensions/amax.py +++ b/transformer_engine/jax/cpp_extensions/amax.py @@ -160,6 +160,18 @@ def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types output_spec = (f"{prefix}_amax",) return SdyShardingRule((input_spec,), (output_spec,)) + @staticmethod + def batcher(batched_args, batch_dims, *, amax_scope, transpose_batch_sequence): + """Batcher for amax calculation - returns single amax value.""" + return AmaxCalculationPrimitive.batcher_impl( + batched_args, + batch_dims, + static_kwargs={ + "amax_scope": amax_scope, + "transpose_batch_sequence": transpose_batch_sequence, + }, + ) + register_primitive(AmaxCalculationPrimitive, outer_only=True) @@ -370,6 +382,30 @@ def shardy_sharding_rule( output_post_rht_amax_spec = (f"{prefix}_post_rht_amax",) return SdyShardingRule((input_spec,), (output_amax_spec, output_post_rht_amax_spec)) + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + ): + """Batcher for RHT amax calculation - returns 2 amax values.""" + return RHTAmaxCalculationPrimitive.batcher_impl( + batched_args, + batch_dims, + static_kwargs={ + "amax_scope": amax_scope, + "transpose_batch_sequence": transpose_batch_sequence, + "rht_matrix_random_sign_mask_t": rht_matrix_random_sign_mask_t, + "produce_regular_amax": produce_regular_amax, + "flatten_axis": flatten_axis, + }, + ) + register_primitive(RHTAmaxCalculationPrimitive) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 556b587191..9f88265e93 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -7,13 +7,14 @@ import warnings from abc import ABCMeta, abstractmethod from functools import partial +from typing import Any, Sequence, Union, Tuple from jax.extend import core from jax.interpreters import xla, mlir from jax.experimental.custom_partitioning import custom_partitioning from jax._src.interpreters import batching from jax._src import dispatch -from jax import ffi +from jax import ffi, numpy as jnp import transformer_engine_jax @@ -168,6 +169,92 @@ def shardy_sharding_rule(*args): del args return "... -> ..." + @classmethod + def batcher_impl( + cls, + batched_args: Sequence[Any], + batch_dims: Sequence[Union[int, None]], + static_kwargs: dict, + ) -> Tuple[Tuple[Any, ...], Tuple[Union[int, None], ...]]: + """Batcher implementation for JAX primitives. + + Implements the standard batching pattern: loop over batch dimension, + call primitive for each slice, and stack results. + + Args: + batched_args: Tuple of input tensors (some may be batched) + batch_dims: Tuple indicating batch dimension for each arg (None if not batched) + static_kwargs: Dictionary of static arguments to pass to primitive.bind() + + Returns: + Tuple of (output_tensors, output_batch_dims) + + Example: + @staticmethod + def batcher(batched_args, batch_dims, *, arg1, arg2, arg3): + return MyPrimitive.batcher_impl( + batched_args, batch_dims, + static_kwargs={'arg1': arg1, 'arg2': arg2, 'arg3': arg3}, + ) + """ + from jax import lax + + # Find batch dimension and validate all batched args have the same batch_dim + batch_dim = None + batch_size = None + for arg, bdim in zip(batched_args, batch_dims): + if bdim is not None: + if batch_dim is None: + batch_dim = bdim + batch_size = arg.shape[bdim] + elif bdim != batch_dim: + raise ValueError( + "All batched arguments must have the same batch dimension. " + f"Got batch_dims={batch_dims}" + ) + assert batch_dim is not None and batch_size is not None, "Invalid batching config!" + + # Loop over batch dimension and collect results + all_results = [] + + for i in range(batch_size): + # Extract slice for each argument + sliced_args = [] + for arg, bdim in zip(batched_args, batch_dims): + if bdim is not None: + slice_i = lax.index_in_dim(arg, i, bdim, keepdims=False) + sliced_args.append(slice_i) + else: # For empty args + sliced_args.append(arg) + + # Call primitive with unbatched slices + result_i = cls.outer_primitive.bind(*sliced_args, **static_kwargs) + + # Normalize to tuple + if not isinstance(result_i, (tuple, list)): + result_i = (result_i,) + elif isinstance(result_i, list): + result_i = tuple(result_i) + + all_results.append(result_i) + + # Transpose: from list of tuples to tuple of lists + # all_results = [(out0_0, out1_0, ...), (out0_1, out1_1, ...), ...] + # transposed = ([out0_0, out0_1, ...], [out1_0, out1_1, ...], ...) + transposed = tuple(zip(*all_results)) + + # Stack each output along the batch dimension + stacked_results = tuple( + jnp.stack(list(out_list), axis=batch_dim) for out_list in transposed + ) + + # Single output: return unwrapped result + if len(stacked_results) == 1: + return stacked_results[0], batch_dim + + # Multiple outputs: return tuple of results + return stacked_results, [batch_dim for _ in stacked_results] + # Registry to store all registered primitive classes _primitive_registry = {} diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 76a8b225ba..55a1700838 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -808,40 +808,33 @@ def batcher( sequence_dim, is_outer, ): - del transpose_batch_sequence, sequence_dim, is_outer assert GemmPrimitive.outer_primitive is not None lhs_bdims, _, rhs_bdims, *_ = batch_dims - # Batched GEMM is not supported - assert ( - lhs_bdims is None and rhs_bdims is None - ), f"(Batching is not supported, got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims})" - out_bdims = (None,) - - # Bias gradient is never batched - bias_bdims = (None,) - - # Pre-GeLU output, if exists, is batched like GEMM output - pre_gelu_bdims = (None,) - if fuse_gelu and not grad: - pre_gelu_bdims = out_bdims + # Validate batch dimensions + if lhs_bdims is not None or rhs_bdims is not None: + assert lhs_bdims == rhs_bdims, ( + "Batched GEMM requires matching batch dimensions, " + f"got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims}" + ) - return ( - GemmPrimitive.outer_primitive.bind( - *batched_args, - out_dtype=out_dtype, - contracting_dims=contracting_dims, - scaling_mode=scaling_mode, - fuse_bias=fuse_bias, - fuse_gelu=fuse_gelu, - grad=grad, - use_split_accumulator=use_split_accumulator, - collective_op=collective_op, - transpose_batch_sequence=transpose_batch_sequence, - sequence_dim=sequence_dim, - is_outer=is_outer, - ), - (out_bdims, bias_bdims, pre_gelu_bdims), + # Use general batcher from BasePrimitive + return GemmPrimitive.batcher_impl( + batched_args, + batch_dims, + static_kwargs={ + "out_dtype": out_dtype, + "contracting_dims": contracting_dims, + "scaling_mode": scaling_mode, + "fuse_bias": fuse_bias, + "fuse_gelu": fuse_gelu, + "grad": grad, + "use_split_accumulator": use_split_accumulator, + "collective_op": collective_op, + "transpose_batch_sequence": transpose_batch_sequence, + "sequence_dim": sequence_dim, + "is_outer": is_outer, + }, ) @staticmethod diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index b3f24e9337..53c6937fb4 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -361,34 +361,24 @@ def batcher( stochastic_rounding, use_rht, ): - """ - to describe batch rules for vmap - """ - del is_outer + """Batch rule for quantization primitive using general batcher.""" check_valid_batch_dims(batch_dims) assert BaseDBiasQuantizePrimitive.outer_primitive is not None - x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix = batched_args - x_bdim, scale_bdim, amax_bdim, _, _, _ = batch_dims - out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim - return ( - BaseDBiasQuantizePrimitive.outer_primitive.bind( - x, - scale, - amax, - sr_rng_state, - post_rht_amax, - rht_matrix, - out_dtype=out_dtype, - scaling_mode=scaling_mode, - q_layout=q_layout, - flatten_axis=flatten_axis, - scale_dtype=scale_dtype, - is_dbias=is_dbias, - stochastic_rounding=stochastic_rounding, - use_rht=use_rht, - ), - out_bdims, + return BaseDBiasQuantizePrimitive.batcher_impl( + batched_args, + batch_dims, + static_kwargs={ + "out_dtype": out_dtype, + "scaling_mode": scaling_mode, + "q_layout": q_layout, + "flatten_axis": flatten_axis, + "scale_dtype": scale_dtype, + "is_dbias": is_dbias, + "is_outer": is_outer, + "stochastic_rounding": stochastic_rounding, + "use_rht": use_rht, + }, ) @staticmethod diff --git a/transformer_engine/jax/einsum.py b/transformer_engine/jax/einsum.py new file mode 100644 index 0000000000..20084c77ea --- /dev/null +++ b/transformer_engine/jax/einsum.py @@ -0,0 +1,424 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Einsum operation with FP8 quantization support for Transformer Engine in JAX. + +This module provides an einsum implementation that decomposes einsum operations into +a sequence of GEMMs, each with its own quantizer for FP8 support. It follows the +pattern of jax.numpy.einsum but uses TE's optimized GEMM operations. + +This module provides an einsum implementation optimized for Mixture-of-Experts (MoE) +models with per-expert quantization support. It leverages JAX's vmap and TE's dense +layer to efficiently handle tensor contractions with a single batch dimension. + +Key Features: + - **Per-expert quantization**: Each expert can have independent scaling and quantization parameters + - **Automatic differentiation**: Full gradient support via dense layer's VJP + - **Single batch dimension**: Optimized for MoE patterns (expert dimension) + - **Explicit API**: Requires quantizer_dim when using quantization + +Limitations: + - **NN layout only**: LHS last dim must contract, RHS last dim must not contract + - **Single batch dimension**: Only one batch dimension supported + - **2-operand only**: Only supports binary operations + - **Explicit quantizer_dim**: Required when quantizer_sets is provided + + For operations that don't meet these requirements (e.g., routing operations + like "BSM,BSEC->EBCM"), use jnp.einsum instead, or set fallback=True to + automatically fall back to jnp.einsum when the operation is not supported. + +Example - MoE Forward Pass with Per-Expert FP8: + ```python + from transformer_engine.jax.einsum import einsum + from transformer_engine.jax.quantize import QuantizerFactory, QuantizeMeta, QuantizeMetaSet + + # Create per-expert quantizers (E experts) + quantizer_sets = [ + QuantizerFactory.create_set( + fp8_recipe=recipe, + quantize_meta_set=QuantizeMetaSet( + x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() + ) + ) for _ in range(num_experts) + ] + + # MoE pipeline with per-expert quantization, + # 1. Dispatch: BSM,BSEC -> EBCM (no quantization - routing operation) + dispatched = jnp.einsum("BSM,BSEC->EBCM", tokens, routing) + # Or with fallback: + # dispatched = einsum("BSM,BSEC->EBCM", tokens, routing, fallback=True) + + # 2. MLP Up: EBCM,EMH -> EBCH (per-expert quantization) + hidden = einsum("EBCM,EMH->EBCH", dispatched, expert_up_weights, + quantizer_sets=expert_quantizers, quantizer_dim='E') + + # 3. MLP Down: EBCH,EHM -> EBCM (per-expert quantization) + expert_out = einsum("EBCH,EHM->EBCM", hidden, expert_down_weights, + quantizer_sets=expert_quantizers, quantizer_dim='E') + + # 4. Combine: EBCM,BSEC -> BSM (no quantization - routing operation) + output = jnp.einsum("EBCM,BSEC->BSM", expert_out, routing) + # Or with fallback: + # output = einsum("EBCM,BSEC->BSM", expert_out, routing, fallback=True) + ``` + +Implementation Details: + The einsum function works by: + 1. Parsing the einsum equation to identify the single batch dimension and contracting dimensions + 2. Validating that quantizer_sets length matches the quantizer dimension size + 3. Creating a vmapped version of TE's dense layer over the batch dimension + 4. Vmapping over quantizer_sets to provide per-batch (e.g., per-expert) quantization + 5. Leveraging dense's existing VJP for automatic differentiation + + This design reuses TE's well-tested dense layer infrastructure while enabling + per-expert quantization for MoE models with minimal code complexity. +""" + +from typing import Tuple, Optional, List +import jax +import jax.numpy as jnp + +from .dense import dense +from .quantize import ( + QuantizerSet, + noop_quantizer_set, +) + + +def _parse_einsum_input(equation: str, *operands) -> Tuple[str, List[str], str]: + """Parse einsum equation into input specs and output spec. + + Args: + equation: Einsum equation string (e.g., "ij,jk->ik" or "BNSM,BNSEC->EBNCM") + operands: Input tensors + + Returns: + Tuple of (equation, input_specs, output_spec) + + Raises: + ValueError: If number of operands doesn't match equation + """ + # Remove spaces + equation = equation.replace(" ", "") + + if "->" in equation: + inputs_str, output_str = equation.split("->") + input_specs = inputs_str.split(",") + else: + # Implicit output mode + inputs_str = equation + input_specs = inputs_str.split(",") + # Compute implicit output + all_indices = set() + for spec in input_specs: + all_indices.update(spec) + output_str = "".join(sorted(all_indices)) + + # Validate each operand's ndim matches its spec + for i, (operand, spec) in enumerate(zip(operands, input_specs)): + expected_ndim = len(spec) + actual_ndim = operand.ndim + if actual_ndim != expected_ndim: + raise ValueError( + f"Operand {i} has {actual_ndim} dimensions but equation '{equation}' " + f"expects {expected_ndim} dimensions (spec: '{spec}'). " + f"Operand shape: {operand.shape}" + ) + + return equation, input_specs, output_str + + +def _find_contracting_and_batch_dims(lhs_spec: str, rhs_spec: str, output_spec: str): + """Find contracting and batch dimensions for a GEMM operation. + + Args: + lhs_spec: Index specification for LHS (e.g., "BNSM") + rhs_spec: Index specification for RHS (e.g., "BNSEC") + output_spec: Index specification for output (e.g., "EBNCM") + + Returns: + Tuple of (lhs_contracting, rhs_contracting, lhs_batch, rhs_batch) + """ + # Contracting dimensions: indices in both lhs and rhs but not in output + lhs_set = set(lhs_spec) + rhs_set = set(rhs_spec) + output_set = set(output_spec) + + contracting_indices = (lhs_set & rhs_set) - output_set + + # Batch dimensions: indices in lhs, rhs, and output + batch_indices = lhs_set & rhs_set & output_set + + # Find positions + lhs_contracting = tuple(i for i, c in enumerate(lhs_spec) if c in contracting_indices) + rhs_contracting = tuple(i for i, c in enumerate(rhs_spec) if c in contracting_indices) + lhs_batch = tuple(i for i, c in enumerate(lhs_spec) if c in batch_indices) + rhs_batch = tuple(i for i, c in enumerate(rhs_spec) if c in batch_indices) + + return lhs_contracting, rhs_contracting, lhs_batch, rhs_batch + + +def _einsum_to_gemm_info(equation: str, *operands): + """Extract GEMM information from einsum equation. + + Args: + equation: Einsum equation + operands: Input tensors + + Returns: + Dict with keys: lhs_idx, rhs_idx, contracting_dims, batch_dims, output_spec + """ + equation, input_specs, output_spec = _parse_einsum_input(equation, *operands) + + if len(input_specs) != 2: + raise NotImplementedError(f"Einsum with {len(input_specs)} operands not yet supported") + + lhs_spec, rhs_spec = input_specs + + lhs_contracting, rhs_contracting, lhs_batch, rhs_batch = _find_contracting_and_batch_dims( + lhs_spec, rhs_spec, output_spec + ) + + return { + "lhs_idx": 0, + "rhs_idx": 1, + "lhs_spec": lhs_spec, + "rhs_spec": rhs_spec, + "output_spec": output_spec, + "contracting_dims": (lhs_contracting, rhs_contracting), + "batch_dims": (lhs_batch, rhs_batch), + } + + +def einsum( + equation: str, + *operands: jnp.ndarray, + quantizer_sets: Optional[List[QuantizerSet]] = None, + quantizer_dim: Optional[str] = None, + operand_axes: Optional[List[Tuple[str, ...]]] = None, + output_axes: Optional[Tuple[str, ...]] = None, + fallback: bool = False, +) -> jnp.ndarray: + """Perform einsum operation with optional FP8 quantization using vmap + dense. + + This function implements einsum by: + 1. Identifying batch dimensions + 2. Using vmap to vectorize over batch dimensions + 3. Calling the existing dense() function which has VJP already implemented + + Each batched GEMM can have its own quantizer_set, enabling per-expert + quantization in MoE models. + + Args: + equation: Einsum equation string (e.g., "ij,jk->ik", "BSM,BSEC->EBCM") + *operands: Input tensors + quantizer_sets: List or tuple of QuantizerSets. Length must match the size of + the dimension specified by quantizer_dim. If None, creates noop quantizers. + quantizer_dim: Index label indicating which dimension the quantizers correspond to. + For MoE, this is typically 'E' (expert dimension). If None and + quantizer_sets is provided, assumes first batch dimension at position 0. + operand_axes: List of logical axes tuples for sharding each operand + output_axes: Logical axes for sharding the output + fallback: Whether to fallback to jnp.einsum if the einsum operation is not supported. + When fallback=True, unsupported operations (e.g., non-NN layouts, routing + operations) will use jnp.einsum. Note: quantization will NOT be applied + when falling back. + + Returns: + Result of the einsum operation + + Examples: + # Simple matrix multiplication with FP8 + result = einsum("ij,jk->ik", A, B, quantizer_sets=my_quantizer_set) + + # MoE with per-expert quantizers (E experts) + expert_quantizers = [quantizer_e0, quantizer_e1, ..., quantizer_eN] + result = einsum("EBNCM,EMH->EBNCH", tokens, weights, + quantizer_sets=expert_quantizers) + + # With fallback for routing operations + result = einsum("BSM,BSEC->EBCM", tokens, routing, fallback=True) + # Falls back to jnp.einsum (no quantization) + """ + if operand_axes is None: + operand_axes = [None] * len(operands) + + if len(operands) != 2: + if fallback: + import warnings + + warnings.warn( + f"TE einsum only supports 2-operand einsum, got {len(operands)} operands. " + "Falling back to jnp.einsum (no quantization will be applied).", + stacklevel=2, + ) + return jnp.einsum(equation, *operands) + raise NotImplementedError("Only 2-operand einsum currently supported") + + # Parse einsum to get GEMM info + gemm_info = _einsum_to_gemm_info(equation, *operands) + contracting_dims = gemm_info["contracting_dims"] + batch_dims = gemm_info["batch_dims"] + lhs_spec = gemm_info["lhs_spec"] + rhs_spec = gemm_info["rhs_spec"] + + lhs, rhs = operands + + # Validate quantizer_dim is provided when quantizer_sets is given + if quantizer_sets is not None and quantizer_dim is None: + raise ValueError( + "quantizer_dim must be specified when quantizer_sets is provided. " + "This explicitly indicates which dimension the quantizers correspond to." + ) + + # Find quantizer dimension + quantizer_dim_lhs = None + quantizer_dim_rhs = None + + if quantizer_dim is not None: + # Find position of quantizer_dim in lhs and rhs specs + if quantizer_dim in lhs_spec: + quantizer_dim_lhs = lhs_spec.index(quantizer_dim) + if quantizer_dim in rhs_spec: + quantizer_dim_rhs = rhs_spec.index(quantizer_dim) + + if quantizer_dim_lhs is None and quantizer_dim_rhs is None: + raise ValueError(f"quantizer_dim '{quantizer_dim}' not found in equation '{equation}'") + + # Check if we have batch dimensions + has_batch_dims = bool(batch_dims[0] or batch_dims[1]) + + # Determine expected quantizer_sets length based on quantizer_dim + if quantizer_dim is not None: + if quantizer_dim_lhs is not None: + expected_length = lhs.shape[quantizer_dim_lhs] + else: + expected_length = rhs.shape[quantizer_dim_rhs] + else: + # No quantizer_dim: determine from batch dimension + if has_batch_dims: + expected_length = lhs.shape[batch_dims[0][0]] + else: + expected_length = 1 + + # Validate and initialize quantizer_sets + if quantizer_sets is None: + quantizer_sets = [noop_quantizer_set] * expected_length + elif not isinstance(quantizer_sets, (list, tuple)): + raise TypeError(f"quantizer_sets must be a list or tuple, got {type(quantizer_sets)}") + elif len(quantizer_sets) != expected_length: + raise ValueError( + f"quantizer_sets length ({len(quantizer_sets)}) must match " + f"{'dimension ' + repr(quantizer_dim) if quantizer_dim else 'batch dimension'} " + f"size ({expected_length})" + ) + + # Validate that this is NN layout (required by dense) + # For NN: lhs last dim must contract, rhs last dim must NOT contract + lhs_ndim = len(gemm_info["lhs_spec"]) + rhs_ndim = len(gemm_info["rhs_spec"]) + lhs_last_contracts = lhs_ndim - 1 in contracting_dims[0] + rhs_last_contracts = rhs_ndim - 1 in contracting_dims[1] + + if not lhs_last_contracts or rhs_last_contracts: + if fallback: + import warnings + + if quantizer_sets is not None and quantizer_sets != [noop_quantizer_set] * len( + quantizer_sets + ): + warnings.warn( + f"TE einsum only supports NN layout. Equation '{equation}' is not NN layout. " + "Falling back to jnp.einsum. WARNING: Quantization will NOT be applied!", + stacklevel=2, + ) + return jnp.einsum(equation, *operands) + raise ValueError( + "TE einsum only supports NN layout (non-transposed matrix multiplication). Equation" + f" '{equation}' is not NN layout:\n - LHS '{gemm_info['lhs_spec']}': last dimension" + f" must contract (got contracting_dims={contracting_dims[0]})\n - RHS" + f" '{gemm_info['rhs_spec']}': last dimension must NOT contract (got" + f" contracting_dims={contracting_dims[1]})\nFor non-NN layouts (e.g., routing" + " operations), use jnp.einsum instead." + ) + + # Create vmapped dense function for batch dimensions + has_batch_dims = bool(batch_dims[0] or batch_dims[1]) + + if has_batch_dims: + # Validate single batch dimension (MoE use case) + if len(batch_dims[0]) != 1 or len(batch_dims[1]) != 1: + if fallback: + import warnings + + if quantizer_sets is not None and quantizer_sets != [noop_quantizer_set] * len( + quantizer_sets + ): + warnings.warn( + "TE einsum only supports single batch dimension. Got" + f" {len(batch_dims[0])} batch dims in lhs and {len(batch_dims[1])} in rhs." + " Falling back to jnp.einsum. WARNING: Quantization will NOT be applied!", + stacklevel=2, + ) + return jnp.einsum(equation, *operands) + raise NotImplementedError( + "Only single batch dimension is currently supported. " + f"Got {len(batch_dims[0])} batch dims in lhs and {len(batch_dims[1])} in rhs. " + f"Equation: '{equation}'" + ) + + lhs_batch_dim = batch_dims[0][0] + rhs_batch_dim = batch_dims[1][0] + + # Adjust contracting dims for the unbatched shapes seen by Python code + # (primitives will see batched shapes, but Python validation sees unbatched) + adj_lhs_contracting = tuple( + dim - (1 if dim > lhs_batch_dim else 0) for dim in contracting_dims[0] + ) + adj_rhs_contracting = tuple( + dim - (1 if dim > rhs_batch_dim else 0) for dim in contracting_dims[1] + ) + adj_contracting_dims = (adj_lhs_contracting, adj_rhs_contracting) + + # Stack quantizers into a pytree structure that vmap can handle + # QuantizerSet is already a pytree, so we can stack them + # For BF16 without quantizer_dim, this will be a stack of noop_quantizer_sets + stacked_quantizers = jax.tree_util.tree_map(lambda *args: jnp.stack(args), *quantizer_sets) + + # Vmap over quantizers (or repeated noop quantizers for BF16) + def dense_with_quantizer(lhs_single, rhs_single, quantizer_set): + """Dense with explicit quantizer argument for vmapping.""" + return dense( + lhs_single, + rhs_single, + None, + contracting_dims=adj_contracting_dims, # Adjusted for unbatched shapes + transpose_batch_sequence=False, + input_axes=operand_axes[0], + kernel_axes=operand_axes[1], + output_axes=output_axes, + quantizer_set=quantizer_set, + ) + + vmapped_func = jax.vmap( + dense_with_quantizer, + in_axes=(lhs_batch_dim, rhs_batch_dim, 0), # vmap over stacked quantizers + out_axes=0, + ) + output = vmapped_func(lhs, rhs, stacked_quantizers) + else: + # No batch dimensions - direct dense call + # quantizer_set length already validated to be 1 + output = dense( + lhs, + rhs, + None, + contracting_dims=contracting_dims, + transpose_batch_sequence=False, + input_axes=operand_axes[0], + kernel_axes=operand_axes[1], + output_axes=output_axes, + quantizer_set=quantizer_sets[0], + ) + + return output diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 90f139c3da..120bd05c13 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -209,49 +209,63 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): flatten_axis: int has_rht_applied: bool - def __post_init__(self): - """Validates and adjusts the scale_inv shape after initialization. - - Ensures the scale_inv shape matches the expected shape based on the scaling mode - and quantization direction. Pads the scale_inv if necessary. - """ - assert self.flatten_axis > 0 - assert ( - 0 < self.flatten_axis < len(self.data.shape) - ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" - - if self.scaling_mode == ScalingMode.NO_SCALING: - self.scale_inv = jnp.empty((0,), dtype=jnp.float32) - else: - unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, - data_layout=self.data_layout, - is_colwise=self.is_colwise, - is_padded=False, - # expect the flatten_axis wrt the N layout - flatten_axis=( - self.flatten_axis - if self.data_layout == "N" - else self.data.ndim - self.flatten_axis - ), - ) - unpadded_scale_shape_broadcast = self.scaling_mode.get_scale_shape( - self.data.shape, - data_layout=self.data_layout, - is_colwise=self.is_colwise, - is_padded=False, - # expect the flatten_axis wrt the N layout - flatten_axis=( - self.flatten_axis - if self.data_layout == "N" - else self.data.ndim - self.flatten_axis - ), - broadcast_2d_scale_shape_to_1d=True, - ) - assert self.scale_inv.shape in (unpadded_scale_shape, unpadded_scale_shape_broadcast), ( - f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or" - f" {unpadded_scale_shape_broadcast} but got {self.scale_inv.shape}." - ) + # def __post_init__(self): + # """Validates and adjusts the scale_inv shape after initialization. + # + # Ensures the scale_inv shape matches the expected shape based on the scaling mode + # and quantization direction. Pads the scale_inv if necessary. + # """ + # assert self.flatten_axis > 0 + # assert ( + # 0 < self.flatten_axis < len(self.data.shape) + # ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" + # + # if self.scaling_mode == ScalingMode.NO_SCALING: + # self.scale_inv = jnp.empty((0,), dtype=jnp.float32) + # else: + # unpadded_scale_shape = self.scaling_mode.get_scale_shape( + # self.data.shape, + # data_layout=self.data_layout, + # is_colwise=self.is_colwise, + # is_padded=False, + # # expect the flatten_axis wrt the N layout + # flatten_axis=( + # self.flatten_axis + # if self.data_layout == "N" + # else self.data.ndim - self.flatten_axis + # ), + # ) + # unpadded_scale_shape_broadcast = self.scaling_mode.get_scale_shape( + # self.data.shape, + # data_layout=self.data_layout, + # is_colwise=self.is_colwise, + # is_padded=False, + # # expect the flatten_axis wrt the N layout + # flatten_axis=( + # self.flatten_axis + # if self.data_layout == "N" + # else self.data.ndim - self.flatten_axis + # ), + # broadcast_2d_scale_shape_to_1d=True, + # ) + # # Check shape, allowing for batch dimensions from vmap + # # If vmapped, shape will be (batch_size, *expected_shape) + # actual_shape = self.scale_inv.shape + # if actual_shape not in (unpadded_scale_shape, unpadded_scale_shape_broadcast): + # # Check if it's a batched version (extra leading dimensions) + # if len(actual_shape) > len(unpadded_scale_shape): + # # Batched: check that trailing dimensions match + # trailing_shape = actual_shape[-(len(unpadded_scale_shape)):] + # if trailing_shape not in (unpadded_scale_shape, unpadded_scale_shape_broadcast): + # raise AssertionError( + # f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or " + # f"{unpadded_scale_shape_broadcast} (possibly with batch dims) but got {self.scale_inv.shape}." + # ) + # else: + # raise AssertionError( + # f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or " + # f"{unpadded_scale_shape_broadcast} but got {self.scale_inv.shape}." + # ) def tree_flatten(self): """Flattens the tensor for JAX tree operations. @@ -431,10 +445,21 @@ def __post_init__(self): flatten_axis=self.flatten_axis, ) - assert self.scale_inv.shape == expected_scale_shape, ( - f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" - f" scale_inv, got {self.scale_inv.shape}" - ) + # Check shape, allowing for batch dimensions from vmap + actual_shape = self.scale_inv.shape + if actual_shape != expected_scale_shape: + # Check if it's a batched version + if len(actual_shape) > len(expected_scale_shape): + trailing_shape = actual_shape[-(len(expected_scale_shape)) :] + assert trailing_shape == expected_scale_shape, ( + f"Unexpected scale_inv shape! Expected {expected_scale_shape} for padded " + f"scale_inv (possibly with batch dims), got {self.scale_inv.shape}" + ) + else: + raise AssertionError( + f"Unexpected scale_inv shape! Expected {expected_scale_shape} for padded " + f"scale_inv, got {self.scale_inv.shape}" + ) def tree_flatten(self): """Flattens the tensor for JAX tree operations. From 1f02cf41c7b521b82d99058e8f0fb6f2bd5b048e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 3 Dec 2025 21:08:42 +0000 Subject: [PATCH 02/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_einsum.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_einsum.py b/tests/jax/test_einsum.py index 39dffa6787..7580a14638 100644 --- a/tests/jax/test_einsum.py +++ b/tests/jax/test_einsum.py @@ -145,7 +145,9 @@ def test_full_moe_grad(self, B, S, M, E, C, H, recipe): mlp_up_quantizer_sets = self._get_quantizer_sets(recipe, E) mlp_down_quantizer_sets = self._get_quantizer_sets(recipe, E) - tokens = jax.random.normal(jax.random.PRNGKey(0), (B, S, M), dtype=jnp.bfloat16) / jnp.sqrt(M) + tokens = jax.random.normal(jax.random.PRNGKey(0), (B, S, M), dtype=jnp.bfloat16) / jnp.sqrt( + M + ) routing = jax.random.normal(jax.random.PRNGKey(1), (B, S, E, C), dtype=jnp.bfloat16) routing = jax.nn.softmax(routing, axis=-1) # Normalize routing weights up_weights = jax.random.normal( From bf3ebc2ccf98a016ff61f859df7fa2686f36114d Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 10 Dec 2025 15:29:37 +0100 Subject: [PATCH 03/61] code drop Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/CMakeLists.txt | 1 + tests/cpp/operator/test_grouped_gemm.cu | 511 ++++++++++++++++++ .../common/gemm/cublaslt_gemm.cu | 484 +++++++++++++++++ .../common/include/transformer_engine/gemm.h | 36 ++ 4 files changed, 1032 insertions(+) create mode 100644 tests/cpp/operator/test_grouped_gemm.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index b2f14b1892..1392ffdadc 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -30,6 +30,7 @@ add_executable(test_operator test_causal_softmax.cu test_swizzle.cu test_swap_first_dims.cu + test_grouped_gemm.cu ../test_common.cu) # Find required packages diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu new file mode 100644 index 0000000000..0e9c6c6a4d --- /dev/null +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -0,0 +1,511 @@ +/*********************************************************************** + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * + * See LICENSE for license information. + **********************************************************************/ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum class InputCase { + kFP8Delayed, + kFP8Current, + kBF16, +}; + +enum class ShapeCase { + kAllSame, + kSameFirst, + kSameLast, + kAllDifferent, +}; + +// Helper owning GPU buffers that back NVTEGroupedTensor. +// NVTEGroupedTensor does not own memory; data/offsets/scales +// must be allocated and freed by the test. +struct GroupedBuffers { + NVTEGroupedTensor handle{nullptr}; + void* data{nullptr}; + void* scale_inv{nullptr}; + int64_t* first_dims_dev{nullptr}; + int64_t* last_dims_dev{nullptr}; + int64_t* offsets_dev{nullptr}; + void* columnwise_data{nullptr}; + NVTEShape logical_shape{}; + std::vector offsets_host; + std::vector tensor_bytes; + size_t num_tensors{0}; + size_t elem_size{0}; + DType dtype{DType::kFloat32}; + NVTEScalingMode scaling_mode{NVTE_DELAYED_TENSOR_SCALING}; + + GroupedBuffers() = default; + GroupedBuffers(const GroupedBuffers&) = delete; + GroupedBuffers& operator=(const GroupedBuffers&) = delete; + GroupedBuffers(GroupedBuffers&& other) noexcept { + *this = std::move(other); + } + GroupedBuffers& operator=(GroupedBuffers&& other) noexcept { + if (this == &other) return *this; + handle = other.handle; + data = other.data; + scale_inv = other.scale_inv; + first_dims_dev = other.first_dims_dev; + last_dims_dev = other.last_dims_dev; + offsets_dev = other.offsets_dev; + logical_shape = other.logical_shape; + offsets_host = std::move(other.offsets_host); + tensor_bytes = std::move(other.tensor_bytes); + num_tensors = other.num_tensors; + elem_size = other.elem_size; + dtype = other.dtype; + scaling_mode = other.scaling_mode; + + other.handle = nullptr; + other.data = nullptr; + other.scale_inv = nullptr; + other.first_dims_dev = nullptr; + other.last_dims_dev = nullptr; + other.offsets_dev = nullptr; + other.num_tensors = 0; + return *this; + } + + ~GroupedBuffers() { + if (data) { + cudaFree(data); + data = nullptr; + } + if (scale_inv) { + cudaFree(scale_inv); + scale_inv = nullptr; + } + if (columnwise_data) { + cudaFree(columnwise_data); + columnwise_data = nullptr; + } + if (first_dims_dev) { + cudaFree(first_dims_dev); + first_dims_dev = nullptr; + } + if (last_dims_dev) { + cudaFree(last_dims_dev); + last_dims_dev = nullptr; + } + if (offsets_dev) { + cudaFree(offsets_dev); + offsets_dev = nullptr; + } + if (handle) { + nvte_destroy_grouped_tensor(handle); + handle = nullptr; + } + } +}; + +size_t grouped_setup_workspace_size(const size_t num_tensors) { + const size_t ptr_bytes = num_tensors * sizeof(void*); + const size_t int_bytes = num_tensors * sizeof(int); + size_t size = 4 * ptr_bytes + 3 * int_bytes + 2 * ptr_bytes; + const size_t alignment = 256; + size = ((size + alignment - 1) / alignment) * alignment; + return size; +} + +GroupedBuffers build_grouped_tensor(const std::vector& tensors, + const NVTEScalingMode scaling_mode) { + NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build."); + const NVTEShape shape = tensors[0]->rowwise_shape(); + const DType dtype = tensors[0]->dtype(); + const size_t num_tensors = tensors.size(); + const size_t elem_size = typeToSize(dtype); + GroupedBuffers grouped; + grouped.elem_size = elem_size; + grouped.num_tensors = num_tensors; + grouped.dtype = dtype; + grouped.scaling_mode = scaling_mode; + grouped.tensor_bytes.resize(num_tensors); + grouped.offsets_host.resize(num_tensors, 0); + + std::vector first_dims(num_tensors); + std::vector last_dims(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + const auto s = tensors[i]->rowwise_shape(); + NVTE_CHECK(s.ndim == 2, "Grouped GEMM test expects 2D tensors."); + first_dims[i] = static_cast(s.data[0]); + last_dims[i] = static_cast(s.data[1]); + grouped.tensor_bytes[i] = bytes(s, dtype); + } + + const bool same_first = std::all_of(first_dims.begin(), first_dims.end(), + [&](int64_t v) { return v == first_dims[0]; }); + const bool same_last = std::all_of(last_dims.begin(), last_dims.end(), + [&](int64_t v) { return v == last_dims[0]; }); + + std::vector offsets(num_tensors, 0); + auto random_padding = [&]() -> int64_t { + static std::mt19937 gen(12345); + std::uniform_int_distribution dist(0, 3); + return dist(gen); + }; + + auto numel = [&](size_t idx) -> int64_t { + return first_dims[idx] * last_dims[idx]; + }; + + const bool need_offsets = !same_first || !same_last; + if (need_offsets) { + offsets[0] = 0; + for (size_t i = 1; i < num_tensors; ++i) { + offsets[i] = offsets[i - 1] + numel(i - 1) + random_padding(); + } + } else { + for (size_t i = 0; i < num_tensors; ++i) { + offsets[i] = static_cast(i) * numel(0); + } + } + grouped.offsets_host = offsets; + + int64_t logical_first = 0; + int64_t logical_last = 0; + if (same_first && same_last) { + logical_first = first_dims[0] * static_cast(num_tensors); + logical_last = last_dims[0]; + } else if (same_first && !same_last) { + logical_first = first_dims[0]; + logical_last = std::accumulate(last_dims.begin(), last_dims.end(), int64_t{0}); + } else if (!same_first && same_last) { + logical_first = std::accumulate(first_dims.begin(), first_dims.end(), int64_t{0}); + logical_last = last_dims[0]; + } else { + logical_first = 1; + logical_last = 0; + for (size_t i = 0; i < num_tensors; ++i) { + logical_last += first_dims[i] * last_dims[i]; + } + } + size_t logical_data[2] = {static_cast(logical_first), + static_cast(logical_last)}; + grouped.logical_shape = nvte_make_shape(logical_data, 2); + grouped.handle = nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape); + + const int64_t last_idx = static_cast(num_tensors - 1); + const int64_t total_elems = need_offsets + ? (offsets[last_idx] + numel(last_idx)) + : (logical_first * logical_last); + const size_t total_bytes = static_cast(total_elems) * elem_size; + + NVTE_CHECK_CUDA(cudaMalloc(&grouped.data, total_bytes)); + for (size_t i = 0; i < num_tensors; ++i) { + const size_t offset_bytes = static_cast(offsets[i]) * elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data) + offset_bytes, + tensors[i]->rowwise_dptr(), + grouped.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + } + + NVTEBasicTensor data_tensor{grouped.data, static_cast(dtype), grouped.logical_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedRowwiseData, &data_tensor); + + const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype); + if (include_columnwise) { + NVTE_CHECK_CUDA(cudaMalloc(&grouped.columnwise_data, total_bytes)); + for (size_t i = 0; i < num_tensors; ++i) { + const size_t offset_bytes = static_cast(offsets[i]) * elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.columnwise_data) + offset_bytes, + tensors[i]->columnwise_dptr(), + grouped.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + } + NVTEBasicTensor col_tensor{grouped.columnwise_data, + static_cast(dtype), + grouped.logical_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedColumnwiseData, &col_tensor); + } + + if (!same_first) { + NVTE_CHECK_CUDA(cudaMalloc(&grouped.first_dims_dev, num_tensors * sizeof(int64_t))); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev, first_dims.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor fd_tensor{grouped.first_dims_dev, kNVTEInt64, fd_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedFirstDims, &fd_tensor); + } + + if (!same_last) { + NVTE_CHECK_CUDA(cudaMalloc(&grouped.last_dims_dev, num_tensors * sizeof(int64_t))); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev, last_dims.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor ld_tensor{grouped.last_dims_dev, kNVTEInt64, ld_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedLastDims, &ld_tensor); + } + + if (!same_first || !same_last) { + NVTE_CHECK_CUDA(cudaMalloc(&grouped.offsets_dev, num_tensors * sizeof(int64_t))); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev, offsets.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape off_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor off_tensor{grouped.offsets_dev, kNVTEInt64, off_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedTensorOffsets, &off_tensor); + } + + if (isFp8Type(dtype)) { + std::vector scale_inv_cpu(num_tensors, 1.f); + for (size_t i = 0; i < num_tensors; ++i) { + tensors[i]->to_cpu(); + scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr()[0]; + } + NVTE_CHECK_CUDA(cudaMalloc(&grouped.scale_inv, sizeof(float) * num_tensors)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv, scale_inv_cpu.data(), + sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); + NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor scale_tensor{grouped.scale_inv, kNVTEFloat32, scale_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedRowwiseScaleInv, &scale_tensor); + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedColumnwiseScaleInv, &scale_tensor); + } + + return grouped; +} + +Tensor make_fp8_operand(const std::string& name, const std::vector& shape) { + Tensor input_fp32(name + "_fp32", shape, DType::kFloat32); + fillUniform(&input_fp32); + + Tensor fp8(name, shape, TypeInfo::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING); + + nvte_compute_amax(input_fp32.data(), fp8.data(), 0); + QuantizationConfigWrapper config; + nvte_compute_scale_from_amax(fp8.data(), config, 0); + nvte_quantize(input_fp32.data(), fp8.data(), 0); + return fp8; +} + +Tensor make_bf16_operand(const std::string& name, const std::vector& shape) { + Tensor t(name, shape, DType::kBFloat16); + fillUniform(&t); + return t; +} + +struct TestParams { + InputCase input_case; + bool transa; + bool transb; + ShapeCase shape_case; +}; + +std::vector> make_shapes(ShapeCase scase) { + switch (scase) { + case ShapeCase::kAllSame: + return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}}; + case ShapeCase::kSameFirst: // M wspólne, N/K zróżnicowane + return {{64, 64, 32}, {64, 96, 32}, {64, 80, 48}}; + case ShapeCase::kSameLast: // N wspólne, M/K zróżnicowane + return {{48, 80, 32}, {96, 80, 48}, {72, 80, 40}}; + case ShapeCase::kAllDifferent: + default: + return {{48, 80, 32}, {96, 64, 48}, {40, 72, 24}}; + } +} + +void run_grouped_gemm_case(const TestParams& params) { + if (params.input_case != InputCase::kBF16 && + getDeviceComputeCapability() < hopperComputeCapability) { + GTEST_SKIP() << "FP8 grouped GEMM requires Hopper or newer."; + } + + const std::vector> shapes = make_shapes(params.shape_case); + + const size_t num_gemms = shapes.size(); + std::vector A_tensors; + std::vector B_tensors; + std::vector D_multi; + + A_tensors.reserve(num_gemms); + B_tensors.reserve(num_gemms); + D_multi.reserve(num_gemms); + + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + const std::vector a_shape = params.transa ? std::vector{K, M} + : std::vector{M, K}; + const std::vector b_shape = params.transb ? std::vector{N, K} + : std::vector{K, N}; + switch (params.input_case) { + case InputCase::kFP8Current: { + A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape)); + break; + } + case InputCase::kBF16: { + A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape)); + break; + } + } + D_multi.emplace_back(Tensor("D_multi" + std::to_string(i), + std::vector{M, N}, + DType::kBFloat16)); + } + + std::vector A_ptrs(num_gemms); + std::vector B_ptrs(num_gemms); + std::vector D_ptrs(num_gemms); + std::vector bias_ptrs(num_gemms, nullptr); + std::vector gelu_ptrs(num_gemms, nullptr); + std::vector workspaces(num_gemms); + std::vector workspace_ptrs(num_gemms, nullptr); + + const size_t cublas_ws_bytes = 32ull * 1024 * 1024; + + for (size_t i = 0; i < num_gemms; ++i) { + A_ptrs[i] = A_tensors[i].data(); + B_ptrs[i] = B_tensors[i].data(); + D_ptrs[i] = D_multi[i].data(); + workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector{cublas_ws_bytes}, DType::kByte); + workspace_ptrs[i] = workspaces[i].data(); + } + + nvte_multi_tensor_gemm(A_ptrs.data(), + B_ptrs.data(), + D_ptrs.data(), + bias_ptrs.data(), + gelu_ptrs.data(), + static_cast(num_gemms), + params.transa, + params.transb, + false, + workspace_ptrs.data(), + false, + false, + 0, + 0); + + GroupedBuffers grouped_A = build_grouped_tensor(A_tensors, A_tensors[0].scaling_mode()); + GroupedBuffers grouped_B = build_grouped_tensor(B_tensors, B_tensors[0].scaling_mode()); + + std::vector C_tensors; + std::vector D_group_tensors; + C_tensors.reserve(num_gemms); + D_group_tensors.reserve(num_gemms); + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + (void)K; + C_tensors.emplace_back(Tensor("C" + std::to_string(i), + std::vector{static_cast(M), static_cast(N)}, + DType::kBFloat16)); + D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i), + std::vector{static_cast(M), static_cast(N)}, + DType::kBFloat16)); + NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, bytes(D_group_tensors.back().rowwise_shape(), D_group_tensors.back().dtype()))); + } + + std::vector C_views, D_views; + for (size_t i = 0; i < num_gemms; ++i) { + C_views.push_back(&C_tensors[i]); + D_views.push_back(&D_group_tensors[i]); + } + + GroupedBuffers grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING); + GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING); + + Tensor alpha_tensor("alpha", std::vector{1}, DType::kFloat32); + Tensor beta_tensor("beta", std::vector{1}, DType::kFloat32); + const float alpha_val = 1.f; + const float beta_val = 0.f; + NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), &alpha_val, sizeof(float), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), &beta_val, sizeof(float), cudaMemcpyHostToDevice)); + + const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms); + Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); + Tensor cublas_ws("cublas_ws", std::vector{cublas_ws_bytes}, DType::kByte); + + nvte_grouped_gemm(params.transa, + params.transb, + alpha_tensor.data(), + grouped_A.handle, + grouped_B.handle, + beta_tensor.data(), + grouped_C.handle, + grouped_D.handle, + setup_ws.data(), + cublas_ws.data(), + nullptr, + 0, + nullptr, + nullptr, + nullptr); + + for (size_t i = 0; i < num_gemms; ++i) { + Tensor grouped_split("grouped_D" + std::to_string(i), + std::vector{static_cast(std::get<0>(shapes[i])), + static_cast(std::get<1>(shapes[i]))}, + D_multi[i].dtype()); + const size_t offset_bytes = static_cast(grouped_D.offsets_host[i]) * grouped_D.elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(), + static_cast(grouped_D.data) + offset_bytes, + grouped_D.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + grouped_split.to_cpu(); + D_multi[i].to_cpu(); + auto [atol, rtol] = getTolerances(D_multi[i].dtype()); + compareResults("grouped_vs_multi", + grouped_split, + D_multi[i].rowwise_cpu_dptr(), + true, + atol, + rtol); + } +} + +class GroupedGemmTest : public ::testing::TestWithParam {}; + +TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) { + run_grouped_gemm_case(GetParam()); +} + +std::string MakeGroupedGemmTestName(const testing::TestParamInfo& info) { + constexpr const char* kInputNames[] = {"FP8Delayed", "FP8Current", "BF16"}; + constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"}; + const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") + + "tb" + (info.param.transb ? "T" : "N"); + return std::string(kInputNames[static_cast(info.param.input_case)]) + "_" + + kShapeNames[static_cast(info.param.shape_case)] + "_" + layout; +} + +const std::vector kTestParams = { + {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent}, + {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent}, + {InputCase::kFP8Current, false, false, ShapeCase::kAllSame}, + {InputCase::kBF16, true, false, ShapeCase::kSameFirst}, + {InputCase::kBF16, false, true, ShapeCase::kSameLast}, + {InputCase::kBF16, false, false, ShapeCase::kAllSame}, + {InputCase::kBF16, true, true, ShapeCase::kAllDifferent}, +}; + +INSTANTIATE_TEST_SUITE_P(OperatorTest, + GroupedGemmTest, + ::testing::ValuesIn(kTestParams), + MakeGroupedGemmTestName); + +} // namespace + + diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 97e8ec9a3e..53be59cc00 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1104,3 +1104,487 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor cublas_path(); } } + + +// Helper struct to pass per-tensor shape/offset info (pointer or uniform value) +struct TensorShapeInfo { + const int64_t *first_dims; // nullptr if uniform + const int64_t *last_dims; // nullptr if uniform + const int64_t *offsets; // nullptr if need to compute + int64_t uniform_first; // used if first_dims == nullptr + int64_t uniform_last; // used if last_dims == nullptr + + // Create from GroupedTensor + static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { + return { + t->first_dims.has_data() ? static_cast(t->first_dims.dptr) : nullptr, + t->last_dims.has_data() ? static_cast(t->last_dims.dptr) : nullptr, + t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) : nullptr, + t->get_common_first_dim(), + t->get_common_last_dim()}; + } + + // Create for C tensor (uses D's dimensions, only has offsets) + static TensorShapeInfo for_C(const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D) { + return { + nullptr, + nullptr, + C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) : nullptr, + D->get_common_first_dim(), + D->get_common_last_dim()}; + } +}; + +// Helper functions to compute average dimensions from logical_shape for heuristics +// These are hints for cuBLASLt algorithm selection, don't need to be exact +inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor* t) { + // logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first) + // In both cases, dividing by num_tensors gives the average + return static_cast(t->logical_shape.data[0]) / static_cast(t->num_tensors); +} + +inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor* t) { + if (t->all_same_last_dim()) { + // logical_shape[1] is the common N + return static_cast(t->logical_shape.data[1]); + } else { + // logical_shape[1] is sum_of_N, divide by num_tensors + return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); + } +} + +// Workspace layout for grouped GEMM +struct GroupedGemmSetupWorkspace { + void **A_ptrs; + void **B_ptrs; + void **C_ptrs; + void **D_ptrs; + int *M; + int *N; + int *K; + float **alpha_ptrs; + float **beta_ptrs; + + // Initialize from workspace buffer + static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors, size_t alignment) { + GroupedGemmSetupWorkspace ws; + size_t offset = 0; + const size_t ptr_size = num_tensors * sizeof(void *); + const size_t int_size = num_tensors * sizeof(int); + + ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.M = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; + ws.N = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; + ws.K = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; + ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + + offset = ((offset + alignment - 1) / alignment) * alignment; + + return ws; + } + + // Calculate required size for setup workspace (pointer arrays + M/N/K + alpha/beta ptrs) + static size_t required_setup_size(size_t num_tensors, size_t alignment) { + const size_t ptr_size = num_tensors * sizeof(void *); + const size_t int_size = num_tensors * sizeof(int); + size_t size = 4 * ptr_size + 3 * int_size + 2 * ptr_size; // M, N, K only (no LDA/LDB/LDC/LDD) + size = ((size + alignment - 1) / alignment) * alignment; + return size; + } +}; + +// ----------------------------------------------------------------------------- +// Helper routines to keep nvte_grouped_gemm readable +// ----------------------------------------------------------------------------- +inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor* inputA, + const transformer_engine::GroupedTensor* inputB, + const transformer_engine::GroupedTensor* inputC, + const transformer_engine::GroupedTensor* outputD) { + const size_t num_tensors = inputA->num_tensors; + NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); + NVTE_CHECK(inputB->num_tensors == num_tensors, + "Grouped GEMM: A and B must have the same num_tensors"); + NVTE_CHECK(inputC->num_tensors == num_tensors, + "Grouped GEMM: A and C must have the same num_tensors"); + NVTE_CHECK(outputD->num_tensors == num_tensors, + "Grouped GEMM: A and D must have the same num_tensors"); + + auto is_fp8_or_16bit = [](DType dtype) { + return dtype == DType::kFloat8E4M3 || dtype == DType::kFloat8E5M2 || + dtype == DType::kBFloat16 || dtype == DType::kFloat16; + }; + auto is_output_dtype = [](DType dtype) { + return dtype == DType::kBFloat16 || dtype == DType::kFloat16 || dtype == DType::kFloat32; + }; + NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), + "Grouped GEMM inputs must be FP8, BF16, or FP16."); + NVTE_CHECK(is_output_dtype(inputC->dtype()) && is_output_dtype(outputD->dtype()), + "Grouped GEMM outputs must be BF16, FP16, or FP32."); + NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), + "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); + NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), + "Grouped GEMM: B tensor is missing both row-wise and column-wise data"); +} + +// Select row-wise vs column-wise storage and adjust transpose flag for grouped GEMM. +// Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and +// fallback to column-wise data when row-wise is absent. +struct GroupedOperandSelection { + const char* base = nullptr; + transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; + bool trans = false; + bool use_columnwise = false; +}; + +inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor* t, + bool trans, bool is_A) { + using namespace transformer_engine; + const bool has_row = t->has_data(); + const bool has_col = t->has_columnwise_data(); + NVTE_CHECK(has_row || has_col, "Grouped GEMM operand is missing both row-wise and column-wise data"); + + // Not yet supported in grouped GEMM: block scaling, MXFP8, NVFP4 specialized layouts. + const auto sm = t->scaling_mode; + NVTE_CHECK(sm != NVTE_BLOCK_SCALING_1D && sm != NVTE_BLOCK_SCALING_2D && + !is_mxfp_scaling(sm) && !is_nvfp_scaling(sm), + "Grouped GEMM does not yet support NVFP4/MXFP8/block scaling operand selection"); + + const DType row_dtype = t->data.dtype; + const DType col_dtype = t->columnwise_data.dtype; + GroupedOperandSelection sel; + sel.trans = trans; + + const DType rep_dtype = has_row ? row_dtype : col_dtype; + const bool is_fp8 = is_fp8_dtype(rep_dtype); + const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); + + // Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed. + if (is_fp8 && !non_tn_fp8_ok) { + if (is_A) { + if (!sel.trans) { + NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"); + sel.base = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = true; // using pre-transposed storage + sel.use_columnwise = true; + return sel; + } + } else { // B + if (sel.trans) { + NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"); + sel.base = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = false; // using pre-transposed storage + sel.use_columnwise = true; + return sel; + } + } + } + + // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). + if (!has_row && has_col) { + sel.base = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = !sel.trans; + sel.use_columnwise = true; + return sel; + } + + // Default: use row-wise data (or column-wise if row-wise absent, covered above). + sel.base = static_cast(has_row ? t->data.dptr : t->columnwise_data.dptr); + sel.dtype = has_row ? row_dtype : col_dtype; + sel.use_columnwise = !has_row && has_col; + return sel; +} + +inline void* validate_and_get_workspace_ptr(transformer_engine::Tensor* ws, size_t required_size, + const char* workspace_name) { + NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null."); + const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype); + NVTE_CHECK(provided_size >= required_size, + "Grouped GEMM: Insufficient ", workspace_name, ". Required: ", required_size, + " bytes, Available: ", provided_size, " bytes."); + return ws->data.dptr; +} + +inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t& descA, + cublasLtMatrixLayoutOpaque_t& descB, + cublasLtMatrixLayoutOpaque_t& descC, + cublasLtMatrixLayoutOpaque_t& descD, + const GroupedGemmWorkspace& ws, bool transa, bool transb, + bool a_columnwise, bool b_columnwise, + size_t num_tensors, cudaDataType_t A_type, cudaDataType_t B_type, + cudaDataType_t D_type) { + // For column-major layout: leading dimension is the number of rows in storage. + // If columnwise data was chosen, storage is already transposed. + const int* rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); + const int* cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); + const int* lda = rowa; + const int* rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); + const int* colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); + const int* ldb = rowb; + + NVTE_CHECK_CUBLAS( + cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, (void*)rowa, (void*)cola, (void*)lda)); + NVTE_CHECK_CUBLAS( + cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, (void*)rowb, (void*)colb, (void*)ldb)); + NVTE_CHECK_CUBLAS( + cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, (void*)ws.M, (void*)ws.N, (void*)ws.M)); + NVTE_CHECK_CUBLAS( + cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, (void*)ws.M, (void*)ws.N, (void*)ws.M)); +} + +inline void init_matmul_desc(cublasLtMatmulDescOpaque_t& matmulDesc, cublasOperation_t op_A, + cublasOperation_t op_B) { + NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + + NVTE_CHECK_CUBLAS( + cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(op_A))); + NVTE_CHECK_CUBLAS( + cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(op_B))); + + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, sizeof(pointer_mode))); + + int64_t alphabeta_batch_stride = 1; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); +} + +inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, + cublasLtMatmulDescOpaque_t& matmulDesc, + cublasLtMatrixLayoutOpaque_t& descA, + cublasLtMatrixLayoutOpaque_t& descB, + cublasLtMatrixLayoutOpaque_t& descC, + cublasLtMatrixLayoutOpaque_t& descD, int64_t avg_m, + int64_t avg_n, int64_t avg_k) { + cublasLtMatmulPreferenceOpaque_t preference; + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference)); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &kGroupedGemmCublasWorkspaceSize, + sizeof(size_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_COLS, &avg_n, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_AVERAGE_REDUCTION_DIM, &avg_k, sizeof(int64_t))); + + cublasLtMatmulHeuristicResult_t heuristicResult; + int returnedResults = 0; + auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD, + &preference, 1, &heuristicResult, &returnedResults); + NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, "Unable to find suitable cuBLAS grouped GEMM algorithm"); + NVTE_CHECK_CUBLAS(status); + NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM"); + return heuristicResult.algo; +} + +// Single kernel that sets up all GEMM parameters. +// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix M/N/K, +// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes. +// We bridge the mismatch on GPU by computing per-group pointers and dims in one kernel. +__global__ void setup_grouped_gemm_kernel( + // Output arrays + void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, + int *M, int *N, int *K, + float **alpha_ptrs, float **beta_ptrs, + // Base pointers + const char *a_base, const char *b_base, const char *c_base, char *d_base, + // Dimension info (per tensor) + TensorShapeInfo A_meta, TensorShapeInfo B_meta, + TensorShapeInfo C_meta, TensorShapeInfo D_meta, + // Element sizes + size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, + // Alpha/beta pointers (same for all groups) + float *alpha_ptr, float *beta_ptr, + // Transpose flags + bool transa, bool transb, + // Number of tensors + size_t num_tensors) { + + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_tensors) return; + + // Get dimensions for this tensor (from array or uniform value) + int64_t a_first = A_meta.first_dims ? A_meta.first_dims[idx] : A_meta.uniform_first; + int64_t a_last = A_meta.last_dims ? A_meta.last_dims[idx] : A_meta.uniform_last; + int64_t b_first = B_meta.first_dims ? B_meta.first_dims[idx] : B_meta.uniform_first; + int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last; + + // Compute offsets (from array or compute from uniform dims) + int64_t a_offset = A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); + int64_t b_offset = B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); + int64_t c_offset = C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); + int64_t d_offset = D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); + + // Compute data pointers + A_ptrs[idx] = const_cast(a_base) + a_offset * a_elem_size; + B_ptrs[idx] = const_cast(b_base) + b_offset * b_elem_size; + C_ptrs[idx] = const_cast(c_base) + c_offset * c_elem_size; + D_ptrs[idx] = d_base + d_offset * d_elem_size; + + // Compute M, N, K dimensions + M[idx] = static_cast(transa ? a_last : a_first); + K[idx] = static_cast(transa ? a_first : a_last); + N[idx] = static_cast(transb ? b_first : b_last); + + // Fill alpha/beta pointers (same for all groups) + alpha_ptrs[idx] = alpha_ptr; + beta_ptrs[idx] = beta_ptr; +} + +// Launch the setup kernel to populate workspace arrays +inline void launch_grouped_gemm_setup( + const GroupedGemmWorkspace &ws, + const transformer_engine::GroupedTensor *A, + const transformer_engine::GroupedTensor *B, + const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D, + const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor, + const char *a_base, const char *b_base, + size_t a_elem_size, size_t b_elem_size, + bool transa, bool transb, + size_t num_tensors, cudaStream_t stream) { + + TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A); + TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B); + TensorShapeInfo C_meta = TensorShapeInfo::for_C(C, D); + TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); + + const char *c_base = static_cast(C->data.dptr); + char *d_base = static_cast(D->data.dptr); + + const size_t c_elem_size = transformer_engine::typeToSize(C->dtype()); + const size_t d_elem_size = transformer_engine::typeToSize(D->dtype()); + + const int threads_per_block = 256; + const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; + + setup_grouped_gemm_kernel<<>>( + ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, + ws.M, ws.N, ws.K, + ws.alpha_ptrs, ws.beta_ptrs, + a_base, b_base, c_base, d_base, + A_meta, B_meta, C_meta, D_meta, + a_elem_size, b_elem_size, c_elem_size, d_elem_size, + static_cast(alpha_tensor->data.dptr), + static_cast(beta_tensor->data.dptr), + transa, transb, num_tensors); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +// Constants for grouped GEMM workspace +static constexpr size_t kGroupedGemmAlignment = 256; +static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB + +inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { + return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); +} + +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, + const NVTEGroupedTensor A, const NVTEGroupedTensor B, + const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, + NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEMatmulConfig config, cudaStream_t stream, + const int64_t* avg_m, const int64_t* avg_n, const int64_t* avg_k) { + NVTE_API_CALL(nvte_grouped_gemm); + using namespace transformer_engine; + + // Convert to internal types + const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); + const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); + const GroupedTensor *inputC = convertNVTEGroupedTensorCheck(C); + GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D); + const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); + const Tensor *beta_tensor = convertNVTETensorCheck(beta); + Tensor *wspace_setup = convertNVTETensor(workspace_setup); + Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); + + // Validate inputs and num_tensors + validate_grouped_gemm_inputs(inputA, inputB, inputC, outputD); + const size_t num_tensors = inputA->num_tensors; + + // Select operand storage (row-wise vs column-wise) and adjust transpose flags to + // mirror the non-grouped GEMM logic for FP8 layout constraints. + bool transa_flag = static_cast(transa); + bool transb_flag = static_cast(transb); + const auto A_sel = select_grouped_operand(inputA, transa_flag, /*is_A=*/true); + const auto B_sel = select_grouped_operand(inputB, transb_flag, /*is_A=*/false); + transa_flag = A_sel.trans; + transb_flag = B_sel.trans; + const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); + const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); + + // Workspaces: setup (pointer arrays) and cuBLAS + const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); + const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize; + + void* setup_workspace_ptr = + validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, "Grouped GEMM setup workspace"); + void* cublas_workspace_ptr = + validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, "Grouped GEMM cuBLAS workspace"); + + NVTE_CHECK(cublas_workspace_ptr != nullptr, "Grouped GEMM: cuBLAS workspace pointer is null"); + + auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( + static_cast(setup_workspace_ptr), num_tensors, kGroupedGemmAlignment); + launch_grouped_gemm_setup(setup_workspace, inputA, inputB, inputC, outputD, + alpha_tensor, beta_tensor, + A_sel.base, B_sel.base, a_elem_size, b_elem_size, + transa_flag, transb_flag, + num_tensors, stream); + + // Get cuBLAS handle + using cublasHandleManager = detail::HandleManager; + cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); + + // Get data types + const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); + const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); + const cudaDataType_t D_type = get_cuda_dtype(outputD->dtype()); + + // Setup cuBLAS operations + cublasOperation_t op_A = transa_flag ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t op_B = transb_flag ? CUBLAS_OP_T : CUBLAS_OP_N; + + // Create grouped matrix layouts + cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; + init_matrix_layouts(descA, descB, descC, descD, setup_workspace, + transa_flag, transb_flag, A_sel.use_columnwise, B_sel.use_columnwise, + num_tensors, A_type, B_type, D_type); + + // Create matmul descriptor + cublasLtMatmulDescOpaque_t matmulDesc; + init_matmul_desc(matmulDesc, op_A, op_B); + + // Compute average dimensions for heuristics + // K dimension: if transa, K is A's first dim; if not, K is A's last dim + int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); + int64_t avg_n_val = avg_n ? *avg_n : compute_avg_last_dim(outputD); + int64_t avg_k_val = + avg_k ? *avg_k : (transa_flag ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA)); + + // Heuristic selection + cublasLtMatmulAlgo_t algo = + select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, descD, avg_m_val, avg_n_val, + avg_k_val); + + // Execute the grouped GEMM + NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs, + setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, + setup_workspace.beta_ptrs, setup_workspace.C_ptrs, + &descC, setup_workspace.D_ptrs, &descD, + &algo, cublas_workspace_ptr, + kGroupedGemmCublasWorkspaceSize, stream)); +} diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 950014cc9b..51241aef6b 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -228,6 +228,42 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor bool transa, bool transb, bool grad, NVTETensor *workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C + * + * Performs batched GEMM on a collection of matrices with potentially different shapes. + * All tensors in the group must have compatible dimensions for matrix multiplication. + * Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous + * memory layout and shape metadata. + * + * \param[in] transa Whether to transpose A matrices. + * \param[in] transb Whether to transpose B matrices. + * \param[in] alpha Scale multiplier for A @ B (NVTETensor with num_tensors elements, + * or single element for uniform alpha). + * \param[in] A Input grouped tensor A. + * \param[in] B Input grouped tensor B. + * \param[in] beta Scale multiplier for C (NVTETensor with num_tensors elements, + * or single element for uniform beta). + * \param[in] C Input grouped tensor C (can be NULL for beta=0). + * \param[out] D Output grouped tensor D. + * \param[in] workspace Workspace tensor for intermediate computations. + * \param[in] config Matrix multiplication configuration. + * \param[in] stream CUDA stream for the operation. + * + * Requirements: + * - A, B, C (if provided), D must have the same num_tensors + * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] + * - Shape compatibility: if transa=false, transb=false: + * - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i]) + */ +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, + const NVTEGroupedTensor A, const NVTEGroupedTensor B, + const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, + NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEMatmulConfig config, cudaStream_t stream, + const int64_t* avg_m, const int64_t* avg_n, const int64_t* avg_k); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus From 76293d4dc9ebb8a7e1c7ba2ae47f866d56998d33 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Dec 2025 14:32:15 +0000 Subject: [PATCH 04/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/cpp/operator/test_grouped_gemm.cu | 2 - .../common/gemm/cublaslt_gemm.cu | 279 +++++++++--------- .../common/include/transformer_engine/gemm.h | 11 +- 3 files changed, 141 insertions(+), 151 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 0e9c6c6a4d..d346e06887 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -507,5 +507,3 @@ INSTANTIATE_TEST_SUITE_P(OperatorTest, MakeGroupedGemmTestName); } // namespace - - diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 53be59cc00..2c8c2093c6 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1105,46 +1105,42 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor } } - // Helper struct to pass per-tensor shape/offset info (pointer or uniform value) struct TensorShapeInfo { - const int64_t *first_dims; // nullptr if uniform - const int64_t *last_dims; // nullptr if uniform - const int64_t *offsets; // nullptr if need to compute - int64_t uniform_first; // used if first_dims == nullptr - int64_t uniform_last; // used if last_dims == nullptr + const int64_t *first_dims; // nullptr if uniform + const int64_t *last_dims; // nullptr if uniform + const int64_t *offsets; // nullptr if need to compute + int64_t uniform_first; // used if first_dims == nullptr + int64_t uniform_last; // used if last_dims == nullptr // Create from GroupedTensor static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { - return { - t->first_dims.has_data() ? static_cast(t->first_dims.dptr) : nullptr, - t->last_dims.has_data() ? static_cast(t->last_dims.dptr) : nullptr, - t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) : nullptr, - t->get_common_first_dim(), - t->get_common_last_dim()}; + return {t->first_dims.has_data() ? static_cast(t->first_dims.dptr) : nullptr, + t->last_dims.has_data() ? static_cast(t->last_dims.dptr) : nullptr, + t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) + : nullptr, + t->get_common_first_dim(), t->get_common_last_dim()}; } // Create for C tensor (uses D's dimensions, only has offsets) static TensorShapeInfo for_C(const transformer_engine::GroupedTensor *C, const transformer_engine::GroupedTensor *D) { - return { - nullptr, - nullptr, - C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) : nullptr, - D->get_common_first_dim(), - D->get_common_last_dim()}; + return {nullptr, nullptr, + C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) + : nullptr, + D->get_common_first_dim(), D->get_common_last_dim()}; } }; // Helper functions to compute average dimensions from logical_shape for heuristics // These are hints for cuBLASLt algorithm selection, don't need to be exact -inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor* t) { +inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor *t) { // logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first) // In both cases, dividing by num_tensors gives the average return static_cast(t->logical_shape.data[0]) / static_cast(t->num_tensors); } -inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor* t) { +inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) { if (t->all_same_last_dim()) { // logical_shape[1] is the common N return static_cast(t->logical_shape.data[1]); @@ -1167,21 +1163,31 @@ struct GroupedGemmSetupWorkspace { float **beta_ptrs; // Initialize from workspace buffer - static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors, size_t alignment) { + static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors, + size_t alignment) { GroupedGemmSetupWorkspace ws; size_t offset = 0; const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); - ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - ws.M = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; - ws.N = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; - ws.K = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; - ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.M = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.N = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.K = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; offset = ((offset + alignment - 1) / alignment) * alignment; @@ -1201,10 +1207,10 @@ struct GroupedGemmSetupWorkspace { // ----------------------------------------------------------------------------- // Helper routines to keep nvte_grouped_gemm readable // ----------------------------------------------------------------------------- -inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor* inputA, - const transformer_engine::GroupedTensor* inputB, - const transformer_engine::GroupedTensor* inputC, - const transformer_engine::GroupedTensor* outputD) { +inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor *inputA, + const transformer_engine::GroupedTensor *inputB, + const transformer_engine::GroupedTensor *inputC, + const transformer_engine::GroupedTensor *outputD) { const size_t num_tensors = inputA->num_tensors; NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); NVTE_CHECK(inputB->num_tensors == num_tensors, @@ -1235,23 +1241,24 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor // Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and // fallback to column-wise data when row-wise is absent. struct GroupedOperandSelection { - const char* base = nullptr; + const char *base = nullptr; transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; bool trans = false; bool use_columnwise = false; }; -inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor* t, +inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor *t, bool trans, bool is_A) { using namespace transformer_engine; const bool has_row = t->has_data(); const bool has_col = t->has_columnwise_data(); - NVTE_CHECK(has_row || has_col, "Grouped GEMM operand is missing both row-wise and column-wise data"); + NVTE_CHECK(has_row || has_col, + "Grouped GEMM operand is missing both row-wise and column-wise data"); // Not yet supported in grouped GEMM: block scaling, MXFP8, NVFP4 specialized layouts. const auto sm = t->scaling_mode; - NVTE_CHECK(sm != NVTE_BLOCK_SCALING_1D && sm != NVTE_BLOCK_SCALING_2D && - !is_mxfp_scaling(sm) && !is_nvfp_scaling(sm), + NVTE_CHECK(sm != NVTE_BLOCK_SCALING_1D && sm != NVTE_BLOCK_SCALING_2D && !is_mxfp_scaling(sm) && + !is_nvfp_scaling(sm), "Grouped GEMM does not yet support NVFP4/MXFP8/block scaling operand selection"); const DType row_dtype = t->data.dtype; @@ -1268,7 +1275,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: if (is_A) { if (!sel.trans) { NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"); - sel.base = static_cast(t->columnwise_data.dptr); + sel.base = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = true; // using pre-transposed storage sel.use_columnwise = true; @@ -1277,7 +1284,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: } else { // B if (sel.trans) { NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"); - sel.base = static_cast(t->columnwise_data.dptr); + sel.base = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = false; // using pre-transposed storage sel.use_columnwise = true; @@ -1288,7 +1295,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). if (!has_row && has_col) { - sel.base = static_cast(t->columnwise_data.dptr); + sel.base = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = !sel.trans; sel.use_columnwise = true; @@ -1296,81 +1303,81 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: } // Default: use row-wise data (or column-wise if row-wise absent, covered above). - sel.base = static_cast(has_row ? t->data.dptr : t->columnwise_data.dptr); + sel.base = static_cast(has_row ? t->data.dptr : t->columnwise_data.dptr); sel.dtype = has_row ? row_dtype : col_dtype; - sel.use_columnwise = !has_row && has_col; + sel.use_columnwise = !has_row && has_col; return sel; } -inline void* validate_and_get_workspace_ptr(transformer_engine::Tensor* ws, size_t required_size, - const char* workspace_name) { +inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size_t required_size, + const char *workspace_name) { NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null."); const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype); - NVTE_CHECK(provided_size >= required_size, - "Grouped GEMM: Insufficient ", workspace_name, ". Required: ", required_size, - " bytes, Available: ", provided_size, " bytes."); + NVTE_CHECK(provided_size >= required_size, "Grouped GEMM: Insufficient ", workspace_name, + ". Required: ", required_size, " bytes, Available: ", provided_size, " bytes."); return ws->data.dptr; } -inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t& descA, - cublasLtMatrixLayoutOpaque_t& descB, - cublasLtMatrixLayoutOpaque_t& descC, - cublasLtMatrixLayoutOpaque_t& descD, - const GroupedGemmWorkspace& ws, bool transa, bool transb, - bool a_columnwise, bool b_columnwise, +inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, + cublasLtMatrixLayoutOpaque_t &descB, + cublasLtMatrixLayoutOpaque_t &descC, + cublasLtMatrixLayoutOpaque_t &descD, const GroupedGemmWorkspace &ws, + bool transa, bool transb, bool a_columnwise, bool b_columnwise, size_t num_tensors, cudaDataType_t A_type, cudaDataType_t B_type, cudaDataType_t D_type) { // For column-major layout: leading dimension is the number of rows in storage. // If columnwise data was chosen, storage is already transposed. - const int* rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); - const int* cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); - const int* lda = rowa; - const int* rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); - const int* colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); - const int* ldb = rowb; - - NVTE_CHECK_CUBLAS( - cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, (void*)rowa, (void*)cola, (void*)lda)); - NVTE_CHECK_CUBLAS( - cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, (void*)rowb, (void*)colb, (void*)ldb)); - NVTE_CHECK_CUBLAS( - cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, (void*)ws.M, (void*)ws.N, (void*)ws.M)); - NVTE_CHECK_CUBLAS( - cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, (void*)ws.M, (void*)ws.N, (void*)ws.M)); + const int *rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); + const int *cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); + const int *lda = rowa; + const int *rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); + const int *colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); + const int *ldb = rowb; + + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, (void *)rowa, + (void *)cola, (void *)lda)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, (void *)rowb, + (void *)colb, (void *)ldb)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, (void *)ws.M, + (void *)ws.N, (void *)ws.M)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, (void *)ws.M, + (void *)ws.N, (void *)ws.M)); } -inline void init_matmul_desc(cublasLtMatmulDescOpaque_t& matmulDesc, cublasOperation_t op_A, +inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, cublasOperation_t op_B) { NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); - NVTE_CHECK_CUBLAS( - cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(op_A))); - NVTE_CHECK_CUBLAS( - cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(op_B))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, + sizeof(op_A))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, + sizeof(op_B))); cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); int64_t alphabeta_batch_stride = 1; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, &alphabeta_batch_stride, sizeof(int64_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, &alphabeta_batch_stride, sizeof(int64_t))); } inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, - cublasLtMatmulDescOpaque_t& matmulDesc, - cublasLtMatrixLayoutOpaque_t& descA, - cublasLtMatrixLayoutOpaque_t& descB, - cublasLtMatrixLayoutOpaque_t& descC, - cublasLtMatrixLayoutOpaque_t& descD, int64_t avg_m, - int64_t avg_n, int64_t avg_k) { + cublasLtMatmulDescOpaque_t &matmulDesc, + cublasLtMatrixLayoutOpaque_t &descA, + cublasLtMatrixLayoutOpaque_t &descB, + cublasLtMatrixLayoutOpaque_t &descC, + cublasLtMatrixLayoutOpaque_t &descD, + int64_t avg_m, int64_t avg_n, int64_t avg_k) { cublasLtMatmulPreferenceOpaque_t preference; NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference)); - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &kGroupedGemmCublasWorkspaceSize, - sizeof(size_t))); + NVTE_CHECK_CUBLAS( + cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &kGroupedGemmCublasWorkspaceSize, sizeof(size_t))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -1382,7 +1389,8 @@ inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, int returnedResults = 0; auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD, &preference, 1, &heuristicResult, &returnedResults); - NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, "Unable to find suitable cuBLAS grouped GEMM algorithm"); + NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, + "Unable to find suitable cuBLAS grouped GEMM algorithm"); NVTE_CHECK_CUBLAS(status); NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM"); return heuristicResult.algo; @@ -1394,14 +1402,12 @@ inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, // We bridge the mismatch on GPU by computing per-group pointers and dims in one kernel. __global__ void setup_grouped_gemm_kernel( // Output arrays - void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, - int *M, int *N, int *K, + void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *M, int *N, int *K, float **alpha_ptrs, float **beta_ptrs, // Base pointers const char *a_base, const char *b_base, const char *c_base, char *d_base, // Dimension info (per tensor) - TensorShapeInfo A_meta, TensorShapeInfo B_meta, - TensorShapeInfo C_meta, TensorShapeInfo D_meta, + TensorShapeInfo A_meta, TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, // Element sizes size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, // Alpha/beta pointers (same for all groups) @@ -1410,7 +1416,6 @@ __global__ void setup_grouped_gemm_kernel( bool transa, bool transb, // Number of tensors size_t num_tensors) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= num_tensors) return; @@ -1421,10 +1426,14 @@ __global__ void setup_grouped_gemm_kernel( int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last; // Compute offsets (from array or compute from uniform dims) - int64_t a_offset = A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); - int64_t b_offset = B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); - int64_t c_offset = C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); - int64_t d_offset = D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); + int64_t a_offset = + A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); + int64_t b_offset = + B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); + int64_t c_offset = + C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); + int64_t d_offset = + D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); // Compute data pointers A_ptrs[idx] = const_cast(a_base) + a_offset * a_elem_size; @@ -1444,18 +1453,12 @@ __global__ void setup_grouped_gemm_kernel( // Launch the setup kernel to populate workspace arrays inline void launch_grouped_gemm_setup( - const GroupedGemmWorkspace &ws, - const transformer_engine::GroupedTensor *A, - const transformer_engine::GroupedTensor *B, - const transformer_engine::GroupedTensor *C, - const transformer_engine::GroupedTensor *D, - const transformer_engine::Tensor *alpha_tensor, - const transformer_engine::Tensor *beta_tensor, - const char *a_base, const char *b_base, - size_t a_elem_size, size_t b_elem_size, - bool transa, bool transb, - size_t num_tensors, cudaStream_t stream) { - + const GroupedGemmWorkspace &ws, const transformer_engine::GroupedTensor *A, + const transformer_engine::GroupedTensor *B, const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor, const char *a_base, const char *b_base, + size_t a_elem_size, size_t b_elem_size, bool transa, bool transb, size_t num_tensors, + cudaStream_t stream) { TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A); TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B); TensorShapeInfo C_meta = TensorShapeInfo::for_C(C, D); @@ -1471,15 +1474,10 @@ inline void launch_grouped_gemm_setup( const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; setup_grouped_gemm_kernel<<>>( - ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, - ws.M, ws.N, ws.K, - ws.alpha_ptrs, ws.beta_ptrs, - a_base, b_base, c_base, d_base, - A_meta, B_meta, C_meta, D_meta, - a_elem_size, b_elem_size, c_elem_size, d_elem_size, - static_cast(alpha_tensor->data.dptr), - static_cast(beta_tensor->data.dptr), - transa, transb, num_tensors); + ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.M, ws.N, ws.K, ws.alpha_ptrs, ws.beta_ptrs, + a_base, b_base, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, b_elem_size, + c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), + static_cast(beta_tensor->data.dptr), transa, transb, num_tensors); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -1492,12 +1490,11 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); } -void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, - const NVTEGroupedTensor A, const NVTEGroupedTensor B, - const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, - NVTETensor workspace_setup, NVTETensor workspace_cublas, - NVTEMatmulConfig config, cudaStream_t stream, - const int64_t* avg_m, const int64_t* avg_n, const int64_t* avg_k) { +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, + const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, + NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m, + const int64_t *avg_n, const int64_t *avg_k) { NVTE_API_CALL(nvte_grouped_gemm); using namespace transformer_engine; @@ -1530,20 +1527,18 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize; - void* setup_workspace_ptr = - validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, "Grouped GEMM setup workspace"); - void* cublas_workspace_ptr = - validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, "Grouped GEMM cuBLAS workspace"); + void *setup_workspace_ptr = validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, + "Grouped GEMM setup workspace"); + void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, + "Grouped GEMM cuBLAS workspace"); NVTE_CHECK(cublas_workspace_ptr != nullptr, "Grouped GEMM: cuBLAS workspace pointer is null"); auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( - static_cast(setup_workspace_ptr), num_tensors, kGroupedGemmAlignment); - launch_grouped_gemm_setup(setup_workspace, inputA, inputB, inputC, outputD, - alpha_tensor, beta_tensor, - A_sel.base, B_sel.base, a_elem_size, b_elem_size, - transa_flag, transb_flag, - num_tensors, stream); + static_cast(setup_workspace_ptr), num_tensors, kGroupedGemmAlignment); + launch_grouped_gemm_setup(setup_workspace, inputA, inputB, inputC, outputD, alpha_tensor, + beta_tensor, A_sel.base, B_sel.base, a_elem_size, b_elem_size, + transa_flag, transb_flag, num_tensors, stream); // Get cuBLAS handle using cublasHandleManager = detail::HandleManager; @@ -1560,9 +1555,9 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, // Create grouped matrix layouts cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; - init_matrix_layouts(descA, descB, descC, descD, setup_workspace, - transa_flag, transb_flag, A_sel.use_columnwise, B_sel.use_columnwise, - num_tensors, A_type, B_type, D_type); + init_matrix_layouts(descA, descB, descC, descD, setup_workspace, transa_flag, transb_flag, + A_sel.use_columnwise, B_sel.use_columnwise, num_tensors, A_type, B_type, + D_type); // Create matmul descriptor cublasLtMatmulDescOpaque_t matmulDesc; @@ -1576,15 +1571,13 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, avg_k ? *avg_k : (transa_flag ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA)); // Heuristic selection - cublasLtMatmulAlgo_t algo = - select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, descD, avg_m_val, avg_n_val, - avg_k_val); + cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, + descD, avg_m_val, avg_n_val, avg_k_val); // Execute the grouped GEMM NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs, - setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, - setup_workspace.beta_ptrs, setup_workspace.C_ptrs, - &descC, setup_workspace.D_ptrs, &descD, - &algo, cublas_workspace_ptr, + setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, + setup_workspace.beta_ptrs, setup_workspace.C_ptrs, &descC, + setup_workspace.D_ptrs, &descD, &algo, cublas_workspace_ptr, kGroupedGemmCublasWorkspaceSize, stream)); } diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 51241aef6b..948058295e 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -257,12 +257,11 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * - Shape compatibility: if transa=false, transb=false: * - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i]) */ -void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, - const NVTEGroupedTensor A, const NVTEGroupedTensor B, - const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, - NVTETensor workspace_setup, NVTETensor workspace_cublas, - NVTEMatmulConfig config, cudaStream_t stream, - const int64_t* avg_m, const int64_t* avg_n, const int64_t* avg_k); +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, + const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, + NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m, + const int64_t *avg_n, const int64_t *avg_k); #ifdef __cplusplus } // extern "C" From 296d77362099c52fa8e19a299f4a4134dc184096 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 10 Dec 2025 18:25:39 +0100 Subject: [PATCH 05/61] Add FP8 scale support and fix alignment for grouped GEMM - Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM - Fix random padding in tests to ensure 16-byte alignment for all dtypes - Reorder GroupedGemmSetupWorkspace members for natural alignment - Remove debug prints Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 55 +++++--- .../common/gemm/cublaslt_gemm.cu | 119 +++++++++++++----- .../common/include/transformer_engine/gemm.h | 2 + 3 files changed, 131 insertions(+), 45 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index d346e06887..bff175f405 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -1,8 +1,8 @@ -/*********************************************************************** - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. - **********************************************************************/ + ************************************************************************/ #include #include @@ -16,6 +16,8 @@ #include #include +#include +#include #include #include "../test_common.h" @@ -136,7 +138,7 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, const NVTEShape shape = tensors[0]->rowwise_shape(); const DType dtype = tensors[0]->dtype(); const size_t num_tensors = tensors.size(); - const size_t elem_size = typeToSize(dtype); + const size_t elem_size = typeToNumBits(dtype) / 8; GroupedBuffers grouped; grouped.elem_size = elem_size; grouped.num_tensors = num_tensors; @@ -162,9 +164,13 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, std::vector offsets(num_tensors, 0); auto random_padding = [&]() -> int64_t { + // Random padding ensuring 16-byte alignment regardless of element size + // cuBLAS requires aligned pointers for vectorized loads static std::mt19937 gen(12345); std::uniform_int_distribution dist(0, 3); - return dist(gen); + // Calculate elements needed for 16-byte alignment + const size_t align_elements = (16 * 8) / typeToNumBits(dtype); // 16 bytes / element_size + return dist(gen) * static_cast(align_elements); }; auto numel = [&](size_t idx) -> int64_t { @@ -301,7 +307,12 @@ Tensor make_fp8_operand(const std::string& name, const std::vector& shap Tensor make_bf16_operand(const std::string& name, const std::vector& shape) { Tensor t(name, shape, DType::kBFloat16); - fillUniform(&t); + // Fill with ones for easier debugging + //fillUniform(&t); + const size_t numel = shape[0] * shape[1]; + std::vector<__nv_bfloat16> ones(numel, __float2bfloat16(1.0f)); + NVTE_CHECK_CUDA(cudaMemcpy(t.rowwise_dptr(), ones.data(), + numel * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice)); return t; } @@ -312,17 +323,21 @@ struct TestParams { ShapeCase shape_case; }; +// Returns a vector of (M, N, K) tuples for each GEMM in the group. +// M - number of rows in output D +// N - number of columns in output D +// K - reduction dimension shared between A and B std::vector> make_shapes(ShapeCase scase) { switch (scase) { case ShapeCase::kAllSame: return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}}; - case ShapeCase::kSameFirst: // M wspólne, N/K zróżnicowane - return {{64, 64, 32}, {64, 96, 32}, {64, 80, 48}}; - case ShapeCase::kSameLast: // N wspólne, M/K zróżnicowane - return {{48, 80, 32}, {96, 80, 48}, {72, 80, 40}}; + case ShapeCase::kSameFirst: + return {{64, 80, 32}, {64, 80, 48}, {64, 80, 64}}; + case ShapeCase::kSameLast: + return {{64, 80, 32}, {64, 80, 48}, {64, 80, 64}}; case ShapeCase::kAllDifferent: default: - return {{48, 80, 32}, {96, 64, 48}, {40, 72, 24}}; + return {{64, 96, 32}, {64, 96, 48}, {64, 96, 64}}; } } @@ -345,10 +360,10 @@ void run_grouped_gemm_case(const TestParams& params) { for (size_t i = 0; i < num_gemms; ++i) { const auto [M, N, K] = shapes[i]; - const std::vector a_shape = params.transa ? std::vector{K, M} - : std::vector{M, K}; - const std::vector b_shape = params.transb ? std::vector{N, K} - : std::vector{K, N}; + const std::vector a_shape = params.transa ? std::vector{M, K} + : std::vector{K, M}; + const std::vector b_shape = params.transb ? std::vector{K, N} + : std::vector{N, K}; switch (params.input_case) { case InputCase::kFP8Current: { A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); @@ -373,6 +388,10 @@ void run_grouped_gemm_case(const TestParams& params) { std::vector gelu_ptrs(num_gemms, nullptr); std::vector workspaces(num_gemms); std::vector workspace_ptrs(num_gemms, nullptr); + std::vector A_views; + std::vector B_views; + A_views.reserve(num_gemms); + B_views.reserve(num_gemms); const size_t cublas_ws_bytes = 32ull * 1024 * 1024; @@ -382,6 +401,8 @@ void run_grouped_gemm_case(const TestParams& params) { D_ptrs[i] = D_multi[i].data(); workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector{cublas_ws_bytes}, DType::kByte); workspace_ptrs[i] = workspaces[i].data(); + A_views.push_back(&A_tensors[i]); + B_views.push_back(&B_tensors[i]); } nvte_multi_tensor_gemm(A_ptrs.data(), @@ -399,8 +420,8 @@ void run_grouped_gemm_case(const TestParams& params) { 0, 0); - GroupedBuffers grouped_A = build_grouped_tensor(A_tensors, A_tensors[0].scaling_mode()); - GroupedBuffers grouped_B = build_grouped_tensor(B_tensors, B_tensors[0].scaling_mode()); + GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode()); + GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode()); std::vector C_tensors; std::vector D_group_tensors; diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 2c8c2093c6..bb29d58de4 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1115,20 +1115,50 @@ struct TensorShapeInfo { // Create from GroupedTensor static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { - return {t->first_dims.has_data() ? static_cast(t->first_dims.dptr) : nullptr, - t->last_dims.has_data() ? static_cast(t->last_dims.dptr) : nullptr, + const bool has_first = t->first_dims.has_data(); + const bool has_last = t->last_dims.has_data(); + // When per-tensor dims are not provided, we must be in the uniform-shape case. + NVTE_CHECK(has_first || t->all_same_first_dim(), + "GroupedTensor is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || t->all_same_last_dim(), + "GroupedTensor is missing last_dims for varying shapes"); + + const int64_t *first_ptr = has_first ? static_cast(t->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; + + const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); + + return {first_ptr, + last_ptr, t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) : nullptr, - t->get_common_first_dim(), t->get_common_last_dim()}; + uniform_first, + uniform_last}; } // Create for C tensor (uses D's dimensions, only has offsets) static TensorShapeInfo for_C(const transformer_engine::GroupedTensor *C, const transformer_engine::GroupedTensor *D) { - return {nullptr, nullptr, + const bool has_first = D->first_dims.has_data(); + const bool has_last = D->last_dims.has_data(); + NVTE_CHECK(has_first || D->all_same_first_dim(), + "GroupedTensor D is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || D->all_same_last_dim(), + "GroupedTensor D is missing last_dims for varying shapes"); + + const int64_t *first_ptr = + has_first ? static_cast(D->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(D->last_dims.dptr) : nullptr; + const int64_t uniform_first = has_first ? 0 : static_cast(D->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(D->get_common_last_dim()); + + return {first_ptr, + last_ptr, C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) : nullptr, - D->get_common_first_dim(), D->get_common_last_dim()}; + uniform_first, + uniform_last}; } }; @@ -1144,10 +1174,9 @@ inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) if (t->all_same_last_dim()) { // logical_shape[1] is the common N return static_cast(t->logical_shape.data[1]); - } else { - // logical_shape[1] is sum_of_N, divide by num_tensors - return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); } + // When varying, logical_shape[1] should be sum of last dims if provided; otherwise fallback to avg via division. + return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); } // Workspace layout for grouped GEMM @@ -1163,6 +1192,7 @@ struct GroupedGemmSetupWorkspace { float **beta_ptrs; // Initialize from workspace buffer + // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors, size_t alignment) { GroupedGemmSetupWorkspace ws; @@ -1170,6 +1200,7 @@ struct GroupedGemmSetupWorkspace { const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); + // Pointer arrays first (all 8-byte aligned) ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); @@ -1178,27 +1209,30 @@ struct GroupedGemmSetupWorkspace { offset += ptr_size; ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + + // Int arrays last (4-byte aligned, always satisfied after pointer arrays) ws.M = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; ws.N = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; ws.K = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; - ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; offset = ((offset + alignment - 1) / alignment) * alignment; return ws; } - // Calculate required size for setup workspace (pointer arrays + M/N/K + alpha/beta ptrs) + // Calculate required size for setup workspace (pointer arrays + M/N/K) static size_t required_setup_size(size_t num_tensors, size_t alignment) { const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); - size_t size = 4 * ptr_size + 3 * int_size + 2 * ptr_size; // M, N, K only (no LDA/LDB/LDC/LDD) + // Layout: 6 ptr arrays, then 3 int arrays (no padding needed) + size_t size = 6 * ptr_size + 3 * int_size; size = ((size + alignment - 1) / alignment) * alignment; return size; } @@ -1220,12 +1254,16 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor NVTE_CHECK(outputD->num_tensors == num_tensors, "Grouped GEMM: A and D must have the same num_tensors"); - auto is_fp8_or_16bit = [](DType dtype) { - return dtype == DType::kFloat8E4M3 || dtype == DType::kFloat8E5M2 || - dtype == DType::kBFloat16 || dtype == DType::kFloat16; + auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2 || + dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16; }; - auto is_output_dtype = [](DType dtype) { - return dtype == DType::kBFloat16 || dtype == DType::kFloat16 || dtype == DType::kFloat32; + auto is_output_dtype = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16 || + dtype == transformer_engine::DType::kFloat32; }; NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), "Grouped GEMM inputs must be FP8, BF16, or FP16."); @@ -1321,7 +1359,8 @@ inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, cublasLtMatrixLayoutOpaque_t &descB, cublasLtMatrixLayoutOpaque_t &descC, - cublasLtMatrixLayoutOpaque_t &descD, const GroupedGemmWorkspace &ws, + cublasLtMatrixLayoutOpaque_t &descD, + const GroupedGemmSetupWorkspace &ws, bool transa, bool transb, bool a_columnwise, bool b_columnwise, size_t num_tensors, cudaDataType_t A_type, cudaDataType_t B_type, cudaDataType_t D_type) { @@ -1366,6 +1405,10 @@ inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOpera &alphabeta_batch_stride, sizeof(int64_t))); } +// Constants for grouped GEMM workspace (declared early for use in heuristics) +static constexpr size_t kGroupedGemmAlignment = 256; +static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB + inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, cublasLtMatmulDescOpaque_t &matmulDesc, cublasLtMatrixLayoutOpaque_t &descA, @@ -1442,9 +1485,11 @@ __global__ void setup_grouped_gemm_kernel( D_ptrs[idx] = d_base + d_offset * d_elem_size; // Compute M, N, K dimensions - M[idx] = static_cast(transa ? a_last : a_first); - K[idx] = static_cast(transa ? a_first : a_last); - N[idx] = static_cast(transb ? b_first : b_last); + // Test stores A as {K,M} when !transa, {M,K} when transa + // Test stores B as {N,K} when !transb, {K,N} when transb + M[idx] = static_cast(transa ? a_first : a_last); + K[idx] = static_cast(transa ? a_last : a_first); + N[idx] = static_cast(transb ? b_last : b_first); // Fill alpha/beta pointers (same for all groups) alpha_ptrs[idx] = alpha_ptr; @@ -1453,7 +1498,7 @@ __global__ void setup_grouped_gemm_kernel( // Launch the setup kernel to populate workspace arrays inline void launch_grouped_gemm_setup( - const GroupedGemmWorkspace &ws, const transformer_engine::GroupedTensor *A, + const GroupedGemmSetupWorkspace &ws, const transformer_engine::GroupedTensor *A, const transformer_engine::GroupedTensor *B, const transformer_engine::GroupedTensor *C, const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, const transformer_engine::Tensor *beta_tensor, const char *a_base, const char *b_base, @@ -1482,10 +1527,6 @@ inline void launch_grouped_gemm_setup( NVTE_CHECK_CUDA(cudaGetLastError()); } -// Constants for grouped GEMM workspace -static constexpr size_t kGroupedGemmAlignment = 256; -static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB - inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); } @@ -1563,6 +1604,28 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT cublasLtMatmulDescOpaque_t matmulDesc; init_matmul_desc(matmulDesc, op_A, op_B); + // Set FP8 scale pointers if needed + const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); + const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); + if (is_fp8_a || is_fp8_b) { + // For FP8 grouped GEMM, we need to pass scale_inv pointers + // The scale_inv arrays contain one float per tensor in the group + if (is_fp8_a) { + void *a_scale_inv = A_sel.use_columnwise ? inputA->columnwise_scale_inv.dptr + : inputA->scale_inv.dptr; + NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); + } + if (is_fp8_b) { + void *b_scale_inv = B_sel.use_columnwise ? inputB->columnwise_scale_inv.dptr + : inputB->scale_inv.dptr; + NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); + } + } + // Compute average dimensions for heuristics // K dimension: if transa, K is A's first dim; if not, K is A's last dim int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 948058295e..246fb5fefd 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -11,6 +11,8 @@ #ifndef TRANSFORMER_ENGINE_GEMM_H_ #define TRANSFORMER_ENGINE_GEMM_H_ +#include + #include "transformer_engine.h" #ifdef __cplusplus From 785df3440a443b72340dfdf33db7391280e3a968 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Dec 2025 17:26:49 +0000 Subject: [PATCH 06/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_gemm.cu | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index bb29d58de4..55f52a1c4d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1123,18 +1123,17 @@ struct TensorShapeInfo { NVTE_CHECK(has_last || t->all_same_last_dim(), "GroupedTensor is missing last_dims for varying shapes"); - const int64_t *first_ptr = has_first ? static_cast(t->first_dims.dptr) : nullptr; + const int64_t *first_ptr = + has_first ? static_cast(t->first_dims.dptr) : nullptr; const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); - return {first_ptr, - last_ptr, + return {first_ptr, last_ptr, t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) : nullptr, - uniform_first, - uniform_last}; + uniform_first, uniform_last}; } // Create for C tensor (uses D's dimensions, only has offsets) @@ -1153,12 +1152,10 @@ struct TensorShapeInfo { const int64_t uniform_first = has_first ? 0 : static_cast(D->get_common_first_dim()); const int64_t uniform_last = has_last ? 0 : static_cast(D->get_common_last_dim()); - return {first_ptr, - last_ptr, + return {first_ptr, last_ptr, C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) : nullptr, - uniform_first, - uniform_last}; + uniform_first, uniform_last}; } }; @@ -1360,9 +1357,9 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, cublasLtMatrixLayoutOpaque_t &descB, cublasLtMatrixLayoutOpaque_t &descC, cublasLtMatrixLayoutOpaque_t &descD, - const GroupedGemmSetupWorkspace &ws, - bool transa, bool transb, bool a_columnwise, bool b_columnwise, - size_t num_tensors, cudaDataType_t A_type, cudaDataType_t B_type, + const GroupedGemmSetupWorkspace &ws, bool transa, bool transb, + bool a_columnwise, bool b_columnwise, size_t num_tensors, + cudaDataType_t A_type, cudaDataType_t B_type, cudaDataType_t D_type) { // For column-major layout: leading dimension is the number of rows in storage. // If columnwise data was chosen, storage is already transposed. @@ -1611,15 +1608,15 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT // For FP8 grouped GEMM, we need to pass scale_inv pointers // The scale_inv arrays contain one float per tensor in the group if (is_fp8_a) { - void *a_scale_inv = A_sel.use_columnwise ? inputA->columnwise_scale_inv.dptr - : inputA->scale_inv.dptr; + void *a_scale_inv = + A_sel.use_columnwise ? inputA->columnwise_scale_inv.dptr : inputA->scale_inv.dptr; NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); } if (is_fp8_b) { - void *b_scale_inv = B_sel.use_columnwise ? inputB->columnwise_scale_inv.dptr - : inputB->scale_inv.dptr; + void *b_scale_inv = + B_sel.use_columnwise ? inputB->columnwise_scale_inv.dptr : inputB->scale_inv.dptr; NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); From 1329b3746abfe3f9d845e90da7945bede6e3893c Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 10 Dec 2025 22:34:16 +0100 Subject: [PATCH 07/61] fix Signed-off-by: Pawel Gadzinski --- .../common/gemm/cublaslt_gemm.cu | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 55f52a1c4d..3662247b51 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1217,9 +1217,6 @@ struct GroupedGemmSetupWorkspace { ws.N = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; ws.K = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - - offset = ((offset + alignment - 1) / alignment) * alignment; return ws; } @@ -1363,21 +1360,21 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, cudaDataType_t D_type) { // For column-major layout: leading dimension is the number of rows in storage. // If columnwise data was chosen, storage is already transposed. - const int *rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); - const int *cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); - const int *lda = rowa; - const int *rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); - const int *colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); - const int *ldb = rowb; - - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, (void *)rowa, - (void *)cola, (void *)lda)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, (void *)rowb, - (void *)colb, (void *)ldb)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, (void *)ws.M, - (void *)ws.N, (void *)ws.M)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, (void *)ws.M, - (void *)ws.N, (void *)ws.M)); + int *rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); + int *cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); + int *lda = rowa; + int *rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); + int *colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); + int *ldb = rowb; + + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, + rowa, cola, lda)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, + rowb, colb, ldb)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, + ws.M, ws.N, ws.M)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, + ws.M, ws.N, ws.M)); } inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, From 47c58be8ce0ee14fc26a90a2f8b3ad8035283b4c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Dec 2025 21:35:06 +0000 Subject: [PATCH 08/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/gemm/cublaslt_gemm.cu | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 3662247b51..91405bd42f 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1367,14 +1367,10 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, int *colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); int *ldb = rowb; - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, - rowa, cola, lda)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, - rowb, colb, ldb)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, - ws.M, ws.N, ws.M)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, - ws.M, ws.N, ws.M)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rowa, cola, lda)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, rowb, colb, ldb)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.M, ws.N, ws.M)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.M, ws.N, ws.M)); } inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, From a155a8a3dd17663c82882f64b30a5a118ba3695b Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 11 Dec 2025 11:55:44 +0100 Subject: [PATCH 09/61] Grouped GEMM: code cleanup and NULL C support - Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers - Simplify select_grouped_operand by removing dead code branches - Add GroupedOperandSelection.tensor field to avoid passing tensor separately - Extract set_fp8_scale_pointers and init_matrix_layouts helpers - Add safety check for FP8 on Hopper column-wise fallback - Support NULL C tensor when beta=0 (uses D as placeholder) - Remove unused get_scale_inv() from test - Add use_null_c test parameter and test case - Fix documentation: alpha/beta are single element tensors only Signed-off-by: Piotr Gadzinski Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 210 ++++++++---------- .../common/gemm/cublaslt_gemm.cu | 163 +++++++------- .../common/include/transformer_engine/gemm.h | 34 +-- 3 files changed, 203 insertions(+), 204 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index bff175f405..5e5144fa4c 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -28,7 +29,6 @@ using namespace test; namespace { enum class InputCase { - kFP8Delayed, kFP8Current, kBF16, }; @@ -40,17 +40,37 @@ enum class ShapeCase { kAllDifferent, }; +// Custom deleters for RAII +struct CudaDeleter { + void operator()(void* p) const { if (p) cudaFree(p); } +}; +struct GroupedTensorDeleter { + void operator()(NVTEGroupedTensor h) const { if (h) nvte_destroy_grouped_tensor(h); } +}; + +template +using CudaPtr = std::unique_ptr; +using GroupedTensorHandle = std::unique_ptr, GroupedTensorDeleter>; + +// Helper to allocate CUDA memory into a CudaPtr +template +CudaPtr cuda_alloc(size_t bytes) { + void* ptr = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&ptr, bytes)); + return CudaPtr(static_cast(ptr)); +} + // Helper owning GPU buffers that back NVTEGroupedTensor. // NVTEGroupedTensor does not own memory; data/offsets/scales // must be allocated and freed by the test. struct GroupedBuffers { - NVTEGroupedTensor handle{nullptr}; - void* data{nullptr}; - void* scale_inv{nullptr}; - int64_t* first_dims_dev{nullptr}; - int64_t* last_dims_dev{nullptr}; - int64_t* offsets_dev{nullptr}; - void* columnwise_data{nullptr}; + GroupedTensorHandle handle; + CudaPtr<> data; + CudaPtr<> scale_inv; + CudaPtr first_dims_dev; + CudaPtr last_dims_dev; + CudaPtr offsets_dev; + CudaPtr<> columnwise_data; NVTEShape logical_shape{}; std::vector offsets_host; std::vector tensor_bytes; @@ -62,65 +82,13 @@ struct GroupedBuffers { GroupedBuffers() = default; GroupedBuffers(const GroupedBuffers&) = delete; GroupedBuffers& operator=(const GroupedBuffers&) = delete; - GroupedBuffers(GroupedBuffers&& other) noexcept { - *this = std::move(other); - } - GroupedBuffers& operator=(GroupedBuffers&& other) noexcept { - if (this == &other) return *this; - handle = other.handle; - data = other.data; - scale_inv = other.scale_inv; - first_dims_dev = other.first_dims_dev; - last_dims_dev = other.last_dims_dev; - offsets_dev = other.offsets_dev; - logical_shape = other.logical_shape; - offsets_host = std::move(other.offsets_host); - tensor_bytes = std::move(other.tensor_bytes); - num_tensors = other.num_tensors; - elem_size = other.elem_size; - dtype = other.dtype; - scaling_mode = other.scaling_mode; - - other.handle = nullptr; - other.data = nullptr; - other.scale_inv = nullptr; - other.first_dims_dev = nullptr; - other.last_dims_dev = nullptr; - other.offsets_dev = nullptr; - other.num_tensors = 0; - return *this; - } + GroupedBuffers(GroupedBuffers&&) = default; + GroupedBuffers& operator=(GroupedBuffers&&) = default; + ~GroupedBuffers() = default; - ~GroupedBuffers() { - if (data) { - cudaFree(data); - data = nullptr; - } - if (scale_inv) { - cudaFree(scale_inv); - scale_inv = nullptr; - } - if (columnwise_data) { - cudaFree(columnwise_data); - columnwise_data = nullptr; - } - if (first_dims_dev) { - cudaFree(first_dims_dev); - first_dims_dev = nullptr; - } - if (last_dims_dev) { - cudaFree(last_dims_dev); - last_dims_dev = nullptr; - } - if (offsets_dev) { - cudaFree(offsets_dev); - offsets_dev = nullptr; - } - if (handle) { - nvte_destroy_grouped_tensor(handle); - handle = nullptr; - } - } + // Convenience accessors for raw pointers + NVTEGroupedTensor get_handle() const { return handle.get(); } + void* get_data() const { return data.get(); } }; size_t grouped_setup_workspace_size(const size_t num_tensors) { @@ -211,7 +179,7 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, size_t logical_data[2] = {static_cast(logical_first), static_cast(logical_last)}; grouped.logical_shape = nvte_make_shape(logical_data, 2); - grouped.handle = nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape); + grouped.handle.reset(nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape)); const int64_t last_idx = static_cast(num_tensors - 1); const int64_t total_elems = need_offsets @@ -219,59 +187,60 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, : (logical_first * logical_last); const size_t total_bytes = static_cast(total_elems) * elem_size; - NVTE_CHECK_CUDA(cudaMalloc(&grouped.data, total_bytes)); + grouped.data = cuda_alloc(total_bytes); for (size_t i = 0; i < num_tensors; ++i) { const size_t offset_bytes = static_cast(offsets[i]) * elem_size; - NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data) + offset_bytes, + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data.get()) + offset_bytes, tensors[i]->rowwise_dptr(), grouped.tensor_bytes[i], cudaMemcpyDeviceToDevice)); } - NVTEBasicTensor data_tensor{grouped.data, static_cast(dtype), grouped.logical_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedRowwiseData, &data_tensor); + NVTEBasicTensor data_tensor{grouped.data.get(), static_cast(dtype), grouped.logical_shape}; + NVTEGroupedTensor h = grouped.handle.get(); + nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseData, &data_tensor); const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype); if (include_columnwise) { - NVTE_CHECK_CUDA(cudaMalloc(&grouped.columnwise_data, total_bytes)); + grouped.columnwise_data = cuda_alloc(total_bytes); for (size_t i = 0; i < num_tensors; ++i) { const size_t offset_bytes = static_cast(offsets[i]) * elem_size; - NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.columnwise_data) + offset_bytes, + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.columnwise_data.get()) + offset_bytes, tensors[i]->columnwise_dptr(), grouped.tensor_bytes[i], cudaMemcpyDeviceToDevice)); } - NVTEBasicTensor col_tensor{grouped.columnwise_data, + NVTEBasicTensor col_tensor{grouped.columnwise_data.get(), static_cast(dtype), grouped.logical_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedColumnwiseData, &col_tensor); + nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseData, &col_tensor); } if (!same_first) { - NVTE_CHECK_CUDA(cudaMalloc(&grouped.first_dims_dev, num_tensors * sizeof(int64_t))); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev, first_dims.data(), + grouped.first_dims_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev.get(), first_dims.data(), num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor fd_tensor{grouped.first_dims_dev, kNVTEInt64, fd_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedFirstDims, &fd_tensor); + NVTEBasicTensor fd_tensor{grouped.first_dims_dev.get(), kNVTEInt64, fd_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedFirstDims, &fd_tensor); } if (!same_last) { - NVTE_CHECK_CUDA(cudaMalloc(&grouped.last_dims_dev, num_tensors * sizeof(int64_t))); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev, last_dims.data(), + grouped.last_dims_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev.get(), last_dims.data(), num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor ld_tensor{grouped.last_dims_dev, kNVTEInt64, ld_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedLastDims, &ld_tensor); + NVTEBasicTensor ld_tensor{grouped.last_dims_dev.get(), kNVTEInt64, ld_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedLastDims, &ld_tensor); } if (!same_first || !same_last) { - NVTE_CHECK_CUDA(cudaMalloc(&grouped.offsets_dev, num_tensors * sizeof(int64_t))); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev, offsets.data(), + grouped.offsets_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev.get(), offsets.data(), num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); NVTEShape off_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor off_tensor{grouped.offsets_dev, kNVTEInt64, off_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedTensorOffsets, &off_tensor); + NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedTensorOffsets, &off_tensor); } if (isFp8Type(dtype)) { @@ -280,13 +249,13 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, tensors[i]->to_cpu(); scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr()[0]; } - NVTE_CHECK_CUDA(cudaMalloc(&grouped.scale_inv, sizeof(float) * num_tensors)); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv, scale_inv_cpu.data(), + grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(), sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor scale_tensor{grouped.scale_inv, kNVTEFloat32, scale_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedRowwiseScaleInv, &scale_tensor); - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedColumnwiseScaleInv, &scale_tensor); + NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseScaleInv, &scale_tensor); + nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor); } return grouped; @@ -321,6 +290,7 @@ struct TestParams { bool transa; bool transb; ShapeCase shape_case; + bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0) }; // Returns a vector of (M, N, K) tuples for each GEMM in the group. @@ -332,12 +302,14 @@ std::vector> make_shapes(ShapeCase scase) { case ShapeCase::kAllSame: return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}}; case ShapeCase::kSameFirst: - return {{64, 80, 32}, {64, 80, 48}, {64, 80, 64}}; + // Same M (first dim), varying N and K + return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}}; case ShapeCase::kSameLast: - return {{64, 80, 32}, {64, 80, 48}, {64, 80, 64}}; + // Same N (last dim), varying M and K + return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}}; case ShapeCase::kAllDifferent: default: - return {{64, 96, 32}, {64, 96, 48}, {64, 96, 64}}; + return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}}; } } @@ -430,9 +402,11 @@ void run_grouped_gemm_case(const TestParams& params) { for (size_t i = 0; i < num_gemms; ++i) { const auto [M, N, K] = shapes[i]; (void)K; - C_tensors.emplace_back(Tensor("C" + std::to_string(i), - std::vector{static_cast(M), static_cast(N)}, - DType::kBFloat16)); + if (!params.use_null_c) { + C_tensors.emplace_back(Tensor("C" + std::to_string(i), + std::vector{static_cast(M), static_cast(N)}, + DType::kBFloat16)); + } D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i), std::vector{static_cast(M), static_cast(N)}, DType::kBFloat16)); @@ -441,11 +415,16 @@ void run_grouped_gemm_case(const TestParams& params) { std::vector C_views, D_views; for (size_t i = 0; i < num_gemms; ++i) { - C_views.push_back(&C_tensors[i]); + if (!params.use_null_c) { + C_views.push_back(&C_tensors[i]); + } D_views.push_back(&D_group_tensors[i]); } - GroupedBuffers grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING); + std::optional grouped_C; + if (!params.use_null_c) { + grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING); + } GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING); Tensor alpha_tensor("alpha", std::vector{1}, DType::kFloat32); @@ -462,11 +441,11 @@ void run_grouped_gemm_case(const TestParams& params) { nvte_grouped_gemm(params.transa, params.transb, alpha_tensor.data(), - grouped_A.handle, - grouped_B.handle, + grouped_A.get_handle(), + grouped_B.get_handle(), beta_tensor.data(), - grouped_C.handle, - grouped_D.handle, + params.use_null_c ? nullptr : grouped_C->get_handle(), + grouped_D.get_handle(), setup_ws.data(), cublas_ws.data(), nullptr, @@ -482,7 +461,7 @@ void run_grouped_gemm_case(const TestParams& params) { D_multi[i].dtype()); const size_t offset_bytes = static_cast(grouped_D.offsets_host[i]) * grouped_D.elem_size; NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(), - static_cast(grouped_D.data) + offset_bytes, + static_cast(grouped_D.get_data()) + offset_bytes, grouped_D.tensor_bytes[i], cudaMemcpyDeviceToDevice)); grouped_split.to_cpu(); @@ -504,22 +483,25 @@ TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) { } std::string MakeGroupedGemmTestName(const testing::TestParamInfo& info) { - constexpr const char* kInputNames[] = {"FP8Delayed", "FP8Current", "BF16"}; + constexpr const char* kInputNames[] = {"FP8Current", "BF16"}; constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"}; const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") + "tb" + (info.param.transb ? "T" : "N"); + const std::string null_c = info.param.use_null_c ? "_NullC" : ""; return std::string(kInputNames[static_cast(info.param.input_case)]) + "_" + - kShapeNames[static_cast(info.param.shape_case)] + "_" + layout; + kShapeNames[static_cast(info.param.shape_case)] + "_" + layout + null_c; } const std::vector kTestParams = { - {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent}, - {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent}, - {InputCase::kFP8Current, false, false, ShapeCase::kAllSame}, - {InputCase::kBF16, true, false, ShapeCase::kSameFirst}, - {InputCase::kBF16, false, true, ShapeCase::kSameLast}, - {InputCase::kBF16, false, false, ShapeCase::kAllSame}, - {InputCase::kBF16, true, true, ShapeCase::kAllDifferent}, + {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false}, + {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false}, + {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false}, + {InputCase::kBF16, true, false, ShapeCase::kSameFirst, false}, + {InputCase::kBF16, false, true, ShapeCase::kSameLast, false}, + {InputCase::kBF16, false, false, ShapeCase::kAllSame, false}, + {InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false}, + // Test NULL C (valid when beta=0) + {InputCase::kBF16, false, false, ShapeCase::kAllSame, true}, }; INSTANTIATE_TEST_SUITE_P(OperatorTest, diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 91405bd42f..9d9a5097d4 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1190,8 +1190,7 @@ struct GroupedGemmSetupWorkspace { // Initialize from workspace buffer // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) - static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors, - size_t alignment) { + static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) { GroupedGemmSetupWorkspace ws; size_t offset = 0; const size_t ptr_size = num_tensors * sizeof(void *); @@ -1243,8 +1242,11 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); NVTE_CHECK(inputB->num_tensors == num_tensors, "Grouped GEMM: A and B must have the same num_tensors"); - NVTE_CHECK(inputC->num_tensors == num_tensors, - "Grouped GEMM: A and C must have the same num_tensors"); + // C can be NULL (will use D as C when beta=0) + if (inputC != nullptr) { + NVTE_CHECK(inputC->num_tensors == num_tensors, + "Grouped GEMM: A and C must have the same num_tensors"); + } NVTE_CHECK(outputD->num_tensors == num_tensors, "Grouped GEMM: A and D must have the same num_tensors"); @@ -1261,8 +1263,13 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor }; NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), "Grouped GEMM inputs must be FP8, BF16, or FP16."); - NVTE_CHECK(is_output_dtype(inputC->dtype()) && is_output_dtype(outputD->dtype()), - "Grouped GEMM outputs must be BF16, FP16, or FP32."); + // Only check C dtype if C is provided + if (inputC != nullptr) { + NVTE_CHECK(is_output_dtype(inputC->dtype()), + "Grouped GEMM: C must be BF16, FP16, or FP32."); + } + NVTE_CHECK(is_output_dtype(outputD->dtype()), + "Grouped GEMM: D must be BF16, FP16, or FP32."); NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), @@ -1273,6 +1280,7 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor // Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and // fallback to column-wise data when row-wise is absent. struct GroupedOperandSelection { + const transformer_engine::GroupedTensor *tensor = nullptr; const char *base = nullptr; transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; bool trans = false; @@ -1296,6 +1304,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: const DType row_dtype = t->data.dtype; const DType col_dtype = t->columnwise_data.dtype; GroupedOperandSelection sel; + sel.tensor = t; sel.trans = trans; const DType rep_dtype = has_row ? row_dtype : col_dtype; @@ -1327,6 +1336,9 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). if (!has_row && has_col) { + // On Hopper FP8, this would break TN requirement - should have been handled above + NVTE_CHECK(!is_fp8 || non_tn_fp8_ok, + "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); sel.base = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = !sel.trans; @@ -1334,10 +1346,10 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: return sel; } - // Default: use row-wise data (or column-wise if row-wise absent, covered above). - sel.base = static_cast(has_row ? t->data.dptr : t->columnwise_data.dptr); - sel.dtype = has_row ? row_dtype : col_dtype; - sel.use_columnwise = !has_row && has_col; + // Default: use row-wise data (column-wise case already handled above) + sel.base = static_cast(t->data.dptr); + sel.dtype = row_dtype; + sel.use_columnwise = false; return sel; } @@ -1354,17 +1366,22 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, cublasLtMatrixLayoutOpaque_t &descB, cublasLtMatrixLayoutOpaque_t &descC, cublasLtMatrixLayoutOpaque_t &descD, - const GroupedGemmSetupWorkspace &ws, bool transa, bool transb, - bool a_columnwise, bool b_columnwise, size_t num_tensors, - cudaDataType_t A_type, cudaDataType_t B_type, - cudaDataType_t D_type) { + const GroupedGemmSetupWorkspace &ws, + const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, + const transformer_engine::GroupedTensor *D, + size_t num_tensors) { + const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); + const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); + const cudaDataType_t D_type = get_cuda_dtype(D->dtype()); + // For column-major layout: leading dimension is the number of rows in storage. // If columnwise data was chosen, storage is already transposed. - int *rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); - int *cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); + int *rowa = A_sel.use_columnwise ? ws.M : (A_sel.trans ? ws.K : ws.M); + int *cola = A_sel.use_columnwise ? ws.K : (A_sel.trans ? ws.M : ws.K); int *lda = rowa; - int *rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); - int *colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); + int *rowb = B_sel.use_columnwise ? ws.N : (B_sel.trans ? ws.N : ws.K); + int *colb = B_sel.use_columnwise ? ws.K : (B_sel.trans ? ws.K : ws.N); int *ldb = rowb; NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rowa, cola, lda)); @@ -1395,6 +1412,31 @@ inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOpera &alphabeta_batch_stride, sizeof(int64_t))); } +inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, + const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel) { + const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); + const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); + if (!is_fp8_a && !is_fp8_b) return; + + if (is_fp8_a) { + void *a_scale_inv = A_sel.use_columnwise + ? A_sel.tensor->columnwise_scale_inv.dptr + : A_sel.tensor->scale_inv.dptr; + NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); + } + if (is_fp8_b) { + void *b_scale_inv = B_sel.use_columnwise + ? B_sel.tensor->columnwise_scale_inv.dptr + : B_sel.tensor->scale_inv.dptr; + NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); + } +} + // Constants for grouped GEMM workspace (declared early for use in heuristics) static constexpr size_t kGroupedGemmAlignment = 256; static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB @@ -1488,20 +1530,20 @@ __global__ void setup_grouped_gemm_kernel( // Launch the setup kernel to populate workspace arrays inline void launch_grouped_gemm_setup( - const GroupedGemmSetupWorkspace &ws, const transformer_engine::GroupedTensor *A, - const transformer_engine::GroupedTensor *B, const transformer_engine::GroupedTensor *C, + const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, const transformer_engine::GroupedTensor *C, const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, - const transformer_engine::Tensor *beta_tensor, const char *a_base, const char *b_base, - size_t a_elem_size, size_t b_elem_size, bool transa, bool transb, size_t num_tensors, - cudaStream_t stream) { - TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A); - TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B); + const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream) { + TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A_sel.tensor); + TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B_sel.tensor); TensorShapeInfo C_meta = TensorShapeInfo::for_C(C, D); TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); const char *c_base = static_cast(C->data.dptr); char *d_base = static_cast(D->data.dptr); + const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); + const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); const size_t c_elem_size = transformer_engine::typeToSize(C->dtype()); const size_t d_elem_size = transformer_engine::typeToSize(D->dtype()); @@ -1510,9 +1552,9 @@ inline void launch_grouped_gemm_setup( setup_grouped_gemm_kernel<<>>( ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.M, ws.N, ws.K, ws.alpha_ptrs, ws.beta_ptrs, - a_base, b_base, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, b_elem_size, - c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), - static_cast(beta_tensor->data.dptr), transa, transb, num_tensors); + A_sel.base, B_sel.base, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, + b_elem_size, c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), + static_cast(beta_tensor->data.dptr), A_sel.trans, B_sel.trans, num_tensors); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -1532,7 +1574,7 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT // Convert to internal types const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); - const GroupedTensor *inputC = convertNVTEGroupedTensorCheck(C); + const GroupedTensor *inputC_raw = convertNVTEGroupedTensor(C); // Can be NULL GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D); const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); const Tensor *beta_tensor = convertNVTETensorCheck(beta); @@ -1540,19 +1582,16 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); // Validate inputs and num_tensors - validate_grouped_gemm_inputs(inputA, inputB, inputC, outputD); + validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD); + + // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) + const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; const size_t num_tensors = inputA->num_tensors; // Select operand storage (row-wise vs column-wise) and adjust transpose flags to // mirror the non-grouped GEMM logic for FP8 layout constraints. - bool transa_flag = static_cast(transa); - bool transb_flag = static_cast(transb); - const auto A_sel = select_grouped_operand(inputA, transa_flag, /*is_A=*/true); - const auto B_sel = select_grouped_operand(inputB, transb_flag, /*is_A=*/false); - transa_flag = A_sel.trans; - transb_flag = B_sel.trans; - const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); - const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); + const auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); + const auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); // Workspaces: setup (pointer arrays) and cuBLAS const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); @@ -1563,65 +1602,35 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, "Grouped GEMM cuBLAS workspace"); - NVTE_CHECK(cublas_workspace_ptr != nullptr, "Grouped GEMM: cuBLAS workspace pointer is null"); - auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( - static_cast(setup_workspace_ptr), num_tensors, kGroupedGemmAlignment); - launch_grouped_gemm_setup(setup_workspace, inputA, inputB, inputC, outputD, alpha_tensor, - beta_tensor, A_sel.base, B_sel.base, a_elem_size, b_elem_size, - transa_flag, transb_flag, num_tensors, stream); + static_cast(setup_workspace_ptr), num_tensors); + launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, + alpha_tensor, beta_tensor, num_tensors, stream); // Get cuBLAS handle using cublasHandleManager = detail::HandleManager; cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); - // Get data types - const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); - const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); - const cudaDataType_t D_type = get_cuda_dtype(outputD->dtype()); - // Setup cuBLAS operations - cublasOperation_t op_A = transa_flag ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t op_B = transb_flag ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t op_A = A_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t op_B = B_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; // Create grouped matrix layouts cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; - init_matrix_layouts(descA, descB, descC, descD, setup_workspace, transa_flag, transb_flag, - A_sel.use_columnwise, B_sel.use_columnwise, num_tensors, A_type, B_type, - D_type); + init_matrix_layouts(descA, descB, descC, descD, setup_workspace, A_sel, B_sel, outputD, + num_tensors); // Create matmul descriptor cublasLtMatmulDescOpaque_t matmulDesc; init_matmul_desc(matmulDesc, op_A, op_B); - - // Set FP8 scale pointers if needed - const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); - const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); - if (is_fp8_a || is_fp8_b) { - // For FP8 grouped GEMM, we need to pass scale_inv pointers - // The scale_inv arrays contain one float per tensor in the group - if (is_fp8_a) { - void *a_scale_inv = - A_sel.use_columnwise ? inputA->columnwise_scale_inv.dptr : inputA->scale_inv.dptr; - NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); - } - if (is_fp8_b) { - void *b_scale_inv = - B_sel.use_columnwise ? inputB->columnwise_scale_inv.dptr : inputB->scale_inv.dptr; - NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); - } - } + set_fp8_scale_pointers(matmulDesc, A_sel, B_sel); // Compute average dimensions for heuristics // K dimension: if transa, K is A's first dim; if not, K is A's last dim int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); int64_t avg_n_val = avg_n ? *avg_n : compute_avg_last_dim(outputD); int64_t avg_k_val = - avg_k ? *avg_k : (transa_flag ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA)); + avg_k ? *avg_k : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) : compute_avg_last_dim(A_sel.tensor)); // Heuristic selection cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 246fb5fefd..02cf01853d 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -239,19 +239,27 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous * memory layout and shape metadata. * - * \param[in] transa Whether to transpose A matrices. - * \param[in] transb Whether to transpose B matrices. - * \param[in] alpha Scale multiplier for A @ B (NVTETensor with num_tensors elements, - * or single element for uniform alpha). - * \param[in] A Input grouped tensor A. - * \param[in] B Input grouped tensor B. - * \param[in] beta Scale multiplier for C (NVTETensor with num_tensors elements, - * or single element for uniform beta). - * \param[in] C Input grouped tensor C (can be NULL for beta=0). - * \param[out] D Output grouped tensor D. - * \param[in] workspace Workspace tensor for intermediate computations. - * \param[in] config Matrix multiplication configuration. - * \param[in] stream CUDA stream for the operation. + * \param[in] transa Whether to transpose A matrices. + * \param[in] transb Whether to transpose B matrices. + * \param[in] alpha Scale multiplier for A @ B (single element NVTETensor). + * \param[in] A Input grouped tensor A. + * \param[in] B Input grouped tensor B. + * \param[in] beta Scale multiplier for C (single element NVTETensor). + * \param[in] C Input grouped tensor C (can be NULL for beta=0). + * \param[out] D Output grouped tensor D. + * \param[in] workspace_setup Workspace tensor for pointer array setup. + * \param[in] workspace_cublas Workspace tensor for cuBLAS operations. + * \param[in] config Matrix multiplication configuration. + * \param[in] stream CUDA stream for the operation. + * \param[in] avg_m Optional hint for average M dimension across all matrices in the + * group. Used by cuBLASLt for algorithm selection heuristics. + * If NULL, computed automatically from D's logical shape. + * \param[in] avg_n Optional hint for average N dimension across all matrices in the + * group. Used by cuBLASLt for algorithm selection heuristics. + * If NULL, computed automatically from D's logical shape. + * \param[in] avg_k Optional hint for average K (reduction) dimension across all + * matrices in the group. Used by cuBLASLt for algorithm selection + * heuristics. If NULL, computed automatically from A's logical shape. * * Requirements: * - A, B, C (if provided), D must have the same num_tensors From 3b2fcdf3137cec31b83dc6dc0f64e2e367aa6f9b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Dec 2025 10:57:26 +0000 Subject: [PATCH 10/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_gemm.cu | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 9d9a5097d4..7f2635943b 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1265,11 +1265,9 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor "Grouped GEMM inputs must be FP8, BF16, or FP16."); // Only check C dtype if C is provided if (inputC != nullptr) { - NVTE_CHECK(is_output_dtype(inputC->dtype()), - "Grouped GEMM: C must be BF16, FP16, or FP32."); + NVTE_CHECK(is_output_dtype(inputC->dtype()), "Grouped GEMM: C must be BF16, FP16, or FP32."); } - NVTE_CHECK(is_output_dtype(outputD->dtype()), - "Grouped GEMM: D must be BF16, FP16, or FP32."); + NVTE_CHECK(is_output_dtype(outputD->dtype()), "Grouped GEMM: D must be BF16, FP16, or FP32."); NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), @@ -1337,8 +1335,9 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). if (!has_row && has_col) { // On Hopper FP8, this would break TN requirement - should have been handled above - NVTE_CHECK(!is_fp8 || non_tn_fp8_ok, - "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); + NVTE_CHECK( + !is_fp8 || non_tn_fp8_ok, + "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); sel.base = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = !sel.trans; @@ -1369,8 +1368,7 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, const GroupedOperandSelection &B_sel, - const transformer_engine::GroupedTensor *D, - size_t num_tensors) { + const transformer_engine::GroupedTensor *D, size_t num_tensors) { const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); const cudaDataType_t D_type = get_cuda_dtype(D->dtype()); @@ -1420,17 +1418,15 @@ inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, if (!is_fp8_a && !is_fp8_b) return; if (is_fp8_a) { - void *a_scale_inv = A_sel.use_columnwise - ? A_sel.tensor->columnwise_scale_inv.dptr - : A_sel.tensor->scale_inv.dptr; + void *a_scale_inv = A_sel.use_columnwise ? A_sel.tensor->columnwise_scale_inv.dptr + : A_sel.tensor->scale_inv.dptr; NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); } if (is_fp8_b) { - void *b_scale_inv = B_sel.use_columnwise - ? B_sel.tensor->columnwise_scale_inv.dptr - : B_sel.tensor->scale_inv.dptr; + void *b_scale_inv = B_sel.use_columnwise ? B_sel.tensor->columnwise_scale_inv.dptr + : B_sel.tensor->scale_inv.dptr; NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); @@ -1604,8 +1600,8 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( static_cast(setup_workspace_ptr), num_tensors); - launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, - alpha_tensor, beta_tensor, num_tensors, stream); + launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor, + beta_tensor, num_tensors, stream); // Get cuBLAS handle using cublasHandleManager = detail::HandleManager; @@ -1629,8 +1625,9 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT // K dimension: if transa, K is A's first dim; if not, K is A's last dim int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); int64_t avg_n_val = avg_n ? *avg_n : compute_avg_last_dim(outputD); - int64_t avg_k_val = - avg_k ? *avg_k : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) : compute_avg_last_dim(A_sel.tensor)); + int64_t avg_k_val = avg_k ? *avg_k + : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) + : compute_avg_last_dim(A_sel.tensor)); // Heuristic selection cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, From 5b0582bbf0fd05773242df67836ec263014d52dd Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 11 Dec 2025 12:15:12 +0100 Subject: [PATCH 11/61] Grouped GEMM: per-matrix alpha/beta support - Change alpha/beta from single values to per-matrix arrays - Validate alpha/beta have exactly num_tensors elements - Update kernel to index alpha_ptr[idx] and beta_ptr[idx] - Move alpha/beta validation to validate_grouped_gemm_inputs - Update tests to use per-matrix alpha/beta arrays - Update documentation Signed-off-by: Piotr Gadzinski Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 15 +++++++----- .../common/gemm/cublaslt_gemm.cu | 24 ++++++++++++++----- .../common/include/transformer_engine/gemm.h | 4 ++-- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 5e5144fa4c..82b5bd3803 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -427,12 +427,15 @@ void run_grouped_gemm_case(const TestParams& params) { } GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING); - Tensor alpha_tensor("alpha", std::vector{1}, DType::kFloat32); - Tensor beta_tensor("beta", std::vector{1}, DType::kFloat32); - const float alpha_val = 1.f; - const float beta_val = 0.f; - NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), &alpha_val, sizeof(float), cudaMemcpyHostToDevice)); - NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), &beta_val, sizeof(float), cudaMemcpyHostToDevice)); + // Per-matrix alpha/beta (all 1.0 and 0.0 respectively) + Tensor alpha_tensor("alpha", std::vector{num_gemms}, DType::kFloat32); + Tensor beta_tensor("beta", std::vector{num_gemms}, DType::kFloat32); + std::vector alpha_vals(num_gemms, 1.f); + std::vector beta_vals(num_gemms, 0.f); + NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms); Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7f2635943b..caa394d549 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1237,7 +1237,9 @@ struct GroupedGemmSetupWorkspace { inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor *inputA, const transformer_engine::GroupedTensor *inputB, const transformer_engine::GroupedTensor *inputC, - const transformer_engine::GroupedTensor *outputD) { + const transformer_engine::GroupedTensor *outputD, + const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor) { const size_t num_tensors = inputA->num_tensors; NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); NVTE_CHECK(inputB->num_tensors == num_tensors, @@ -1250,6 +1252,16 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor NVTE_CHECK(outputD->num_tensors == num_tensors, "Grouped GEMM: A and D must have the same num_tensors"); + // Validate alpha/beta have per-matrix values + const size_t alpha_numel = alpha_tensor->data.shape.numel(); + const size_t beta_numel = beta_tensor->data.shape.numel(); + NVTE_CHECK(alpha_numel == num_tensors, + "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", + alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, + "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", + beta_numel); + auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { return dtype == transformer_engine::DType::kFloat8E4M3 || dtype == transformer_engine::DType::kFloat8E5M2 || @@ -1481,7 +1493,7 @@ __global__ void setup_grouped_gemm_kernel( TensorShapeInfo A_meta, TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, // Element sizes size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, - // Alpha/beta pointers (same for all groups) + // Alpha/beta pointers (per-matrix arrays) float *alpha_ptr, float *beta_ptr, // Transpose flags bool transa, bool transb, @@ -1519,9 +1531,9 @@ __global__ void setup_grouped_gemm_kernel( K[idx] = static_cast(transa ? a_last : a_first); N[idx] = static_cast(transb ? b_last : b_first); - // Fill alpha/beta pointers (same for all groups) - alpha_ptrs[idx] = alpha_ptr; - beta_ptrs[idx] = beta_ptr; + // Fill alpha/beta pointers (per-matrix) + alpha_ptrs[idx] = alpha_ptr + idx; + beta_ptrs[idx] = beta_ptr + idx; } // Launch the setup kernel to populate workspace arrays @@ -1578,7 +1590,7 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); // Validate inputs and num_tensors - validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD); + validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD, alpha_tensor, beta_tensor); // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 02cf01853d..9dfa009115 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -241,10 +241,10 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * * \param[in] transa Whether to transpose A matrices. * \param[in] transb Whether to transpose B matrices. - * \param[in] alpha Scale multiplier for A @ B (single element NVTETensor). + * \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements). * \param[in] A Input grouped tensor A. * \param[in] B Input grouped tensor B. - * \param[in] beta Scale multiplier for C (single element NVTETensor). + * \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements). * \param[in] C Input grouped tensor C (can be NULL for beta=0). * \param[out] D Output grouped tensor D. * \param[in] workspace_setup Workspace tensor for pointer array setup. From 101766bcb15e9cd6a9df01eaa6e5b5b9d9989f40 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Dec 2025 11:17:48 +0000 Subject: [PATCH 12/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/gemm/cublaslt_gemm.cu | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index caa394d549..1d63cf65cf 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1255,12 +1255,10 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor // Validate alpha/beta have per-matrix values const size_t alpha_numel = alpha_tensor->data.shape.numel(); const size_t beta_numel = beta_tensor->data.shape.numel(); - NVTE_CHECK(alpha_numel == num_tensors, - "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", - alpha_numel); - NVTE_CHECK(beta_numel == num_tensors, - "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", - beta_numel); + NVTE_CHECK(alpha_numel == num_tensors, "Grouped GEMM: alpha must have num_tensors (", num_tensors, + ") elements, got ", alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, "Grouped GEMM: beta must have num_tensors (", num_tensors, + ") elements, got ", beta_numel); auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { return dtype == transformer_engine::DType::kFloat8E4M3 || From 1167f7539fb91a7d8cb7de2ea252e89415967073 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 11 Dec 2025 12:25:28 +0100 Subject: [PATCH 13/61] Fix alpha/beta numel - use SimpleTensor::numel() Signed-off-by: Piotr Gadzinski Signed-off-by: Pawel Gadzinski --- transformer_engine/common/gemm/cublaslt_gemm.cu | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 1d63cf65cf..b8aa2a8ba3 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1253,12 +1253,14 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor "Grouped GEMM: A and D must have the same num_tensors"); // Validate alpha/beta have per-matrix values - const size_t alpha_numel = alpha_tensor->data.shape.numel(); - const size_t beta_numel = beta_tensor->data.shape.numel(); - NVTE_CHECK(alpha_numel == num_tensors, "Grouped GEMM: alpha must have num_tensors (", num_tensors, - ") elements, got ", alpha_numel); - NVTE_CHECK(beta_numel == num_tensors, "Grouped GEMM: beta must have num_tensors (", num_tensors, - ") elements, got ", beta_numel); + const size_t alpha_numel = alpha_tensor->data.numel(); + const size_t beta_numel = beta_tensor->data.numel(); + NVTE_CHECK(alpha_numel == num_tensors, + "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", + alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, + "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", + beta_numel); auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { return dtype == transformer_engine::DType::kFloat8E4M3 || From 00eb18662846645875c9da5edaeb37b216c8833c Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 16 Dec 2025 17:41:24 -0800 Subject: [PATCH 14/61] Einsum WIP 1 --- build_tools/build_ext.py | 6 ++ transformer_engine/jax/cpp_extensions/base.py | 12 +-- transformer_engine/jax/cpp_extensions/gemm.py | 10 +-- transformer_engine/jax/dense.py | 87 +++++++++++++------ transformer_engine/jax/sharding.py | 2 + 5 files changed, 79 insertions(+), 38 deletions(-) diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 349858ac49..c269a29874 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -61,6 +61,12 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None: f"-DCMAKE_BUILD_TYPE={build_type}", f"-DCMAKE_INSTALL_PREFIX={install_dir}", ] + if bool(int(os.getenv("NVTE_USE_CCACHE", "0"))): + ccache_bin = os.getenv("NVTE_CCACHE_BIN", "ccache") + configure_command += [ + f"-DCMAKE_CXX_COMPILER_LAUNCHER={ccache_bin}", + f"-DCMAKE_CUDA_COMPILER_LAUNCHER={ccache_bin}", + ] configure_command += self.cmake_flags import pybind11 diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 22a4b7dda4..70734ad4c4 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -207,12 +207,12 @@ def batcher(batched_args, batch_dims, *, arg1, arg2, arg3): if batch_dim is None: batch_dim = bdim batch_size = arg.shape[bdim] - elif bdim != batch_dim: - raise ValueError( - "All batched arguments must have the same batch dimension. " - f"Got batch_dims={batch_dims}" - ) - assert batch_dim is not None and batch_size is not None, "Invalid batching config!" + # elif bdim != batch_dim: + # raise ValueError( + # "All batched arguments must have the same batch dimension. " + # f"Got batch_dims={batch_dims}" + # ) + # assert batch_dim is not None and batch_size is not None, "Invalid batching config!" # Loop over batch dimension and collect results all_results = [] diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 55a1700838..7d44643046 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -812,11 +812,11 @@ def batcher( lhs_bdims, _, rhs_bdims, *_ = batch_dims # Validate batch dimensions - if lhs_bdims is not None or rhs_bdims is not None: - assert lhs_bdims == rhs_bdims, ( - "Batched GEMM requires matching batch dimensions, " - f"got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims}" - ) + # if lhs_bdims is not None or rhs_bdims is not None: + # assert lhs_bdims == rhs_bdims, ( + # "Batched GEMM requires matching batch dimensions, " + # f"got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims}" + # ) # Use general batcher from BasePrimitive return GemmPrimitive.batcher_impl( diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index c499b0651e..f941e598ae 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -69,6 +69,7 @@ def dense( output_axes: Tuple[str, ...] = None, collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set, quantizer_set: QuantizerSet = noop_quantizer_set, + batch_dims : Tuple[Sequence[int], Sequence[int]] = ((), ()), ): """Perform dense layer transformation with optional quantization. @@ -109,11 +110,12 @@ def dense( output_axes, collective_op_set, quantizer_set, + batch_dims, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8)) +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 10)) def _dense( x, kernel, @@ -125,6 +127,7 @@ def _dense( output_axes, collective_op_set, quantizer_set, # need to be a diff_arg for DelayedScaling state management + batch_dims, ): """Internal implementation of dense layer transformation with custom VJP. @@ -157,6 +160,7 @@ def _dense( output_axes, collective_op_set, quantizer_set, + batch_dims, ) return output @@ -172,6 +176,7 @@ def _dense_fwd_rule( output_axes, collective_op_set, quantizer_set, + batch_dims, ): """Forward pass rule for dense layer transformation. @@ -185,9 +190,9 @@ def _dense_fwd_rule( # Check supported input layout x_is_transposed = x.ndim - 1 not in x_contracting_dims k_is_transposed = kernel.ndim - 1 in k_contracting_dims - assert ( - not x_is_transposed and not k_is_transposed - ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel." + # assert ( + # not x_is_transposed and not k_is_transposed + # ), f"Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel. {x_contracting_dims=},{x.ndim=},{k_contracting_dims=},{kernel.ndim=}" flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) @@ -237,6 +242,47 @@ def _dense_fwd_rule( ) return output, ctx +def dot_general_transpose_lhs(g, x, y, *, dimension_numbers, + swap_ans=False): + # from: https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py#L198 + import itertools + import numpy as np + def _remaining(original, *removed_lists): + removed = set(itertools.chain(*removed_lists)) + return [i for i in original if i not in removed] + + def _ranges_like(*xs): + start = 0 + for x in xs: + x_len = len(x) + yield range(start, start + x_len) + start += x_len + + (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers + x_ndim = x.ndim + x_kept = _remaining(range(x_ndim), x_contract, x_batch) + y_kept = _remaining(range(y.ndim), y_contract, y_batch) + if swap_ans: + ans_batch, ans_y, _ = _ranges_like(x_batch, y_kept, x_kept) + else: + ans_batch, _, ans_y = _ranges_like(x_batch, x_kept, y_kept) + dims = ((ans_y, y_kept), (ans_batch, y_batch)) + x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) + out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) + x_bar = jax.lax.transpose( + # TODO(jberchtold): I'm ignoring the batch_dims here, do I need to explicitly use vmap or something? + tex.gemm(g, y, contracting_dims=dims[0]), + tuple(out_axes) + ) + return x_bar + +def dot_general_transpose_rhs(g, x, y, *, dimension_numbers): + (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers + swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) + y_bar = dot_general_transpose_lhs( + g, y, x, dimension_numbers=swapped_dimension_numbers, + swap_ans=True) + return y_bar def _dense_bwd_rule( contracting_dims, @@ -245,6 +291,7 @@ def _dense_bwd_rule( kernel_axes, output_axes, collective_op_set, + batch_dims, ctx, grad, ): @@ -277,35 +324,21 @@ def _dense_bwd_rule( transpose_batch_sequence=transpose_batch_sequence, ) - # GEMM NT - # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim - g_contracting_dim = tuple( - range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) - ) - # k_non_contracting_dims - k_contracting_dim = tuple( - dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims - ) + fwd_cdims = (fwd_x_contracting_dims, fwd_k_contracting_dims) + dims = (fwd_cdims, batch_dims) - dgrad = tex.gemm( + dgrad = dot_general_transpose_lhs( casted_grad.get_tensor(usage=TensorUsage.LHS), + casted_x_lhs, casted_kernel_rhs, - contracting_dims=(g_contracting_dim, k_contracting_dim), - transpose_batch_sequence=transpose_batch_sequence, - collective_op=collective_op_set.backward, + dimension_numbers=dims, ) - # GEMM TN - # x_non_contracting_dims - g_contracting_dim = x_contracting_dim = tuple( - range(0, len(x_shape) - len(fwd_x_contracting_dims)) - ) - - wgrad = tex.gemm( + wgrad = dot_general_transpose_rhs( + casted_grad.get_tensor(usage=TensorUsage.LHS), # TODO(jberchtold): should be RHS to use fused kernel for 2x layout? but would need to update dims accordingly casted_x_lhs, - casted_grad.get_tensor(usage=TensorUsage.RHS), - contracting_dims=(x_contracting_dim, g_contracting_dim), - transpose_batch_sequence=transpose_batch_sequence, + casted_kernel_rhs, + dimension_numbers=dims, ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 6cb0dd257c..01405ba87a 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -261,6 +261,8 @@ def get_mesh_axis_size(axis, mesh=None): if axis is None: return 1 + print(mesh) + assert axis in mesh.shape, f"{axis} is not a axis of the given mesh {mesh.shape}" return mesh.shape[axis] From 38defb8ec354055f0a14017d5a525e1cc911d57c Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 18 Dec 2025 08:45:19 -0800 Subject: [PATCH 15/61] Test --- transformer_engine/jax/cpp_extensions/base.py | 2 +- transformer_engine/jax/cpp_extensions/quantization.py | 2 +- transformer_engine/jax/dense.py | 9 ++------- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 70734ad4c4..defdce7b68 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -212,7 +212,7 @@ def batcher(batched_args, batch_dims, *, arg1, arg2, arg3): # "All batched arguments must have the same batch dimension. " # f"Got batch_dims={batch_dims}" # ) - # assert batch_dim is not None and batch_size is not None, "Invalid batching config!" + assert batch_dim is not None and batch_size is not None, "Invalid batching config!" # Loop over batch dimension and collect results all_results = [] diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 53c6937fb4..c5d76cf28c 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -362,7 +362,7 @@ def batcher( use_rht, ): """Batch rule for quantization primitive using general batcher.""" - check_valid_batch_dims(batch_dims) + # check_valid_batch_dims(batch_dims) assert BaseDBiasQuantizePrimitive.outer_primitive is not None return BaseDBiasQuantizePrimitive.batcher_impl( diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index f941e598ae..62b0e054aa 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -69,7 +69,6 @@ def dense( output_axes: Tuple[str, ...] = None, collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set, quantizer_set: QuantizerSet = noop_quantizer_set, - batch_dims : Tuple[Sequence[int], Sequence[int]] = ((), ()), ): """Perform dense layer transformation with optional quantization. @@ -110,12 +109,11 @@ def dense( output_axes, collective_op_set, quantizer_set, - batch_dims, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 10)) +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8)) def _dense( x, kernel, @@ -127,7 +125,6 @@ def _dense( output_axes, collective_op_set, quantizer_set, # need to be a diff_arg for DelayedScaling state management - batch_dims, ): """Internal implementation of dense layer transformation with custom VJP. @@ -160,7 +157,6 @@ def _dense( output_axes, collective_op_set, quantizer_set, - batch_dims, ) return output @@ -176,7 +172,6 @@ def _dense_fwd_rule( output_axes, collective_op_set, quantizer_set, - batch_dims, ): """Forward pass rule for dense layer transformation. @@ -291,7 +286,6 @@ def _dense_bwd_rule( kernel_axes, output_axes, collective_op_set, - batch_dims, ctx, grad, ): @@ -325,6 +319,7 @@ def _dense_bwd_rule( ) fwd_cdims = (fwd_x_contracting_dims, fwd_k_contracting_dims) + batch_dims = ((), ()) # vmap is done outside dense VJP if needed dims = (fwd_cdims, batch_dims) dgrad = dot_general_transpose_lhs( From e4a80a3522b8d1b29199d807a4770ebc815ca487 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Dec 2025 09:57:33 +0100 Subject: [PATCH 16/61] Refactor: move grouped GEMM to separate file and cleanup API Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 12 +- .../common/gemm/cublaslt_gemm.cu | 549 +--------------- .../common/gemm/cublaslt_grouped_gemm.cu | 599 ++++++++++++++++++ .../common/gemm/cublaslt_grouped_gemm.cuh | 18 + .../common/include/transformer_engine/gemm.h | 12 +- 5 files changed, 635 insertions(+), 555 deletions(-) create mode 100644 transformer_engine/common/gemm/cublaslt_grouped_gemm.cu create mode 100644 transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 82b5bd3803..0ea76946bc 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include #include @@ -314,9 +315,12 @@ std::vector> make_shapes(ShapeCase scase) { } void run_grouped_gemm_case(const TestParams& params) { - if (params.input_case != InputCase::kBF16 && - getDeviceComputeCapability() < hopperComputeCapability) { - GTEST_SKIP() << "FP8 grouped GEMM requires Hopper or newer."; +#if CUBLAS_VERSION < 130200 + GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is " + << CUBLAS_VERSION << "."; +#else + if (getDeviceComputeCapability() < hopperComputeCapability) { + GTEST_SKIP() << "Grouped GEMM requires Hopper (SM90) or newer."; } const std::vector> shapes = make_shapes(params.shape_case); @@ -451,7 +455,6 @@ void run_grouped_gemm_case(const TestParams& params) { grouped_D.get_handle(), setup_ws.data(), cublas_ws.data(), - nullptr, 0, nullptr, nullptr, @@ -477,6 +480,7 @@ void run_grouped_gemm_case(const TestParams& params) { atol, rtol); } +#endif // CUBLAS_VERSION >= 130200 } class GroupedGemmTest : public ::testing::TestWithParam {}; diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index b8aa2a8ba3..86f517af7d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -23,6 +23,7 @@ #include "../util/logging.h" #include "../util/multi_stream.h" #include "./config.h" +#include "./cublaslt_grouped_gemm.cuh" #include "./cutlass_grouped_gemm.cuh" namespace { @@ -1104,551 +1105,3 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor cublas_path(); } } - -// Helper struct to pass per-tensor shape/offset info (pointer or uniform value) -struct TensorShapeInfo { - const int64_t *first_dims; // nullptr if uniform - const int64_t *last_dims; // nullptr if uniform - const int64_t *offsets; // nullptr if need to compute - int64_t uniform_first; // used if first_dims == nullptr - int64_t uniform_last; // used if last_dims == nullptr - - // Create from GroupedTensor - static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { - const bool has_first = t->first_dims.has_data(); - const bool has_last = t->last_dims.has_data(); - // When per-tensor dims are not provided, we must be in the uniform-shape case. - NVTE_CHECK(has_first || t->all_same_first_dim(), - "GroupedTensor is missing first_dims for varying shapes"); - NVTE_CHECK(has_last || t->all_same_last_dim(), - "GroupedTensor is missing last_dims for varying shapes"); - - const int64_t *first_ptr = - has_first ? static_cast(t->first_dims.dptr) : nullptr; - const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; - - const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); - const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); - - return {first_ptr, last_ptr, - t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) - : nullptr, - uniform_first, uniform_last}; - } - - // Create for C tensor (uses D's dimensions, only has offsets) - static TensorShapeInfo for_C(const transformer_engine::GroupedTensor *C, - const transformer_engine::GroupedTensor *D) { - const bool has_first = D->first_dims.has_data(); - const bool has_last = D->last_dims.has_data(); - NVTE_CHECK(has_first || D->all_same_first_dim(), - "GroupedTensor D is missing first_dims for varying shapes"); - NVTE_CHECK(has_last || D->all_same_last_dim(), - "GroupedTensor D is missing last_dims for varying shapes"); - - const int64_t *first_ptr = - has_first ? static_cast(D->first_dims.dptr) : nullptr; - const int64_t *last_ptr = has_last ? static_cast(D->last_dims.dptr) : nullptr; - const int64_t uniform_first = has_first ? 0 : static_cast(D->get_common_first_dim()); - const int64_t uniform_last = has_last ? 0 : static_cast(D->get_common_last_dim()); - - return {first_ptr, last_ptr, - C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) - : nullptr, - uniform_first, uniform_last}; - } -}; - -// Helper functions to compute average dimensions from logical_shape for heuristics -// These are hints for cuBLASLt algorithm selection, don't need to be exact -inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor *t) { - // logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first) - // In both cases, dividing by num_tensors gives the average - return static_cast(t->logical_shape.data[0]) / static_cast(t->num_tensors); -} - -inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) { - if (t->all_same_last_dim()) { - // logical_shape[1] is the common N - return static_cast(t->logical_shape.data[1]); - } - // When varying, logical_shape[1] should be sum of last dims if provided; otherwise fallback to avg via division. - return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); -} - -// Workspace layout for grouped GEMM -struct GroupedGemmSetupWorkspace { - void **A_ptrs; - void **B_ptrs; - void **C_ptrs; - void **D_ptrs; - int *M; - int *N; - int *K; - float **alpha_ptrs; - float **beta_ptrs; - - // Initialize from workspace buffer - // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) - static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) { - GroupedGemmSetupWorkspace ws; - size_t offset = 0; - const size_t ptr_size = num_tensors * sizeof(void *); - const size_t int_size = num_tensors * sizeof(int); - - // Pointer arrays first (all 8-byte aligned) - ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - - // Int arrays last (4-byte aligned, always satisfied after pointer arrays) - ws.M = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - ws.N = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - ws.K = reinterpret_cast(setup_ws_ptr + offset); - - return ws; - } - - // Calculate required size for setup workspace (pointer arrays + M/N/K) - static size_t required_setup_size(size_t num_tensors, size_t alignment) { - const size_t ptr_size = num_tensors * sizeof(void *); - const size_t int_size = num_tensors * sizeof(int); - // Layout: 6 ptr arrays, then 3 int arrays (no padding needed) - size_t size = 6 * ptr_size + 3 * int_size; - size = ((size + alignment - 1) / alignment) * alignment; - return size; - } -}; - -// ----------------------------------------------------------------------------- -// Helper routines to keep nvte_grouped_gemm readable -// ----------------------------------------------------------------------------- -inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor *inputA, - const transformer_engine::GroupedTensor *inputB, - const transformer_engine::GroupedTensor *inputC, - const transformer_engine::GroupedTensor *outputD, - const transformer_engine::Tensor *alpha_tensor, - const transformer_engine::Tensor *beta_tensor) { - const size_t num_tensors = inputA->num_tensors; - NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); - NVTE_CHECK(inputB->num_tensors == num_tensors, - "Grouped GEMM: A and B must have the same num_tensors"); - // C can be NULL (will use D as C when beta=0) - if (inputC != nullptr) { - NVTE_CHECK(inputC->num_tensors == num_tensors, - "Grouped GEMM: A and C must have the same num_tensors"); - } - NVTE_CHECK(outputD->num_tensors == num_tensors, - "Grouped GEMM: A and D must have the same num_tensors"); - - // Validate alpha/beta have per-matrix values - const size_t alpha_numel = alpha_tensor->data.numel(); - const size_t beta_numel = beta_tensor->data.numel(); - NVTE_CHECK(alpha_numel == num_tensors, - "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", - alpha_numel); - NVTE_CHECK(beta_numel == num_tensors, - "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", - beta_numel); - - auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { - return dtype == transformer_engine::DType::kFloat8E4M3 || - dtype == transformer_engine::DType::kFloat8E5M2 || - dtype == transformer_engine::DType::kBFloat16 || - dtype == transformer_engine::DType::kFloat16; - }; - auto is_output_dtype = [](transformer_engine::DType dtype) { - return dtype == transformer_engine::DType::kBFloat16 || - dtype == transformer_engine::DType::kFloat16 || - dtype == transformer_engine::DType::kFloat32; - }; - NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), - "Grouped GEMM inputs must be FP8, BF16, or FP16."); - // Only check C dtype if C is provided - if (inputC != nullptr) { - NVTE_CHECK(is_output_dtype(inputC->dtype()), "Grouped GEMM: C must be BF16, FP16, or FP32."); - } - NVTE_CHECK(is_output_dtype(outputD->dtype()), "Grouped GEMM: D must be BF16, FP16, or FP32."); - NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), - "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); - NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), - "Grouped GEMM: B tensor is missing both row-wise and column-wise data"); -} - -// Select row-wise vs column-wise storage and adjust transpose flag for grouped GEMM. -// Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and -// fallback to column-wise data when row-wise is absent. -struct GroupedOperandSelection { - const transformer_engine::GroupedTensor *tensor = nullptr; - const char *base = nullptr; - transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; - bool trans = false; - bool use_columnwise = false; -}; - -inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor *t, - bool trans, bool is_A) { - using namespace transformer_engine; - const bool has_row = t->has_data(); - const bool has_col = t->has_columnwise_data(); - NVTE_CHECK(has_row || has_col, - "Grouped GEMM operand is missing both row-wise and column-wise data"); - - // Not yet supported in grouped GEMM: block scaling, MXFP8, NVFP4 specialized layouts. - const auto sm = t->scaling_mode; - NVTE_CHECK(sm != NVTE_BLOCK_SCALING_1D && sm != NVTE_BLOCK_SCALING_2D && !is_mxfp_scaling(sm) && - !is_nvfp_scaling(sm), - "Grouped GEMM does not yet support NVFP4/MXFP8/block scaling operand selection"); - - const DType row_dtype = t->data.dtype; - const DType col_dtype = t->columnwise_data.dtype; - GroupedOperandSelection sel; - sel.tensor = t; - sel.trans = trans; - - const DType rep_dtype = has_row ? row_dtype : col_dtype; - const bool is_fp8 = is_fp8_dtype(rep_dtype); - const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); - - // Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed. - if (is_fp8 && !non_tn_fp8_ok) { - if (is_A) { - if (!sel.trans) { - NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"); - sel.base = static_cast(t->columnwise_data.dptr); - sel.dtype = col_dtype; - sel.trans = true; // using pre-transposed storage - sel.use_columnwise = true; - return sel; - } - } else { // B - if (sel.trans) { - NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"); - sel.base = static_cast(t->columnwise_data.dptr); - sel.dtype = col_dtype; - sel.trans = false; // using pre-transposed storage - sel.use_columnwise = true; - return sel; - } - } - } - - // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). - if (!has_row && has_col) { - // On Hopper FP8, this would break TN requirement - should have been handled above - NVTE_CHECK( - !is_fp8 || non_tn_fp8_ok, - "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); - sel.base = static_cast(t->columnwise_data.dptr); - sel.dtype = col_dtype; - sel.trans = !sel.trans; - sel.use_columnwise = true; - return sel; - } - - // Default: use row-wise data (column-wise case already handled above) - sel.base = static_cast(t->data.dptr); - sel.dtype = row_dtype; - sel.use_columnwise = false; - return sel; -} - -inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size_t required_size, - const char *workspace_name) { - NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null."); - const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype); - NVTE_CHECK(provided_size >= required_size, "Grouped GEMM: Insufficient ", workspace_name, - ". Required: ", required_size, " bytes, Available: ", provided_size, " bytes."); - return ws->data.dptr; -} - -inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, - cublasLtMatrixLayoutOpaque_t &descB, - cublasLtMatrixLayoutOpaque_t &descC, - cublasLtMatrixLayoutOpaque_t &descD, - const GroupedGemmSetupWorkspace &ws, - const GroupedOperandSelection &A_sel, - const GroupedOperandSelection &B_sel, - const transformer_engine::GroupedTensor *D, size_t num_tensors) { - const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); - const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); - const cudaDataType_t D_type = get_cuda_dtype(D->dtype()); - - // For column-major layout: leading dimension is the number of rows in storage. - // If columnwise data was chosen, storage is already transposed. - int *rowa = A_sel.use_columnwise ? ws.M : (A_sel.trans ? ws.K : ws.M); - int *cola = A_sel.use_columnwise ? ws.K : (A_sel.trans ? ws.M : ws.K); - int *lda = rowa; - int *rowb = B_sel.use_columnwise ? ws.N : (B_sel.trans ? ws.N : ws.K); - int *colb = B_sel.use_columnwise ? ws.K : (B_sel.trans ? ws.K : ws.N); - int *ldb = rowb; - - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rowa, cola, lda)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, rowb, colb, ldb)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.M, ws.N, ws.M)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.M, ws.N, ws.M)); -} - -inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, - cublasOperation_t op_B) { - NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); - - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, - sizeof(op_A))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, - sizeof(op_B))); - - cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, - &pointer_mode, sizeof(pointer_mode))); - - int64_t alphabeta_batch_stride = 1; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, - CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, - &alphabeta_batch_stride, sizeof(int64_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, - CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, - &alphabeta_batch_stride, sizeof(int64_t))); -} - -inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, - const GroupedOperandSelection &A_sel, - const GroupedOperandSelection &B_sel) { - const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); - const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); - if (!is_fp8_a && !is_fp8_b) return; - - if (is_fp8_a) { - void *a_scale_inv = A_sel.use_columnwise ? A_sel.tensor->columnwise_scale_inv.dptr - : A_sel.tensor->scale_inv.dptr; - NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); - } - if (is_fp8_b) { - void *b_scale_inv = B_sel.use_columnwise ? B_sel.tensor->columnwise_scale_inv.dptr - : B_sel.tensor->scale_inv.dptr; - NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); - } -} - -// Constants for grouped GEMM workspace (declared early for use in heuristics) -static constexpr size_t kGroupedGemmAlignment = 256; -static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB - -inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, - cublasLtMatmulDescOpaque_t &matmulDesc, - cublasLtMatrixLayoutOpaque_t &descA, - cublasLtMatrixLayoutOpaque_t &descB, - cublasLtMatrixLayoutOpaque_t &descC, - cublasLtMatrixLayoutOpaque_t &descD, - int64_t avg_m, int64_t avg_n, int64_t avg_k) { - cublasLtMatmulPreferenceOpaque_t preference; - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference)); - NVTE_CHECK_CUBLAS( - cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &kGroupedGemmCublasWorkspaceSize, sizeof(size_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_COLS, &avg_n, sizeof(int64_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_GROUPED_AVERAGE_REDUCTION_DIM, &avg_k, sizeof(int64_t))); - - cublasLtMatmulHeuristicResult_t heuristicResult; - int returnedResults = 0; - auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD, - &preference, 1, &heuristicResult, &returnedResults); - NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, - "Unable to find suitable cuBLAS grouped GEMM algorithm"); - NVTE_CHECK_CUBLAS(status); - NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM"); - return heuristicResult.algo; -} - -// Single kernel that sets up all GEMM parameters. -// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix M/N/K, -// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes. -// We bridge the mismatch on GPU by computing per-group pointers and dims in one kernel. -__global__ void setup_grouped_gemm_kernel( - // Output arrays - void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *M, int *N, int *K, - float **alpha_ptrs, float **beta_ptrs, - // Base pointers - const char *a_base, const char *b_base, const char *c_base, char *d_base, - // Dimension info (per tensor) - TensorShapeInfo A_meta, TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, - // Element sizes - size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, - // Alpha/beta pointers (per-matrix arrays) - float *alpha_ptr, float *beta_ptr, - // Transpose flags - bool transa, bool transb, - // Number of tensors - size_t num_tensors) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= num_tensors) return; - - // Get dimensions for this tensor (from array or uniform value) - int64_t a_first = A_meta.first_dims ? A_meta.first_dims[idx] : A_meta.uniform_first; - int64_t a_last = A_meta.last_dims ? A_meta.last_dims[idx] : A_meta.uniform_last; - int64_t b_first = B_meta.first_dims ? B_meta.first_dims[idx] : B_meta.uniform_first; - int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last; - - // Compute offsets (from array or compute from uniform dims) - int64_t a_offset = - A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); - int64_t b_offset = - B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); - int64_t c_offset = - C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); - int64_t d_offset = - D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); - - // Compute data pointers - A_ptrs[idx] = const_cast(a_base) + a_offset * a_elem_size; - B_ptrs[idx] = const_cast(b_base) + b_offset * b_elem_size; - C_ptrs[idx] = const_cast(c_base) + c_offset * c_elem_size; - D_ptrs[idx] = d_base + d_offset * d_elem_size; - - // Compute M, N, K dimensions - // Test stores A as {K,M} when !transa, {M,K} when transa - // Test stores B as {N,K} when !transb, {K,N} when transb - M[idx] = static_cast(transa ? a_first : a_last); - K[idx] = static_cast(transa ? a_last : a_first); - N[idx] = static_cast(transb ? b_last : b_first); - - // Fill alpha/beta pointers (per-matrix) - alpha_ptrs[idx] = alpha_ptr + idx; - beta_ptrs[idx] = beta_ptr + idx; -} - -// Launch the setup kernel to populate workspace arrays -inline void launch_grouped_gemm_setup( - const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, - const GroupedOperandSelection &B_sel, const transformer_engine::GroupedTensor *C, - const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, - const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream) { - TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A_sel.tensor); - TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B_sel.tensor); - TensorShapeInfo C_meta = TensorShapeInfo::for_C(C, D); - TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); - - const char *c_base = static_cast(C->data.dptr); - char *d_base = static_cast(D->data.dptr); - - const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); - const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); - const size_t c_elem_size = transformer_engine::typeToSize(C->dtype()); - const size_t d_elem_size = transformer_engine::typeToSize(D->dtype()); - - const int threads_per_block = 256; - const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; - - setup_grouped_gemm_kernel<<>>( - ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.M, ws.N, ws.K, ws.alpha_ptrs, ws.beta_ptrs, - A_sel.base, B_sel.base, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, - b_elem_size, c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), - static_cast(beta_tensor->data.dptr), A_sel.trans, B_sel.trans, num_tensors); - - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { - return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); -} - -void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, - const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, - NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, - NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m, - const int64_t *avg_n, const int64_t *avg_k) { - NVTE_API_CALL(nvte_grouped_gemm); - using namespace transformer_engine; - - // Convert to internal types - const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); - const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); - const GroupedTensor *inputC_raw = convertNVTEGroupedTensor(C); // Can be NULL - GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D); - const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); - const Tensor *beta_tensor = convertNVTETensorCheck(beta); - Tensor *wspace_setup = convertNVTETensor(workspace_setup); - Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); - - // Validate inputs and num_tensors - validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD, alpha_tensor, beta_tensor); - - // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) - const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; - const size_t num_tensors = inputA->num_tensors; - - // Select operand storage (row-wise vs column-wise) and adjust transpose flags to - // mirror the non-grouped GEMM logic for FP8 layout constraints. - const auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); - const auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); - - // Workspaces: setup (pointer arrays) and cuBLAS - const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); - const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize; - - void *setup_workspace_ptr = validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, - "Grouped GEMM setup workspace"); - void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, - "Grouped GEMM cuBLAS workspace"); - - auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( - static_cast(setup_workspace_ptr), num_tensors); - launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor, - beta_tensor, num_tensors, stream); - - // Get cuBLAS handle - using cublasHandleManager = detail::HandleManager; - cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); - - // Setup cuBLAS operations - cublasOperation_t op_A = A_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t op_B = B_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; - - // Create grouped matrix layouts - cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; - init_matrix_layouts(descA, descB, descC, descD, setup_workspace, A_sel, B_sel, outputD, - num_tensors); - - // Create matmul descriptor - cublasLtMatmulDescOpaque_t matmulDesc; - init_matmul_desc(matmulDesc, op_A, op_B); - set_fp8_scale_pointers(matmulDesc, A_sel, B_sel); - - // Compute average dimensions for heuristics - // K dimension: if transa, K is A's first dim; if not, K is A's last dim - int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); - int64_t avg_n_val = avg_n ? *avg_n : compute_avg_last_dim(outputD); - int64_t avg_k_val = avg_k ? *avg_k - : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) - : compute_avg_last_dim(A_sel.tensor)); - - // Heuristic selection - cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, - descD, avg_m_val, avg_n_val, avg_k_val); - - // Execute the grouped GEMM - NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs, - setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, - setup_workspace.beta_ptrs, setup_workspace.C_ptrs, &descC, - setup_workspace.D_ptrs, &descD, &algo, cublas_workspace_ptr, - kGroupedGemmCublasWorkspaceSize, stream)); -} diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu new file mode 100644 index 0000000000..4125bd82bf --- /dev/null +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -0,0 +1,599 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "../util/handle_manager.h" +#include "../util/logging.h" +#include "./cublaslt_grouped_gemm.cuh" + +namespace { + +inline void CreateCublasHandle(cublasLtHandle_t *handle) { + NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); +} + +} // namespace + +#if CUBLAS_VERSION >= 130100 + +namespace { + +// Helper struct to pass per-tensor shape/offset info (pointer or uniform value) +struct TensorShapeInfo { + const int64_t *first_dims; // nullptr if uniform + const int64_t *last_dims; // nullptr if uniform + const int64_t *offsets; // nullptr if need to compute + int64_t uniform_first; // used if first_dims == nullptr + int64_t uniform_last; // used if last_dims == nullptr + + // Create from GroupedTensor + static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { + const bool has_first = t->first_dims.has_data(); + const bool has_last = t->last_dims.has_data(); + // When per-tensor dims are not provided, we must be in the uniform-shape case. + NVTE_CHECK(has_first || t->all_same_first_dim(), + "GroupedTensor is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || t->all_same_last_dim(), + "GroupedTensor is missing last_dims for varying shapes"); + + const int64_t *first_ptr = + has_first ? static_cast(t->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; + + const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); + + return {first_ptr, last_ptr, + t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) + : nullptr, + uniform_first, uniform_last}; + } + + // Create for C tensor (uses D's dimensions, only has offsets) + static TensorShapeInfo for_C(const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D) { + const bool has_first = D->first_dims.has_data(); + const bool has_last = D->last_dims.has_data(); + NVTE_CHECK(has_first || D->all_same_first_dim(), + "GroupedTensor D is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || D->all_same_last_dim(), + "GroupedTensor D is missing last_dims for varying shapes"); + + const int64_t *first_ptr = + has_first ? static_cast(D->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(D->last_dims.dptr) : nullptr; + const int64_t uniform_first = has_first ? 0 : static_cast(D->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(D->get_common_last_dim()); + + return {first_ptr, last_ptr, + C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) + : nullptr, + uniform_first, uniform_last}; + } +}; + +// Helper functions to compute average dimensions from logical_shape for heuristics +// These are hints for cuBLASLt algorithm selection, don't need to be exact +inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor *t) { + // logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first) + // In both cases, dividing by num_tensors gives the average + return static_cast(t->logical_shape.data[0]) / static_cast(t->num_tensors); +} + +inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) { + if (t->all_same_last_dim()) { + // logical_shape[1] is the common N + return static_cast(t->logical_shape.data[1]); + } + // When varying, logical_shape[1] should be sum of last dims if provided; otherwise fallback to avg via division. + return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); +} + +// Workspace layout for grouped GEMM +struct GroupedGemmSetupWorkspace { + void **A_ptrs; + void **B_ptrs; + void **C_ptrs; + void **D_ptrs; + int *M; + int *N; + int *K; + float **alpha_ptrs; + float **beta_ptrs; + + // Initialize from workspace buffer + // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) + static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) { + GroupedGemmSetupWorkspace ws; + size_t offset = 0; + const size_t ptr_size = num_tensors * sizeof(void *); + const size_t int_size = num_tensors * sizeof(int); + + // Pointer arrays first (all 8-byte aligned) + ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + + // Int arrays last (4-byte aligned, always satisfied after pointer arrays) + ws.M = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.N = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.K = reinterpret_cast(setup_ws_ptr + offset); + + return ws; + } + + // Calculate required size for setup workspace (pointer arrays + M/N/K) + static size_t required_setup_size(size_t num_tensors, size_t alignment) { + const size_t ptr_size = num_tensors * sizeof(void *); + const size_t int_size = num_tensors * sizeof(int); + // Layout: 6 ptr arrays, then 3 int arrays (no padding needed) + size_t size = 6 * ptr_size + 3 * int_size; + size = ((size + alignment - 1) / alignment) * alignment; + return size; + } +}; + +// ----------------------------------------------------------------------------- +// Helper routines to keep nvte_grouped_gemm readable +// ----------------------------------------------------------------------------- +inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor *inputA, + const transformer_engine::GroupedTensor *inputB, + const transformer_engine::GroupedTensor *inputC, + const transformer_engine::GroupedTensor *outputD, + const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor) { + const size_t num_tensors = inputA->num_tensors; + NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); + NVTE_CHECK(inputB->num_tensors == num_tensors, + "Grouped GEMM: A and B must have the same num_tensors"); + // C can be NULL (will use D as C when beta=0) + if (inputC != nullptr) { + NVTE_CHECK(inputC->num_tensors == num_tensors, + "Grouped GEMM: A and C must have the same num_tensors"); + } + NVTE_CHECK(outputD->num_tensors == num_tensors, + "Grouped GEMM: A and D must have the same num_tensors"); + + // Validate alpha/beta have per-matrix values + const size_t alpha_numel = alpha_tensor->data.numel(); + const size_t beta_numel = beta_tensor->data.numel(); + NVTE_CHECK(alpha_numel == num_tensors, + "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", + alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, + "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", + beta_numel); + + auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2 || + dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16; + }; + auto is_output_dtype = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16 || + dtype == transformer_engine::DType::kFloat32; + }; + NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), + "Grouped GEMM inputs must be FP8, BF16, or FP16."); + // Only check C dtype if C is provided + if (inputC != nullptr) { + NVTE_CHECK(is_output_dtype(inputC->dtype()), "Grouped GEMM: C must be BF16, FP16, or FP32."); + } + NVTE_CHECK(is_output_dtype(outputD->dtype()), "Grouped GEMM: D must be BF16, FP16, or FP32."); + NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), + "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); + NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), + "Grouped GEMM: B tensor is missing both row-wise and column-wise data"); +} + +// Select row-wise vs column-wise storage and adjust transpose flag for grouped GEMM. +// Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and +// fallback to column-wise data when row-wise is absent. +struct GroupedOperandSelection { + const transformer_engine::GroupedTensor *tensor = nullptr; + const char *dptr = nullptr; + transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; + bool trans = false; + bool use_columnwise = false; +}; + +inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor *t, + bool trans, bool is_A) { + using namespace transformer_engine; + const bool has_row = t->has_data(); + const bool has_col = t->has_columnwise_data(); + NVTE_CHECK(has_row || has_col, + "Grouped GEMM operand is missing both row-wise and column-wise data"); + + // Currently only unquantized data and tensor-scaled FP8 are supported. + const auto sm = t->scaling_mode; + NVTE_CHECK(sm == NVTE_DELAYED_TENSOR_SCALING, + "Grouped GEMM is only supported with unquantized data and tensor-scaled FP8 data"); + + const DType row_dtype = t->data.dtype; + const DType col_dtype = t->columnwise_data.dtype; + GroupedOperandSelection sel; + sel.tensor = t; + sel.trans = trans; + + const DType rep_dtype = has_row ? row_dtype : col_dtype; + const bool is_fp8 = is_fp8_dtype(rep_dtype); + const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); + + // Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed. + if (is_fp8 && !non_tn_fp8_ok) { + if (is_A) { + if (!sel.trans) { + NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"); + sel.dptr = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = true; // using pre-transposed storage + sel.use_columnwise = true; + return sel; + } + } else { // B + if (sel.trans) { + NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"); + sel.dptr = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = false; // using pre-transposed storage + sel.use_columnwise = true; + return sel; + } + } + } + + // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). + if (!has_row && has_col) { + // On Hopper FP8, this would break TN requirement - should have been handled above + NVTE_CHECK( + !is_fp8 || non_tn_fp8_ok, + "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); + sel.dptr = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = !sel.trans; + sel.use_columnwise = true; + return sel; + } + + // Default: use row-wise data (column-wise case already handled above) + sel.dptr = static_cast(t->data.dptr); + sel.dtype = row_dtype; + sel.use_columnwise = false; + return sel; +} + +inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size_t required_size, + const char *workspace_name) { + NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null."); + const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype); + NVTE_CHECK(provided_size >= required_size, "Grouped GEMM: Insufficient ", workspace_name, + ". Required: ", required_size, " bytes, Available: ", provided_size, " bytes."); + return ws->data.dptr; +} + +inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, + cublasLtMatrixLayoutOpaque_t &descB, + cublasLtMatrixLayoutOpaque_t &descC, + cublasLtMatrixLayoutOpaque_t &descD, + const GroupedGemmSetupWorkspace &ws, + const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, + const transformer_engine::GroupedTensor *D, size_t num_tensors) { + const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); + const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); + const cudaDataType_t D_type = get_cuda_dtype(D->dtype()); + + // For column-major layout: leading dimension is the number of rows in storage. + // If columnwise data was chosen, storage is already transposed. + int *rowa = A_sel.use_columnwise ? ws.M : (A_sel.trans ? ws.K : ws.M); + int *cola = A_sel.use_columnwise ? ws.K : (A_sel.trans ? ws.M : ws.K); + int *lda = rowa; + int *rowb = B_sel.use_columnwise ? ws.N : (B_sel.trans ? ws.N : ws.K); + int *colb = B_sel.use_columnwise ? ws.K : (B_sel.trans ? ws.K : ws.N); + int *ldb = rowb; + + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rowa, cola, lda)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, rowb, colb, ldb)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.M, ws.N, ws.M)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.M, ws.N, ws.M)); +} + +inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, + cublasOperation_t op_B) { + NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, + sizeof(op_A))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, + sizeof(op_B))); + + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, sizeof(pointer_mode))); + + int64_t alphabeta_batch_stride = 1; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); +} + +inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, + const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel) { + const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); + const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); + if (!is_fp8_a && !is_fp8_b) return; + + if (is_fp8_a) { + void *a_scale_inv = A_sel.use_columnwise ? A_sel.tensor->columnwise_scale_inv.dptr + : A_sel.tensor->scale_inv.dptr; + NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); + } + if (is_fp8_b) { + void *b_scale_inv = B_sel.use_columnwise ? B_sel.tensor->columnwise_scale_inv.dptr + : B_sel.tensor->scale_inv.dptr; + NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); + } +} + +// Constants for grouped GEMM workspace (declared early for use in heuristics) +static constexpr size_t kGroupedGemmAlignment = 256; +static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB + +inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, + cublasLtMatmulDescOpaque_t &matmulDesc, + cublasLtMatrixLayoutOpaque_t &descA, + cublasLtMatrixLayoutOpaque_t &descB, + cublasLtMatrixLayoutOpaque_t &descC, + cublasLtMatrixLayoutOpaque_t &descD, + int64_t avg_m, int64_t avg_n, int64_t avg_k) { + cublasLtMatmulPreferenceOpaque_t preference; + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference)); + NVTE_CHECK_CUBLAS( + cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &kGroupedGemmCublasWorkspaceSize, sizeof(size_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_COLS, &avg_n, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_AVERAGE_REDUCTION_DIM, &avg_k, sizeof(int64_t))); + + cublasLtMatmulHeuristicResult_t heuristicResult; + int returnedResults = 0; + auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD, + &preference, 1, &heuristicResult, &returnedResults); + NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, + "Unable to find suitable cuBLAS grouped GEMM algorithm"); + NVTE_CHECK_CUBLAS(status); + NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM"); + return heuristicResult.algo; +} + +// Single kernel that sets up all GEMM parameters. +// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix M/N/K, +// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes. +// We bridge the mismatch on GPU by computing per-group pointers and dims in one kernel. +__global__ void setup_grouped_gemm_kernel( + // Output arrays + void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *M, int *N, int *K, + float **alpha_ptrs, float **beta_ptrs, + // Base pointers + const char *a_base, const char *b_base, const char *c_base, char *d_base, + // Dimension info (per tensor) + TensorShapeInfo A_meta, TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, + // Element sizes + size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, + // Alpha/beta pointers (per-matrix arrays) + float *alpha_ptr, float *beta_ptr, + // Transpose flags + bool transa, bool transb, + // Number of tensors + size_t num_tensors) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_tensors) return; + + // Get dimensions for this tensor (from array or uniform value) + int64_t a_first = A_meta.first_dims ? A_meta.first_dims[idx] : A_meta.uniform_first; + int64_t a_last = A_meta.last_dims ? A_meta.last_dims[idx] : A_meta.uniform_last; + int64_t b_first = B_meta.first_dims ? B_meta.first_dims[idx] : B_meta.uniform_first; + int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last; + + // Compute offsets (from array or compute from uniform dims) + int64_t a_offset = + A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); + int64_t b_offset = + B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); + int64_t c_offset = + C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); + int64_t d_offset = + D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); + + // Compute data pointers + A_ptrs[idx] = const_cast(a_base) + a_offset * a_elem_size; + B_ptrs[idx] = const_cast(b_base) + b_offset * b_elem_size; + C_ptrs[idx] = const_cast(c_base) + c_offset * c_elem_size; + D_ptrs[idx] = d_base + d_offset * d_elem_size; + + // Compute M, N, K dimensions + // Test stores A as {K,M} when !transa, {M,K} when transa + // Test stores B as {N,K} when !transb, {K,N} when transb + M[idx] = static_cast(transa ? a_first : a_last); + K[idx] = static_cast(transa ? a_last : a_first); + N[idx] = static_cast(transb ? b_last : b_first); + + // Fill alpha/beta pointers (per-matrix) + alpha_ptrs[idx] = alpha_ptr + idx; + beta_ptrs[idx] = beta_ptr + idx; +} + +// Launch the setup kernel to populate workspace arrays +inline void launch_grouped_gemm_setup( + const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream) { + TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A_sel.tensor); + TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B_sel.tensor); + TensorShapeInfo C_meta = TensorShapeInfo::for_C(C, D); + TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); + + const char *c_base = static_cast(C->data.dptr); + char *d_base = static_cast(D->data.dptr); + + const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); + const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); + const size_t c_elem_size = transformer_engine::typeToSize(C->dtype()); + const size_t d_elem_size = transformer_engine::typeToSize(D->dtype()); + + const int threads_per_block = 256; + const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; + + setup_grouped_gemm_kernel<<>>( + ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.M, ws.N, ws.K, ws.alpha_ptrs, ws.beta_ptrs, + A_sel.dptr, B_sel.dptr, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, + b_elem_size, c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), + static_cast(beta_tensor->data.dptr), A_sel.trans, B_sel.trans, num_tensors); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { + return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); +} + +} // namespace + +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, + const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, + NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, + cudaStream_t stream, const int64_t *avg_m, const int64_t *avg_n, + const int64_t *avg_k) { + NVTE_API_CALL(nvte_grouped_gemm); + using namespace transformer_engine; + + // Grouped GEMM requires Hopper (SM90) or newer + const int current_device = cuda::current_device(); + NVTE_CHECK(cuda::sm_arch(current_device) >= 90, + "nvte_grouped_gemm requires Hopper (SM90) or newer architecture."); + + // Convert to internal types + const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); + const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); + const GroupedTensor *inputC_raw = convertNVTEGroupedTensor(C); // Can be NULL + GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D); + const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); + const Tensor *beta_tensor = convertNVTETensorCheck(beta); + Tensor *wspace_setup = convertNVTETensor(workspace_setup); + Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); + + // Validate inputs and num_tensors + validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD, alpha_tensor, beta_tensor); + + // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) + const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; + const size_t num_tensors = inputA->num_tensors; + + // Select operand storage (row-wise vs column-wise) and adjust transpose flags to + // mirror the non-grouped GEMM logic for FP8 layout constraints. + const auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); + const auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); + + // Workspaces: setup (pointer arrays) and cuBLAS + const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); + const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize; + + void *setup_workspace_ptr = validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, + "Grouped GEMM setup workspace"); + void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, + "Grouped GEMM cuBLAS workspace"); + + auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( + static_cast(setup_workspace_ptr), num_tensors); + launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor, + beta_tensor, num_tensors, stream); + + // Get cuBLAS handle + using cublasHandleManager = detail::HandleManager; + cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); + + // Setup cuBLAS operations + cublasOperation_t op_A = A_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t op_B = B_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; + + // Create grouped matrix layouts + cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; + init_matrix_layouts(descA, descB, descC, descD, setup_workspace, A_sel, B_sel, outputD, + num_tensors); + + // Create matmul descriptor + cublasLtMatmulDescOpaque_t matmulDesc; + init_matmul_desc(matmulDesc, op_A, op_B); + set_fp8_scale_pointers(matmulDesc, A_sel, B_sel); + + // Compute average dimensions for heuristics + // K dimension: if transa, K is A's first dim; if not, K is A's last dim + int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); + int64_t avg_n_val = avg_n ? *avg_n : compute_avg_last_dim(outputD); + int64_t avg_k_val = avg_k ? *avg_k + : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) + : compute_avg_last_dim(A_sel.tensor)); + + // Heuristic selection + cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, + descD, avg_m_val, avg_n_val, avg_k_val); + + // Execute the grouped GEMM + NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs, + setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, + setup_workspace.beta_ptrs, setup_workspace.C_ptrs, &descC, + setup_workspace.D_ptrs, &descD, &algo, cublas_workspace_ptr, + kGroupedGemmCublasWorkspaceSize, stream)); +} + +#else // CUBLAS_VERSION < 130100 + +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, + const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, + NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, + cudaStream_t stream, const int64_t *avg_m, const int64_t *avg_n, + const int64_t *avg_k) { + NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.2+, but compile-time cuBLAS version is ", + CUBLAS_VERSION, ". Please upgrade to CUDA 13.2 or newer."); +} + +#endif // CUBLAS_VERSION >= 130100 + diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh new file mode 100644 index 0000000000..6514ba2f97 --- /dev/null +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh @@ -0,0 +1,18 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_GEMM_CUBLASLT_GROUPED_GEMM_CUH_ +#define TRANSFORMER_ENGINE_COMMON_GEMM_CUBLASLT_GROUPED_GEMM_CUH_ + +#include +#include +#include + +// nvte_grouped_gemm is declared in transformer_engine/gemm.h +// This header is for internal use only. + +#endif // TRANSFORMER_ENGINE_COMMON_GEMM_CUBLASLT_GROUPED_GEMM_CUH_ + diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 9dfa009115..b2e42bd66f 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -11,7 +11,7 @@ #ifndef TRANSFORMER_ENGINE_GEMM_H_ #define TRANSFORMER_ENGINE_GEMM_H_ -#include +#include #include "transformer_engine.h" @@ -233,6 +233,10 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C + * + * \note Requires cuBLAS 13.2+ (CUDA 13.2+) and Hopper (SM90) or newer GPU architecture. + * Will error at runtime if compiled with an older cuBLAS version or run on + * a pre-Hopper GPU. * * Performs batched GEMM on a collection of matrices with potentially different shapes. * All tensors in the group must have compatible dimensions for matrix multiplication. @@ -262,6 +266,8 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * heuristics. If NULL, computed automatically from A's logical shape. * * Requirements: + * - cuBLAS 13.2+ (CUDA 13.2+) + * - Hopper (SM90) or newer GPU architecture * - A, B, C (if provided), D must have the same num_tensors * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] * - Shape compatibility: if transa=false, transb=false: @@ -270,8 +276,8 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, - NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m, - const int64_t *avg_n, const int64_t *avg_k); + cudaStream_t stream, const int64_t *avg_m, const int64_t *avg_n, + const int64_t *avg_k); #ifdef __cplusplus } // extern "C" From 047a9f93bd5252241883077e0a904b2c7f1c6e57 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Dec 2025 12:29:12 +0100 Subject: [PATCH 17/61] fix Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 5 +++-- transformer_engine/common/CMakeLists.txt | 1 + transformer_engine/common/include/transformer_engine/gemm.h | 3 +-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 0ea76946bc..3336dbc6d5 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -137,8 +137,9 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, // cuBLAS requires aligned pointers for vectorized loads static std::mt19937 gen(12345); std::uniform_int_distribution dist(0, 3); - // Calculate elements needed for 16-byte alignment - const size_t align_elements = (16 * 8) / typeToNumBits(dtype); // 16 bytes / element_size + // Calculate elements needed for 16-byte alignment in bytes, rounded up + const size_t align_elements = + std::max(1, (16 + elem_size - 1) / elem_size); // 16 bytes / element_size return dist(gen) * static_cast(align_elements); }; diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 264f7f9a78..e25bf02439 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -144,6 +144,7 @@ list(APPEND transformer_engine_cuda_sources fused_attn/fused_attn_fp8.cu fused_attn/utils.cu gemm/cublaslt_gemm.cu + gemm/cublaslt_grouped_gemm.cu normalization/layernorm/ln_bwd_semi_cuda_kernel.cu normalization/layernorm/ln_fwd_cuda_kernel.cu normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index b2e42bd66f..f1e2776158 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -234,7 +234,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C * - * \note Requires cuBLAS 13.2+ (CUDA 13.2+) and Hopper (SM90) or newer GPU architecture. + * \note Requires cuBLAS 13.1+ (CUDA 13.1+) and Hopper (SM90) or newer GPU architecture. * Will error at runtime if compiled with an older cuBLAS version or run on * a pre-Hopper GPU. * @@ -253,7 +253,6 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * \param[out] D Output grouped tensor D. * \param[in] workspace_setup Workspace tensor for pointer array setup. * \param[in] workspace_cublas Workspace tensor for cuBLAS operations. - * \param[in] config Matrix multiplication configuration. * \param[in] stream CUDA stream for the operation. * \param[in] avg_m Optional hint for average M dimension across all matrices in the * group. Used by cuBLASLt for algorithm selection heuristics. From c490e06ab71f9919d69bfc2c67eb6b7cf6bc20ff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Dec 2025 11:32:34 +0000 Subject: [PATCH 18/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_grouped_gemm.cu | 11 ++++------- .../common/gemm/cublaslt_grouped_gemm.cuh | 1 - .../common/include/transformer_engine/gemm.h | 2 +- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 4125bd82bf..3647a4c39e 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -180,12 +180,10 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor // Validate alpha/beta have per-matrix values const size_t alpha_numel = alpha_tensor->data.numel(); const size_t beta_numel = beta_tensor->data.numel(); - NVTE_CHECK(alpha_numel == num_tensors, - "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", - alpha_numel); - NVTE_CHECK(beta_numel == num_tensors, - "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", - beta_numel); + NVTE_CHECK(alpha_numel == num_tensors, "Grouped GEMM: alpha must have num_tensors (", num_tensors, + ") elements, got ", alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, "Grouped GEMM: beta must have num_tensors (", num_tensors, + ") elements, got ", beta_numel); auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { return dtype == transformer_engine::DType::kFloat8E4M3 || @@ -596,4 +594,3 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT } #endif // CUBLAS_VERSION >= 130100 - diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh index 6514ba2f97..a032e594d5 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh @@ -15,4 +15,3 @@ // This header is for internal use only. #endif // TRANSFORMER_ENGINE_COMMON_GEMM_CUBLASLT_GROUPED_GEMM_CUH_ - diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index f1e2776158..0c8d601d50 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -11,7 +11,7 @@ #ifndef TRANSFORMER_ENGINE_GEMM_H_ #define TRANSFORMER_ENGINE_GEMM_H_ -#include +#include #include "transformer_engine.h" From e39784572a83cb560fca20f2e7f77f7f7795a834 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 19 Dec 2025 08:35:50 -0800 Subject: [PATCH 19/61] batching working correctly for quant and gemm but slow Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/base.py | 30 ++++-- transformer_engine/jax/cpp_extensions/gemm.py | 94 ++++++++++++++----- .../jax/cpp_extensions/quantization.py | 10 +- transformer_engine/jax/sharding.py | 2 - 4 files changed, 102 insertions(+), 34 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index defdce7b68..335af2eb47 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -175,6 +175,7 @@ def batcher_impl( batched_args: Sequence[Any], batch_dims: Sequence[Union[int, None]], static_kwargs: dict, + output_bdims: Union[Sequence[Union[int, None]], None] = None, ) -> Tuple[Tuple[Any, ...], Tuple[Union[int, None], ...]]: """Batcher implementation for JAX primitives. @@ -207,13 +208,21 @@ def batcher(batched_args, batch_dims, *, arg1, arg2, arg3): if batch_dim is None: batch_dim = bdim batch_size = arg.shape[bdim] - # elif bdim != batch_dim: - # raise ValueError( - # "All batched arguments must have the same batch dimension. " - # f"Got batch_dims={batch_dims}" - # ) + elif output_bdims is None and bdim != batch_dim: + raise ValueError( + "All batched arguments must have the same batch dimension. " + f"Got batch_dims={batch_dims}" + ) + elif arg.shape[bdim] != batch_size: + raise ValueError( + "All batched arguments must have the same batch size. " + f"Got sizes {[arg.shape[bdim] for arg, bdim in zip(batched_args, batch_dims) if bdim is not None]}. " + f"Got batched_args={[arg.shape for arg, bdim in zip(batched_args, batch_dims) if bdim is not None]}." + ) assert batch_dim is not None and batch_size is not None, "Invalid batching config!" + print(f"[{cls.__name__}] Batching with size {batch_size}") + # Loop over batch dimension and collect results all_results = [] @@ -244,9 +253,14 @@ def batcher(batched_args, batch_dims, *, arg1, arg2, arg3): transposed = tuple(zip(*all_results)) # Stack each output along the batch dimension - stacked_results = tuple( - jnp.stack(list(out_list), axis=batch_dim) for out_list in transposed - ) + if output_bdims is not None: + stacked_results = tuple( + jnp.stack(list(out_list), axis=out_bdim) for out_list, out_bdim in zip(transposed, output_bdims) + ) + else: + stacked_results = tuple( + jnp.stack(list(out_list), axis=batch_dim) for out_list in transposed + ) # Single output: return unwrapped result if len(stacked_results) == 1: diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 7d44643046..28100c9715 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -583,27 +583,27 @@ def lowering( ) lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed) - lhs_contracting_size = ( - reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:]) - if lhs_transposed - else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary]) - ) - assert_cublas_requirements( - scaling_mode, - lhs_contracting_size, - "LHS", - ) - rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed) - rhs_contracting_size = ( - reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary]) - if rhs_transposed - else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:]) - ) - assert_cublas_requirements( - scaling_mode, - rhs_contracting_size, - "RHS", - ) + # lhs_contracting_size = ( + # reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:]) + # if lhs_transposed + # else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary]) + # ) + # assert_cublas_requirements( + # scaling_mode, + # lhs_contracting_size, + # f"LHS {lhs_aval.shape} with contracting dims {lhs_cdims}", + # ) + # rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed) + # rhs_contracting_size = ( + # reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary]) + # if rhs_transposed + # else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:]) + # ) + # assert_cublas_requirements( + # scaling_mode, + # rhs_contracting_size, + # f"RHS {rhs_aval.shape} with contracting dims {rhs_cdims}", + # ) args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta) kwargs = { @@ -818,10 +818,60 @@ def batcher( # f"got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims}" # ) + f = partial(GemmPrimitive.outer_impl, + **{ + "out_dtype": out_dtype, + "contracting_dims": contracting_dims, + "scaling_mode": scaling_mode, + "fuse_bias": fuse_bias, + "fuse_gelu": fuse_gelu, + "grad": grad, + "use_split_accumulator": use_split_accumulator, + "collective_op": collective_op, + "transpose_batch_sequence": transpose_batch_sequence, + "sequence_dim": sequence_dim, + "is_outer": is_outer, + }) + + lhs_cdims, rhs_cdims = contracting_dims + # Calculate output batch dimension based on input batch dims and contracting dims + # Both lhs and rhs have batch dimensions that may be at different indices + if lhs_bdims is not None and rhs_bdims is not None: + # Count non-contracting dimensions in LHS before the batch dimension + lhs_non_contracting_before_batch = sum( + 1 for i in range(lhs_bdims) + if i not in lhs_cdims + ) + # The output batch dimension will be at the position corresponding to + # the LHS batch dimension's position among non-contracting dimensions + output_bdim = lhs_non_contracting_before_batch + elif lhs_bdims is not None: + # LHS has a batch dimension - this will be the output batch dimension + output_bdim = 0 + elif rhs_bdims is not None: + # RHS has a batch dimension - need to account for LHS non-contracting dims + lhs_non_contracting = len([i for i in range(len(batched_args[0].shape)) + if i not in lhs_cdims and i != lhs_bdims]) + output_bdim = lhs_non_contracting + else: + # No batch dimensions in either operand + output_bdim = None + # Use general batcher from BasePrimitive return GemmPrimitive.batcher_impl( batched_args, - batch_dims, + batch_dims=( + lhs_bdims, # lhs + 0, # lhs_scale_inv + rhs_bdims, # rhs + 0, # rhs_scale_inv + *(None for _ in batched_args[4:]), # bias, gelu_input, alpha, beta + ), + output_bdims=( + output_bdim, # output + 0, # bias_grad + 0, # pre_gelu_out + ), static_kwargs={ "out_dtype": out_dtype, "contracting_dims": contracting_dims, diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index c5d76cf28c..a95afe8b8e 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -20,7 +20,6 @@ from .base import BasePrimitive, register_primitive from .misc import ( get_padded_spec, - check_valid_batch_dims, te_dtype_to_jax_dtype, jax_dtype_to_te_dtype, multidim_transpose, @@ -362,12 +361,19 @@ def batcher( use_rht, ): """Batch rule for quantization primitive using general batcher.""" - # check_valid_batch_dims(batch_dims) assert BaseDBiasQuantizePrimitive.outer_primitive is not None return BaseDBiasQuantizePrimitive.batcher_impl( batched_args, batch_dims, + output_bdims=( + batch_dims[0], # out + batch_dims[0], # colwise_out (probably need to transpose according if scaling mode does it) + 0, # scale_inv + 0, # colwise_scale_inv + 0, # updated_amax + 0, # dbias + ), static_kwargs={ "out_dtype": out_dtype, "scaling_mode": scaling_mode, diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 01405ba87a..6cb0dd257c 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -261,8 +261,6 @@ def get_mesh_axis_size(axis, mesh=None): if axis is None: return 1 - print(mesh) - assert axis in mesh.shape, f"{axis} is not a axis of the given mesh {mesh.shape}" return mesh.shape[axis] From 59145cc2a7d4e4cb92addbd39c374541cbed5eb9 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 22 Dec 2025 10:21:19 +0100 Subject: [PATCH 20/61] fix Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 7 ++++--- .../common/gemm/cublaslt_grouped_gemm.cu | 10 +++++----- .../common/include/transformer_engine/gemm.h | 6 +++--- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 3336dbc6d5..bdcfa68a4f 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -95,7 +95,8 @@ struct GroupedBuffers { size_t grouped_setup_workspace_size(const size_t num_tensors) { const size_t ptr_bytes = num_tensors * sizeof(void*); const size_t int_bytes = num_tensors * sizeof(int); - size_t size = 4 * ptr_bytes + 3 * int_bytes + 2 * ptr_bytes; + // Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 3 int arrays (M, N, K) + size_t size = 6 * ptr_bytes + 3 * int_bytes; const size_t alignment = 256; size = ((size + alignment - 1) / alignment) * alignment; return size; @@ -320,8 +321,8 @@ void run_grouped_gemm_case(const TestParams& params) { GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is " << CUBLAS_VERSION << "."; #else - if (getDeviceComputeCapability() < hopperComputeCapability) { - GTEST_SKIP() << "Grouped GEMM requires Hopper (SM90) or newer."; + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer."; } const std::vector> shapes = make_shapes(params.shape_case); diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 3647a4c39e..40180fe760 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -503,10 +503,10 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT NVTE_API_CALL(nvte_grouped_gemm); using namespace transformer_engine; - // Grouped GEMM requires Hopper (SM90) or newer + // Grouped GEMM requires Blackwell (SM100) or newer const int current_device = cuda::current_device(); - NVTE_CHECK(cuda::sm_arch(current_device) >= 90, - "nvte_grouped_gemm requires Hopper (SM90) or newer architecture."); + NVTE_CHECK(cuda::sm_arch(current_device) >= 100, + "nvte_grouped_gemm requires Blackwell (SM100) or newer architecture."); // Convert to internal types const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); @@ -589,8 +589,8 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, cudaStream_t stream, const int64_t *avg_m, const int64_t *avg_n, const int64_t *avg_k) { - NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.2+, but compile-time cuBLAS version is ", - CUBLAS_VERSION, ". Please upgrade to CUDA 13.2 or newer."); + NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.1+, but compile-time cuBLAS version is ", + CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); } #endif // CUBLAS_VERSION >= 130100 diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 0c8d601d50..168141224c 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -234,9 +234,9 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C * - * \note Requires cuBLAS 13.1+ (CUDA 13.1+) and Hopper (SM90) or newer GPU architecture. + * \note Requires cuBLAS 13.1+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture. * Will error at runtime if compiled with an older cuBLAS version or run on - * a pre-Hopper GPU. + * a pre-Blackwell GPU. * * Performs batched GEMM on a collection of matrices with potentially different shapes. * All tensors in the group must have compatible dimensions for matrix multiplication. @@ -266,7 +266,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * * Requirements: * - cuBLAS 13.2+ (CUDA 13.2+) - * - Hopper (SM90) or newer GPU architecture + * - Blackwell (SM100) or newer GPU architecture * - A, B, C (if provided), D must have the same num_tensors * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] * - Shape compatibility: if transa=false, transb=false: From 77b422ac8d6e33bb5d56651a2e956629c17a5db8 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 22 Dec 2025 10:47:19 +0100 Subject: [PATCH 21/61] Require Blackwell (SM100) and cuBLAS 13.1+ for grouped GEMM Signed-off-by: Pawel Gadzinski --- 3rdparty/cudnn-frontend | 2 +- tests/cpp/operator/test_grouped_gemm.cu | 4 ++-- transformer_engine/common/include/transformer_engine/gemm.h | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 0258951d4d..be6c079be8 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 0258951d4d512f4714eb1574496f4d57669b1b93 +Subproject commit be6c079be8aaffa0fc079fcf039887e637c289c7 diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index bdcfa68a4f..2514f11ab3 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -317,8 +317,8 @@ std::vector> make_shapes(ShapeCase scase) { } void run_grouped_gemm_case(const TestParams& params) { -#if CUBLAS_VERSION < 130200 - GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is " +#if CUBLAS_VERSION < 130100 + GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.1+, but compile-time cuBLAS version is " << CUBLAS_VERSION << "."; #else if (getDeviceComputeCapability() < blackwellComputeCapability) { diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 168141224c..f4c60ca3fe 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -265,7 +265,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * heuristics. If NULL, computed automatically from A's logical shape. * * Requirements: - * - cuBLAS 13.2+ (CUDA 13.2+) + * - cuBLAS 13.1+ (CUDA 13.1+) * - Blackwell (SM100) or newer GPU architecture * - A, B, C (if provided), D must have the same num_tensors * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] From 9c8158ee86a30699710c0dc1cb17c5d9b9aa4ced Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 22 Dec 2025 11:28:47 +0100 Subject: [PATCH 22/61] fix Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 2514f11ab3..ada6980858 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -482,7 +482,7 @@ void run_grouped_gemm_case(const TestParams& params) { atol, rtol); } -#endif // CUBLAS_VERSION >= 130200 +#endif // CUBLAS_VERSION >= 130100 } class GroupedGemmTest : public ::testing::TestWithParam {}; From b1e0893be9eb00495765f65c636b23eae698afc1 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 22 Dec 2025 11:22:11 -0800 Subject: [PATCH 23/61] fix --- transformer_engine/common/gemm/cublaslt_gemm.cu | 8 ++++---- transformer_engine/jax/dense.py | 13 ++++++------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 118bf19335..92d89b425f 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -154,8 +154,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK(ret.lda % 16 == 0, - "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); + // NVTE_CHECK(ret.lda % 16 == 0, + // "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. @@ -245,8 +245,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK(ret.ldb % 16 == 0, - "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); + // NVTE_CHECK(ret.ldb % 16 == 0, + // "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { if (is_B_transposed) { diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 62b0e054aa..9db60d3bd8 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -244,28 +244,27 @@ def dot_general_transpose_lhs(g, x, y, *, dimension_numbers, import numpy as np def _remaining(original, *removed_lists): removed = set(itertools.chain(*removed_lists)) - return [i for i in original if i not in removed] + return tuple(i for i in original if i not in removed) def _ranges_like(*xs): start = 0 for x in xs: x_len = len(x) - yield range(start, start + x_len) + yield tuple(range(start, start + x_len)) start += x_len (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers x_ndim = x.ndim - x_kept = _remaining(range(x_ndim), x_contract, x_batch) - y_kept = _remaining(range(y.ndim), y_contract, y_batch) + x_kept = _remaining(tuple(range(x_ndim)), x_contract, x_batch) + y_kept = _remaining(tuple(range(y.ndim)), y_contract, y_batch) if swap_ans: ans_batch, ans_y, _ = _ranges_like(x_batch, y_kept, x_kept) else: ans_batch, _, ans_y = _ranges_like(x_batch, x_kept, y_kept) dims = ((ans_y, y_kept), (ans_batch, y_batch)) - x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) - out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) + x_contract_sorted_by_y = tuple(np.take(x_contract, np.argsort(y_contract))) + out_axes = np.argsort(tuple(x_batch) + x_kept + x_contract_sorted_by_y) x_bar = jax.lax.transpose( - # TODO(jberchtold): I'm ignoring the batch_dims here, do I need to explicitly use vmap or something? tex.gemm(g, y, contracting_dims=dims[0]), tuple(out_axes) ) From fb2067bacb9c21b71ff6cd329cae542415400887 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 23 Dec 2025 10:03:29 -0800 Subject: [PATCH 24/61] move einsum logic into TE --- transformer_engine/jax/flax/__init__.py | 3 +- transformer_engine/jax/flax/module.py | 62 +++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index d1a9cb47f8..59a0958b7b 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -4,7 +4,7 @@ """Transformer Engine bindings for JAX""" from .module import DenseGeneral, LayerNorm from .module import LayerNormDenseGeneral, LayerNormMLP -from .module import wrap_function_in_te_state_module, make_dot_general_cls +from .module import wrap_function_in_te_state_module, make_dot_general_cls, make_einsum_cls from .transformer import extend_logical_axis_rules from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import TransformerLayer, TransformerLayerType @@ -16,6 +16,7 @@ "LayerNormMLP", "wrap_function_in_te_state_module", "make_dot_general_cls", + "make_einsum_cls", "extend_logical_axis_rules", "DotProductAttention", "MultiHeadAttention", diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index dcfb812896..ca84d46d6b 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1438,3 +1438,65 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): ) return wrap_function_in_te_state_module(te_dot_general, quantization_recipe, "dot_general") + +def make_einsum_cls(quantization_recipe): + import functools + import jax + def te_einsum(generate_quantizer_set, s, x, kernel, **kwargs): + quantizer_set = generate_quantizer_set() + def dot_general(x, kernel, dims, *args, **kwargs): + # print(f"TE dot_general called with dims: {dims}, args: {args}, kwargs: {kwargs}") + contracting_dims, batch_dims = dims + ((x_bdim,), (k_bdim,)) = batch_dims + batch_dims = (x_bdim, k_bdim) + + if x_bdim != 0 or k_bdim != 0: + print(f"{x_bdim=}, {k_bdim=}") + return jax.lax.dot_general(x, kernel, dims, *args, **kwargs) + + if x.dtype not in [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]: + # HACK: because x input is bool for dispatch mask + x = x.astype(kernel.dtype) + + # Adjust for unbatched + contracting_dims = tuple( + tuple(dim - (1 if dim > bdim else 0) for dim in cdims) + for bdim, cdims in zip(batch_dims, contracting_dims)) + + f = functools.partial( + dense, + contracting_dims=contracting_dims, + quantizer_set=quantizer_set) + return jax.vmap(f, in_axes=(x_bdim, k_bdim))( + x, + kernel, + ) + + group_sizes = None + + # assuming x batch dim is axis 0, squash dims so we have (B*M, K) + # import math + # num_groups = x.shape[0] + # group_size = math.prod(x.shape[1:-1]) + # x_orig_ndim = x.ndim + # # FIXME: breaks partitioning + # x = x.reshape(x.shape[0] * group_size, x.shape[-1]) + # contracting_dims = ( + # tuple([c - (x_orig_ndim - x.ndim) for c in contracting_dims[0]]), + # *contracting_dims[1:], + # ) + + # group_sizes = jnp.array([group_size]*num_groups, dtype=jnp.int32) + + # print(f'{group_sizes=}, {contracting_dims=}, {x.shape=}, {kernel.shape=}, {contracting_dims=}') + + # return transformer_engine.jax.dense.grouped_dense( + # x, + # kernel, + # group_sizes=group_sizes, + # contracting_dims=contracting_dims, + # # quantizer_set=quantizer_set + # ) + return jnp.einsum(s, x, kernel, _dot_general=dot_general, **kwargs) + + return wrap_function_in_te_state_module(te_einsum, quantization_recipe, "einsum")() From 30716a622c2d1f381de0e09800ef9936b030c420 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 23 Dec 2025 10:42:36 -0800 Subject: [PATCH 25/61] einsum unit tests --- tests/jax/test_custom_call_compute.py | 41 +++++++++++++++++++++++++++ transformer_engine/jax/flax/module.py | 7 ++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 897d9f683e..7a81683bc7 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1974,3 +1974,44 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype) + +class TestEinsum: + + def _te_einsum(self, eqn, a, b, quantization_recipe): + from transformer_engine.jax.flax import make_einsum_cls + + te_einsum = make_einsum_cls(quantization_recipe=quantization_recipe) + var_collect = te_einsum.init(jax.random.PRNGKey(0), eqn, a, b) + return te_einsum.apply(var_collect, eqn, a, b) + + def _ref_einsum(self, eqn, a, b): + return jnp.einsum(eqn, a, b) + + @pytest_parametrize_wrapper('eqn,a_shape,b_shape', [ + # ('ij,jk->ik', (64, 32), (32, 128)), + # ('bij,bjk->bik', (8, 64, 32), (8, 32, 128)), + # ('abc,cde->abde', (4, 8, 16), (16, 32, 64)), + ('BSM,BSEC->EBCM', (2, 4096, 4096), (2, 4096, 8, 1024)), + ('EBCM,EMH->EBCH', (8, 2, 1024, 4096), (8, 4096, 14336)) , + ('EBCM,EMH->EBCH', (8, 2, 1024, 4096), (8, 4096, 14336)), + ('EBCH,EHM->EBCM', (8, 2, 1024, 14336), (8, 14336, 4096)), + ('EBCM,BSEC->BSM', (8, 2, 1024, 4096), (2, 4096, 8, 1024)), + ]) + @pytest_parametrize_wrapper('dtype', [jnp.bfloat16]) + @pytest_parametrize_wrapper('quantization_recipe', supported_recipes) + def test_einsum(self, eqn, a_shape, b_shape, dtype, quantization_recipe): + from transformer_engine.common.recipe import Float8CurrentScaling + import functools + + if not isinstance(quantization_recipe, Float8CurrentScaling): + pytest.skip("Einsum currently only supports Float8CurrentScaling recipe.") + return + key = jax.random.PRNGKey(0) + subkeys = jax.random.split(key, 2) + a = jax.random.uniform(subkeys[0], a_shape, dtype=dtype) + b = jax.random.uniform(subkeys[1], b_shape, dtype=dtype) + + te_out = jax.jit(functools.partial(self._te_einsum, eqn, quantization_recipe=quantization_recipe))(a, b) + ref_out = jax.jit(functools.partial(self._ref_einsum, eqn))(a, b) + + assert_allclose(te_out, ref_out, dtype=dtype) \ No newline at end of file diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index ca84d46d6b..0399ccfabf 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1443,7 +1443,8 @@ def make_einsum_cls(quantization_recipe): import functools import jax def te_einsum(generate_quantizer_set, s, x, kernel, **kwargs): - quantizer_set = generate_quantizer_set() + # with open("/tmp/te_einsum_log.txt", "a") as f: + # f.write(f"{(s, x.shape, kernel.shape)}\n") def dot_general(x, kernel, dims, *args, **kwargs): # print(f"TE dot_general called with dims: {dims}, args: {args}, kwargs: {kwargs}") contracting_dims, batch_dims = dims @@ -1453,6 +1454,10 @@ def dot_general(x, kernel, dims, *args, **kwargs): if x_bdim != 0 or k_bdim != 0: print(f"{x_bdim=}, {k_bdim=}") return jax.lax.dot_general(x, kernel, dims, *args, **kwargs) + + quantizer_set = generate_quantizer_set() + print(f'{quantizer_set=}') + # import pdb; pdb.set_trace() if x.dtype not in [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]: # HACK: because x input is bool for dispatch mask From 349c3155fdd34b1fc1ca009252ac64105fc6c24e Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 23 Dec 2025 10:47:19 -0800 Subject: [PATCH 26/61] fwd bwd einsum test --- tests/jax/test_custom_call_compute.py | 56 ++++++++++++++++++++------- 1 file changed, 42 insertions(+), 14 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 7a81683bc7..082a99cd8b 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1975,6 +1975,18 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype) +@pytest_parametrize_wrapper('eqn,a_shape,b_shape', [ + # ('ij,jk->ik', (64, 32), (32, 128)), + # ('bij,bjk->bik', (8, 64, 32), (8, 32, 128)), + # ('abc,cde->abde', (4, 8, 16), (16, 32, 64)), + ('BSM,BSEC->EBCM', (2, 4096, 4096), (2, 4096, 8, 1024)), + ('EBCM,EMH->EBCH', (8, 2, 1024, 4096), (8, 4096, 14336)) , + ('EBCM,EMH->EBCH', (8, 2, 1024, 4096), (8, 4096, 14336)), + ('EBCH,EHM->EBCM', (8, 2, 1024, 14336), (8, 14336, 4096)), + ('EBCM,BSEC->BSM', (8, 2, 1024, 4096), (2, 4096, 8, 1024)), +]) +@pytest_parametrize_wrapper('dtype', [jnp.bfloat16]) +@pytest_parametrize_wrapper('quantization_recipe', supported_recipes) class TestEinsum: def _te_einsum(self, eqn, a, b, quantization_recipe): @@ -1987,19 +1999,7 @@ def _te_einsum(self, eqn, a, b, quantization_recipe): def _ref_einsum(self, eqn, a, b): return jnp.einsum(eqn, a, b) - @pytest_parametrize_wrapper('eqn,a_shape,b_shape', [ - # ('ij,jk->ik', (64, 32), (32, 128)), - # ('bij,bjk->bik', (8, 64, 32), (8, 32, 128)), - # ('abc,cde->abde', (4, 8, 16), (16, 32, 64)), - ('BSM,BSEC->EBCM', (2, 4096, 4096), (2, 4096, 8, 1024)), - ('EBCM,EMH->EBCH', (8, 2, 1024, 4096), (8, 4096, 14336)) , - ('EBCM,EMH->EBCH', (8, 2, 1024, 4096), (8, 4096, 14336)), - ('EBCH,EHM->EBCM', (8, 2, 1024, 14336), (8, 14336, 4096)), - ('EBCM,BSEC->BSM', (8, 2, 1024, 4096), (2, 4096, 8, 1024)), - ]) - @pytest_parametrize_wrapper('dtype', [jnp.bfloat16]) - @pytest_parametrize_wrapper('quantization_recipe', supported_recipes) - def test_einsum(self, eqn, a_shape, b_shape, dtype, quantization_recipe): + def test_einsum_fwd(self, eqn, a_shape, b_shape, dtype, quantization_recipe): from transformer_engine.common.recipe import Float8CurrentScaling import functools @@ -2014,4 +2014,32 @@ def test_einsum(self, eqn, a_shape, b_shape, dtype, quantization_recipe): te_out = jax.jit(functools.partial(self._te_einsum, eqn, quantization_recipe=quantization_recipe))(a, b) ref_out = jax.jit(functools.partial(self._ref_einsum, eqn))(a, b) - assert_allclose(te_out, ref_out, dtype=dtype) \ No newline at end of file + assert_allclose(te_out, ref_out, dtype=dtype) + + def test_einsum_fwd_and_bwd(self, eqn, a_shape, b_shape, dtype, quantization_recipe): + from transformer_engine.common.recipe import Float8CurrentScaling + import functools + + if not isinstance(quantization_recipe, Float8CurrentScaling): + pytest.skip("Einsum currently only supports Float8CurrentScaling recipe.") + return + key = jax.random.PRNGKey(0) + subkeys = jax.random.split(key, 2) + a = jax.random.uniform(subkeys[0], a_shape, dtype=dtype) + b = jax.random.uniform(subkeys[1], b_shape, dtype=dtype) + + def wrap_in_mean(f): + @functools.wraps(f) + def wrapped(*args): + return jnp.mean(f(*args)) + return wrapped + + te_fwd, te_grads = jax.jit(jax.value_and_grad(wrap_in_mean(functools.partial(self._te_einsum, eqn, quantization_recipe=quantization_recipe))))(a, b) + ref_fwd, ref_grads = jax.jit(jax.value_and_grad(wrap_in_mean(functools.partial(self._ref_einsum, eqn))))(a, b) + + assert_allclose(te_fwd, ref_fwd, dtype=dtype) + + assert len(te_grads) == len(ref_grads), f"Number of gradients differ: {len(te_grads)=} vs {len(ref_grads)=}" + + for te_grad, ref_grad in zip(te_grads, ref_grads): + assert_allclose(te_grad, ref_grad, dtype=dtype) \ No newline at end of file From 57ab3b09c9baf1587aaca4ecb5632b91021e1c14 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 23 Dec 2025 11:12:59 -0800 Subject: [PATCH 27/61] unit tests passed with grouped gemm in bf16 --- transformer_engine/jax/flax/module.py | 78 +++++++++++++++------------ 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 0399ccfabf..733eaf513b 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -17,7 +17,7 @@ from jax.ad_checkpoint import checkpoint_name -from ..dense import dense +from ..dense import dense, grouped_dense from ..layernorm import canonicalize_norm_type from ..layernorm import layernorm @@ -1455,9 +1455,9 @@ def dot_general(x, kernel, dims, *args, **kwargs): print(f"{x_bdim=}, {k_bdim=}") return jax.lax.dot_general(x, kernel, dims, *args, **kwargs) + target_out_shape = jax.lax.dot_general(x, kernel, dims).shape + # TODO: add num groups to make grouped quantizer set quantizer_set = generate_quantizer_set() - print(f'{quantizer_set=}') - # import pdb; pdb.set_trace() if x.dtype not in [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]: # HACK: because x input is bool for dispatch mask @@ -1468,40 +1468,50 @@ def dot_general(x, kernel, dims, *args, **kwargs): tuple(dim - (1 if dim > bdim else 0) for dim in cdims) for bdim, cdims in zip(batch_dims, contracting_dims)) - f = functools.partial( - dense, - contracting_dims=contracting_dims, - quantizer_set=quantizer_set) - return jax.vmap(f, in_axes=(x_bdim, k_bdim))( + group_sizes = None + print(f'{x.shape=}, {kernel.shape=}, {dims=}') + + def reorder_lhs_for_grouped_gemm(tensor, cdims): + # (B*M, K) + assert len(cdims) == 1, f"Only support single contracting dim for now, got {cdims}" + cdim = cdims[0] + 1 # account for batch dim at front + out = jnp.transpose(tensor, tuple(range(cdim)) + tuple(range(cdim + 1, tensor.ndim)) + (cdim,)) + return out.reshape((-1, out.shape[-1])) + + + def reorder_rhs_for_grouped_gemm(tensor, bdims, cdims): + # (B, K, N) + assert len(bdims) == 1 and len(cdims) == 1, f"Only support single batch and contracting dim for now, got {bdims}, {cdims}" + bdim = bdims[0] + assert bdim == 0, f"Only support batch dim 0 for now, got {bdim}" + cdim = cdims[0] + 1 # account for batch dim at front + out = jnp.transpose(tensor, (bdim, cdim) + tuple(i for i in range(tensor.ndim) if i != bdim and i != cdim)) + return out.reshape((*out.shape[:2], -1)) + + x = reorder_lhs_for_grouped_gemm(x, contracting_dims[0]) + kernel = reorder_rhs_for_grouped_gemm(kernel, (batch_dims[1],), contracting_dims[1]) + + num_groups = kernel.shape[0] + group_size = x.shape[0] // num_groups + + group_sizes = jnp.array([group_size]*num_groups, dtype=jnp.int32) + + print(f'{group_sizes=}, {contracting_dims=}, {x.shape=}, {kernel.shape=}, {contracting_dims=}') + + contracting_dims = ( + # (B*M, K) + (1,), + # (B, K, N) + (1,), + ) + out = grouped_dense( x, kernel, + group_sizes=group_sizes, + contracting_dims=contracting_dims, + # quantizer_set=quantizer_set ) - - group_sizes = None - - # assuming x batch dim is axis 0, squash dims so we have (B*M, K) - # import math - # num_groups = x.shape[0] - # group_size = math.prod(x.shape[1:-1]) - # x_orig_ndim = x.ndim - # # FIXME: breaks partitioning - # x = x.reshape(x.shape[0] * group_size, x.shape[-1]) - # contracting_dims = ( - # tuple([c - (x_orig_ndim - x.ndim) for c in contracting_dims[0]]), - # *contracting_dims[1:], - # ) - - # group_sizes = jnp.array([group_size]*num_groups, dtype=jnp.int32) - - # print(f'{group_sizes=}, {contracting_dims=}, {x.shape=}, {kernel.shape=}, {contracting_dims=}') - - # return transformer_engine.jax.dense.grouped_dense( - # x, - # kernel, - # group_sizes=group_sizes, - # contracting_dims=contracting_dims, - # # quantizer_set=quantizer_set - # ) + return out.reshape(target_out_shape) return jnp.einsum(s, x, kernel, _dot_general=dot_general, **kwargs) return wrap_function_in_te_state_module(te_einsum, quantization_recipe, "einsum")() From ab98852671870d1ebabeaf22eb65609d536ca744 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 23 Dec 2025 11:26:56 -0800 Subject: [PATCH 28/61] grouped quantization working for single gpu --- transformer_engine/jax/flax/module.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 733eaf513b..cc6088e8d2 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -377,6 +377,7 @@ def generate_quantizer_set( variable_collection: str = None, quantization_checkpoint_name: Optional[str] = None, fp8_recipe=None, + n_groups: int = None, ): """ Generate a set of FP8 meta for a GEMM. @@ -409,6 +410,7 @@ def generate_quantizer_set( fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set, checkpoint_name=quantization_checkpoint_name, + n_groups=n_groups, ) return quantizer_set @@ -1379,12 +1381,13 @@ def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] class TEWrapper(te.flax.module.TransformerEngineBase): """Wrapper Flax module for TransformerEngine quantization support.""" - def generate_quantizer_set(self, postfix: str = ""): + def generate_quantizer_set(self, postfix: str = "", n_groups: int = None): OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" return super().generate_quantizer_set( postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, fp8_recipe=quantization_recipe, + n_groups=n_groups, ) @nn.compact @@ -1456,8 +1459,6 @@ def dot_general(x, kernel, dims, *args, **kwargs): return jax.lax.dot_general(x, kernel, dims, *args, **kwargs) target_out_shape = jax.lax.dot_general(x, kernel, dims).shape - # TODO: add num groups to make grouped quantizer set - quantizer_set = generate_quantizer_set() if x.dtype not in [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]: # HACK: because x input is bool for dispatch mask @@ -1496,6 +1497,8 @@ def reorder_rhs_for_grouped_gemm(tensor, bdims, cdims): group_sizes = jnp.array([group_size]*num_groups, dtype=jnp.int32) + quantizer_set = generate_quantizer_set(n_groups=num_groups) + print(f'{group_sizes=}, {contracting_dims=}, {x.shape=}, {kernel.shape=}, {contracting_dims=}') contracting_dims = ( @@ -1509,7 +1512,7 @@ def reorder_rhs_for_grouped_gemm(tensor, bdims, cdims): kernel, group_sizes=group_sizes, contracting_dims=contracting_dims, - # quantizer_set=quantizer_set + quantizer_set=quantizer_set ) return out.reshape(target_out_shape) return jnp.einsum(s, x, kernel, _dot_general=dot_general, **kwargs) From ed540c8e5566d46f2ddb645fdd3940ff94d310c3 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 30 Dec 2025 09:19:40 +0000 Subject: [PATCH 29/61] fixes Signed-off-by: Pawel Gadzinski --- 3rdparty/cudnn-frontend | 2 +- tests/cpp/operator/test_grouped_gemm.cu | 6 +- transformer_engine/common/gemm/config.cpp | 86 +++++++++++ transformer_engine/common/gemm/config.h | 19 +++ .../common/gemm/cublaslt_gemm.cu | 1 - .../common/gemm/cublaslt_grouped_gemm.cu | 23 +-- .../common/include/transformer_engine/gemm.h | 134 ++++++++++++++++-- 7 files changed, 245 insertions(+), 26 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index be6c079be8..0258951d4d 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit be6c079be8aaffa0fc079fcf039887e637c289c7 +Subproject commit 0258951d4d512f4714eb1574496f4d57669b1b93 diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index ada6980858..4d6e1b7bb9 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -457,10 +457,8 @@ void run_grouped_gemm_case(const TestParams& params) { grouped_D.get_handle(), setup_ws.data(), cublas_ws.data(), - 0, - nullptr, - nullptr, - nullptr); + nullptr, // config (use defaults) + 0); for (size_t i = 0; i < num_gemms; ++i) { Tensor grouped_split("grouped_D" + std::to_string(i), diff --git a/transformer_engine/common/gemm/config.cpp b/transformer_engine/common/gemm/config.cpp index cf211beaf9..bf0b7bc2bf 100644 --- a/transformer_engine/common/gemm/config.cpp +++ b/transformer_engine/common/gemm/config.cpp @@ -114,3 +114,89 @@ void nvte_destroy_matmul_config(NVTEMatmulConfig config) { delete reinterpret_cast(config); } } + +NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config() { + return new transformer_engine::GroupedMatmulConfig; +} + +void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, + NVTEGroupedMatmulConfigAttribute attr, void *buf, + size_t size_in_bytes, size_t *size_written) { + // Write attribute size + NVTE_CHECK(attr < kNVTEGroupedMatmulConfigNumAttributes, + "Invalid NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); + NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)"); + const auto &attr_size = transformer_engine::GroupedMatmulConfig::attr_sizes[attr]; + *size_written = attr_size; + + // Return immediately if buffer is not provided + if (buf == nullptr) { + return; + } + + // Check buffer size + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for grouped matmul config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + + // Write to buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)"); + const auto &config_ = + *reinterpret_cast(config); + switch (attr) { + case kNVTEGroupedMatmulConfigAvgM: + std::memcpy(buf, &config_.avg_m, attr_size); + break; + case kNVTEGroupedMatmulConfigAvgN: + std::memcpy(buf, &config_.avg_n, attr_size); + break; + case kNVTEGroupedMatmulConfigAvgK: + std::memcpy(buf, &config_.avg_k, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, + NVTEGroupedMatmulConfigAttribute attr, + const void *buf, size_t size_in_bytes) { + // Check attribute and buffer + NVTE_CHECK(attr < kNVTEGroupedMatmulConfigNumAttributes, + "Invalid NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); + const auto &attr_size = transformer_engine::GroupedMatmulConfig::attr_sizes[attr]; + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for grouped matmul config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); + + // Read from buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)"); + auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEGroupedMatmulConfigAvgM: + std::memcpy(&config_.avg_m, buf, attr_size); + config_.avg_m_set = true; + break; + case kNVTEGroupedMatmulConfigAvgN: + std::memcpy(&config_.avg_n, buf, attr_size); + config_.avg_n_set = true; + break; + case kNVTEGroupedMatmulConfigAvgK: + std::memcpy(&config_.avg_k, buf, attr_size); + config_.avg_k_set = true; + break; + default: + NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config) { + if (config != nullptr) { + delete reinterpret_cast(config); + } +} diff --git a/transformer_engine/common/gemm/config.h b/transformer_engine/common/gemm/config.h index 54ccf06a53..4f93ff7fbc 100644 --- a/transformer_engine/common/gemm/config.h +++ b/transformer_engine/common/gemm/config.h @@ -31,6 +31,25 @@ struct MatmulConfig { }; }; +struct GroupedMatmulConfig { + // Average dimension hints for cuBLASLt algorithm selection heuristics. + // Value of 0 means "not set" - compute automatically from tensor shapes. + int64_t avg_m = 0; + int64_t avg_n = 0; + int64_t avg_k = 0; + + // Track which attributes have been explicitly set + bool avg_m_set = false; + bool avg_n_set = false; + bool avg_k_set = false; + + static constexpr size_t attr_sizes[] = { + sizeof(int64_t), // avg_m + sizeof(int64_t), // avg_n + sizeof(int64_t) // avg_k + }; +}; + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_GEMM_CONFIG_H_ diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 1d931da4aa..118bf19335 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -23,7 +23,6 @@ #include "../util/logging.h" #include "../util/multi_stream.h" #include "./config.h" -#include "./cublaslt_grouped_gemm.cuh" #include "./cutlass_grouped_gemm.cuh" namespace { diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 40180fe760..03692bf052 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -16,6 +16,7 @@ #include "../util/cuda_runtime.h" #include "../util/handle_manager.h" #include "../util/logging.h" +#include "./config.h" #include "./cublaslt_grouped_gemm.cuh" namespace { @@ -498,8 +499,7 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, - cudaStream_t stream, const int64_t *avg_m, const int64_t *avg_n, - const int64_t *avg_k) { + NVTEGroupedMatmulConfig config, cudaStream_t stream) { NVTE_API_CALL(nvte_grouped_gemm); using namespace transformer_engine; @@ -518,6 +518,12 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT Tensor *wspace_setup = convertNVTETensor(workspace_setup); Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); + // Parse config (if provided) + GroupedMatmulConfig config_; + if (config != nullptr) { + config_ = *reinterpret_cast(config); + } + // Validate inputs and num_tensors validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD, alpha_tensor, beta_tensor); @@ -564,11 +570,11 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT // Compute average dimensions for heuristics // K dimension: if transa, K is A's first dim; if not, K is A's last dim - int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); - int64_t avg_n_val = avg_n ? *avg_n : compute_avg_last_dim(outputD); - int64_t avg_k_val = avg_k ? *avg_k - : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) - : compute_avg_last_dim(A_sel.tensor)); + int64_t avg_m_val = config_.avg_m_set ? config_.avg_m : compute_avg_first_dim(outputD); + int64_t avg_n_val = config_.avg_n_set ? config_.avg_n : compute_avg_last_dim(outputD); + int64_t avg_k_val = config_.avg_k_set ? config_.avg_k + : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) + : compute_avg_last_dim(A_sel.tensor)); // Heuristic selection cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, @@ -587,8 +593,7 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, - cudaStream_t stream, const int64_t *avg_m, const int64_t *avg_n, - const int64_t *avg_k) { + NVTEGroupedMatmulConfig config, cudaStream_t stream) { NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.1+, but compile-time cuBLAS version is ", CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); } diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index f4c60ca3fe..00fd0b7048 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -22,6 +22,9 @@ extern "C" { /*! \brief Configuration for matrix multiplication. */ typedef void *NVTEMatmulConfig; +/*! \brief Configuration for grouped matrix multiplication. */ +typedef void *NVTEGroupedMatmulConfig; + /*! \enum NVTEMatmulConfigAttribute * \brief Type of option for matrix multiplication. */ @@ -54,6 +57,34 @@ enum NVTEMatmulConfigAttribute { kNVTEMatmulConfigNumAttributes }; +/*! \enum NVTEGroupedMatmulConfigAttribute + * \brief Type of option for grouped matrix multiplication. + */ +enum NVTEGroupedMatmulConfigAttribute { + /*! Average M dimension hint + * + * Optional hint for average M dimension across all matrices in the group. + * Used by cuBLASLt for algorithm selection heuristics. If not set, + * computed automatically from D's logical shape. + */ + kNVTEGroupedMatmulConfigAvgM = 0, + /*! Average N dimension hint + * + * Optional hint for average N dimension across all matrices in the group. + * Used by cuBLASLt for algorithm selection heuristics. If not set, + * computed automatically from D's logical shape. + */ + kNVTEGroupedMatmulConfigAvgN = 1, + /*! Average K (reduction) dimension hint + * + * Optional hint for average K dimension across all matrices in the group. + * Used by cuBLASLt for algorithm selection heuristics. If not set, + * computed automatically from A's logical shape. + */ + kNVTEGroupedMatmulConfigAvgK = 2, + kNVTEGroupedMatmulConfigNumAttributes +}; + /*! \brief Create a matrix multiplication configuration. */ NVTEMatmulConfig nvte_create_matmul_config(); @@ -84,6 +115,38 @@ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA /*! \brief Destroy a matrix multiplication configuration. */ void nvte_destroy_matmul_config(NVTEMatmulConfig config); +/*! \brief Create a grouped matrix multiplication configuration. */ +NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config(); + +/*! \brief Query an option in grouped matrix multiplication configuration. + * + * \param[in] config Grouped matrix multiplication configuration. + * \param[in] attr Option type. + * \param[out] buf Memory address to write option value. Ignored if + * NULL. + * \param[in] size_in_bytes Size of buf. + * \param[out] size_written Number of bytes that have been written to + * buf. If buf is NULL, then the number of + * bytes that would have been written. + */ +void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, + NVTEGroupedMatmulConfigAttribute attr, void *buf, + size_t size_in_bytes, size_t *size_written); + +/*! \brief Set an option in grouped matrix multiplication configuration. + * + * \param[in] config Grouped matrix multiplication configuration. + * \param[in] attr Option type. + * \param[out] buf Memory address to read option value. + * \param[in] size_in_bytes Size of buf. + */ +void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, + NVTEGroupedMatmulConfigAttribute attr, + const void *buf, size_t size_in_bytes); + +/*! \brief Destroy a grouped matrix multiplication configuration. */ +void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config); + /*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated). * * This has been deprecated in favor of nvte_cublas_gemm_v2. @@ -253,16 +316,8 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * \param[out] D Output grouped tensor D. * \param[in] workspace_setup Workspace tensor for pointer array setup. * \param[in] workspace_cublas Workspace tensor for cuBLAS operations. + * \param[in] config Additional configuration (can be NULL for defaults). * \param[in] stream CUDA stream for the operation. - * \param[in] avg_m Optional hint for average M dimension across all matrices in the - * group. Used by cuBLASLt for algorithm selection heuristics. - * If NULL, computed automatically from D's logical shape. - * \param[in] avg_n Optional hint for average N dimension across all matrices in the - * group. Used by cuBLASLt for algorithm selection heuristics. - * If NULL, computed automatically from D's logical shape. - * \param[in] avg_k Optional hint for average K (reduction) dimension across all - * matrices in the group. Used by cuBLASLt for algorithm selection - * heuristics. If NULL, computed automatically from A's logical shape. * * Requirements: * - cuBLAS 13.1+ (CUDA 13.1+) @@ -275,8 +330,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, - cudaStream_t stream, const int64_t *avg_m, const int64_t *avg_n, - const int64_t *avg_k); + NVTEGroupedMatmulConfig config, cudaStream_t stream); #ifdef __cplusplus } // extern "C" @@ -376,6 +430,64 @@ class MatmulConfigWrapper { NVTEMatmulConfig config_ = nullptr; }; +/*! \struct GroupedMatmulConfigWrapper + * \brief C++ wrapper for NVTEGroupedMatmulConfig. + */ +class GroupedMatmulConfigWrapper { + public: + GroupedMatmulConfigWrapper() : config_{nvte_create_grouped_matmul_config()} {} + + GroupedMatmulConfigWrapper(const GroupedMatmulConfigWrapper &) = delete; + GroupedMatmulConfigWrapper &operator=(const GroupedMatmulConfigWrapper &) = delete; + + GroupedMatmulConfigWrapper(GroupedMatmulConfigWrapper &&other) : config_{other.config_} { + other.config_ = nullptr; + } + GroupedMatmulConfigWrapper &operator=(GroupedMatmulConfigWrapper &&other) { + if (config_ != nullptr) { + nvte_destroy_grouped_matmul_config(config_); + } + config_ = other.config_; + other.config_ = nullptr; + return *this; + } + + ~GroupedMatmulConfigWrapper() { + if (config_ != nullptr) { + nvte_destroy_grouped_matmul_config(config_); + config_ = nullptr; + } + } + + /*! \brief Get the underlying NVTEGroupedMatmulConfig. + * + * \return NVTEGroupedMatmulConfig held by this GroupedMatmulConfigWrapper. + */ + operator NVTEGroupedMatmulConfig() const noexcept { return config_; } + + /*! \brief Set average M dimension hint for algorithm selection. */ + void set_avg_m(int64_t avg_m) { + nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgM, &avg_m, + sizeof(int64_t)); + } + + /*! \brief Set average N dimension hint for algorithm selection. */ + void set_avg_n(int64_t avg_n) { + nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgN, &avg_n, + sizeof(int64_t)); + } + + /*! \brief Set average K dimension hint for algorithm selection. */ + void set_avg_k(int64_t avg_k) { + nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgK, &avg_k, + sizeof(int64_t)); + } + + private: + /*! \brief Wrapped NVTEGroupedMatmulConfig. */ + NVTEGroupedMatmulConfig config_ = nullptr; +}; + } // namespace transformer_engine #endif // __cplusplus From 359a9f548fcc8d7089f7cab9af824976f4aac120 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Dec 2025 09:31:37 +0000 Subject: [PATCH 30/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/gemm/config.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/common/gemm/config.cpp b/transformer_engine/common/gemm/config.cpp index bf0b7bc2bf..5c1a899d59 100644 --- a/transformer_engine/common/gemm/config.cpp +++ b/transformer_engine/common/gemm/config.cpp @@ -143,8 +143,7 @@ void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, // Write to buffer NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)"); - const auto &config_ = - *reinterpret_cast(config); + const auto &config_ = *reinterpret_cast(config); switch (attr) { case kNVTEGroupedMatmulConfigAvgM: std::memcpy(buf, &config_.avg_m, attr_size); From a702426f1bddc0b4b1e2d0ce5dd808a19c039174 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 30 Dec 2025 12:08:30 +0100 Subject: [PATCH 31/61] fixes Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 50 ++++++++++++------- transformer_engine/common/gemm/config.cpp | 12 +++++ transformer_engine/common/gemm/config.h | 12 ++++- .../common/gemm/cublaslt_grouped_gemm.cu | 9 ++++ .../common/include/transformer_engine/gemm.h | 16 ++++++ 5 files changed, 81 insertions(+), 18 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 4d6e1b7bb9..1a85e54f82 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -293,7 +293,8 @@ struct TestParams { bool transa; bool transb; ShapeCase shape_case; - bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0) + bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0) + bool use_split_accumulator = false; // Whether to use split accumulator for FP8 GEMM }; // Returns a vector of (M, N, K) tuples for each GEMM in the group. @@ -362,8 +363,6 @@ void run_grouped_gemm_case(const TestParams& params) { std::vector A_ptrs(num_gemms); std::vector B_ptrs(num_gemms); std::vector D_ptrs(num_gemms); - std::vector bias_ptrs(num_gemms, nullptr); - std::vector gelu_ptrs(num_gemms, nullptr); std::vector workspaces(num_gemms); std::vector workspace_ptrs(num_gemms, nullptr); std::vector A_views; @@ -371,6 +370,10 @@ void run_grouped_gemm_case(const TestParams& params) { A_views.reserve(num_gemms); B_views.reserve(num_gemms); + // Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues) + std::vector bias_ptrs(num_gemms, nullptr); + std::vector gelu_ptrs(num_gemms, nullptr); + const size_t cublas_ws_bytes = 32ull * 1024 * 1024; for (size_t i = 0; i < num_gemms; ++i) { @@ -391,11 +394,11 @@ void run_grouped_gemm_case(const TestParams& params) { static_cast(num_gemms), params.transa, params.transb, - false, + false, // grad workspace_ptrs.data(), - false, - false, - 0, + false, // accumulate + params.use_split_accumulator, + 0, // sm_count 0); GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode()); @@ -447,6 +450,10 @@ void run_grouped_gemm_case(const TestParams& params) { Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); Tensor cublas_ws("cublas_ws", std::vector{cublas_ws_bytes}, DType::kByte); + // Create config with use_split_accumulator setting + transformer_engine::GroupedMatmulConfigWrapper config; + config.set_use_split_accumulator(params.use_split_accumulator); + nvte_grouped_gemm(params.transa, params.transb, alpha_tensor.data(), @@ -457,7 +464,7 @@ void run_grouped_gemm_case(const TestParams& params) { grouped_D.get_handle(), setup_ws.data(), cublas_ws.data(), - nullptr, // config (use defaults) + config, 0); for (size_t i = 0; i < num_gemms; ++i) { @@ -495,20 +502,29 @@ std::string MakeGroupedGemmTestName(const testing::TestParamInfo(info.param.input_case)]) + "_" + - kShapeNames[static_cast(info.param.shape_case)] + "_" + layout + null_c; + kShapeNames[static_cast(info.param.shape_case)] + "_" + layout + null_c + split_acc; } +// TestParams: {input_case, transa, transb, shape_case, use_null_c, use_split_accumulator} const std::vector kTestParams = { - {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false}, - {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false}, - {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false}, - {InputCase::kBF16, true, false, ShapeCase::kSameFirst, false}, - {InputCase::kBF16, false, true, ShapeCase::kSameLast, false}, - {InputCase::kBF16, false, false, ShapeCase::kAllSame, false}, - {InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false}, + // Basic tests (no split accumulator) + {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false, false}, + {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false, false}, + {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false, false}, + {InputCase::kBF16, true, false, ShapeCase::kSameFirst, false, false}, + {InputCase::kBF16, false, true, ShapeCase::kSameLast, false, false}, + {InputCase::kBF16, false, false, ShapeCase::kAllSame, false, false}, + {InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false, false}, // Test NULL C (valid when beta=0) - {InputCase::kBF16, false, false, ShapeCase::kAllSame, true}, + {InputCase::kBF16, false, false, ShapeCase::kAllSame, true, false}, + + // Split accumulator tests + {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false, true}, + {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false, true}, + {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false, true}, + {InputCase::kFP8Current, true, false, ShapeCase::kSameFirst, false, true}, }; INSTANTIATE_TEST_SUITE_P(OperatorTest, diff --git a/transformer_engine/common/gemm/config.cpp b/transformer_engine/common/gemm/config.cpp index 5c1a899d59..2c7fc38129 100644 --- a/transformer_engine/common/gemm/config.cpp +++ b/transformer_engine/common/gemm/config.cpp @@ -154,6 +154,12 @@ void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, case kNVTEGroupedMatmulConfigAvgK: std::memcpy(buf, &config_.avg_k, attr_size); break; + case kNVTEGroupedMatmulConfigUseSplitAccumulator: + std::memcpy(buf, &config_.use_split_accumulator, attr_size); + break; + case kNVTEGroupedMatmulConfigSMCount: + std::memcpy(buf, &config_.sm_count, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); } @@ -189,6 +195,12 @@ void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, std::memcpy(&config_.avg_k, buf, attr_size); config_.avg_k_set = true; break; + case kNVTEGroupedMatmulConfigUseSplitAccumulator: + std::memcpy(&config_.use_split_accumulator, buf, attr_size); + break; + case kNVTEGroupedMatmulConfigSMCount: + std::memcpy(&config_.sm_count, buf, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/common/gemm/config.h b/transformer_engine/common/gemm/config.h index 4f93ff7fbc..012de5e059 100644 --- a/transformer_engine/common/gemm/config.h +++ b/transformer_engine/common/gemm/config.h @@ -7,6 +7,8 @@ #ifndef TRANSFORMER_ENGINE_GEMM_CONFIG_H_ #define TRANSFORMER_ENGINE_GEMM_CONFIG_H_ +#include + #include namespace transformer_engine { @@ -38,6 +40,12 @@ struct GroupedMatmulConfig { int64_t avg_n = 0; int64_t avg_k = 0; + // Whether to use split accumulator for FP8 GEMM (more accurate but slower) + bool use_split_accumulator = true; + + // Number of streaming multiprocessors to use in GEMM kernel + int sm_count = 0; + // Track which attributes have been explicitly set bool avg_m_set = false; bool avg_n_set = false; @@ -46,7 +54,9 @@ struct GroupedMatmulConfig { static constexpr size_t attr_sizes[] = { sizeof(int64_t), // avg_m sizeof(int64_t), // avg_n - sizeof(int64_t) // avg_k + sizeof(int64_t), // avg_k + sizeof(bool), // use_split_accumulator + sizeof(int) // sm_count }; }; diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 03692bf052..0183752b55 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -568,6 +568,15 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT init_matmul_desc(matmulDesc, op_A, op_B); set_fp8_scale_pointers(matmulDesc, A_sel, B_sel); + // Set fast accumulation mode for FP8 + // Fast accumulation: 0 = split accumulator (more accurate), 1 = fast accumulator + const bool is_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); + if (is_fp8) { + int8_t fastAccuMode = config_.use_split_accumulator ? 0 : 1; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); + } + // Compute average dimensions for heuristics // K dimension: if transa, K is A's first dim; if not, K is A's last dim int64_t avg_m_val = config_.avg_m_set ? config_.avg_m : compute_avg_first_dim(outputD); diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 00fd0b7048..a596b77fde 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -82,6 +82,10 @@ enum NVTEGroupedMatmulConfigAttribute { * computed automatically from A's logical shape. */ kNVTEGroupedMatmulConfigAvgK = 2, + /*! Whether to use split accumulator for FP8 GEMM. */ + kNVTEGroupedMatmulConfigUseSplitAccumulator = 3, + /*! Number of streaming multiprocessors to use in GEMM kernel. */ + kNVTEGroupedMatmulConfigSMCount = 4, kNVTEGroupedMatmulConfigNumAttributes }; @@ -483,6 +487,18 @@ class GroupedMatmulConfigWrapper { sizeof(int64_t)); } + /*! \brief Set whether to use split accumulator for FP8 GEMM. */ + void set_use_split_accumulator(bool use_split_accumulator) { + nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigUseSplitAccumulator, + &use_split_accumulator, sizeof(bool)); + } + + /*! \brief Set number of streaming multiprocessors to use. */ + void set_sm_count(int sm_count) { + nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigSMCount, + &sm_count, sizeof(int)); + } + private: /*! \brief Wrapped NVTEGroupedMatmulConfig. */ NVTEGroupedMatmulConfig config_ = nullptr; From fb027d0481dd5c4165013b38d9fed53bbc3141fc Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 30 Dec 2025 12:15:08 +0100 Subject: [PATCH 32/61] fix Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 42 +++++++------------ transformer_engine/common/gemm/config.cpp | 6 --- transformer_engine/common/gemm/config.h | 4 -- .../common/gemm/cublaslt_grouped_gemm.cu | 31 ++++++++------ .../common/include/transformer_engine/gemm.h | 10 +---- 5 files changed, 33 insertions(+), 60 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 1a85e54f82..46add9e5e1 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -279,8 +279,6 @@ Tensor make_fp8_operand(const std::string& name, const std::vector& shap Tensor make_bf16_operand(const std::string& name, const std::vector& shape) { Tensor t(name, shape, DType::kBFloat16); - // Fill with ones for easier debugging - //fillUniform(&t); const size_t numel = shape[0] * shape[1]; std::vector<__nv_bfloat16> ones(numel, __float2bfloat16(1.0f)); NVTE_CHECK_CUDA(cudaMemcpy(t.rowwise_dptr(), ones.data(), @@ -293,8 +291,7 @@ struct TestParams { bool transa; bool transb; ShapeCase shape_case; - bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0) - bool use_split_accumulator = false; // Whether to use split accumulator for FP8 GEMM + bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0) }; // Returns a vector of (M, N, K) tuples for each GEMM in the group. @@ -397,7 +394,7 @@ void run_grouped_gemm_case(const TestParams& params) { false, // grad workspace_ptrs.data(), false, // accumulate - params.use_split_accumulator, + false, // use_split_accumulator 0, // sm_count 0); @@ -450,10 +447,6 @@ void run_grouped_gemm_case(const TestParams& params) { Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); Tensor cublas_ws("cublas_ws", std::vector{cublas_ws_bytes}, DType::kByte); - // Create config with use_split_accumulator setting - transformer_engine::GroupedMatmulConfigWrapper config; - config.set_use_split_accumulator(params.use_split_accumulator); - nvte_grouped_gemm(params.transa, params.transb, alpha_tensor.data(), @@ -464,7 +457,7 @@ void run_grouped_gemm_case(const TestParams& params) { grouped_D.get_handle(), setup_ws.data(), cublas_ws.data(), - config, + nullptr, // config (use defaults) 0); for (size_t i = 0; i < num_gemms; ++i) { @@ -502,29 +495,22 @@ std::string MakeGroupedGemmTestName(const testing::TestParamInfo(info.param.input_case)]) + "_" + - kShapeNames[static_cast(info.param.shape_case)] + "_" + layout + null_c + split_acc; + kShapeNames[static_cast(info.param.shape_case)] + "_" + layout + null_c; } -// TestParams: {input_case, transa, transb, shape_case, use_null_c, use_split_accumulator} +// TestParams: {input_case, transa, transb, shape_case, use_null_c} const std::vector kTestParams = { - // Basic tests (no split accumulator) - {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false, false}, - {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false, false}, - {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false, false}, - {InputCase::kBF16, true, false, ShapeCase::kSameFirst, false, false}, - {InputCase::kBF16, false, true, ShapeCase::kSameLast, false, false}, - {InputCase::kBF16, false, false, ShapeCase::kAllSame, false, false}, - {InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false, false}, + // Basic tests + {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false}, + {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false}, + {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false}, + {InputCase::kBF16, true, false, ShapeCase::kSameFirst, false}, + {InputCase::kBF16, false, true, ShapeCase::kSameLast, false}, + {InputCase::kBF16, false, false, ShapeCase::kAllSame, false}, + {InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false}, // Test NULL C (valid when beta=0) - {InputCase::kBF16, false, false, ShapeCase::kAllSame, true, false}, - - // Split accumulator tests - {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false, true}, - {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false, true}, - {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false, true}, - {InputCase::kFP8Current, true, false, ShapeCase::kSameFirst, false, true}, + {InputCase::kBF16, false, false, ShapeCase::kAllSame, true}, }; INSTANTIATE_TEST_SUITE_P(OperatorTest, diff --git a/transformer_engine/common/gemm/config.cpp b/transformer_engine/common/gemm/config.cpp index 2c7fc38129..c305ce033d 100644 --- a/transformer_engine/common/gemm/config.cpp +++ b/transformer_engine/common/gemm/config.cpp @@ -154,9 +154,6 @@ void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, case kNVTEGroupedMatmulConfigAvgK: std::memcpy(buf, &config_.avg_k, attr_size); break; - case kNVTEGroupedMatmulConfigUseSplitAccumulator: - std::memcpy(buf, &config_.use_split_accumulator, attr_size); - break; case kNVTEGroupedMatmulConfigSMCount: std::memcpy(buf, &config_.sm_count, attr_size); break; @@ -195,9 +192,6 @@ void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, std::memcpy(&config_.avg_k, buf, attr_size); config_.avg_k_set = true; break; - case kNVTEGroupedMatmulConfigUseSplitAccumulator: - std::memcpy(&config_.use_split_accumulator, buf, attr_size); - break; case kNVTEGroupedMatmulConfigSMCount: std::memcpy(&config_.sm_count, buf, attr_size); break; diff --git a/transformer_engine/common/gemm/config.h b/transformer_engine/common/gemm/config.h index 012de5e059..fd9b1266e6 100644 --- a/transformer_engine/common/gemm/config.h +++ b/transformer_engine/common/gemm/config.h @@ -40,9 +40,6 @@ struct GroupedMatmulConfig { int64_t avg_n = 0; int64_t avg_k = 0; - // Whether to use split accumulator for FP8 GEMM (more accurate but slower) - bool use_split_accumulator = true; - // Number of streaming multiprocessors to use in GEMM kernel int sm_count = 0; @@ -55,7 +52,6 @@ struct GroupedMatmulConfig { sizeof(int64_t), // avg_m sizeof(int64_t), // avg_n sizeof(int64_t), // avg_k - sizeof(bool), // use_split_accumulator sizeof(int) // sm_count }; }; diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 0183752b55..98c78a304d 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -310,15 +310,17 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, // For column-major layout: leading dimension is the number of rows in storage. // If columnwise data was chosen, storage is already transposed. - int *rowa = A_sel.use_columnwise ? ws.M : (A_sel.trans ? ws.K : ws.M); - int *cola = A_sel.use_columnwise ? ws.K : (A_sel.trans ? ws.M : ws.K); - int *lda = rowa; - int *rowb = B_sel.use_columnwise ? ws.N : (B_sel.trans ? ws.N : ws.K); - int *colb = B_sel.use_columnwise ? ws.K : (B_sel.trans ? ws.K : ws.N); - int *ldb = rowb; - - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rowa, cola, lda)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, rowb, colb, ldb)); + // Storage dimensions for A: rows_A x cols_A with leading dimension lda_storage + int *rows_A = A_sel.use_columnwise ? ws.M : (A_sel.trans ? ws.K : ws.M); + int *cols_A = A_sel.use_columnwise ? ws.K : (A_sel.trans ? ws.M : ws.K); + int *lda_storage = rows_A; + // Storage dimensions for B: rows_B x cols_B with leading dimension ldb_storage + int *rows_B = B_sel.use_columnwise ? ws.N : (B_sel.trans ? ws.N : ws.K); + int *cols_B = B_sel.use_columnwise ? ws.K : (B_sel.trans ? ws.K : ws.N); + int *ldb_storage = rows_B; + + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rows_A, cols_A, lda_storage)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, rows_B, cols_B, ldb_storage)); NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.M, ws.N, ws.M)); NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.M, ws.N, ws.M)); } @@ -442,14 +444,15 @@ __global__ void setup_grouped_gemm_kernel( D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); // Compute data pointers + // Note: const_cast is safe here - cuBLAS requires void** but won't modify A/B/C data A_ptrs[idx] = const_cast(a_base) + a_offset * a_elem_size; B_ptrs[idx] = const_cast(b_base) + b_offset * b_elem_size; C_ptrs[idx] = const_cast(c_base) + c_offset * c_elem_size; D_ptrs[idx] = d_base + d_offset * d_elem_size; - // Compute M, N, K dimensions - // Test stores A as {K,M} when !transa, {M,K} when transa - // Test stores B as {N,K} when !transb, {K,N} when transb + // Compute M, N, K dimensions from tensor shapes + // Input A is stored as {K,M} when !transa, {M,K} when transa + // Input B is stored as {N,K} when !transb, {K,N} when transb M[idx] = static_cast(transa ? a_first : a_last); K[idx] = static_cast(transa ? a_last : a_first); N[idx] = static_cast(transb ? b_last : b_first); @@ -570,9 +573,11 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT // Set fast accumulation mode for FP8 // Fast accumulation: 0 = split accumulator (more accurate), 1 = fast accumulator + // Note: cuBLASLt grouped GEMM API does not support configurable split accumulator, + // we always use fast accumulator for performance. const bool is_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); if (is_fp8) { - int8_t fastAccuMode = config_.use_split_accumulator ? 0 : 1; + int8_t fastAccuMode = 1; // Always use fast accumulator NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( &matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); } diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index a596b77fde..1311021185 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -82,10 +82,8 @@ enum NVTEGroupedMatmulConfigAttribute { * computed automatically from A's logical shape. */ kNVTEGroupedMatmulConfigAvgK = 2, - /*! Whether to use split accumulator for FP8 GEMM. */ - kNVTEGroupedMatmulConfigUseSplitAccumulator = 3, /*! Number of streaming multiprocessors to use in GEMM kernel. */ - kNVTEGroupedMatmulConfigSMCount = 4, + kNVTEGroupedMatmulConfigSMCount = 3, kNVTEGroupedMatmulConfigNumAttributes }; @@ -487,12 +485,6 @@ class GroupedMatmulConfigWrapper { sizeof(int64_t)); } - /*! \brief Set whether to use split accumulator for FP8 GEMM. */ - void set_use_split_accumulator(bool use_split_accumulator) { - nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigUseSplitAccumulator, - &use_split_accumulator, sizeof(bool)); - } - /*! \brief Set number of streaming multiprocessors to use. */ void set_sm_count(int sm_count) { nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigSMCount, From ae854151137f41571fbb8c921d627ed96dd0b301 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Dec 2025 11:16:29 +0000 Subject: [PATCH 33/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/gemm/config.h | 4 ++-- .../common/gemm/cublaslt_grouped_gemm.cu | 10 ++++++---- .../common/include/transformer_engine/gemm.h | 4 ++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/gemm/config.h b/transformer_engine/common/gemm/config.h index fd9b1266e6..6f75a34b37 100644 --- a/transformer_engine/common/gemm/config.h +++ b/transformer_engine/common/gemm/config.h @@ -7,10 +7,10 @@ #ifndef TRANSFORMER_ENGINE_GEMM_CONFIG_H_ #define TRANSFORMER_ENGINE_GEMM_CONFIG_H_ -#include - #include +#include + namespace transformer_engine { struct MatmulConfig { diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 98c78a304d..20c3e5222a 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -319,8 +319,10 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, int *cols_B = B_sel.use_columnwise ? ws.K : (B_sel.trans ? ws.K : ws.N); int *ldb_storage = rows_B; - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rows_A, cols_A, lda_storage)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, rows_B, cols_B, ldb_storage)); + NVTE_CHECK_CUBLAS( + cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rows_A, cols_A, lda_storage)); + NVTE_CHECK_CUBLAS( + cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, rows_B, cols_B, ldb_storage)); NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.M, ws.N, ws.M)); NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.M, ws.N, ws.M)); } @@ -578,8 +580,8 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT const bool is_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); if (is_fp8) { int8_t fastAccuMode = 1; // Always use fast accumulator - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, + &fastAccuMode, sizeof(fastAccuMode))); } // Compute average dimensions for heuristics diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 1311021185..1971714621 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -487,8 +487,8 @@ class GroupedMatmulConfigWrapper { /*! \brief Set number of streaming multiprocessors to use. */ void set_sm_count(int sm_count) { - nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigSMCount, - &sm_count, sizeof(int)); + nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigSMCount, &sm_count, + sizeof(int)); } private: From f1fc31c5d043f9b4224d0a6e95e0e55335788383 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 5 Jan 2026 08:59:34 -0800 Subject: [PATCH 34/61] wip --- .../jax/csrc/extensions/gemm.cpp | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 6566ff1689..79418c138e 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -768,10 +768,24 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i)); } - nvte_multi_tensor_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, - grad, workspace_list.data(), accumulate, use_split_accumulator, - num_math_sm, stream); + // nvte_multi_tensor_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), + // pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, + // grad, workspace_list.data(), accumulate, use_split_accumulator, + // num_math_sm, stream); + int64_t avg_m = 0, avg_n = 0, avg_k = 0; + nvte_grouped_gemm( + rhs_is_trans, lhs_is_trans, + alpha, + rhs_list, lhs_list, + beta, + C, + out_list, + workspace_setup, + workspace_cublas, + stream, + &avg_m, + &avg_n, + &avg_k); return ffi_with_cuda_error_check(); } From 43f7e60ecf449413e8fcfe77f4b09bb708c16f51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Wed, 7 Jan 2026 16:22:41 +0100 Subject: [PATCH 35/61] Update transformer_engine/common/gemm/config.h MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Przemyslaw Tredak Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> --- transformer_engine/common/gemm/config.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/gemm/config.h b/transformer_engine/common/gemm/config.h index 6f75a34b37..56c5db16c9 100644 --- a/transformer_engine/common/gemm/config.h +++ b/transformer_engine/common/gemm/config.h @@ -49,10 +49,10 @@ struct GroupedMatmulConfig { bool avg_k_set = false; static constexpr size_t attr_sizes[] = { - sizeof(int64_t), // avg_m - sizeof(int64_t), // avg_n - sizeof(int64_t), // avg_k - sizeof(int) // sm_count + sizeof(avg_m), + sizeof(avg_n), + sizeof(avg_k), + sizeof(sm_count) }; }; From 30468af1570212b254718e2cc25ea8ed64d0b9b1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Jan 2026 15:23:47 +0000 Subject: [PATCH 36/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/gemm/config.h | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/gemm/config.h b/transformer_engine/common/gemm/config.h index 56c5db16c9..2723bf2d30 100644 --- a/transformer_engine/common/gemm/config.h +++ b/transformer_engine/common/gemm/config.h @@ -48,12 +48,8 @@ struct GroupedMatmulConfig { bool avg_n_set = false; bool avg_k_set = false; - static constexpr size_t attr_sizes[] = { - sizeof(avg_m), - sizeof(avg_n), - sizeof(avg_k), - sizeof(sm_count) - }; + static constexpr size_t attr_sizes[] = {sizeof(avg_m), sizeof(avg_n), sizeof(avg_k), + sizeof(sm_count)}; }; } // namespace transformer_engine From 2ccaee5699af2df5cc1f60174578db8a071a3d41 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 7 Jan 2026 16:33:28 +0100 Subject: [PATCH 37/61] changed Signed-off-by: Pawel Gadzinski --- transformer_engine/common/gemm/config.cpp | 48 ++++++++++++------- transformer_engine/common/gemm/config.h | 14 ++---- .../common/gemm/cublaslt_grouped_gemm.cu | 21 ++------ .../common/gemm/cublaslt_grouped_gemm.cuh | 17 ------- 4 files changed, 39 insertions(+), 61 deletions(-) delete mode 100644 transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh diff --git a/transformer_engine/common/gemm/config.cpp b/transformer_engine/common/gemm/config.cpp index c305ce033d..9cdfb29bbd 100644 --- a/transformer_engine/common/gemm/config.cpp +++ b/transformer_engine/common/gemm/config.cpp @@ -145,15 +145,21 @@ void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)"); const auto &config_ = *reinterpret_cast(config); switch (attr) { - case kNVTEGroupedMatmulConfigAvgM: - std::memcpy(buf, &config_.avg_m, attr_size); + case kNVTEGroupedMatmulConfigAvgM: { + int64_t val = config_.avg_m.value_or(0); + std::memcpy(buf, &val, attr_size); break; - case kNVTEGroupedMatmulConfigAvgN: - std::memcpy(buf, &config_.avg_n, attr_size); + } + case kNVTEGroupedMatmulConfigAvgN: { + int64_t val = config_.avg_n.value_or(0); + std::memcpy(buf, &val, attr_size); break; - case kNVTEGroupedMatmulConfigAvgK: - std::memcpy(buf, &config_.avg_k, attr_size); + } + case kNVTEGroupedMatmulConfigAvgK: { + int64_t val = config_.avg_k.value_or(0); + std::memcpy(buf, &val, attr_size); break; + } case kNVTEGroupedMatmulConfigSMCount: std::memcpy(buf, &config_.sm_count, attr_size); break; @@ -180,18 +186,24 @@ void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)"); auto &config_ = *reinterpret_cast(config); switch (attr) { - case kNVTEGroupedMatmulConfigAvgM: - std::memcpy(&config_.avg_m, buf, attr_size); - config_.avg_m_set = true; - break; - case kNVTEGroupedMatmulConfigAvgN: - std::memcpy(&config_.avg_n, buf, attr_size); - config_.avg_n_set = true; - break; - case kNVTEGroupedMatmulConfigAvgK: - std::memcpy(&config_.avg_k, buf, attr_size); - config_.avg_k_set = true; - break; + case kNVTEGroupedMatmulConfigAvgM: { + int64_t val; + std::memcpy(&val, buf, attr_size); + config_.avg_m = val; + break; + } + case kNVTEGroupedMatmulConfigAvgN: { + int64_t val; + std::memcpy(&val, buf, attr_size); + config_.avg_n = val; + break; + } + case kNVTEGroupedMatmulConfigAvgK: { + int64_t val; + std::memcpy(&val, buf, attr_size); + config_.avg_k = val; + break; + } case kNVTEGroupedMatmulConfigSMCount: std::memcpy(&config_.sm_count, buf, attr_size); break; diff --git a/transformer_engine/common/gemm/config.h b/transformer_engine/common/gemm/config.h index 2723bf2d30..b1aaae2591 100644 --- a/transformer_engine/common/gemm/config.h +++ b/transformer_engine/common/gemm/config.h @@ -10,6 +10,7 @@ #include #include +#include namespace transformer_engine { @@ -35,19 +36,14 @@ struct MatmulConfig { struct GroupedMatmulConfig { // Average dimension hints for cuBLASLt algorithm selection heuristics. - // Value of 0 means "not set" - compute automatically from tensor shapes. - int64_t avg_m = 0; - int64_t avg_n = 0; - int64_t avg_k = 0; + // nullopt means "not set" - compute automatically from tensor shapes. + std::optional avg_m; + std::optional avg_n; + std::optional avg_k; // Number of streaming multiprocessors to use in GEMM kernel int sm_count = 0; - // Track which attributes have been explicitly set - bool avg_m_set = false; - bool avg_n_set = false; - bool avg_k_set = false; - static constexpr size_t attr_sizes[] = {sizeof(avg_m), sizeof(avg_n), sizeof(avg_k), sizeof(sm_count)}; }; diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 20c3e5222a..d11e2221be 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -17,7 +17,6 @@ #include "../util/handle_manager.h" #include "../util/logging.h" #include "./config.h" -#include "./cublaslt_grouped_gemm.cuh" namespace { @@ -573,24 +572,12 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT init_matmul_desc(matmulDesc, op_A, op_B); set_fp8_scale_pointers(matmulDesc, A_sel, B_sel); - // Set fast accumulation mode for FP8 - // Fast accumulation: 0 = split accumulator (more accurate), 1 = fast accumulator - // Note: cuBLASLt grouped GEMM API does not support configurable split accumulator, - // we always use fast accumulator for performance. - const bool is_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); - if (is_fp8) { - int8_t fastAccuMode = 1; // Always use fast accumulator - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, - &fastAccuMode, sizeof(fastAccuMode))); - } - // Compute average dimensions for heuristics // K dimension: if transa, K is A's first dim; if not, K is A's last dim - int64_t avg_m_val = config_.avg_m_set ? config_.avg_m : compute_avg_first_dim(outputD); - int64_t avg_n_val = config_.avg_n_set ? config_.avg_n : compute_avg_last_dim(outputD); - int64_t avg_k_val = config_.avg_k_set ? config_.avg_k - : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) - : compute_avg_last_dim(A_sel.tensor)); + int64_t avg_m_val = config_.avg_m.value_or(compute_avg_first_dim(outputD)); + int64_t avg_n_val = config_.avg_n.value_or(compute_avg_last_dim(outputD)); + int64_t avg_k_val = config_.avg_k.value_or(A_sel.trans ? compute_avg_first_dim(A_sel.tensor) + : compute_avg_last_dim(A_sel.tensor)); // Heuristic selection cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh deleted file mode 100644 index a032e594d5..0000000000 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh +++ /dev/null @@ -1,17 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_COMMON_GEMM_CUBLASLT_GROUPED_GEMM_CUH_ -#define TRANSFORMER_ENGINE_COMMON_GEMM_CUBLASLT_GROUPED_GEMM_CUH_ - -#include -#include -#include - -// nvte_grouped_gemm is declared in transformer_engine/gemm.h -// This header is for internal use only. - -#endif // TRANSFORMER_ENGINE_COMMON_GEMM_CUBLASLT_GROUPED_GEMM_CUH_ From bd8fa3010fdad67fc6556063a3058852fe7e572e Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 7 Jan 2026 17:18:25 +0100 Subject: [PATCH 38/61] suggestions Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 223 +----------------- tests/cpp/test_common.cu | 163 +++++++++++++ tests/cpp/test_common.h | 54 +++++ .../common/gemm/cublaslt_grouped_gemm.cu | 30 +-- .../common/include/transformer_engine/gemm.h | 17 +- 5 files changed, 247 insertions(+), 240 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 46add9e5e1..8ff7fa75aa 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -41,57 +41,6 @@ enum class ShapeCase { kAllDifferent, }; -// Custom deleters for RAII -struct CudaDeleter { - void operator()(void* p) const { if (p) cudaFree(p); } -}; -struct GroupedTensorDeleter { - void operator()(NVTEGroupedTensor h) const { if (h) nvte_destroy_grouped_tensor(h); } -}; - -template -using CudaPtr = std::unique_ptr; -using GroupedTensorHandle = std::unique_ptr, GroupedTensorDeleter>; - -// Helper to allocate CUDA memory into a CudaPtr -template -CudaPtr cuda_alloc(size_t bytes) { - void* ptr = nullptr; - NVTE_CHECK_CUDA(cudaMalloc(&ptr, bytes)); - return CudaPtr(static_cast(ptr)); -} - -// Helper owning GPU buffers that back NVTEGroupedTensor. -// NVTEGroupedTensor does not own memory; data/offsets/scales -// must be allocated and freed by the test. -struct GroupedBuffers { - GroupedTensorHandle handle; - CudaPtr<> data; - CudaPtr<> scale_inv; - CudaPtr first_dims_dev; - CudaPtr last_dims_dev; - CudaPtr offsets_dev; - CudaPtr<> columnwise_data; - NVTEShape logical_shape{}; - std::vector offsets_host; - std::vector tensor_bytes; - size_t num_tensors{0}; - size_t elem_size{0}; - DType dtype{DType::kFloat32}; - NVTEScalingMode scaling_mode{NVTE_DELAYED_TENSOR_SCALING}; - - GroupedBuffers() = default; - GroupedBuffers(const GroupedBuffers&) = delete; - GroupedBuffers& operator=(const GroupedBuffers&) = delete; - GroupedBuffers(GroupedBuffers&&) = default; - GroupedBuffers& operator=(GroupedBuffers&&) = default; - ~GroupedBuffers() = default; - - // Convenience accessors for raw pointers - NVTEGroupedTensor get_handle() const { return handle.get(); } - void* get_data() const { return data.get(); } -}; - size_t grouped_setup_workspace_size(const size_t num_tensors) { const size_t ptr_bytes = num_tensors * sizeof(void*); const size_t int_bytes = num_tensors * sizeof(int); @@ -102,168 +51,6 @@ size_t grouped_setup_workspace_size(const size_t num_tensors) { return size; } -GroupedBuffers build_grouped_tensor(const std::vector& tensors, - const NVTEScalingMode scaling_mode) { - NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build."); - const NVTEShape shape = tensors[0]->rowwise_shape(); - const DType dtype = tensors[0]->dtype(); - const size_t num_tensors = tensors.size(); - const size_t elem_size = typeToNumBits(dtype) / 8; - GroupedBuffers grouped; - grouped.elem_size = elem_size; - grouped.num_tensors = num_tensors; - grouped.dtype = dtype; - grouped.scaling_mode = scaling_mode; - grouped.tensor_bytes.resize(num_tensors); - grouped.offsets_host.resize(num_tensors, 0); - - std::vector first_dims(num_tensors); - std::vector last_dims(num_tensors); - for (size_t i = 0; i < num_tensors; ++i) { - const auto s = tensors[i]->rowwise_shape(); - NVTE_CHECK(s.ndim == 2, "Grouped GEMM test expects 2D tensors."); - first_dims[i] = static_cast(s.data[0]); - last_dims[i] = static_cast(s.data[1]); - grouped.tensor_bytes[i] = bytes(s, dtype); - } - - const bool same_first = std::all_of(first_dims.begin(), first_dims.end(), - [&](int64_t v) { return v == first_dims[0]; }); - const bool same_last = std::all_of(last_dims.begin(), last_dims.end(), - [&](int64_t v) { return v == last_dims[0]; }); - - std::vector offsets(num_tensors, 0); - auto random_padding = [&]() -> int64_t { - // Random padding ensuring 16-byte alignment regardless of element size - // cuBLAS requires aligned pointers for vectorized loads - static std::mt19937 gen(12345); - std::uniform_int_distribution dist(0, 3); - // Calculate elements needed for 16-byte alignment in bytes, rounded up - const size_t align_elements = - std::max(1, (16 + elem_size - 1) / elem_size); // 16 bytes / element_size - return dist(gen) * static_cast(align_elements); - }; - - auto numel = [&](size_t idx) -> int64_t { - return first_dims[idx] * last_dims[idx]; - }; - - const bool need_offsets = !same_first || !same_last; - if (need_offsets) { - offsets[0] = 0; - for (size_t i = 1; i < num_tensors; ++i) { - offsets[i] = offsets[i - 1] + numel(i - 1) + random_padding(); - } - } else { - for (size_t i = 0; i < num_tensors; ++i) { - offsets[i] = static_cast(i) * numel(0); - } - } - grouped.offsets_host = offsets; - - int64_t logical_first = 0; - int64_t logical_last = 0; - if (same_first && same_last) { - logical_first = first_dims[0] * static_cast(num_tensors); - logical_last = last_dims[0]; - } else if (same_first && !same_last) { - logical_first = first_dims[0]; - logical_last = std::accumulate(last_dims.begin(), last_dims.end(), int64_t{0}); - } else if (!same_first && same_last) { - logical_first = std::accumulate(first_dims.begin(), first_dims.end(), int64_t{0}); - logical_last = last_dims[0]; - } else { - logical_first = 1; - logical_last = 0; - for (size_t i = 0; i < num_tensors; ++i) { - logical_last += first_dims[i] * last_dims[i]; - } - } - size_t logical_data[2] = {static_cast(logical_first), - static_cast(logical_last)}; - grouped.logical_shape = nvte_make_shape(logical_data, 2); - grouped.handle.reset(nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape)); - - const int64_t last_idx = static_cast(num_tensors - 1); - const int64_t total_elems = need_offsets - ? (offsets[last_idx] + numel(last_idx)) - : (logical_first * logical_last); - const size_t total_bytes = static_cast(total_elems) * elem_size; - - grouped.data = cuda_alloc(total_bytes); - for (size_t i = 0; i < num_tensors; ++i) { - const size_t offset_bytes = static_cast(offsets[i]) * elem_size; - NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data.get()) + offset_bytes, - tensors[i]->rowwise_dptr(), - grouped.tensor_bytes[i], - cudaMemcpyDeviceToDevice)); - } - - NVTEBasicTensor data_tensor{grouped.data.get(), static_cast(dtype), grouped.logical_shape}; - NVTEGroupedTensor h = grouped.handle.get(); - nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseData, &data_tensor); - - const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype); - if (include_columnwise) { - grouped.columnwise_data = cuda_alloc(total_bytes); - for (size_t i = 0; i < num_tensors; ++i) { - const size_t offset_bytes = static_cast(offsets[i]) * elem_size; - NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.columnwise_data.get()) + offset_bytes, - tensors[i]->columnwise_dptr(), - grouped.tensor_bytes[i], - cudaMemcpyDeviceToDevice)); - } - NVTEBasicTensor col_tensor{grouped.columnwise_data.get(), - static_cast(dtype), - grouped.logical_shape}; - nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseData, &col_tensor); - } - - if (!same_first) { - grouped.first_dims_dev = cuda_alloc(num_tensors * sizeof(int64_t)); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev.get(), first_dims.data(), - num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); - NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor fd_tensor{grouped.first_dims_dev.get(), kNVTEInt64, fd_shape}; - nvte_set_grouped_tensor_param(&h, kNVTEGroupedFirstDims, &fd_tensor); - } - - if (!same_last) { - grouped.last_dims_dev = cuda_alloc(num_tensors * sizeof(int64_t)); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev.get(), last_dims.data(), - num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); - NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor ld_tensor{grouped.last_dims_dev.get(), kNVTEInt64, ld_shape}; - nvte_set_grouped_tensor_param(&h, kNVTEGroupedLastDims, &ld_tensor); - } - - if (!same_first || !same_last) { - grouped.offsets_dev = cuda_alloc(num_tensors * sizeof(int64_t)); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev.get(), offsets.data(), - num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); - NVTEShape off_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape}; - nvte_set_grouped_tensor_param(&h, kNVTEGroupedTensorOffsets, &off_tensor); - } - - if (isFp8Type(dtype)) { - std::vector scale_inv_cpu(num_tensors, 1.f); - for (size_t i = 0; i < num_tensors; ++i) { - tensors[i]->to_cpu(); - scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr()[0]; - } - grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(), - sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); - NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape}; - nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseScaleInv, &scale_tensor); - nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor); - } - - return grouped; -} - Tensor make_fp8_operand(const std::string& name, const std::vector& shape) { Tensor input_fp32(name + "_fp32", shape, DType::kFloat32); fillUniform(&input_fp32); @@ -447,14 +234,14 @@ void run_grouped_gemm_case(const TestParams& params) { Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); Tensor cublas_ws("cublas_ws", std::vector{cublas_ws_bytes}, DType::kByte); - nvte_grouped_gemm(params.transa, - params.transb, - alpha_tensor.data(), - grouped_A.get_handle(), + nvte_grouped_gemm(grouped_A.get_handle(), + params.transa, grouped_B.get_handle(), - beta_tensor.data(), + params.transb, params.use_null_c ? nullptr : grouped_C->get_handle(), grouped_D.get_handle(), + alpha_tensor.data(), + beta_tensor.data(), setup_ws.data(), cublas_ws.data(), nullptr, // config (use defaults) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index d70eb13536..21586fc499 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -1057,4 +1058,166 @@ std::array get_scale_tensor_dims(const size_t rows, return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X}; } +GroupedBuffers build_grouped_tensor(const std::vector& tensors, + const NVTEScalingMode scaling_mode) { + NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build."); + const NVTEShape shape = tensors[0]->rowwise_shape(); + const DType dtype = tensors[0]->dtype(); + const size_t num_tensors = tensors.size(); + const size_t elem_size = typeToNumBits(dtype) / 8; + GroupedBuffers grouped; + grouped.elem_size = elem_size; + grouped.num_tensors = num_tensors; + grouped.dtype = dtype; + grouped.scaling_mode = scaling_mode; + grouped.tensor_bytes.resize(num_tensors); + grouped.offsets_host.resize(num_tensors, 0); + + std::vector first_dims(num_tensors); + std::vector last_dims(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + const auto s = tensors[i]->rowwise_shape(); + NVTE_CHECK(s.ndim == 2, "Grouped tensor build expects 2D tensors."); + first_dims[i] = static_cast(s.data[0]); + last_dims[i] = static_cast(s.data[1]); + grouped.tensor_bytes[i] = bytes(s, dtype); + } + + const bool same_first = std::all_of(first_dims.begin(), first_dims.end(), + [&](int64_t v) { return v == first_dims[0]; }); + const bool same_last = std::all_of(last_dims.begin(), last_dims.end(), + [&](int64_t v) { return v == last_dims[0]; }); + + std::vector offsets(num_tensors, 0); + auto random_padding = [&]() -> int64_t { + // Random padding ensuring 16-byte alignment regardless of element size + // cuBLAS requires aligned pointers for vectorized loads + static std::mt19937 gen(12345); + std::uniform_int_distribution dist(0, 3); + // Calculate elements needed for 16-byte alignment in bytes, rounded up + const size_t align_elements = + std::max(1, (16 + elem_size - 1) / elem_size); // 16 bytes / element_size + return dist(gen) * static_cast(align_elements); + }; + + auto numel = [&](size_t idx) -> int64_t { + return first_dims[idx] * last_dims[idx]; + }; + + const bool need_offsets = !same_first || !same_last; + if (need_offsets) { + offsets[0] = 0; + for (size_t i = 1; i < num_tensors; ++i) { + offsets[i] = offsets[i - 1] + numel(i - 1) + random_padding(); + } + } else { + for (size_t i = 0; i < num_tensors; ++i) { + offsets[i] = static_cast(i) * numel(0); + } + } + grouped.offsets_host = offsets; + + int64_t logical_first = 0; + int64_t logical_last = 0; + if (same_first && same_last) { + logical_first = first_dims[0] * static_cast(num_tensors); + logical_last = last_dims[0]; + } else if (same_first && !same_last) { + logical_first = first_dims[0]; + logical_last = std::accumulate(last_dims.begin(), last_dims.end(), int64_t{0}); + } else if (!same_first && same_last) { + logical_first = std::accumulate(first_dims.begin(), first_dims.end(), int64_t{0}); + logical_last = last_dims[0]; + } else { + logical_first = 1; + logical_last = 0; + for (size_t i = 0; i < num_tensors; ++i) { + logical_last += first_dims[i] * last_dims[i]; + } + } + size_t logical_data[2] = {static_cast(logical_first), + static_cast(logical_last)}; + grouped.logical_shape = nvte_make_shape(logical_data, 2); + grouped.handle.reset(nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape)); + + const int64_t last_idx = static_cast(num_tensors - 1); + const int64_t total_elems = need_offsets + ? (offsets[last_idx] + numel(last_idx)) + : (logical_first * logical_last); + const size_t total_bytes = static_cast(total_elems) * elem_size; + + grouped.data = cuda_alloc(total_bytes); + for (size_t i = 0; i < num_tensors; ++i) { + const size_t offset_bytes = static_cast(offsets[i]) * elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data.get()) + offset_bytes, + tensors[i]->rowwise_dptr(), + grouped.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + } + + NVTEBasicTensor data_tensor{grouped.data.get(), static_cast(dtype), grouped.logical_shape}; + NVTEGroupedTensor h = grouped.handle.get(); + nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseData, &data_tensor); + + const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype); + if (include_columnwise) { + grouped.columnwise_data = cuda_alloc(total_bytes); + for (size_t i = 0; i < num_tensors; ++i) { + const size_t offset_bytes = static_cast(offsets[i]) * elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.columnwise_data.get()) + offset_bytes, + tensors[i]->columnwise_dptr(), + grouped.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + } + NVTEBasicTensor col_tensor{grouped.columnwise_data.get(), + static_cast(dtype), + grouped.logical_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseData, &col_tensor); + } + + if (!same_first) { + grouped.first_dims_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev.get(), first_dims.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor fd_tensor{grouped.first_dims_dev.get(), kNVTEInt64, fd_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedFirstDims, &fd_tensor); + } + + if (!same_last) { + grouped.last_dims_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev.get(), last_dims.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor ld_tensor{grouped.last_dims_dev.get(), kNVTEInt64, ld_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedLastDims, &ld_tensor); + } + + if (!same_first || !same_last) { + grouped.offsets_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev.get(), offsets.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape off_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedTensorOffsets, &off_tensor); + } + + if (isFp8Type(dtype)) { + std::vector scale_inv_cpu(num_tensors, 1.f); + for (size_t i = 0; i < num_tensors; ++i) { + tensors[i]->to_cpu(); + scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr()[0]; + } + grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(), + sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); + NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseScaleInv, &scale_tensor); + nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor); + } + + return grouped; +} + } // namespace test diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index b8993dfb62..106c336405 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -500,6 +500,60 @@ int32_t getDeviceComputeCapability(); constexpr int32_t hopperComputeCapability = 90; constexpr int32_t blackwellComputeCapability = 100; +// Custom deleters for RAII +struct CudaDeleter { + void operator()(void* p) const { if (p) cudaFree(p); } +}; +struct GroupedTensorDeleter { + void operator()(NVTEGroupedTensor h) const { if (h) nvte_destroy_grouped_tensor(h); } +}; + +template +using CudaPtr = std::unique_ptr; +using GroupedTensorHandle = std::unique_ptr, GroupedTensorDeleter>; + +// Helper to allocate CUDA memory into a CudaPtr +template +CudaPtr cuda_alloc(size_t bytes) { + void* ptr = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&ptr, bytes)); + return CudaPtr(static_cast(ptr)); +} + +// Helper owning GPU buffers that back NVTEGroupedTensor. +// NVTEGroupedTensor does not own memory; data/offsets/scales +// must be allocated and freed by the test. +struct GroupedBuffers { + GroupedTensorHandle handle; + CudaPtr<> data; + CudaPtr<> scale_inv; + CudaPtr first_dims_dev; + CudaPtr last_dims_dev; + CudaPtr offsets_dev; + CudaPtr<> columnwise_data; + NVTEShape logical_shape{}; + std::vector offsets_host; + std::vector tensor_bytes; + size_t num_tensors{0}; + size_t elem_size{0}; + DType dtype{DType::kFloat32}; + NVTEScalingMode scaling_mode{NVTE_DELAYED_TENSOR_SCALING}; + + GroupedBuffers() = default; + GroupedBuffers(const GroupedBuffers&) = delete; + GroupedBuffers& operator=(const GroupedBuffers&) = delete; + GroupedBuffers(GroupedBuffers&&) = default; + GroupedBuffers& operator=(GroupedBuffers&&) = default; + ~GroupedBuffers() = default; + + // Convenience accessors for raw pointers + NVTEGroupedTensor get_handle() const { return handle.get(); } + void* get_data() const { return data.get(); } +}; + +GroupedBuffers build_grouped_tensor(const std::vector& tensors, + const NVTEScalingMode scaling_mode); + } // namespace test #if FP4_TYPE_SUPPORTED diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index d11e2221be..5638dc772f 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -62,7 +62,7 @@ struct TensorShapeInfo { } // Create for C tensor (uses D's dimensions, only has offsets) - static TensorShapeInfo for_C(const transformer_engine::GroupedTensor *C, + static TensorShapeInfo create_shape_info_for_C(const transformer_engine::GroupedTensor *C, const transformer_engine::GroupedTensor *D) { const bool has_first = D->first_dims.has_data(); const bool has_last = D->last_dims.has_data(); @@ -166,16 +166,16 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor const transformer_engine::Tensor *alpha_tensor, const transformer_engine::Tensor *beta_tensor) { const size_t num_tensors = inputA->num_tensors; - NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); + NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: number of tensors must be at least 1"); NVTE_CHECK(inputB->num_tensors == num_tensors, - "Grouped GEMM: A and B must have the same num_tensors"); + "Grouped GEMM: A and B must have the same number of tensors"); // C can be NULL (will use D as C when beta=0) if (inputC != nullptr) { NVTE_CHECK(inputC->num_tensors == num_tensors, - "Grouped GEMM: A and C must have the same num_tensors"); + "Grouped GEMM: A and C must have the same number of tensors"); } NVTE_CHECK(outputD->num_tensors == num_tensors, - "Grouped GEMM: A and D must have the same num_tensors"); + "Grouped GEMM: A and D must have the same number of tensors"); // Validate alpha/beta have per-matrix values const size_t alpha_numel = alpha_tensor->data.numel(); @@ -471,7 +471,7 @@ inline void launch_grouped_gemm_setup( const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream) { TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A_sel.tensor); TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B_sel.tensor); - TensorShapeInfo C_meta = TensorShapeInfo::for_C(C, D); + TensorShapeInfo C_meta = TensorShapeInfo::create_shape_info_for_C(C, D); TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); const char *c_base = static_cast(C->data.dptr); @@ -500,10 +500,11 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { } // namespace -void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, - const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, - NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, - NVTEGroupedMatmulConfig config, cudaStream_t stream) { +void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, + const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, + const NVTETensor beta, NVTETensor workspace_setup, + NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, + cudaStream_t stream) { NVTE_API_CALL(nvte_grouped_gemm); using namespace transformer_engine; @@ -593,10 +594,11 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT #else // CUBLAS_VERSION < 130100 -void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, - const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, - NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, - NVTEGroupedMatmulConfig config, cudaStream_t stream) { +void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, + const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, + const NVTETensor beta, NVTETensor workspace_setup, + NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, + cudaStream_t stream) { NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.1+, but compile-time cuBLAS version is ", CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); } diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 1971714621..cc12fb1c6b 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -308,14 +308,14 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous * memory layout and shape metadata. * - * \param[in] transa Whether to transpose A matrices. - * \param[in] transb Whether to transpose B matrices. - * \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements). * \param[in] A Input grouped tensor A. + * \param[in] transa Whether to transpose A matrices. * \param[in] B Input grouped tensor B. - * \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements). + * \param[in] transb Whether to transpose B matrices. * \param[in] C Input grouped tensor C (can be NULL for beta=0). * \param[out] D Output grouped tensor D. + * \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements). + * \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements). * \param[in] workspace_setup Workspace tensor for pointer array setup. * \param[in] workspace_cublas Workspace tensor for cuBLAS operations. * \param[in] config Additional configuration (can be NULL for defaults). @@ -329,10 +329,11 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * - Shape compatibility: if transa=false, transb=false: * - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i]) */ -void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, - const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, - NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, - NVTEGroupedMatmulConfig config, cudaStream_t stream); +void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, + const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, + const NVTETensor beta, NVTETensor workspace_setup, + NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" From f0df80e63b8a3cc60668da6c7124c2a4d5af6ae0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Jan 2026 16:19:17 +0000 Subject: [PATCH 39/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/gemm/cublaslt_grouped_gemm.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 5638dc772f..3861ebf857 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -63,7 +63,7 @@ struct TensorShapeInfo { // Create for C tensor (uses D's dimensions, only has offsets) static TensorShapeInfo create_shape_info_for_C(const transformer_engine::GroupedTensor *C, - const transformer_engine::GroupedTensor *D) { + const transformer_engine::GroupedTensor *D) { const bool has_first = D->first_dims.has_data(); const bool has_last = D->last_dims.has_data(); NVTE_CHECK(has_first || D->all_same_first_dim(), From 301874d31dc5d5cfee6d4e5cbaf1037161354222 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 7 Jan 2026 17:23:23 +0100 Subject: [PATCH 40/61] fix Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 2 +- tests/cpp/test_common.cu | 2 +- tests/cpp/test_common.h | 2 +- transformer_engine/common/gemm/cublaslt_grouped_gemm.cu | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 8ff7fa75aa..90d89c77c8 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 21586fc499..af99d9c42f 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 106c336405..ac9f377ef4 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 3861ebf857..0d376c2e56 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ From c8cf7633aa29fbb93a05d7b70475ff1366fc43f0 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 7 Jan 2026 11:22:48 -0800 Subject: [PATCH 41/61] with many hacks grouped gemm with new api works for a particular hardcoded shape --- transformer_engine/jax/cpp_extensions/gemm.py | 15 ++- .../jax/csrc/extensions/gemm.cpp | 105 ++++++++++++++---- 2 files changed, 96 insertions(+), 24 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 28100c9715..38d21f26ec 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1463,7 +1463,7 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16) + impl_static_args = (9, 10, 11, 12, 13, 14, 15, 16, 17, 18) inner_primitive = None outer_primitive = None @@ -1476,6 +1476,8 @@ def abstract( bias_aval, group_sizes_aval, group_offset_aval, + alpha, + beta, *, M, N, @@ -1535,6 +1537,8 @@ def abstract( # We also pad scale_inv swizzle buffers size for 256 bytes alignment. workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + + workspace_size += 1024*1024 # HACK: properly make a workspace_setup buffer in addition to the workspace_cublas buffer workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) out_shape = (M, N) @@ -1587,6 +1591,8 @@ def impl( bias, group_sizes, group_offset, + alpha, + beta, M, N, K, @@ -1607,6 +1613,8 @@ def impl( bias, group_sizes, group_offset, + alpha, + beta, M=M, N=N, K=K, @@ -2115,6 +2123,9 @@ def grouped_gemm( assert not has_bias or bias.shape == (group_sizes.size, N) bias = jnp.empty((), jnp.float32) if bias is None else bias + num_gemms = group_sizes.shape[0] + alpha = jnp.ones((num_gemms,), jnp.float32) + beta = jnp.zeros((num_gemms,), jnp.float32) (out,) = GroupedGemmPrimitive.outer_primitive.bind( lhs_data, lhs_scale_inv, @@ -2123,6 +2134,8 @@ def grouped_gemm( bias, group_sizes, group_offset, + alpha, + beta, M=M, N=N, K=K_lhs, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 79418c138e..7c2d4c81e6 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -399,10 +399,62 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGro .Ret() // dummy_output .Attr("num_gemms")); +NVTEGroupedTensor make_grouped_tensor(Buffer_Type const& data, std::optional scale_inv, JAXX_Scaling_Mode scaling_mode, size_t num_tensors) { + printf("make_grouped_tensor data shape: "); + for (auto dim : data.dimensions()) { + printf("%zu, ", dim); + } + printf("\n"); + NVTEShape logical_shape{}; + if (data.dimensions().size() == 1) { + // HACK + size_t cdim_size = 4096; + logical_shape.ndim = 2; + logical_shape.data[0] = data.dimensions()[0] / cdim_size; + logical_shape.data[1] = cdim_size; + } + else { + NVTE_CHECK(data.dimensions().size() == 2, "Expected 2D tensor for GEMM operand but received ndim=", data.dimensions().size()); + + logical_shape.ndim = 2; + logical_shape.data[0] = data.dimensions()[0]; + logical_shape.data[1] = data.dimensions()[1]; + } + + NVTEGroupedTensor grouped_tensor = nvte_create_grouped_tensor(get_nvte_scaling_mode(scaling_mode), num_tensors, logical_shape); + + NVTEBasicTensor data_tensor{reinterpret_cast(data.untyped_data()), + static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())), + logical_shape}; + nvte_set_grouped_tensor_param(&grouped_tensor, kNVTEGroupedRowwiseData, &data_tensor); + + if (scale_inv.has_value()) { + NVTEShape logical_scale_shape{}; + if (scale_inv->dimensions().size() == 1) { + logical_scale_shape.ndim = 1; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + } else if (scale_inv->dimensions().size() == 2) { + logical_scale_shape.ndim = 2; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + logical_scale_shape.data[1] = scale_inv->dimensions()[1]; + } else { + NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM scale_inv but received ndim=", scale_inv->dimensions().size()); + } + NVTEBasicTensor scale_inv_tensor{reinterpret_cast(scale_inv->untyped_data()), + static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())), + logical_scale_shape}; + nvte_set_grouped_tensor_param(&grouped_tensor, kNVTEGroupedRowwiseScaleInv, &scale_inv_tensor); + } + + return grouped_tensor; +} + Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, - Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, + Buffer_Type group_sizes, Buffer_Type group_offset, + Buffer_Type alpha, Buffer_Type beta, + Result_Type output, Result_Type workspace, + size_t m, size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { // Notes on matrix layouts and transpose: @@ -577,7 +629,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type std::vector bias_list; std::vector pre_gelu_list; std::vector out_list; - std::vector workspace_list; size_t lhs_sinv_total_size = 0; size_t rhs_sinv_total_size = 0; @@ -724,15 +775,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type out_list.push_back(out_wrapper_list.back().data()); } - auto workspace_shape = std::vector{workspace_size}; - for (int i = 0; i < num_streams; i++) { - auto workspace_i = - TensorWrapper(static_cast(workspace_ptr), workspace_shape, DType::kByte); - workspace_wrapper_list.push_back(std::move(workspace_i)); - workspace_list.push_back(workspace_wrapper_list.back().data()); - workspace_ptr += workspace_size; - } - if (is_fp8_gemm) { if (is_tensor_scaling) { lhs_sinv_size *= tensor_scaling_sinv_aligment; @@ -772,20 +814,35 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, // grad, workspace_list.data(), accumulate, use_split_accumulator, // num_math_sm, stream); - int64_t avg_m = 0, avg_n = 0, avg_k = 0; + + constexpr size_t workspace_setup_size = 1024 * 1024; // HACK: dummy workspace for setup + TensorWrapper workspace_setup(workspace_ptr, + std::vector{workspace_setup_size}, DType::kByte); + TensorWrapper workspace_cublas(workspace_ptr + workspace_setup_size, + std::vector{workspace_size}, DType::kByte); + + TensorWrapper alpha_tensor(static_cast(alpha.untyped_data()), std::vector{num_gemms}, + convert_ffi_datatype_to_te_dtype(alpha.element_type())); + TensorWrapper beta_tensor(static_cast(beta.untyped_data()), std::vector{num_gemms}, + convert_ffi_datatype_to_te_dtype(beta.element_type())); + + NVTEGroupedTensor rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms); + NVTEGroupedTensor lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms); + NVTEGroupedTensor out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms); + nvte_grouped_gemm( rhs_is_trans, lhs_is_trans, - alpha, - rhs_list, lhs_list, - beta, - C, - out_list, - workspace_setup, - workspace_cublas, + alpha_tensor.data(), + rhs_tensor, lhs_tensor, + beta_tensor.data(), + nullptr, + out_tensor, + workspace_setup.data(), + workspace_cublas.data(), stream, - &avg_m, - &avg_n, - &avg_k); + nullptr, + nullptr, + nullptr); return ffi_with_cuda_error_check(); } @@ -800,6 +857,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Arg() // bias .Arg() // group_sizes .Arg() // group_offset + .Arg() // alpha + .Arg() // beta .Ret() // output .Ret() // workspace .Attr("M") From 21e7002991831ecd933388f4ad95a53d0d64d69b Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 7 Jan 2026 11:59:53 -0800 Subject: [PATCH 42/61] progress --- transformer_engine/jax/cpp_extensions/gemm.py | 7 +++ .../jax/csrc/extensions/gemm.cpp | 60 ++++++++++--------- 2 files changed, 40 insertions(+), 27 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 38d21f26ec..25f3315ba7 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2123,6 +2123,13 @@ def grouped_gemm( assert not has_bias or bias.shape == (group_sizes.size, N) bias = jnp.empty((), jnp.float32) if bias is None else bias + print(f"{lhs_data.shape=}, {rhs_data.shape=}, {group_sizes.shape=}") + print(f"{M=}, {N=}, {K_lhs=}, {K_rhs=}") + # import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() + # print(f"{lhs_is_trans=}, {rhs_is_trans=}") + # import pdb; pdb.set_trace() + num_gemms = group_sizes.shape[0] alpha = jnp.ones((num_gemms,), jnp.float32) beta = jnp.zeros((num_gemms,), jnp.float32) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 7c2d4c81e6..9543c66356 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -399,33 +399,34 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGro .Ret() // dummy_output .Attr("num_gemms")); -NVTEGroupedTensor make_grouped_tensor(Buffer_Type const& data, std::optional scale_inv, JAXX_Scaling_Mode scaling_mode, size_t num_tensors) { - printf("make_grouped_tensor data shape: "); - for (auto dim : data.dimensions()) { - printf("%zu, ", dim); - } - printf("\n"); - NVTEShape logical_shape{}; - if (data.dimensions().size() == 1) { - // HACK - size_t cdim_size = 4096; - logical_shape.ndim = 2; - logical_shape.data[0] = data.dimensions()[0] / cdim_size; - logical_shape.data[1] = cdim_size; - } - else { - NVTE_CHECK(data.dimensions().size() == 2, "Expected 2D tensor for GEMM operand but received ndim=", data.dimensions().size()); - - logical_shape.ndim = 2; - logical_shape.data[0] = data.dimensions()[0]; - logical_shape.data[1] = data.dimensions()[1]; - } - - NVTEGroupedTensor grouped_tensor = nvte_create_grouped_tensor(get_nvte_scaling_mode(scaling_mode), num_tensors, logical_shape); +NVTEGroupedTensor make_grouped_tensor(Buffer_Type const& data, std::optional scale_inv, JAXX_Scaling_Mode scaling_mode, size_t num_tensors, NVTEShape const& dataShape) { + // printf("make_grouped_tensor data shape: "); + // for (auto dim : data.dimensions()) { + // printf("%zu, ", dim); + // } + // printf("\n"); + // NVTEShape logical_shape{}; + // if (data.dimensions().size() == 1) { + // // HACK + // size_t cdim_size = 4096; + // logical_shape.ndim = 2; + // logical_shape.data[0] = data.dimensions()[0] / cdim_size; + // logical_shape.data[1] = cdim_size; + // printf("NUM TENSORS: %zu\n", num_tensors); + // } + // else { + // NVTE_CHECK(data.dimensions().size() == 2, "Expected 2D tensor for GEMM operand but received ndim=", data.dimensions().size()); + + // logical_shape.ndim = 2; + // logical_shape.data[0] = data.dimensions()[0]; + // logical_shape.data[1] = data.dimensions()[1]; + // } + + NVTEGroupedTensor grouped_tensor = nvte_create_grouped_tensor(get_nvte_scaling_mode(scaling_mode), num_tensors, dataShape); NVTEBasicTensor data_tensor{reinterpret_cast(data.untyped_data()), static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())), - logical_shape}; + dataShape}; nvte_set_grouped_tensor_param(&grouped_tensor, kNVTEGroupedRowwiseData, &data_tensor); if (scale_inv.has_value()) { @@ -826,9 +827,14 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type TensorWrapper beta_tensor(static_cast(beta.untyped_data()), std::vector{num_gemms}, convert_ffi_datatype_to_te_dtype(beta.element_type())); - NVTEGroupedTensor rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms); - NVTEGroupedTensor lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms); - NVTEGroupedTensor out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms); + NVTEShape rhsShape{.data={num_gemms * k, n}, .ndim=2}; + NVTEGroupedTensor rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); + NVTEShape lhsShape{.data={m, k}, .ndim=2}; + NVTEGroupedTensor lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); + NVTEShape outShape{.data={m, n}, .ndim=2}; + NVTEGroupedTensor out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms, outShape); + + NVTE_CHECK(!rhs_is_trans && !lhs_is_trans, "TE grouped GEMM only supports non-transposed inputs but received rhs_is_trans=", rhs_is_trans, " lhs_is_trans=", lhs_is_trans); nvte_grouped_gemm( rhs_is_trans, lhs_is_trans, From 1ae08ddd7dfde42a9c2fea90128f19a74f9a191c Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 7 Jan 2026 13:46:24 -0800 Subject: [PATCH 43/61] more tests pass --- test_einsum.py | 74 +++++++++++++++++++ .../jax/csrc/extensions/gemm.cpp | 16 +++- 2 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 test_einsum.py diff --git a/test_einsum.py b/test_einsum.py new file mode 100644 index 0000000000..5bb05403f2 --- /dev/null +++ b/test_einsum.py @@ -0,0 +1,74 @@ +from enum import Enum + +import jax +import jax.numpy as jnp +import numpy as np +import transformer_engine.jax as te +from transformer_engine.common.recipe import Recipe, Float8CurrentScaling, MXFP8BlockScaling, DelayedScaling, NVFP4BlockScaling +from flax import linen as nn + +def make_einsum_cls(quantization_recipe): + def te_einsum(generate_quantizer_set, s, x, kernel, **kwargs): + def dot_general(x, kernel, dims, *args, **kwargs): + contracting_dims, batch_dims = dims + assert batch_dims == ((), ()), "Batch dims not supported in TE/JAX yet" + + quantizer_set = generate_quantizer_set("quantizer_set_for_einsum") + return te.dense.dense( + x, + kernel, + contracting_dims=contracting_dims, + quantizer_set=quantizer_set, + ) + return jnp.einsum(s, x, kernel, _dot_general=dot_general, **kwargs) + + return te.flax.wrap_function_in_te_state_module(te_einsum, quantization_recipe, "einsum")() + +class EinsumType(Enum): + JAX = 'jax' + TE = 'te' + +def main(): + + class SimpleModel(nn.Module): + + einsum_type: EinsumType + quantization_recipe: Recipe = None + + def _einsum(self, *args, **kwargs): + if self.einsum_type == EinsumType.JAX: + return jnp.einsum(*args, **kwargs) + elif self.einsum_type == EinsumType.TE: + # It is important that we call make_einsum_cls(recipe) here each time einsum + # is called. If we were to call make_einsum_cls only once and re-use it, the state for some recipes such as DelayedScaling would become incorrectly shared instead of each call having its own state. + return make_einsum_cls(self.quantization_recipe)(*args, **kwargs) + else: + raise ValueError(f"Unsupported einsum type: {self.einsum_type}") + + @nn.compact + def __call__(self, x): + kernel = self.param('kernel', jax.nn.initializers.lecun_normal(), (32, 32), jnp.bfloat16) + return self._einsum("ij,jk->ik", x, kernel) + + + def test_model(einsum_type: EinsumType, quantization_recipe: Recipe = None): + model = SimpleModel(einsum_type=einsum_type, quantization_recipe=quantization_recipe) + x = jax.random.uniform(jax.random.PRNGKey(2), (32, 32), jnp.bfloat16) + var_collect = model.init(jax.random.PRNGKey(3), x) + # It is important to use var_collect here to ensure all state (e.g., quantizer states) is properly handled. If you use var_collect['params'] only, TE's state management will not work correctly for recipes that require state (e.g. DelayedScaling). + y = model.apply(var_collect, x) + return y + + # einsum_cls = None, so standard JAX computation + ref_out = test_model(einsum_type=EinsumType.JAX) + + # einsum using Transformer Engine's Float8CurrentScaling recipe + te_out = test_model(einsum_type=EinsumType.TE, quantization_recipe=Float8CurrentScaling()) + + # Compare outputs + atol = float(jnp.finfo(jnp.float8_e4m3fn).eps) + np.testing.assert_allclose(ref_out, te_out, atol=atol) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 9543c66356..61e241b197 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -827,14 +827,26 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type TensorWrapper beta_tensor(static_cast(beta.untyped_data()), std::vector{num_gemms}, convert_ffi_datatype_to_te_dtype(beta.element_type())); - NVTEShape rhsShape{.data={num_gemms * k, n}, .ndim=2}; + NVTEShape rhsShape{.data={k, n}, .ndim=2}; + if (!is_grouped_dense_wgrad) { + rhsShape.data[0] *= num_gemms; + } + if (rhs_is_trans) { + std::swap(rhsShape.data[0], rhsShape.data[1]); + } NVTEGroupedTensor rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); NVTEShape lhsShape{.data={m, k}, .ndim=2}; + if (lhs_is_trans) { + std::swap(lhsShape.data[0], lhsShape.data[1]); + } NVTEGroupedTensor lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); NVTEShape outShape{.data={m, n}, .ndim=2}; + if (is_grouped_dense_wgrad) { + outShape.data[0] *= num_gemms; + } NVTEGroupedTensor out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms, outShape); - NVTE_CHECK(!rhs_is_trans && !lhs_is_trans, "TE grouped GEMM only supports non-transposed inputs but received rhs_is_trans=", rhs_is_trans, " lhs_is_trans=", lhs_is_trans); + // NVTE_CHECK(!rhs_is_trans && !lhs_is_trans, "TE grouped GEMM only supports non-transposed inputs but received rhs_is_trans=", rhs_is_trans, " lhs_is_trans=", lhs_is_trans); nvte_grouped_gemm( rhs_is_trans, lhs_is_trans, From fe39e39be1abfa46642fdde9e3ede365bc1dfb3c Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 7 Jan 2026 14:25:47 -0800 Subject: [PATCH 44/61] einsum tests pass --- transformer_engine/jax/csrc/extensions/gemm.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 61e241b197..f49530ee1c 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -828,15 +828,16 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type convert_ffi_datatype_to_te_dtype(beta.element_type())); NVTEShape rhsShape{.data={k, n}, .ndim=2}; - if (!is_grouped_dense_wgrad) { - rhsShape.data[0] *= num_gemms; - } if (rhs_is_trans) { std::swap(rhsShape.data[0], rhsShape.data[1]); } + if (!is_grouped_dense_wgrad) { + // If is_grouped_dense_wgrad, then n already includes num_gemms (G) pre-multiplied in gemm.py, so we don't need to multiply it here. + rhsShape.data[0] *= num_gemms; + } NVTEGroupedTensor rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); NVTEShape lhsShape{.data={m, k}, .ndim=2}; - if (lhs_is_trans) { + if (lhs_is_trans && is_grouped_dense_wgrad) { std::swap(lhsShape.data[0], lhsShape.data[1]); } NVTEGroupedTensor lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); From 5e47d57b3e670d86ce37e5e2e44397158360adb4 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 8 Jan 2026 09:37:17 -0800 Subject: [PATCH 45/61] more progress, works in maxtext single-gpu and is closer to bf16 batched gemm speed --- transformer_engine/jax/cpp_extensions/gemm.py | 4 +- .../jax/csrc/extensions/gemm.cpp | 246 +----------------- .../jax/csrc/extensions/quantization.cpp | 26 +- transformer_engine/jax/flax/module.py | 4 +- 4 files changed, 27 insertions(+), 253 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 25f3315ba7..5c53dedb8a 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2123,8 +2123,8 @@ def grouped_gemm( assert not has_bias or bias.shape == (group_sizes.size, N) bias = jnp.empty((), jnp.float32) if bias is None else bias - print(f"{lhs_data.shape=}, {rhs_data.shape=}, {group_sizes.shape=}") - print(f"{M=}, {N=}, {K_lhs=}, {K_rhs=}") + # print(f"{lhs_data.shape=}, {rhs_data.shape=}, {group_sizes.shape=}") + # print(f"{M=}, {N=}, {K_lhs=}, {K_rhs=}") # import pdb; pdb.set_trace() # import pdb; pdb.set_trace() # print(f"{lhs_is_trans=}, {rhs_is_trans=}") diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f49530ee1c..0bfab2d7dc 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -534,22 +534,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); size_t out_dtype_bytes = te_dtype_bytes(out_dtype); - if (is_tensor_scaling) { - size_t dpitch = tensor_scaling_sinv_aligment; - size_t spitch = lhs_sinv_dtype_bytes; - size_t width = lhs_sinv_dtype_bytes; - size_t height = lhs_sinv_size; - cudaMemcpy2DAsync(lhs_scatter_aligned_ptr, dpitch, lhs_sinv_ptr, spitch, width, height, - cudaMemcpyDeviceToDevice, stream); - spitch = rhs_sinv_dtype_bytes; - width = rhs_sinv_dtype_bytes; - height = rhs_sinv_size; - cudaMemcpy2DAsync(rhs_scatter_aligned_ptr, dpitch, rhs_sinv_ptr, spitch, width, height, - cudaMemcpyDeviceToDevice, stream); - lhs_sinv_ptr = lhs_scatter_aligned_ptr; - rhs_sinv_ptr = rhs_scatter_aligned_ptr; - } - NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); @@ -576,29 +560,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type " = ", expected_out_size, ", got ", actual_out_size); } - size_t dim_list_bytes = sizeof(int32_t) * num_gemms; - std::vector dim_list_host(num_gemms); - size_t host_num_gemms = 0; - if (use_async_d2h_group_sizes) { - host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); - NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, - " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); - } else { - auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); - } - size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - if (!is_grouped_dense_wgrad) { - NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, - ", got sum(group_sizes)=", sum_group_sizes); - } else { - NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, - ", got sum(group_sizes)=", sum_group_sizes); - } - auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); bool grad = false; bool accumulate = false; @@ -612,210 +573,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); } - // These lists are to keep the TensorWrapper objects alive - std::vector lhs_wrapper_list; - std::vector rhs_wrapper_list; - std::vector lhs_swizzle_wrapper_list; // For MXFP8 scale_inv swizzling - std::vector rhs_swizzle_wrapper_list; - std::vector bias_wrapper_list; - std::vector pre_gelu_wrapper_list; - std::vector out_wrapper_list; - std::vector workspace_wrapper_list; - - // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM - std::vector lhs_list; - std::vector rhs_list; - std::vector lhs_swizzle_list; - std::vector rhs_swizzle_list; - std::vector bias_list; - std::vector pre_gelu_list; - std::vector out_list; - - size_t lhs_sinv_total_size = 0; - size_t rhs_sinv_total_size = 0; - - std::vector zero_out_dptr_list; - std::vector zero_out_size_list; - - for (size_t i = 0; i < num_gemms; i++) { - // Matrix data shapes - size_t m_i = dim_list_host[i]; - auto lhs_shape_i = std::vector{m_i, k}; - auto rhs_shape_i = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; - auto out_shape_i = std::vector{m_i, n}; - if (is_grouped_dense_wgrad) { - size_t k_i = dim_list_host[i]; - lhs_shape_i[0] = lhs_is_trans ? k_i : m; - lhs_shape_i[1] = lhs_is_trans ? m : k_i; - rhs_shape_i[0] = rhs_is_trans ? n : k_i; - rhs_shape_i[1] = rhs_is_trans ? k_i : n; - out_shape_i[0] = m; - out_shape_i[1] = n; - } - - size_t lhs_size = lhs_shape_i[0] * lhs_shape_i[1]; - size_t rhs_size = rhs_shape_i[0] * rhs_shape_i[1]; - size_t out_size = out_shape_i[0] * out_shape_i[1]; - bool is_empty_gemm = lhs_size == 0 || rhs_size == 0; - if (is_empty_gemm && out_size > 0) { - zero_out_dptr_list.push_back(out_ptr); - zero_out_size_list.push_back(out_size * out_dtype_bytes); - } - - // Set matrix data pointers - auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto out_i = TensorWrapper(static_cast(out_ptr), out_shape_i, out_dtype); - void *lhs_vptr = static_cast(lhs_ptr); - void *rhs_vptr = static_cast(rhs_ptr); - if (rhs_use_colwise) // MatA to enter cuBLAS - rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - else - rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - if (lhs_use_colwise) // MatB to enter cuBLAS - lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - else - lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - - // Set scale_inv shapes and pointers - void *rhs_sinv_vptr = static_cast(rhs_sinv_ptr); - void *lhs_sinv_vptr = static_cast(lhs_sinv_ptr); - size_t lhs_sinv_size_i = 0; - size_t rhs_sinv_size_i = 0; - if (is_tensor_scaling) { - auto tensor_scaling_sinv_shape = std::vector{1}; - // If is_empty_gemm, scale_inv does not have the corresponding value, do not move the pointers - if (!is_empty_gemm) { - lhs_sinv_size_i = tensor_scaling_sinv_aligment / lhs_sinv_dtype_bytes; - rhs_sinv_size_i = tensor_scaling_sinv_aligment / rhs_sinv_dtype_bytes; - } - if (rhs_use_colwise) // MatA to enter cuBLAS - rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); - else - rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); - if (lhs_use_colwise) // MatB to enter cuBLAS - lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); - else - lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); - } else if (is_mxfp8_scaling) { - auto lhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto rhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - void *swizzled_lhs_sinv_vptr = static_cast(swizzled_lhs_sinv_ptr); - void *swizzled_rhs_sinv_vptr = static_cast(swizzled_rhs_sinv_ptr); - - // {lhs, rhs}_swizzle_i point to unswizzled scale_inv data as input, while {lhs, rhs}_i - // point to swizzled scale_inv data (store on workspace, only used for GEMM). - // Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers - auto lhs_sinv_shape_i = - get_block_scale_shape(scaling_mode, lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); - auto rhs_sinv_shape_i = - get_block_scale_shape(scaling_mode, rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise); - lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1]; - rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1]; - if (lhs_use_colwise) { - lhs_swizzle_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - lhs_swizzle_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - lhs_i.set_columnwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - } else { - lhs_swizzle_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - lhs_swizzle_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - lhs_i.set_rowwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - } - if (rhs_use_colwise) { - rhs_swizzle_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - rhs_swizzle_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - rhs_i.set_columnwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - } else { - rhs_swizzle_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - rhs_swizzle_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - rhs_i.set_rowwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - } - - if (!is_empty_gemm) { - lhs_swizzle_wrapper_list.push_back(std::move(lhs_swizzle_i)); - rhs_swizzle_wrapper_list.push_back(std::move(rhs_swizzle_i)); - lhs_swizzle_list.push_back(lhs_swizzle_wrapper_list.back().data()); - rhs_swizzle_list.push_back(rhs_swizzle_wrapper_list.back().data()); - } - } else { - NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, - "Unsupported scaling mode: ", static_cast(scaling_mode)); - } - - auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); - auto pre_gelu_i = TensorWrapper(nullptr, std::vector{0}, out_dtype); - - // Update pointer for the next GEMM pair - lhs_ptr += lhs_size * lhs_dtype_bytes; - rhs_ptr += rhs_size * rhs_dtype_bytes; - out_ptr += out_size * out_dtype_bytes; - if (is_fp8_gemm) { - lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; - rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; - lhs_sinv_total_size += lhs_sinv_size_i; - rhs_sinv_total_size += rhs_sinv_size_i; - if (is_mxfp8_scaling) { - swizzled_lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; - swizzled_rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; - } - } - if (has_bias) bias_ptr += n * bias_dtype_bytes; - - // Move objects to the lists to keep them alive - if (is_empty_gemm) continue; - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); - out_wrapper_list.push_back(std::move(out_i)); - bias_wrapper_list.push_back(std::move(bias_i)); - pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); - - lhs_list.push_back(lhs_wrapper_list.back().data()); - rhs_list.push_back(rhs_wrapper_list.back().data()); - bias_list.push_back(bias_wrapper_list.back().data()); - pre_gelu_list.push_back(pre_gelu_wrapper_list.back().data()); - out_list.push_back(out_wrapper_list.back().data()); - } - - if (is_fp8_gemm) { - if (is_tensor_scaling) { - lhs_sinv_size *= tensor_scaling_sinv_aligment; - rhs_sinv_size *= tensor_scaling_sinv_aligment; - } - NVTE_CHECK(lhs_sinv_total_size <= lhs_sinv_size, "Actual total lhs_sinv size ", - lhs_sinv_total_size, " exceeds estimated upper bound ", lhs_sinv_size); - NVTE_CHECK(rhs_sinv_total_size <= rhs_sinv_size, "Actual total rhs_sinv size ", - rhs_sinv_total_size, " exceeds estimated upper bound ", rhs_sinv_size); - } - - size_t num_non_empty_gemms = lhs_list.size(); - - if (is_mxfp8_scaling) { - for (int i = 0; i < num_non_empty_gemms; i++) { - // The i-th GEMM will use the (i % num_streams)-th stream to compute, - // use the same stream to swizzle the scaling factors to make sure that - // the swizzling is done before the GEMM computation starts. - int stream_id = i % num_streams; - cudaStream_t stream_i = nvte_get_compute_stream(stream_id); - nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i); - nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); - } - } - - // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM - size_t num_zero_outs = zero_out_dptr_list.size(); - for (int i = 0; i < num_zero_outs; i++) { - int stream_id = i % num_streams; - cudaStream_t stream_i = nvte_get_compute_stream(stream_id); - void *dptr = zero_out_dptr_list[i]; - size_t count = zero_out_size_list[i]; - NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i)); - } - - // nvte_multi_tensor_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - // pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, - // grad, workspace_list.data(), accumulate, use_split_accumulator, - // num_math_sm, stream); - constexpr size_t workspace_setup_size = 1024 * 1024; // HACK: dummy workspace for setup TensorWrapper workspace_setup(workspace_ptr, std::vector{workspace_setup_size}, DType::kByte); @@ -888,7 +645,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("scaling_mode") .Attr("has_bias") .Attr("is_grouped_dense_wgrad") - .Attr("use_async_d2h_group_sizes")); + .Attr("use_async_d2h_group_sizes"), + FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 1f7db84383..ad3553313f 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -375,11 +375,24 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty size_t num_groups = group_sizes.dimensions()[0]; size_t dim_list_bytes = group_size_dtype_bytes * num_groups; std::vector dim_list_host(num_groups); - auto *group_size_ptr = reinterpret_cast(group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), group_size_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); + // HACK: assumes batched gemm with equal group sizes + for (size_t i = 0; i < num_groups; i++) { + if (input_dims[0] == num_groups) { + dim_list_host[i] = 1; + continue; + } + dim_list_host[i] = m / num_groups; + } + // auto *group_size_ptr = reinterpret_cast(group_sizes.untyped_data()); + // cudaMemcpyAsync(dim_list_host.data(), group_size_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + // stream); + // // Note: This may break cudaGraph. + // cudaStreamSynchronize(stream); + // printf("GroupedQuantizeFFI: m=%zu, n=%zu, group sizes = ", m, n); + // for (size_t i = 0; i < num_groups; i++) { + // printf("%d ", dim_list_host[i]); + // } + // printf("\n"); size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); NVTE_CHECK(m == sum_group_sizes || input_dims[0] == sum_group_sizes, @@ -492,7 +505,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Ret() // amax .Attr("scaling_mode") .Attr("q_layout") - .Attr("flatten_axis")); + .Attr("flatten_axis"), + FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index cc6088e8d2..3b4a5ef148 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1444,6 +1444,7 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): def make_einsum_cls(quantization_recipe): import functools + import math import jax def te_einsum(generate_quantizer_set, s, x, kernel, **kwargs): # with open("/tmp/te_einsum_log.txt", "a") as f: @@ -1493,7 +1494,8 @@ def reorder_rhs_for_grouped_gemm(tensor, bdims, cdims): kernel = reorder_rhs_for_grouped_gemm(kernel, (batch_dims[1],), contracting_dims[1]) num_groups = kernel.shape[0] - group_size = x.shape[0] // num_groups + group_size = math.prod(x.shape[:-1]) // num_groups + print(f'{num_groups=}, {group_size=}, {x.shape=}, {kernel.shape=}') group_sizes = jnp.array([group_size]*num_groups, dtype=jnp.int32) From bc6cf66512bf4a4a35ce9e014768bb34f749744b Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 8 Jan 2026 10:44:12 -0800 Subject: [PATCH 46/61] attempt at passing thru stateful args for DS --- transformer_engine/jax/quantize/quantizer.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 4edc187795..6831758875 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -7,7 +7,7 @@ This module provides classes and utilities for quantizing tensors in JAX. """ from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass, field, InitVar from functools import partial from typing import Union, Optional, Tuple import warnings @@ -893,6 +893,7 @@ class GroupedQuantizer(Quantizer): data_layout: str = None n_groups: int = 1 quantizers: Tuple[Quantizer] = field(default_factory=lambda: (None,)) + extra_kwargs: InitVar[dict] = None def tree_flatten(self): """Flatten the quantizer for JAX tree operations. @@ -911,10 +912,12 @@ def tree_flatten(self): ) return (children, aux_data) - def __post_init__(self): + def __post_init__(self, extra_kwargs: dict): + print(f"QuantizerFactory creating quantizers for GroupedQuantizer: {self.n_groups=}, {self.scaling_mode=}, {self.q_dtype=}, {self.q_layout=}, {extra_kwargs=}, {self.quantizers=}") if self.quantizers[0] is None: quantizers = QuantizerFactory.create( - self.n_groups, self.scaling_mode, self.q_dtype, self.q_layout + n_quantizers=self.n_groups, + scaling_mode=self.scaling_mode, q_dtype=self.q_dtype, q_layout=self.q_layout, **extra_kwargs ) self.quantizers = (quantizers,) if not isinstance(quantizers, tuple) else quantizers self.data_layout = self.quantizers[0].data_layout @@ -1106,8 +1109,14 @@ def create( warnings.warn( "Using more than one GroupedQuantizer for a grouped input is not recommended" ) - quantizer_type = GroupedQuantizer - kwargs["n_groups"] = n_groups + quantizer_type = lambda q_dtype, scaling_mode, q_layout, checkpoint_name, **kwargs: GroupedQuantizer( + q_dtype=q_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + checkpoint_name=checkpoint_name, + n_groups=n_groups, + extra_kwargs=kwargs, + ) else: quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode) From bcbe864825fa8f40103c72b8b750a807490de28f Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 8 Jan 2026 10:44:18 -0800 Subject: [PATCH 47/61] Revert "attempt at passing thru stateful args for DS" This reverts commit bc6cf66512bf4a4a35ce9e014768bb34f749744b. --- transformer_engine/jax/quantize/quantizer.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 6831758875..4edc187795 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -7,7 +7,7 @@ This module provides classes and utilities for quantizing tensors in JAX. """ from abc import ABC, abstractmethod -from dataclasses import dataclass, field, InitVar +from dataclasses import dataclass, field from functools import partial from typing import Union, Optional, Tuple import warnings @@ -893,7 +893,6 @@ class GroupedQuantizer(Quantizer): data_layout: str = None n_groups: int = 1 quantizers: Tuple[Quantizer] = field(default_factory=lambda: (None,)) - extra_kwargs: InitVar[dict] = None def tree_flatten(self): """Flatten the quantizer for JAX tree operations. @@ -912,12 +911,10 @@ def tree_flatten(self): ) return (children, aux_data) - def __post_init__(self, extra_kwargs: dict): - print(f"QuantizerFactory creating quantizers for GroupedQuantizer: {self.n_groups=}, {self.scaling_mode=}, {self.q_dtype=}, {self.q_layout=}, {extra_kwargs=}, {self.quantizers=}") + def __post_init__(self): if self.quantizers[0] is None: quantizers = QuantizerFactory.create( - n_quantizers=self.n_groups, - scaling_mode=self.scaling_mode, q_dtype=self.q_dtype, q_layout=self.q_layout, **extra_kwargs + self.n_groups, self.scaling_mode, self.q_dtype, self.q_layout ) self.quantizers = (quantizers,) if not isinstance(quantizers, tuple) else quantizers self.data_layout = self.quantizers[0].data_layout @@ -1109,14 +1106,8 @@ def create( warnings.warn( "Using more than one GroupedQuantizer for a grouped input is not recommended" ) - quantizer_type = lambda q_dtype, scaling_mode, q_layout, checkpoint_name, **kwargs: GroupedQuantizer( - q_dtype=q_dtype, - scaling_mode=scaling_mode, - q_layout=q_layout, - checkpoint_name=checkpoint_name, - n_groups=n_groups, - extra_kwargs=kwargs, - ) + quantizer_type = GroupedQuantizer + kwargs["n_groups"] = n_groups else: quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode) From b40353fbad69d3b90197f1ea8dd28dee9263d593 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 8 Jan 2026 14:06:45 -0800 Subject: [PATCH 48/61] batch gemm specialization for CS amax calc --- .../jax/cpp_extensions/quantization.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index a95afe8b8e..b8ea3bd4f4 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1209,21 +1209,26 @@ def grouped_quantize( assert n_groups == len( quantizer.quantizers ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}" - scale = jnp.empty((n_groups,), jnp.float32) + scale = jnp.ones((n_groups,), jnp.float32) if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: for i, quantizer_i in enumerate(quantizer.quantizers): scale = scale.at[i].set(quantizer_i.scale[0]) if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: - if amax is not None: - row_amax = amax - else: - row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) - segment_ids = jnp.repeat( - jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis] - ) - grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) + # TODO fixme, measure perf with always scale/amax of 1 to just isolate quant and gemm + # HACK: assumes equal group sizes + assert group_axis == 0, f"Currently only group_axis = 0 is supported for current-tensor-scaling, but received {group_axis=}" + grouped_amax = jnp.max(jnp.abs(x.reshape((n_groups, x.shape[0]//n_groups, *x.shape[1:]))), axis=tuple(range(1, x.ndim+1))) + # import pdb; pdb.set_trace() + # if amax is not None: + # row_amax = amax + # else: + # row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) + # segment_ids = jnp.repeat( + # jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis] + # ) + # grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) for i in range(n_groups): tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype, margin=0.0) scale = scale.at[i].set(tmp_scale[0]) From 6c5d96941522cecfffad51d68a16c2a79428012b Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 9 Jan 2026 17:27:20 +0100 Subject: [PATCH 49/61] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/common/gemm/config.h | 6 ++++-- .../common/gemm/cublaslt_grouped_gemm.cu | 21 +++++++++---------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/transformer_engine/common/gemm/config.h b/transformer_engine/common/gemm/config.h index b1aaae2591..cdea24ea7e 100644 --- a/transformer_engine/common/gemm/config.h +++ b/transformer_engine/common/gemm/config.h @@ -44,8 +44,10 @@ struct GroupedMatmulConfig { // Number of streaming multiprocessors to use in GEMM kernel int sm_count = 0; - static constexpr size_t attr_sizes[] = {sizeof(avg_m), sizeof(avg_n), sizeof(avg_k), - sizeof(sm_count)}; + // Note: API transfers the value type, not std::optional + static constexpr size_t attr_sizes[] = {sizeof(decltype(avg_m)::value_type), + sizeof(decltype(avg_n)::value_type), + sizeof(decltype(avg_k)::value_type), sizeof(sm_count)}; }; } // namespace transformer_engine diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 0d376c2e56..a03e5b516a 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -214,7 +214,7 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor // fallback to column-wise data when row-wise is absent. struct GroupedOperandSelection { const transformer_engine::GroupedTensor *tensor = nullptr; - const char *dptr = nullptr; + char *dptr = nullptr; transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; bool trans = false; bool use_columnwise = false; @@ -248,7 +248,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: if (is_A) { if (!sel.trans) { NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"); - sel.dptr = static_cast(t->columnwise_data.dptr); + sel.dptr = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = true; // using pre-transposed storage sel.use_columnwise = true; @@ -257,7 +257,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: } else { // B if (sel.trans) { NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"); - sel.dptr = static_cast(t->columnwise_data.dptr); + sel.dptr = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = false; // using pre-transposed storage sel.use_columnwise = true; @@ -272,7 +272,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: NVTE_CHECK( !is_fp8 || non_tn_fp8_ok, "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); - sel.dptr = static_cast(t->columnwise_data.dptr); + sel.dptr = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = !sel.trans; sel.use_columnwise = true; @@ -280,7 +280,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: } // Default: use row-wise data (column-wise case already handled above) - sel.dptr = static_cast(t->data.dptr); + sel.dptr = static_cast(t->data.dptr); sel.dtype = row_dtype; sel.use_columnwise = false; return sel; @@ -414,7 +414,7 @@ __global__ void setup_grouped_gemm_kernel( void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *M, int *N, int *K, float **alpha_ptrs, float **beta_ptrs, // Base pointers - const char *a_base, const char *b_base, const char *c_base, char *d_base, + char *a_base, char *b_base, char *c_base, char *d_base, // Dimension info (per tensor) TensorShapeInfo A_meta, TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, // Element sizes @@ -445,10 +445,9 @@ __global__ void setup_grouped_gemm_kernel( D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); // Compute data pointers - // Note: const_cast is safe here - cuBLAS requires void** but won't modify A/B/C data - A_ptrs[idx] = const_cast(a_base) + a_offset * a_elem_size; - B_ptrs[idx] = const_cast(b_base) + b_offset * b_elem_size; - C_ptrs[idx] = const_cast(c_base) + c_offset * c_elem_size; + A_ptrs[idx] = a_base + a_offset * a_elem_size; + B_ptrs[idx] = b_base + b_offset * b_elem_size; + C_ptrs[idx] = c_base + c_offset * c_elem_size; D_ptrs[idx] = d_base + d_offset * d_elem_size; // Compute M, N, K dimensions from tensor shapes @@ -474,7 +473,7 @@ inline void launch_grouped_gemm_setup( TensorShapeInfo C_meta = TensorShapeInfo::create_shape_info_for_C(C, D); TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); - const char *c_base = static_cast(C->data.dptr); + char *c_base = static_cast(C->data.dptr); char *d_base = static_cast(D->data.dptr); const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); From c91cd8ffa2e5ec93247a716bede7851673c94b0c Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 9 Jan 2026 17:32:19 +0100 Subject: [PATCH 50/61] fix Signed-off-by: Pawel Gadzinski --- .../common/gemm/cublaslt_gemm.cu | 33 ++++++++----------- .../common/gemm/cublaslt_grouped_gemm.cu | 5 ++- .../common/util/cuda_runtime.cpp | 7 ++++ transformer_engine/common/util/cuda_runtime.h | 6 ++++ 4 files changed, 30 insertions(+), 21 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 118bf19335..7c04c14eff 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -302,13 +302,6 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla return ret; } -/* cuBLAS version number at run-time */ -size_t cublas_version() { - // Cache version to avoid cuBLAS logging overhead - static size_t version = cublasLtGetVersion(); - return version; -} - } // namespace namespace transformer_engine { @@ -501,8 +494,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #endif // CUBLAS_VERSION >= 120800 } else if (mxfp8_gemm) { #if CUBLAS_VERSION >= 120800 - NVTE_CHECK(cublas_version() >= 120800, - "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); + NVTE_CHECK(cuda::cublas_version() >= 120800, + "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cuda::cublas_version()); fp8e8m0 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); fp8e8m0 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -515,7 +508,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. - if (cublas_version() <= 120803) { + if (cuda::cublas_version() <= 120803) { const int64_t dummy_a_vec_stride = 1; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, @@ -527,8 +520,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #endif // CUBLAS_VERSION >= 120800 } else if (use_fp4) { // NVFP4 GEMM #if CUBLAS_VERSION >= 120800 - NVTE_CHECK(cublas_version() >= 120800, - "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); + NVTE_CHECK(cuda::cublas_version() >= 120800, + "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cuda::cublas_version()); // make sure alpha beta computation dtype remains fp32 by CUBLASLT_MATMUL_DESC_SCALE_TYPE cublasDataType_t scale_type = CUDA_R_32F; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( @@ -558,9 +551,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { #if CUBLAS_VERSION >= 120900 - NVTE_CHECK(cublas_version() >= 120900, + NVTE_CHECK(cuda::cublas_version() >= 120900, "FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ", - cublas_version()); + cuda::cublas_version()); float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -588,7 +581,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } #if CUBLAS_VERSION >= 120800 - if (cublas_version() >= 120800) { + if (cuda::cublas_version() >= 120800) { NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a))); @@ -605,7 +598,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); #if CUBLAS_VERSION >= 120800 - if (cublas_version() >= 120800) { + if (cuda::cublas_version() >= 120800) { // NOTE: In all current cases where FP8 output is supported, the input is // scaled identically to the output. NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -692,9 +685,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ", cuda::cudart_version()); - NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000, + NVTE_CHECK(cuda::cublas_version() >= 120205 && cuda::cublas_version() < 130000, "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", - cublas_version()); + cuda::cublas_version()); if (m_split == 0) m_split = 1; if (n_split == 0) n_split = 1; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( @@ -920,9 +913,9 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ", transformer_engine::cuda::cudart_version()); NVTE_CHECK( - cublas_version() >= 120205 && cublas_version() < 130000, + cuda::cublas_version() >= 120205 && cuda::cublas_version() < 130000, "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", - cublas_version()); + cuda::cublas_version()); const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index a03e5b516a..d4696c9127 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -507,10 +507,13 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT NVTE_API_CALL(nvte_grouped_gemm); using namespace transformer_engine; - // Grouped GEMM requires Blackwell (SM100) or newer + // Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.1+ const int current_device = cuda::current_device(); NVTE_CHECK(cuda::sm_arch(current_device) >= 100, "nvte_grouped_gemm requires Blackwell (SM100) or newer architecture."); + NVTE_CHECK(cuda::cublas_version() >= 130100, + "nvte_grouped_gemm requires cuBLAS 13.1+, but run-time cuBLAS version is ", + cuda::cublas_version()); // Convert to internal types const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 2e5ef8b8e1..47cfa6bc96 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -6,6 +6,7 @@ #include "../util/cuda_runtime.h" +#include #include #include @@ -210,6 +211,12 @@ int cudart_version() { return version; } +size_t cublas_version() { + // Cache version to avoid cuBLAS logging overhead + static size_t version = cublasLtGetVersion(); + return version; +} + } // namespace cuda } // namespace transformer_engine diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index 6b999870dd..b7b9680688 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -73,6 +73,12 @@ const std::string &include_directory(bool required = false); */ int cudart_version(); +/* \brief cuBLAS version number at run-time + * + * Versions may differ between compile-time and run-time. + */ +size_t cublas_version(); + } // namespace cuda } // namespace transformer_engine From 0319e79ee06153c560e953548b302b4aee69b5f5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Jan 2026 16:33:05 +0000 Subject: [PATCH 51/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/gemm/cublaslt_gemm.cu | 6 ++++-- transformer_engine/common/util/cuda_runtime.cpp | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7c04c14eff..b82fe82b63 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -495,7 +495,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } else if (mxfp8_gemm) { #if CUBLAS_VERSION >= 120800 NVTE_CHECK(cuda::cublas_version() >= 120800, - "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cuda::cublas_version()); + "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", + cuda::cublas_version()); fp8e8m0 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); fp8e8m0 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -521,7 +522,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } else if (use_fp4) { // NVFP4 GEMM #if CUBLAS_VERSION >= 120800 NVTE_CHECK(cuda::cublas_version() >= 120800, - "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cuda::cublas_version()); + "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", + cuda::cublas_version()); // make sure alpha beta computation dtype remains fp32 by CUBLASLT_MATMUL_DESC_SCALE_TYPE cublasDataType_t scale_type = CUDA_R_32F; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 47cfa6bc96..0e8ff58b7c 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -7,6 +7,7 @@ #include "../util/cuda_runtime.h" #include + #include #include From a14d5bc25a50ff8e6f1b68448ef35bed521049cc Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 13 Jan 2026 19:11:45 +0100 Subject: [PATCH 52/61] refactored hopper tensor selection Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 4 +- .../common/gemm/cublaslt_grouped_gemm.cu | 185 +++++++++++------- 2 files changed, 115 insertions(+), 74 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 90d89c77c8..35c4375cbe 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -44,8 +44,8 @@ enum class ShapeCase { size_t grouped_setup_workspace_size(const size_t num_tensors) { const size_t ptr_bytes = num_tensors * sizeof(void*); const size_t int_bytes = num_tensors * sizeof(int); - // Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 3 int arrays (M, N, K) - size_t size = 6 * ptr_bytes + 3 * int_bytes; + // Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 6 int arrays (a_rows, a_cols, b_rows, b_cols, d_rows, d_cols) + size_t size = 6 * ptr_bytes + 6 * int_bytes; const size_t alignment = 256; size = ((size + alignment - 1) / alignment) * alignment; return size; diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index d4696c9127..c1a75f0523 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -107,11 +107,15 @@ struct GroupedGemmSetupWorkspace { void **B_ptrs; void **C_ptrs; void **D_ptrs; - int *M; - int *N; - int *K; float **alpha_ptrs; float **beta_ptrs; + // Storage dimensions for cuBLAS matrix layouts + int *a_rows; + int *a_cols; + int *b_rows; + int *b_cols; + int *d_rows; // M (first dim) - also used for C + int *d_cols; // N (last dim) - also used for C // Initialize from workspace buffer // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) @@ -135,22 +139,28 @@ struct GroupedGemmSetupWorkspace { ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - // Int arrays last (4-byte aligned, always satisfied after pointer arrays) - ws.M = reinterpret_cast(setup_ws_ptr + offset); + // Int arrays for storage dimensions (4-byte aligned) + ws.a_rows = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; - ws.N = reinterpret_cast(setup_ws_ptr + offset); + ws.a_cols = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; - ws.K = reinterpret_cast(setup_ws_ptr + offset); + ws.b_rows = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.b_cols = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.d_rows = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.d_cols = reinterpret_cast(setup_ws_ptr + offset); return ws; } - // Calculate required size for setup workspace (pointer arrays + M/N/K) + // Calculate required size for setup workspace static size_t required_setup_size(size_t num_tensors, size_t alignment) { const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); - // Layout: 6 ptr arrays, then 3 int arrays (no padding needed) - size_t size = 6 * ptr_size + 3 * int_size; + // Layout: 6 ptr arrays, then 6 int arrays + size_t size = 6 * ptr_size + 6 * int_size; size = ((size + alignment - 1) / alignment) * alignment; return size; } @@ -212,14 +222,44 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor // Select row-wise vs column-wise storage and adjust transpose flag for grouped GEMM. // Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and // fallback to column-wise data when row-wise is absent. +// Contains all information needed for GEMM setup - shape already accounts for storage layout. struct GroupedOperandSelection { - const transformer_engine::GroupedTensor *tensor = nullptr; + TensorShapeInfo shape; // Shape info with dims already swapped for columnwise if needed char *dptr = nullptr; + void *scale_inv = nullptr; transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; bool trans = false; - bool use_columnwise = false; }; +// Helper to create TensorShapeInfo from a GroupedTensor, optionally swapping first/last dims. +// When swap_dims=true, first_dims and last_dims are swapped to account for columnwise storage. +// Note: tensor_offsets are the same for rowwise and columnwise data (same element count per tensor). +inline TensorShapeInfo create_shape_info(const transformer_engine::GroupedTensor *t, + bool swap_dims) { + const bool has_first = t->first_dims.has_data(); + const bool has_last = t->last_dims.has_data(); + NVTE_CHECK(has_first || t->all_same_first_dim(), + "GroupedTensor is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || t->all_same_last_dim(), + "GroupedTensor is missing last_dims for varying shapes"); + + const int64_t *first_ptr = + has_first ? static_cast(t->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; + const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); + + const int64_t *offsets_ptr = t->tensor_offsets.has_data() + ? static_cast(t->tensor_offsets.dptr) + : nullptr; + + if (swap_dims) { + // Swap first/last to account for columnwise (transposed) storage + return {last_ptr, first_ptr, offsets_ptr, uniform_last, uniform_first}; + } + return {first_ptr, last_ptr, offsets_ptr, uniform_first, uniform_last}; +} + inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor *t, bool trans, bool is_A) { using namespace transformer_engine; @@ -236,31 +276,42 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: const DType row_dtype = t->data.dtype; const DType col_dtype = t->columnwise_data.dtype; GroupedOperandSelection sel; - sel.tensor = t; sel.trans = trans; const DType rep_dtype = has_row ? row_dtype : col_dtype; const bool is_fp8 = is_fp8_dtype(rep_dtype); const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); + // Helper to select columnwise storage (swaps dims in shape) + auto use_columnwise = [&]() { + sel.dptr = static_cast(t->columnwise_data.dptr); + sel.scale_inv = t->columnwise_scale_inv.dptr; + sel.dtype = col_dtype; + sel.shape = create_shape_info(t, /*swap_dims=*/true); + }; + + // Helper to select row-wise storage + auto use_rowwise = [&]() { + sel.dptr = static_cast(t->data.dptr); + sel.scale_inv = t->scale_inv.dptr; + sel.dtype = row_dtype; + sel.shape = create_shape_info(t, /*swap_dims=*/false); + }; + // Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed. if (is_fp8 && !non_tn_fp8_ok) { if (is_A) { if (!sel.trans) { NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"); - sel.dptr = static_cast(t->columnwise_data.dptr); - sel.dtype = col_dtype; + use_columnwise(); sel.trans = true; // using pre-transposed storage - sel.use_columnwise = true; return sel; } } else { // B if (sel.trans) { NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"); - sel.dptr = static_cast(t->columnwise_data.dptr); - sel.dtype = col_dtype; + use_columnwise(); sel.trans = false; // using pre-transposed storage - sel.use_columnwise = true; return sel; } } @@ -272,17 +323,13 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: NVTE_CHECK( !is_fp8 || non_tn_fp8_ok, "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); - sel.dptr = static_cast(t->columnwise_data.dptr); - sel.dtype = col_dtype; - sel.trans = !sel.trans; - sel.use_columnwise = true; + use_columnwise(); + sel.trans = !trans; // flip transpose for pre-transposed storage return sel; } - // Default: use row-wise data (column-wise case already handled above) - sel.dptr = static_cast(t->data.dptr); - sel.dtype = row_dtype; - sel.use_columnwise = false; + // Default: use row-wise data + use_rowwise(); return sel; } @@ -307,23 +354,15 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); const cudaDataType_t D_type = get_cuda_dtype(D->dtype()); - // For column-major layout: leading dimension is the number of rows in storage. - // If columnwise data was chosen, storage is already transposed. - // Storage dimensions for A: rows_A x cols_A with leading dimension lda_storage - int *rows_A = A_sel.use_columnwise ? ws.M : (A_sel.trans ? ws.K : ws.M); - int *cols_A = A_sel.use_columnwise ? ws.K : (A_sel.trans ? ws.M : ws.K); - int *lda_storage = rows_A; - // Storage dimensions for B: rows_B x cols_B with leading dimension ldb_storage - int *rows_B = B_sel.use_columnwise ? ws.N : (B_sel.trans ? ws.N : ws.K); - int *cols_B = B_sel.use_columnwise ? ws.K : (B_sel.trans ? ws.K : ws.N); - int *ldb_storage = rows_B; - - NVTE_CHECK_CUBLAS( - cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rows_A, cols_A, lda_storage)); - NVTE_CHECK_CUBLAS( - cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, rows_B, cols_B, ldb_storage)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.M, ws.N, ws.M)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.M, ws.N, ws.M)); + // Storage dimensions computed by kernel, leading dimension = rows + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, ws.a_rows, + ws.a_cols, ws.a_rows)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, ws.b_rows, + ws.b_cols, ws.b_rows)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.d_rows, + ws.d_cols, ws.d_rows)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.d_rows, + ws.d_cols, ws.d_rows)); } inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, @@ -356,15 +395,13 @@ inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, if (!is_fp8_a && !is_fp8_b) return; if (is_fp8_a) { - void *a_scale_inv = A_sel.use_columnwise ? A_sel.tensor->columnwise_scale_inv.dptr - : A_sel.tensor->scale_inv.dptr; + void *a_scale_inv = A_sel.scale_inv; NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); } if (is_fp8_b) { - void *b_scale_inv = B_sel.use_columnwise ? B_sel.tensor->columnwise_scale_inv.dptr - : B_sel.tensor->scale_inv.dptr; + void *b_scale_inv = B_sel.scale_inv; NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); @@ -406,24 +443,19 @@ inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, } // Single kernel that sets up all GEMM parameters. -// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix M/N/K, +// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix dimensions, // but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes. -// We bridge the mismatch on GPU by computing per-group pointers and dims in one kernel. +// We bridge the mismatch on GPU by computing per-group pointers and storage dims in one kernel. __global__ void setup_grouped_gemm_kernel( // Output arrays - void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *M, int *N, int *K, + void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, + int *a_rows, int *a_cols, int *b_rows, int *b_cols, int *d_rows, int *d_cols, float **alpha_ptrs, float **beta_ptrs, - // Base pointers + // Inputs char *a_base, char *b_base, char *c_base, char *d_base, - // Dimension info (per tensor) TensorShapeInfo A_meta, TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, - // Element sizes size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, - // Alpha/beta pointers (per-matrix arrays) float *alpha_ptr, float *beta_ptr, - // Transpose flags - bool transa, bool transb, - // Number of tensors size_t num_tensors) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= num_tensors) return; @@ -433,6 +465,8 @@ __global__ void setup_grouped_gemm_kernel( int64_t a_last = A_meta.last_dims ? A_meta.last_dims[idx] : A_meta.uniform_last; int64_t b_first = B_meta.first_dims ? B_meta.first_dims[idx] : B_meta.uniform_first; int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last; + int64_t d_first = D_meta.first_dims ? D_meta.first_dims[idx] : D_meta.uniform_first; + int64_t d_last = D_meta.last_dims ? D_meta.last_dims[idx] : D_meta.uniform_last; // Compute offsets (from array or compute from uniform dims) int64_t a_offset = @@ -450,12 +484,16 @@ __global__ void setup_grouped_gemm_kernel( C_ptrs[idx] = c_base + c_offset * c_elem_size; D_ptrs[idx] = d_base + d_offset * d_elem_size; - // Compute M, N, K dimensions from tensor shapes - // Input A is stored as {K,M} when !transa, {M,K} when transa - // Input B is stored as {N,K} when !transb, {K,N} when transb - M[idx] = static_cast(transa ? a_first : a_last); - K[idx] = static_cast(transa ? a_last : a_first); - N[idx] = static_cast(transb ? b_last : b_first); + // Compute storage dimensions for cuBLAS matrix layouts. + // For INPUTS (A, B): Row-wise storage is seen as transposed column-major by cuBLAS, + // so rows=last, cols=first. For columnwise, dims are already swapped. + a_rows[idx] = static_cast(a_last); + a_cols[idx] = static_cast(a_first); + b_rows[idx] = static_cast(b_last); + b_cols[idx] = static_cast(b_first); + // For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N). + d_rows[idx] = static_cast(d_first); + d_cols[idx] = static_cast(d_last); // Fill alpha/beta pointers (per-matrix) alpha_ptrs[idx] = alpha_ptr + idx; @@ -468,8 +506,9 @@ inline void launch_grouped_gemm_setup( const GroupedOperandSelection &B_sel, const transformer_engine::GroupedTensor *C, const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream) { - TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A_sel.tensor); - TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B_sel.tensor); + // Use shape info from selection (already accounts for columnwise dimension swap) + TensorShapeInfo A_meta = A_sel.shape; + TensorShapeInfo B_meta = B_sel.shape; TensorShapeInfo C_meta = TensorShapeInfo::create_shape_info_for_C(C, D); TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); @@ -485,10 +524,11 @@ inline void launch_grouped_gemm_setup( const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; setup_grouped_gemm_kernel<<>>( - ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.M, ws.N, ws.K, ws.alpha_ptrs, ws.beta_ptrs, - A_sel.dptr, B_sel.dptr, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, - b_elem_size, c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), - static_cast(beta_tensor->data.dptr), A_sel.trans, B_sel.trans, num_tensors); + ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.a_rows, ws.a_cols, ws.b_rows, ws.b_cols, + ws.d_rows, ws.d_cols, ws.alpha_ptrs, ws.beta_ptrs, A_sel.dptr, B_sel.dptr, c_base, d_base, + A_meta, B_meta, C_meta, D_meta, a_elem_size, b_elem_size, c_elem_size, d_elem_size, + static_cast(alpha_tensor->data.dptr), + static_cast(beta_tensor->data.dptr), num_tensors); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -577,10 +617,11 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT // Compute average dimensions for heuristics // K dimension: if transa, K is A's first dim; if not, K is A's last dim + // Use original inputA and transa for heuristics (not modified A_sel.trans) int64_t avg_m_val = config_.avg_m.value_or(compute_avg_first_dim(outputD)); int64_t avg_n_val = config_.avg_n.value_or(compute_avg_last_dim(outputD)); - int64_t avg_k_val = config_.avg_k.value_or(A_sel.trans ? compute_avg_first_dim(A_sel.tensor) - : compute_avg_last_dim(A_sel.tensor)); + int64_t avg_k_val = config_.avg_k.value_or(transa ? compute_avg_first_dim(inputA) + : compute_avg_last_dim(inputA)); // Heuristic selection cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, From c5c2fbf59388234dd0f402d36eff708ec3fbb684 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Jan 2026 18:14:39 +0000 Subject: [PATCH 53/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_grouped_gemm.cu | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index c1a75f0523..a1206474ea 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -243,15 +243,13 @@ inline TensorShapeInfo create_shape_info(const transformer_engine::GroupedTensor NVTE_CHECK(has_last || t->all_same_last_dim(), "GroupedTensor is missing last_dims for varying shapes"); - const int64_t *first_ptr = - has_first ? static_cast(t->first_dims.dptr) : nullptr; + const int64_t *first_ptr = has_first ? static_cast(t->first_dims.dptr) : nullptr; const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); - const int64_t *offsets_ptr = t->tensor_offsets.has_data() - ? static_cast(t->tensor_offsets.dptr) - : nullptr; + const int64_t *offsets_ptr = + t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) : nullptr; if (swap_dims) { // Swap first/last to account for columnwise (transposed) storage @@ -448,14 +446,12 @@ inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, // We bridge the mismatch on GPU by computing per-group pointers and storage dims in one kernel. __global__ void setup_grouped_gemm_kernel( // Output arrays - void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, - int *a_rows, int *a_cols, int *b_rows, int *b_cols, int *d_rows, int *d_cols, - float **alpha_ptrs, float **beta_ptrs, + void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *a_rows, int *a_cols, + int *b_rows, int *b_cols, int *d_rows, int *d_cols, float **alpha_ptrs, float **beta_ptrs, // Inputs - char *a_base, char *b_base, char *c_base, char *d_base, - TensorShapeInfo A_meta, TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, - size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, - float *alpha_ptr, float *beta_ptr, + char *a_base, char *b_base, char *c_base, char *d_base, TensorShapeInfo A_meta, + TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, size_t a_elem_size, + size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, float *alpha_ptr, float *beta_ptr, size_t num_tensors) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= num_tensors) return; @@ -527,8 +523,8 @@ inline void launch_grouped_gemm_setup( ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.a_rows, ws.a_cols, ws.b_rows, ws.b_cols, ws.d_rows, ws.d_cols, ws.alpha_ptrs, ws.beta_ptrs, A_sel.dptr, B_sel.dptr, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, b_elem_size, c_elem_size, d_elem_size, - static_cast(alpha_tensor->data.dptr), - static_cast(beta_tensor->data.dptr), num_tensors); + static_cast(alpha_tensor->data.dptr), static_cast(beta_tensor->data.dptr), + num_tensors); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -620,8 +616,8 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT // Use original inputA and transa for heuristics (not modified A_sel.trans) int64_t avg_m_val = config_.avg_m.value_or(compute_avg_first_dim(outputD)); int64_t avg_n_val = config_.avg_n.value_or(compute_avg_last_dim(outputD)); - int64_t avg_k_val = config_.avg_k.value_or(transa ? compute_avg_first_dim(inputA) - : compute_avg_last_dim(inputA)); + int64_t avg_k_val = + config_.avg_k.value_or(transa ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA)); // Heuristic selection cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, From ee71c96552c4065bee9826992e1cadfd9556c012 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 15 Jan 2026 10:35:06 -0800 Subject: [PATCH 54/61] multi-GPU grouped quantize working now in shard_map (with hack to use single-stream for multi tensor quantize) --- transformer_engine/common/cast/cast.cu | 22 +---------- .../jax/cpp_extensions/quantization.py | 21 ++++------ .../jax/csrc/extensions/quantization.cpp | 39 +++++++------------ transformer_engine/jax/flax/__init__.py | 3 +- transformer_engine/jax/flax/module.py | 20 ++++++++++ transformer_engine/jax/sharding.py | 6 ++- 6 files changed, 49 insertions(+), 62 deletions(-) diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 73467d7275..dc77a35886 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -75,29 +75,9 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, constexpr bool IS_ACT = false; - const size_t num_streams = nvte_get_num_compute_streams(); - - int num_stream_used = std::min(num_streams, num_tensors); - // wait for current stream to finish - NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream)); - for (int s = 0; s < num_stream_used; s++) { - NVTE_CHECK_CUDA( - cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0))); - } - for (int i = 0; i < num_tensors; i++) { dispatch::quantize_fwd_helper( - inputs[i], outputs[i], quant_configs, detail::get_compute_stream(i % num_streams)); - } - - // record events on compute streams - for (int s = 0; s < num_stream_used; s++) { - NVTE_CHECK_CUDA( - cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s))); - } - // wait for all compute streams to finish - for (int s = 0; s < num_stream_used; s++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s))); + inputs[i], outputs[i], quant_configs, stream); } } diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index b8ea3bd4f4..4a2c001f5b 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1216,19 +1216,14 @@ def grouped_quantize( scale = scale.at[i].set(quantizer_i.scale[0]) if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: - # TODO fixme, measure perf with always scale/amax of 1 to just isolate quant and gemm - # HACK: assumes equal group sizes - assert group_axis == 0, f"Currently only group_axis = 0 is supported for current-tensor-scaling, but received {group_axis=}" - grouped_amax = jnp.max(jnp.abs(x.reshape((n_groups, x.shape[0]//n_groups, *x.shape[1:]))), axis=tuple(range(1, x.ndim+1))) - # import pdb; pdb.set_trace() - # if amax is not None: - # row_amax = amax - # else: - # row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) - # segment_ids = jnp.repeat( - # jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis] - # ) - # grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) + if amax is not None: + row_amax = amax + else: + row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) + segment_ids = jnp.repeat( + jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis] + ) + grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) for i in range(n_groups): tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype, margin=0.0) scale = scale.at[i].set(tmp_scale[0]) diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index ad3553313f..2b7beb8d6b 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -375,29 +375,19 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty size_t num_groups = group_sizes.dimensions()[0]; size_t dim_list_bytes = group_size_dtype_bytes * num_groups; std::vector dim_list_host(num_groups); - // HACK: assumes batched gemm with equal group sizes - for (size_t i = 0; i < num_groups; i++) { - if (input_dims[0] == num_groups) { - dim_list_host[i] = 1; - continue; - } - dim_list_host[i] = m / num_groups; - } - // auto *group_size_ptr = reinterpret_cast(group_sizes.untyped_data()); - // cudaMemcpyAsync(dim_list_host.data(), group_size_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - // stream); - // // Note: This may break cudaGraph. - // cudaStreamSynchronize(stream); - // printf("GroupedQuantizeFFI: m=%zu, n=%zu, group sizes = ", m, n); - // for (size_t i = 0; i < num_groups; i++) { - // printf("%d ", dim_list_host[i]); - // } - // printf("\n"); - - size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - NVTE_CHECK(m == sum_group_sizes || input_dims[0] == sum_group_sizes, - "Unexpected group_sizes! Got %zu (M=%zu, input_dims[0] = %zu)", sum_group_sizes, m, - input_dims[0]); + auto *group_size_ptr = reinterpret_cast(group_sizes.untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), group_size_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + + // For MaxText case, I think is okay if this check fails as we are expecting to overallocate the buffers in the current use_ring_of_experts impl, which will result in the group sizes not filling the whole tensor. + // size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + // NVTE_CHECK(m == sum_group_sizes || input_dims[0] == sum_group_sizes, + // "Unexpected group_sizes! Got ", sum_group_sizes, " (M=", m, ", input_dims[0] = ", input_dims[0], ")"); + + // TODO(jberchtold): This is a temporary fix to zero out the output buffers to prevent NaNs in output when this buffer is over-allocated and the groups do not fill the whole buffer. Though these NaNs should be ignored in the downstream GEMM, so more debugging is needed to see why they cause issues. + cudaMemsetAsync(outputs->untyped_data(), 0, outputs->size_bytes(), stream); if (is_delayed_scaling) { NVTE_CHECK(amaxs->dimensions()[0] == num_groups, "Unexpected amax size, Expected ", num_groups, @@ -505,8 +495,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Ret() // amax .Attr("scaling_mode") .Attr("q_layout") - .Attr("flatten_axis"), - FFI_CudaGraph_Traits); + .Attr("flatten_axis")); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index 59a0958b7b..1a19685697 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -4,7 +4,7 @@ """Transformer Engine bindings for JAX""" from .module import DenseGeneral, LayerNorm from .module import LayerNormDenseGeneral, LayerNormMLP -from .module import wrap_function_in_te_state_module, make_dot_general_cls, make_einsum_cls +from .module import wrap_function_in_te_state_module, make_dot_general_cls, make_einsum_cls, make_ragged_dot_cls from .transformer import extend_logical_axis_rules from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import TransformerLayer, TransformerLayerType @@ -17,6 +17,7 @@ "wrap_function_in_te_state_module", "make_dot_general_cls", "make_einsum_cls", + "make_ragged_dot_cls", "extend_logical_axis_rules", "DotProductAttention", "MultiHeadAttention", diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 3b4a5ef148..03d5581ae6 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1520,3 +1520,23 @@ def reorder_rhs_for_grouped_gemm(tensor, bdims, cdims): return jnp.einsum(s, x, kernel, _dot_general=dot_general, **kwargs) return wrap_function_in_te_state_module(te_einsum, quantization_recipe, "einsum")() + +def make_ragged_dot_cls(quantization_recipe): + import jax + def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): + num_groups = group_sizes.shape[0] + quantizer_set = generate_quantizer_set(n_groups=num_groups) + + target_out_shape = jax.lax.ragged_dot(x, kernel, group_sizes=group_sizes).shape + + out = grouped_dense( + x, + kernel, + group_sizes=group_sizes, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set + ) + + return out.reshape(target_out_shape) + + return wrap_function_in_te_state_module(te_grouped_dot_general, quantization_recipe, "ragged_dot")() diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index b4b8c42027..4171d1c7b0 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -51,7 +51,8 @@ def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh): return mesh.shape[resource], resource -def _validate_mesh_resource_configuration(mesh_resource): +# TODO(jberchtold): FIXME, this validation fails in FP8CS amax reduction because the GlobalMeshResource is set but there is no active mesh in the context (afaict shard_map does not share it's mesh as a context), so this is triggering a FalsePositive assert. However, I am not sure if we can safely ignore this when the mesh is empty or all axes are manual as some users may use shard_map with some axes manual and some auto. +# def _validate_mesh_resource_configuration(mesh_resource): """Validate that the mesh resource configuration is consistent and conflict-free.""" is_tp_enabled = ( mesh_resource.tp_resource is not None and get_mesh_axis_size(mesh_resource.tp_resource) > 1 @@ -375,7 +376,8 @@ def global_mesh_resource() -> MeshResource: " context. If you are not using multiple GPUs, you can use an empty MeshResource by" " wrapping your program in 'with global_shard_guard(MeshResource()):'" ) - _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) + # TODO(jberchtold): FIXME, this validation fails in FP8CS amax reduction because the GlobalMeshResource is set but there is no active mesh in the context (afaict shard_map does not share it's mesh as a context), so this is triggering a FalsePositive assert. However, I am not sure if we can safely ignore this when the mesh is empty or all axes are manual as some users may use shard_map with some axes manual and some auto. + # _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) return _GLOBAL_MESH_RESOURCE From 9856862450547b2cbd688f30dc4fa8ecda111227 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 15 Jan 2026 11:11:55 -0800 Subject: [PATCH 55/61] reduce size of zero'ing memset to only uninitialized part of quantization buffer --- .../jax/csrc/extensions/quantization.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 2b7beb8d6b..3d98126290 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -382,13 +382,10 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty cudaStreamSynchronize(stream); // For MaxText case, I think is okay if this check fails as we are expecting to overallocate the buffers in the current use_ring_of_experts impl, which will result in the group sizes not filling the whole tensor. - // size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); // NVTE_CHECK(m == sum_group_sizes || input_dims[0] == sum_group_sizes, // "Unexpected group_sizes! Got ", sum_group_sizes, " (M=", m, ", input_dims[0] = ", input_dims[0], ")"); - // TODO(jberchtold): This is a temporary fix to zero out the output buffers to prevent NaNs in output when this buffer is over-allocated and the groups do not fill the whole buffer. Though these NaNs should be ignored in the downstream GEMM, so more debugging is needed to see why they cause issues. - cudaMemsetAsync(outputs->untyped_data(), 0, outputs->size_bytes(), stream); - if (is_delayed_scaling) { NVTE_CHECK(amaxs->dimensions()[0] == num_groups, "Unexpected amax size, Expected ", num_groups, ", got ", amaxs->dimensions()[0]); @@ -402,6 +399,13 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty size_t num_non_empty_groups = 0; size_t total_rowwise_sinv_size = 0; size_t total_colwise_sinv_size = 0; + + + // TODO(jberchtold): This is a temporary fix to zero out the output buffers to prevent NaNs in output when this buffer is over-allocated and the groups do not fill the whole buffer. Though these NaNs should be ignored in the downstream GEMM, so more debugging is needed to see why they cause issues. + size_t used_output_size = (sum_group_sizes*non_group_m) * n * output_dtype_bytes; + cudaMemsetAsync(outputs->untyped_data() + used_output_size, 0, outputs->size_bytes() - used_output_size, stream); + + for (size_t i = 0; i < num_groups; i++) { size_t m_i = dim_list_host[i] * non_group_m; // Skip for zero-size input + shiff the scale ptr From 23b5de303865ec8c560f9f4fee55015edddf43cf Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 15 Jan 2026 15:12:14 -0800 Subject: [PATCH 56/61] fix TE/JAX to work compile with latest nvte_grouped_gemm API changes --- transformer_engine/jax/cpp_extensions/gemm.py | 7 ------- transformer_engine/jax/csrc/extensions/gemm.cpp | 16 ++++++---------- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 5c53dedb8a..38d21f26ec 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2123,13 +2123,6 @@ def grouped_gemm( assert not has_bias or bias.shape == (group_sizes.size, N) bias = jnp.empty((), jnp.float32) if bias is None else bias - # print(f"{lhs_data.shape=}, {rhs_data.shape=}, {group_sizes.shape=}") - # print(f"{M=}, {N=}, {K_lhs=}, {K_rhs=}") - # import pdb; pdb.set_trace() - # import pdb; pdb.set_trace() - # print(f"{lhs_is_trans=}, {rhs_is_trans=}") - # import pdb; pdb.set_trace() - num_gemms = group_sizes.shape[0] alpha = jnp.ones((num_gemms,), jnp.float32) beta = jnp.zeros((num_gemms,), jnp.float32) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 0bfab2d7dc..13feef709a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -604,21 +604,17 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type } NVTEGroupedTensor out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms, outShape); - // NVTE_CHECK(!rhs_is_trans && !lhs_is_trans, "TE grouped GEMM only supports non-transposed inputs but received rhs_is_trans=", rhs_is_trans, " lhs_is_trans=", lhs_is_trans); - nvte_grouped_gemm( - rhs_is_trans, lhs_is_trans, - alpha_tensor.data(), - rhs_tensor, lhs_tensor, - beta_tensor.data(), + rhs_tensor, rhs_is_trans, + lhs_tensor, lhs_is_trans, nullptr, out_tensor, + alpha_tensor.data(), + beta_tensor.data(), workspace_setup.data(), workspace_cublas.data(), - stream, - nullptr, - nullptr, - nullptr); + nullptr, // config (use defaults) + stream); return ffi_with_cuda_error_check(); } From 179aab63d57b31298b427929c83955052779e201 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 20 Jan 2026 16:21:46 -0800 Subject: [PATCH 57/61] some tests starting to work --- tests/jax/test_custom_call_compute.py | 93 +++++--- transformer_engine/jax/cpp_extensions/gemm.py | 28 ++- .../jax/cpp_extensions/quantization.py | 2 +- .../jax/csrc/extensions/gemm.cpp | 221 ++++++++++++++---- transformer_engine/jax/flax/module.py | 4 +- transformer_engine/jax/permutation.py | 4 + 6 files changed, 270 insertions(+), 82 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 082a99cd8b..5c8c5d1b48 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1761,50 +1761,68 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): GROUPED_DENSE_INPUT_SHAPES = [ # (n_groups, m, n, k), the actual m will be multiplied by 32 - (5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4 - (8, 64, 32, 128), - (8, 64, 128, 256), + # (5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4 + + # (4, 16, 4, 4), + + (3, 192, 64, 96), + + # (8, 64, 32, 128), + # (8, 64, 128, 256), ] @pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES) class TestGroupedDense: def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): - lhs_contract_dim, _ = contracting_dims - assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3 - if bias is None: - bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype) - else: - assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2]) - remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop() - lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis) - rhs = jnp.split(rhs, rhs.shape[0], axis=0) - bias = jnp.split(bias, bias.shape[0], axis=0) - ref_out = [] - dim_num = (contracting_dims, ((), ())) - for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias): - out_i = jax.lax.dot_general( - lhs_i, rhs_i, dim_num, precision=jax.lax.Precision.HIGHEST - ) + jnp.expand_dims(bias_i, axis=0) - ref_out.append(jnp.squeeze(out_i)) - return ref_out + out = jax.lax.ragged_dot(lhs, rhs, group_sizes) + print(f"In ref grouped dense: {lhs.shape=}, {rhs.shape=}, {out.shape=}") + return out + + dot_dimension_numbers = (((), ()), contracting_dims) + lhs_ragged_dimensions = (0,) + rhs_group_dimensions = (0,) + print(lhs.shape, rhs.shape, group_sizes, dot_dimension_numbers, lhs_ragged_dimensions, rhs_group_dimensions) + dims = jax.lax.RaggedDotDimensionNumbers(dot_dimension_numbers, lhs_ragged_dimensions, rhs_group_dimensions) + return jax.lax.ragged_dot_general(lhs, rhs, group_sizes, dims) + # lhs_contract_dim, _ = contracting_dims + # assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3 + # if bias is None: + # bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype) + # else: + # assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2]) + # remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop() + # lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis) + # rhs = jnp.split(rhs, rhs.shape[0], axis=0) + # bias = jnp.split(bias, bias.shape[0], axis=0) + # ref_out = [] + # dim_num = (contracting_dims, ((), ())) + # for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias): + # out_i = jax.lax.dot_general( + # lhs_i, rhs_i, dim_num, precision=jax.lax.Precision.HIGHEST + # ) + jnp.expand_dims(bias_i, axis=0) + # ref_out.append(jnp.squeeze(out_i)) + # return ref_out def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 4) n_groups, m, n, k = input_shape - group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) - group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) - group_sizes = jnp.diff(group_sizes) - # Make one empty input lhs to test empty GEMM handling - group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1]) - group_sizes = group_sizes.at[1].set(0) + # group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) + # group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) + # group_sizes = jnp.diff(group_sizes) + + # # Make one empty input lhs to test empty GEMM handling + # group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1]) + # group_sizes = group_sizes.at[1].set(0) + + group_sizes = jnp.full((n_groups,), m // n_groups) assert group_sizes.sum() == m # *32 to make sure that input shape works for MXFP8 - group_sizes = group_sizes * 32 - m = m * 32 + # group_sizes = group_sizes * 32 + # m = m * 32 lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m) rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k) @@ -1822,9 +1840,15 @@ def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", wi def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype): assert out.dtype == ref_list[0].dtype - out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0) - for i in range(len(ref_list)): - assert_allclose(out_list[i], ref_list[i], dtype=dtype) + import numpy as np + np.set_printoptions(threshold=10000) + jnp.set_printoptions(threshold=10000) + print("Actual:", out) + print("Expected:", ref_list) + assert_allclose(out, ref_list, dtype=dtype) + # out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + # for i in range(len(ref_list)): + # assert_allclose(out_list[i], ref_list[i], dtype=dtype) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) @pytest_parametrize_wrapper("layout", ["NN"]) @@ -1979,7 +2003,7 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): # ('ij,jk->ik', (64, 32), (32, 128)), # ('bij,bjk->bik', (8, 64, 32), (8, 32, 128)), # ('abc,cde->abde', (4, 8, 16), (16, 32, 64)), - ('BSM,BSEC->EBCM', (2, 4096, 4096), (2, 4096, 8, 1024)), + ('BSM,BSEC->EBCM', (2, 16, 16), (2, 16, 8, 8)), ('EBCM,EMH->EBCH', (8, 2, 1024, 4096), (8, 4096, 14336)) , ('EBCM,EMH->EBCH', (8, 2, 1024, 4096), (8, 4096, 14336)), ('EBCH,EHM->EBCM', (8, 2, 1024, 14336), (8, 14336, 4096)), @@ -2014,6 +2038,9 @@ def test_einsum_fwd(self, eqn, a_shape, b_shape, dtype, quantization_recipe): te_out = jax.jit(functools.partial(self._te_einsum, eqn, quantization_recipe=quantization_recipe))(a, b) ref_out = jax.jit(functools.partial(self._ref_einsum, eqn))(a, b) + # jax.config.update("jax_numpy_rank_promotion", "raise") + # jnp.set_printoptions(threshold=jnp.inf, linewidth=jnp.inf) + # print(te_out) assert_allclose(te_out, ref_out, dtype=dtype) def test_einsum_fwd_and_bwd(self, eqn, a_shape, b_shape, dtype, quantization_recipe): diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 38d21f26ec..23d774b2ca 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1543,7 +1543,8 @@ def abstract( out_shape = (M, N) if is_grouped_dense_wgrad: - out_shape = (group_sizes_aval.size, M, N) + num_tensors = group_sizes_aval.size // 2 # packed int32 -> logical int64 shape + out_shape = (num_tensors, M, N) out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) return (out_aval, workspace_aval) @@ -1980,8 +1981,6 @@ def grouped_gemm( lhs: [M, K] or [K, N] rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K] """ - # TODO(Phuong): implement the group_offset - group_offset = group_offset or jnp.zeros((1,), jnp.int32) # TODO(Phuong): implement the precision del precision @@ -2117,13 +2116,29 @@ def grouped_gemm( else: assert group_sizes.size == rhs_shape[0] - assert group_offset.size == 1 - has_bias = bias is not None assert not has_bias or bias.shape == (group_sizes.size, N) bias = jnp.empty((), jnp.float32) if bias is None else bias - num_gemms = group_sizes.shape[0] + + if group_offset is None: + # Compute group_offset as cumulative sum of group_sizes, starting with 0 + group_offset = jnp.concatenate([jnp.array([0], dtype=jnp.int32), jnp.cumsum(group_sizes, dtype=jnp.int32)[:-1]]) + group_offset *= K_lhs # Offset is by number of elements total, not number of rows + + jax.debug.print("group_sizes: {}, group_offset: {}", group_sizes, group_offset) + jax.debug.print("M={}, jnp.sum(group_sizes)={}, N={}, K_lhs={}", M, jnp.sum(group_sizes), N, K_lhs) + jax.debug.print("lhs_data.size={}, group_offset={}", lhs_data.size, group_offset) + + # print(f"{lhs_data.shape=}, {rhs_data.shape=}, {M=}, {N=}, {K_lhs=}") + + # Interlace zeros with group_sizes to upcast packed int32s to int64 + # This ensures proper alignment and prevents overflow issues + zeros = jnp.zeros_like(group_sizes, dtype=jnp.int32) + group_sizes = jnp.stack([group_sizes, zeros], axis=1).flatten() + group_offset = jnp.stack([group_offset, zeros], axis=1).flatten() + + num_gemms = group_sizes.shape[0] // 2 # Due to interlaced zeros to support int64 alpha = jnp.ones((num_gemms,), jnp.float32) beta = jnp.zeros((num_gemms,), jnp.float32) (out,) = GroupedGemmPrimitive.outer_primitive.bind( @@ -2147,4 +2162,5 @@ def grouped_gemm( is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, ) + print(f"GroupedGemm: {lhs_data.shape=}, {rhs_data.shape=}, {out.shape=}") return out diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 4a2c001f5b..8add335fbf 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -96,7 +96,7 @@ def abstract( dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] out_shape = x_aval.shape - assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert scale_aval is None or scale_aval.dtype == jnp.float32, f"scale must be float32 but received {scale_aval}" if stochastic_rounding: assert ScalingMode( scaling_mode diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 13feef709a..108a6b6843 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -399,37 +399,58 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGro .Ret() // dummy_output .Attr("num_gemms")); -NVTEGroupedTensor make_grouped_tensor(Buffer_Type const& data, std::optional scale_inv, JAXX_Scaling_Mode scaling_mode, size_t num_tensors, NVTEShape const& dataShape) { - // printf("make_grouped_tensor data shape: "); - // for (auto dim : data.dimensions()) { - // printf("%zu, ", dim); - // } - // printf("\n"); - // NVTEShape logical_shape{}; - // if (data.dimensions().size() == 1) { - // // HACK - // size_t cdim_size = 4096; - // logical_shape.ndim = 2; - // logical_shape.data[0] = data.dimensions()[0] / cdim_size; - // logical_shape.data[1] = cdim_size; - // printf("NUM TENSORS: %zu\n", num_tensors); - // } - // else { - // NVTE_CHECK(data.dimensions().size() == 2, "Expected 2D tensor for GEMM operand but received ndim=", data.dimensions().size()); - - // logical_shape.ndim = 2; - // logical_shape.data[0] = data.dimensions()[0]; - // logical_shape.data[1] = data.dimensions()[1]; - // } - - NVTEGroupedTensor grouped_tensor = nvte_create_grouped_tensor(get_nvte_scaling_mode(scaling_mode), num_tensors, dataShape); - - NVTEBasicTensor data_tensor{reinterpret_cast(data.untyped_data()), - static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())), - dataShape}; - nvte_set_grouped_tensor_param(&grouped_tensor, kNVTEGroupedRowwiseData, &data_tensor); +class JAXX_GroupedTensorWrapper { +public: + JAXX_GroupedTensorWrapper() = delete; + JAXX_GroupedTensorWrapper(JAXX_Scaling_Mode scaling_mode, size_t num_tensors, + NVTEShape const& dataShape); + ~JAXX_GroupedTensorWrapper() = default; + + void set_rowwise(Buffer_Type const& data, std::optional const& scale_inv); + void set_group_info(Buffer_Type const& group_sizes, Buffer_Type const& group_offsets); + + operator NVTEGroupedTensor() const { return m_grouped_tensor; } + NVTEGroupedTensor const& get_grouped_tensor() const; + +private: + NVTEShape m_data_shape{}; + NVTEGroupedTensor m_grouped_tensor{}; + + // Internal tensors. These need to be kept alive as long as the grouped tensor is alive. + NVTEBasicTensor m_data_tensor{}; + NVTEBasicTensor m_scale_inv_tensor{}; + + NVTEBasicTensor m_sizes_tensor{}; + NVTEBasicTensor m_offsets_tensor{}; +}; + +JAXX_GroupedTensorWrapper::JAXX_GroupedTensorWrapper(JAXX_Scaling_Mode scaling_mode, + size_t num_tensors, + NVTEShape const& dataShape) { + m_data_shape = dataShape; + m_grouped_tensor = nvte_create_grouped_tensor(get_nvte_scaling_mode(scaling_mode), num_tensors, dataShape); +} + +void JAXX_GroupedTensorWrapper::set_rowwise(Buffer_Type const& data, + std::optional const& scale_inv) { + printf("set_rowwise data shape: XLA buffer shape: "); + for (auto dim : data.dimensions()) { + printf("%zu, ", dim); + } + printf("NVTEShape: "); + for (int i = 0; i < m_data_shape.ndim; ++i) { + printf("%d, ", m_data_shape.data[i]); + } + printf("\n"); + NVTEDType data_dtype = static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())); + m_data_tensor = NVTEBasicTensor{reinterpret_cast(data.untyped_data()), data_dtype, + m_data_shape}; + + nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedRowwiseData, &m_data_tensor); if (scale_inv.has_value()) { + NVTEDType scale_inv_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())); NVTEShape logical_scale_shape{}; if (scale_inv->dimensions().size() == 1) { logical_scale_shape.ndim = 1; @@ -439,20 +460,116 @@ NVTEGroupedTensor make_grouped_tensor(Buffer_Type const& data, std::optionaldimensions()[0]; logical_scale_shape.data[1] = scale_inv->dimensions()[1]; } else { - NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM scale_inv but received ndim=", scale_inv->dimensions().size()); + NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM scale_inv but received ndim=", + scale_inv->dimensions().size()); } - NVTEBasicTensor scale_inv_tensor{reinterpret_cast(scale_inv->untyped_data()), - static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())), - logical_scale_shape}; - nvte_set_grouped_tensor_param(&grouped_tensor, kNVTEGroupedRowwiseScaleInv, &scale_inv_tensor); + m_scale_inv_tensor = NVTEBasicTensor{reinterpret_cast(scale_inv->untyped_data()), + scale_inv_dtype, logical_scale_shape}; + nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedRowwiseScaleInv, + &m_scale_inv_tensor); } +} + +void JAXX_GroupedTensorWrapper::set_group_info(Buffer_Type const& group_sizes, + Buffer_Type const& group_offsets) { + NVTEDType sizes_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(group_sizes.element_type())); + NVTEDType offsets_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(group_offsets.element_type())); + + NVTE_CHECK(sizes_dtype == NVTEDType::kNVTEInt32, + "group_sizes must be of type int32."); + NVTE_CHECK(offsets_dtype == NVTEDType::kNVTEInt32, + "group_offsets must be of type int32."); + + // JAX only supports int32 but cuBLAS requires int64 so we pack two int32 into one int64 + size_t num_tensors = group_sizes.dimensions()[0] / 2; + NVTE_CHECK(group_sizes.dimensions().size() == 1, + "group_sizes must be a 1D tensor with length equal to the number of tensors."); + NVTE_CHECK(group_offsets.dimensions().size() == 1, + "group_offsets must be a 1D tensor with length equal to the number of tensors."); + NVTE_CHECK(group_offsets.dimensions()[0] == 2 * num_tensors, + "group_sizes and group_offsets must have the same number of elements."); + + NVTEShape shape{}; + shape.ndim = 1; + shape.data[0] = num_tensors; + + m_sizes_tensor = NVTEBasicTensor{reinterpret_cast(group_sizes.untyped_data()), + NVTEDType::kNVTEInt64, + shape}; + m_offsets_tensor = NVTEBasicTensor{reinterpret_cast(group_offsets.untyped_data()), + NVTEDType::kNVTEInt64, + shape}; + + nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedFirstDims, &m_sizes_tensor); + nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedTensorOffsets, &m_offsets_tensor); +} + +NVTEGroupedTensor const& JAXX_GroupedTensorWrapper::get_grouped_tensor() const { + return m_grouped_tensor; +} + +JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const& data, std::optional scale_inv, JAXX_Scaling_Mode scaling_mode, size_t num_tensors, NVTEShape const& dataShape) { + JAXX_GroupedTensorWrapper grouped_tensor_wrapper(scaling_mode, num_tensors, dataShape); + grouped_tensor_wrapper.set_rowwise(data, scale_inv); - return grouped_tensor; + return grouped_tensor_wrapper; } +// NVTEGroupedTensor make_grouped_tensor(Buffer_Type const& data, std::optional scale_inv, JAXX_Scaling_Mode scaling_mode, size_t num_tensors, NVTEShape const& dataShape) { +// // printf("make_grouped_tensor data shape: "); +// // for (auto dim : data.dimensions()) { +// // printf("%zu, ", dim); +// // } +// // printf("\n"); +// // NVTEShape logical_shape{}; +// // if (data.dimensions().size() == 1) { +// // // HACK +// // size_t cdim_size = 4096; +// // logical_shape.ndim = 2; +// // logical_shape.data[0] = data.dimensions()[0] / cdim_size; +// // logical_shape.data[1] = cdim_size; +// // printf("NUM TENSORS: %zu\n", num_tensors); +// // } +// // else { +// // NVTE_CHECK(data.dimensions().size() == 2, "Expected 2D tensor for GEMM operand but received ndim=", data.dimensions().size()); + +// // logical_shape.ndim = 2; +// // logical_shape.data[0] = data.dimensions()[0]; +// // logical_shape.data[1] = data.dimensions()[1]; +// // } + +// NVTEGroupedTensor grouped_tensor = nvte_create_grouped_tensor(get_nvte_scaling_mode(scaling_mode), num_tensors, dataShape); + +// NVTEBasicTensor data_tensor{reinterpret_cast(data.untyped_data()), +// static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())), +// dataShape}; +// nvte_set_grouped_tensor_param(&grouped_tensor, kNVTEGroupedRowwiseData, &data_tensor); + +// if (scale_inv.has_value()) { +// NVTEShape logical_scale_shape{}; +// if (scale_inv->dimensions().size() == 1) { +// logical_scale_shape.ndim = 1; +// logical_scale_shape.data[0] = scale_inv->dimensions()[0]; +// } else if (scale_inv->dimensions().size() == 2) { +// logical_scale_shape.ndim = 2; +// logical_scale_shape.data[0] = scale_inv->dimensions()[0]; +// logical_scale_shape.data[1] = scale_inv->dimensions()[1]; +// } else { +// NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM scale_inv but received ndim=", scale_inv->dimensions().size()); +// } +// NVTEBasicTensor scale_inv_tensor{reinterpret_cast(scale_inv->untyped_data()), +// static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())), +// logical_scale_shape}; +// nvte_set_grouped_tensor_param(&grouped_tensor, kNVTEGroupedRowwiseScaleInv, &scale_inv_tensor); +// } + +// return grouped_tensor; +// } Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type group_offset, + Buffer_Type group_sizes, Buffer_Type group_offsets, Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, @@ -487,7 +604,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); NVTE_CHECK(group_sizes.dimensions().size() == 1); - size_t num_gemms = group_sizes.dimensions()[0]; + size_t num_gemms = group_sizes.dimensions()[0] / 2; // JAX only supports int32 but cuBLAS requires int64 so we pack two int32 into one int64 // It is weird that TE/Common GEMM only use colwise for MXFP8 const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); @@ -584,6 +701,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type TensorWrapper beta_tensor(static_cast(beta.untyped_data()), std::vector{num_gemms}, convert_ffi_datatype_to_te_dtype(beta.element_type())); + + printf("Num gemms: %zu, M: %zu, N: %zu, K: %zu, group_sizes: %zu\n", num_gemms, m, n, k, group_sizes.dimensions()[0] / 2); + + //// RHS NVTEShape rhsShape{.data={k, n}, .ndim=2}; if (rhs_is_trans) { std::swap(rhsShape.data[0], rhsShape.data[1]); @@ -592,17 +713,37 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // If is_grouped_dense_wgrad, then n already includes num_gemms (G) pre-multiplied in gemm.py, so we don't need to multiply it here. rhsShape.data[0] *= num_gemms; } - NVTEGroupedTensor rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); + auto rhs_tensor = make_grouped_tensor(rhs_data, std::nullopt, JAXX_Scaling_Mode::NO_SCALING,/*rhs_sinv, scaling_mode,*/ num_gemms, rhsShape); + + //// LHS NVTEShape lhsShape{.data={m, k}, .ndim=2}; if (lhs_is_trans && is_grouped_dense_wgrad) { std::swap(lhsShape.data[0], lhsShape.data[1]); } - NVTEGroupedTensor lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); + auto lhs_tensor = make_grouped_tensor(lhs_data, std::nullopt, JAXX_Scaling_Mode::NO_SCALING,/*lhs_sinv, scaling_mode,*/ num_gemms, lhsShape); + if (!is_grouped_dense_wgrad) { + lhs_tensor.set_group_info(group_sizes, group_offsets); + } + + //// OUTPUT NVTEShape outShape{.data={m, n}, .ndim=2}; if (is_grouped_dense_wgrad) { outShape.data[0] *= num_gemms; } - NVTEGroupedTensor out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms, outShape); + auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms, outShape); + if (is_grouped_dense_wgrad) { + out_tensor.set_group_info(group_sizes, group_offsets); + } + + printf("rhs_shape: [%zu, %zu], lhs_shape: [%zu, %zu], out_shape: [%zu, %zu]\n", + rhsShape.data[0], rhsShape.data[1], + lhsShape.data[0], lhsShape.data[1], + outShape.data[0], outShape.data[1]); + + printf("rhs_is_trans: %d, lhs_is_trans: %d\n", rhs_is_trans, lhs_is_trans); + + // HACK: jberchtold FIXME + // cudaMemsetAsync(output->untyped_data(), 0, output->size_bytes(), stream); nvte_grouped_gemm( rhs_tensor, rhs_is_trans, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 03d5581ae6..f9757d29b4 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1494,7 +1494,7 @@ def reorder_rhs_for_grouped_gemm(tensor, bdims, cdims): kernel = reorder_rhs_for_grouped_gemm(kernel, (batch_dims[1],), contracting_dims[1]) num_groups = kernel.shape[0] - group_size = math.prod(x.shape[:-1]) // num_groups + group_size = x.shape[1] print(f'{num_groups=}, {group_size=}, {x.shape=}, {kernel.shape=}') group_sizes = jnp.array([group_size]*num_groups, dtype=jnp.int32) @@ -1534,7 +1534,7 @@ def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwa kernel, group_sizes=group_sizes, contracting_dims=((1,), (1,)), - quantizer_set=quantizer_set + # quantizer_set=quantizer_set ) return out.reshape(target_out_shape) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 55a59a1650..636740922e 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -73,6 +73,10 @@ def token_dispatch( Permuted probabilities of shape [num_out_tokens], or None if probs was not provided. row_id_map : jnp.ndarray Row ID map for use in token_combine (shape [num_tokens, num_experts * 2 + 1]). + + [num_tokens, 0:num_experts] = expert indices for each token + + [num_experts] = max([num_tokens, 0:num_experts], axis=0) + 1 """ return _token_dispatch(inp, routing_map, probs, num_out_tokens) From 6a54ff8f7a602497a98dadd0ee15d4442f4f52ba Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 20 Jan 2026 16:51:43 -0800 Subject: [PATCH 58/61] wip --- tests/jax/test_custom_call_compute.py | 73 ++++++++----------- .../jax/csrc/extensions/gemm.cpp | 7 +- 2 files changed, 37 insertions(+), 43 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 5c8c5d1b48..a45b7fd4af 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1771,38 +1771,33 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): # (8, 64, 128, 256), ] +# TODO(jberchtold): Support MXFP8 and NVFP4 +grouped_gemm_supported_scaling_modes = [ + # ScalingMode.DELAYED_TENSOR_SCALING, + ScalingMode.CURRENT_TENSOR_SCALING +] @pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES) class TestGroupedDense: def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): - out = jax.lax.ragged_dot(lhs, rhs, group_sizes) - print(f"In ref grouped dense: {lhs.shape=}, {rhs.shape=}, {out.shape=}") - return out - - dot_dimension_numbers = (((), ()), contracting_dims) - lhs_ragged_dimensions = (0,) - rhs_group_dimensions = (0,) - print(lhs.shape, rhs.shape, group_sizes, dot_dimension_numbers, lhs_ragged_dimensions, rhs_group_dimensions) - dims = jax.lax.RaggedDotDimensionNumbers(dot_dimension_numbers, lhs_ragged_dimensions, rhs_group_dimensions) - return jax.lax.ragged_dot_general(lhs, rhs, group_sizes, dims) - # lhs_contract_dim, _ = contracting_dims - # assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3 - # if bias is None: - # bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype) - # else: - # assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2]) - # remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop() - # lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis) - # rhs = jnp.split(rhs, rhs.shape[0], axis=0) - # bias = jnp.split(bias, bias.shape[0], axis=0) - # ref_out = [] - # dim_num = (contracting_dims, ((), ())) - # for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias): - # out_i = jax.lax.dot_general( - # lhs_i, rhs_i, dim_num, precision=jax.lax.Precision.HIGHEST - # ) + jnp.expand_dims(bias_i, axis=0) - # ref_out.append(jnp.squeeze(out_i)) - # return ref_out + lhs_contract_dim, _ = contracting_dims + assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3 + if bias is None: + bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype) + else: + assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2]) + remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop() + lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis) + rhs = jnp.split(rhs, rhs.shape[0], axis=0) + bias = jnp.split(bias, bias.shape[0], axis=0) + ref_out = [] + dim_num = (contracting_dims, ((), ())) + for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias): + out_i = jax.lax.dot_general( + lhs_i, rhs_i, dim_num, precision=jax.lax.Precision.HIGHEST + ) + jnp.expand_dims(bias_i, axis=0) + ref_out.append(jnp.squeeze(out_i)) + return ref_out def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False): key = jax.random.PRNGKey(0) @@ -1840,15 +1835,11 @@ def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", wi def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype): assert out.dtype == ref_list[0].dtype - import numpy as np - np.set_printoptions(threshold=10000) - jnp.set_printoptions(threshold=10000) - print("Actual:", out) - print("Expected:", ref_list) - assert_allclose(out, ref_list, dtype=dtype) - # out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0) - # for i in range(len(ref_list)): - # assert_allclose(out_list[i], ref_list[i], dtype=dtype) + out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + print([o.shape for o in out_list]) + print([r.shape for r in ref_list]) + for i in range(len(ref_list)): + assert_allclose(out_list[i], ref_list[i], dtype=dtype) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) @pytest_parametrize_wrapper("layout", ["NN"]) @@ -1878,7 +1869,7 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) - @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) + @pytest_parametrize_wrapper("scaling_mode", grouped_gemm_supported_scaling_modes) @pytest_parametrize_wrapper("layout", ["NN"]) def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): fwd_dtype, bwd_dtype = fwd_bwd_dtype @@ -1933,7 +1924,7 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( dtype, input_shape, - with_bias=True, + with_bias=False, ) value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) @@ -1959,14 +1950,14 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): "fwd_bwd_dtype", [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)], ) - @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) + @pytest_parametrize_wrapper("scaling_mode", grouped_gemm_supported_scaling_modes) def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): fwd_dtype, bwd_dtype = fwd_bwd_dtype dtype = jnp.bfloat16 x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( dtype, input_shape, - with_bias=True, + with_bias=False, ) quantizer_set = QuantizerFactory.create_set( diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 108a6b6843..2ab578b8d0 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -512,6 +512,9 @@ NVTEGroupedTensor const& JAXX_GroupedTensorWrapper::get_grouped_tensor() const { JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const& data, std::optional scale_inv, JAXX_Scaling_Mode scaling_mode, size_t num_tensors, NVTEShape const& dataShape) { JAXX_GroupedTensorWrapper grouped_tensor_wrapper(scaling_mode, num_tensors, dataShape); + if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING) { + scale_inv = std::nullopt; + } grouped_tensor_wrapper.set_rowwise(data, scale_inv); return grouped_tensor_wrapper; @@ -713,14 +716,14 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // If is_grouped_dense_wgrad, then n already includes num_gemms (G) pre-multiplied in gemm.py, so we don't need to multiply it here. rhsShape.data[0] *= num_gemms; } - auto rhs_tensor = make_grouped_tensor(rhs_data, std::nullopt, JAXX_Scaling_Mode::NO_SCALING,/*rhs_sinv, scaling_mode,*/ num_gemms, rhsShape); + auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); //// LHS NVTEShape lhsShape{.data={m, k}, .ndim=2}; if (lhs_is_trans && is_grouped_dense_wgrad) { std::swap(lhsShape.data[0], lhsShape.data[1]); } - auto lhs_tensor = make_grouped_tensor(lhs_data, std::nullopt, JAXX_Scaling_Mode::NO_SCALING,/*lhs_sinv, scaling_mode,*/ num_gemms, lhsShape); + auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); if (!is_grouped_dense_wgrad) { lhs_tensor.set_group_info(group_sizes, group_offsets); } From 8c86a86003cc1aaeb7cd9e95309a91879dd86685 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 21 Jan 2026 11:44:57 -0800 Subject: [PATCH 59/61] wip --- tests/jax/test_custom_call_compute.py | 80 ++++++++++++++----- .../jax/csrc/extensions/gemm.cpp | 69 +++------------- 2 files changed, 73 insertions(+), 76 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index a45b7fd4af..bcee1d4860 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1765,8 +1765,9 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): # (4, 16, 4, 4), - (3, 192, 64, 96), + # (3, 192, 64, 96), + (8, 64*8, 128*8, 128*8), # (8, 64, 32, 128), # (8, 64, 128, 256), ] @@ -1780,6 +1781,7 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): @pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES) class TestGroupedDense: def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): + # return jax.lax.ragged_dot(lhs, rhs, group_sizes) lhs_contract_dim, _ = contracting_dims assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3 if bias is None: @@ -1797,34 +1799,35 @@ def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): lhs_i, rhs_i, dim_num, precision=jax.lax.Precision.HIGHEST ) + jnp.expand_dims(bias_i, axis=0) ref_out.append(jnp.squeeze(out_i)) - return ref_out + return jnp.concatenate(ref_out, axis=0) def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 4) n_groups, m, n, k = input_shape - # group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) - # group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) - # group_sizes = jnp.diff(group_sizes) + group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) + group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) + group_sizes = jnp.diff(group_sizes) - # # Make one empty input lhs to test empty GEMM handling - # group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1]) - # group_sizes = group_sizes.at[1].set(0) - - group_sizes = jnp.full((n_groups,), m // n_groups) - assert group_sizes.sum() == m + # Make one empty input lhs to test empty GEMM handling + group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1]) + group_sizes = group_sizes.at[1].set(0) # *32 to make sure that input shape works for MXFP8 # group_sizes = group_sizes * 32 # m = m * 32 + group_sizes = jnp.full((n_groups,), m // n_groups) + assert group_sizes.sum() == m + lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m) rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k) bias_shape = (n_groups, n) - lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype) - rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype) + lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype) / jnp.sqrt(k) + rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype) / jnp.sqrt(k) + # rhs = jnp.concatenate([i/n_groups*jnp.identity(k, dtype=dtype).reshape(1, k, k) for i in range(n_groups)], axis=0) bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) @@ -1833,13 +1836,50 @@ def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", wi return lhs, rhs, group_sizes, contracting_dims, bias + def _diff_to_image(self, a, b): + import numpy as np + from PIL import Image + # Convert to numpy and compute diff + a_np = np.array(a) + b_np = np.array(b) + diff = np.abs(a_np - b_np) + + # Normalize diff to 0-255 range for visualization + diff_normalized = (diff - diff.min()) / (diff.max() - diff.min() + 1e-8) * 255 + diff_uint8 = diff_normalized.astype(np.uint8) + + # Create heatmap image + img = Image.fromarray(diff_uint8, mode='L') + return img + def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype): + import numpy as np + from PIL import Image assert out.dtype == ref_list[0].dtype - out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0) - print([o.shape for o in out_list]) - print([r.shape for r in ref_list]) - for i in range(len(ref_list)): - assert_allclose(out_list[i], ref_list[i], dtype=dtype) + self._diff_to_image(out, ref_list).save('output_diff.png') + assert_allclose(out, ref_list, dtype=dtype) + + + + # ref_list = jnp.split(ref_list, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + # out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + # print([o.shape for o in out_list]) + # print([r.shape for r in ref_list]) + # for i in range(len(ref_list)): + # print(f"Asserting output for group {i}, output shape: {out_list[i].shape}, ref shape: {ref_list[i].shape}") + # # Convert to numpy and compute diff + # out_np = np.array(out_list[i]) + # ref_np = np.array(ref_list[i]) + # diff = np.abs(out_np - ref_np) + + # # Normalize diff to 0-255 range for visualization + # diff_normalized = (diff - diff.min()) / (diff.max() - diff.min() + 1e-8) * 255 + # diff_uint8 = diff_normalized.astype(np.uint8) + + # # Create heatmap image + # img = Image.fromarray(diff_uint8, mode='L') + # img.save(f'output_group_{i}.png') + # assert_allclose(out_list[i], ref_list[i], dtype=dtype) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) @pytest_parametrize_wrapper("layout", ["NN"]) @@ -1943,7 +1983,7 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype) assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype) assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype) - assert_allclose(prim_dbias, ref_dbias, dtype=dtype) + # assert_allclose(prim_dbias, ref_dbias, dtype=dtype) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize( @@ -1988,7 +2028,7 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype) assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) - assert_allclose(prim_dbias, ref_dbias, dtype=dtype) + # assert_allclose(prim_dbias, ref_dbias, dtype=dtype) @pytest_parametrize_wrapper('eqn,a_shape,b_shape', [ # ('ij,jk->ik', (64, 32), (32, 128)), diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 2ab578b8d0..e10a9b9ac6 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -519,56 +519,6 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const& data, std::opti return grouped_tensor_wrapper; } -// NVTEGroupedTensor make_grouped_tensor(Buffer_Type const& data, std::optional scale_inv, JAXX_Scaling_Mode scaling_mode, size_t num_tensors, NVTEShape const& dataShape) { -// // printf("make_grouped_tensor data shape: "); -// // for (auto dim : data.dimensions()) { -// // printf("%zu, ", dim); -// // } -// // printf("\n"); -// // NVTEShape logical_shape{}; -// // if (data.dimensions().size() == 1) { -// // // HACK -// // size_t cdim_size = 4096; -// // logical_shape.ndim = 2; -// // logical_shape.data[0] = data.dimensions()[0] / cdim_size; -// // logical_shape.data[1] = cdim_size; -// // printf("NUM TENSORS: %zu\n", num_tensors); -// // } -// // else { -// // NVTE_CHECK(data.dimensions().size() == 2, "Expected 2D tensor for GEMM operand but received ndim=", data.dimensions().size()); - -// // logical_shape.ndim = 2; -// // logical_shape.data[0] = data.dimensions()[0]; -// // logical_shape.data[1] = data.dimensions()[1]; -// // } - -// NVTEGroupedTensor grouped_tensor = nvte_create_grouped_tensor(get_nvte_scaling_mode(scaling_mode), num_tensors, dataShape); - -// NVTEBasicTensor data_tensor{reinterpret_cast(data.untyped_data()), -// static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())), -// dataShape}; -// nvte_set_grouped_tensor_param(&grouped_tensor, kNVTEGroupedRowwiseData, &data_tensor); - -// if (scale_inv.has_value()) { -// NVTEShape logical_scale_shape{}; -// if (scale_inv->dimensions().size() == 1) { -// logical_scale_shape.ndim = 1; -// logical_scale_shape.data[0] = scale_inv->dimensions()[0]; -// } else if (scale_inv->dimensions().size() == 2) { -// logical_scale_shape.ndim = 2; -// logical_scale_shape.data[0] = scale_inv->dimensions()[0]; -// logical_scale_shape.data[1] = scale_inv->dimensions()[1]; -// } else { -// NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM scale_inv but received ndim=", scale_inv->dimensions().size()); -// } -// NVTEBasicTensor scale_inv_tensor{reinterpret_cast(scale_inv->untyped_data()), -// static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())), -// logical_scale_shape}; -// nvte_set_grouped_tensor_param(&grouped_tensor, kNVTEGroupedRowwiseScaleInv, &scale_inv_tensor); -// } - -// return grouped_tensor; -// } Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, @@ -705,11 +655,11 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type convert_ffi_datatype_to_te_dtype(beta.element_type())); - printf("Num gemms: %zu, M: %zu, N: %zu, K: %zu, group_sizes: %zu\n", num_gemms, m, n, k, group_sizes.dimensions()[0] / 2); + printf("Num gemms: %zu, M: %zu, N: %zu, K: %zu, group_sizes: %zu, lhs_is_trans: %d, rhs_is_trans: %d, is_grouped_dense_wgrad: %d\n", num_gemms, m, n, k, group_sizes.dimensions()[0] / 2, lhs_is_trans, rhs_is_trans, is_grouped_dense_wgrad); //// RHS NVTEShape rhsShape{.data={k, n}, .ndim=2}; - if (rhs_is_trans) { + if (rhs_is_trans && !is_grouped_dense_wgrad) { std::swap(rhsShape.data[0], rhsShape.data[1]); } if (!is_grouped_dense_wgrad) { @@ -717,12 +667,19 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type rhsShape.data[0] *= num_gemms; } auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - + if (is_grouped_dense_wgrad) { + rhs_tensor.set_group_info(group_sizes, group_offsets); + } + //// LHS - NVTEShape lhsShape{.data={m, k}, .ndim=2}; + NVTEShape lhsShape{.data={k, m}, .ndim=2}; if (lhs_is_trans && is_grouped_dense_wgrad) { std::swap(lhsShape.data[0], lhsShape.data[1]); } + if (is_grouped_dense_wgrad) { + // If is_grouped_dense_wgrad, then m already includes num_gemms (G) pre-multiplied in gemm.py, so we don't need to multiply it here. + lhsShape.data[0] *= num_gemms; + } auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); if (!is_grouped_dense_wgrad) { lhs_tensor.set_group_info(group_sizes, group_offsets); @@ -734,7 +691,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type outShape.data[0] *= num_gemms; } auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms, outShape); - if (is_grouped_dense_wgrad) { + if (!is_grouped_dense_wgrad) { out_tensor.set_group_info(group_sizes, group_offsets); } @@ -746,7 +703,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type printf("rhs_is_trans: %d, lhs_is_trans: %d\n", rhs_is_trans, lhs_is_trans); // HACK: jberchtold FIXME - // cudaMemsetAsync(output->untyped_data(), 0, output->size_bytes(), stream); + cudaMemsetAsync(output->untyped_data(), 0xFF, output->size_bytes(), stream); nvte_grouped_gemm( rhs_tensor, rhs_is_trans, From d8247da9d67f6176e49d68b399958c8004b05dec Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 21 Jan 2026 13:53:03 -0800 Subject: [PATCH 60/61] wip --- .../jax/csrc/extensions/gemm.cpp | 302 ++++++++---------- 1 file changed, 129 insertions(+), 173 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index e10a9b9ac6..485a643bbc 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -399,127 +399,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGro .Ret() // dummy_output .Attr("num_gemms")); -class JAXX_GroupedTensorWrapper { -public: - JAXX_GroupedTensorWrapper() = delete; - JAXX_GroupedTensorWrapper(JAXX_Scaling_Mode scaling_mode, size_t num_tensors, - NVTEShape const& dataShape); - ~JAXX_GroupedTensorWrapper() = default; - - void set_rowwise(Buffer_Type const& data, std::optional const& scale_inv); - void set_group_info(Buffer_Type const& group_sizes, Buffer_Type const& group_offsets); - - operator NVTEGroupedTensor() const { return m_grouped_tensor; } - NVTEGroupedTensor const& get_grouped_tensor() const; - -private: - NVTEShape m_data_shape{}; - NVTEGroupedTensor m_grouped_tensor{}; - - // Internal tensors. These need to be kept alive as long as the grouped tensor is alive. - NVTEBasicTensor m_data_tensor{}; - NVTEBasicTensor m_scale_inv_tensor{}; - - NVTEBasicTensor m_sizes_tensor{}; - NVTEBasicTensor m_offsets_tensor{}; -}; - -JAXX_GroupedTensorWrapper::JAXX_GroupedTensorWrapper(JAXX_Scaling_Mode scaling_mode, - size_t num_tensors, - NVTEShape const& dataShape) { - m_data_shape = dataShape; - m_grouped_tensor = nvte_create_grouped_tensor(get_nvte_scaling_mode(scaling_mode), num_tensors, dataShape); -} - -void JAXX_GroupedTensorWrapper::set_rowwise(Buffer_Type const& data, - std::optional const& scale_inv) { - printf("set_rowwise data shape: XLA buffer shape: "); - for (auto dim : data.dimensions()) { - printf("%zu, ", dim); - } - printf("NVTEShape: "); - for (int i = 0; i < m_data_shape.ndim; ++i) { - printf("%d, ", m_data_shape.data[i]); - } - printf("\n"); - NVTEDType data_dtype = static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())); - m_data_tensor = NVTEBasicTensor{reinterpret_cast(data.untyped_data()), data_dtype, - m_data_shape}; - - nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedRowwiseData, &m_data_tensor); - - if (scale_inv.has_value()) { - NVTEDType scale_inv_dtype = - static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())); - NVTEShape logical_scale_shape{}; - if (scale_inv->dimensions().size() == 1) { - logical_scale_shape.ndim = 1; - logical_scale_shape.data[0] = scale_inv->dimensions()[0]; - } else if (scale_inv->dimensions().size() == 2) { - logical_scale_shape.ndim = 2; - logical_scale_shape.data[0] = scale_inv->dimensions()[0]; - logical_scale_shape.data[1] = scale_inv->dimensions()[1]; - } else { - NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM scale_inv but received ndim=", - scale_inv->dimensions().size()); - } - m_scale_inv_tensor = NVTEBasicTensor{reinterpret_cast(scale_inv->untyped_data()), - scale_inv_dtype, logical_scale_shape}; - nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedRowwiseScaleInv, - &m_scale_inv_tensor); - } -} - -void JAXX_GroupedTensorWrapper::set_group_info(Buffer_Type const& group_sizes, - Buffer_Type const& group_offsets) { - NVTEDType sizes_dtype = - static_cast(convert_ffi_datatype_to_te_dtype(group_sizes.element_type())); - NVTEDType offsets_dtype = - static_cast(convert_ffi_datatype_to_te_dtype(group_offsets.element_type())); - - NVTE_CHECK(sizes_dtype == NVTEDType::kNVTEInt32, - "group_sizes must be of type int32."); - NVTE_CHECK(offsets_dtype == NVTEDType::kNVTEInt32, - "group_offsets must be of type int32."); - - // JAX only supports int32 but cuBLAS requires int64 so we pack two int32 into one int64 - size_t num_tensors = group_sizes.dimensions()[0] / 2; - NVTE_CHECK(group_sizes.dimensions().size() == 1, - "group_sizes must be a 1D tensor with length equal to the number of tensors."); - NVTE_CHECK(group_offsets.dimensions().size() == 1, - "group_offsets must be a 1D tensor with length equal to the number of tensors."); - NVTE_CHECK(group_offsets.dimensions()[0] == 2 * num_tensors, - "group_sizes and group_offsets must have the same number of elements."); - - NVTEShape shape{}; - shape.ndim = 1; - shape.data[0] = num_tensors; - - m_sizes_tensor = NVTEBasicTensor{reinterpret_cast(group_sizes.untyped_data()), - NVTEDType::kNVTEInt64, - shape}; - m_offsets_tensor = NVTEBasicTensor{reinterpret_cast(group_offsets.untyped_data()), - NVTEDType::kNVTEInt64, - shape}; - - nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedFirstDims, &m_sizes_tensor); - nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedTensorOffsets, &m_offsets_tensor); -} - -NVTEGroupedTensor const& JAXX_GroupedTensorWrapper::get_grouped_tensor() const { - return m_grouped_tensor; -} - -JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const& data, std::optional scale_inv, JAXX_Scaling_Mode scaling_mode, size_t num_tensors, NVTEShape const& dataShape) { - JAXX_GroupedTensorWrapper grouped_tensor_wrapper(scaling_mode, num_tensors, dataShape); - if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING) { - scale_inv = std::nullopt; - } - grouped_tensor_wrapper.set_rowwise(data, scale_inv); - - return grouped_tensor_wrapper; -} - Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offsets, @@ -654,68 +533,145 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type TensorWrapper beta_tensor(static_cast(beta.untyped_data()), std::vector{num_gemms}, convert_ffi_datatype_to_te_dtype(beta.element_type())); - - printf("Num gemms: %zu, M: %zu, N: %zu, K: %zu, group_sizes: %zu, lhs_is_trans: %d, rhs_is_trans: %d, is_grouped_dense_wgrad: %d\n", num_gemms, m, n, k, group_sizes.dimensions()[0] / 2, lhs_is_trans, rhs_is_trans, is_grouped_dense_wgrad); - - //// RHS - NVTEShape rhsShape{.data={k, n}, .ndim=2}; - if (rhs_is_trans && !is_grouped_dense_wgrad) { - std::swap(rhsShape.data[0], rhsShape.data[1]); - } - if (!is_grouped_dense_wgrad) { - // If is_grouped_dense_wgrad, then n already includes num_gemms (G) pre-multiplied in gemm.py, so we don't need to multiply it here. - rhsShape.data[0] *= num_gemms; - } - auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - if (is_grouped_dense_wgrad) { - rhs_tensor.set_group_info(group_sizes, group_offsets); + // Grouped GEMM currently only supports tensor scaling + NVTE_CHECK(is_tensor_scaling, + "Grouped GEMM only supports tensor scaling (DELAYED_TENSOR_SCALING or CURRENT_TENSOR_SCALING)"); + + // To make the output compatible with JAX row-major, we swap A and B in cuBLAS GEMM call. + // JAX: C = LHS @ RHS => cuBLAS: C^T = RHS^T @ LHS^T + // So we pass: A = rhs (swapped), B = lhs (swapped) + // trans_a = !rhs_is_trans (flip because of the swap) + // trans_b = !lhs_is_trans (flip because of the swap) + bool trans_a = !rhs_is_trans; + bool trans_b = !lhs_is_trans; + + // Calculate logical shapes for grouped tensors + // After swap: A (from rhs) has shape [num_gemms, K, N] or [num_gemms, N, K] if transposed + // B (from lhs) has shape [M, K] or [K, M] if transposed (M varies across gemms) + // D has shape [M, N] (M varies across gemms for non-wgrad case, or [num_gemms, M, N] for wgrad) + + // The group_sizes array contains the M dimension for each GEMM (packed as int32 pairs for int64) + // For grouped GEMM: lhs is [sum(M_i), K], rhs is [num_gemms, K, N], output is [sum(M_i), N] + // For grouped dense wgrad: lhs is [K, M], rhs is [K, N], output is [num_gemms, M, N] + + auto nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode); + + // Create grouped tensor A (from rhs - after swap for cuBLAS) + // For non-wgrad: rhs is [num_gemms, K, N] -> each tensor is [K, N] + // Shape: uniform first dim = K (if !rhs_trans) or N (if rhs_trans) + // uniform last dim = N (if !rhs_trans) or K (if rhs_trans) + size_t a_first_dim = rhs_is_trans ? n : k; + size_t a_last_dim = rhs_is_trans ? k : n; + size_t a_logical_shape_data[2] = {num_gemms * a_first_dim, a_last_dim}; + NVTEShape a_logical_shape = nvte_make_shape(a_logical_shape_data, 2); + NVTEGroupedTensor grouped_a = nvte_create_grouped_tensor(nvte_scaling_mode, num_gemms, a_logical_shape); + + NVTEBasicTensor a_data_tensor{rhs_ptr, static_cast(rhs_dtype), a_logical_shape}; + nvte_set_grouped_tensor_param(&grouped_a, kNVTEGroupedRowwiseData, &a_data_tensor); + + // Set scale_inv for A (rhs) if FP8 + if (is_fp8_gemm) { + NVTEShape a_scale_shape = nvte_make_shape(&num_gemms, 1); + NVTEBasicTensor a_scale_tensor{rhs_scatter_aligned_ptr, kNVTEFloat32, a_scale_shape}; + nvte_set_grouped_tensor_param(&grouped_a, kNVTEGroupedRowwiseScaleInv, &a_scale_tensor); + // Scatter rhs scale_inv values to aligned positions + for (size_t i = 0; i < num_gemms; ++i) { + cudaMemcpyAsync(rhs_scatter_aligned_ptr + i * tensor_scaling_sinv_aligment, + rhs_sinv_ptr + i * sizeof(float), + sizeof(float), cudaMemcpyDeviceToDevice, stream); + } } - //// LHS - NVTEShape lhsShape{.data={k, m}, .ndim=2}; - if (lhs_is_trans && is_grouped_dense_wgrad) { - std::swap(lhsShape.data[0], lhsShape.data[1]); - } + // Create grouped tensor B (from lhs - after swap for cuBLAS) + // For non-wgrad: lhs is [sum(M_i), K] with varying M_i per gemm + // For wgrad: lhs is [K, M] uniform + size_t b_first_dim, b_last_dim; + bool b_has_varying_first_dim = !is_grouped_dense_wgrad; if (is_grouped_dense_wgrad) { - // If is_grouped_dense_wgrad, then m already includes num_gemms (G) pre-multiplied in gemm.py, so we don't need to multiply it here. - lhsShape.data[0] *= num_gemms; + b_first_dim = lhs_is_trans ? m : k; + b_last_dim = lhs_is_trans ? k : m; + } else { + // Varying M: first_dim varies, last_dim is K + b_first_dim = lhs_is_trans ? k : m; // total M for logical shape + b_last_dim = lhs_is_trans ? m : k; // but this is the sum, we need to handle varying dims } - auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - if (!is_grouped_dense_wgrad) { - lhs_tensor.set_group_info(group_sizes, group_offsets); + size_t b_logical_shape_data[2] = {b_first_dim, b_last_dim}; + NVTEShape b_logical_shape = nvte_make_shape(b_logical_shape_data, 2); + NVTEGroupedTensor grouped_b = nvte_create_grouped_tensor(nvte_scaling_mode, num_gemms, b_logical_shape); + + NVTEBasicTensor b_data_tensor{lhs_ptr, static_cast(lhs_dtype), b_logical_shape}; + nvte_set_grouped_tensor_param(&grouped_b, kNVTEGroupedRowwiseData, &b_data_tensor); + + // Set first_dims for B if varying (non-wgrad case) + // group_sizes contains the M values as int64 (packed as pairs of int32) + // group_offsets contains cumulative offsets + if (b_has_varying_first_dim) { + NVTEShape first_dims_shape = nvte_make_shape(&num_gemms, 1); + NVTEBasicTensor first_dims_tensor{group_sizes.untyped_data(), kNVTEInt64, first_dims_shape}; + nvte_set_grouped_tensor_param(&grouped_b, kNVTEGroupedFirstDims, &first_dims_tensor); + + NVTEBasicTensor offsets_tensor{group_offsets.untyped_data(), kNVTEInt64, first_dims_shape}; + nvte_set_grouped_tensor_param(&grouped_b, kNVTEGroupedTensorOffsets, &offsets_tensor); } - //// OUTPUT - NVTEShape outShape{.data={m, n}, .ndim=2}; + // Set scale_inv for B (lhs) if FP8 + if (is_fp8_gemm) { + NVTEShape b_scale_shape = nvte_make_shape(&num_gemms, 1); + NVTEBasicTensor b_scale_tensor{lhs_scatter_aligned_ptr, kNVTEFloat32, b_scale_shape}; + nvte_set_grouped_tensor_param(&grouped_b, kNVTEGroupedRowwiseScaleInv, &b_scale_tensor); + // Scatter lhs scale_inv values to aligned positions + for (size_t i = 0; i < num_gemms; ++i) { + cudaMemcpyAsync(lhs_scatter_aligned_ptr + i * tensor_scaling_sinv_aligment, + lhs_sinv_ptr + i * sizeof(float), + sizeof(float), cudaMemcpyDeviceToDevice, stream); + } + } + + // Create grouped tensor D (output) + // For non-wgrad: output is [sum(M_i), N] with varying M_i + // For wgrad: output is [num_gemms, M, N] with uniform M + size_t d_first_dim, d_last_dim; + bool d_has_varying_first_dim = !is_grouped_dense_wgrad; if (is_grouped_dense_wgrad) { - outShape.data[0] *= num_gemms; + d_first_dim = num_gemms * m; + d_last_dim = n; + } else { + d_first_dim = m; // total M + d_last_dim = n; } - auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms, outShape); - if (!is_grouped_dense_wgrad) { - out_tensor.set_group_info(group_sizes, group_offsets); + size_t d_logical_shape_data[2] = {d_first_dim, d_last_dim}; + NVTEShape d_logical_shape = nvte_make_shape(d_logical_shape_data, 2); + NVTEGroupedTensor grouped_d = nvte_create_grouped_tensor(nvte_scaling_mode, num_gemms, d_logical_shape); + + NVTEBasicTensor d_data_tensor{out_ptr, static_cast(out_dtype), d_logical_shape}; + nvte_set_grouped_tensor_param(&grouped_d, kNVTEGroupedRowwiseData, &d_data_tensor); + + // Set first_dims and offsets for D if varying + if (d_has_varying_first_dim) { + NVTEShape first_dims_shape = nvte_make_shape(&num_gemms, 1); + NVTEBasicTensor first_dims_tensor{group_sizes.untyped_data(), kNVTEInt64, first_dims_shape}; + nvte_set_grouped_tensor_param(&grouped_d, kNVTEGroupedFirstDims, &first_dims_tensor); + + NVTEBasicTensor offsets_tensor{group_offsets.untyped_data(), kNVTEInt64, first_dims_shape}; + nvte_set_grouped_tensor_param(&grouped_d, kNVTEGroupedTensorOffsets, &offsets_tensor); } - printf("rhs_shape: [%zu, %zu], lhs_shape: [%zu, %zu], out_shape: [%zu, %zu]\n", - rhsShape.data[0], rhsShape.data[1], - lhsShape.data[0], lhsShape.data[1], - outShape.data[0], outShape.data[1]); - - printf("rhs_is_trans: %d, lhs_is_trans: %d\n", rhs_is_trans, lhs_is_trans); - - // HACK: jberchtold FIXME - cudaMemsetAsync(output->untyped_data(), 0xFF, output->size_bytes(), stream); - - nvte_grouped_gemm( - rhs_tensor, rhs_is_trans, - lhs_tensor, lhs_is_trans, - nullptr, - out_tensor, - alpha_tensor.data(), - beta_tensor.data(), - workspace_setup.data(), - workspace_cublas.data(), - nullptr, // config (use defaults) - stream); + // Call nvte_grouped_gemm + // Note: C is nullptr since beta=0 (no accumulation) + nvte_grouped_gemm(grouped_a, trans_a, grouped_b, trans_b, + nullptr, // C tensor (nullptr for beta=0) + grouped_d, + alpha_tensor.data(), + beta_tensor.data(), + workspace_setup.data(), + workspace_cublas.data(), + nullptr, // config (use defaults) + stream); + + // Clean up grouped tensors + nvte_destroy_grouped_tensor(grouped_a); + nvte_destroy_grouped_tensor(grouped_b); + nvte_destroy_grouped_tensor(grouped_d); return ffi_with_cuda_error_check(); } From d799a29c1f963af153c385bfb9f791a954c3eeee Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 22 Jan 2026 08:39:16 -0800 Subject: [PATCH 61/61] wip --- transformer_engine/jax/csrc/extensions/gemm.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 485a643bbc..aed0802f2b 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -259,6 +259,10 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ", out_shape[1]=", out_shape[1]); // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order + + // if A contracting dim is innermost dimension, then A is T layout. Wants to be MxK. + // if B contracting dim is innermost dimension, then B is N layout. It wants to be KxM, then with column-major layout will be N because it doesn't need to be transposed. + // TN: A_trans = true, B_trans = false nvte_cublas_gemm_v2(rhs_transposed /*transa*/, lhs_transposed /*transb*/, alpha_ptr, rhs_.data() /*A*/, lhs_.data() /*B*/, beta_ptr, out_.data() /*C*/, out_.data() /*D*/, workspace_.data(), config, stream); @@ -533,9 +537,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type TensorWrapper beta_tensor(static_cast(beta.untyped_data()), std::vector{num_gemms}, convert_ffi_datatype_to_te_dtype(beta.element_type())); - // Grouped GEMM currently only supports tensor scaling - NVTE_CHECK(is_tensor_scaling, - "Grouped GEMM only supports tensor scaling (DELAYED_TENSOR_SCALING or CURRENT_TENSOR_SCALING)"); + // Grouped GEMM supports NO_SCALING (for BF16/FP16) and tensor scaling (for FP8) + const bool is_no_scaling = scaling_mode == JAXX_Scaling_Mode::NO_SCALING; + NVTE_CHECK(is_tensor_scaling || is_no_scaling, + "Grouped GEMM only supports NO_SCALING, DELAYED_TENSOR_SCALING, or CURRENT_TENSOR_SCALING"); // To make the output compatible with JAX row-major, we swap A and B in cuBLAS GEMM call. // JAX: C = LHS @ RHS => cuBLAS: C^T = RHS^T @ LHS^T