Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
abacbe5
code init
pggPL Feb 5, 2026
ca67a05
fix
pggPL Feb 6, 2026
55c9fd5
code drop
pggPL Feb 19, 2026
09bb7ea
Remove redundant nvte_set/get_grouped_tensor_swizzled_scales
pggPL Mar 16, 2026
63df695
Merge gitlab/main into grouped_gemm_nvfp4_and_hopper
pggPL Mar 17, 2026
44ab70d
Add Hopper support for grouped GEMM and refactor cuBLAS version checks
pggPL Mar 17, 2026
3689c10
Add NVFP4 support for discrete-input grouped GEMM and skip FP8 tensor…
pggPL Mar 18, 2026
d6d26bc
Add alignment assertions for MXFP8/NVFP4 scale offsets in grouped GEM…
pggPL Mar 18, 2026
c3ba64b
Merge remote-tracking branch 'origin/main' into grouped_gemm_nvfp4_an…
pggPL Apr 26, 2026
8bdd739
Fix grouped GEMM: NVFP4 columnwise transa=N + relax MXFP8 alignment f…
pggPL Apr 26, 2026
eba3468
Clarify swap_dims comment in build_grouped_gemm_multi_inputA_args
pggPL Apr 26, 2026
4375cf2
Merge remote-tracking branch 'upstream/main' into grouped_gemm_nvfp4_…
pggPL May 8, 2026
6c7a515
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2026
526a04a
Fix grouped GEMM scale_inv offsets for NVFP4 and FP8 block scaling
pggPL May 8, 2026
0f49cc3
Relax NVFP4 amax contiguity; consolidate scale_inv offset helpers; te…
pggPL May 11, 2026
ce342dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2026
a4df7bd
Remove unused float_size in GroupedGemmSetupWorkspace::from_buffers
pggPL May 11, 2026
d6a1597
Fix Hopper grouped GEMM alpha beta handling
pggPL Mar 16, 2026
d71d614
fix
pggPL Mar 16, 2026
b86fc7e
Address code review: NVFP4 amax check, swap_dims default, test refactor
pggPL May 11, 2026
59b90b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2026
7b1c8aa
Merge branch 'main' into grouped_gemm_nvfp4_and_hopper
pggPL May 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
801 changes: 404 additions & 397 deletions tests/cpp/operator/test_grouped_gemm.cu

Large diffs are not rendered by default.

21 changes: 14 additions & 7 deletions tests/cpp/operator/test_swizzle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,10 @@ void performTestGroupedSwizzleMXFP8(const int num_tensors, const size_t M, const
output_tensors.emplace_back(std::move(output));
}

GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING);
GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING);
GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING,
/*enforce_grouped_gemm_alignment=*/false);
GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING,
/*enforce_grouped_gemm_alignment=*/false);
const uint8_t input_swizzled = 0;
nvte_set_grouped_tensor_param(grouped_input.get_handle(),
kNVTEGroupedWithGEMMSwizzledScales,
Expand Down Expand Up @@ -369,8 +371,10 @@ void performTestGroupedUnswizzleMXFP8(const int num_tensors, const size_t M, con
output_tensors.emplace_back(std::move(output));
}

GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING);
GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING);
GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING,
/*enforce_grouped_gemm_alignment=*/false);
GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING,
/*enforce_grouped_gemm_alignment=*/false);
const uint8_t input_swizzled = 1;
nvte_set_grouped_tensor_param(grouped_input.get_handle(),
kNVTEGroupedWithGEMMSwizzledScales,
Expand Down Expand Up @@ -459,9 +463,12 @@ void performTestGroupedSwizzleUnswizzleRoundtrip(const int num_tensors, const si
final_tensors.emplace_back(std::move(fin));
}

GroupedBuffers grouped_orig = build_grouped_tensor(orig_ptrs, NVTE_MXFP8_1D_SCALING);
GroupedBuffers grouped_mid = build_grouped_tensor(mid_ptrs, NVTE_MXFP8_1D_SCALING);
GroupedBuffers grouped_fin = build_grouped_tensor(final_ptrs, NVTE_MXFP8_1D_SCALING);
GroupedBuffers grouped_orig = build_grouped_tensor(orig_ptrs, NVTE_MXFP8_1D_SCALING,
/*enforce_grouped_gemm_alignment=*/false);
GroupedBuffers grouped_mid = build_grouped_tensor(mid_ptrs, NVTE_MXFP8_1D_SCALING,
/*enforce_grouped_gemm_alignment=*/false);
GroupedBuffers grouped_fin = build_grouped_tensor(final_ptrs, NVTE_MXFP8_1D_SCALING,
/*enforce_grouped_gemm_alignment=*/false);

const NVTEShape row_shape = orig_tensors[0]->rowwise_scale_inv_shape();
const NVTEShape col_shape = orig_tensors[0]->columnwise_scale_inv_shape();
Expand Down
225 changes: 172 additions & 53 deletions tests/cpp/test_common.cu

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion tests/cpp/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,8 @@ struct GroupedBuffers {
CudaPtr<int64_t> last_dims_dev;
CudaPtr<int64_t> offsets_dev;
CudaPtr<> columnwise_data;
CudaPtr<> amax_dev; // Per-tensor amax for NVFP4 grouped GEMM
CudaPtr<> columnwise_amax_dev; // Per-tensor columnwise amax for NVFP4 grouped GEMM
NVTEShape logical_shape{};
std::vector<int64_t> offsets_host;
std::vector<size_t> tensor_bytes;
Expand All @@ -614,7 +616,8 @@ struct GroupedBuffers {
};

GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
const NVTEScalingMode scaling_mode);
const NVTEScalingMode scaling_mode,
bool enforce_grouped_gemm_alignment = true);

} // namespace test

Expand Down
7 changes: 5 additions & 2 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2883,10 +2883,13 @@ def _apply_grouped_bias_ref(
@pytest.mark.parametrize("accumulate", [False, True])
@pytest.mark.parametrize("use_bias_scale", [False, True])
def test_grouped_gemm_grouped_tensor(z, m, n, k, case, layout, accumulate, use_bias_scale) -> None:
if torch.cuda.get_device_capability() < (9, 0):
pytest.skip("Grouped GEMM requires Hopper (SM90) or newer.")
if torch.cuda.get_device_capability() < (10, 0):
if tex.get_cublasLt_version() < 130400:
pytest.skip("Grouped GEMM on Hopper requires cuBLAS 13.4+.")
if tex.get_cublasLt_version() < 130300:
pytest.skip("Grouped GEMM requires cuBLAS 13.3+.")
if torch.cuda.get_device_capability() < (10, 0):
pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.")
if not is_bf16_available():
pytest.skip("bfloat16 is required for grouped GEMM test.")

Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_M

inline bool is_nvfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; }

inline bool is_fp8_block_scaling(const NVTEScalingMode &mode) {
return mode == NVTE_BLOCK_SCALING_1D || mode == NVTE_BLOCK_SCALING_2D;
}

inline size_t product(const std::vector<size_t> &shape, const size_t begin, const size_t end) {
NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ",
end, " in a vector with ", shape.size(), " entries");
Expand Down
Loading
Loading