[JAX] Collective GEMM with FP8 and MXFP8 support#2740
[JAX] Collective GEMM with FP8 and MXFP8 support#2740phu0ngng wants to merge 16 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
/te-ci JAX L1 |
for more information, see https://pre-commit.ci
Greptile SummaryThis PR extends the JAX Collective GEMM infrastructure to support Key points:
Confidence Score: 1/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["tex.gemm() / dense() / layernorm_mlp()"] --> B["_te_gemm()"]
B --> C{collective_op?}
C -- "is_none=True" --> D["Standard GEMM path\napply_padding_to_scale_inv (lhs + rhs)"]
C -- "is_none=False" --> E{scaling_mode?}
E -- "MXFP8_1D_SCALING" --> F["Assert dims divisible by 128\nSkip padding\nAssert scale seq_dim % tpsp == 0"]
E -- "DELAYED / CURRENT" --> G["apply_padding_to_scale_inv (lhs + rhs)"]
E -- "NVFP4" --> H["🚫 Assert fails — not supported"]
F --> I{need_reorder?}
G --> I
I -- "RS + lhs.shape[0]!=1" --> J["_reorder_tpsp_leading(lhs)"]
I -- "RS or AG + scale + 1D_block" --> K["_reorder_tpsp_leading(lhs_scale_inv)"]
J --> L["GemmPrimitive.inner_primitive.bind()"]
K --> L
I -- "no reorder" --> L
L --> M{need_reorder?}
M -- "AG + output.shape[0]!=1" --> N["_reorder_dp_leading(output)"]
M -- "no reorder" --> O["Return output"]
N --> O
Last reviewed commit: b0cc000 |
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
| if not collective_op.is_none: | ||
| if not collective_op.is_none: | ||
| assert not scaling_mode.is_nvfp4_scaling, ( | ||
| f"Collective GEMM is not yet supported with {scaling_mode} quantization. " | ||
| "Only DELAYED_TENSOR_SCALING, CURRENT_TENSOR_SCALING, and MXFP8_1D_SCALING are supported." | ||
| ) |
There was a problem hiding this comment.
Duplicate if statement causes Python SyntaxError
Lines 1252 and 1253 contain two identical if not collective_op.is_none: guards at the same indentation level. In Python, an if statement requires an indented block as its body — but the second if sits at the same indentation depth, not inside the first. Python raises IndentationError: expected an indented block after 'if' statement on line 1252, which prevents the entire module from being imported.
The first duplicate line must be removed:
| if not collective_op.is_none: | |
| if not collective_op.is_none: | |
| assert not scaling_mode.is_nvfp4_scaling, ( | |
| f"Collective GEMM is not yet supported with {scaling_mode} quantization. " | |
| "Only DELAYED_TENSOR_SCALING, CURRENT_TENSOR_SCALING, and MXFP8_1D_SCALING are supported." | |
| ) | |
| if not collective_op.is_none: | |
| assert not scaling_mode.is_nvfp4_scaling, ( | |
| f"Collective GEMM is not yet supported with {scaling_mode} quantization. " | |
| "Only DELAYED_TENSOR_SCALING, CURRENT_TENSOR_SCALING, and MXFP8_1D_SCALING are supported." | |
| ) |
Description
This PR extends the JAX Collective GEMM support with DelayedScalingFP8, CurrentScalingFP8, and MXFP8.
Unit tests for those quantization recipes are added. In addition, this PR also cleans up the test infrastructure in the collective gemm tests.
Note that Collective GEMM + MXFP8 requires all dimensions of the GEMM operands to be divisible by 128.
Besides, in the case of CGEMM + MXFP8 + AllGather, the block scales are still all-gathered in the critical path, unlike the quantized data, which is collectively gathered overlapping with the computation.
Type of change
Checklist: