Skip to content

GGEMM+srelu kernels for MxFP8 Nemotron#2981

Open
sraman-rgb wants to merge 1 commit into
NVIDIA:mainfrom
sraman-rgb:fc1-srelu-main
Open

GGEMM+srelu kernels for MxFP8 Nemotron#2981
sraman-rgb wants to merge 1 commit into
NVIDIA:mainfrom
sraman-rgb:fc1-srelu-main

Conversation

@sraman-rgb
Copy link
Copy Markdown

@sraman-rgb sraman-rgb commented May 12, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ksivaman
Copy link
Copy Markdown
Member

/te-ci pytorch

@ksivaman
Copy link
Copy Markdown
Member

Please sign-off your commits @sraman-rgb

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 12, 2026

Greptile Summary

This PR adds ScaledSReLU (squared-ReLU with per-row post-scaling) as a new activation option for the MXFP8 GGEMM fused kernel path, mirroring the existing SwiGLU/GeGLU fusion. The implementation introduces ForwardGroupedMLP_CuTeGEMMSReLU_MXFP8 and BackwardGroupedMLP_CuTeGEMMDSReLU_MXFP8 subclasses that redirect the kernel import to the srelu/dsrelu sm100 wrappers and correctly adjust the FC1-to-FC2 width ratio (1:1 for SReLU vs 2:1 for GLU).

  • _common.py: fuse_grouped_mlp_ops and validate_grouped_mlp_dims are generalised to accept any activation type, eliminating the hard-coded GLU assumptions while preserving backward compatibility for SwiGLU/GeGLU paths.
  • activation.py: ScaledSReLU is added as a first-class BasicOperation, overriding fuser_forward/fuser_backward and correctly threading the prev_op_grad_output_quantizer through to tex.dsrelu.
  • backward_grouped_mlp.py: beta_tensor and act_func are now conditionally injected into the dGLU kernel kwargs (absent for SReLU), and grad_scales handling adds a None guard whose else () branch is unreachable but semantically incorrect for num_extra_inputs = 1.

Confidence Score: 5/5

The forward and backward fused paths for ScaledSReLU are structurally sound and mirror the well-tested SwiGLU pattern; no correctness-breaking logic was found in the changed code.

The two findings are a dead-code branch in the backward grad-extras tuple (which would misbehave if reached, but is not reachable today) and a minor inconsistency in extra_input_requires_grad between the fused and unfused paths that could waste compute when scales are frozen but does not corrupt gradients.

backward_grouped_mlp.py — the activation_grad_extra construction; forward_grouped_mlp.py — the extra_input_requires_grad assignment.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/activation.py Adds ScaledSReLU operation with fuser_forward/fuser_backward overrides; gradient logic matches the reference SwiGLU pattern and handles quantizer propagation correctly.
transformer_engine/pytorch/ops/_common.py Refactors fuse_grouped_mlp_ops to accept activation_op_types and activation_kwarg; centralises dimension validation in validate_grouped_mlp_dims, correctly handling the 1:1 (SReLU) vs 2:1 (GLU) FC1-to-FC2 width ratio.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Extends ForwardGroupedMLP to accept either swiglu or srelu activation; adds SReLU subclass overriding the kernel import; fused forward saves context compatible with both unfused SReLU backward and the new fused dSReLU backward.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Refactors BackwardGroupedMLP to handle both SwiGLU/GeGLU and SReLU activations; adds BackwardGroupedMLP_CuTeGEMMDSReLU_MXFP8 subclass; beta_tensor and act_func are now conditionally added to kernel kwargs; grad_scales None guard has a dead else-branch that would silently under-deliver extra-input gradients if reached.
tests/pytorch/test_fusible_ops.py Adds test_scaled_srelu for the unfused op and extends test_grouped_mlp to cover scaled_srelu with appropriate skip guards; fusion detection asserts updated for the new forward/backward class pair.

Sequence Diagram

sequenceDiagram
    participant User
    participant Fuser
    participant FwdSReLU as ForwardGroupedMLP_CuTeGEMMSReLU_MXFP8
    participant BwdSReLU as BackwardGroupedMLP_CuTeGEMMDSReLU_MXFP8
    participant srelu_kernel as grouped_gemm_srelu_wrapper_sm100
    participant dsrelu_kernel as grouped_gemm_dsrelu_wrapper_sm100

    User->>Fuser: forward(fc1, ScaledSReLU, fc2)
    Fuser->>FwdSReLU: fuser_forward(fc1_glu_kwargs, scales)
    Note over FwdSReLU: prob_tensor = scales (or ones if None)<br/>act_func NOT added (SReLU path)
    FwdSReLU->>srelu_kernel: grouped_gemm_srelu_wrapper_sm100(...)
    srelu_kernel-->>FwdSReLU: fc2_out
    FwdSReLU-->>User: output, saves (swiglu_in, scales) in ctx

    User->>Fuser: backward(grad_output)
    Fuser->>BwdSReLU: fuser_backward(activation_ctx)
    Note over BwdSReLU: prob_tensor = scales_tensor<br/>dprob_tensor = zeros_like(scales_tensor)<br/>beta_tensor/act_func NOT added
    BwdSReLU->>dsrelu_kernel: grouped_gemm_dsrelu_wrapper_sm100(...)
    dsrelu_kernel-->>BwdSReLU: d_row/col tensors, dprob_tensor (grad_scales)
    BwdSReLU-->>User: grad_input, grad_scales
Loading

Reviews (4): Last reviewed commit: "Add MXFP8 grouped MLP SReLU fusion" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/basic/activation.py
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Signed-off-by: sraman-rgb <sraman@nvidia.com>
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, but we've gotten to the point where we need to start thinking about how to gracefully handle adding new activations. It seems that every model has a different activation function.

Comment on lines +309 to +310
swiglu: Optional[ScaledSwiGLU | ScaledClampedQGeGLU] = None,
srelu: Optional[ScaledSReLU] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not have a single arg?

Suggested change
swiglu: Optional[ScaledSwiGLU | ScaledClampedQGeGLU] = None,
srelu: Optional[ScaledSReLU] = None,
activation: Optional[FusibleOperation] = None,

It seems like we're adding one activation function after another, so we want interfaces that scale gracefully. Also, fused ops are basically internal to TE and these ops in particular are experimental, so backward compatibility is not a major concern.

The forward fused op should have a similar design. Changing to a consistent arg name would also let us get rid of the kwarg name messiness in the op fusion function.

return fc2_out, [(), (), ()]


class ForwardGroupedMLP_CuTeGEMMSReLU_MXFP8(ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an awkward class hierarchy. It would be better to have a virtual base class that both the GLU and non-GLU functions inherit from. The backward fused ops should have a similar design.

While we're messing with the existing classes, we should reconsider the names. The "SwiGLU" op is actually used for both SwiGLU and ClampedQGeGLU, so a name like "GLU" would be better. And there's no reason to expect "SReLU" won't be applied to other activations later, so maybe "Unary" would be more general.

pytest.skip("Quantized group GEMM is only supported with BF16/FP16")
if activation == "scaled_srelu" and quantization != "mxfp8":
pytest.skip("ScaledSReLU grouped MLP fusion is only supported with MXFP8")
if activation == "scaled_srelu" and glu_interleave_size is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This is assuming that activations are GLUs by default, and SReLU is weird. Isn't that kind of backward? In any case, it would be more logical to have a single point where we check is_glu_activation, and then use that everywhere.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants