-
Notifications
You must be signed in to change notification settings - Fork 617
[PyTorch] Add grouped linear op and experimental fusion for grouped MLP #2622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Expose option for custom op fusions Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add tests for custom ops Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warnings and numerical test failures Signed-off-by: Tim Moon <tmoon@nvidia.com> * Tweak pattern matching logic with fixed window sizes Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use TF32 tols in fused op tests Signed-off-by: Tim Moon <tmoon@nvidia.com> * Review suggestion from @greptile-apps Signed-off-by: Tim Moon <tmoon@nvidia.com> * Backpropagate fixes from #2622 Signed-off-by: Tim Moon <tmoon@nvidia.com> --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch L1 |
Greptile OverviewGreptile SummaryThis PR adds grouped linear operations and experimental fusion for grouped MLP blocks in Mixture-of-Experts models. The implementation includes:
All previously identified issues from earlier review threads have been successfully addressed:
The implementation follows established patterns in the codebase, includes comprehensive tests, and properly handles FP8/MXFP8 quantization. Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Input as Input Tensor
participant FC1 as GroupedLinear (FC1)
participant SwiGLU as ScaledSwiGLU
participant FC2 as GroupedLinear (FC2)
participant Output as Output Tensor
Note over Input,Output: Standard Path (Unfused)
Input->>FC1: Split by group sizes
FC1->>FC1: Apply grouped linear transformations
FC1->>SwiGLU: Interleaved gate/activation output
SwiGLU->>SwiGLU: Apply SwiGLU + post-scale
SwiGLU->>FC2: Split activations by group sizes
FC2->>FC2: Apply grouped linear transformations
FC2->>Output: Concatenated output
Note over Input,Output: Fused Path (MXFP8 + SM100)
Input->>FC1: Quantize input (MXFP8)
FC1->>FC1: CuTe kernel: GEMM + SwiGLU + post-scale
FC1-->>FC2: Direct MXFP8 intermediate tensors
FC2->>FC2: Apply grouped GEMM
FC2->>Output: Final output
|
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
| quantizer.optimize_for_gemm = True | ||
| fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers) | ||
|
|
||
| # Pack data tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be a silly question: are these packing and unpacking code just for verification? Or will they be in the final version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm working on getting rid of the concatenations, but the permutes are no-ops. The kernel API expects tensors with non-contiguous dims: https://github.com/NVIDIA/cudnn-frontend/blob/main/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py#L240-L245
| ) | ||
|
|
||
| # Fused kernel for FC1 + SwiGLU + post-scale | ||
| fc1_kernel_out = self.grouped_gemm_swiglu_kernel()( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After swiglu, it usually needs to multiply with permuted_probs. Does this weighted swiglu supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the probs are passed into the kernel here: https://github.com/timmoon10/TransformerEngine/blob/46294be478f6551e2cf251283adc7529ddb2964e/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py#L264
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Review suggestions from @greptile-apps Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
Description
This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.
Type of change
Changes
Checklist: