Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions modelopt/torch/quantization/plugins/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

"""Support quantization for Transformer Engine layers."""

import copy
import inspect
import os
import warnings

import torch
Expand All @@ -27,12 +29,17 @@

from modelopt.torch.quantization.utils import replace_function

from ..nn import QuantModuleRegistry
from ..nn import QuantModuleRegistry, SequentialQuantizer
from .custom import _ParallelLinear

_TE_VERSION = Version(te.__version__)


def _per_expert_weight_quantizer_enabled() -> bool:
"""Opt-in MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER=1: per-gemm weight_quantizer in TEGroupedLinear."""
return os.environ.get("MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER", "0") == "1"


def _assert_te_fp8_enabled():
"""Check if Transformer Engine FP8 autocast is enabled and raise error if so."""
try:
Expand Down Expand Up @@ -137,8 +144,10 @@ def _setup(self):
# Remove self.weight after setup.
delattr(self, "weight")

# TODO: GroupedLinear supports weights split by `num_gemms`, to support quantization
# with static parameters beyond per-tensor, we need to support a unique quantizer for each gemm.
self._per_expert_weight_quantizer = _per_expert_weight_quantizer_enabled()
if self._per_expert_weight_quantizer:
for i in range(self.num_gemms):
self.add_module(f"weight_quantizer_{i}", copy.deepcopy(self.weight_quantizer))

def modelopt_post_restore(self, prefix: str = ""):
# GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
Expand All @@ -150,12 +159,35 @@ def modelopt_post_restore(self, prefix: str = ""):
# Remove self.weight after post_restore.
delattr(self, "weight")

# Base post_restore only re-calibrates self.weight_quantizer; the per-expert
# weight_quantizer_{i} also need re-calibration so a TP/EP change between save
# and restore produces correctly shaped per-channel _amax. Mirror base behavior:
# only re-calibrate quantizers whose loaded state had _amax (skip unused ones).
if getattr(self, "_per_expert_weight_quantizer", False):
from modelopt.torch.quantization.model_calib import max_calibrate

for i in range(self.num_gemms):
weight_i = getattr(self, f"weight{i}", None)
if weight_i is None:
continue
wq_i = self._get_weight_quantizer(i)
q = wq_i[0] if isinstance(wq_i, SequentialQuantizer) else wq_i
if not hasattr(q, "_amax"):
continue
wq_i.reset_amax()
max_calibrate(wq_i, lambda wq, w=weight_i: wq(w), distributed_sync=False)

def _get_weight_quantizer(self, gemm_idx: int):
if getattr(self, "_per_expert_weight_quantizer", False):
return getattr(self, f"weight_quantizer_{gemm_idx}")
return self.weight_quantizer

def iter_weights_for_calibration(self):
"""Yield ``(weight_i, weight_quantizer)`` for each of the ``num_gemms`` grouped weights."""
for i in range(self.num_gemms):
weight_i = getattr(self, f"weight{i}", None)
if weight_i is not None:
yield weight_i, self.weight_quantizer
yield weight_i, self._get_weight_quantizer(i)

@staticmethod
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
Expand All @@ -182,8 +214,9 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args):

new_args = list(args)
new_args[inp_pos] = self.input_quantizer(args[inp_pos])
for i in range(weights_start, weights_start + num_gemms):
new_args[i] = self.weight_quantizer(args[i])
for gemm_idx in range(num_gemms):
pos = weights_start + gemm_idx
new_args[pos] = self._get_weight_quantizer(gemm_idx)(args[pos])
output = getattr(package, func_name)(*new_args)
# TE 2.15+ returns `(out, new_workspaces)`; TE <= 2.14 returns just `out`.
# Only the activation tensor participates in output quantization.
Expand Down
126 changes: 126 additions & 0 deletions tests/gpu_megatron/torch/quantization/plugins/test_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import copy
import math
from functools import partial

import pytest
Expand Down Expand Up @@ -694,6 +695,131 @@ def test_te_grouped_vs_sequential_quantize(dist_workers_size_4, quant_cfg):
)


def _test_te_grouped_vs_sequential_default_amax_helper(tp_size, ep_size, quant_cfg, rank, size):
"""TEGrouped per-linear amax should equal max-over-Sequential-experts under default sync=False."""
initialize_for_megatron(
tensor_model_parallel_size=tp_size,
expert_model_parallel_size=ep_size,
seed=SEED,
)

te_grouped = _gpt_model_provider(
tp_size=tp_size, ep_size=ep_size, hidden_size=32, moe_grouped_gemm=True,
transformer_impl="transformer_engine", num_moe_experts=4,
)
forward = get_forward(te_grouped, batch_size=8)

sequential = _gpt_model_provider(
tp_size=tp_size, ep_size=ep_size, hidden_size=32, moe_grouped_gemm=False,
num_moe_experts=4, transformer_impl="modelopt",
)
copy_weights_from_grouped_to_non_grouped(te_grouped, sequential)

for module in te_grouped.modules():
if isinstance(module, TopKRouter):
module.topk = module.num_experts
for module in sequential.modules():
if isinstance(module, TopKRouter):
module.topk = module.num_experts

mtq.quantize(te_grouped, quant_cfg, forward)
mtq.quantize(sequential, quant_cfg, forward)

