Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 95 additions & 18 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2480,6 +2480,59 @@ def test_scaled_swiglu(
assert_close_grads(x_test, x_ref, **tols)
assert_close_grads(scales_test, scales_ref, **tols)

@pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128)))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("scales_requires_grad", (False, True))
def test_scaled_srelu(
self,
*,
in_shape: Iterable[int],
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
input_requires_grad: bool,
scales_requires_grad: bool,
) -> None:
"""SReLU with post-scale"""

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=input_requires_grad,
)
scales_ref, scales_test = make_reference_and_test_tensors(
in_shape[:-1],
test_dtype=dtype,
test_device=device,
requires_grad=scales_requires_grad,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
y = torch.nn.functional.relu(x_ref).square()
y_ref = scales_ref.unsqueeze(-1) * y
if input_requires_grad or scales_requires_grad:
y_ref.backward(dy_ref)

# Implementation with fusible operation
op = te_ops.ScaledSReLU()
y_test = op(x_test, scales_test)
if input_requires_grad or scales_requires_grad:
y_test.backward(dy_test)

# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
assert_close(y_test, y_ref, **tols)
assert_close_grads(x_test, x_ref, **tols)
assert_close_grads(scales_test, scales_ref, **tols)

def test_interleaved_scaled_swiglu(self):
"""SwiGLU with post-scale and block interleaved input format"""
self.test_scaled_swiglu(
Expand Down Expand Up @@ -3570,7 +3623,9 @@ def test_layernorm_mlp(
@pytest.mark.parametrize("glu_interleave_size", (None, 32))
@pytest.mark.parametrize("delay_wgrad_compute", (False, True))
@pytest.mark.parametrize("hidden_size", (128, 256))
@pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu"))
@pytest.mark.parametrize(
"activation", ("scaled_swiglu", "scaled_clamped_qgeglu", "scaled_srelu")
)
def test_grouped_mlp(
self,
*,
Expand All @@ -3588,7 +3643,7 @@ def test_grouped_mlp(
delay_wgrad_compute: bool,
activation: str,
) -> None:
"""GroupedLinear + ScaledSwiGLU / ScaledClampedQGeGLU + GroupedLinear"""
"""GroupedLinear + scaled activation + GroupedLinear"""

# Split sizes
split_sizes = [split_alignment * (i) for i in range(group_size)]
Expand All @@ -3608,9 +3663,14 @@ def test_grouped_mlp(
pytest.skip("single_grouped_bias requires bias=True")
if with_quantization and dtype not in (torch.bfloat16, torch.float16):
pytest.skip("Quantized group GEMM is only supported with BF16/FP16")
if activation == "scaled_srelu" and quantization != "mxfp8":
pytest.skip("ScaledSReLU grouped MLP fusion is only supported with MXFP8")
if activation == "scaled_srelu" and glu_interleave_size is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This is assuming that activations are GLUs by default, and SReLU is weird. Isn't that kind of backward? In any case, it would be more logical to have a single point where we check is_glu_activation, and then use that everywhere.

pytest.skip("SReLU does not use GLU interleaving")
if quantization == "nvfp4" and activation == "scaled_clamped_qgeglu" and bias:
# TODO: ksivaman: Need to debug numerics for this case.
pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU")
fc1_out_features = hidden_size if activation == "scaled_srelu" else 2 * hidden_size

# Random data
x_ref, x_test = make_reference_and_test_tensors(
Expand Down Expand Up @@ -3641,7 +3701,7 @@ def test_grouped_mlp(
fc2_bs_ref, fc2_bs_test = [], []
for _ in range(group_size):
fc1_w_ref, fc1_w_test = make_reference_and_test_tensors(
(2 * hidden_size, hidden_size),
(fc1_out_features, hidden_size),
min=-0.25,
max=0.25,
quantization=quantization,
Expand All @@ -3660,7 +3720,7 @@ def test_grouped_mlp(
fc2_b_ref, fc2_b_test = None, None
if bias:
fc1_b_ref, fc1_b_test = make_reference_and_test_tensors(
(2 * hidden_size,),
(fc1_out_features,),
min=-0.5,
max=0.5,
test_dtype=dtype,
Expand Down Expand Up @@ -3689,7 +3749,7 @@ def test_grouped_mlp(
for group_idx in range(group_size):
x = xs[group_idx]
x = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx])
if glu_interleave_size is not None:
if activation != "scaled_srelu" and glu_interleave_size is not None:
x = x.reshape(
-1,
2 * hidden_size // (2 * glu_interleave_size),
Expand All @@ -3698,15 +3758,20 @@ def test_grouped_mlp(
)
x = x.transpose(1, 2)
x = x.reshape(-1, 2 * hidden_size)
x1, x2 = x.chunk(2, dim=-1)
if activation == "scaled_swiglu":
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
else:
elif activation == "scaled_clamped_qgeglu":
x1, x2 = x.chunk(2, dim=-1)
lim = torch.tensor(7.0, device=x1.device, dtype=x1.dtype)
geglu_alpha = 1.702
x1c = torch.minimum(x1, lim)
x2c = torch.clamp(x2, -lim, lim)
x = (x2c + 1) * (x1c * torch.sigmoid(geglu_alpha * x1c))
elif activation == "scaled_srelu":
x = torch.nn.functional.relu(x).square()
else:
raise ValueError(f"Unexpected grouped MLP activation ({activation})")
x = x * probs[group_idx].unsqueeze(-1)
x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx])
if bias:
Expand All @@ -3717,16 +3782,19 @@ def test_grouped_mlp(

# Construct operations
recipe = make_recipe(quantization)
scaled_act = (
te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
if activation == "scaled_swiglu"
else te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size)
)
if activation == "scaled_swiglu":
scaled_act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
elif activation == "scaled_clamped_qgeglu":
scaled_act = te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size)
elif activation == "scaled_srelu":
scaled_act = te_ops.ScaledSReLU()
else:
raise ValueError(f"Unexpected grouped MLP activation ({activation})")
with te.quantized_model_init(enabled=with_quantization, recipe=recipe):
fc1 = te_ops.GroupedLinear(
group_size,
hidden_size,
2 * hidden_size,
fc1_out_features,
bias=bias,
device=device,
dtype=dtype,
Expand Down Expand Up @@ -3810,22 +3878,31 @@ def test_grouped_mlp(
if (
quantization == "mxfp8"
and dtype in (torch.bfloat16, torch.float16)
and glu_interleave_size == 32
and (
(activation == "scaled_srelu" and glu_interleave_size is None)
or (activation != "scaled_srelu" and glu_interleave_size == 32)
)
and _cudnn_frontend_version_supported()
):
if te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported():
if activation == "scaled_srelu":
forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMSReLU_MXFP8
backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSReLU_MXFP8
else:
forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8
backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8
if forward_cls.is_supported():
forward_ops = module._module_groups[0]._forward_ops
assert len(forward_ops) == 1
assert isinstance(
forward_ops[0][0],
te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8,
forward_cls,
)
if te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported():
if backward_cls is not None and backward_cls.is_supported():
backward_ops = module._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(
backward_ops[0][0],
te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8,
backward_cls,
)

# Loose tols for sanity checking
Expand Down
55 changes: 33 additions & 22 deletions transformer_engine/pytorch/ops/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,13 @@ def get_dummy_wgrads_for_params(
return out


def validate_grouped_mlp_dims(fc1, glu_op, fc2) -> None:
"""Validate FC1 / scaled GLU / FC2 dimensions for fused grouped MLP."""
def validate_grouped_mlp_dims(fc1, activation_op, fc2) -> None:
"""Validate FC1 / activation / FC2 dimensions for fused grouped MLP."""
from .basic import ( # pylint: disable=import-outside-toplevel
ScaledSReLU,
ScaledClampedQGeGLU,
ScaledSwiGLU,
)

if fc1.in_features % 64 != 0 or fc1.out_features % 64 != 0:
raise ValueError(
Expand All @@ -195,17 +200,27 @@ def validate_grouped_mlp_dims(fc1, glu_op, fc2) -> None:
f"Unsupported dims for FC2 (num_groups={fc2.num_groups}, "
f"in_features={fc2.in_features}, out_features={fc2.out_features})."
)
if fc1.out_features != 2 * fc2.in_features or fc1.num_groups != fc2.num_groups:
if isinstance(activation_op, (ScaledSwiGLU, ScaledClampedQGeGLU)):
expected_fc1_out_features = 2 * fc2.in_features
elif isinstance(activation_op, ScaledSReLU):
expected_fc1_out_features = fc2.in_features
else:
raise TypeError(f"Unsupported grouped MLP activation ({activation_op.__class__.__name__}).")

if fc1.out_features != expected_fc1_out_features or fc1.num_groups != fc2.num_groups:
raise ValueError(
f"FC1 (num_groups={fc1.num_groups}, in_features={fc1.in_features}, "
f"out_features={fc1.out_features}) "
f"and FC2 (num_groups={fc2.num_groups}, in_features={fc2.in_features}, "
f"out_features={fc2.out_features}) do not match."
)
if glu_op.glu_interleave_size != 32:
if (
isinstance(activation_op, (ScaledSwiGLU, ScaledClampedQGeGLU))
and activation_op.glu_interleave_size != 32
):
raise ValueError(
"Fused kernel requires 32-wide GLU interleaving, "
f"but got glu_interleave_size={glu_op.glu_interleave_size}."
f"but got glu_interleave_size={activation_op.glu_interleave_size}."
)


Expand All @@ -214,8 +229,10 @@ def fuse_grouped_mlp_ops(
*,
recipe,
fused_op_cls,
activation_op_types=None,
activation_kwarg: str = "swiglu",
):
"""Sliding-window fusion for GroupedLinear + scaled GLU + GroupedLinear.
"""Sliding-window fusion for GroupedLinear + activation + GroupedLinear.

Parameters
----------
Expand All @@ -225,9 +242,7 @@ def fuse_grouped_mlp_ops(
Quantization recipe.
fused_op_cls : type
Fused operation class with ``is_supported()`` classmethod and
constructor accepting ``fc1``, ``glu_op``, ``fc2`` keyword args. The
``glu_op`` must be :class:`~transformer_engine.pytorch.ops.basic.swiglu.ScaledSwiGLU`
or :class:`~transformer_engine.pytorch.ops.basic.swiglu.ScaledClampedQGeGLU`.
constructor accepting ``fc1``, activation op, and ``fc2`` keyword args.

Returns
-------
Expand All @@ -244,6 +259,8 @@ def fuse_grouped_mlp_ops(
return ops
if recipe is None or not recipe.mxfp8():
return ops
if activation_op_types is None:
activation_op_types = (ScaledSwiGLU, ScaledClampedQGeGLU)

out = []
window, ops = ops[:3], ops[3:]
Expand All @@ -252,31 +269,25 @@ def fuse_grouped_mlp_ops(
matches_pattern = True
if not (
isinstance(window[0], GroupedLinear)
and isinstance(window[1], (ScaledSwiGLU, ScaledClampedQGeGLU))
and isinstance(window[1], activation_op_types)
and isinstance(window[2], GroupedLinear)
):
matches_pattern = False
elif isinstance(window[1], ScaledClampedQGeGLU) and (
abs(window[1]._clamped.alpha - 1.702) > 0.001
):
matches_pattern = False
elif window[0].num_groups != window[2].num_groups:
matches_pattern = False
elif (
window[0].in_features % 64 != 0
or window[0].out_features % 64 != 0
or window[2].in_features % 64 != 0
or window[2].out_features % 64 != 0
):
matches_pattern = False
elif window[1].glu_interleave_size != 32:
matches_pattern = False
else:
try:
validate_grouped_mlp_dims(window[0], window[1], window[2])
except (TypeError, ValueError):
matches_pattern = False

if matches_pattern:
op = fused_op_cls(
fc1=window[0],
swiglu=window[1],
fc2=window[2],
**{activation_kwarg: window[1]},
)
window = [op]
else:
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/ops/basic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ReLU,
ReGLU,
SReLU,
ScaledSReLU,
SReGLU,
SiLU,
)
Expand Down
Loading
Loading