Skip to content

[Core] MXFP8 grouped GEMM + tensor-scaled FP8 fixes#2748

Open
jberchtold-nvidia wants to merge 12 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/te-common-mxfp8-grouped-gemm-plus-fixes
Open

[Core] MXFP8 grouped GEMM + tensor-scaled FP8 fixes#2748
jberchtold-nvidia wants to merge 12 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/te-common-mxfp8-grouped-gemm-plus-fixes

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

jberchtold-nvidia and others added 2 commits March 9, 2026 15:47
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 9, 2026

Greptile Summary

This PR adds MXFP8 grouped GEMM support to the cuBLAS-LT backend and fixes tensor-scaled FP8 grouped GEMM scale pointer handling. The core change introduces per-tensor E8M0 scale pointer arrays in the GPU setup kernel, proper cuBLAS VEC32_UE8M0 scale mode configuration, and 16-byte workspace alignment for pointer arrays required by cuBLAS 13.2+.

Key changes:

  • New MXFP8 operand selection logic (select_grouped_operand): routes A/B to rowwise vs. columnwise data based on transpose flags so scales always run along the K dimension.
  • GPU kernel extension (setup_grouped_gemm_kernel): per-tensor scale pointers now populated for both tensor-scaled FP8 (float* indexed by tensor) and MXFP8 (E8M0 base + a_offset/32 byte offset).
  • Tensor-scaled FP8 fix: switches from a single flat float* scale pointer to a PER_BATCH_SCALAR_32F pointer array (float**), aligning with the grouped GEMM API requirement.
  • Workspace layout: upgraded from 6 to 8 pointer arrays with 16-byte per-array alignment; required_setup_size accounts for this but the test helper grouped_setup_workspace_size does not (covered in a prior review thread).
  • Test infrastructure: build_grouped_tensor refactored to support rowwise-only / columnwise-only layouts, MXFP8 scale gathering added, and random padding disabled for MXFP8 to keep data and scale offsets consistent.
  • New public API: nvte_set/get_grouped_tensor_swizzled_scales to mark whether GEMM-ready swizzled scales are present.

Confidence Score: 3/5

  • Merging carries moderate risk: the tensor-scaled FP8 scale-pointer change is a behaviour-altering fix that needs careful validation, and the MXFP8 scale-offset computation in the GPU kernel contains known edge-case issues (non-32-divisible tensor sizes) documented in prior review threads.
  • The overall structure is sound and the test coverage is reasonable for the happy-path shapes chosen (all multiples of 128). However, a few known issues remain open from prior review rounds (workspace size underestimation for odd num_tensors, scale offset assumes exact 32-divisibility, unconditional FAST_ACCUM=0). The new concern about build_grouped_tensor unconditionally flagging MXFP8 scales as swizzled without verifying individual tensor state could silently produce wrong GEMM results if callers don't follow the expected workflow.
  • transformer_engine/common/gemm/cublaslt_grouped_gemm.cu (scale offset computation and FAST_ACCUM scope) and tests/cpp/test_common.cu (unconditional swizzled-scale assumption in build_grouped_tensor).

Important Files Changed

Filename Overview
tests/cpp/operator/test_grouped_gemm.cu Adds kMXFP8 input case and make_mxfp8_operand helper; updates test shapes to 128-multiples for MXFP8 alignment; 8 new MXFP8 test configurations added. Minor: grouped_setup_workspace_size test helper still diverges from production's aligned_ptr_size calculation for non-power-of-2 num_tensors (addressed in prior review thread).
tests/cpp/test_common.cu build_grouped_tensor refactored to support rowwise-only or columnwise-only layouts; MXFP8 scale gathering added via gather_scales lambda; random padding disabled for MXFP8. Concern: swizzled-scale flag is set unconditionally for all MXFP8 grouped tensors without verifying individual tensor flags, which could silently pack un-swizzled scales and mark them as swizzled.
tests/cpp/test_common.h Adds columnwise_scale_inv CudaPtr field to GroupedBuffers to hold MXFP8 columnwise scale buffer. Straightforward struct addition.
transformer_engine/common/common.h Adds with_gemm_swizzled_scales(false) to GroupedTensor constructor initializer list. Minimal, correct change.
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Core MXFP8 grouped GEMM implementation: adds MXFP8 operand selection logic, per-tensor E8M0 scale pointer computation in GPU kernel (a_offset/32), cuBLAS VEC32_UE8M0 scale mode, and workspace pointer-array alignment upgrade from 8-byte to 16-byte. Several concerns pre-exist in prior review threads (scale divisibility, FAST_ACCUM unconditional, workspace size mismatch).
transformer_engine/common/include/transformer_engine/transformer_engine.h Adds nvte_set_grouped_tensor_swizzled_scales and nvte_get_grouped_tensor_swizzled_scales API declarations. Correctly marked as EXPERIMENTAL.
transformer_engine/common/transformer_engine.cpp Implements the two new swizzled-scales getter/setter. Null-guard pattern (explicit nullptr check before convertNVTEGroupedTensorCheck) matches existing pre-existing APIs like nvte_grouped_tensor_scaling_mode.

Sequence Diagram

