Skip to content

[JAX] Collective GEMM with FP8 and MXFP8 support#2740

Open
phu0ngng wants to merge 16 commits intoNVIDIA:mainfrom
phu0ngng:cgemm_fp8
Open

[JAX] Collective GEMM with FP8 and MXFP8 support#2740
phu0ngng wants to merge 16 commits intoNVIDIA:mainfrom
phu0ngng:cgemm_fp8

Conversation

@phu0ngng
Copy link
Collaborator

@phu0ngng phu0ngng commented Mar 5, 2026

Description

This PR extends the JAX Collective GEMM support with DelayedScalingFP8, CurrentScalingFP8, and MXFP8.
Unit tests for those quantization recipes are added. In addition, this PR also cleans up the test infrastructure in the collective gemm tests.

Note that Collective GEMM + MXFP8 requires all dimensions of the GEMM operands to be divisible by 128.
Besides, in the case of CGEMM + MXFP8 + AllGather, the block scales are still all-gathered in the critical path, unlike the quantized data, which is collectively gathered overlapping with the computation.

Type of change

  • Documentation change (change only to the documentation, either a fix or new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@phu0ngng phu0ngng changed the title [JAX] CGEMM + FP8 [JAX] CGEMM + FP8MXFP8 Mar 10, 2026
phu0ngng and others added 14 commits March 10, 2026 15:35
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng phu0ngng changed the title [JAX] CGEMM + FP8MXFP8 [JAX] CGEMM + FP8/MXFP8 Mar 10, 2026
@phu0ngng
Copy link
Collaborator Author

/te-ci JAX L1

@phu0ngng phu0ngng marked this pull request as ready for review March 10, 2026 23:23
@phu0ngng phu0ngng changed the title [JAX] CGEMM + FP8/MXFP8 [JAX] Collective GEMM with FP8 and MXFP8 support Mar 10, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 10, 2026

Greptile Summary

This PR extends the JAX Collective GEMM infrastructure to support DelayedScalingFP8, CurrentScalingFP8, and MXFP8 quantization recipes, rounding out FP8 support alongside the existing BF16 path. The core changes thread a quantizer_set argument through the GEMM/dense/layernorm-MLP primitives, add _reorder_tpsp_leading / _reorder_dp_leading helper functions (refactored from previously-inlined reshape logic), and introduce proper scale-sharding specs for block-scaling modes during sharding inference. The test infrastructure is significantly cleaned up: per-file execution is replaced by per-test-case execution in the shell harness, tolerances are now FP8-aware, and new test cases are added for each recipe × collective-op combination.

Key points:

  • Critical defect: transformer_engine/jax/cpp_extensions/gemm.py lines 1252–1253 contain a duplicated if not collective_op.is_none: at the same indentation level. Python raises IndentationError: expected an indented block and the module cannot be imported until one copy is removed.
  • The new _reorder_tpsp_leading / _reorder_dp_leading helper functions correctly extract the previously inlined reshape+transpose logic and reuse it for both the data tensor and the scale-inverse tensor.
  • helper.py gains two clean public utilities (is_quantize_recipe_supported, get_quantization_recipe) that map string recipe names to TE's internal ScalingMode and recipe objects.
  • NVFP4 + Collective GEMM tests are intentionally commented out pending future support, which matches the runtime assert not scaling_mode.is_nvfp4_scaling guard.

Confidence Score: 1/5

  • Not safe to merge — a Python SyntaxError in the core GEMM module prevents the entire transformer_engine.jax.cpp_extensions.gemm module from loading.
  • The duplicate if not collective_op.is_none: at lines 1252–1253 of gemm.py is a hard Python SyntaxError (IndentationError) that blocks module import. This single defect makes all functionality in this PR non-functional until resolved. The rest of the changes are well-structured and the test coverage is good.
  • transformer_engine/jax/cpp_extensions/gemm.py — the duplicate if guard at lines 1252-1253 must be fixed before merging.

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Extends GEMM collective operations with MXFP8 support: adds _reorder_tpsp_leading/_reorder_dp_leading helper functions, scale-sharding logic for block-scaling modes, and a correctness guard for NVFP4 — but introduces a Python SyntaxError via duplicate if not collective_op.is_none: at line 1252-1253 that prevents the module from loading.
transformer_engine/jax/quantize/helper.py Adds is_quantize_recipe_supported and get_quantization_recipe public helpers that map string recipe names to ScalingMode/recipe objects; clean implementation with proper error messages.
examples/jax/collective_gemm/common.py Adds FP8 tolerance handling (dtype_tols FP8 branch, get_tolerance_dtype), moves imports to the top, removes unused PARAMS_KEY/assert_allclose_print_index/_is_distributed_initialized, and extends cgemm_parser with --quantize-recipe; clean refactor.
examples/jax/collective_gemm/test_gemm.py Adds FP8 (DelayedScaling, Float8CurrentScaling) and MXFP8 test cases for Collective GEMM with AllGather/ReduceScatter; passes quantizer_set through _jitted_cgemm and adjusts tolerance dtype accordingly.
examples/jax/collective_gemm/test_dense_grad.py Adds FP8/MXFP8 gradient tests for dense collective GEMM; threads quantizer_set through _mean_dense / _value_and_grad_dense and switches tolerance dtype based on active quantizer.
examples/jax/collective_gemm/test_layernorm_mlp_grad.py Adds FP8/MXFP8 gradient tests for the LayerNorm MLP collective GEMM path; threads quantizer_sets tuple through _mean_layernorm_mlp / _value_and_grad_layernorm_mlp.
examples/jax/collective_gemm/run_test_cgemm.sh Refactors from per-file to per-test-case execution, extracts a TEST_NAME from the pytest node ID for log/XML naming, and adds new FP8/MXFP8 test cases while commenting out unsupported NVFP4 cases.
transformer_engine/jax/csrc/extensions/gemm.cpp Adds a blank line for readability after scale_dtype assignment and after workspace pointer setup; no logic changes.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["tex.gemm() / dense() / layernorm_mlp()"] --> B["_te_gemm()"]
    B --> C{collective_op?}
    C -- "is_none=True" --> D["Standard GEMM path\napply_padding_to_scale_inv (lhs + rhs)"]
    C -- "is_none=False" --> E{scaling_mode?}
    E -- "MXFP8_1D_SCALING" --> F["Assert dims divisible by 128\nSkip padding\nAssert scale seq_dim % tpsp == 0"]
    E -- "DELAYED / CURRENT" --> G["apply_padding_to_scale_inv (lhs + rhs)"]
    E -- "NVFP4" --> H["🚫 Assert fails — not supported"]
    F --> I{need_reorder?}
    G --> I
    I -- "RS + lhs.shape[0]!=1" --> J["_reorder_tpsp_leading(lhs)"]
    I -- "RS or AG + scale + 1D_block" --> K["_reorder_tpsp_leading(lhs_scale_inv)"]
    J --> L["GemmPrimitive.inner_primitive.bind()"]
    K --> L
    I -- "no reorder" --> L
    L --> M{need_reorder?}
    M -- "AG + output.shape[0]!=1" --> N["_reorder_dp_leading(output)"]
    M -- "no reorder" --> O["Return output"]
    N --> O
Loading

Last reviewed commit: b0cc000

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Comment on lines +1252 to +1257
if not collective_op.is_none:
if not collective_op.is_none:
assert not scaling_mode.is_nvfp4_scaling, (
f"Collective GEMM is not yet supported with {scaling_mode} quantization. "
"Only DELAYED_TENSOR_SCALING, CURRENT_TENSOR_SCALING, and MXFP8_1D_SCALING are supported."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate if statement causes Python SyntaxError

Lines 1252 and 1253 contain two identical if not collective_op.is_none: guards at the same indentation level. In Python, an if statement requires an indented block as its body — but the second if sits at the same indentation depth, not inside the first. Python raises IndentationError: expected an indented block after 'if' statement on line 1252, which prevents the entire module from being imported.

The first duplicate line must be removed:

Suggested change
if not collective_op.is_none:
if not collective_op.is_none:
assert not scaling_mode.is_nvfp4_scaling, (
f"Collective GEMM is not yet supported with {scaling_mode} quantization. "
"Only DELAYED_TENSOR_SCALING, CURRENT_TENSOR_SCALING, and MXFP8_1D_SCALING are supported."
)
if not collective_op.is_none:
assert not scaling_mode.is_nvfp4_scaling, (
f"Collective GEMM is not yet supported with {scaling_mode} quantization. "
"Only DELAYED_TENSOR_SCALING, CURRENT_TENSOR_SCALING, and MXFP8_1D_SCALING are supported."
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant