Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 34 additions & 47 deletions examples/jax/collective_gemm/common.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,52 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Shared functions for the comm_overlap tests"""
"""Shared functions for the collective GEMM tests"""

import argparse

import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import mesh_utils

from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap


# Add this after your existing imports
def dtype_tols(dtype, rtol=None, atol=None):
"""Expected numerical tolerance for a data type."""
# Return immediately if tolerances are fully specified
if rtol is not None and atol is not None:
return {"rtol": rtol, "atol": atol}

# Default tolerances for common dtypes
if dtype in [jnp.float32, "float32"]:
return {"rtol": 1e-5, "atol": 1e-8}
elif dtype in [jnp.float16, "float16"]:
return {"rtol": 1e-3, "atol": 1e-6}
elif dtype in [jnp.bfloat16, "bfloat16"]:
return {"rtol": 1e-2, "atol": 1e-5}
elif dtype in [jnp.float8_e4m3fn, "float8_e4m3fn", jnp.float8_e5m2, "float8_e5m2"]:
# FP8 quantization introduces ~1% error; match C++ getTolerances for fp8 types
return {"rtol": 1e-2, "atol": 1e-2}
else:
return {"rtol": 1e-5, "atol": 1e-8}


def assert_allclose(
actual,
desired,
rtol=None,
atol=None,
dtype=None,
**kwargs,
):
def get_tolerance_dtype(quantizer_set):
"""Return the dtype used to select numerical tolerances based on the active quantizer.

Reads q_dtype from quantizer_set.x; falls back to bfloat16 when no quantizer is
active (NO_SCALING / noop path, where quantizer_set.x is None).
"""
if quantizer_set.x is not None:
return quantizer_set.x.q_dtype
return jnp.bfloat16


def assert_allclose(actual, desired, rtol=None, atol=None, dtype=None, **kwargs):
"""Check if two tensors are close."""
# Infer data type if needed
if dtype is None:
if isinstance(actual, float):
dtype = "float32"
else:
dtype = actual.dtype
dtype = "float32" if isinstance(actual, float) else actual.dtype

# Determine tolerances
tols = {}
if rtol is None or atol is None:
tols = dtype_tols(dtype)
Expand All @@ -50,49 +55,26 @@ def assert_allclose(
if atol is not None:
tols["atol"] = atol

# Cast tensors to fp32
if not isinstance(actual, float):
actual = actual.astype(jnp.float32)
if not isinstance(desired, float):
desired = desired.astype(jnp.float32)

# Check if tensors are close
np.testing.assert_allclose(actual, desired, **tols, **kwargs)


def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e-8):
if not jnp.allclose(ref_output, gathered_output, rtol=rtol, atol=atol):
diff = jnp.abs(ref_output - gathered_output)
mask = diff > (atol + rtol * jnp.abs(gathered_output))
print(mask.astype(int))
print(jnp.where(mask, diff, 0))


# Shared constants for all tests
# Shared constants
DP_AXIS = "data"
TPSP_AXIS = "tensor_sequence"
PARAMS_KEY = "params"

# Shared functions for distributed testing
import argparse
import jax
from jax.experimental import mesh_utils
from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap

# Global flag to track if distributed has been initialized
_distributed_initialized = False


def _is_distributed_initialized():
"""Check if JAX distributed has been initialized."""
return _distributed_initialized


def _initialize_distributed(args):
"""Initialize JAX distributed with custom arguments."""
global _distributed_initialized

# Check if already initialized
if _distributed_initialized:
return

Expand All @@ -105,14 +87,10 @@ def _initialize_distributed(args):
assert (
args.num_devices_per_process is not None
), "Either local_device_ids or num_devices_per_process must be provided"
# Calculate device range for this process
# Single process single device: each process gets one unique device
# Single process multiple devices: each process gets a unique range of devices
start_device = args.process_id * args.num_devices_per_process
device_range = range(start_device, start_device + args.num_devices_per_process)
global_device_ids_for_this_process = ",".join(map(str, device_range))
else:
# Use explicitly provided global device IDs
global_device_ids_for_this_process = args.local_device_ids
args.num_devices_per_process = len(args.local_device_ids.split(","))

Expand Down Expand Up @@ -229,7 +207,16 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para
help="Type of collective operation",
)
parser.add_argument(
"--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use"
"--quantize-recipe",
type=str,
default=None,
choices=[
"DelayedScaling",
"Float8CurrentScaling",
"MXFP8BlockScaling",
"NVFP4BlockScaling",
],
help="Quantization recipe to use. Omit for BF16 (no quantization).",
)
parser.add_argument(
"--enable-data-parallel", action="store_true", help="Enable data parallelism"
Expand Down
72 changes: 50 additions & 22 deletions examples/jax/collective_gemm/run_test_cgemm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,36 @@ else
echo "NVLINK support detected"
fi

# Define the test files to run
TEST_FILES=(
"test_gemm.py"
"test_dense_grad.py"
"test_layernorm_mlp_grad.py"
# Define individual test cases to run (file::class::method)
# DelayedScalingFP8 and CurrentScalingFP8 use the same GEMM so we don't need to test both cases all
# the time.
TEST_CASES=(
# test_gemm.py cases
"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp"
"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp"
"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp"
"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp"
"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp"
"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp"
# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp"
# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp"
#
# # test_dense_grad.py cases
"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather"
"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter"
"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather"
"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter"
"test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather"
"test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter"
# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_all_gather"
# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_reduce_scatter"
#
# # test_layernorm_mlp_grad.py cases
"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad"
"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad"
"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_current_scaling_fp8_layernorm_mlp_grad"
"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_mxfp8_layernorm_mlp_grad"
# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_nvfp4_layernorm_mlp_grad"
)

echo
Expand Down Expand Up @@ -57,32 +82,35 @@ cleanup() {
# Set up signal handlers to cleanup on exit
trap cleanup EXIT INT TERM

# Run each test file across all GPUs
for TEST_FILE in "${TEST_FILES[@]}"; do
# Run each test case across all GPUs
for TEST_CASE in "${TEST_CASES[@]}"; do
echo
echo "=== Starting test file: $TEST_FILE ..."
echo "=== Starting test: $TEST_CASE ..."

# Extract just the test method name for log/xml file naming
TEST_NAME=$(echo "$TEST_CASE" | awk -F'::' '{print $NF}')

# Clear PIDs array for this test file
# Clear PIDs array for this test case
PIDS=()

for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_FILE}_gpu_${i}.log"
LOG_FILE="${TEST_NAME}_gpu_${i}.log"

if [ $i -eq 0 ]; then
# For process 0: show live output AND save to log file using tee
echo "=== Live output from process 0 ==="
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \
"$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
-vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_NAME}.xml \
"$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \
--num-processes=$NUM_GPUS \
--process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
else
# For other processes: redirect to log files only
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \
--num-processes=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
PID=$!
Expand All @@ -93,22 +121,22 @@ for TEST_FILE in "${TEST_FILES[@]}"; do
# Wait for all processes to finish
wait

# Check and print the log content from process 0 (now has log file thanks to tee)
if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE SKIPPED"
elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE FAILED"
# Check and print the log content from process 0
if grep -q "SKIPPED" "${TEST_NAME}_gpu_0.log"; then
echo "... $TEST_CASE SKIPPED"
elif grep -q "FAILED" "${TEST_NAME}_gpu_0.log"; then
echo "... $TEST_CASE FAILED"
HAS_FAILURE=1
elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE PASSED"
elif grep -q "PASSED" "${TEST_NAME}_gpu_0.log"; then
echo "... $TEST_CASE PASSED"
else
echo "... $TEST_FILE INVALID"
echo "... $TEST_CASE INVALID"
HAS_FAILURE=1
fi

# Remove the log files after processing them
wait
rm ${TEST_FILE}_gpu_*.log
rm ${TEST_NAME}_gpu_*.log
done

wait
Expand Down
Loading
Loading