sequenceDiagram
    participant User
    participant make_mxfp8_operand
    participant nvte_quantize
    participant nvte_swizzle_scaling_factors
    participant build_grouped_tensor
    participant nvte_grouped_gemm
    participant setup_grouped_gemm_kernel
    participant cublasLtMatmul

    User->>make_mxfp8_operand: BF16 tensor + shape + (is_A, transposed)
    make_mxfp8_operand->>nvte_quantize: BF16 → MXFP8 (rowwise or columnwise)
    make_mxfp8_operand->>nvte_swizzle_scaling_factors: swizzle E8M0 scales for GEMM
    make_mxfp8_operand-->>User: mxfp8_swizzled tensor

    User->>build_grouped_tensor: [mxfp8_swizzled tensors], NVTE_MXFP8_1D_SCALING
    build_grouped_tensor->>build_grouped_tensor: gather_scales() — pack E8M0 per-tensor scales contiguously
    build_grouped_tensor->>build_grouped_tensor: set use_random_padding=false (offsets = sum of numel, no gaps)
    build_grouped_tensor->>build_grouped_tensor: nvte_set_grouped_tensor_swizzled_scales(h, 1)
    build_grouped_tensor-->>User: GroupedBuffers (data + scale_inv + columnwise_scale_inv)

    User->>nvte_grouped_gemm: GroupedTensor A, B, C, D
    nvte_grouped_gemm->>nvte_grouped_gemm: select_grouped_operand — pick rowwise/columnwise per A/B transpose
    nvte_grouped_gemm->>nvte_grouped_gemm: GroupedGemmSetupWorkspace::from_buffers (16-byte aligned ptr arrays)
    nvte_grouped_gemm->>setup_grouped_gemm_kernel: launch (a_mxfp8_scale_base, b_mxfp8_scale_base)
    setup_grouped_gemm_kernel->>setup_grouped_gemm_kernel: a_scale_inv_ptrs[i] = base + a_offset/32
    setup_grouped_gemm_kernel-->>nvte_grouped_gemm: per-tensor A/B/C/D/alpha/beta/scale pointers filled
    nvte_grouped_gemm->>nvte_grouped_gemm: set_fp8_scale_pointers — CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0
    nvte_grouped_gemm->>cublasLtMatmul: execute grouped GEMM
    cublasLtMatmul-->>User: output D tensors
Loading

Comments Outside Diff (1)

  1. tests/cpp/test_common.cu, line 1306-1308 (link)

    Swizzled-scale flag set unconditionally for all MXFP8 grouped tensors

    nvte_set_grouped_tensor_swizzled_scales(h, 1) is called at the end of the MXFP8 branch regardless of whether the individual tensors actually contain swizzled scales. build_grouped_tensor gathers raw scale bytes from the tensors via rowwise_cpu_scale_inv_ptr<uint8_t>() / columnwise_cpu_scale_inv_ptr<uint8_t>() without checking each tensor's with_gemm_swizzled_scales flag.

    If a caller passes MXFP8 tensors whose scales have NOT been swizzled (e.g., forgetting to call nvte_swizzle_scaling_factors), the function will:

    1. Pack the un-swizzled scale bytes into the grouped buffer.
    2. Mark the grouped tensor as having swizzled scales.
    3. Cause the subsequent GEMM to silently consume incorrectly-formatted scale data.

    A defensive check would catch this early:

    // Before gathering scales in the MXFP8 branch:
    for (size_t i = 0; i < num_tensors; ++i) {
      NVTE_CHECK(tensors[i]->with_gemm_swizzled_scales(),
                 "build_grouped_tensor: MXFP8 tensor ", i,
                 " does not have swizzled scales. Call nvte_swizzle_scaling_factors first.");
    }

Last reviewed commit: d103bbd

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

vthumbe1503
vthumbe1503 previously approved these changes Mar 9, 2026
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator

/te-ci

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator

/te-ci

Comment on lines +1262 to +1268
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
NVTE_CHECK_CUDA(cudaGetLastError());
void* dst = static_cast<char*>(buffer.get()) + scale_offsets[i];
const void* src = get_cpu_ptr_fn(tensors[i]);
NVTE_CHECK_CUDA(cudaMemcpy(dst, src, numels[i], cudaMemcpyHostToDevice));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant CPU sync for swizzled MXFP8 scales.

The loop calls tensors[i]->to_cpu() on line 1263, then immediately passes the tensor to get_cpu_ptr_fn(tensors[i]) on line 1267. However, both rowwise_cpu_scale_inv_ptr<uint8_t>() and columnwise_cpu_scale_inv_ptr<uint8_t>() internally call to_cpu() themselves (test_common.h lines 249 and 264), making the explicit call on line 1263 redundant.

Additionally, the GPU pointers are available directly via get_rowwise_scale_inv().data_ptr and get_columnwise_scale_inv().data_ptr, allowing a device-to-device copy that avoids the round-trip entirely:

Suggested change
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
NVTE_CHECK_CUDA(cudaGetLastError());
void* dst = static_cast<char*>(buffer.get()) + scale_offsets[i];
const void* src = get_cpu_ptr_fn(tensors[i]);
NVTE_CHECK_CUDA(cudaMemcpy(dst, src, numels[i], cudaMemcpyHostToDevice));
}
NVTE_CHECK_CUDA(cudaMemcpy(dst,
has_rowwise ? tensors[i]->tensor_.get_rowwise_scale_inv().data_ptr
: tensors[i]->tensor_.get_columnwise_scale_inv().data_ptr,
numels[i],
cudaMemcpyDeviceToDevice));

This improves both clarity and efficiency in test code.

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Add documentation for scaling factors in common.h

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

1 similar comment
@vthumbe1503
Copy link
Collaborator

/te-ci

@vthumbe1503
Copy link
Collaborator

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants