From 817940c94aacfb18e8cd0b83febbbd3aa8d294b8 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 5 Mar 2026 14:20:49 -0800 Subject: [PATCH 01/22] add cgemm + FP8 tests Signed-off-by: Phuong Nguyen --- examples/jax/collective_gemm/common.py | 40 ++++++- .../jax/collective_gemm/run_test_cgemm.sh | 79 +++++++++---- .../jax/collective_gemm/test_dense_grad.py | 111 +++++++++++++++++- examples/jax/collective_gemm/test_gemm.py | 111 +++++++++++++++++- .../test_layernorm_mlp_grad.py | 65 +++++++++- transformer_engine/jax/cpp_extensions/gemm.py | 6 + 6 files changed, 375 insertions(+), 37 deletions(-) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index 2965896d07..e3904221a4 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -77,6 +77,8 @@ def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e- import argparse import jax from jax.experimental import mesh_utils +from transformer_engine.common import recipe as te_recipe +from transformer_engine.jax.quantize import ScalingMode from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap # Global flag to track if distributed has been initialized @@ -183,6 +185,36 @@ def _create_mesh(args): return mesh +def get_scaling_mode_from_recipe_name(name: str) -> ScalingMode: + """Get ScalingMode from a recipe name string.""" + match name: + case "DelayedScaling": + return ScalingMode.DELAYED_TENSOR_SCALING + case "Float8CurrentScaling": + return ScalingMode.CURRENT_TENSOR_SCALING + case "MXFP8BlockScaling": + return ScalingMode.MXFP8_1D_SCALING + case "NVFP4BlockScaling": + return ScalingMode.NVFP4_1D_SCALING + case _: + raise ValueError(f"Invalid recipe name, got {name}") + + +def get_quantization_recipe_from_name_string(name: str): + """Query recipe from a given name string""" + match name: + case "DelayedScaling": + return te_recipe.DelayedScaling() + case "MXFP8BlockScaling": + return te_recipe.MXFP8BlockScaling() + case "Float8CurrentScaling": + return te_recipe.Float8CurrentScaling() + case "NVFP4BlockScaling": + return te_recipe.NVFP4BlockScaling() + case _: + raise ValueError(f"Invalid quantization_recipe, got {name}") + + def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor parallelism"): """Create common argument parser for all collective GEMM tests.""" parser = argparse.ArgumentParser(description=description) @@ -229,7 +261,7 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para help="Type of collective operation", ) parser.add_argument( - "--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use" + "--quantize-recipe", type=str, default="DelayedScaling", help="Quantization recipe to use" ) parser.add_argument( "--enable-data-parallel", action="store_true", help="Enable data parallelism" @@ -237,5 +269,11 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para parser.add_argument( "--enable-result-check", action="store_true", default=True, help="Enable result checking" ) + parser.add_argument( + "--use-fp8", + action="store_true", + default=False, + help="Enable FP8 quantization", + ) return parser diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index 388c878376..08fc92d2d8 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -23,11 +23,43 @@ else echo "NVLINK support detected" fi -# Define the test files to run -TEST_FILES=( -"test_gemm.py" -"test_dense_grad.py" -"test_layernorm_mlp_grad.py" +# Define individual test cases to run (file::class::method) +# DelayedScalingFP8 and CurrentScalingFP8 use the same GEMM so we don't need to test both cases all +# the time. +TEST_CASES=( +# test_gemm.py cases +"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_reduce_scatter_with_dp" +# TODO(Phuong): Enable when supported +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" + +# test_dense_grad.py cases +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_reduce_scatter" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter" +# TODO(Phuong): Enable when supported +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_reduce_scatter" + +# test_layernorm_mlp_grad.py cases +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_current_scaling_fp8_layernorm_mlp_grad" +# TODO(Phuong): Enable when supported +# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_mxfp8_layernorm_mlp_grad" +# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_nvfp4_layernorm_mlp_grad" ) echo @@ -57,24 +89,27 @@ cleanup() { # Set up signal handlers to cleanup on exit trap cleanup EXIT INT TERM -# Run each test file across all GPUs -for TEST_FILE in "${TEST_FILES[@]}"; do +# Run each test case across all GPUs +for TEST_CASE in "${TEST_CASES[@]}"; do echo - echo "=== Starting test file: $TEST_FILE ..." + echo "=== Starting test: $TEST_CASE ..." - # Clear PIDs array for this test file + # Extract just the test method name for log/xml file naming + TEST_NAME=$(echo "$TEST_CASE" | awk -F'::' '{print $NF}') + + # Clear PIDs array for this test case PIDS=() for i in $(seq 0 $(($NUM_GPUS - 1))); do # Define output file for logs - LOG_FILE="${TEST_FILE}_gpu_${i}.log" + LOG_FILE="${TEST_NAME}_gpu_${i}.log" if [ $i -eq 0 ]; then # For process 0: show live output AND save to log file using tee echo "=== Live output from process 0 ===" pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \ - "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_NAME}.xml \ + "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \ --num-processes=$NUM_GPUS \ --process-id=$i 2>&1 | tee "$LOG_FILE" & PID=$! @@ -82,7 +117,7 @@ for TEST_FILE in "${TEST_FILES[@]}"; do else # For other processes: redirect to log files only pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \ --num-processes=$NUM_GPUS \ --process-id=$i > "$LOG_FILE" 2>&1 & PID=$! @@ -93,22 +128,22 @@ for TEST_FILE in "${TEST_FILES[@]}"; do # Wait for all processes to finish wait - # Check and print the log content from process 0 (now has log file thanks to tee) - if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE SKIPPED" - elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE FAILED" + # Check and print the log content from process 0 + if grep -q "SKIPPED" "${TEST_NAME}_gpu_0.log"; then + echo "... $TEST_CASE SKIPPED" + elif grep -q "FAILED" "${TEST_NAME}_gpu_0.log"; then + echo "... $TEST_CASE FAILED" HAS_FAILURE=1 - elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE PASSED" + elif grep -q "PASSED" "${TEST_NAME}_gpu_0.log"; then + echo "... $TEST_CASE PASSED" else - echo "... $TEST_FILE INVALID" + echo "... $TEST_CASE INVALID" HAS_FAILURE=1 fi # Remove the log files after processing them wait - rm ${TEST_FILE}_gpu_*.log + rm ${TEST_NAME}_gpu_*.log done wait diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index 94c7dc5b66..980a9a7df9 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -20,11 +20,18 @@ TPSP_AXIS, PARAMS_KEY, cgemm_parser, + get_quantization_recipe_from_name_string, + get_scaling_mode_from_recipe_name, ) from transformer_engine.jax.dense import dense -from transformer_engine.jax.quantize import autocast +from transformer_engine.jax.quantize import ( + autocast, + is_scaling_mode_supported, + QuantizerFactory, + noop_quantizer_set, +) from transformer_engine.jax.cpp_extensions.gemm import ( CollectiveOp, CollectiveOpSet, @@ -56,7 +63,7 @@ def _get_operand_sharding(mesh, collective_op): return x_sharding, weight_sharding, bias_sharding -def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set): +def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set): output = dense( x, weight, @@ -66,13 +73,14 @@ def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collectiv kernel_axes=weight_axes, output_axes=output_axes, collective_op_set=collective_op_set, + quantizer_set=quantizer_set, ) return jnp.mean(output.astype(jnp.float32)) -def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set): +def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set): return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))( - x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set + x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set ) @@ -98,11 +106,16 @@ def run_dense_grad_tests(args, mesh=None): ) collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op) + use_fp8 = getattr(args, "use_fp8", False) + recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_fp8 else None with mesh, autocast( - enabled=False, - recipe=None, + enabled=use_fp8, + recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): + # Build quantizer_set inside autocast so create_set() reads the global recipe + # for correct fwd/bwd dtypes. + quantizer_set = QuantizerFactory.create_set() if use_fp8 else noop_quantizer_set # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) @@ -123,6 +136,7 @@ def run_dense_grad_tests(args, mesh=None): weight_axes, output_axes, noop_collective_op_set, + quantizer_set, ) output, sharded_grads = _value_and_grad_dense( x_sharded, @@ -132,6 +146,7 @@ def run_dense_grad_tests(args, mesh=None): weight_axes, output_axes, collective_op_set, + quantizer_set, ) jax.block_until_ready(ref_output) jax.block_until_ready(output) @@ -187,6 +202,90 @@ def test_te_bf16_reduce_scatter(self): self.args.collective_type = "reduce_scatter" run_dense_grad_tests(self.args, self.mesh) + def test_te_delayed_scaling_fp8_all_gather(self): + """Test Collective Dense Gradient with FP8 DelayedScaling + AllGather""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "all_gather" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_delayed_scaling_fp8_reduce_scatter(self): + """Test Collective Dense Gradient with FP8 DelayedScaling + ReduceScatter""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_all_gather(self): + """Test Collective Dense Gradient with FP8 Float8CurrentScaling + AllGather""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "all_gather" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_reduce_scatter(self): + """Test Collective Dense Gradient with FP8 Float8CurrentScaling + ReduceScatter""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" + run_dense_grad_tests(self.args, self.mesh) + + # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported + # def test_te_mxfp8_all_gather(self): + # """Test Collective Dense Gradient with MXFP8BlockScaling + AllGather""" + # self.args.quantize_recipe = "MXFP8BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # self.args.collective_type = "all_gather" + # run_dense_grad_tests(self.args, self.mesh) + + # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported + # def test_te_mxfp8_reduce_scatter(self): + # """Test Collective Dense Gradient with MXFP8BlockScaling + ReduceScatter""" + # self.args.quantize_recipe = "MXFP8BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # self.args.collective_type = "reduce_scatter" + # run_dense_grad_tests(self.args, self.mesh) + + # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported + # def test_te_nvfp4_all_gather(self): + # """Test Collective Dense Gradient with NVFP4BlockScaling + AllGather""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # self.args.collective_type = "all_gather" + # run_dense_grad_tests(self.args, self.mesh) + + # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported + # def test_te_nvfp4_reduce_scatter(self): + # """Test Collective Dense Gradient with NVFP4BlockScaling + ReduceScatter""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # self.args.collective_type = "reduce_scatter" + # run_dense_grad_tests(self.args, self.mesh) + if __name__ == "__main__": import sys diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index ea119713e3..abefafe0a0 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -29,10 +29,17 @@ TPSP_AXIS, PARAMS_KEY, cgemm_parser, + get_quantization_recipe_from_name_string, + get_scaling_mode_from_recipe_name, ) import transformer_engine.jax.cpp_extensions as tex -from transformer_engine.jax.quantize import autocast +from transformer_engine.jax.quantize import ( + autocast, + is_scaling_mode_supported, + QuantizerFactory, + noop_quantizer_set, +) from transformer_engine.jax.cpp_extensions.gemm import CollectiveOp from transformer_engine.jax.sharding import MeshResource @@ -72,13 +79,14 @@ def _get_dp_and_tp_sizes(args): @partial(jax.jit, static_argnames=("contracting_dims", "collective_op", "output_sharding")) -def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_sharding): +def _jitted_cgemm(x, weight, bias, quantizer_set, contracting_dims, collective_op, output_sharding): output = tex.gemm( x, weight, bias=bias, contracting_dims=contracting_dims, collective_op=collective_op, + quantizer_set=quantizer_set, ) if output_sharding is not None: output = jax.lax.with_sharding_constraint(output, output_sharding) @@ -107,11 +115,20 @@ def run_gemm_tests(args, mesh=None): else CollectiveOp.REDUCE_SCATTER ) + use_fp8 = getattr(args, "use_fp8", False) + recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_fp8 else None + + # autocast sets the global recipe (fwd/bwd dtypes) AND the global MeshResource + # (via global_shard_guard) required for collective GEMM sharding axis resolution. with mesh, autocast( - enabled=False, - recipe=None, + enabled=use_fp8, + recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): + # Build quantizer_set inside autocast so create_set() can read the global recipe + # for correct fwd/bwd dtypes. autocast does not inject quantizers into raw + # tex.gemm() calls, so we must pass quantizer_set explicitly. + quantizer_set = QuantizerFactory.create_set() if use_fp8 else noop_quantizer_set print(f"Device mesh: {mesh}") x_sharding, weight_sharding, bias_sharding, output_sharding = _get_operand_sharding( @@ -125,6 +142,7 @@ def run_gemm_tests(args, mesh=None): x_sharded, weight_sharded, bias_sharded, + quantizer_set, contracting_dims=((2,), (0,)), collective_op=CollectiveOp.NONE, output_sharding=output_sharding, @@ -133,6 +151,7 @@ def run_gemm_tests(args, mesh=None): x_sharded, weight_sharded, bias_sharded, + quantizer_set, contracting_dims=((2,), (0,)), collective_op=collective_op, output_sharding=output_sharding, @@ -186,6 +205,90 @@ def test_te_bf16_reduce_scatter_with_dp(self): self.args.collective_type = "reduce_scatter" run_gemm_tests(self.args, self.mesh) + def test_te_delayed_scaling_fp8_all_gather_with_dp(self): + """Test Collective GEMM with FP8 DelayedScaling + AllGather""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "all_gather" + run_gemm_tests(self.args, self.mesh) + + def test_te_delayed_scaling_fp8_reduce_scatter_with_dp(self): + """Test Collective GEMM with FP8 DelayedScaling + ReduceScatter""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" + run_gemm_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_all_gather_with_dp(self): + """Test Collective GEMM with FP8 Float8CurrentScaling + AllGather""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "all_gather" + run_gemm_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_reduce_scatter_with_dp(self): + """Test Collective GEMM with FP8 Float8CurrentScaling + ReduceScatter""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" + run_gemm_tests(self.args, self.mesh) + + # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported + # def test_te_mxfp8_all_gather_with_dp(self): + # """Test Collective GEMM with MXFP8BlockScaling + AllGather""" + # self.args.quantize_recipe = "MXFP8BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # self.args.collective_type = "all_gather" + # run_gemm_tests(self.args, self.mesh) + + # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported + # def test_te_mxfp8_reduce_scatter_with_dp(self): + # """Test Collective GEMM with MXFP8BlockScaling + ReduceScatter""" + # self.args.quantize_recipe = "MXFP8BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # self.args.collective_type = "reduce_scatter" + # run_gemm_tests(self.args, self.mesh) + + # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported + # def test_te_nvfp4_all_gather_with_dp(self): + # """Test Collective GEMM with NVFP4BlockScaling + AllGather""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # self.args.collective_type = "all_gather" + # run_gemm_tests(self.args, self.mesh) + + # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported + # def test_te_nvfp4_reduce_scatter_with_dp(self): + # """Test Collective GEMM with NVFP4BlockScaling + ReduceScatter""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # self.args.collective_type = "reduce_scatter" + # run_gemm_tests(self.args, self.mesh) + if __name__ == "__main__": import sys diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 84cb011da1..7a46487325 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -20,11 +20,18 @@ TPSP_AXIS, PARAMS_KEY, cgemm_parser, + get_quantization_recipe_from_name_string, + get_scaling_mode_from_recipe_name, ) from transformer_engine.jax.layernorm_mlp import layernorm_mlp -from transformer_engine.jax.quantize import autocast +from transformer_engine.jax.quantize import ( + autocast, + is_scaling_mode_supported, + QuantizerFactory, + noop_quantizer_set, +) from transformer_engine.jax.cpp_extensions.gemm import ( CollectiveOpSet, CollectiveOp, @@ -68,6 +75,7 @@ def _mean_layernorm_mlp( weight_1_axes, weight_2_axes, collective_op_sets, + quantizer_sets, ): output = layernorm_mlp( x, @@ -82,6 +90,7 @@ def _mean_layernorm_mlp( kernel_2_axes=weight_2_axes, activation_type=("gelu",), collective_op_sets=collective_op_sets, + quantizer_sets=quantizer_sets, ) return jnp.mean(output) @@ -98,6 +107,7 @@ def _value_and_grad_layernorm_mlp( weight_1_axes, weight_2_axes, collective_op_sets, + quantizer_sets, ): return jax.jit( jax.value_and_grad(_mean_layernorm_mlp, (0, 1, 2, 3, 4, 5)), static_argnums=(6, 7, 8, 9, 10) @@ -113,6 +123,7 @@ def _value_and_grad_layernorm_mlp( weight_1_axes, weight_2_axes, collective_op_sets, + quantizer_sets, ) @@ -149,11 +160,17 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): collective_op_sets = (collective_op_set_1, collective_op_set_2) noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set) + use_fp8 = getattr(args, "use_fp8", False) + recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_fp8 else None with mesh, autocast( - enabled=False, - recipe=None, + enabled=use_fp8, + recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): + # Build quantizer_set inside autocast so create_set() reads the global recipe + # for correct fwd/bwd dtypes. One set per dense layer (GEMM1=AG, GEMM2=RS). + quantizer_set = QuantizerFactory.create_set() if use_fp8 else noop_quantizer_set + quantizer_sets = (quantizer_set, quantizer_set) # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) @@ -181,6 +198,7 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): weight_1_axes, weight_2_axes, noop_collective_op_sets, + quantizer_sets, ) output, sharded_grads = _value_and_grad_layernorm_mlp( x_sharded, @@ -194,6 +212,7 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): weight_1_axes, weight_2_axes, collective_op_sets, + quantizer_sets, ) jax.block_until_ready(ref_output) jax.block_until_ready(output) @@ -240,9 +259,47 @@ def tearDown(self): os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None) def test_te_bf16_layernorm_mlp_grad(self): - """Test Collective Dense Gradient with AllGather""" + """Test Collective LayerNorm MLP Gradient with BF16""" + run_layernorm_mlp_grad_tests(self.args, self.mesh) + + def test_te_delayed_scaling_fp8_layernorm_mlp_grad(self): + """Test Collective LayerNorm MLP Gradient with FP8 DelayedScaling""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True run_layernorm_mlp_grad_tests(self.args, self.mesh) + def test_te_current_scaling_fp8_layernorm_mlp_grad(self): + """Test Collective LayerNorm MLP Gradient with FP8 Float8CurrentScaling""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + run_layernorm_mlp_grad_tests(self.args, self.mesh) + + # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported + # def test_te_mxfp8_layernorm_mlp_grad(self): + # """Test Collective LayerNorm MLP Gradient with MXFP8BlockScaling""" + # self.args.quantize_recipe = "MXFP8BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # run_layernorm_mlp_grad_tests(self.args, self.mesh) + + # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported + # def test_te_nvfp4_layernorm_mlp_grad(self): + # """Test Collective LayerNorm MLP Gradient with NVFP4BlockScaling""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # run_layernorm_mlp_grad_tests(self.args, self.mesh) + if __name__ == "__main__": import sys diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 4506adf33b..804d89f156 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1247,6 +1247,12 @@ def _te_gemm( rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs_amax) alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv + if not collective_op.is_none and scaling_mode.is_1d_block_scaling(): + raise ValueError( + f"Collective GEMM is not yet supported with {scaling_mode} quantization. " + "Only DELAYED_TENSOR_SCALING and CURRENT_TENSOR_SCALING are supported." + ) + out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype if bias is None: bias = jnp.empty(0, dtype=out_dtype) From 1b3519b3acf1d225758a63ca81b66605dd41ef8a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Mar 2026 22:22:46 +0000 Subject: [PATCH 02/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jax/collective_gemm/test_dense_grad.py | 24 ++++++++++++++----- examples/jax/collective_gemm/test_gemm.py | 16 +++++++++---- .../test_layernorm_mlp_grad.py | 8 +++++-- 3 files changed, 36 insertions(+), 12 deletions(-) diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index 980a9a7df9..dcb51b34f4 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -63,7 +63,9 @@ def _get_operand_sharding(mesh, collective_op): return x_sharding, weight_sharding, bias_sharding -def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set): +def _mean_dense( + x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set +): output = dense( x, weight, @@ -78,7 +80,9 @@ def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collectiv return jnp.mean(output.astype(jnp.float32)) -def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set): +def _value_and_grad_dense( + x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set +): return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))( x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set ) @@ -205,7 +209,9 @@ def test_te_bf16_reduce_scatter(self): def test_te_delayed_scaling_fp8_all_gather(self): """Test Collective Dense Gradient with FP8 DelayedScaling + AllGather""" self.args.quantize_recipe = "DelayedScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True @@ -215,7 +221,9 @@ def test_te_delayed_scaling_fp8_all_gather(self): def test_te_delayed_scaling_fp8_reduce_scatter(self): """Test Collective Dense Gradient with FP8 DelayedScaling + ReduceScatter""" self.args.quantize_recipe = "DelayedScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True @@ -225,7 +233,9 @@ def test_te_delayed_scaling_fp8_reduce_scatter(self): def test_te_current_scaling_fp8_all_gather(self): """Test Collective Dense Gradient with FP8 Float8CurrentScaling + AllGather""" self.args.quantize_recipe = "Float8CurrentScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True @@ -235,7 +245,9 @@ def test_te_current_scaling_fp8_all_gather(self): def test_te_current_scaling_fp8_reduce_scatter(self): """Test Collective Dense Gradient with FP8 Float8CurrentScaling + ReduceScatter""" self.args.quantize_recipe = "Float8CurrentScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index abefafe0a0..f8c9dd18a8 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -208,7 +208,9 @@ def test_te_bf16_reduce_scatter_with_dp(self): def test_te_delayed_scaling_fp8_all_gather_with_dp(self): """Test Collective GEMM with FP8 DelayedScaling + AllGather""" self.args.quantize_recipe = "DelayedScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True @@ -218,7 +220,9 @@ def test_te_delayed_scaling_fp8_all_gather_with_dp(self): def test_te_delayed_scaling_fp8_reduce_scatter_with_dp(self): """Test Collective GEMM with FP8 DelayedScaling + ReduceScatter""" self.args.quantize_recipe = "DelayedScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True @@ -228,7 +232,9 @@ def test_te_delayed_scaling_fp8_reduce_scatter_with_dp(self): def test_te_current_scaling_fp8_all_gather_with_dp(self): """Test Collective GEMM with FP8 Float8CurrentScaling + AllGather""" self.args.quantize_recipe = "Float8CurrentScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True @@ -238,7 +244,9 @@ def test_te_current_scaling_fp8_all_gather_with_dp(self): def test_te_current_scaling_fp8_reduce_scatter_with_dp(self): """Test Collective GEMM with FP8 Float8CurrentScaling + ReduceScatter""" self.args.quantize_recipe = "Float8CurrentScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 7a46487325..ebf55ea5df 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -265,7 +265,9 @@ def test_te_bf16_layernorm_mlp_grad(self): def test_te_delayed_scaling_fp8_layernorm_mlp_grad(self): """Test Collective LayerNorm MLP Gradient with FP8 DelayedScaling""" self.args.quantize_recipe = "DelayedScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True @@ -274,7 +276,9 @@ def test_te_delayed_scaling_fp8_layernorm_mlp_grad(self): def test_te_current_scaling_fp8_layernorm_mlp_grad(self): """Test Collective LayerNorm MLP Gradient with FP8 Float8CurrentScaling""" self.args.quantize_recipe = "Float8CurrentScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True From 2e3bbebf81e54c679fa6bbf1260d6f48075adaa0 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 08:00:58 -0700 Subject: [PATCH 03/22] cgemm+mxfp8 passed for AG Signed-off-by: Phuong Nguyen --- examples/jax/collective_gemm/common.py | 2 + .../jax/collective_gemm/run_test_cgemm.sh | 64 +++++++++---------- examples/jax/collective_gemm/test_gemm.py | 19 +++--- transformer_engine/jax/cpp_extensions/gemm.py | 60 ++++++++++++++--- .../jax/csrc/extensions/gemm.cpp | 5 +- transformer_engine/jax/quantize/helper.py | 10 +-- 6 files changed, 105 insertions(+), 55 deletions(-) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index e3904221a4..95452e85bc 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -250,7 +250,9 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para "--tensor-parallel-size", type=int, default=None, help="Tensor parallel size" ) parser.add_argument("--batch-size", type=int, default=4, help="Batch size for testing") + # parser.add_argument("--batch-size", type=int, default=2, help="Batch size for testing") parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing") + # parser.add_argument("--seq-len", type=int, default=16384, help="Sequence length for testing") parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension") parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension") parser.add_argument( diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index 08fc92d2d8..b1a8816703 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -28,38 +28,38 @@ fi # the time. TEST_CASES=( # test_gemm.py cases -"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp" -"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp" -"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp" -"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp" -"test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_all_gather_with_dp" -"test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_reduce_scatter_with_dp" -# TODO(Phuong): Enable when supported -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" - -# test_dense_grad.py cases -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather" -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter" -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_all_gather" -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_reduce_scatter" -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather" -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter" -# TODO(Phuong): Enable when supported -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_all_gather" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_reduce_scatter" - -# test_layernorm_mlp_grad.py cases -"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad" -"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad" -"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_current_scaling_fp8_layernorm_mlp_grad" -# TODO(Phuong): Enable when supported -# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_mxfp8_layernorm_mlp_grad" -# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_nvfp4_layernorm_mlp_grad" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_all_gather_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_reduce_scatter_with_dp" +# # TODO(Phuong): Enable when supported +"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" +# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" +# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" +# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" +# +# # test_dense_grad.py cases +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_reduce_scatter" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter" +# # TODO(Phuong): Enable when supported +# # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather" +# # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter" +# # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_all_gather" +# # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_reduce_scatter" +# +# # test_layernorm_mlp_grad.py cases +# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad" +# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad" +# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_current_scaling_fp8_layernorm_mlp_grad" +# # TODO(Phuong): Enable when supported +# # "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_mxfp8_layernorm_mlp_grad" +# # "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_nvfp4_layernorm_mlp_grad" ) echo diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index f8c9dd18a8..2a968f7ba4 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -253,16 +253,15 @@ def test_te_current_scaling_fp8_reduce_scatter_with_dp(self): self.args.collective_type = "reduce_scatter" run_gemm_tests(self.args, self.mesh) - # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported - # def test_te_mxfp8_all_gather_with_dp(self): - # """Test Collective GEMM with MXFP8BlockScaling + AllGather""" - # self.args.quantize_recipe = "MXFP8BlockScaling" - # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) - # if not is_supported: - # self.skipTest(reason) - # self.args.use_fp8 = True - # self.args.collective_type = "all_gather" - # run_gemm_tests(self.args, self.mesh) + def test_te_mxfp8_all_gather_with_dp(self): + """Test Collective GEMM with MXFP8BlockScaling + AllGather""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "all_gather" + run_gemm_tests(self.args, self.mesh) # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported # def test_te_mxfp8_reduce_scatter_with_dp(self): diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 804d89f156..7a21eb1521 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -681,6 +681,34 @@ def impl( reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) lhs = reordered.reshape(original_shape) + if ( + collective_op.is_all_gather + and not transpose_batch_sequence + and not is_outer + and not lhs_scale_inv.shape[0] == 1 + and scaling_mode.is_1d_block_scaling() + ): + + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + original_shape = lhs_scale_inv.shape + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = lhs_scale_inv.reshape( + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + tpsp_axis_size(), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) + lhs_scale_inv = reordered.reshape(original_shape) + (output, _) = GemmPrimitive.inner_primitive.bind( lhs, lhs_scale_inv, @@ -812,6 +840,7 @@ def _parse_operand_output_specs( contracting_dims, transpose_batch_sequence, collective_op, + scaling_mode, ): lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) @@ -955,12 +984,24 @@ def _parse_operand_output_specs( # Bias sharding is based on GEMM output before any scatter bias_specs = rhs_non_cspecs if arg_infos[4].size > 0 else (None,) # bias is operand index 4 + # Scale shardings are based on the scaling_mode and collective_op + lhs_scale_specs = rhs_scale_specs = (None,) + if scaling_mode.is_1d_block_scaling(): + rhs_scale_specs = rhs_specs + if collective_op.is_all_gather: + lhs_scale_specs = tuple(None if i == sequence_dim else s for i, s in enumerate(lhs_specs)) + else: + lhs_scale_specs = lhs_specs + print(lhs_scale_specs) + print(rhs_scale_specs) + + if not collective_op.is_none: if sequence_dim < 0: raise ValueError(f"Invalid sequence_dim. Got sequence_dim={sequence_dim}") return ( - (lhs_specs, rhs_specs, bias_specs), + (lhs_specs, lhs_scale_specs, rhs_specs, rhs_scale_specs, bias_specs), out_specs, reduce_spec, sequence_dim, @@ -982,7 +1023,6 @@ def infer_sharding_from_operands( ): del ( out_dtype, - scaling_mode, use_split_accumulator, result_infos, is_outer, @@ -990,7 +1030,7 @@ def infer_sharding_from_operands( ) (_, out_specs, *_) = GemmPrimitive._parse_operand_output_specs( - arg_infos, contracting_dims, transpose_batch_sequence, collective_op + arg_infos, contracting_dims, transpose_batch_sequence, collective_op, scaling_mode, ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) @@ -1013,7 +1053,7 @@ def partition( del result_infos, is_outer, sequence_dim ( - (lhs_specs, rhs_specs, bias_input_specs), + (lhs_specs, lhs_scale_specs, rhs_specs, rhs_scale_specs, bias_input_specs), out_specs, reduce_spec, inferred_sequence_dim, @@ -1022,17 +1062,21 @@ def partition( contracting_dims, transpose_batch_sequence, collective_op, + scaling_mode, ) # Block scale inverses match their operands, but tensor scale inverses are unsharded. none_sharding = NamedSharding(mesh, PartitionSpec(None)) lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs)) + lhs_scale_sharding = NamedSharding(mesh, PartitionSpec(*lhs_scale_specs)) rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs)) + rhs_scale_sharding = NamedSharding(mesh, PartitionSpec(*rhs_scale_specs)) + arg_shardings = ( lhs_sharding, - lhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, + lhs_scale_sharding, rhs_sharding, - rhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, + rhs_scale_sharding, ) # Bias @@ -1247,8 +1291,8 @@ def _te_gemm( rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs_amax) alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv - if not collective_op.is_none and scaling_mode.is_1d_block_scaling(): - raise ValueError( + if not collective_op.is_none: + assert not scaling_mode.is_nvfp4_scaling, ( f"Collective GEMM is not yet supported with {scaling_mode} quantization. " "Only DELAYED_TENSOR_SCALING and CURRENT_TENSOR_SCALING are supported." ) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 737dd65622..268307ea83 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -58,6 +58,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( std::vector scale_shape = {1}; auto is_nvfp4 = is_nvfp4_scaling(scaling_mode); auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING || is_nvfp4) { // Block scaling also needs to be collapsed to match 2D data scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary), @@ -73,7 +74,8 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } input.set_with_gemm_swizzled_scales(true); - } else if (is_nvfp4) { // Swizzle for NVFP4 + } + else if (is_nvfp4) { // Swizzle for NVFP4 NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS"); input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); // Create tensor to hold swizzled scale factor @@ -202,6 +204,7 @@ Error_Type GemmV2FFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); size_t workspace_size = static_cast(workspace->element_count()) - 256; + if (is_nvfp4_scaling(config.scaling_mode)) { auto lhs_scale_size = product(lhs_scale_inv.dimensions()); auto rhs_scale_size = product(rhs_scale_inv.dimensions()); diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index c491bb8638..4ef9426433 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -916,10 +916,12 @@ def apply_padding_to_scale_inv( unpadded_scale_shape = scaling_mode.get_scale_shape( data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis ) - assert scale_inv.shape == unpadded_scale_shape, ( - f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got " - f"{scale_inv.shape}." - ) + + # TODO + # assert scale_inv.shape == unpadded_scale_shape, ( + # f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got " + # f"{scale_inv.shape}." + # ) # Pad the scales with the lowest representable value (2^-127) and return pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape)) From 76e0d78efe38582cc491f7c26c705eb2cc789d61 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 08:13:00 -0700 Subject: [PATCH 04/22] refactor code Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 110 ++++++++---------- 1 file changed, 47 insertions(+), 63 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 7a21eb1521..9a47dc3e25 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -407,6 +407,48 @@ def assert_cublas_requirements(scaling_mode, contracting_size, tensor_name): ) +def _reorder_tpsp_leading(tensor, original_shape): + """Reorder tensor so the tpsp axis is leading: reshape (dp, n, tpsp, m, ...), transpose (2, 0, 1, 3, ...).""" + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = tensor.reshape( + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + tpsp_axis_size(), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) + return reordered.reshape(original_shape) + + +def _reorder_dp_leading(tensor, original_shape): + """Reorder tensor so the dp axis is leading: reshape (tpsp, dp, n, m, ...), transpose (1, 2, 0, 3, ...).""" + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = tensor.reshape( + tpsp_axis_size(), + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(1, 2, 0, 3, *range(4, reshaped.ndim)) + return reordered.reshape(original_shape) + + class GemmPrimitive(BasePrimitive): """ Primitive for cuBLAS GEMM @@ -658,28 +700,8 @@ def impl( and not is_outer and not lhs.shape[0] == 1 ): - if sequence_dim != 1: - raise ValueError(f"Invalid sequence_dim. Got sequence_dim={sequence_dim}") - original_shape = lhs.shape - if original_shape[0] % dp_or_fsdp_axis_size() != 0 and original_shape[0] != 1: - raise ValueError( - f"Original_shape[0]={original_shape[0]} is not divisible by" - f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" - ) - if original_shape[1] % tpsp_axis_size() != 0 and original_shape[1] != 1: - raise ValueError( - f"Original_shape[1]={original_shape[1]} is not divisible by" - f" tpsp_axis_size()={tpsp_axis_size()}" - ) - reshaped = lhs.reshape( - dp_or_fsdp_axis_size(), - int(original_shape[0] / dp_or_fsdp_axis_size()), - tpsp_axis_size(), - int(original_shape[1] / tpsp_axis_size()), - *original_shape[2:], - ) - reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) - lhs = reordered.reshape(original_shape) + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + lhs = _reorder_tpsp_leading(lhs, lhs.shape) if ( collective_op.is_all_gather @@ -688,26 +710,8 @@ def impl( and not lhs_scale_inv.shape[0] == 1 and scaling_mode.is_1d_block_scaling() ): - assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" - original_shape = lhs_scale_inv.shape - assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( - f"Original_shape[0]={original_shape[0]} is not divisible by" - f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" - ) - assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( - f"Original_shape[1]={original_shape[1]} is not divisible by" - f" tpsp_axis_size()={tpsp_axis_size()}" - ) - reshaped = lhs_scale_inv.reshape( - dp_or_fsdp_axis_size(), - int(original_shape[0] / dp_or_fsdp_axis_size()), - tpsp_axis_size(), - int(original_shape[1] / tpsp_axis_size()), - *original_shape[2:], - ) - reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) - lhs_scale_inv = reordered.reshape(original_shape) + lhs_scale_inv = _reorder_tpsp_leading(lhs_scale_inv, lhs_scale_inv.shape) (output, _) = GemmPrimitive.inner_primitive.bind( lhs, @@ -733,28 +737,8 @@ def impl( and not is_outer and not output.shape[0] == 1 ): - if sequence_dim != 1: - raise ValueError(f"Invalid sequence_dim. Got sequence_dim={sequence_dim}") - original_shape = output.shape - if original_shape[0] % dp_or_fsdp_axis_size() != 0 and original_shape[0] != 1: - raise ValueError( - f"Original_shape[0]={original_shape[0]} is not divisible by" - f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" - ) - if original_shape[1] % tpsp_axis_size() != 0 and original_shape[1] != 1: - raise ValueError( - f"Original_shape[1]={original_shape[1]} is not divisible by" - f" tpsp_axis_size()={tpsp_axis_size()}" - ) - reshaped = output.reshape( - tpsp_axis_size(), - dp_or_fsdp_axis_size(), - int(original_shape[0] / dp_or_fsdp_axis_size()), - int(original_shape[1] / tpsp_axis_size()), - *original_shape[2:], - ) - reordered = reshaped.transpose(1, 2, 0, 3, *range(4, reshaped.ndim)) - output = reordered.reshape(original_shape) + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + output = _reorder_dp_leading(output, output.shape) return (output,) From 3d687393ece3c7c21f12e28b6d34aa7929b2dc44 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 08:22:33 -0700 Subject: [PATCH 05/22] mxfp8 + rs passed Signed-off-by: Phuong Nguyen --- examples/jax/collective_gemm/common.py | 2 -- .../jax/collective_gemm/run_test_cgemm.sh | 4 ++-- examples/jax/collective_gemm/test_gemm.py | 19 +++++++++---------- transformer_engine/jax/cpp_extensions/gemm.py | 10 ++++++++++ 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index 95452e85bc..e3904221a4 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -250,9 +250,7 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para "--tensor-parallel-size", type=int, default=None, help="Tensor parallel size" ) parser.add_argument("--batch-size", type=int, default=4, help="Batch size for testing") - # parser.add_argument("--batch-size", type=int, default=2, help="Batch size for testing") parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing") - # parser.add_argument("--seq-len", type=int, default=16384, help="Sequence length for testing") parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension") parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension") parser.add_argument( diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index b1a8816703..d5de55287b 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -35,8 +35,8 @@ TEST_CASES=( # "test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_all_gather_with_dp" # "test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_reduce_scatter_with_dp" # # TODO(Phuong): Enable when supported -"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" -# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" # # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" # # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" # diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index 2a968f7ba4..fed69ec636 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -263,16 +263,15 @@ def test_te_mxfp8_all_gather_with_dp(self): self.args.collective_type = "all_gather" run_gemm_tests(self.args, self.mesh) - # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported - # def test_te_mxfp8_reduce_scatter_with_dp(self): - # """Test Collective GEMM with MXFP8BlockScaling + ReduceScatter""" - # self.args.quantize_recipe = "MXFP8BlockScaling" - # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) - # if not is_supported: - # self.skipTest(reason) - # self.args.use_fp8 = True - # self.args.collective_type = "reduce_scatter" - # run_gemm_tests(self.args, self.mesh) + def test_te_mxfp8_reduce_scatter_with_dp(self): + """Test Collective GEMM with MXFP8BlockScaling + ReduceScatter""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" + run_gemm_tests(self.args, self.mesh) # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported # def test_te_nvfp4_all_gather_with_dp(self): diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 9a47dc3e25..364bd9f905 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -703,6 +703,16 @@ def impl( assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" lhs = _reorder_tpsp_leading(lhs, lhs.shape) + if ( + collective_op.is_reduce_scatter + and not transpose_batch_sequence + and not is_outer + and not lhs_scale_inv.shape[0] == 1 + and scaling_mode.is_1d_block_scaling() + ): + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + lhs_scale_inv = _reorder_tpsp_leading(lhs_scale_inv, lhs_scale_inv.shape) + if ( collective_op.is_all_gather and not transpose_batch_sequence From f53cebbf087b381bd63524b4cf47673b54b8ead1 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 08:55:02 -0700 Subject: [PATCH 06/22] simplify the conditions Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 35 ++++--------------- 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 364bd9f905..19de43aed6 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -693,33 +693,15 @@ def impl( lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) + # Determine if we need to reorder the tensor so that the input/output are in the correct layout for the collective operation + need_reorder = not transpose_batch_sequence and not is_outer and not collective_op.is_none + # Alter lhs blocks so that CGEMM RS outputs correctly - if ( - collective_op.is_reduce_scatter - and not transpose_batch_sequence - and not is_outer - and not lhs.shape[0] == 1 - ): + if need_reorder and collective_op.is_reduce_scatter and lhs.shape[0] != 1: assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" lhs = _reorder_tpsp_leading(lhs, lhs.shape) - if ( - collective_op.is_reduce_scatter - and not transpose_batch_sequence - and not is_outer - and not lhs_scale_inv.shape[0] == 1 - and scaling_mode.is_1d_block_scaling() - ): - assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" - lhs_scale_inv = _reorder_tpsp_leading(lhs_scale_inv, lhs_scale_inv.shape) - - if ( - collective_op.is_all_gather - and not transpose_batch_sequence - and not is_outer - and not lhs_scale_inv.shape[0] == 1 - and scaling_mode.is_1d_block_scaling() - ): + if need_reorder and (collective_op.is_reduce_scatter or collective_op.is_all_gather) and lhs_scale_inv.shape[0] != 1 and scaling_mode.is_1d_block_scaling(): assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" lhs_scale_inv = _reorder_tpsp_leading(lhs_scale_inv, lhs_scale_inv.shape) @@ -741,12 +723,7 @@ def impl( collective_op=collective_op, ) # Alter output blocks for CGEMM AG - if ( - collective_op.is_all_gather - and not transpose_batch_sequence - and not is_outer - and not output.shape[0] == 1 - ): + if need_reorder and collective_op.is_all_gather and output.shape[0] != 1: assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" output = _reorder_dp_leading(output, output.shape) From 07ac8896767c65e0cf5ec3a2726557f4951e9fea Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 10:30:35 -0700 Subject: [PATCH 07/22] added size check for mxfp8 Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 19de43aed6..6441cfa8ac 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -681,12 +681,26 @@ def impl( lhs_flatten_axis = max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims) rhs_flatten_axis = min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1 - lhs_scale_inv = apply_padding_to_scale_inv( - lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis - ) - rhs_scale_inv = apply_padding_to_scale_inv( - rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis - ) + if not collective_op.is_none and not is_outer: + # MXFP8 + Collective AG/RS: both sides of flatten_axis must be multiples of 128. + # No padding is needed in this case + lhs_first, lhs_last = math.prod(lhs.shape[:lhs_flatten_axis]), math.prod(lhs.shape[lhs_flatten_axis:]) + assert lhs_first % 128 == 0 and lhs_last % 128 == 0, ( + f"MXFP8 + Collective AG requires LHS dimensions before and after the flatten axis to be multiples of 128. " + f"Got lhs.shape={lhs.shape}, lhs_flatten_axis={lhs_flatten_axis}" + ) + # The scale needs to be in good shape for reordering + assert lhs_scale_inv.shape[sequence_dim] % tpsp_axis_size() == 0, ( + f"MXFP8 + Collective AG/RS requires LHS scale inv sequence dimension to be multiples of tpsp_axis_size. " + f"Got lhs_scale_inv.shape={lhs_scale_inv.shape}, tpsp_axis_size={tpsp_axis_size()}, sequence_dim={sequence_dim}" + ) + else: + lhs_scale_inv = apply_padding_to_scale_inv( + lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis, + ) + rhs_scale_inv = apply_padding_to_scale_inv( + rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis + ) # Only perform JAX-based swizzle for MXFP8, NVFP4 swizzle will go though nvte kernel if scaling_mode.is_mxfp8_scaling: @@ -963,9 +977,6 @@ def _parse_operand_output_specs( lhs_scale_specs = tuple(None if i == sequence_dim else s for i, s in enumerate(lhs_specs)) else: lhs_scale_specs = lhs_specs - print(lhs_scale_specs) - print(rhs_scale_specs) - if not collective_op.is_none: if sequence_dim < 0: From 593d7ae983f4c5de35b5c8678d88d26a917a7745 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 11:12:03 -0700 Subject: [PATCH 08/22] added tols for assertions Signed-off-by: Phuong Nguyen --- examples/jax/collective_gemm/common.py | 14 ++++++++++++++ examples/jax/collective_gemm/run_test_cgemm.sh | 2 +- examples/jax/collective_gemm/test_dense_grad.py | 6 ++++-- examples/jax/collective_gemm/test_gemm.py | 3 ++- .../jax/collective_gemm/test_layernorm_mlp_grad.py | 6 ++++-- 5 files changed, 25 insertions(+), 6 deletions(-) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index e3904221a4..355b5f11b7 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -21,10 +21,24 @@ def dtype_tols(dtype, rtol=None, atol=None): return {"rtol": 1e-3, "atol": 1e-6} elif dtype in [jnp.bfloat16, "bfloat16"]: return {"rtol": 1e-2, "atol": 1e-5} + elif dtype in [jnp.float8_e4m3fn, "float8_e4m3fn", jnp.float8_e5m2, "float8_e5m2"]: + # FP8 quantization introduces ~1% error; match C++ getTolerances for fp8 types + return {"rtol": 1e-2, "atol": 1e-2} else: return {"rtol": 1e-5, "atol": 1e-8} +def get_tolerance_dtype(quantizer_set): + """Return the dtype used to select numerical tolerances based on the active quantizer. + + Reads q_dtype from quantizer_set.x; falls back to bfloat16 when no quantizer is + active (NO_SCALING / noop path, where quantizer_set.x is None). + """ + if quantizer_set.x is not None: + return quantizer_set.x.q_dtype + return jnp.bfloat16 + + def assert_allclose( actual, desired, diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index d5de55287b..04553a0174 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -35,7 +35,7 @@ TEST_CASES=( # "test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_all_gather_with_dp" # "test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_reduce_scatter_with_dp" # # TODO(Phuong): Enable when supported -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" # # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" # # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index dcb51b34f4..7c2ec5d607 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -13,6 +13,7 @@ from common import ( assert_allclose, + get_tolerance_dtype, _initialize_distributed, _get_dp_and_tp_sizes, _create_mesh, @@ -167,9 +168,10 @@ def run_dense_grad_tests(args, mesh=None): jax.block_until_ready(gathered_ref_grads) if args.enable_result_check and args.process_id == 0: - assert_allclose(ref_output, output, dtype=jnp.bfloat16) + tol_dtype = get_tolerance_dtype(quantizer_set) + assert_allclose(ref_output, output, dtype=tol_dtype) for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): - assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16) + assert_allclose(ref_grad, gathered_grad, dtype=tol_dtype) class TestCollectiveDenseGradient(unittest.TestCase): diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index fed69ec636..3df85ab87c 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -22,6 +22,7 @@ from common import ( assert_allclose, + get_tolerance_dtype, _initialize_distributed, _get_dp_and_tp_sizes, _create_mesh, @@ -169,7 +170,7 @@ def run_gemm_tests(args, mesh=None): jax.block_until_ready(gathered_output) if args.enable_result_check and args.process_id == 0: - assert_allclose(gathered_ref_output, gathered_output) + assert_allclose(gathered_ref_output, gathered_output, dtype=get_tolerance_dtype(quantizer_set)) class TestCollectiveGemmWithDP(unittest.TestCase): diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index ebf55ea5df..927f3e99b2 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -13,6 +13,7 @@ from common import ( assert_allclose, + get_tolerance_dtype, _initialize_distributed, _get_dp_and_tp_sizes, _create_mesh, @@ -229,9 +230,10 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): jax.block_until_ready(gathered_ref_grads) if args.enable_result_check and args.process_id == 0: - assert_allclose(ref_output, output, dtype=jnp.bfloat16) + tol_dtype = get_tolerance_dtype(quantizer_set) + assert_allclose(ref_output, output, dtype=tol_dtype) for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): - assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16) + assert_allclose(ref_grad, gathered_grad, dtype=tol_dtype) class TestCollectiveLayerNormMLPGradient(unittest.TestCase): From e96e86b39fb10cad10adeccbd6cdf6e6cb11cc52 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 11:26:07 -0700 Subject: [PATCH 09/22] update tests with recipes Signed-off-by: Phuong Nguyen --- examples/jax/collective_gemm/common.py | 12 +++++----- .../jax/collective_gemm/test_dense_grad.py | 20 +++++++---------- examples/jax/collective_gemm/test_gemm.py | 22 +++++++++---------- .../test_layernorm_mlp_grad.py | 14 +++++------- 4 files changed, 29 insertions(+), 39 deletions(-) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index 355b5f11b7..483f3e60af 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -275,7 +275,11 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para help="Type of collective operation", ) parser.add_argument( - "--quantize-recipe", type=str, default="DelayedScaling", help="Quantization recipe to use" + "--quantize-recipe", + type=str, + default=None, + choices=["DelayedScaling", "Float8CurrentScaling", "MXFP8BlockScaling", "NVFP4BlockScaling"], + help="Quantization recipe to use. Omit for BF16 (no quantization).", ) parser.add_argument( "--enable-data-parallel", action="store_true", help="Enable data parallelism" @@ -283,11 +287,5 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para parser.add_argument( "--enable-result-check", action="store_true", default=True, help="Enable result checking" ) - parser.add_argument( - "--use-fp8", - action="store_true", - default=False, - help="Enable FP8 quantization", - ) return parser diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index 7c2ec5d607..adc97b1790 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -111,16 +111,16 @@ def run_dense_grad_tests(args, mesh=None): ) collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op) - use_fp8 = getattr(args, "use_fp8", False) - recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_fp8 else None + use_quantization = args.quantize_recipe is not None + recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None with mesh, autocast( - enabled=use_fp8, + enabled=use_quantization, recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): # Build quantizer_set inside autocast so create_set() reads the global recipe # for correct fwd/bwd dtypes. - quantizer_set = QuantizerFactory.create_set() if use_fp8 else noop_quantizer_set + quantizer_set = QuantizerFactory.create_set() if use_quantization else noop_quantizer_set # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) @@ -216,7 +216,7 @@ def test_te_delayed_scaling_fp8_all_gather(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "all_gather" run_dense_grad_tests(self.args, self.mesh) @@ -228,7 +228,7 @@ def test_te_delayed_scaling_fp8_reduce_scatter(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" run_dense_grad_tests(self.args, self.mesh) @@ -240,7 +240,7 @@ def test_te_current_scaling_fp8_all_gather(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "all_gather" run_dense_grad_tests(self.args, self.mesh) @@ -252,7 +252,7 @@ def test_te_current_scaling_fp8_reduce_scatter(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" run_dense_grad_tests(self.args, self.mesh) @@ -263,7 +263,6 @@ def test_te_current_scaling_fp8_reduce_scatter(self): # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) # if not is_supported: # self.skipTest(reason) - # self.args.use_fp8 = True # self.args.collective_type = "all_gather" # run_dense_grad_tests(self.args, self.mesh) @@ -274,7 +273,6 @@ def test_te_current_scaling_fp8_reduce_scatter(self): # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) # if not is_supported: # self.skipTest(reason) - # self.args.use_fp8 = True # self.args.collective_type = "reduce_scatter" # run_dense_grad_tests(self.args, self.mesh) @@ -285,7 +283,6 @@ def test_te_current_scaling_fp8_reduce_scatter(self): # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) # if not is_supported: # self.skipTest(reason) - # self.args.use_fp8 = True # self.args.collective_type = "all_gather" # run_dense_grad_tests(self.args, self.mesh) @@ -296,7 +293,6 @@ def test_te_current_scaling_fp8_reduce_scatter(self): # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) # if not is_supported: # self.skipTest(reason) - # self.args.use_fp8 = True # self.args.collective_type = "reduce_scatter" # run_dense_grad_tests(self.args, self.mesh) diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index 3df85ab87c..d06fa9e75e 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -116,20 +116,20 @@ def run_gemm_tests(args, mesh=None): else CollectiveOp.REDUCE_SCATTER ) - use_fp8 = getattr(args, "use_fp8", False) - recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_fp8 else None + use_quantization = args.quantize_recipe is not None + recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None # autocast sets the global recipe (fwd/bwd dtypes) AND the global MeshResource # (via global_shard_guard) required for collective GEMM sharding axis resolution. with mesh, autocast( - enabled=use_fp8, + enabled=use_quantization, recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): # Build quantizer_set inside autocast so create_set() can read the global recipe # for correct fwd/bwd dtypes. autocast does not inject quantizers into raw # tex.gemm() calls, so we must pass quantizer_set explicitly. - quantizer_set = QuantizerFactory.create_set() if use_fp8 else noop_quantizer_set + quantizer_set = QuantizerFactory.create_set() if use_quantization else noop_quantizer_set print(f"Device mesh: {mesh}") x_sharding, weight_sharding, bias_sharding, output_sharding = _get_operand_sharding( @@ -214,7 +214,7 @@ def test_te_delayed_scaling_fp8_all_gather_with_dp(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "all_gather" run_gemm_tests(self.args, self.mesh) @@ -226,7 +226,7 @@ def test_te_delayed_scaling_fp8_reduce_scatter_with_dp(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" run_gemm_tests(self.args, self.mesh) @@ -238,7 +238,7 @@ def test_te_current_scaling_fp8_all_gather_with_dp(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "all_gather" run_gemm_tests(self.args, self.mesh) @@ -250,7 +250,7 @@ def test_te_current_scaling_fp8_reduce_scatter_with_dp(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" run_gemm_tests(self.args, self.mesh) @@ -260,7 +260,7 @@ def test_te_mxfp8_all_gather_with_dp(self): is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "all_gather" run_gemm_tests(self.args, self.mesh) @@ -270,7 +270,7 @@ def test_te_mxfp8_reduce_scatter_with_dp(self): is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" run_gemm_tests(self.args, self.mesh) @@ -281,7 +281,6 @@ def test_te_mxfp8_reduce_scatter_with_dp(self): # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) # if not is_supported: # self.skipTest(reason) - # self.args.use_fp8 = True # self.args.collective_type = "all_gather" # run_gemm_tests(self.args, self.mesh) @@ -292,7 +291,6 @@ def test_te_mxfp8_reduce_scatter_with_dp(self): # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) # if not is_supported: # self.skipTest(reason) - # self.args.use_fp8 = True # self.args.collective_type = "reduce_scatter" # run_gemm_tests(self.args, self.mesh) diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 927f3e99b2..599c88e3d9 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -161,16 +161,16 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): collective_op_sets = (collective_op_set_1, collective_op_set_2) noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set) - use_fp8 = getattr(args, "use_fp8", False) - recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_fp8 else None + use_quantization = args.quantize_recipe is not None + recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None with mesh, autocast( - enabled=use_fp8, + enabled=use_quantization, recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): # Build quantizer_set inside autocast so create_set() reads the global recipe # for correct fwd/bwd dtypes. One set per dense layer (GEMM1=AG, GEMM2=RS). - quantizer_set = QuantizerFactory.create_set() if use_fp8 else noop_quantizer_set + quantizer_set = QuantizerFactory.create_set() if use_quantization else noop_quantizer_set quantizer_sets = (quantizer_set, quantizer_set) # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() @@ -272,7 +272,7 @@ def test_te_delayed_scaling_fp8_layernorm_mlp_grad(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + run_layernorm_mlp_grad_tests(self.args, self.mesh) def test_te_current_scaling_fp8_layernorm_mlp_grad(self): @@ -283,7 +283,7 @@ def test_te_current_scaling_fp8_layernorm_mlp_grad(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + run_layernorm_mlp_grad_tests(self.args, self.mesh) # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported @@ -293,7 +293,6 @@ def test_te_current_scaling_fp8_layernorm_mlp_grad(self): # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) # if not is_supported: # self.skipTest(reason) - # self.args.use_fp8 = True # run_layernorm_mlp_grad_tests(self.args, self.mesh) # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported @@ -303,7 +302,6 @@ def test_te_current_scaling_fp8_layernorm_mlp_grad(self): # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) # if not is_supported: # self.skipTest(reason) - # self.args.use_fp8 = True # run_layernorm_mlp_grad_tests(self.args, self.mesh) From 1d49bd581fd2f5d66d536dd94b95b4efefa71ae1 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 11:26:25 -0700 Subject: [PATCH 10/22] shape check when padding mxfp8 scales Signed-off-by: Phuong Nguyen --- transformer_engine/jax/quantize/helper.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 4ef9426433..f922671def 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -917,11 +917,10 @@ def apply_padding_to_scale_inv( data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis ) - # TODO - # assert scale_inv.shape == unpadded_scale_shape, ( - # f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got " - # f"{scale_inv.shape}." - # ) + assert scale_inv.shape == unpadded_scale_shape, ( + f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got " + f"{scale_inv.shape}." + ) # Pad the scales with the lowest representable value (2^-127) and return pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape)) From 16adffe21fc34793968d463e4c7bdad95723e0e3 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 11:38:20 -0700 Subject: [PATCH 11/22] cleanup Signed-off-by: Phuong Nguyen --- .../jax/collective_gemm/run_test_cgemm.sh | 41 ++++++++----------- .../jax/collective_gemm/test_dense_grad.py | 36 ++++++++-------- examples/jax/collective_gemm/test_gemm.py | 2 - .../test_layernorm_mlp_grad.py | 16 ++++---- 4 files changed, 40 insertions(+), 55 deletions(-) diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index 04553a0174..a098515af9 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -28,38 +28,31 @@ fi # the time. TEST_CASES=( # test_gemm.py cases -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_all_gather_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_reduce_scatter_with_dp" -# # TODO(Phuong): Enable when supported +"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp" "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" # # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" # # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" # # # test_dense_grad.py cases -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_all_gather" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_reduce_scatter" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter" -# # TODO(Phuong): Enable when supported -# # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather" -# # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter" -# # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_all_gather" -# # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_reduce_scatter" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_reduce_scatter" # # # test_layernorm_mlp_grad.py cases -# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad" -# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad" -# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_current_scaling_fp8_layernorm_mlp_grad" -# # TODO(Phuong): Enable when supported -# # "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_mxfp8_layernorm_mlp_grad" -# # "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_nvfp4_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_current_scaling_fp8_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_mxfp8_layernorm_mlp_grad" +# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_nvfp4_layernorm_mlp_grad" ) echo diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index adc97b1790..35a4b9457f 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -256,27 +256,24 @@ def test_te_current_scaling_fp8_reduce_scatter(self): self.args.collective_type = "reduce_scatter" run_dense_grad_tests(self.args, self.mesh) - # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported - # def test_te_mxfp8_all_gather(self): - # """Test Collective Dense Gradient with MXFP8BlockScaling + AllGather""" - # self.args.quantize_recipe = "MXFP8BlockScaling" - # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) - # if not is_supported: - # self.skipTest(reason) - # self.args.collective_type = "all_gather" - # run_dense_grad_tests(self.args, self.mesh) + def test_te_mxfp8_all_gather(self): + """Test Collective Dense Gradient with MXFP8BlockScaling + AllGather""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.collective_type = "all_gather" + run_dense_grad_tests(self.args, self.mesh) - # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported - # def test_te_mxfp8_reduce_scatter(self): - # """Test Collective Dense Gradient with MXFP8BlockScaling + ReduceScatter""" - # self.args.quantize_recipe = "MXFP8BlockScaling" - # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) - # if not is_supported: - # self.skipTest(reason) - # self.args.collective_type = "reduce_scatter" - # run_dense_grad_tests(self.args, self.mesh) + def test_te_mxfp8_reduce_scatter(self): + """Test Collective Dense Gradient with MXFP8BlockScaling + ReduceScatter""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.collective_type = "reduce_scatter" + run_dense_grad_tests(self.args, self.mesh) - # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported # def test_te_nvfp4_all_gather(self): # """Test Collective Dense Gradient with NVFP4BlockScaling + AllGather""" # self.args.quantize_recipe = "NVFP4BlockScaling" @@ -286,7 +283,6 @@ def test_te_current_scaling_fp8_reduce_scatter(self): # self.args.collective_type = "all_gather" # run_dense_grad_tests(self.args, self.mesh) - # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported # def test_te_nvfp4_reduce_scatter(self): # """Test Collective Dense Gradient with NVFP4BlockScaling + ReduceScatter""" # self.args.quantize_recipe = "NVFP4BlockScaling" diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index d06fa9e75e..c969062376 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -274,7 +274,6 @@ def test_te_mxfp8_reduce_scatter_with_dp(self): self.args.collective_type = "reduce_scatter" run_gemm_tests(self.args, self.mesh) - # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported # def test_te_nvfp4_all_gather_with_dp(self): # """Test Collective GEMM with NVFP4BlockScaling + AllGather""" # self.args.quantize_recipe = "NVFP4BlockScaling" @@ -284,7 +283,6 @@ def test_te_mxfp8_reduce_scatter_with_dp(self): # self.args.collective_type = "all_gather" # run_gemm_tests(self.args, self.mesh) - # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported # def test_te_nvfp4_reduce_scatter_with_dp(self): # """Test Collective GEMM with NVFP4BlockScaling + ReduceScatter""" # self.args.quantize_recipe = "NVFP4BlockScaling" diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 599c88e3d9..fe245f436b 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -286,16 +286,14 @@ def test_te_current_scaling_fp8_layernorm_mlp_grad(self): run_layernorm_mlp_grad_tests(self.args, self.mesh) - # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported - # def test_te_mxfp8_layernorm_mlp_grad(self): - # """Test Collective LayerNorm MLP Gradient with MXFP8BlockScaling""" - # self.args.quantize_recipe = "MXFP8BlockScaling" - # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) - # if not is_supported: - # self.skipTest(reason) - # run_layernorm_mlp_grad_tests(self.args, self.mesh) + def test_te_mxfp8_layernorm_mlp_grad(self): + """Test Collective LayerNorm MLP Gradient with MXFP8BlockScaling""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + run_layernorm_mlp_grad_tests(self.args, self.mesh) - # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported # def test_te_nvfp4_layernorm_mlp_grad(self): # """Test Collective LayerNorm MLP Gradient with NVFP4BlockScaling""" # self.args.quantize_recipe = "NVFP4BlockScaling" From 4ba6974e8e5a4ac648719323ea01e15b03a2cd5d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Mar 2026 18:39:54 +0000 Subject: [PATCH 12/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/jax/collective_gemm/common.py | 7 +++- .../jax/collective_gemm/test_dense_grad.py | 12 ++++-- examples/jax/collective_gemm/test_gemm.py | 16 ++++++-- .../test_layernorm_mlp_grad.py | 8 +++- transformer_engine/jax/cpp_extensions/gemm.py | 37 ++++++++++++++----- .../jax/csrc/extensions/gemm.cpp | 3 +- 6 files changed, 62 insertions(+), 21 deletions(-) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index 483f3e60af..d0fad8c48b 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -278,7 +278,12 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para "--quantize-recipe", type=str, default=None, - choices=["DelayedScaling", "Float8CurrentScaling", "MXFP8BlockScaling", "NVFP4BlockScaling"], + choices=[ + "DelayedScaling", + "Float8CurrentScaling", + "MXFP8BlockScaling", + "NVFP4BlockScaling", + ], help="Quantization recipe to use. Omit for BF16 (no quantization).", ) parser.add_argument( diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index 35a4b9457f..b6a5422470 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -112,7 +112,9 @@ def run_dense_grad_tests(args, mesh=None): collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op) use_quantization = args.quantize_recipe is not None - recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + recipe = ( + get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + ) with mesh, autocast( enabled=use_quantization, recipe=recipe, @@ -259,7 +261,9 @@ def test_te_current_scaling_fp8_reduce_scatter(self): def test_te_mxfp8_all_gather(self): """Test Collective Dense Gradient with MXFP8BlockScaling + AllGather""" self.args.quantize_recipe = "MXFP8BlockScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.collective_type = "all_gather" @@ -268,7 +272,9 @@ def test_te_mxfp8_all_gather(self): def test_te_mxfp8_reduce_scatter(self): """Test Collective Dense Gradient with MXFP8BlockScaling + ReduceScatter""" self.args.quantize_recipe = "MXFP8BlockScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.collective_type = "reduce_scatter" diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index c969062376..8f0e9a44cf 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -117,7 +117,9 @@ def run_gemm_tests(args, mesh=None): ) use_quantization = args.quantize_recipe is not None - recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + recipe = ( + get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + ) # autocast sets the global recipe (fwd/bwd dtypes) AND the global MeshResource # (via global_shard_guard) required for collective GEMM sharding axis resolution. @@ -170,7 +172,9 @@ def run_gemm_tests(args, mesh=None): jax.block_until_ready(gathered_output) if args.enable_result_check and args.process_id == 0: - assert_allclose(gathered_ref_output, gathered_output, dtype=get_tolerance_dtype(quantizer_set)) + assert_allclose( + gathered_ref_output, gathered_output, dtype=get_tolerance_dtype(quantizer_set) + ) class TestCollectiveGemmWithDP(unittest.TestCase): @@ -257,7 +261,9 @@ def test_te_current_scaling_fp8_reduce_scatter_with_dp(self): def test_te_mxfp8_all_gather_with_dp(self): """Test Collective GEMM with MXFP8BlockScaling + AllGather""" self.args.quantize_recipe = "MXFP8BlockScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) @@ -267,7 +273,9 @@ def test_te_mxfp8_all_gather_with_dp(self): def test_te_mxfp8_reduce_scatter_with_dp(self): """Test Collective GEMM with MXFP8BlockScaling + ReduceScatter""" self.args.quantize_recipe = "MXFP8BlockScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index fe245f436b..f242840ba0 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -162,7 +162,9 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set) use_quantization = args.quantize_recipe is not None - recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + recipe = ( + get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + ) with mesh, autocast( enabled=use_quantization, recipe=recipe, @@ -289,7 +291,9 @@ def test_te_current_scaling_fp8_layernorm_mlp_grad(self): def test_te_mxfp8_layernorm_mlp_grad(self): """Test Collective LayerNorm MLP Gradient with MXFP8BlockScaling""" self.args.quantize_recipe = "MXFP8BlockScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) run_layernorm_mlp_grad_tests(self.args, self.mesh) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 6441cfa8ac..a8d2f2aa69 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -684,19 +684,27 @@ def impl( if not collective_op.is_none and not is_outer: # MXFP8 + Collective AG/RS: both sides of flatten_axis must be multiples of 128. # No padding is needed in this case - lhs_first, lhs_last = math.prod(lhs.shape[:lhs_flatten_axis]), math.prod(lhs.shape[lhs_flatten_axis:]) + lhs_first, lhs_last = math.prod(lhs.shape[:lhs_flatten_axis]), math.prod( + lhs.shape[lhs_flatten_axis:] + ) assert lhs_first % 128 == 0 and lhs_last % 128 == 0, ( - f"MXFP8 + Collective AG requires LHS dimensions before and after the flatten axis to be multiples of 128. " - f"Got lhs.shape={lhs.shape}, lhs_flatten_axis={lhs_flatten_axis}" + "MXFP8 + Collective AG requires LHS dimensions before and after the flatten" + f" axis to be multiples of 128. Got lhs.shape={lhs.shape}," + f" lhs_flatten_axis={lhs_flatten_axis}" ) # The scale needs to be in good shape for reordering assert lhs_scale_inv.shape[sequence_dim] % tpsp_axis_size() == 0, ( - f"MXFP8 + Collective AG/RS requires LHS scale inv sequence dimension to be multiples of tpsp_axis_size. " - f"Got lhs_scale_inv.shape={lhs_scale_inv.shape}, tpsp_axis_size={tpsp_axis_size()}, sequence_dim={sequence_dim}" + "MXFP8 + Collective AG/RS requires LHS scale inv sequence dimension to be" + f" multiples of tpsp_axis_size. Got lhs_scale_inv.shape={lhs_scale_inv.shape}," + f" tpsp_axis_size={tpsp_axis_size()}, sequence_dim={sequence_dim}" ) else: lhs_scale_inv = apply_padding_to_scale_inv( - lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis, + lhs_scale_inv, + scaling_mode, + lhs.shape, + lhs_transposed, + lhs_flatten_axis, ) rhs_scale_inv = apply_padding_to_scale_inv( rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis @@ -715,7 +723,12 @@ def impl( assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" lhs = _reorder_tpsp_leading(lhs, lhs.shape) - if need_reorder and (collective_op.is_reduce_scatter or collective_op.is_all_gather) and lhs_scale_inv.shape[0] != 1 and scaling_mode.is_1d_block_scaling(): + if ( + need_reorder + and (collective_op.is_reduce_scatter or collective_op.is_all_gather) + and lhs_scale_inv.shape[0] != 1 + and scaling_mode.is_1d_block_scaling() + ): assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" lhs_scale_inv = _reorder_tpsp_leading(lhs_scale_inv, lhs_scale_inv.shape) @@ -974,7 +987,9 @@ def _parse_operand_output_specs( if scaling_mode.is_1d_block_scaling(): rhs_scale_specs = rhs_specs if collective_op.is_all_gather: - lhs_scale_specs = tuple(None if i == sequence_dim else s for i, s in enumerate(lhs_specs)) + lhs_scale_specs = tuple( + None if i == sequence_dim else s for i, s in enumerate(lhs_specs) + ) else: lhs_scale_specs = lhs_specs @@ -1012,7 +1027,11 @@ def infer_sharding_from_operands( ) (_, out_specs, *_) = GemmPrimitive._parse_operand_output_specs( - arg_infos, contracting_dims, transpose_batch_sequence, collective_op, scaling_mode, + arg_infos, + contracting_dims, + transpose_batch_sequence, + collective_op, + scaling_mode, ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 268307ea83..2acefa2d30 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -74,8 +74,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } input.set_with_gemm_swizzled_scales(true); - } - else if (is_nvfp4) { // Swizzle for NVFP4 + } else if (is_nvfp4) { // Swizzle for NVFP4 NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS"); input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); // Create tensor to hold swizzled scale factor From 463d9dae70ca2297aec248e72695133a90686b2b Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 14:54:23 -0700 Subject: [PATCH 13/22] enable tests + is_quantize_recipe_supported Signed-off-by: Phuong Nguyen --- examples/jax/collective_gemm/common.py | 80 ++++--------------- .../jax/collective_gemm/test_dense_grad.py | 32 +++----- examples/jax/collective_gemm/test_gemm.py | 32 +++----- .../test_layernorm_mlp_grad.py | 18 ++--- transformer_engine/jax/quantize/helper.py | 27 +++++++ 5 files changed, 64 insertions(+), 125 deletions(-) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index d0fad8c48b..3028f96672 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -1,20 +1,24 @@ # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -"""Shared functions for the comm_overlap tests""" +"""Shared functions for the collective GEMM tests""" +import argparse + +import jax import jax.numpy as jnp import numpy as np +from jax.experimental import mesh_utils + +from transformer_engine.common import recipe as te_recipe +from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap -# Add this after your existing imports def dtype_tols(dtype, rtol=None, atol=None): """Expected numerical tolerance for a data type.""" - # Return immediately if tolerances are fully specified if rtol is not None and atol is not None: return {"rtol": rtol, "atol": atol} - # Default tolerances for common dtypes if dtype in [jnp.float32, "float32"]: return {"rtol": 1e-5, "atol": 1e-8} elif dtype in [jnp.float16, "float16"]: @@ -39,23 +43,11 @@ def get_tolerance_dtype(quantizer_set): return jnp.bfloat16 -def assert_allclose( - actual, - desired, - rtol=None, - atol=None, - dtype=None, - **kwargs, -): +def assert_allclose(actual, desired, rtol=None, atol=None, dtype=None, **kwargs): """Check if two tensors are close.""" - # Infer data type if needed if dtype is None: - if isinstance(actual, float): - dtype = "float32" - else: - dtype = actual.dtype + dtype = "float32" if isinstance(actual, float) else actual.dtype - # Determine tolerances tols = {} if rtol is None or atol is None: tols = dtype_tols(dtype) @@ -64,51 +56,26 @@ def assert_allclose( if atol is not None: tols["atol"] = atol - # Cast tensors to fp32 if not isinstance(actual, float): actual = actual.astype(jnp.float32) if not isinstance(desired, float): desired = desired.astype(jnp.float32) - # Check if tensors are close np.testing.assert_allclose(actual, desired, **tols, **kwargs) -def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e-8): - if not jnp.allclose(ref_output, gathered_output, rtol=rtol, atol=atol): - diff = jnp.abs(ref_output - gathered_output) - mask = diff > (atol + rtol * jnp.abs(gathered_output)) - print(mask.astype(int)) - print(jnp.where(mask, diff, 0)) - - -# Shared constants for all tests +# Shared constants DP_AXIS = "data" TPSP_AXIS = "tensor_sequence" -PARAMS_KEY = "params" - -# Shared functions for distributed testing -import argparse -import jax -from jax.experimental import mesh_utils -from transformer_engine.common import recipe as te_recipe -from transformer_engine.jax.quantize import ScalingMode -from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap # Global flag to track if distributed has been initialized _distributed_initialized = False -def _is_distributed_initialized(): - """Check if JAX distributed has been initialized.""" - return _distributed_initialized - - def _initialize_distributed(args): """Initialize JAX distributed with custom arguments.""" global _distributed_initialized - # Check if already initialized if _distributed_initialized: return @@ -121,14 +88,10 @@ def _initialize_distributed(args): assert ( args.num_devices_per_process is not None ), "Either local_device_ids or num_devices_per_process must be provided" - # Calculate device range for this process - # Single process single device: each process gets one unique device - # Single process multiple devices: each process gets a unique range of devices start_device = args.process_id * args.num_devices_per_process device_range = range(start_device, start_device + args.num_devices_per_process) global_device_ids_for_this_process = ",".join(map(str, device_range)) else: - # Use explicitly provided global device IDs global_device_ids_for_this_process = args.local_device_ids args.num_devices_per_process = len(args.local_device_ids.split(",")) @@ -199,30 +162,15 @@ def _create_mesh(args): return mesh -def get_scaling_mode_from_recipe_name(name: str) -> ScalingMode: - """Get ScalingMode from a recipe name string.""" - match name: - case "DelayedScaling": - return ScalingMode.DELAYED_TENSOR_SCALING - case "Float8CurrentScaling": - return ScalingMode.CURRENT_TENSOR_SCALING - case "MXFP8BlockScaling": - return ScalingMode.MXFP8_1D_SCALING - case "NVFP4BlockScaling": - return ScalingMode.NVFP4_1D_SCALING - case _: - raise ValueError(f"Invalid recipe name, got {name}") - - def get_quantization_recipe_from_name_string(name: str): - """Query recipe from a given name string""" + """Return a recipe object from a recipe name string.""" match name: case "DelayedScaling": return te_recipe.DelayedScaling() - case "MXFP8BlockScaling": - return te_recipe.MXFP8BlockScaling() case "Float8CurrentScaling": return te_recipe.Float8CurrentScaling() + case "MXFP8BlockScaling": + return te_recipe.MXFP8BlockScaling() case "NVFP4BlockScaling": return te_recipe.NVFP4BlockScaling() case _: diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index b6a5422470..deac8b9c40 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -19,17 +19,15 @@ _create_mesh, DP_AXIS, TPSP_AXIS, - PARAMS_KEY, cgemm_parser, get_quantization_recipe_from_name_string, - get_scaling_mode_from_recipe_name, ) from transformer_engine.jax.dense import dense from transformer_engine.jax.quantize import ( autocast, - is_scaling_mode_supported, + is_quantize_recipe_supported, QuantizerFactory, noop_quantizer_set, ) @@ -213,9 +211,7 @@ def test_te_bf16_reduce_scatter(self): def test_te_delayed_scaling_fp8_all_gather(self): """Test Collective Dense Gradient with FP8 DelayedScaling + AllGather""" self.args.quantize_recipe = "DelayedScaling" - is_supported, reason = is_scaling_mode_supported( - get_scaling_mode_from_recipe_name(self.args.quantize_recipe) - ) + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) if not is_supported: self.skipTest(reason) @@ -225,9 +221,7 @@ def test_te_delayed_scaling_fp8_all_gather(self): def test_te_delayed_scaling_fp8_reduce_scatter(self): """Test Collective Dense Gradient with FP8 DelayedScaling + ReduceScatter""" self.args.quantize_recipe = "DelayedScaling" - is_supported, reason = is_scaling_mode_supported( - get_scaling_mode_from_recipe_name(self.args.quantize_recipe) - ) + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) if not is_supported: self.skipTest(reason) @@ -237,9 +231,7 @@ def test_te_delayed_scaling_fp8_reduce_scatter(self): def test_te_current_scaling_fp8_all_gather(self): """Test Collective Dense Gradient with FP8 Float8CurrentScaling + AllGather""" self.args.quantize_recipe = "Float8CurrentScaling" - is_supported, reason = is_scaling_mode_supported( - get_scaling_mode_from_recipe_name(self.args.quantize_recipe) - ) + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) if not is_supported: self.skipTest(reason) @@ -249,9 +241,7 @@ def test_te_current_scaling_fp8_all_gather(self): def test_te_current_scaling_fp8_reduce_scatter(self): """Test Collective Dense Gradient with FP8 Float8CurrentScaling + ReduceScatter""" self.args.quantize_recipe = "Float8CurrentScaling" - is_supported, reason = is_scaling_mode_supported( - get_scaling_mode_from_recipe_name(self.args.quantize_recipe) - ) + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) if not is_supported: self.skipTest(reason) @@ -261,9 +251,7 @@ def test_te_current_scaling_fp8_reduce_scatter(self): def test_te_mxfp8_all_gather(self): """Test Collective Dense Gradient with MXFP8BlockScaling + AllGather""" self.args.quantize_recipe = "MXFP8BlockScaling" - is_supported, reason = is_scaling_mode_supported( - get_scaling_mode_from_recipe_name(self.args.quantize_recipe) - ) + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) if not is_supported: self.skipTest(reason) self.args.collective_type = "all_gather" @@ -272,9 +260,7 @@ def test_te_mxfp8_all_gather(self): def test_te_mxfp8_reduce_scatter(self): """Test Collective Dense Gradient with MXFP8BlockScaling + ReduceScatter""" self.args.quantize_recipe = "MXFP8BlockScaling" - is_supported, reason = is_scaling_mode_supported( - get_scaling_mode_from_recipe_name(self.args.quantize_recipe) - ) + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) if not is_supported: self.skipTest(reason) self.args.collective_type = "reduce_scatter" @@ -283,7 +269,7 @@ def test_te_mxfp8_reduce_scatter(self): # def test_te_nvfp4_all_gather(self): # """Test Collective Dense Gradient with NVFP4BlockScaling + AllGather""" # self.args.quantize_recipe = "NVFP4BlockScaling" - # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) # if not is_supported: # self.skipTest(reason) # self.args.collective_type = "all_gather" @@ -292,7 +278,7 @@ def test_te_mxfp8_reduce_scatter(self): # def test_te_nvfp4_reduce_scatter(self): # """Test Collective Dense Gradient with NVFP4BlockScaling + ReduceScatter""" # self.args.quantize_recipe = "NVFP4BlockScaling" - # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) # if not is_supported: # self.skipTest(reason) # self.args.collective_type = "reduce_scatter" diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index 8f0e9a44cf..3dc3154a44 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -28,16 +28,14 @@ _create_mesh, DP_AXIS, TPSP_AXIS, - PARAMS_KEY, cgemm_parser, get_quantization_recipe_from_name_string, - get_scaling_mode_from_recipe_name, ) import transformer_engine.jax.cpp_extensions as tex from transformer_engine.jax.quantize import ( autocast, - is_scaling_mode_supported, + is_quantize_recipe_supported, QuantizerFactory, noop_quantizer_set, ) @@ -213,9 +211,7 @@ def test_te_bf16_reduce_scatter_with_dp(self): def test_te_delayed_scaling_fp8_all_gather_with_dp(self): """Test Collective GEMM with FP8 DelayedScaling + AllGather""" self.args.quantize_recipe = "DelayedScaling" - is_supported, reason = is_scaling_mode_supported( - get_scaling_mode_from_recipe_name(self.args.quantize_recipe) - ) + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) if not is_supported: self.skipTest(reason) @@ -225,9 +221,7 @@ def test_te_delayed_scaling_fp8_all_gather_with_dp(self): def test_te_delayed_scaling_fp8_reduce_scatter_with_dp(self): """Test Collective GEMM with FP8 DelayedScaling + ReduceScatter""" self.args.quantize_recipe = "DelayedScaling" - is_supported, reason = is_scaling_mode_supported( - get_scaling_mode_from_recipe_name(self.args.quantize_recipe) - ) + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) if not is_supported: self.skipTest(reason) @@ -237,9 +231,7 @@ def test_te_delayed_scaling_fp8_reduce_scatter_with_dp(self): def test_te_current_scaling_fp8_all_gather_with_dp(self): """Test Collective GEMM with FP8 Float8CurrentScaling + AllGather""" self.args.quantize_recipe = "Float8CurrentScaling" - is_supported, reason = is_scaling_mode_supported( - get_scaling_mode_from_recipe_name(self.args.quantize_recipe) - ) + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) if not is_supported: self.skipTest(reason) @@ -249,9 +241,7 @@ def test_te_current_scaling_fp8_all_gather_with_dp(self): def test_te_current_scaling_fp8_reduce_scatter_with_dp(self): """Test Collective GEMM with FP8 Float8CurrentScaling + ReduceScatter""" self.args.quantize_recipe = "Float8CurrentScaling" - is_supported, reason = is_scaling_mode_supported( - get_scaling_mode_from_recipe_name(self.args.quantize_recipe) - ) + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) if not is_supported: self.skipTest(reason) @@ -261,9 +251,7 @@ def test_te_current_scaling_fp8_reduce_scatter_with_dp(self): def test_te_mxfp8_all_gather_with_dp(self): """Test Collective GEMM with MXFP8BlockScaling + AllGather""" self.args.quantize_recipe = "MXFP8BlockScaling" - is_supported, reason = is_scaling_mode_supported( - get_scaling_mode_from_recipe_name(self.args.quantize_recipe) - ) + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) if not is_supported: self.skipTest(reason) @@ -273,9 +261,7 @@ def test_te_mxfp8_all_gather_with_dp(self): def test_te_mxfp8_reduce_scatter_with_dp(self): """Test Collective GEMM with MXFP8BlockScaling + ReduceScatter""" self.args.quantize_recipe = "MXFP8BlockScaling" - is_supported, reason = is_scaling_mode_supported( - get_scaling_mode_from_recipe_name(self.args.quantize_recipe) - ) + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) if not is_supported: self.skipTest(reason) @@ -285,7 +271,7 @@ def test_te_mxfp8_reduce_scatter_with_dp(self): # def test_te_nvfp4_all_gather_with_dp(self): # """Test Collective GEMM with NVFP4BlockScaling + AllGather""" # self.args.quantize_recipe = "NVFP4BlockScaling" - # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) # if not is_supported: # self.skipTest(reason) # self.args.collective_type = "all_gather" @@ -294,7 +280,7 @@ def test_te_mxfp8_reduce_scatter_with_dp(self): # def test_te_nvfp4_reduce_scatter_with_dp(self): # """Test Collective GEMM with NVFP4BlockScaling + ReduceScatter""" # self.args.quantize_recipe = "NVFP4BlockScaling" - # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) # if not is_supported: # self.skipTest(reason) # self.args.collective_type = "reduce_scatter" diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index f242840ba0..788c4efa37 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -19,17 +19,15 @@ _create_mesh, DP_AXIS, TPSP_AXIS, - PARAMS_KEY, cgemm_parser, get_quantization_recipe_from_name_string, - get_scaling_mode_from_recipe_name, ) from transformer_engine.jax.layernorm_mlp import layernorm_mlp from transformer_engine.jax.quantize import ( autocast, - is_scaling_mode_supported, + is_quantize_recipe_supported, QuantizerFactory, noop_quantizer_set, ) @@ -269,9 +267,7 @@ def test_te_bf16_layernorm_mlp_grad(self): def test_te_delayed_scaling_fp8_layernorm_mlp_grad(self): """Test Collective LayerNorm MLP Gradient with FP8 DelayedScaling""" self.args.quantize_recipe = "DelayedScaling" - is_supported, reason = is_scaling_mode_supported( - get_scaling_mode_from_recipe_name(self.args.quantize_recipe) - ) + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) if not is_supported: self.skipTest(reason) @@ -280,9 +276,7 @@ def test_te_delayed_scaling_fp8_layernorm_mlp_grad(self): def test_te_current_scaling_fp8_layernorm_mlp_grad(self): """Test Collective LayerNorm MLP Gradient with FP8 Float8CurrentScaling""" self.args.quantize_recipe = "Float8CurrentScaling" - is_supported, reason = is_scaling_mode_supported( - get_scaling_mode_from_recipe_name(self.args.quantize_recipe) - ) + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) if not is_supported: self.skipTest(reason) @@ -291,9 +285,7 @@ def test_te_current_scaling_fp8_layernorm_mlp_grad(self): def test_te_mxfp8_layernorm_mlp_grad(self): """Test Collective LayerNorm MLP Gradient with MXFP8BlockScaling""" self.args.quantize_recipe = "MXFP8BlockScaling" - is_supported, reason = is_scaling_mode_supported( - get_scaling_mode_from_recipe_name(self.args.quantize_recipe) - ) + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) if not is_supported: self.skipTest(reason) run_layernorm_mlp_grad_tests(self.args, self.mesh) @@ -301,7 +293,7 @@ def test_te_mxfp8_layernorm_mlp_grad(self): # def test_te_nvfp4_layernorm_mlp_grad(self): # """Test Collective LayerNorm MLP Gradient with NVFP4BlockScaling""" # self.args.quantize_recipe = "NVFP4BlockScaling" - # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) # if not is_supported: # self.skipTest(reason) # run_layernorm_mlp_grad_tests(self.args, self.mesh) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index f922671def..fa4d1841c0 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -51,6 +51,7 @@ "fp8_autocast", "is_fp8_available", "is_scaling_mode_supported", + "is_quantize_recipe_supported", "get_supported_scaling_modes", "get_supported_quantization_recipes", "update_collections", @@ -162,6 +163,32 @@ def is_scaling_mode_supported( return _is_scaling_mode_supported[scaling_mode], _reason_for_no_scaling_mode[scaling_mode] +_RECIPE_NAME_TO_SCALING_MODE = { + "DelayedScaling": ScalingMode.DELAYED_TENSOR_SCALING, + "Float8CurrentScaling": ScalingMode.CURRENT_TENSOR_SCALING, + "MXFP8BlockScaling": ScalingMode.MXFP8_1D_SCALING, + "NVFP4BlockScaling": ScalingMode.NVFP4_1D_SCALING, +} + + +def is_quantize_recipe_supported(recipe_name: str, gpu_id=None) -> Tuple[bool, str]: + """Check if the given quantization recipe (by name) is supported on the current GPU. + + Args: + recipe_name: Name of the recipe, e.g. "DelayedScaling", "Float8CurrentScaling", + "MXFP8BlockScaling", "NVFP4BlockScaling". + gpu_id: Optional GPU ID to check a specific device (default: all local devices). + + Returns: + A tuple of (supported: bool, reason: str). + """ + scaling_mode = _RECIPE_NAME_TO_SCALING_MODE.get(recipe_name) + if scaling_mode is None: + valid = list(_RECIPE_NAME_TO_SCALING_MODE) + return False, f"Unknown quantization recipe '{recipe_name}'. Valid options: {valid}" + return is_scaling_mode_supported(scaling_mode, gpu_id) + + def is_fp8_available( scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, gpu_id=None, From 43f3d31c5dfa2c745e14c36c948c20d50cbf7d09 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 15:33:43 -0700 Subject: [PATCH 14/22] cleanup Signed-off-by: Phuong Nguyen --- examples/jax/collective_gemm/common.py | 16 ---------- .../jax/collective_gemm/test_dense_grad.py | 7 ++--- examples/jax/collective_gemm/test_gemm.py | 21 ++------------ .../test_layernorm_mlp_grad.py | 11 ++++--- transformer_engine/jax/quantize/helper.py | 29 +++++++++++++++++++ 5 files changed, 39 insertions(+), 45 deletions(-) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index 3028f96672..6815932395 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -10,7 +10,6 @@ import numpy as np from jax.experimental import mesh_utils -from transformer_engine.common import recipe as te_recipe from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap @@ -162,21 +161,6 @@ def _create_mesh(args): return mesh -def get_quantization_recipe_from_name_string(name: str): - """Return a recipe object from a recipe name string.""" - match name: - case "DelayedScaling": - return te_recipe.DelayedScaling() - case "Float8CurrentScaling": - return te_recipe.Float8CurrentScaling() - case "MXFP8BlockScaling": - return te_recipe.MXFP8BlockScaling() - case "NVFP4BlockScaling": - return te_recipe.NVFP4BlockScaling() - case _: - raise ValueError(f"Invalid quantization_recipe, got {name}") - - def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor parallelism"): """Create common argument parser for all collective GEMM tests.""" parser = argparse.ArgumentParser(description=description) diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index deac8b9c40..a36efb2ae9 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -2,7 +2,6 @@ # # See LICENSE for license information. """Collective Dense Gradient test on multi-GPU with tensor parallelism""" -import argparse import unittest import os @@ -20,7 +19,6 @@ DP_AXIS, TPSP_AXIS, cgemm_parser, - get_quantization_recipe_from_name_string, ) from transformer_engine.jax.dense import dense @@ -28,6 +26,7 @@ from transformer_engine.jax.quantize import ( autocast, is_quantize_recipe_supported, + get_quantization_recipe, QuantizerFactory, noop_quantizer_set, ) @@ -111,7 +110,7 @@ def run_dense_grad_tests(args, mesh=None): use_quantization = args.quantize_recipe is not None recipe = ( - get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + get_quantization_recipe(args.quantize_recipe) if use_quantization else None ) with mesh, autocast( enabled=use_quantization, @@ -306,6 +305,6 @@ def test_te_mxfp8_reduce_scatter(self): args = cgemm_parser( "Collective Dense Gradient test on multi-GPU with tensor parallelism" - ).parse_args([]) + ).parse_args() _initialize_distributed(args) run_dense_grad_tests(args, mesh=None) diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index 3dc3154a44..77b73c2243 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -29,13 +29,13 @@ DP_AXIS, TPSP_AXIS, cgemm_parser, - get_quantization_recipe_from_name_string, ) import transformer_engine.jax.cpp_extensions as tex from transformer_engine.jax.quantize import ( autocast, is_quantize_recipe_supported, + get_quantization_recipe, QuantizerFactory, noop_quantizer_set, ) @@ -60,23 +60,6 @@ def _get_operand_sharding(mesh, collective_op, is_with_dp): return x_sharding, weight_sharding, bias_sharding, output_sharding -def _get_dp_and_tp_sizes(args): - num_gpu = args.num_processes * args.num_devices_per_process - if args.tensor_parallel_size is None: - num_gpu_dp = 2 if args.enable_data_parallel else 1 - assert ( - num_gpu > 1 and num_gpu % num_gpu_dp == 0 - ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" - num_gpu_tp = num_gpu // num_gpu_dp - else: - num_gpu_tp = args.tensor_parallel_size - assert ( - num_gpu > 1 and num_gpu % num_gpu_tp == 0 - ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" - num_gpu_dp = num_gpu // num_gpu_tp - return num_gpu_dp, num_gpu_tp - - @partial(jax.jit, static_argnames=("contracting_dims", "collective_op", "output_sharding")) def _jitted_cgemm(x, weight, bias, quantizer_set, contracting_dims, collective_op, output_sharding): output = tex.gemm( @@ -116,7 +99,7 @@ def run_gemm_tests(args, mesh=None): use_quantization = args.quantize_recipe is not None recipe = ( - get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + get_quantization_recipe(args.quantize_recipe) if use_quantization else None ) # autocast sets the global recipe (fwd/bwd dtypes) AND the global MeshResource diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 788c4efa37..320449d9ca 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -2,7 +2,6 @@ # # See LICENSE for license information. """Collective Dense Gradient test on multi-GPU with tensor parallelism""" -import argparse import unittest import os @@ -20,7 +19,6 @@ DP_AXIS, TPSP_AXIS, cgemm_parser, - get_quantization_recipe_from_name_string, ) from transformer_engine.jax.layernorm_mlp import layernorm_mlp @@ -28,6 +26,7 @@ from transformer_engine.jax.quantize import ( autocast, is_quantize_recipe_supported, + get_quantization_recipe, QuantizerFactory, noop_quantizer_set, ) @@ -127,7 +126,7 @@ def _value_and_grad_layernorm_mlp( def run_layernorm_mlp_grad_tests(args, mesh=None): - """Execute Dense Gradient tests.""" + """Execute LayerNorm MLP Gradient tests.""" print(args) # Initialize distributed with provided arguments @@ -161,7 +160,7 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): use_quantization = args.quantize_recipe is not None recipe = ( - get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + get_quantization_recipe(args.quantize_recipe) if use_quantization else None ) with mesh, autocast( enabled=use_quantization, @@ -237,7 +236,7 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): class TestCollectiveLayerNormMLPGradient(unittest.TestCase): - """Collective Dense Gradient unittests""" + """Collective LayerNorm MLP Gradient unittests""" def setUp(self): self.args = cgemm_parser( @@ -320,6 +319,6 @@ def test_te_mxfp8_layernorm_mlp_grad(self): args = cgemm_parser( "Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism" - ).parse_args([]) + ).parse_args() _initialize_distributed(args) run_layernorm_mlp_grad_tests(args, mesh=None) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index fa4d1841c0..3fde48a67a 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -52,6 +52,7 @@ "is_fp8_available", "is_scaling_mode_supported", "is_quantize_recipe_supported", + "get_quantization_recipe", "get_supported_scaling_modes", "get_supported_quantization_recipes", "update_collections", @@ -189,6 +190,34 @@ def is_quantize_recipe_supported(recipe_name: str, gpu_id=None) -> Tuple[bool, s return is_scaling_mode_supported(scaling_mode, gpu_id) +_RECIPE_NAME_TO_RECIPE = { + "DelayedScaling": DelayedScaling, + "Float8CurrentScaling": Float8CurrentScaling, + "MXFP8BlockScaling": MXFP8BlockScaling, + "NVFP4BlockScaling": NVFP4BlockScaling, +} + + +def get_quantization_recipe(name: str) -> Recipe: + """Return a recipe object from a recipe name string. + + Args: + name: Recipe name. One of "DelayedScaling", "Float8CurrentScaling", + "MXFP8BlockScaling", or "NVFP4BlockScaling". + + Returns: + A new instance of the corresponding recipe class. + + Raises: + ValueError: If ``name`` does not match any known recipe. + """ + recipe_cls = _RECIPE_NAME_TO_RECIPE.get(name) + if recipe_cls is None: + valid = list(_RECIPE_NAME_TO_RECIPE) + raise ValueError(f"Invalid quantization recipe '{name}'. Valid options: {valid}") + return recipe_cls() + + def is_fp8_available( scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, gpu_id=None, From 4a3057c288aa545b940e0331196c3b06bead68a7 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 11 Mar 2026 14:31:54 -0700 Subject: [PATCH 15/22] address comments Signed-off-by: Phuong Nguyen --- .../jax/collective_gemm/run_test_cgemm.sh | 42 ++++++++--------- .../test_layernorm_mlp_grad.py | 7 ++- transformer_engine/jax/cpp_extensions/gemm.py | 12 ++++- transformer_engine/jax/quantize/helper.py | 46 ++++++++----------- 4 files changed, 56 insertions(+), 51 deletions(-) diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index a098515af9..647fc24c43 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -28,28 +28,28 @@ fi # the time. TEST_CASES=( # test_gemm.py cases -"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp" -"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp" -"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp" -"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp" -"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" -"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" -# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" -# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" -# -# # test_dense_grad.py cases -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather" -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter" -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather" -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter" -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather" -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" +# # # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" +# # # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" +# # +# # # test_dense_grad.py cases +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter" # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_all_gather" # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_reduce_scatter" # # # test_layernorm_mlp_grad.py cases -"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad" -"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad" +# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad" +# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad" "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_current_scaling_fp8_layernorm_mlp_grad" "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_mxfp8_layernorm_mlp_grad" # "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_nvfp4_layernorm_mlp_grad" @@ -100,7 +100,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do if [ $i -eq 0 ]; then # For process 0: show live output AND save to log file using tee echo "=== Live output from process 0 ===" - pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + XLA_FLAGS="--xla_gpu_graph_min_graph_size=1 $XLA_FLAGS" pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_NAME}.xml \ "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \ --num-processes=$NUM_GPUS \ @@ -109,7 +109,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do PIDS+=($PID) else # For other processes: redirect to log files only - pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + XLA_FLAGS="--xla_gpu_graph_min_graph_size=1 $XLA_FLAGS" pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \ --num-processes=$NUM_GPUS \ --process-id=$i > "$LOG_FILE" 2>&1 & @@ -136,7 +136,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do # Remove the log files after processing them wait - rm ${TEST_NAME}_gpu_*.log + rm ${TEST_NAME}_gpu_*.log done wait diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 320449d9ca..dfa92511b5 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -169,8 +169,8 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): ): # Build quantizer_set inside autocast so create_set() reads the global recipe # for correct fwd/bwd dtypes. One set per dense layer (GEMM1=AG, GEMM2=RS). - quantizer_set = QuantizerFactory.create_set() if use_quantization else noop_quantizer_set - quantizer_sets = (quantizer_set, quantizer_set) + quantizer_set = QuantizerFactory.create_set(n_quantizer_sets) if use_quantization else (noop_quantizer_set, noop_quantizer_set) + # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) @@ -200,6 +200,7 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): noop_collective_op_sets, quantizer_sets, ) + jax.profiler.start_trace(f"traces/cgemm_trace_{args.quantize_recipe}") output, sharded_grads = _value_and_grad_layernorm_mlp( x_sharded, weight_1_sharded, @@ -214,6 +215,8 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): collective_op_sets, quantizer_sets, ) + jax.block_until_ready(output) + jax.profiler.stop_trace() jax.block_until_ready(ref_output) jax.block_until_ready(output) gathered_grads = [] diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index a8d2f2aa69..d54bd3c525 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -688,10 +688,18 @@ def impl( lhs.shape[lhs_flatten_axis:] ) assert lhs_first % 128 == 0 and lhs_last % 128 == 0, ( - "MXFP8 + Collective AG requires LHS dimensions before and after the flatten" + "MXFP8 + Collective AG/RS requires LHS dimensions before and after the flatten" f" axis to be multiples of 128. Got lhs.shape={lhs.shape}," f" lhs_flatten_axis={lhs_flatten_axis}" ) + rhs_first, rhs_last = math.prod(rhs.shape[:rhs_flatten_axis]), math.prod( + rhs.shape[rhs_flatten_axis:] + ) + assert rhs_first % 128 == 0 and rhs_last % 128 == 0, ( + "MXFP8 + Collective AG/RS requires LHS dimensions before and after the flatten" + f" axis to be multiples of 128. Got rhs.shape={rhs.shape}," + f" rhs_flatten_axis={rhs_flatten_axis}" + ) # The scale needs to be in good shape for reordering assert lhs_scale_inv.shape[sequence_dim] % tpsp_axis_size() == 0, ( "MXFP8 + Collective AG/RS requires LHS scale inv sequence dimension to be" @@ -1295,7 +1303,7 @@ def _te_gemm( if not collective_op.is_none: assert not scaling_mode.is_nvfp4_scaling, ( f"Collective GEMM is not yet supported with {scaling_mode} quantization. " - "Only DELAYED_TENSOR_SCALING and CURRENT_TENSOR_SCALING are supported." + "Only DELAYED_TENSOR_SCALING, CURRENT_TENSOR_SCALING, and MXFP8_1D_SCALING are supported." ) out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 3fde48a67a..3747ab8245 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -164,32 +164,6 @@ def is_scaling_mode_supported( return _is_scaling_mode_supported[scaling_mode], _reason_for_no_scaling_mode[scaling_mode] -_RECIPE_NAME_TO_SCALING_MODE = { - "DelayedScaling": ScalingMode.DELAYED_TENSOR_SCALING, - "Float8CurrentScaling": ScalingMode.CURRENT_TENSOR_SCALING, - "MXFP8BlockScaling": ScalingMode.MXFP8_1D_SCALING, - "NVFP4BlockScaling": ScalingMode.NVFP4_1D_SCALING, -} - - -def is_quantize_recipe_supported(recipe_name: str, gpu_id=None) -> Tuple[bool, str]: - """Check if the given quantization recipe (by name) is supported on the current GPU. - - Args: - recipe_name: Name of the recipe, e.g. "DelayedScaling", "Float8CurrentScaling", - "MXFP8BlockScaling", "NVFP4BlockScaling". - gpu_id: Optional GPU ID to check a specific device (default: all local devices). - - Returns: - A tuple of (supported: bool, reason: str). - """ - scaling_mode = _RECIPE_NAME_TO_SCALING_MODE.get(recipe_name) - if scaling_mode is None: - valid = list(_RECIPE_NAME_TO_SCALING_MODE) - return False, f"Unknown quantization recipe '{recipe_name}'. Valid options: {valid}" - return is_scaling_mode_supported(scaling_mode, gpu_id) - - _RECIPE_NAME_TO_RECIPE = { "DelayedScaling": DelayedScaling, "Float8CurrentScaling": Float8CurrentScaling, @@ -218,6 +192,26 @@ def get_quantization_recipe(name: str) -> Recipe: return recipe_cls() +def is_quantize_recipe_supported(recipe_name: str) -> Tuple[bool, str]: + """Check if the given quantization recipe (by name) is supported on the current GPU. + + Args: + recipe_name: Name of the recipe, e.g. "DelayedScaling", "Float8CurrentScaling", + "MXFP8BlockScaling", "NVFP4BlockScaling". + + Returns: + A tuple of (supported: bool, reason: str). + """ + recipe = get_quantization_recipe(recipe_name) + config = get_quantize_config_with_recipe(recipe) + for tensor_source in TensorSource: + scaling_mode = config.get_scaling_mode(tensor_source) + is_supported, reason = is_scaling_mode_supported(scaling_mode) + if not is_supported: + return is_supported, reason + return True, None + + def is_fp8_available( scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, gpu_id=None, From 10ae5b8f86eccbab93125eed74a44cb9e15b8ea6 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 11 Mar 2026 14:34:11 -0700 Subject: [PATCH 16/22] typo Signed-off-by: Phuong Nguyen --- examples/jax/collective_gemm/test_layernorm_mlp_grad.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index dfa92511b5..bdd4f371b3 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -167,9 +167,9 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): - # Build quantizer_set inside autocast so create_set() reads the global recipe + # Build quantizer_sets inside autocast so create_set() reads the global recipe # for correct fwd/bwd dtypes. One set per dense layer (GEMM1=AG, GEMM2=RS). - quantizer_set = QuantizerFactory.create_set(n_quantizer_sets) if use_quantization else (noop_quantizer_set, noop_quantizer_set) + quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2) if use_quantization else (noop_quantizer_set, noop_quantizer_set) # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() @@ -232,7 +232,7 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): jax.block_until_ready(gathered_ref_grads) if args.enable_result_check and args.process_id == 0: - tol_dtype = get_tolerance_dtype(quantizer_set) + tol_dtype = get_tolerance_dtype(quantizer_sets[0]) assert_allclose(ref_output, output, dtype=tol_dtype) for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): assert_allclose(ref_grad, gathered_grad, dtype=tol_dtype) From edb405c6ccba207e3688945ef1dad6f776430e4c Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 11 Mar 2026 14:35:43 -0700 Subject: [PATCH 17/22] include all tests Signed-off-by: Phuong Nguyen --- .../jax/collective_gemm/run_test_cgemm.sh | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index 647fc24c43..019e5ebbc3 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -28,28 +28,28 @@ fi # the time. TEST_CASES=( # test_gemm.py cases -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" -# # # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" -# # # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" -# # -# # # test_dense_grad.py cases -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" +# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" +# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" +# +# # test_dense_grad.py cases +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter" # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_all_gather" # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_reduce_scatter" -# -# # test_layernorm_mlp_grad.py cases -# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad" -# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad" + +# test_layernorm_mlp_grad.py cases +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad" "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_current_scaling_fp8_layernorm_mlp_grad" "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_mxfp8_layernorm_mlp_grad" # "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_nvfp4_layernorm_mlp_grad" From 976c80da46e4b6116c56f582ca9aafec723b6311 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Mar 2026 21:38:52 +0000 Subject: [PATCH 18/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/jax/collective_gemm/test_dense_grad.py | 4 +--- examples/jax/collective_gemm/test_gemm.py | 4 +--- .../jax/collective_gemm/test_layernorm_mlp_grad.py | 10 ++++++---- transformer_engine/jax/cpp_extensions/gemm.py | 4 ++-- transformer_engine/jax/quantize/helper.py | 8 ++++---- 5 files changed, 14 insertions(+), 16 deletions(-) diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index a36efb2ae9..1d300f8e90 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -109,9 +109,7 @@ def run_dense_grad_tests(args, mesh=None): collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op) use_quantization = args.quantize_recipe is not None - recipe = ( - get_quantization_recipe(args.quantize_recipe) if use_quantization else None - ) + recipe = get_quantization_recipe(args.quantize_recipe) if use_quantization else None with mesh, autocast( enabled=use_quantization, recipe=recipe, diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index 77b73c2243..c2db8fc44a 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -98,9 +98,7 @@ def run_gemm_tests(args, mesh=None): ) use_quantization = args.quantize_recipe is not None - recipe = ( - get_quantization_recipe(args.quantize_recipe) if use_quantization else None - ) + recipe = get_quantization_recipe(args.quantize_recipe) if use_quantization else None # autocast sets the global recipe (fwd/bwd dtypes) AND the global MeshResource # (via global_shard_guard) required for collective GEMM sharding axis resolution. diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index bdd4f371b3..e4ad1fc23c 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -159,9 +159,7 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set) use_quantization = args.quantize_recipe is not None - recipe = ( - get_quantization_recipe(args.quantize_recipe) if use_quantization else None - ) + recipe = get_quantization_recipe(args.quantize_recipe) if use_quantization else None with mesh, autocast( enabled=use_quantization, recipe=recipe, @@ -169,7 +167,11 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): ): # Build quantizer_sets inside autocast so create_set() reads the global recipe # for correct fwd/bwd dtypes. One set per dense layer (GEMM1=AG, GEMM2=RS). - quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2) if use_quantization else (noop_quantizer_set, noop_quantizer_set) + quantizer_sets = ( + QuantizerFactory.create_set(n_quantizer_sets=2) + if use_quantization + else (noop_quantizer_set, noop_quantizer_set) + ) # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d54bd3c525..3d4a38cafb 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1302,8 +1302,8 @@ def _te_gemm( if not collective_op.is_none: assert not scaling_mode.is_nvfp4_scaling, ( - f"Collective GEMM is not yet supported with {scaling_mode} quantization. " - "Only DELAYED_TENSOR_SCALING, CURRENT_TENSOR_SCALING, and MXFP8_1D_SCALING are supported." + f"Collective GEMM is not yet supported with {scaling_mode} quantization. Only" + " DELAYED_TENSOR_SCALING, CURRENT_TENSOR_SCALING, and MXFP8_1D_SCALING are supported." ) out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 3747ab8245..3a93af4a68 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -205,10 +205,10 @@ def is_quantize_recipe_supported(recipe_name: str) -> Tuple[bool, str]: recipe = get_quantization_recipe(recipe_name) config = get_quantize_config_with_recipe(recipe) for tensor_source in TensorSource: - scaling_mode = config.get_scaling_mode(tensor_source) - is_supported, reason = is_scaling_mode_supported(scaling_mode) - if not is_supported: - return is_supported, reason + scaling_mode = config.get_scaling_mode(tensor_source) + is_supported, reason = is_scaling_mode_supported(scaling_mode) + if not is_supported: + return is_supported, reason return True, None From 83f6f12d57fd904e9da40f6af55bbc67e10a755f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 11 Mar 2026 14:40:47 -0700 Subject: [PATCH 19/22] cleanup tracing Signed-off-by: Phuong Nguyen --- examples/jax/collective_gemm/test_layernorm_mlp_grad.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index e4ad1fc23c..be94c68d37 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -202,7 +202,6 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): noop_collective_op_sets, quantizer_sets, ) - jax.profiler.start_trace(f"traces/cgemm_trace_{args.quantize_recipe}") output, sharded_grads = _value_and_grad_layernorm_mlp( x_sharded, weight_1_sharded, @@ -217,8 +216,6 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): collective_op_sets, quantizer_sets, ) - jax.block_until_ready(output) - jax.profiler.stop_trace() jax.block_until_ready(ref_output) jax.block_until_ready(output) gathered_grads = [] From 74cae04230a1cee23cf2b27159308631e9c6a85f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 11 Mar 2026 14:42:05 -0700 Subject: [PATCH 20/22] cleanup Signed-off-by: Phuong Nguyen --- examples/jax/collective_gemm/run_test_cgemm.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index 019e5ebbc3..8340d2010f 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -100,7 +100,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do if [ $i -eq 0 ]; then # For process 0: show live output AND save to log file using tee echo "=== Live output from process 0 ===" - XLA_FLAGS="--xla_gpu_graph_min_graph_size=1 $XLA_FLAGS" pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_NAME}.xml \ "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \ --num-processes=$NUM_GPUS \ @@ -109,7 +109,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do PIDS+=($PID) else # For other processes: redirect to log files only - XLA_FLAGS="--xla_gpu_graph_min_graph_size=1 $XLA_FLAGS" pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \ --num-processes=$NUM_GPUS \ --process-id=$i > "$LOG_FILE" 2>&1 & From e6468c680f139334d246d539a1aa6459feff0758 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 11 Mar 2026 14:46:11 -0700 Subject: [PATCH 21/22] add comment Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 3d4a38cafb..bc5f916248 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -994,6 +994,8 @@ def _parse_operand_output_specs( lhs_scale_specs = rhs_scale_specs = (None,) if scaling_mode.is_1d_block_scaling(): rhs_scale_specs = rhs_specs + # Set the seq spec to None to trigger AG the scales as TE/Common CGEMM does not handle + # scale collecting yet if collective_op.is_all_gather: lhs_scale_specs = tuple( None if i == sequence_dim else s for i, s in enumerate(lhs_specs) From ea929cf8cd2f5fb26c7f19f4d603374547e9c927 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 11 Mar 2026 14:49:52 -0700 Subject: [PATCH 22/22] typo Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index bc5f916248..515f02af6e 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -702,7 +702,7 @@ def impl( ) # The scale needs to be in good shape for reordering assert lhs_scale_inv.shape[sequence_dim] % tpsp_axis_size() == 0, ( - "MXFP8 + Collective AG/RS requires LHS scale inv sequence dimension to be" + "MXFP8 + Collective AG/RS requires RHS scale inv sequence dimension to be" f" multiples of tpsp_axis_size. Got lhs_scale_inv.shape={lhs_scale_inv.shape}," f" tpsp_axis_size={tpsp_axis_size()}, sequence_dim={sequence_dim}" )