Skip to content

[Common, pyTorch] Grouped MXFP8 dequantize support#2722

Open
ptrendx wants to merge 11 commits intoNVIDIA:mainfrom
ptrendx:pr_grouped_dequantize
Open

[Common, pyTorch] Grouped MXFP8 dequantize support#2722
ptrendx wants to merge 11 commits intoNVIDIA:mainfrom
ptrendx:pr_grouped_dequantize

Conversation

@ptrendx
Copy link
Member

@ptrendx ptrendx commented Mar 2, 2026

Description

Support dequantization for MXFP8 grouped tensors.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Grouped dequantization kernel for MXFP8
  • Exposed the functionality in PyTorch

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx requested a review from Oleg-Goncharov March 2, 2026 19:13
pre-commit-ci bot and others added 3 commits March 2, 2026 19:19
@ptrendx ptrendx linked an issue Mar 2, 2026 that may be closed by this pull request
ptrendx added 3 commits March 3, 2026 13:46
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx marked this pull request as ready for review March 10, 2026 18:00
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 10, 2026

Greptile Summary

This PR adds grouped MXFP8 dequantization support by introducing a new TMA-based CUDA kernel (group_dequantize_mxfp8.cuh) and wiring it through the C API (nvte_group_dequantize) and the PyTorch Python binding (tex.group_dequantize). The kernel reuses the TMA descriptor infrastructure and shape-representation helpers from the existing group_quantize path, handles all four shape modes (SAME, VARYING_FIRST, VARYING_LAST, VARYING_BOTH), and is gated on compute capability ≥ 10.0.

Key findings:

  • Bug in test offsets shape (test_dequantize_mxfp8_grouped.cu, line 165): offsets_shape.data[0] is set to num_tensors but the CSR-format offsets array has num_tensors + 1 elements. The declared shape metadata is incorrect and should be num_tensors + 1.
  • Missing last_dims in Python binding (cast.cpp): group_dequantize extracts and forwards first_dims and tensor_offsets from the Python GroupedTensor, but silently omits last_dims. This leaves input->last_dims.dptr as nullptr inside the C++ wrapper. For VARYING_LAST_DIM or VARYING_BOTH_DIMS shapes, the kernel would dereference that null pointer. If those cases are intentionally unsupported in Python, an explicit NVTE_CHECK should guard against them.
  • Dead code in kernel (group_dequantize_mxfp8.cuh, line 985): parity ^= 1 appears after the main loop but parity is never read again; this is a minor cleanup item.
  • Python tests cover SAME_BOTH_DIMS and VARYING_FIRST_DIM shapes only; VARYING_LAST_DIM and VARYING_BOTH_DIMS are exercised only by the C++ tests.

Confidence Score: 3/5

  • PR is mostly safe for SAME_BOTH_DIMS and VARYING_FIRST_DIM paths, but has a null-pointer risk in the Python binding for VARYING_LAST_DIM/VARYING_BOTH_DIMS shapes and incorrect metadata in the test harness.
  • The kernel logic and the C API layer are well-written and closely mirror the established group-quantize pattern. The two issues that lower the score are: (1) the missing last_dims extraction in the Python binding, which would cause a null-pointer dereference for any future or external caller using VARYING_LAST_DIM/VARYING_BOTH_DIMS shapes; and (2) the off-by-one in offsets_shape.data[0] in the test, which means the declared shape metadata does not match the allocated buffer. The dead-code parity ^= 1 is cosmetic. The currently exercised code paths (all is_single_tensor) would pass, but the latent issues need fixing before broader use.
  • transformer_engine/pytorch/csrc/extensions/cast.cpp (missing last_dims) and tests/cpp/operator/test_dequantize_mxfp8_grouped.cu (incorrect offsets shape) need the most attention.

Important Files Changed

Filename Overview
transformer_engine/common/cast/mxfp8/group_dequantize_mxfp8.cuh New CUDA kernel implementing grouped MXFP8 dequantization using TMA; handles four shape representations (SAME, VARYING_FIRST, VARYING_LAST, VARYING_BOTH). Kernel logic is sound and mirrors the existing group-quantize pattern. Minor: parity ^= 1 at the end of the kernel body is dead code.
transformer_engine/pytorch/csrc/extensions/cast.cpp New group_dequantize Python binding added; correctly extracts rowwise/columnwise data, scale_inv, first_dims and tensor_offsets, but is missing extraction and forwarding of last_dims. This silently leaves input->last_dims.dptr as nullptr, which would crash for VARYING_LAST_DIM or VARYING_BOTH_DIMS shaped tensors.
tests/cpp/operator/test_dequantize_mxfp8_grouped.cu Comprehensive C++ test comparing grouped vs single-tensor dequantize with bitwise equality. Bug: offsets_shape.data[0] is set to num_tensors but the CSR offsets array has num_tensors + 1 entries; the shape metadata should be num_tensors + 1. Also missing CUDA error checking on malloc/memcpy calls.
transformer_engine/common/cast/dispatch/dequantize.cuh New group_dequantize_helper dispatch function added; correctly routes to the MXFP8 grouped kernel and gates on CC >= 10.0. Default error case is handled.
transformer_engine/common/cast/cast.cu New nvte_group_dequantize C API entry point; trivially wraps dispatch::group_dequantize_helper and correctly uses convertNVTEGroupedTensorCheck.
tests/pytorch/test_grouped_tensor.py Two new Python tests added: functional correctness (test_group_dequantize) and CUDA-graph capturability (test_group_dequantize_cudagraph_capturable). Tests are well-structured but only cover SAME_BOTH_DIMS and VARYING_FIRST_DIM shapes.
transformer_engine/common/include/transformer_engine/cast.h New nvte_group_dequantize declared and documented correctly; docstring cleanup for nvte_dequantize (removed stale MXFP8-specific note).
transformer_engine/pytorch/csrc/extensions/pybind.cpp group_dequantize correctly registered in pybind11 module with appropriate argument names.

