Skip to content

Conversation

@timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Jan 24, 2026

Description

This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add a grouped linear operation
  • Add a post-scaled SwiGLU op and add support for interleaving SwiGLU gate and linear units
  • Add a fused operation for grouped MLP

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

timmoon10 and others added 30 commits January 7, 2026 00:15
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added the performance Performance issues label Jan 24, 2026
Signed-off-by: Tim Moon <tmoon@nvidia.com>
timmoon10 added a commit to timmoon10/TransformerEngine that referenced this pull request Jan 24, 2026
Signed-off-by: Tim Moon <tmoon@nvidia.com>
timmoon10 added a commit that referenced this pull request Jan 25, 2026
* Expose option for custom op fusions

Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add tests for custom ops

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter warnings and numerical test failures

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Tweak pattern matching logic with fixed window sizes

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use TF32 tols in fused op tests

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Review suggestion from @greptile-apps

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Backpropagate fixes from #2622

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@timmoon10 timmoon10 mentioned this pull request Jan 25, 2026
13 tasks
@timmoon10 timmoon10 changed the title [PyTorch] Prototype of fused operation for grouped MLP [PyTorch] Add grouped linear op and experimental fusion for grouped MLP Jan 25, 2026
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 marked this pull request as ready for review January 25, 2026 01:00
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch L1

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 25, 2026

Greptile Overview

Greptile Summary

This PR adds grouped linear operations and experimental fusion for grouped MLP blocks in Mixture-of-Experts models. The implementation includes:

  • GroupedLinear: A new operation that applies multiple linear transformations by splitting input along the first dimension, applying separate weights per group, and concatenating results
  • ScaledSwiGLU: A post-scaled SwiGLU activation with support for interleaved gate/activation units (32-wide interleaving)
  • ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8: An experimental fused operation for SM100+ GPUs using CuTe DSL kernel that combines grouped GEMM, SwiGLU, and post-scaling in a single kernel call

All previously identified issues from earlier review threads have been successfully addressed:

  • Fixed undefined group_idx variable usage
  • Corrected duplicate condition checks in dimension validation
  • Fixed missing f prefixes for f-strings
  • Corrected attribute access for per-group weights

The implementation follows established patterns in the codebase, includes comprehensive tests, and properly handles FP8/MXFP8 quantization.

Confidence Score: 5/5

  • This PR is safe to merge with no blocking issues found
  • All previously identified issues have been resolved, the implementation follows established codebase patterns, includes comprehensive test coverage, and properly handles edge cases
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/grouped_linear.py New GroupedLinear op implementing multiple linear transformations with FP8 quantization support
transformer_engine/pytorch/ops/basic/swiglu.py New ScaledSwiGLU op with post-scaling and support for gate interleaving
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Experimental fused operation for MXFP8 grouped MLP using CuTe DSL kernel

Sequence Diagram

sequenceDiagram
    participant Input as Input Tensor
    participant FC1 as GroupedLinear (FC1)
    participant SwiGLU as ScaledSwiGLU
    participant FC2 as GroupedLinear (FC2)
    participant Output as Output Tensor
    
    Note over Input,Output: Standard Path (Unfused)
    Input->>FC1: Split by group sizes
    FC1->>FC1: Apply grouped linear transformations
    FC1->>SwiGLU: Interleaved gate/activation output
    SwiGLU->>SwiGLU: Apply SwiGLU + post-scale
    SwiGLU->>FC2: Split activations by group sizes
    FC2->>FC2: Apply grouped linear transformations
    FC2->>Output: Concatenated output
    
    Note over Input,Output: Fused Path (MXFP8 + SM100)
    Input->>FC1: Quantize input (MXFP8)
    FC1->>FC1: CuTe kernel: GEMM + SwiGLU + post-scale
    FC1-->>FC2: Direct MXFP8 intermediate tensors
    FC2->>FC2: Apply grouped GEMM
    FC2->>Output: Final output
Loading

greptile-apps[bot]

This comment was marked as resolved.

@greptile-apps

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
greptile-apps[bot]

This comment was marked as outdated.

quantizer.optimize_for_gemm = True
fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers)

# Pack data tensors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be a silly question: are these packing and unpacking code just for verification? Or will they be in the final version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm working on getting rid of the concatenations, but the permutes are no-ops. The kernel API expects tensors with non-contiguous dims: https://github.com/NVIDIA/cudnn-frontend/blob/main/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py#L240-L245

)

# Fused kernel for FC1 + SwiGLU + post-scale
fc1_kernel_out = self.grouped_gemm_swiglu_kernel()(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After swiglu, it usually needs to multiply with permuted_probs. Does this weighted swiglu supported?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Review suggestions from @greptile-apps

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants