GGEMM+srelu kernels for MxFP8 Nemotron#2981
Conversation
|
/te-ci pytorch |
|
Please sign-off your commits @sraman-rgb |
Greptile SummaryThis PR adds
Confidence Score: 5/5The forward and backward fused paths for ScaledSReLU are structurally sound and mirror the well-tested SwiGLU pattern; no correctness-breaking logic was found in the changed code. The two findings are a dead-code branch in the backward grad-extras tuple (which would misbehave if reached, but is not reachable today) and a minor inconsistency in extra_input_requires_grad between the fused and unfused paths that could waste compute when scales are frozen but does not corrupt gradients. backward_grouped_mlp.py — the activation_grad_extra construction; forward_grouped_mlp.py — the extra_input_requires_grad assignment. Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Fuser
participant FwdSReLU as ForwardGroupedMLP_CuTeGEMMSReLU_MXFP8
participant BwdSReLU as BackwardGroupedMLP_CuTeGEMMDSReLU_MXFP8
participant srelu_kernel as grouped_gemm_srelu_wrapper_sm100
participant dsrelu_kernel as grouped_gemm_dsrelu_wrapper_sm100
User->>Fuser: forward(fc1, ScaledSReLU, fc2)
Fuser->>FwdSReLU: fuser_forward(fc1_glu_kwargs, scales)
Note over FwdSReLU: prob_tensor = scales (or ones if None)<br/>act_func NOT added (SReLU path)
FwdSReLU->>srelu_kernel: grouped_gemm_srelu_wrapper_sm100(...)
srelu_kernel-->>FwdSReLU: fc2_out
FwdSReLU-->>User: output, saves (swiglu_in, scales) in ctx
User->>Fuser: backward(grad_output)
Fuser->>BwdSReLU: fuser_backward(activation_ctx)
Note over BwdSReLU: prob_tensor = scales_tensor<br/>dprob_tensor = zeros_like(scales_tensor)<br/>beta_tensor/act_func NOT added
BwdSReLU->>dsrelu_kernel: grouped_gemm_dsrelu_wrapper_sm100(...)
dsrelu_kernel-->>BwdSReLU: d_row/col tensors, dprob_tensor (grad_scales)
BwdSReLU-->>User: grad_input, grad_scales
Reviews (4): Last reviewed commit: "Add MXFP8 grouped MLP SReLU fusion" | Re-trigger Greptile |
8373402 to
765d2e9
Compare
Signed-off-by: sraman-rgb <sraman@nvidia.com>
765d2e9 to
43093cc
Compare
timmoon10
left a comment
There was a problem hiding this comment.
Overall looks good, but we've gotten to the point where we need to start thinking about how to gracefully handle adding new activations. It seems that every model has a different activation function.
| swiglu: Optional[ScaledSwiGLU | ScaledClampedQGeGLU] = None, | ||
| srelu: Optional[ScaledSReLU] = None, |
There was a problem hiding this comment.
Why not have a single arg?
| swiglu: Optional[ScaledSwiGLU | ScaledClampedQGeGLU] = None, | |
| srelu: Optional[ScaledSReLU] = None, | |
| activation: Optional[FusibleOperation] = None, |
It seems like we're adding one activation function after another, so we want interfaces that scale gracefully. Also, fused ops are basically internal to TE and these ops in particular are experimental, so backward compatibility is not a major concern.
The forward fused op should have a similar design. Changing to a consistent arg name would also let us get rid of the kwarg name messiness in the op fusion function.
| return fc2_out, [(), (), ()] | ||
|
|
||
|
|
||
| class ForwardGroupedMLP_CuTeGEMMSReLU_MXFP8(ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8): |
There was a problem hiding this comment.
This is an awkward class hierarchy. It would be better to have a virtual base class that both the GLU and non-GLU functions inherit from. The backward fused ops should have a similar design.
While we're messing with the existing classes, we should reconsider the names. The "SwiGLU" op is actually used for both SwiGLU and ClampedQGeGLU, so a name like "GLU" would be better. And there's no reason to expect "SReLU" won't be applied to other activations later, so maybe "Unary" would be more general.
| pytest.skip("Quantized group GEMM is only supported with BF16/FP16") | ||
| if activation == "scaled_srelu" and quantization != "mxfp8": | ||
| pytest.skip("ScaledSReLU grouped MLP fusion is only supported with MXFP8") | ||
| if activation == "scaled_srelu" and glu_interleave_size is not None: |
There was a problem hiding this comment.
Nit: This is assuming that activations are GLUs by default, and SReLU is weird. Isn't that kind of backward? In any case, it would be more logical to have a single point where we check is_glu_activation, and then use that everywhere.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: