Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
380 changes: 380 additions & 0 deletions benchmarks/linear/benchmark_graph_safe_grouped_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,380 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Benchmark MXFP8 graph-safe grouped MLP.

This mirrors ``benchmark_grouped_linear.py`` but targets the graph-safe TE ops
path used by grouped MLP:

GroupedLinear -> ScaledSwiGLU -> GroupedLinear

The benchmark intentionally uses CUDA-device ``m_splits`` and MXFP8 only.

Example:

python benchmarks/linear/benchmark_graph_safe_grouped_linear.py

Forward-only:

python benchmarks/linear/benchmark_graph_safe_grouped_linear.py --fwd-only

Nsight Systems:

(optionally: unset DEBUGINFOD_URLS)

nsys profile \
--output=./benchmarks/linear/graph_safe_grouped_linear_mxfp8 \
--force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_graph_safe_grouped_linear.py --profile
"""

# Match the Qwen MXFP8 SFT launch toggles before importing TE.
import os

os.environ.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "1")
os.environ.setdefault("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")
os.environ.setdefault("NVTE_CUTEDSL_FUSED_GROUPED_MLP", "1")
os.environ.setdefault("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1")

import argparse
from contextlib import nullcontext

import pandas as pd
import torch
import torch.utils.benchmark as benchmark

import transformer_engine.pytorch as te
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common.recipe import MXFP8BlockScaling
from transformer_engine.pytorch.quantization import FP8GlobalStateManager


MXFP8_AVAILABLE, REASON_FOR_NO_MXFP8 = FP8GlobalStateManager.is_mxfp8_available()


def parse_int_list(value: str) -> list[int]:
"""Parse comma-separated integers."""
return [int(x) for x in value.split(",") if x]


def make_uniform_splits(total_tokens: int, num_groups: int) -> list[int]:
"""Split tokens uniformly across groups."""
if total_tokens % num_groups != 0:
raise ValueError(
"Uniform split requires total_tokens divisible by num_groups, "
f"got total_tokens={total_tokens}, num_groups={num_groups}"
)
return [total_tokens // num_groups] * num_groups


def build_grouped_mlp(
*,
num_groups: int,
hidden_dim: int,
ffn_hidden_dim: int,
dtype: torch.dtype,
single_grouped_weight: bool,
accumulate_into_main_grad: bool,
glu_interleave_size: int,
) -> te_ops.Sequential:
"""Build graph-safe grouped MLP ops sequence."""
recipe = MXFP8BlockScaling()
with te.quantized_model_init(enabled=True, recipe=recipe):
fc1 = te_ops.GroupedLinear(
num_groups,
hidden_dim,
2 * ffn_hidden_dim,
bias=False,
device="cuda",
dtype=dtype,
single_grouped_weight=single_grouped_weight,
accumulate_into_main_grad=accumulate_into_main_grad,
)
fc2 = te_ops.GroupedLinear(
num_groups,
ffn_hidden_dim,
hidden_dim,
bias=False,
device="cuda",
dtype=dtype,
single_grouped_weight=single_grouped_weight,
accumulate_into_main_grad=accumulate_into_main_grad,
)
return te_ops.Sequential(
fc1,
te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size),
fc2,
)


def init_main_grads(module: torch.nn.Module, value: float = 0.0) -> None:
"""Initialize Megatron-style main_grad buffers for accumulate_into_main_grad."""
with torch.no_grad():
for param in module.parameters():
if getattr(param, "main_grad", None) is None:
param.main_grad = torch.empty(
param.size(), device=param.device, dtype=torch.float32
)
param.main_grad.fill_(value)


def zero_grads(module: torch.nn.Module, x: torch.Tensor, scales: torch.Tensor) -> None:
"""Reset gradients without changing allocated main_grad buffers."""
module.zero_grad(set_to_none=True)
x.grad = None
scales.grad = None


def run_grouped_mlp_steps(
module: torch.nn.Module,
x: torch.Tensor,
split_sizes: torch.Tensor,
scales: torch.Tensor,
grad_output: torch.Tensor,
*,
recipe: MXFP8BlockScaling,
fwd_only: bool,
num_steps: int,
accumulate_into_main_grad: bool,
) -> torch.Tensor:
"""Run eager grouped MLP for a number of synthetic microbatches."""
quantization_context = te.autocast(enabled=True, recipe=recipe)

if fwd_only:
with torch.no_grad(), quantization_context:
for _ in range(num_steps):
out = module(x, split_sizes, scales, split_sizes)
return out

zero_grads(module, x, scales)
if accumulate_into_main_grad:
init_main_grads(module)

with quantization_context:
for step in range(num_steps):
torch.cuda.nvtx.range_push(f"step_{step}")
out = module(x, split_sizes, scales, split_sizes)
out.backward(grad_output)
torch.cuda.nvtx.range_pop()
return out


def benchmark_case(
*,
total_tokens: int,
hidden_dim: int,
ffn_hidden_dim: int,
num_groups: int,
dtype: torch.dtype,
fwd_only: bool,
single_grouped_weight: bool,
accumulate_into_main_grad: bool,
glu_interleave_size: int,
num_microbatches: int,
min_run_time: float,
profile: bool,
) -> float:
"""Benchmark one grouped MLP shape."""
split_sizes_list = make_uniform_splits(total_tokens, num_groups)
split_sizes = torch.tensor(split_sizes_list, dtype=torch.int64, device="cuda")
x = torch.randn(
(total_tokens, hidden_dim),
dtype=dtype,
device="cuda",
requires_grad=not fwd_only,
)
scales = torch.ones(
(total_tokens,),
dtype=dtype,
device="cuda",
requires_grad=not fwd_only,
)
grad_output = torch.ones((total_tokens, hidden_dim), dtype=dtype, device="cuda")

module = build_grouped_mlp(
num_groups=num_groups,
hidden_dim=hidden_dim,
ffn_hidden_dim=ffn_hidden_dim,
dtype=dtype,
single_grouped_weight=single_grouped_weight,
accumulate_into_main_grad=accumulate_into_main_grad,
glu_interleave_size=glu_interleave_size,
)
recipe = MXFP8BlockScaling()

print(
"case:",
f"tokens={total_tokens}",
f"hidden={hidden_dim}",
f"ffn_hidden={ffn_hidden_dim}",
f"num_groups={num_groups}",
f"fwd_only={fwd_only}",
f"single_grouped_weight={single_grouped_weight}",
f"accumulate_into_main_grad={accumulate_into_main_grad}",
f"glu_interleave_size={glu_interleave_size}",
)
print(f"m_splits: {split_sizes_list}")

# Warmup also forces the op-fuser to materialize the expected fused ops.
run_grouped_mlp_steps(
module,
x,
split_sizes,
scales,
grad_output,
recipe=recipe,
fwd_only=fwd_only,
num_steps=128,
accumulate_into_main_grad=accumulate_into_main_grad,
)
torch.cuda.synchronize()

forward_ops = module._module_groups[0]._forward_ops
print("forward fused op:", type(forward_ops[0][0]).__name__ if forward_ops else "none")
if not fwd_only:
backward_ops = module._module_groups[0]._backward_ops
print("backward fused op:", type(backward_ops[0][0]).__name__ if backward_ops else "none")

label = "graph_safe_grouped_mlp_mxfp8_swiglu"
timing_context = (
torch.autograd.profiler.emit_nvtx(record_shapes=True) if profile else nullcontext()
)
with timing_context:
torch.cuda.nvtx.range_push(label)
timing = benchmark.Timer(
stmt=(
"run_grouped_mlp_steps("
"module, x, split_sizes, scales, grad_output, "
"recipe=recipe, fwd_only=fwd_only, num_steps=num_microbatches, "
"accumulate_into_main_grad=accumulate_into_main_grad)"
),
globals={
"run_grouped_mlp_steps": run_grouped_mlp_steps,
"module": module,
"x": x,
"split_sizes": split_sizes,
"scales": scales,
"grad_output": grad_output,
"recipe": recipe,
"fwd_only": fwd_only,
"num_microbatches": num_microbatches,
"accumulate_into_main_grad": accumulate_into_main_grad,
},
num_threads=1,
).blocked_autorange(min_run_time=min_run_time)
torch.cuda.nvtx.range_pop()

print(f"mxfp8_swiglu: {timing}\n")
return timing.median * 1000 / num_microbatches


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable NVTX profiling annotations")
parser.add_argument(
"--fwd-only",
action="store_true",
default=False,
help="Benchmark forward only. Default benchmarks forward + backward.",
)
parser.add_argument(
"--num-groups",
type=str,
default="8",
help="Comma-separated local grouped GEMM/expert counts.",
)
parser.add_argument(
"--token-dims",
type=str,
default="65536",
help="Comma-separated total token counts to benchmark.",
)
parser.add_argument("--hidden-dim", type=int, default=7168)
parser.add_argument("--ffn-hidden-dim", type=int, default=2048)
parser.add_argument("--num-microbatches", type=int, default=32)
parser.add_argument("--min-run-time", type=float, default=10.0)
parser.add_argument("--glu-interleave-size", type=int, default=32)
parser.add_argument(
"--single-grouped-weight",
action="store_true",
default=False,
help="Use one GroupedTensor parameter for each grouped linear.",
)
args = parser.parse_args()

if not MXFP8_AVAILABLE:
raise RuntimeError(f"MXFP8 is not available: {REASON_FOR_NO_MXFP8}")
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark.")

dtype = torch.bfloat16
accumulate_into_main_grad = True
token_dims = parse_int_list(args.token_dims)
num_groups_list = parse_int_list(args.num_groups)

print("Environment toggles:")
for name in (
"CUDA_DEVICE_MAX_CONNECTIONS",
"NVTE_ALLOW_NONDETERMINISTIC_ALGO",
"NVTE_CUTEDSL_FUSED_GROUPED_MLP",
"CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL",
):
print(f" {name}={os.environ.get(name)}")
print("Recipe: MXFP8BlockScaling")
print("Activation: ScaledSwiGLU")
print(f"Default GLU interleave size: {args.glu_interleave_size}")
print()

data = []
for num_groups in num_groups_list:
for total_tokens in token_dims:
timing_ms = benchmark_case(
total_tokens=total_tokens,
hidden_dim=args.hidden_dim,
ffn_hidden_dim=args.ffn_hidden_dim,
num_groups=num_groups,
dtype=dtype,
fwd_only=args.fwd_only,
single_grouped_weight=args.single_grouped_weight,
accumulate_into_main_grad=accumulate_into_main_grad,
glu_interleave_size=args.glu_interleave_size,
num_microbatches=args.num_microbatches,
min_run_time=args.min_run_time,
profile=args.profile,
)
data.append(
[
total_tokens,
args.hidden_dim,
args.ffn_hidden_dim,
num_groups,
args.glu_interleave_size,
args.single_grouped_weight,
accumulate_into_main_grad,
"fwd" if args.fwd_only else "fwd_bwd",
timing_ms,
]
)

timing_col = "time_per_microbatch_ms"
df = pd.DataFrame(
data=data,
columns=[
"tokens",
"hidden_dim",
"ffn_hidden_dim",
"num_groups",
"glu_interleave_size",
"single_grouped_weight",
"accumulate_into_main_grad",
"mode",
timing_col,
],
)
print(df)


if __name__ == "__main__":
main()
Loading
Loading