Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
5175aad
Naive implementation of grouped linear op
timmoon10 Jan 7, 2026
5ffd57e
Use grouped GEMM tex functions
timmoon10 Jan 7, 2026
2ee42da
Support quantized compute
timmoon10 Jan 8, 2026
93e71df
Debug test failures with MXFP8 or NVFP4 params
timmoon10 Jan 8, 2026
fdddc47
Add multiply op
timmoon10 Jan 10, 2026
b448a17
Bug fixes
timmoon10 Jan 10, 2026
3f38897
Expose option for custom op fusions
timmoon10 Jan 14, 2026
a359b67
Add tests for custom ops
timmoon10 Jan 14, 2026
5f7204f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2026
8ddb8ce
Fix linter warnings and numerical test failures
timmoon10 Jan 14, 2026
cfc2617
Tweak pattern matching logic with fixed window sizes
timmoon10 Jan 15, 2026
0ce5dfb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2026
9bf5843
Merge branch 'main' into tmoon/custom-fused-ops
timmoon10 Jan 15, 2026
4992903
Use TF32 tols in fused op tests
timmoon10 Jan 15, 2026
9ab7751
Review suggestion from @greptile-apps
timmoon10 Jan 15, 2026
a086d81
Merge branch 'main' into tmoon/custom-fused-ops
timmoon10 Jan 15, 2026
f05f7a8
Merge branch 'main' into tmoon/grouped-linear-op
timmoon10 Jan 15, 2026
9348138
Fix linter warnings
timmoon10 Jan 15, 2026
5366729
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2026
1b0b229
Merge branch 'tmoon/grouped-linear-op' into tmoon/cute-gemm-swiglu
timmoon10 Jan 15, 2026
3bbe881
Merge branch 'tmoon/custom-fused-ops' into tmoon/cute-gemm-swiglu
timmoon10 Jan 15, 2026
321646e
Initial impl of fused op for grouped MLP
timmoon10 Jan 16, 2026
e137451
Import group GEMM+SwiGLU kernel
timmoon10 Jan 17, 2026
11da59d
Merge branch 'main' into tmoon/cute-gemm-swiglu
timmoon10 Jan 20, 2026
cb728bb
Add unit test for grouped MLP op
timmoon10 Jan 20, 2026
e7459cc
Call fused group GEMM + SwiGLU kernel
timmoon10 Jan 21, 2026
b15ca0d
Debug test failures
timmoon10 Jan 21, 2026
3da2c17
Get test to not pass trivially
timmoon10 Jan 22, 2026
0270eb1
Handle interleaving for SwiGLU
timmoon10 Jan 22, 2026
0b09790
Fix numeric tests, except for probs grad
timmoon10 Jan 22, 2026
7c40290
Use pre-swizzled scales from GEMM+SwiGLU output
timmoon10 Jan 22, 2026
a098cc0
Add scaled SwiGLU op
timmoon10 Jan 23, 2026
e4f51d3
Avoid CPU splits in group GEMM+SwiGLU kernel
timmoon10 Jan 23, 2026
fb28b6e
Debug scaled SwiGLU
timmoon10 Jan 23, 2026
b0bf34d
Handle case where fused kernel is not available
timmoon10 Jan 24, 2026
000c273
Revert to plain tensor concat
timmoon10 Jan 24, 2026
e2ea4d2
Support GLU interleaving in plain SwiGLU op
timmoon10 Jan 24, 2026
4c6c35f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2026
caf580b
Remove MultiplyExtraInput op
timmoon10 Jan 24, 2026
b36007e
Merge branch 'main' into tmoon/cute-gemm-swiglu
timmoon10 Jan 25, 2026
36e6918
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2026
ba28c6f
Fix linter warnings
timmoon10 Jan 25, 2026
575da6e
Review suggestions from @greptile-apps
timmoon10 Jan 26, 2026
46294be
Apply suggestion from @greptile-apps[bot]
timmoon10 Jan 26, 2026
fccb0bb
Tweak variable names
timmoon10 Jan 29, 2026
4259e27
Fix f-strings
timmoon10 Jan 30, 2026
2442d34
Fix bug when grouped MLP is not being trained
timmoon10 Jan 30, 2026
a7351e5
Fix f-string
timmoon10 Jan 31, 2026
9be1c49
Replace explicit concat with optional concat
timmoon10 Jan 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
489 changes: 483 additions & 6 deletions tests/pytorch/test_fusible_ops.py