te_modules = [m for m in te_grouped.modules() if isinstance(m, TEGroupedMLP)]
seq_modules = [m for m in sequential.modules() if isinstance(m, SequentialMLP)]
assert len(te_modules) == len(seq_modules)

saw_per_expert_divergence = False
for te_mlp, seq_mlp in zip(te_modules, seq_modules):
for linear_name in ("linear_fc1", "linear_fc2"):
te_amax = getattr(te_mlp, linear_name).weight_quantizer.amax
assert te_amax is not None and te_amax.numel() == 1

seq_amaxes = torch.stack([
getattr(expert, linear_name).weight_quantizer.amax.view(())
for expert in seq_mlp.local_experts
])
seq_max = seq_amaxes.max()

assert torch.allclose(te_amax.view(()), seq_max, atol=1e-5, rtol=1e-5), (
f"TEGrouped per-linear amax ({te_amax.item()}) != "
f"max-over-Sequential-experts ({seq_max.item()}) for {linear_name}"
)

if (seq_amaxes.max() - seq_amaxes.min()).item() > 1e-5:
saw_per_expert_divergence = True

assert saw_per_expert_divergence, (
"Expected per-expert weight amax to diverge across SequentialMLP experts."
)


@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG])
def test_te_grouped_vs_sequential_default_amax(dist_workers_size_4, quant_cfg):
dist_workers_size_4.run(
partial(_test_te_grouped_vs_sequential_default_amax_helper, 1, 2, quant_cfg)
)
Comment on lines +698 to +761
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Make these “default” tests control the feature flag explicitly.

These helpers never clear or set MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER, so a worker process with that flag exported will exercise the new per-expert path instead of the default shared-quantizer path. That also means the env-enabled branch added in this PR still has no dedicated coverage. Please force the flag off in these default-behavior tests and add a separate env-on case for the new path.

Also applies to: 764-820

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/gpu_megatron/torch/quantization/plugins/test_megatron.py` around lines
698 - 761, The default-path helper
_test_te_grouped_vs_sequential_default_amax_helper must explicitly disable the
feature flag: set os.environ["MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER"]="0"
(saving the original and restoring it after the test) before calling
initialize_for_megatron so the shared-quantizer behavior is forced; then add a
new test variant that sets
os.environ["MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER"]="1" and runs the same
checks to cover the env-on path. Apply the same explicit env-off/restore and
corresponding env-on test change to the other similar helper/test pair
exercising TEGrouped vs Sequential (the companion helper referenced in the
review) so both default and env-enabled code paths have dedicated coverage.



def _test_te_grouped_vs_sequential_default_loss_helper(tp_size, ep_size, quant_cfg, rank, size):
"""TEGrouped quantized output should diverge from BF16 reference more than SequentialMLP under default sync=False."""
initialize_for_megatron(
tensor_model_parallel_size=tp_size,
expert_model_parallel_size=ep_size,
seed=SEED,
)

te_grouped = _gpt_model_provider(
tp_size=tp_size, ep_size=ep_size, hidden_size=32, moe_grouped_gemm=True,
transformer_impl="transformer_engine", num_moe_experts=4,
)
forward = get_forward(te_grouped, batch_size=8)

sequential = _gpt_model_provider(
tp_size=tp_size, ep_size=ep_size, hidden_size=32, moe_grouped_gemm=False,
num_moe_experts=4, transformer_impl="modelopt",
)
copy_weights_from_grouped_to_non_grouped(te_grouped, sequential)

for module in te_grouped.modules():
if isinstance(module, TopKRouter):
module.topk = module.num_experts
for module in sequential.modules():
if isinstance(module, TopKRouter):
module.topk = module.num_experts

ref_te = forward(te_grouped)
ref_seq = forward(sequential)

mtq.quantize(te_grouped, quant_cfg, forward)
mtq.quantize(sequential, quant_cfg, forward)

out_te = forward(te_grouped)
out_seq = forward(sequential)

err_te = (out_te - ref_te).abs().mean().item()
err_seq = (out_seq - ref_seq).abs().mean().item()

if rank == 0:
print(
f"\n[default-amax] TEGrouped quant-err={err_te:.6f}, "
f"Sequential quant-err={err_seq:.6f}, ratio TE/Seq={err_te / max(err_seq, 1e-12):.3f}"
)

# At toy scale (4 small experts) the per-tensor amax difference is dominated
# by other numerical noise (~few %); the effect amplifies at production scale
# (e.g. 128 experts in Nemotron Nano). Just sanity-check both errors are finite.
assert err_te > 0 and err_seq > 0
assert math.isfinite(err_te) and math.isfinite(err_seq)


@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG])
def test_te_grouped_vs_sequential_default_loss(dist_workers_size_4, quant_cfg):
dist_workers_size_4.run(
partial(_test_te_grouped_vs_sequential_default_loss_helper, 1, 2, quant_cfg)
)


@pytest.mark.parametrize("ep_size", [1, 2])
@pytest.mark.parametrize("sync_weight_amax", [True, False])
def test_layer_sync_moe_local_experts_amax(dist_workers, ep_size, sync_weight_amax):
Expand Down
Loading