From 5175aad7172294244848dcc26a1b0bcc339c9671 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 7 Jan 2026 00:15:10 +0000 Subject: [PATCH 1/8] Naive implementation of grouped linear op Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 142 ++++++ .../pytorch/ops/basic/__init__.py | 1 + .../pytorch/ops/basic/grouped_linear.py | 450 ++++++++++++++++++ 3 files changed, 593 insertions(+) create mode 100644 transformer_engine/pytorch/ops/basic/grouped_linear.py diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index ce15dd1421..d2c84403c4 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -7,6 +7,7 @@ from collections.abc import Iterable import io import math +import random from typing import Optional import pytest @@ -1924,6 +1925,147 @@ def test_dropout( abs(z_score) < 2.5758 ), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})" + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("input_requires_grad", (False, True)) + @pytest.mark.parametrize("weight_requires_grad", (False, True)) + def test_grouped_linear( + self, + *, + group_size: int = 4, + bias: bool, + weight_shape: tuple[int, int] = (32, 32), + split_alignment: int = 32, + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + quantization: Optional[str] = None, + quantized_compute: bool = False, + quantized_weight: bool = False, + input_requires_grad: bool, + weight_requires_grad: bool, + ) -> None: + """Grouped GEMM""" + + # Split sizes + split_sizes = [split_alignment * i for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device="cpu") + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = (split_sizes.sum().item(), in_features) + out_shape = (in_shape[0], out_features) + + # Skip invalid configurations + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not specified") + if quantization is not None and not (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not used") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=input_requires_grad, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + ws_ref, ws_test = [], [] + bs_ref, bs_test = [], [] + for _ in range(group_size): + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=weight_requires_grad, + ) + b_ref, b_test = None, None + if bias: + b_ref, b_test = make_reference_and_test_tensors( + out_features, + test_dtype=dtype, + test_device=device, + requires_grad=weight_requires_grad, + ) + ws_ref.append(w_ref) + ws_test.append(w_test) + bs_ref.append(b_ref) + bs_test.append(b_test) + + # Plain PyTorch implementation + xs_ref = torch.split(x_ref, split_sizes.tolist()) + ys_ref = [] + for x, w, b in zip(xs_ref, ws_ref, bs_ref): + ys_ref.append(torch.nn.functional.linear(x, w, bias=b)) + y_ref = torch.cat(ys_ref) + if input_requires_grad or weight_requires_grad: + y_ref.backward(dy_ref) + + # Construct fusible operation + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): + op = te_ops.GroupedLinear( + group_size, + in_features, + out_features, + bias=bias, + device=device, + dtype=dtype, + ) + with torch.no_grad(): + for group_idx in range(group_size): + getattr(op, f"weight{group_idx}").copy_(ws_test[group_idx]) + if bias: + getattr(op, f"bias{group_idx}").copy_(bs_test[group_idx]) + del ws_test, bs_test + for param in op.parameters(): + param.requires_grad_(requires_grad=weight_requires_grad) + + # Forward and backward pass with op + with te.autocast(enabled=quantized_compute, recipe=recipe): + y_test = op(x_test, split_sizes) + if input_requires_grad or weight_requires_grad: + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantized_compute: + tols = quantization_tols(quantization) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + if input_requires_grad: + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + else: + assert x_test.grad is None + for group_idx in range(group_size): + w_test = getattr(op, f"weight{group_idx}") + if weight_requires_grad: + dw_test = w_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dw_test, ws_ref[group_idx].grad, **tols) + else: + assert w_test.grad is None + if bias: + b_test = getattr(op, f"bias{group_idx}") + if weight_requires_grad: + db_test = b_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db_test, bs_ref[group_idx].grad, **tols) + else: + assert b_test.grad is None + class TestFusedOps: """Tests for fused operations""" diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 665ffe359c..a74f02e3a0 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -24,6 +24,7 @@ from .bias import Bias from .constant_scale import ConstantScale from .dropout import Dropout +from .grouped_linear import GroupedLinear from .identity import Identity from .l2normalization import L2Normalization from .layer_norm import LayerNorm diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py new file mode 100644 index 0000000000..e03710189f --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -0,0 +1,450 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for bias.""" + +from __future__ import annotations +from collections.abc import Iterable +import contextlib +import math +from typing import Any, Optional + +import torch + +import transformer_engine_torch as tex +from ...module.base import get_dummy_wgrad +from ...quantization import FP8GlobalStateManager +from ...tensor import Quantizer +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, +) +from .._common import is_quantized_tensor +from ..op import BasicOperation, OperationContext + + +class GroupedLinear(BasicOperation): + + # Operation expects input split sizes + num_extra_inputs: int = 1 + + def __init__( + self, + group_size: int, + in_features: int, + out_features: int, + *, + bias: bool = True, + device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, + rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, + accumulate_into_main_grad: bool = False, + ) -> None: + super().__init__() + + # Weight tensor dimensions + self.group_size: int = group_size + self.in_features: int = in_features + self.out_features: int = out_features + if self.group_size <= 0: + raise ValueError(f"Invalid group size ({self.group_size})") + if self.in_features <= 0: + raise ValueError(f"Invalid input size ({self.in_features})") + if self.out_features <= 0: + raise ValueError(f"Invalid output size ({self.out_features})") + + # Weight tensor attributes + device = canonicalize_device(device) + dtype = canonicalize_dtype(dtype) + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + + # Initialize recipe state if needed for natively quantized weight + self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters() + if self._with_quantized_weight: + self.reset_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe()) + + # RNG state tracker + self._rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] + self._rng_state_tracker_function = rng_state_tracker_function + + # Register weights + self.weight0: torch.nn.Parameter + for group_idx in range(self.group_size): + weight_tensor = torch.empty( + self.out_features, + self.in_features, + device=device, + dtype=dtype, + ) + self.register_parameter( + f"weight{group_idx}", + torch.nn.Parameter(weight_tensor), + ) + + # Register biases + self.bias0: Optional[torch.nn.Parameter] + for group_idx in range(self.group_size): + bias_tensor = None + if bias: + bias_tensor = torch.empty( + self.out_features, + device=device, + dtype=dtype, + ) + bias_tensor = torch.nn.Parameter(bias_tensor) + self.register_parameter(f"bias{group_idx}", bias_tensor) + + # Initialize weights if needed + if device.type != "meta": + self.reset_parameters() + + # Whether to accumulate weight gradient into main_grad + self._accumulate_into_main_grad: bool = accumulate_into_main_grad + + def num_quantizers(self, mode: str) -> int: + if mode == "forward": + return 2 * self.group_size + if mode == "backward": + return self.group_size + return 0 + + @property + def has_bias(self) -> bool: + return self.bias0 is not None + + @torch.no_grad + def reset_parameters(self) -> None: + """Initialize parameter buffers and values""" + + for group_idx in range(self.group_size): + + # Parameters + weight = getattr(self, f"weight{group_idx}") + bias = getattr(self, f"bias{group_idx}") + + # Parameter device + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Allocate buffers if needed + if is_quantized_tensor(weight): + weight = torch.empty( + weight.size(), + dtype=weight.dtype, + device=device, + ) + elif not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) + if bias is not None and not devices_match(bias.device, device): + bias = torch.empty_like(bias, device=device) + + # Initialize values + init_context = contextlib.nullcontext() + if self._rng_state_tracker_function is not None: + init_context = self._rng_state_tracker_function().fork() + with init_context: + torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + if bias is not None: + bias.zero_() + + # Quantize weight if needed + if self._with_quantized_weight: + quantizer = self.get_quantizer("forward", 1) + if quantizer is None: + raise RuntimeError( + "Tried to quantize weight with deferred initialization " + "due to meta device, but no quantizer was available. " + "This is most likely because the weight was initialized " + "within quantized_model_init, but the forward pass was not " + "performed within autocast." + ) + quantizer.set_usage( + rowwise=True, + columnwise=torch.is_grad_enabled(), + ) + quantizer.internal = False + with torch.no_grad(): + weight = quantizer(weight) + + # Save updated parameters + if not isinstance(weight, torch.nn.Parameter): + weight = torch.nn.Parameter(weight) + setattr(self, f"weight{group_idx}", weight) + if bias is not None: + if not isinstance(bias, torch.nn.Parameter): + bias = torch.nn.Parameter(bias) + setattr(self, f"bias{group_idx}", bias) + + def pre_first_fuser_forward(self) -> None: + super().pre_first_fuser_forward() + + # Initialize params if needed + if any(param.device.type == "meta" for param in self.parameters()): + self.reset_parameters() + + # Check that weights are consistent + dtype = self.weight0.dtype + device = self.weight0.device + weight_requires_grad = self.weight0.requires_grad + weight_tensor_type = type(self.weight0.data) + for group_idx in range(self.group_size): + weight = getattr(self, f"weight{group_idx}") + if weight.dtype != dtype: + raise RuntimeError( + f"Weight {group_idx} has invalid dtype " + f"(expected {dtype}, got {weight.dtype})." + ) + if not devices_match(weight.device, device): + raise RuntimeError( + f"Weight {group_idx} has invalid device " + f"(expected {device}, got {weight.device})." + ) + if weight.requires_grad != weight_requires_grad: + raise RuntimeError( + f"Weight {group_idx} has requires_grad={weight.requires_grad}, " + f"but expected requires_grad={weight_requires_grad}." + ) + if type(weight.data) != weight_tensor_type: + raise RuntimeError( + f"Weight {group_idx} has invalid tensor type " + f"(expected {weight_tensor_type.__name__}, " + f"got {type(weight.data).__name__})." + ) + + # Check that biases are consistent + for group_idx in range(self.group_size): + bias = getattr(self, f"bias{group_idx}") + if self.has_bias: + if bias is None: + raise RuntimeError( + f"Expected biases, but bias {group_idx} is uninitialized" + ) + if bias.dtype != dtype: + raise RuntimeError( + f"Bias {group_idx} has invalid dtype " + f"(expected {dtype}, got {bias.dtype})." + ) + if not devices_match(bias.device, device): + raise RuntimeError( + f"Bias {group_idx} has invalid device " + f"(expected {device}, got {bias.device})." + ) + if bias.requires_grad != weight_requires_grad: + raise RuntimeError( + f"Bias {group_idx} has requires_grad={bias.requires_grad}, " + f"but expected requires_grad={weight_requires_grad}." + ) + else: + if bias is not None: + raise RuntimeError( + f"Expected no biases, but bias {group_idx} is initialized" + ) + + def pre_fuser_forward(self, *, requires_grad: bool) -> None: + super().pre_fuser_forward(requires_grad=requires_grad) + if FP8GlobalStateManager.is_fp8_enabled(): + # Assume weights have consistent grad requirement + weight_requires_grad = requires_grad and self.weight0.requires_grad + + # Configure quantizer usages + # Note: We cache the quantized input for backward pass, + # but discard the quantized weights. + for group_idx in range(self.group_size): + input_quantizer = self.get_quantizer("forward", 2 * group_idx) + weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) + grad_output_quantizer = self.get_quantizer("backward", group_idx) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=False) + grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + + def op_forward(self, *args, **kwargs): + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs): + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + + # Check which grads are required + ctx = basic_op_ctxs[0] + input_requires_grad = ctx.requires_grad + weight_requires_grad = ctx.requires_grad and self.weight0.requires_grad + + # Quantizers + input_quantizers = None + weight_quantizers = None + grad_output_quantizers = None + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + if with_quantized_compute: + input_quantizers = [] + weight_quantizers = [] + grad_output_quantizers = [] + for group_idx in range(self.group_size): + input_quantizers.append(self.get_quantizer("forward", 2 * group_idx)) + weight_quantizers.append(self.get_quantizer("forward", 2 * group_idx + 1)) + grad_output_quantizers.append(self.get_quantizer("backward", group_idx)) + + # Get autocast dtype if needed + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = self.weight0.dtype + + # Extract split sizes from extra input + # TODO Support splits on GPU + split_sizes = basic_op_extra_inputs[0][0] + split_sizes_int = [int(s) for s in split_sizes.tolist()] + if len(split_sizes_int) != self.group_size: + raise ValueError( + f"Expected {self.group_size} splits, but got {len(split_sizes_int)}." + ) + + # Extract params + weights = [] + biases = [] + for group_idx in range(self.group_size): + weights.append(getattr(self, f"weight{group_idx}")) + biases.append(getattr(self, f"bias{group_idx}")) + + # Perform GEMMs + # TODO: Fused impl, quantization + xs = torch.split(input_, split_sizes_int) + ys = [] + for x, w, b in zip(xs, weights, biases): + y = torch.nn.functional.linear(x, w, bias=b) + ys.append(y) + out = torch.cat(ys) + + # Save state for backward pass + if ctx.requires_grad: + ctx.save_for_backward(split_sizes, *xs, *weights) + ctx.with_quantized_compute = with_quantized_compute + ctx.input_quantizers = input_quantizers + ctx.weight_quantizers = weight_quantizers + ctx.grad_output_quantizers = grad_output_quantizers + ctx.grad_input_quantizers = None + ctx.dtype = dtype + ctx.input_requires_grad = input_requires_grad + ctx.weight_requires_grad = weight_requires_grad + + return out, [()] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + group_size = self.group_size + has_bias = self.has_bias + + # Saved tensors from forward pass + ctx = basic_op_ctxs[0] + saved_tensors = ctx.saved_tensors + split_sizes, saved_tensors = saved_tensors[0], saved_tensors[1:] + xs, saved_tensors = saved_tensors[:group_size], saved_tensors[group_size:] + weights, saved_tensors = saved_tensors[:group_size], saved_tensors[group_size:] + + # Split grad output tensor + # TODO Support splits on GPU + split_sizes_int = [int(s) for s in split_sizes.tolist()] + dys = torch.split(grad_output, split_sizes_int) + + # Megatron-LM wgrad fusion + # Note: Get grad tensors from params so we can accumulate + # directly into it. + accumulate_into_main_grad = self._accumulate_into_main_grad + grad_weights = [None] * group_size + if ctx.weight_requires_grad and accumulate_into_main_grad: + for group_idx in range(group_size): + weight_param = getattr(self, f"weight{group_idx}") + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) + if not hasattr(weight_param, "main_grad"): + raise RuntimeError( + "GroupLinear op is configured with " + "accumulate_into_main_grad=True, " + "but weight parameter does not have main_grad attribute" + ) + grad_weights[group_idx] = weight_param.main_grad.detach() + else: + accumulate_into_main_grad = False + + # Compute grad biases + # TODO: Fuse with quantization + grad_biases = [None] * group_size + if ctx.weight_requires_grad and has_bias: + for group_idx in range(group_size): + dy = dys[group_idx] + grad_biases[group_idx] = dy.reshape(-1, dy.size(-1)).sum(0) + + # Perform GEMMs + # TODO: Fused impl, quantization + grad_input = None + if ctx.input_requires_grad: + dxs = [] + for group_idx in range(group_size): + dy_shape = list(dys[group_idx].size()) + dx = torch.matmul( + dys[group_idx].reshape(-1, dy_shape[-1]), + weights[group_idx], + ) + dxs.append(dx.reshape(dy_shape[:-1] + [dx.size(-1)])) + grad_input = torch.cat(dxs) + if ctx.weight_requires_grad: + for group_idx in range(group_size): + grad_weights[group_idx] = torch.matmul( + dys[group_idx].reshape(-1, dys[group_idx].size(-1)).T, + xs[group_idx].reshape(-1, xs[group_idx].size(-1)), + out=grad_weights[group_idx], + ) + + # Clear input tensors if possible + clear_tensor_data(*xs) + + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. + if accumulate_into_main_grad: + grad_weights = [None] * group_size + for group_idx in range(group_size): + weight_param = getattr(self, f"weight{group_idx}") + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weights[group_idx] = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + + grad_params = grad_weights + grad_biases if has_bias else grad_weights + return grad_input, [grad_params], [(None,)] From 5ffd57e74e16835fc3b1039bb66d6c63fdbadab2 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 7 Jan 2026 02:09:25 +0000 Subject: [PATCH 2/8] Use grouped GEMM tex functions Signed-off-by: Tim Moon --- .../pytorch/ops/basic/grouped_linear.py | 149 +++++++++++------- 1 file changed, 95 insertions(+), 54 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index e03710189f..325db168ce 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -13,7 +13,13 @@ import torch import transformer_engine_torch as tex -from ...module.base import get_dummy_wgrad +from ...cpp_extensions import general_grouped_gemm +from ...module.base import ( + _2X_ACC_FPROP, + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, + get_dummy_wgrad, +) from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ...utils import ( @@ -288,6 +294,8 @@ def fuser_forward( next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + group_size = self.group_size + has_bias = self.has_bias # Check which grads are required ctx = basic_op_ctxs[0] @@ -303,7 +311,7 @@ def fuser_forward( input_quantizers = [] weight_quantizers = [] grad_output_quantizers = [] - for group_idx in range(self.group_size): + for group_idx in range(group_size): input_quantizers.append(self.get_quantizer("forward", 2 * group_idx)) weight_quantizers.append(self.get_quantizer("forward", 2 * group_idx + 1)) grad_output_quantizers.append(self.get_quantizer("backward", group_idx)) @@ -318,26 +326,40 @@ def fuser_forward( # TODO Support splits on GPU split_sizes = basic_op_extra_inputs[0][0] split_sizes_int = [int(s) for s in split_sizes.tolist()] - if len(split_sizes_int) != self.group_size: + if len(split_sizes_int) != group_size: raise ValueError( - f"Expected {self.group_size} splits, but got {len(split_sizes_int)}." + f"Expected {group_size} splits, but got {len(split_sizes_int)}." ) # Extract params weights = [] - biases = [] - for group_idx in range(self.group_size): + biases = [] if has_bias else None + for group_idx in range(group_size): weights.append(getattr(self, f"weight{group_idx}")) - biases.append(getattr(self, f"bias{group_idx}")) + if has_bias: + biases.append(getattr(self, f"bias{group_idx}")) - # Perform GEMMs - # TODO: Fused impl, quantization + # Split input tensor xs = torch.split(input_, split_sizes_int) - ys = [] - for x, w, b in zip(xs, weights, biases): - y = torch.nn.functional.linear(x, w, bias=b) - ys.append(y) - out = torch.cat(ys) + + # Allocate output tensor + in_shape = list(input_.size()) + out_shape = in_shape[:-1] + [self.out_features] + out = torch.empty(out_shape, dtype=dtype, device=input_.device) + + # Perform GEMMs + general_grouped_gemm( + weights, + xs, + [out], + [None] * group_size, # quantization_params + dtype, + m_splits=split_sizes_int, + bias=biases, + use_bias=has_bias, + use_split_accumulator=_2X_ACC_FPROP, + single_output=True, + ) # Save state for backward pass if ctx.requires_grad: @@ -379,55 +401,74 @@ def fuser_backward( split_sizes_int = [int(s) for s in split_sizes.tolist()] dys = torch.split(grad_output, split_sizes_int) - # Megatron-LM wgrad fusion - # Note: Get grad tensors from params so we can accumulate - # directly into it. accumulate_into_main_grad = self._accumulate_into_main_grad grad_weights = [None] * group_size - if ctx.weight_requires_grad and accumulate_into_main_grad: - for group_idx in range(group_size): - weight_param = getattr(self, f"weight{group_idx}") - if hasattr(weight_param, "__fsdp_param__"): - weight_param.main_grad = weight_param.get_main_grad() - accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) - if not hasattr(weight_param, "main_grad"): - raise RuntimeError( - "GroupLinear op is configured with " - "accumulate_into_main_grad=True, " - "but weight parameter does not have main_grad attribute" + if ctx.weight_requires_grad: + if accumulate_into_main_grad: + # Megatron-LM wgrad fusion + # Note: Get grad tensors from params so we can + # accumulate directly into it. + for group_idx in range(group_size): + weight_param = getattr(self, f"weight{group_idx}") + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) + if not hasattr(weight_param, "main_grad"): + raise RuntimeError( + "GroupLinear op is configured with " + "accumulate_into_main_grad=True, " + "but weight parameter does not have main_grad attribute" + ) + else: + weight_shape = weights[0].size() + device = weights[0].device + for group_idx in range(group_size): + grad_weights[group_idx] = torch.empty( + weight_shape, + dtype=ctx.dtype, + device=device, ) - grad_weights[group_idx] = weight_param.main_grad.detach() else: accumulate_into_main_grad = False - # Compute grad biases - # TODO: Fuse with quantization - grad_biases = [None] * group_size - if ctx.weight_requires_grad and has_bias: - for group_idx in range(group_size): - dy = dys[group_idx] - grad_biases[group_idx] = dy.reshape(-1, dy.size(-1)).sum(0) - - # Perform GEMMs - # TODO: Fused impl, quantization + # Perform dgrad GEMMs grad_input = None if ctx.input_requires_grad: - dxs = [] - for group_idx in range(group_size): - dy_shape = list(dys[group_idx].size()) - dx = torch.matmul( - dys[group_idx].reshape(-1, dy_shape[-1]), - weights[group_idx], - ) - dxs.append(dx.reshape(dy_shape[:-1] + [dx.size(-1)])) - grad_input = torch.cat(dxs) + out_shape = list(grad_output.size()) + in_shape = out_shape[:-1] + [self.in_features] + grad_input = torch.empty( + in_shape, + dtype=ctx.dtype, + device=grad_output.device, + ) + general_grouped_gemm( + weights, + dys, + [grad_input], + [None] * group_size, # quantization_params + ctx.dtype, + layout="NN", + m_splits=split_sizes_int, + use_split_accumulator=_2X_ACC_DGRAD, + single_output=True, + ) + + # Perform wgrad GEMMs + grad_biases = [None] * group_size if ctx.weight_requires_grad: - for group_idx in range(group_size): - grad_weights[group_idx] = torch.matmul( - dys[group_idx].reshape(-1, dys[group_idx].size(-1)).T, - xs[group_idx].reshape(-1, xs[group_idx].size(-1)), - out=grad_weights[group_idx], - ) + _, grad_biases, _ = general_grouped_gemm( + xs, + dys, + grad_weights, + [None] * group_size, # quantization_params + ctx.dtype, + layout="NT", + m_splits=split_sizes_int, + grad=True, + use_bias=has_bias, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_into_main_grad, + ) # Clear input tensors if possible clear_tensor_data(*xs) From 2ee42da1bdc4cfcd326a9e55e785828b0eb19e79 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 8 Jan 2026 05:50:36 +0000 Subject: [PATCH 3/8] Support quantized compute Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 18 +- .../pytorch/ops/basic/grouped_linear.py | 156 ++++++++++++++---- 2 files changed, 136 insertions(+), 38 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index d2c84403c4..e7af692098 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1926,6 +1926,10 @@ def test_dropout( ), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})" @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantized_compute", (False, True)) + @pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("input_requires_grad", (False, True)) @pytest.mark.parametrize("weight_requires_grad", (False, True)) def test_grouped_linear( @@ -1933,13 +1937,13 @@ def test_grouped_linear( *, group_size: int = 4, bias: bool, - weight_shape: tuple[int, int] = (32, 32), - split_alignment: int = 32, - dtype: torch.dtype = torch.float32, + weight_shape: tuple[int, int] = (128, 128), + split_alignment: int = 128, + dtype: torch.dtype, device: torch.device = "cuda", - quantization: Optional[str] = None, - quantized_compute: bool = False, - quantized_weight: bool = False, + quantization: Optional[str], + quantized_compute: bool, + quantized_weight: bool, input_requires_grad: bool, weight_requires_grad: bool, ) -> None: @@ -1962,6 +1966,8 @@ def test_grouped_linear( pytest.skip("Quantization scheme is not specified") if quantization is not None and not (quantized_compute or quantized_weight): pytest.skip("Quantization scheme is not used") + if quantization is not None and dtype not in (torch.bfloat16, torch.float16): + pytest.skip("Quantized group GEMM is only supported with BF16/FP16") # Random data x_ref, x_test = make_reference_and_test_tensors( diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 325db168ce..3851b6e3ef 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -28,7 +28,7 @@ clear_tensor_data, devices_match, ) -from .._common import is_quantized_tensor +from .._common import is_quantized_tensor, maybe_dequantize from ..op import BasicOperation, OperationContext @@ -268,6 +268,49 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: weight_quantizer.set_usage(rowwise=True, columnwise=False) grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: + super().reset_recipe_state(recipe=recipe) + + for group_idx in range(self.group_size): + # Input/grad output quantizers use internal tensors + input_quantizer = self.get_quantizer("forward", 2 * group_idx) + grad_output_quantizer = self.get_quantizer("backward", group_idx) + if input_quantizer is not None: + input_quantizer.internal = True + if grad_output_quantizer is not None: + grad_output_quantizer.internal = True + + # Handle weight quantizer + # Note: This function may be called in base class constructor, + # before any basic linear attrs have been set. + weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) + if weight_quantizer is None: + pass + elif is_quantized_tensor(getattr(self, "weight", None)): + # Make sure weight param has correct quantizer + weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) + weight_quantizer.internal = False + self.weight.update_quantizer(weight_quantizer.copy()) + else: + # Use internal tensors if quantized weights will not be + # exposed externally + weight_quantizer.internal = ( + not FP8GlobalStateManager.with_fp8_parameters() + and not getattr(self, "_with_quantized_weight", False) + ) + + # Recipe-specific configuration + # Note: This function may be called in base class constructor, + # before any basic linear attrs have been set. + if recipe is not None: + if recipe.float8_current_scaling(): + input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon + grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon + def op_forward(self, *args, **kwargs): raise RuntimeError( "{self.__class__.__name__} operation has " @@ -303,18 +346,15 @@ def fuser_forward( weight_requires_grad = ctx.requires_grad and self.weight0.requires_grad # Quantizers - input_quantizers = None - weight_quantizers = None - grad_output_quantizers = None + input_quantizers = [None] * group_size + weight_quantizers = [None] * group_size + grad_output_quantizers = [None] * group_size with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() if with_quantized_compute: - input_quantizers = [] - weight_quantizers = [] - grad_output_quantizers = [] for group_idx in range(group_size): - input_quantizers.append(self.get_quantizer("forward", 2 * group_idx)) - weight_quantizers.append(self.get_quantizer("forward", 2 * group_idx + 1)) - grad_output_quantizers.append(self.get_quantizer("backward", group_idx)) + input_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx) + weight_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx + 1) + grad_output_quantizers[group_idx] = self.get_quantizer("backward", group_idx) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -332,15 +372,34 @@ def fuser_forward( ) # Extract params - weights = [] - biases = [] if has_bias else None - for group_idx in range(group_size): - weights.append(getattr(self, f"weight{group_idx}")) - if has_bias: - biases.append(getattr(self, f"bias{group_idx}")) - - # Split input tensor - xs = torch.split(input_, split_sizes_int) + weights = [getattr(self, f"weight{idx}") for idx in range(group_size)] + bs = None + if has_bias: + bs = [ + maybe_dequantize(getattr(self, f"bias{idx}"), dtype) + for idx in range(group_size) + ] + + # Convert weight dtype if needed + ws = [] + for w, quantizer in zip(weights, weight_quantizers): + if not with_quantized_compute: + w = maybe_dequantize(w, dtype) + elif with_quantized_compute and not is_quantized_tensor(w): + quantizer = weight_quantizers[group_idx] + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + w = quantizer(w) + ws.append(w) + + # Split input tensor and convert dtypes if needed + x = maybe_dequantize(input_, dtype) + xs = None + if with_quantized_compute: + for quantizer in input_quantizers: + quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + xs = tex.split_quantize(x, split_sizes_int, input_quantizers) + else: + xs = torch.split(x, split_sizes_int) # Allocate output tensor in_shape = list(input_.size()) @@ -349,21 +408,36 @@ def fuser_forward( # Perform GEMMs general_grouped_gemm( - weights, + ws, xs, [out], [None] * group_size, # quantization_params dtype, m_splits=split_sizes_int, - bias=biases, + bias=bs, use_bias=has_bias, use_split_accumulator=_2X_ACC_FPROP, single_output=True, ) + # Prepare weight tensors for backward pass + if not input_requires_grad: + ws = [None] * group_size + elif with_quantized_compute: + for w, weight_param in zip(ws, weights): + if w is not weight_param: + w.update_usage(rowwise_usage=False, columnwise_usage=True) + + # Prepare input tensor for backward pass + if not weight_requires_grad: + xs = [None] * group_size + elif with_quantized_compute: + for x in xs: + x.update_usage(rowwise_usage=False, columnwise_usage=True) + # Save state for backward pass if ctx.requires_grad: - ctx.save_for_backward(split_sizes, *xs, *weights) + ctx.save_for_backward(split_sizes, *xs, *ws) ctx.with_quantized_compute = with_quantized_compute ctx.input_quantizers = input_quantizers ctx.weight_quantizers = weight_quantizers @@ -394,13 +468,34 @@ def fuser_backward( saved_tensors = ctx.saved_tensors split_sizes, saved_tensors = saved_tensors[0], saved_tensors[1:] xs, saved_tensors = saved_tensors[:group_size], saved_tensors[group_size:] - weights, saved_tensors = saved_tensors[:group_size], saved_tensors[group_size:] + ws, saved_tensors = saved_tensors[:group_size], saved_tensors[group_size:] - # Split grad output tensor + # Split grad output tensor and convert dtypes if needed # TODO Support splits on GPU split_sizes_int = [int(s) for s in split_sizes.tolist()] - dys = torch.split(grad_output, split_sizes_int) + dy = maybe_dequantize(grad_output, ctx.dtype) + dys = None + grad_biases = [None] * group_size + if ctx.with_quantized_compute: + for quantizer in ctx.grad_output_quantizers: + quantizer.set_usage( + rowwise=ctx.input_requires_grad, + columnwise=ctx.weight_requires_grad, + ) + dys = tex.split_quantize(dy, split_sizes_int, ctx.grad_output_quantizers) + if has_bias: + grad_biases = [ + dy.reshape(-1, dy.size(-1)).sum(dim=0) + for dy in torch.split(grad_output, split_sizes_int) + ] + else: + dys = torch.split(grad_output, split_sizes_int) + if has_bias: + grad_biases = [ + dy.reshape(-1, dy.size(-1)).sum(dim=0) for dy in dys + ] + # Initialize grad weight grads accumulate_into_main_grad = self._accumulate_into_main_grad grad_weights = [None] * group_size if ctx.weight_requires_grad: @@ -420,8 +515,8 @@ def fuser_backward( "but weight parameter does not have main_grad attribute" ) else: - weight_shape = weights[0].size() - device = weights[0].device + weight_shape = ws[0].size() + device = grad_output.device for group_idx in range(group_size): grad_weights[group_idx] = torch.empty( weight_shape, @@ -442,7 +537,7 @@ def fuser_backward( device=grad_output.device, ) general_grouped_gemm( - weights, + ws, dys, [grad_input], [None] * group_size, # quantization_params @@ -454,9 +549,8 @@ def fuser_backward( ) # Perform wgrad GEMMs - grad_biases = [None] * group_size if ctx.weight_requires_grad: - _, grad_biases, _ = general_grouped_gemm( + general_grouped_gemm( xs, dys, grad_weights, @@ -464,8 +558,6 @@ def fuser_backward( ctx.dtype, layout="NT", m_splits=split_sizes_int, - grad=True, - use_bias=has_bias, use_split_accumulator=_2X_ACC_WGRAD, accumulate=accumulate_into_main_grad, ) From 93e71df5dbe073a17ec46a1c08557beb3dbf9d92 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 8 Jan 2026 06:06:12 +0000 Subject: [PATCH 4/8] Debug test failures with MXFP8 or NVFP4 params Signed-off-by: Tim Moon --- .../pytorch/ops/basic/grouped_linear.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 3851b6e3ef..1c71e4de73 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -122,20 +122,17 @@ def num_quantizers(self, mode: str) -> int: def has_bias(self) -> bool: return self.bias0 is not None - @torch.no_grad def reset_parameters(self) -> None: """Initialize parameter buffers and values""" - for group_idx in range(self.group_size): + # Parameter device + device = self.weight0.device + if device.type == "meta": + device = canonicalize_device(None) - # Parameters + # Initialize weights + for group_idx in range(self.group_size): weight = getattr(self, f"weight{group_idx}") - bias = getattr(self, f"bias{group_idx}") - - # Parameter device - device = weight.device - if device.type == "meta": - device = canonicalize_device(None) # Allocate buffers if needed if is_quantized_tensor(weight): @@ -146,8 +143,6 @@ def reset_parameters(self) -> None: ) elif not devices_match(weight.device, device): weight = torch.empty_like(weight, device=device) - if bias is not None and not devices_match(bias.device, device): - bias = torch.empty_like(bias, device=device) # Initialize values init_context = contextlib.nullcontext() @@ -155,12 +150,10 @@ def reset_parameters(self) -> None: init_context = self._rng_state_tracker_function().fork() with init_context: torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) - if bias is not None: - bias.zero_() # Quantize weight if needed if self._with_quantized_weight: - quantizer = self.get_quantizer("forward", 1) + quantizer = self.get_quantizer("forward", 2 * group_idx + 1) if quantizer is None: raise RuntimeError( "Tried to quantize weight with deferred initialization " @@ -181,10 +174,18 @@ def reset_parameters(self) -> None: if not isinstance(weight, torch.nn.Parameter): weight = torch.nn.Parameter(weight) setattr(self, f"weight{group_idx}", weight) - if bias is not None: - if not isinstance(bias, torch.nn.Parameter): - bias = torch.nn.Parameter(bias) - setattr(self, f"bias{group_idx}", bias) + + # Initialize biases if needed + if self.bias0 is not None: + with torch.no_grad(): + for group_idx in range(self.group_size): + bias = getattr(self, f"bias{group_idx}") + if not devices_match(bias.device, device): + bias = torch.empty_like(bias, device=device) + bias.zero_() + if not isinstance(bias, torch.nn.Parameter): + bias = torch.nn.Parameter(bias) + setattr(self, f"bias{group_idx}", bias) def pre_first_fuser_forward(self) -> None: super().pre_first_fuser_forward() From fdddc479482a91a20e74feb47368046a3c3f1725 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 10 Jan 2026 00:12:52 +0000 Subject: [PATCH 5/8] Add multiply op Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 64 ++++++++ .../pytorch/ops/basic/__init__.py | 1 + .../pytorch/ops/basic/multiply_extra_input.py | 152 ++++++++++++++++++ 3 files changed, 217 insertions(+) create mode 100644 transformer_engine/pytorch/ops/basic/multiply_extra_input.py diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index e7af692098..dfb92c6863 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2072,6 +2072,70 @@ def test_grouped_linear( else: assert b_test.grad is None + @pytest.mark.parametrize( + "input_shape,extra_input_shape", + ( + ((3,4,5), (3,4,5)), + ((6,7), ()), + ((), (8,9)), + ((10,11,12), (11,1)), + ((1,15), (13,14,15)), + ) + ) + @pytest.mark.parametrize("input_requires_grad", (False, True)) + @pytest.mark.parametrize("extra_input_requires_grad", (False, True)) + def test_multiply_extra_input( + self, + *, + input_shape: Iterable[int], + extra_input_shape: Iterable[int], + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + input_requires_grad: bool, + extra_input_requires_grad: bool, + ) -> None: + """Multiply two tensors""" + + # Random data + x1_ref, x1_test = make_reference_and_test_tensors( + input_shape, + test_dtype=dtype, + test_device=device, + requires_grad=input_requires_grad, + ) + x2_ref, x2_test = make_reference_and_test_tensors( + extra_input_shape, + test_dtype=dtype, + test_device=device, + requires_grad=extra_input_requires_grad, + ) + + # Plain PyTorch implementation + y_ref = x1_ref * x2_ref + if input_requires_grad or extra_input_requires_grad: + torch.square(y_ref).sum().backward() + + # Implementation with fusible operation + op = te_ops.MultiplyExtraInput() + y_test = op(x1_test, x2_test) + if input_requires_grad or extra_input_requires_grad: + torch.square(y_test).sum().backward() + + # Check results + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + if input_requires_grad: + dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dx1_test, x1_ref.grad, **tols) + else: + assert x1_test.grad is None + if extra_input_requires_grad: + dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dx2_test, x2_ref.grad, **tols) + else: + assert x2_test.grad is None + class TestFusedOps: """Tests for fused operations""" diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index a74f02e3a0..c119682151 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -29,6 +29,7 @@ from .l2normalization import L2Normalization from .layer_norm import LayerNorm from .make_extra_output import MakeExtraOutput +from .multiply_extra_input import MultiplyExtraInput from .quantize import Quantize from .reduce_scatter import ReduceScatter from .reshape import Reshape diff --git a/transformer_engine/pytorch/ops/basic/multiply_extra_input.py b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py new file mode 100644 index 0000000000..b4a763bde1 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py @@ -0,0 +1,152 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for multiplying with extra input tensor.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +from ..op import BasicOperation, OperationContext +from .._common import maybe_dequantize + + +def _reduce_broadcast_dims( + x: torch.Tensor, + target_shape: Iterable[int], +) -> torch.Tensor: + """Reduce a tensor down to a target shape. + + The input tensor shape and target shape are assumed to be + broadcast-compatible. In other words, a tensor with the target + shape can be broadcast to match the input tensor shape. + + """ + shape = tuple(x.size()) + target_shape = tuple(target_shape) + + # Return immediately if tensor already has correct shape + if shape == target_shape: + return x + + # Determine reduction dimensions + reduce_dims = [] + if len(shape) < len(target_shape): + raise ValueError( + "Invalid target shape " + f"(shape={shape} cannot be broadcast to shape={target_shape})." + ) + elif len(shape) > len(target_shape): + reduce_dims.extend(range(len(shape) - len(target_shape))) + for idx in range(-len(target_shape), 0): + if shape[idx] == target_shape[idx]: + pass + elif target_shape[idx] != 1: + raise ValueError( + "Invalid target shape " + f"(shape={shape} cannot be broadcast to shape={target_shape})." + ) + else: + reduce_dims.append(idx) + + # Perform reduction + return x.sum(reduce_dims).reshape(target_shape) + + +class MultiplyExtraInput(BasicOperation): + """Multiply with extra input tensor. + + If the tensor shapes do not match, they will follow NumPy + broadcasting semantics. + + """ + + # Operation expects extra input tensor + num_extra_inputs: int = 1 + + def op_forward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + extra_input = basic_op_extra_inputs[0][0] + + # Determine compute dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + elif isinstance(input_, torch.Tensor): + dtype = input_.dtype + else: + dtype = extra_input.dtype + + # Perform multiplication + x1 = maybe_dequantize(input_, dtype) + x2 = maybe_dequantize(extra_input, dtype) + output = input_ * extra_input + + # Save state for backward pass + ctx = basic_op_ctxs[0] + if ctx.requires_grad: + ctx.input_shape = x1.size() + ctx.extra_input_shape = extra_input.size() + ctx.input_requires_grad = True + if isinstance(input_, torch.Tensor): + ctx.input_requires_grad = input_.requires_grad + ctx.extra_input_requires_grad = extra_input.requires_grad + ctx.save_for_backward( + x1 if ctx.extra_input_requires_grad else None, + x2 if ctx.input_requires_grad else None, + ) + + return output, [()] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + ctx = basic_op_ctxs[0] + input_, extra_input = ctx.saved_tensors + grad_input = None + if ctx.input_requires_grad: + grad_input = _reduce_broadcast_dims( + grad_output * extra_input, + ctx.input_shape, + ) + grad_extra_input = None + if ctx.extra_input_requires_grad: + grad_extra_input = _reduce_broadcast_dims( + grad_output * input_, + ctx.extra_input_shape, + ) + return grad_input, [()], [(grad_extra_input,)] From b448a17d2be841eaa02ade6975dd0b4f401fa6fc Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 10 Jan 2026 03:21:49 +0000 Subject: [PATCH 6/8] Bug fixes Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/basic/grouped_linear.py | 9 ++------- .../pytorch/ops/basic/multiply_extra_input.py | 2 -- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 1c71e4de73..d2a4b379e5 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -490,7 +490,7 @@ def fuser_backward( for dy in torch.split(grad_output, split_sizes_int) ] else: - dys = torch.split(grad_output, split_sizes_int) + dys = torch.split(dy, split_sizes_int) if has_bias: grad_biases = [ dy.reshape(-1, dy.size(-1)).sum(dim=0) for dy in dys @@ -509,12 +509,7 @@ def fuser_backward( if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) - if not hasattr(weight_param, "main_grad"): - raise RuntimeError( - "GroupLinear op is configured with " - "accumulate_into_main_grad=True, " - "but weight parameter does not have main_grad attribute" - ) + grad_weights[group_idx] = weight_param.main_grad else: weight_shape = ws[0].size() device = grad_output.device diff --git a/transformer_engine/pytorch/ops/basic/multiply_extra_input.py b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py index b4a763bde1..c1846f5e0d 100644 --- a/transformer_engine/pytorch/ops/basic/multiply_extra_input.py +++ b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py @@ -114,8 +114,6 @@ def fuser_forward( ctx.input_shape = x1.size() ctx.extra_input_shape = extra_input.size() ctx.input_requires_grad = True - if isinstance(input_, torch.Tensor): - ctx.input_requires_grad = input_.requires_grad ctx.extra_input_requires_grad = extra_input.requires_grad ctx.save_for_backward( x1 if ctx.extra_input_requires_grad else None, From 9348138111c03202b82857ba944b7ce24306180e Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 15 Jan 2026 22:56:16 +0000 Subject: [PATCH 7/8] Fix linter warnings Signed-off-by: Tim Moon --- .../pytorch/ops/basic/grouped_linear.py | 44 ++++++++++++++++--- .../pytorch/ops/basic/multiply_extra_input.py | 3 +- 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index d2a4b379e5..e90482b399 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -5,7 +5,7 @@ """Fusible operation for bias.""" from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Callable, Iterable import contextlib import math from typing import Any, Optional @@ -14,13 +14,14 @@ import transformer_engine_torch as tex from ...cpp_extensions import general_grouped_gemm +from ...distributed import CudaRNGStatesTracker from ...module.base import ( _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, get_dummy_wgrad, ) -from ...quantization import FP8GlobalStateManager +from ...quantization import FP8GlobalStateManager, Recipe from ...tensor import Quantizer from ...utils import ( canonicalize_device, @@ -33,6 +34,40 @@ class GroupedLinear(BasicOperation): + """Apply multiple linear transformations: :math:``y_i = x_i W_i^T + b_i`` + + This is equivalent to splitting the input tensor along its first + dimension, applying a separate ``torch.nn.Linear`` to each split, + and concatenating along the first dimension. + + Paramters + --------- + group_size : int + Number of linear transformations. + in_features : int + Inner dimension of input tensor. + out_features : int + Inner dimension of output tensor. + bias : bool, default = ``True`` + Apply additive bias. + device : torch.device, default = default CUDA device + Tensor device. + dtype : torch.dtype, default = default dtype + Tensor datatype. + rng_state_tracker_function : callable + Function that returns ``CudaRNGStatesTracker``, which is used + for model-parallel weight initialization. + accumulate_into_main_grad : bool, default = ``False`` + Whether to directly accumulate weight gradients into the + weight's ``main_grad`` attribute instead of relying on PyTorch + autograd. The weight's ``main_grad`` must be set externally + and there is no guarantee that `grad` will be set or be + meaningful. This is primarily intented to integrate with + Megatron-LM. This argument along with weight tensor having + attribute ``overwrite_main_grad`` set to True will overwrite + ``main_grad`` instead of accumulating. + + """ # Operation expects input split sizes num_extra_inputs: int = 1 @@ -120,6 +155,7 @@ def num_quantizers(self, mode: str) -> int: @property def has_bias(self) -> bool: + """Whether an additive bias is being applied""" return self.bias0 is not None def reset_parameters(self) -> None: @@ -216,7 +252,7 @@ def pre_first_fuser_forward(self) -> None: f"Weight {group_idx} has requires_grad={weight.requires_grad}, " f"but expected requires_grad={weight_requires_grad}." ) - if type(weight.data) != weight_tensor_type: + if type(weight.data) != weight_tensor_type: # pylint: disable=unidiomatic-typecheck raise RuntimeError( f"Weight {group_idx} has invalid tensor type " f"(expected {weight_tensor_type.__name__}, " @@ -364,7 +400,6 @@ def fuser_forward( dtype = self.weight0.dtype # Extract split sizes from extra input - # TODO Support splits on GPU split_sizes = basic_op_extra_inputs[0][0] split_sizes_int = [int(s) for s in split_sizes.tolist()] if len(split_sizes_int) != group_size: @@ -472,7 +507,6 @@ def fuser_backward( ws, saved_tensors = saved_tensors[:group_size], saved_tensors[group_size:] # Split grad output tensor and convert dtypes if needed - # TODO Support splits on GPU split_sizes_int = [int(s) for s in split_sizes.tolist()] dy = maybe_dequantize(grad_output, ctx.dtype) dys = None diff --git a/transformer_engine/pytorch/ops/basic/multiply_extra_input.py b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py index c1846f5e0d..1209963872 100644 --- a/transformer_engine/pytorch/ops/basic/multiply_extra_input.py +++ b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py @@ -10,6 +10,7 @@ import torch +from ...tensor import Quantizer from ..op import BasicOperation, OperationContext from .._common import maybe_dequantize @@ -39,7 +40,7 @@ def _reduce_broadcast_dims( "Invalid target shape " f"(shape={shape} cannot be broadcast to shape={target_shape})." ) - elif len(shape) > len(target_shape): + if len(shape) > len(target_shape): reduce_dims.extend(range(len(shape) - len(target_shape))) for idx in range(-len(target_shape), 0): if shape[idx] == target_shape[idx]: From 536672976b0994c6d2e6689936cbfcb9c1bea8aa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Jan 2026 23:02:25 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_fusible_ops.py | 12 +++--- .../pytorch/ops/basic/grouped_linear.py | 39 ++++++++----------- .../pytorch/ops/basic/multiply_extra_input.py | 6 +-- 3 files changed, 24 insertions(+), 33 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index dfb92c6863..a39d4e521a 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2075,12 +2075,12 @@ def test_grouped_linear( @pytest.mark.parametrize( "input_shape,extra_input_shape", ( - ((3,4,5), (3,4,5)), - ((6,7), ()), - ((), (8,9)), - ((10,11,12), (11,1)), - ((1,15), (13,14,15)), - ) + ((3, 4, 5), (3, 4, 5)), + ((6, 7), ()), + ((), (8, 9)), + ((10, 11, 12), (11, 1)), + ((1, 15), (13, 14, 15)), + ), ) @pytest.mark.parametrize("input_requires_grad", (False, True)) @pytest.mark.parametrize("extra_input_requires_grad", (False, True)) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index e90482b399..b7a6b843e4 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -239,8 +239,7 @@ def pre_first_fuser_forward(self) -> None: weight = getattr(self, f"weight{group_idx}") if weight.dtype != dtype: raise RuntimeError( - f"Weight {group_idx} has invalid dtype " - f"(expected {dtype}, got {weight.dtype})." + f"Weight {group_idx} has invalid dtype (expected {dtype}, got {weight.dtype})." ) if not devices_match(weight.device, device): raise RuntimeError( @@ -264,13 +263,10 @@ def pre_first_fuser_forward(self) -> None: bias = getattr(self, f"bias{group_idx}") if self.has_bias: if bias is None: - raise RuntimeError( - f"Expected biases, but bias {group_idx} is uninitialized" - ) + raise RuntimeError(f"Expected biases, but bias {group_idx} is uninitialized") if bias.dtype != dtype: raise RuntimeError( - f"Bias {group_idx} has invalid dtype " - f"(expected {dtype}, got {bias.dtype})." + f"Bias {group_idx} has invalid dtype (expected {dtype}, got {bias.dtype})." ) if not devices_match(bias.device, device): raise RuntimeError( @@ -284,9 +280,7 @@ def pre_first_fuser_forward(self) -> None: ) else: if bias is not None: - raise RuntimeError( - f"Expected no biases, but bias {group_idx} is initialized" - ) + raise RuntimeError(f"Expected no biases, but bias {group_idx} is initialized") def pre_fuser_forward(self, *, requires_grad: bool) -> None: super().pre_fuser_forward(requires_grad=requires_grad) @@ -345,8 +339,12 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon - grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale - grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon + grad_output_quantizer.force_pow_2_scales = ( + recipe.fp8_quant_bwd_grad.power_2_scale + ) + grad_output_quantizer.amax_epsilon_scales = ( + recipe.fp8_quant_bwd_grad.amax_epsilon + ) def op_forward(self, *args, **kwargs): raise RuntimeError( @@ -403,18 +401,13 @@ def fuser_forward( split_sizes = basic_op_extra_inputs[0][0] split_sizes_int = [int(s) for s in split_sizes.tolist()] if len(split_sizes_int) != group_size: - raise ValueError( - f"Expected {group_size} splits, but got {len(split_sizes_int)}." - ) + raise ValueError(f"Expected {group_size} splits, but got {len(split_sizes_int)}.") # Extract params weights = [getattr(self, f"weight{idx}") for idx in range(group_size)] bs = None if has_bias: - bs = [ - maybe_dequantize(getattr(self, f"bias{idx}"), dtype) - for idx in range(group_size) - ] + bs = [maybe_dequantize(getattr(self, f"bias{idx}"), dtype) for idx in range(group_size)] # Convert weight dtype if needed ws = [] @@ -526,9 +519,7 @@ def fuser_backward( else: dys = torch.split(dy, split_sizes_int) if has_bias: - grad_biases = [ - dy.reshape(-1, dy.size(-1)).sum(dim=0) for dy in dys - ] + grad_biases = [dy.reshape(-1, dy.size(-1)).sum(dim=0) for dy in dys] # Initialize grad weight grads accumulate_into_main_grad = self._accumulate_into_main_grad @@ -542,7 +533,9 @@ def fuser_backward( weight_param = getattr(self, f"weight{group_idx}") if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() - accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) + accumulate_into_main_grad = not getattr( + weight_param, "overwrite_main_grad", False + ) grad_weights[group_idx] = weight_param.main_grad else: weight_shape = ws[0].size() diff --git a/transformer_engine/pytorch/ops/basic/multiply_extra_input.py b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py index 1209963872..f9dfef4d81 100644 --- a/transformer_engine/pytorch/ops/basic/multiply_extra_input.py +++ b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py @@ -37,8 +37,7 @@ def _reduce_broadcast_dims( reduce_dims = [] if len(shape) < len(target_shape): raise ValueError( - "Invalid target shape " - f"(shape={shape} cannot be broadcast to shape={target_shape})." + f"Invalid target shape (shape={shape} cannot be broadcast to shape={target_shape})." ) if len(shape) > len(target_shape): reduce_dims.extend(range(len(shape) - len(target_shape))) @@ -47,8 +46,7 @@ def _reduce_broadcast_dims( pass elif target_shape[idx] != 1: raise ValueError( - "Invalid target shape " - f"(shape={shape} cannot be broadcast to shape={target_shape})." + f"Invalid target shape (shape={shape} cannot be broadcast to shape={target_shape})." ) else: reduce_dims.append(idx)