diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 7691582f97..6c2c5eb43c 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2480,6 +2480,59 @@ def test_scaled_swiglu( assert_close_grads(x_test, x_ref, **tols) assert_close_grads(scales_test, scales_ref, **tols) + @pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128))) + @pytest.mark.parametrize("input_requires_grad", (False, True)) + @pytest.mark.parametrize("scales_requires_grad", (False, True)) + def test_scaled_srelu( + self, + *, + in_shape: Iterable[int], + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + input_requires_grad: bool, + scales_requires_grad: bool, + ) -> None: + """SReLU with post-scale""" + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=input_requires_grad, + ) + scales_ref, scales_test = make_reference_and_test_tensors( + in_shape[:-1], + test_dtype=dtype, + test_device=device, + requires_grad=scales_requires_grad, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y = torch.nn.functional.relu(x_ref).square() + y_ref = scales_ref.unsqueeze(-1) * y + if input_requires_grad or scales_requires_grad: + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.ScaledSReLU() + y_test = op(x_test, scales_test) + if input_requires_grad or scales_requires_grad: + y_test.backward(dy_test) + + # Check results + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + assert_close(y_test, y_ref, **tols) + assert_close_grads(x_test, x_ref, **tols) + assert_close_grads(scales_test, scales_ref, **tols) + def test_interleaved_scaled_swiglu(self): """SwiGLU with post-scale and block interleaved input format""" self.test_scaled_swiglu( @@ -3570,7 +3623,9 @@ def test_layernorm_mlp( @pytest.mark.parametrize("glu_interleave_size", (None, 32)) @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) @pytest.mark.parametrize("hidden_size", (128, 256)) - @pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu")) + @pytest.mark.parametrize( + "activation", ("scaled_swiglu", "scaled_clamped_qgeglu", "scaled_srelu") + ) def test_grouped_mlp( self, *, @@ -3588,7 +3643,7 @@ def test_grouped_mlp( delay_wgrad_compute: bool, activation: str, ) -> None: - """GroupedLinear + ScaledSwiGLU / ScaledClampedQGeGLU + GroupedLinear""" + """GroupedLinear + scaled activation + GroupedLinear""" # Split sizes split_sizes = [split_alignment * (i) for i in range(group_size)] @@ -3608,9 +3663,14 @@ def test_grouped_mlp( pytest.skip("single_grouped_bias requires bias=True") if with_quantization and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + if activation == "scaled_srelu" and quantization != "mxfp8": + pytest.skip("ScaledSReLU grouped MLP fusion is only supported with MXFP8") + if activation == "scaled_srelu" and glu_interleave_size is not None: + pytest.skip("SReLU does not use GLU interleaving") if quantization == "nvfp4" and activation == "scaled_clamped_qgeglu" and bias: # TODO: ksivaman: Need to debug numerics for this case. pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") + fc1_out_features = hidden_size if activation == "scaled_srelu" else 2 * hidden_size # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -3641,7 +3701,7 @@ def test_grouped_mlp( fc2_bs_ref, fc2_bs_test = [], [] for _ in range(group_size): fc1_w_ref, fc1_w_test = make_reference_and_test_tensors( - (2 * hidden_size, hidden_size), + (fc1_out_features, hidden_size), min=-0.25, max=0.25, quantization=quantization, @@ -3660,7 +3720,7 @@ def test_grouped_mlp( fc2_b_ref, fc2_b_test = None, None if bias: fc1_b_ref, fc1_b_test = make_reference_and_test_tensors( - (2 * hidden_size,), + (fc1_out_features,), min=-0.5, max=0.5, test_dtype=dtype, @@ -3689,7 +3749,7 @@ def test_grouped_mlp( for group_idx in range(group_size): x = xs[group_idx] x = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx]) - if glu_interleave_size is not None: + if activation != "scaled_srelu" and glu_interleave_size is not None: x = x.reshape( -1, 2 * hidden_size // (2 * glu_interleave_size), @@ -3698,15 +3758,20 @@ def test_grouped_mlp( ) x = x.transpose(1, 2) x = x.reshape(-1, 2 * hidden_size) - x1, x2 = x.chunk(2, dim=-1) if activation == "scaled_swiglu": + x1, x2 = x.chunk(2, dim=-1) x = torch.nn.functional.silu(x1) * x2 - else: + elif activation == "scaled_clamped_qgeglu": + x1, x2 = x.chunk(2, dim=-1) lim = torch.tensor(7.0, device=x1.device, dtype=x1.dtype) geglu_alpha = 1.702 x1c = torch.minimum(x1, lim) x2c = torch.clamp(x2, -lim, lim) x = (x2c + 1) * (x1c * torch.sigmoid(geglu_alpha * x1c)) + elif activation == "scaled_srelu": + x = torch.nn.functional.relu(x).square() + else: + raise ValueError(f"Unexpected grouped MLP activation ({activation})") x = x * probs[group_idx].unsqueeze(-1) x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx]) if bias: @@ -3717,16 +3782,19 @@ def test_grouped_mlp( # Construct operations recipe = make_recipe(quantization) - scaled_act = ( - te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) - if activation == "scaled_swiglu" - else te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) - ) + if activation == "scaled_swiglu": + scaled_act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + elif activation == "scaled_clamped_qgeglu": + scaled_act = te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + elif activation == "scaled_srelu": + scaled_act = te_ops.ScaledSReLU() + else: + raise ValueError(f"Unexpected grouped MLP activation ({activation})") with te.quantized_model_init(enabled=with_quantization, recipe=recipe): fc1 = te_ops.GroupedLinear( group_size, hidden_size, - 2 * hidden_size, + fc1_out_features, bias=bias, device=device, dtype=dtype, @@ -3810,22 +3878,31 @@ def test_grouped_mlp( if ( quantization == "mxfp8" and dtype in (torch.bfloat16, torch.float16) - and glu_interleave_size == 32 + and ( + (activation == "scaled_srelu" and glu_interleave_size is None) + or (activation != "scaled_srelu" and glu_interleave_size == 32) + ) and _cudnn_frontend_version_supported() ): - if te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): + if activation == "scaled_srelu": + forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMSReLU_MXFP8 + backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSReLU_MXFP8 + else: + forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 + backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8 + if forward_cls.is_supported(): forward_ops = module._module_groups[0]._forward_ops assert len(forward_ops) == 1 assert isinstance( forward_ops[0][0], - te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + forward_cls, ) - if te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported(): + if backward_cls is not None and backward_cls.is_supported(): backward_ops = module._module_groups[0]._backward_ops assert len(backward_ops) == 1 assert isinstance( backward_ops[0][0], - te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, + backward_cls, ) # Loose tols for sanity checking diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 9325d87ae7..246b0c4e68 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -182,8 +182,13 @@ def get_dummy_wgrads_for_params( return out -def validate_grouped_mlp_dims(fc1, glu_op, fc2) -> None: - """Validate FC1 / scaled GLU / FC2 dimensions for fused grouped MLP.""" +def validate_grouped_mlp_dims(fc1, activation_op, fc2) -> None: + """Validate FC1 / activation / FC2 dimensions for fused grouped MLP.""" + from .basic import ( # pylint: disable=import-outside-toplevel + ScaledSReLU, + ScaledClampedQGeGLU, + ScaledSwiGLU, + ) if fc1.in_features % 64 != 0 or fc1.out_features % 64 != 0: raise ValueError( @@ -195,17 +200,27 @@ def validate_grouped_mlp_dims(fc1, glu_op, fc2) -> None: f"Unsupported dims for FC2 (num_groups={fc2.num_groups}, " f"in_features={fc2.in_features}, out_features={fc2.out_features})." ) - if fc1.out_features != 2 * fc2.in_features or fc1.num_groups != fc2.num_groups: + if isinstance(activation_op, (ScaledSwiGLU, ScaledClampedQGeGLU)): + expected_fc1_out_features = 2 * fc2.in_features + elif isinstance(activation_op, ScaledSReLU): + expected_fc1_out_features = fc2.in_features + else: + raise TypeError(f"Unsupported grouped MLP activation ({activation_op.__class__.__name__}).") + + if fc1.out_features != expected_fc1_out_features or fc1.num_groups != fc2.num_groups: raise ValueError( f"FC1 (num_groups={fc1.num_groups}, in_features={fc1.in_features}, " f"out_features={fc1.out_features}) " f"and FC2 (num_groups={fc2.num_groups}, in_features={fc2.in_features}, " f"out_features={fc2.out_features}) do not match." ) - if glu_op.glu_interleave_size != 32: + if ( + isinstance(activation_op, (ScaledSwiGLU, ScaledClampedQGeGLU)) + and activation_op.glu_interleave_size != 32 + ): raise ValueError( "Fused kernel requires 32-wide GLU interleaving, " - f"but got glu_interleave_size={glu_op.glu_interleave_size}." + f"but got glu_interleave_size={activation_op.glu_interleave_size}." ) @@ -214,8 +229,10 @@ def fuse_grouped_mlp_ops( *, recipe, fused_op_cls, + activation_op_types=None, + activation_kwarg: str = "swiglu", ): - """Sliding-window fusion for GroupedLinear + scaled GLU + GroupedLinear. + """Sliding-window fusion for GroupedLinear + activation + GroupedLinear. Parameters ---------- @@ -225,9 +242,7 @@ def fuse_grouped_mlp_ops( Quantization recipe. fused_op_cls : type Fused operation class with ``is_supported()`` classmethod and - constructor accepting ``fc1``, ``glu_op``, ``fc2`` keyword args. The - ``glu_op`` must be :class:`~transformer_engine.pytorch.ops.basic.swiglu.ScaledSwiGLU` - or :class:`~transformer_engine.pytorch.ops.basic.swiglu.ScaledClampedQGeGLU`. + constructor accepting ``fc1``, activation op, and ``fc2`` keyword args. Returns ------- @@ -244,6 +259,8 @@ def fuse_grouped_mlp_ops( return ops if recipe is None or not recipe.mxfp8(): return ops + if activation_op_types is None: + activation_op_types = (ScaledSwiGLU, ScaledClampedQGeGLU) out = [] window, ops = ops[:3], ops[3:] @@ -252,7 +269,7 @@ def fuse_grouped_mlp_ops( matches_pattern = True if not ( isinstance(window[0], GroupedLinear) - and isinstance(window[1], (ScaledSwiGLU, ScaledClampedQGeGLU)) + and isinstance(window[1], activation_op_types) and isinstance(window[2], GroupedLinear) ): matches_pattern = False @@ -260,23 +277,17 @@ def fuse_grouped_mlp_ops( abs(window[1]._clamped.alpha - 1.702) > 0.001 ): matches_pattern = False - elif window[0].num_groups != window[2].num_groups: - matches_pattern = False - elif ( - window[0].in_features % 64 != 0 - or window[0].out_features % 64 != 0 - or window[2].in_features % 64 != 0 - or window[2].out_features % 64 != 0 - ): - matches_pattern = False - elif window[1].glu_interleave_size != 32: - matches_pattern = False + else: + try: + validate_grouped_mlp_dims(window[0], window[1], window[2]) + except (TypeError, ValueError): + matches_pattern = False if matches_pattern: op = fused_op_cls( fc1=window[0], - swiglu=window[1], fc2=window[2], + **{activation_kwarg: window[1]}, ) window = [op] else: diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 45c938ede8..6def36ffc7 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -13,6 +13,7 @@ ReLU, ReGLU, SReLU, + ScaledSReLU, SReGLU, SiLU, ) diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 13cb519c19..3d927c24a1 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -6,7 +6,8 @@ from __future__ import annotations import abc -from typing import Optional +from collections.abc import Iterable +from typing import Any, Optional import torch @@ -26,6 +27,7 @@ "ReLU", "ReGLU", "SReLU", + "ScaledSReLU", "SReGLU", "SiLU", ] @@ -345,6 +347,100 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dsrelu(*args, **kwargs) +class ScaledSReLU(BasicOperation): + r"""Squared ReLU with per-row post-scaling. + + If the SReLU output has shape ``(d_1, ..., d_n)``, it is multiplied + with an extra input tensor of shape ``(d_1, ..., d_{n-1})``. + """ + + num_extra_inputs: int = 1 + + def op_forward(self, *args, **kwargs) -> None: + raise RuntimeError( + f"{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs) -> None: + raise RuntimeError( + f"{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], # pylint: disable=unused-argument + basic_op_kwargs: list[dict[str, Any]], # pylint: disable=unused-argument + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + extra_input = basic_op_extra_inputs[0][0] + + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + elif isinstance(input_, torch.Tensor): + dtype = input_.dtype + else: + dtype = extra_input.dtype + + x = maybe_dequantize(input_.contiguous(), dtype) + scales = maybe_dequantize(extra_input, dtype) + y = tex.srelu(x, None) * scales.unsqueeze(-1) + + ctx = basic_op_ctxs[0] + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x) + ctx.input_requires_grad = True + ctx.extra_input_requires_grad = extra_input.requires_grad + ctx.dtype = dtype + ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer + ctx.save_for_backward(x, scales) + + return y, [()] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + del basic_op_grad_extra_outputs + + ctx = basic_op_ctxs[0] + x, scales = ctx.saved_tensors + x = maybe_dequantize(x.contiguous(), ctx.dtype) + scales = maybe_dequantize(scales, ctx.dtype) + grad_output = maybe_dequantize(grad_output.contiguous(), ctx.dtype) + + grad_input = None + if ctx.input_requires_grad: + grad_srelu_out = grad_output * scales.unsqueeze(-1) + grad_input = tex.dsrelu(grad_srelu_out, x, ctx.prev_op_grad_output_quantizer) + + grad_extra_input = None + if ctx.extra_input_requires_grad: + srelu_out = tex.srelu(x, None) + grad_extra_input = torch.linalg.vecdot(srelu_out, grad_output) + + clear_tensor_data(ctx.saved_tensors[0]) + + return grad_input, [()], [(grad_extra_input,)] + + class SReGLU(_ActivationOperation): r"""Squared Rectified Gated Linear Unit diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 19a090f121..84bf9421fa 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -33,7 +33,9 @@ # Note: Registration logic is non-trivial, so submodule handles it internally. from .forward_grouped_mlp import ( # pylint: disable=wrong-import-position ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + ForwardGroupedMLP_CuTeGEMMSReLU_MXFP8, ) from .backward_grouped_mlp import ( # pylint: disable=wrong-import-position BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, + BackwardGroupedMLP_CuTeGEMMDSReLU_MXFP8, ) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index a11d0505c1..aab52ef8f7 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -18,7 +18,7 @@ from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...utils import clear_tensor_data, get_cached_ones_tensor, get_device_compute_capability from ...constants import MXFP8_BLOCK_SCALING_SIZE -from ..basic import GroupedLinear, ScaledClampedQGeGLU, ScaledSwiGLU +from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU, ScaledSwiGLU from ..fuser import register_backward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext from .._common import ( @@ -306,19 +306,32 @@ def __init__( self, *, fc1: GroupedLinear, - swiglu: ScaledSwiGLU | ScaledClampedQGeGLU, + swiglu: Optional[ScaledSwiGLU | ScaledClampedQGeGLU] = None, + srelu: Optional[ScaledSReLU] = None, fc2: GroupedLinear, ) -> None: - super().__init__((fc1, swiglu, fc2)) + if swiglu is not None and srelu is not None: + raise TypeError( + "Expected exactly one activation op, but both swiglu and srelu were provided." + ) + activation = swiglu if swiglu is not None else srelu + if activation is None: + raise TypeError("Expected a grouped MLP activation op.") + super().__init__((fc1, activation, fc2)) if not self.is_supported(): self.grouped_gemm_dglu_kernel() # Try triggering import error raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.") - validate_grouped_mlp_dims(fc1, swiglu, fc2) - # The cuDNN dgeglu implementation corresponds to ScaledClampedQGeGLU. - # The act_func string should be fixed on the cuDNN FE side. - self._cudnn_dact_func: str = ( - "dgeglu" if isinstance(swiglu, ScaledClampedQGeGLU) else "dswiglu" - ) + validate_grouped_mlp_dims(fc1, activation, fc2) + if isinstance(activation, ScaledSReLU): + # grouped_gemm_dsrelu_wrapper_sm100 is dSReLU-specific and does not + # take the GLU ``act_func`` selector. + self._cudnn_dact_func: Optional[str] = None + else: + # The cuDNN dgeglu implementation corresponds to ScaledClampedQGeGLU. + # The act_func string should be fixed on the cuDNN FE side. + self._cudnn_dact_func = ( + "dgeglu" if isinstance(activation, ScaledClampedQGeGLU) else "dswiglu" + ) def fuser_backward( self, @@ -333,7 +346,7 @@ def fuser_backward( # Get basic operations fc1_op, _, fc2_op = self.basic_ops - fc1_ctx, swiglu_ctx, fc2_ctx = basic_op_ctxs + fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs # Tensor properties fc1_weight_shape = (fc1_op.out_features, fc1_op.in_features) @@ -358,8 +371,8 @@ def fuser_backward( saved_tensors[num_groups:], ) - # Saved tensors from scaled SwiGLU forward - swiglu_in, scales = swiglu_ctx.saved_tensors + # Saved tensors from activation forward + swiglu_in, scales = activation_ctx.saved_tensors # Saved tensors from FC2 forward. # Layout: [split_sizes, base_split_offsets, split_points, @@ -459,7 +472,6 @@ def fuser_backward( "sfa_tensor": fc2_dy_scales, "padded_offsets": split_points, "alpha_tensor": alpha_tensor, - "beta_tensor": alpha_tensor, "prob_tensor": scales_tensor, "dprob_tensor": dscales_tensor, "generate_dbias": fc1_op.has_bias, @@ -469,9 +481,11 @@ def fuser_backward( "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, "current_stream": current_stream, "discrete_col_sfd": True, - "act_func": self._cudnn_dact_func, "use_dynamic_sched": True, } + if self._cudnn_dact_func is not None: + fc2_dglu_kwargs["beta_tensor"] = alpha_tensor + fc2_dglu_kwargs["act_func"] = self._cudnn_dact_func if fc2_op.single_grouped_weight: # Clone and swizzle scales for GEMM @@ -547,7 +561,8 @@ def fuser_backward( else: fc2_bias_grads = [fc2_dbias_packed[idx] for idx in range(num_groups)] - grad_scales = grad_scales.to(dtype=dtype) + if grad_scales is not None: + grad_scales = grad_scales.to(dtype=dtype) fc1_bias_grads: Optional[list[Optional[torch.Tensor]]] = None fc1_bias_grad_packed: Optional[torch.Tensor] = None @@ -703,13 +718,26 @@ def fuser_backward( ) fc2_grad_extra = (None, None) if fc2_op._scale_bias else (None,) + activation_grad_extra = (grad_scales,) if grad_scales is not None else () return ( grad_input, [fc1_grad_params, (), fc2_grad_params], - [(None,), (grad_scales,), fc2_grad_extra], + [(None,), activation_grad_extra, fc2_grad_extra], ) +class BackwardGroupedMLP_CuTeGEMMDSReLU_MXFP8(BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8): + """Fused backward op for GroupedLinear + ScaledSReLU + GroupedLinear.""" + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_dglu_kernel(cls) -> Callable: + """Fused kernel for grouped GEMM and dSReLU activation backward.""" + from cudnn import grouped_gemm_dsrelu_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_dsrelu_wrapper_sm100 + + def fuse_backward_ops( ops: list[FusibleOperation], *, @@ -739,6 +767,25 @@ def fuse_backward_ops( ) +def fuse_backward_srelu_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply GroupedLinear + ScaledSReLU + GroupedLinear fusion for backward pass.""" + + return fuse_grouped_mlp_ops( + ops, + recipe=recipe, + fused_op_cls=BackwardGroupedMLP_CuTeGEMMDSReLU_MXFP8, + activation_op_types=(ScaledSReLU,), + activation_kwarg="srelu", + ) + + # Register fusion if available if BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported(): register_backward_fusion(fuse_backward_ops, prepend=True) +if BackwardGroupedMLP_CuTeGEMMDSReLU_MXFP8.is_supported(): + register_backward_fusion(fuse_backward_srelu_ops, prepend=True) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 91db2ff9b7..59509fc3ae 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -19,7 +19,7 @@ from ...tensor.grouped_tensor import GroupedTensor from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...constants import MXFP8_BLOCK_SCALING_SIZE -from ..basic import GroupedLinear, ScaledClampedQGeGLU, ScaledSwiGLU +from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU, ScaledSwiGLU from ..fuser import register_forward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext from .._common import ( @@ -89,17 +89,32 @@ def __init__( self, *, fc1: GroupedLinear, - swiglu: ScaledSwiGLU | ScaledClampedQGeGLU, + swiglu: Optional[ScaledSwiGLU | ScaledClampedQGeGLU] = None, + srelu: Optional[ScaledSReLU] = None, fc2: GroupedLinear, ) -> None: - super().__init__((fc1, swiglu, fc2)) + if swiglu is not None and srelu is not None: + raise TypeError( + "Expected exactly one activation op, but both swiglu and srelu were provided." + ) + activation = swiglu if swiglu is not None else srelu + if activation is None: + raise TypeError("Expected a grouped MLP activation op.") + super().__init__((fc1, activation, fc2)) if not self.is_supported(): self.grouped_gemm_glu_kernel() # Try triggering import error raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.") - validate_grouped_mlp_dims(fc1, swiglu, fc2) - # The cuDNN geglu implementation corresponds to ScaledClampedQGeGLU. - # The act_func string should be fixed on the cuDNN FE side. - self._cudnn_act_func: str = "geglu" if isinstance(swiglu, ScaledClampedQGeGLU) else "swiglu" + validate_grouped_mlp_dims(fc1, activation, fc2) + if isinstance(activation, ScaledSReLU): + # grouped_gemm_srelu_wrapper_sm100 is SReLU-specific and does not + # take the GLU ``act_func`` selector. + self._cudnn_act_func: Optional[str] = None + else: + # The cuDNN geglu implementation corresponds to ScaledClampedQGeGLU. + # The act_func string should be fixed on the cuDNN FE side. + self._cudnn_act_func = ( + "geglu" if isinstance(activation, ScaledClampedQGeGLU) else "swiglu" + ) def fuser_forward( self, @@ -113,7 +128,7 @@ def fuser_forward( ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: # Get basic operations fc1_op, _, fc2_op = self.basic_ops - fc1_ctx, swiglu_ctx, fc2_ctx = basic_op_ctxs + fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs # Tensor properties fc1_weight_shape = (fc1_op.out_features, fc1_op.in_features) @@ -164,7 +179,7 @@ def fuser_forward( split_points = base_split_offsets[1:].to(dtype=torch.int) fc2_x_tensor_offsets = base_split_offsets * fc2_weight_shape[1] - # Extract post-scales from extra input + # Extract per-row activation probabilities from the middle op. scales = basic_op_extra_inputs[1][0] # Prepare FC1 grouped weight tensor for fused kernels. @@ -294,7 +309,11 @@ def fuser_forward( "alpha_tensor": alpha_tensor, "bias_tensor": fc1_bias_packed, "norm_const_tensor": norm_const_tensor, - "prob_tensor": scales.detach().to(dtype=dtype).reshape(-1, 1, 1), + "prob_tensor": ( + scales.detach().to(dtype=dtype).reshape(-1, 1, 1) + if scales is not None + else torch.ones((in_shape[0], 1, 1), dtype=torch.float32, device=device) + ), "acc_dtype": torch.float32, "c_dtype": torch.bfloat16, "d_dtype": torch.float8_e4m3fn, @@ -302,9 +321,10 @@ def fuser_forward( "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, "current_stream": current_stream, "discrete_col_sfd": True, - "act_func": self._cudnn_act_func, "use_dynamic_sched": True, } + if self._cudnn_act_func is not None: + fc1_glu_kwargs["act_func"] = self._cudnn_act_func if fc1_op.single_grouped_weight: # Clone and swizzle scales for GEMM. @@ -451,6 +471,7 @@ def fuser_forward( # Save state for backward pass if requires_grad: mark_grouped_tensor(grouped_fc1_x, swiglu_in, scales, grouped_fc2_x) + activation_op = self.basic_ops[1] # Save the input ``GroupedTensor``s themselves for the activations. for grouped_fc_x in (grouped_fc1_x, grouped_fc2_x): @@ -481,11 +502,13 @@ def fuser_forward( fc1_ctx.input_requires_grad = input_requires_grad fc1_ctx.weight_requires_grad = weight_requires_grad - # Scaled SwiGLU - swiglu_ctx.save_for_backward(swiglu_in, scales) - swiglu_ctx.input_requires_grad = True - swiglu_ctx.extra_input_requires_grad = True - swiglu_ctx.dtype = dtype + # Activation + activation_ctx.save_for_backward(swiglu_in, scales) + activation_ctx.extra_input_requires_grad = True + if isinstance(activation_op, ScaledSReLU): + activation_ctx.prev_op_grad_output_quantizer = fc1_grad_output_quantizer + activation_ctx.input_requires_grad = True + activation_ctx.dtype = dtype # FC2 saved-tensor layout. Matches the unfused # ``GroupedLinear._fuser_forward_grouped_tensor`` layout so the @@ -520,6 +543,18 @@ def fuser_forward( return fc2_out, [(), (), ()] +class ForwardGroupedMLP_CuTeGEMMSReLU_MXFP8(ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8): + """Fused op for MXFP8 GroupedLinear + ScaledSReLU + GroupedLinear.""" + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_glu_kernel(cls) -> Callable: + """Fused kernel for grouped GEMM, SReLU activation, and post-multiplication.""" + from cudnn import grouped_gemm_srelu_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_srelu_wrapper_sm100 + + def fuse_forward_ops( ops: list[FusibleOperation], *, @@ -549,6 +584,25 @@ def fuse_forward_ops( ) +def fuse_forward_srelu_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply GroupedLinear + ScaledSReLU + GroupedLinear fusion for forward pass.""" + + return fuse_grouped_mlp_ops( + ops, + recipe=recipe, + fused_op_cls=ForwardGroupedMLP_CuTeGEMMSReLU_MXFP8, + activation_op_types=(ScaledSReLU,), + activation_kwarg="srelu", + ) + + # Register fusion if available if ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): register_forward_fusion(fuse_forward_ops, prepend=True) +if ForwardGroupedMLP_CuTeGEMMSReLU_MXFP8.is_supported(): + register_forward_fusion(fuse_forward_srelu_ops, prepend=True)