Sequence Diagram

sequenceDiagram
    participant PY as Python Caller
    participant BIND as pybind11 (cast.cpp)
    participant API as nvte_group_dequantize (cast.cu)
    participant DISP as dispatch::group_dequantize_helper
    participant TMA as update_tma_descriptors kernel
    participant KERN as group_dequantize_mxfp8_kernel

    PY->>BIND: tex.group_dequantize(grouped_tensor, otype)
    BIND->>BIND: Extract num_tensors, logical_shape, rowwise/colwise data,\nscale_inv, first_dims, tensor_offsets
    BIND->>BIND: Build GroupedTensorWrapper (input_cpp)\nCreate output via NoneQuantizer
    BIND->>API: nvte_group_dequantize(input, output, stream)
    API->>DISP: group_dequantize_helper(input, output, stream)
    DISP->>DISP: Validate inputs, check scaling mode\nDetermine shape_rep (SAME / VARYING_FIRST / VARYING_LAST / VARYING_BOTH)
    alt is_single_tensor (SAME_BOTH_DIMS or VARYING_FIRST_DIM)
        DISP->>KERN: group_dequantize_mxfp8_kernel<<<blocks, 128>>>\n(static TMA descriptors)
    else multi-tensor (VARYING_LAST_DIM or VARYING_BOTH_DIMS)
        DISP->>TMA: update_tma_descriptors<<<num_tensors, 32>>>\n(write per-tensor CUtensorMap into device globals)
        TMA-->>DISP: done
        DISP->>KERN: group_dequantize_mxfp8_kernel<<<blocks, 128>>>\n(per-tensor TMA descriptors via g_tensor_maps_*)
    end
    KERN->>KERN: TMA load FP8 tile → shmem\nApply e8m0 scale → higher-precision tile\nTMA store to output
    KERN-->>DISP: done
    DISP-->>API: done
    API-->>BIND: done
    BIND-->>PY: Python GroupedTensor (BF16/FP16/FP32 rowwise data)
Loading

Comments Outside Diff (2)

  1. transformer_engine/common/cast/mxfp8/group_dequantize_mxfp8.cuh, line 985 (link)

    Dead code — parity ^= 1 is never used

    parity is toggled here but is never read again after the loop. The variable is only consumed inside the loop by ptx::mbarrier_wait_parity(&mbar[iter], parity), where it always holds the value 0 (single-phase, not a true ping-pong). This leftover flip should be removed to avoid confusion.

  2. transformer_engine/pytorch/csrc/extensions/cast.cpp, line 1200-1238 (link)

    last_dims is never extracted or forwarded to the C++ wrapper

    first_dims and tensor_offsets are pulled from the Python GroupedTensor and forwarded to input_cpp, but last_dims is silently dropped. The underlying C++ kernel (group_dequantize_mxfp8_kernel) accepts a last_dims_ptr parameter and passes it through to get_tensor_cols_num, which dereferences it for both VARYING_LAST_DIM and VARYING_BOTH_DIMS shapes.

    For the currently-tested shapes (all SAME_BOTH_DIMS / VARYING_FIRST_DIM) the pointer is never dereferenced and the omission goes unnoticed. But any future caller that supplies a tensor with varying last dimensions from the Python side will get a null-pointer dereference in the kernel.

    The fix is to extract last_dims and set it on input_cpp, mirroring the first_dims block:

    auto last_dims = get_optional_tensor("last_dims");
    // ...
    if (last_dims.has_value()) {
        input_cpp.set_last_dims(last_dims->data_ptr(), DType::kInt64,
                                getTensorShape(*last_dims));
    }

    If VARYING_LAST_DIM / VARYING_BOTH_DIMS are intentionally unsupported in the Python path for now, that constraint should be enforced with an explicit NVTE_CHECK rather than silently allowing last_dims_ptr == nullptr to reach the kernel.

Last reviewed commit: 702adc5

nvte_set_grouped_tensor_param(in_group_tensor,
NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor,
sizeof(in_data_tensor));
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect offsets shape — off-by-one

offsets_shape.data[0] is set to num_tensors, but the offsets array is a standard CSR-style sentinel array with num_tensors + 1 entries (the last entry stores the total element count). The allocation uses (num_tensors + 1) * sizeof(size_t) on line 132 and offsets_h is declared with num_tensors + 1 on line 408. get_current_tensor_id (borrowed from the quantize path) searches over offsets_ptr[0 .. num_tensors], so it will access one element past the declared shape.

Suggested change
} else {
offsets_shape.data[0] = num_tensors + 1;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Dequantization support for the grouped tensor - MXFP8

1 participant