[Common, pyTorch] Grouped MXFP8 dequantize support#2722
[Common, pyTorch] Grouped MXFP8 dequantize support#2722ptrendx wants to merge 11 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds grouped MXFP8 dequantization support by introducing a new TMA-based CUDA kernel ( Key findings:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant PY as Python Caller
participant BIND as pybind11 (cast.cpp)
participant API as nvte_group_dequantize (cast.cu)
participant DISP as dispatch::group_dequantize_helper
participant TMA as update_tma_descriptors kernel
participant KERN as group_dequantize_mxfp8_kernel
PY->>BIND: tex.group_dequantize(grouped_tensor, otype)
BIND->>BIND: Extract num_tensors, logical_shape, rowwise/colwise data,\nscale_inv, first_dims, tensor_offsets
BIND->>BIND: Build GroupedTensorWrapper (input_cpp)\nCreate output via NoneQuantizer
BIND->>API: nvte_group_dequantize(input, output, stream)
API->>DISP: group_dequantize_helper(input, output, stream)
DISP->>DISP: Validate inputs, check scaling mode\nDetermine shape_rep (SAME / VARYING_FIRST / VARYING_LAST / VARYING_BOTH)
alt is_single_tensor (SAME_BOTH_DIMS or VARYING_FIRST_DIM)
DISP->>KERN: group_dequantize_mxfp8_kernel<<<blocks, 128>>>\n(static TMA descriptors)
else multi-tensor (VARYING_LAST_DIM or VARYING_BOTH_DIMS)
DISP->>TMA: update_tma_descriptors<<<num_tensors, 32>>>\n(write per-tensor CUtensorMap into device globals)
TMA-->>DISP: done
DISP->>KERN: group_dequantize_mxfp8_kernel<<<blocks, 128>>>\n(per-tensor TMA descriptors via g_tensor_maps_*)
end
KERN->>KERN: TMA load FP8 tile → shmem\nApply e8m0 scale → higher-precision tile\nTMA store to output
KERN-->>DISP: done
DISP-->>API: done
API-->>BIND: done
BIND-->>PY: Python GroupedTensor (BF16/FP16/FP32 rowwise data)
|
| nvte_set_grouped_tensor_param(in_group_tensor, | ||
| NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor, | ||
| sizeof(in_data_tensor)); | ||
| } else { |
There was a problem hiding this comment.
Incorrect offsets shape — off-by-one
offsets_shape.data[0] is set to num_tensors, but the offsets array is a standard CSR-style sentinel array with num_tensors + 1 entries (the last entry stores the total element count). The allocation uses (num_tensors + 1) * sizeof(size_t) on line 132 and offsets_h is declared with num_tensors + 1 on line 408. get_current_tensor_id (borrowed from the quantize path) searches over offsets_ptr[0 .. num_tensors], so it will access one element past the declared shape.
| } else { | |
| offsets_shape.data[0] = num_tensors + 1; |
Description
Support dequantization for MXFP8 grouped tensors.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: