Skip to content

Conversation

@Oleg-Goncharov
Copy link
Collaborator

@Oleg-Goncharov Oleg-Goncharov commented Jan 12, 2026

Description

This PR adds a new kernel that supports MXFP8 quantization of grouped tensors.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added MXFP8 cast kernel for grouped tensors
  • Added the test suite

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_mxfp8_grouped_kernel branch 4 times, most recently from e6bf02a to fc2a53f Compare January 15, 2026 16:15
@Oleg-Goncharov Oleg-Goncharov added enhancement New feature or request MoE labels Jan 15, 2026
@ptrendx ptrendx linked an issue Jan 16, 2026 that may be closed by this pull request
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_mxfp8_grouped_kernel branch from 74a7917 to 88cf1b2 Compare January 21, 2026 17:00
pre-commit-ci bot and others added 6 commits January 21, 2026 17:00
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_mxfp8_grouped_kernel branch from 7c4fda7 to 39bb24f Compare January 22, 2026 18:12
@Oleg-Goncharov Oleg-Goncharov marked this pull request as ready for review January 24, 2026 00:53
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 24, 2026

Greptile Summary

  • Implements MXFP8 quantization kernel for grouped tensors, enabling efficient processing of multiple tensors with varying dimensions in a single operation
  • Extends all major activation functions (GeLU, SiLU, ReLU, etc.) to support grouped tensor variants with comprehensive C API coverage
  • Adds extensive test coverage with 777 test configurations validating correctness across different tensor shapes, scaling modes, and activation functions

Important Files Changed

Filename Overview
transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh New CUDA kernel implementation for grouped tensor MXFP8 quantization with TMA operations and multi-stage processing
tests/cpp/operator/test_cast_mxfp8_grouped.cu New comprehensive test suite with 777 test configurations validating grouped tensor quantization functionality
transformer_engine/common/cast/dispatch/quantize.cuh Adds dispatch helper functions for grouped tensor quantization operations
transformer_engine/common/include/transformer_engine/cast.h Extends C API with 7 new function declarations for grouped tensor quantization operations

Confidence score: 4/5

  • This PR is generally safe to merge but requires careful review due to complex GPU kernel implementation and API extensions
  • Score reflects sophisticated CUDA kernel code with architecture-specific optimizations and potential edge cases in TMA descriptor handling
  • Pay close attention to the main CUDA kernel implementation and verify the API parameter type consistency for gradient functions

Sequence Diagram

sequenceDiagram
    participant User
    participant TestSuite as "GroupedFusedCastMXFP8TestSuite"
    participant Helper as "performTest<IType, OType>"
    participant KernelDispatch as "group_quantize_*_helper"
    participant Kernel as "group_quantize_mxfp8_kernel"
    participant GPU as "CUDA Device"

    User->>TestSuite: "Run test with parameters"
    TestSuite->>TestSuite: "Configure tensors and validation"
    TestSuite->>Helper: "performTest(processing_method, OP, ...)"
    Helper->>Helper: "Setup input/output tensors and reference data"
    Helper->>Helper: "Create NVTE grouped tensors"
    
    alt Processing Method: CAST_ONLY
        Helper->>KernelDispatch: "nvte_group_quantize(input, output)"
    else Processing Method: CAST_DBIAS
        Helper->>KernelDispatch: "nvte_group_quantize_dbias(grad, output, dbias)"
    else Processing Method: CAST_ACT
        Helper->>KernelDispatch: "nvte_group_gelu/silu/relu(input, output)"
    else Processing Method: CAST_DACT
        Helper->>KernelDispatch: "nvte_group_dgelu/dsilu/drelu(grad, input, output)"
    else Processing Method: CAST_DBIAS_DACT
        Helper->>KernelDispatch: "nvte_group_quantize_dbias_dgelu/dsilu/drelu(grad, input, output, dbias)"
    end
    
    KernelDispatch->>KernelDispatch: "group_quantize_fwd/bwd_helper"
    KernelDispatch->>Kernel: "update_tma_descriptors<<<num_tensors, 32>>>"
    GPU->>KernelDispatch: "TMA descriptors updated"
    KernelDispatch->>Kernel: "group_quantize_mxfp8_kernel<<<blocks, threads>>>"
    GPU->>Kernel: "Execute MXFP8 quantization with scaling"
    Kernel->>GPU: "Write quantized data and scaling factors"
    GPU->>KernelDispatch: "Kernel execution complete"
    
    alt IS_DBIAS enabled
        KernelDispatch->>GPU: "reduce_dbias(workspace, dbias)"
    end
    
    KernelDispatch->>Helper: "Quantized tensors ready"
    Helper->>Helper: "Compare GPU results with CPU reference"
    Helper->>TestSuite: "Validation complete"
    TestSuite->>User: "Test result"
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +104 to +107
case ShapeRepresentation::SAME_BOTH_DIMS: // rows_num = first_logical_dim / num_tensors; break;
case ShapeRepresentation::VARYING_LAST_DIM:
rows_num = first_logical_dim;
break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: commented-out code in case statement creates ambiguity

The commented calculation rows_num = first_logical_dim / num_tensors; break; suggests this case should compute rows differently, but the fallthrough to VARYING_LAST_DIM (which sets rows_num = first_logical_dim) may not be the intended behavior. If fallthrough is intentional, remove the comment to clarify. If the division by num_tensors is needed for SAME_BOTH_DIMS, uncomment and add the break statement.

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data,
first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X,
last_logical_dim, 0, output_type_bit_size);
} constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: missing newline after closing brace before constexpr declarations

This formatting issue breaks the visual separation between the conditional compilation block and the subsequent constant declarations. Add a newline for better readability.

Suggested change
} constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
}
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request MoE

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Quantization support for GroupedTensor: MXFP8

2 participants