Large diffs are not rendered by default.

19 changes: 12 additions & 7 deletions transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,19 @@ def forward(
return torch.cat(tensors, dim=dim)
data_ptr += tensor.size(dim) * data_ptr_stride

# Out-of-place concatenation when view tensors have different storage
# Note: This works around an edge case with the split_quantize
# function, which might allocate a buffer and construct
# subviews. However, in order to reduce CPU overheads, these
# views are configured manually outside of PyTorch. PyTorch
# doesn't know these views share the same memory, and it
# blocks us from reconstructing the full tensor because it
# thinks we are accessing out-of-bounds memory.
if tensors[0].untyped_storage().nbytes() < out_shape[dim] * data_ptr_stride:
return torch.cat(tensors, dim=dim)

# No-op concatenation
out = tensors[0].new()
out.set_(
tensors[0].untyped_storage(),
tensors[0].storage_offset(),
out_shape,
strides,
)
out = tensors[0].as_strided(out_shape, strides)
out.requires_grad = any(tensor.requires_grad for tensor in tensors)
return out

Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/ops/basic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
SReLU,
SReGLU,
SiLU,
SwiGLU,
ClampedSwiGLU,
)
from .add_extra_input import AddExtraInput
from .all_gather import AllGather
Expand All @@ -24,6 +22,7 @@
from .bias import Bias
from .constant_scale import ConstantScale
from .dropout import Dropout
from .grouped_linear import GroupedLinear
from .identity import Identity
from .l2normalization import L2Normalization
from .layer_norm import LayerNorm
Expand All @@ -32,3 +31,4 @@
from .reduce_scatter import ReduceScatter
from .reshape import Reshape
from .rmsnorm import RMSNorm
from .swiglu import ClampedSwiGLU, ScaledSwiGLU, SwiGLU
75 changes: 0 additions & 75 deletions transformer_engine/pytorch/ops/basic/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
"SReLU",
"SReGLU",
"SiLU",
"SwiGLU",
"ClampedSwiGLU",
]


Expand Down Expand Up @@ -355,76 +353,3 @@ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:

def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dsilu(*args, **kwargs)


class SwiGLU(_ActivationOperation):
r"""Swish gated linear unit

The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:

.. math::

\text{GEGLU}(a,b) = \text{SiLU}(a) * b

where

.. math::

\text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)}

.. warning::

Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.

The Sigmoid Linear Unit (SiLU) gating function is also known as
the swish function. See
`GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__
and `Gaussian Error Linear Units (GELUs)<https://arxiv.org/abs/1606.08415>`__.

"""

def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.swiglu(*args, **kwargs)

def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dswiglu(*args, **kwargs)


class ClampedSwiGLU(_ActivationOperation):
r"""GPT-OSS
Implementation based on `GPT-OSS<https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250>`__.

This activation has two differences compared to the original SwiGLU
1. Both gate and pre-activations are clipped based on parameter limit.
2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation.

.. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt
from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor.

Parameters
----------
limit : float
The clamp limit.
alpha : float
The scaling factor for the sigmoid function used in the activation.
cache_quantized_input : bool, default = False
Quantize input tensor when caching for use in the backward pass.
"""

def __init__(
self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False
):
super().__init__(cache_quantized_input=cache_quantized_input)
self.limit = limit
self.alpha = alpha

def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.clamped_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs)

def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.clamped_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs)
Loading