diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index e670141f79a..e3a87927fd3 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -15,7 +15,9 @@ """Support quantization for Transformer Engine layers.""" +import copy import inspect +import os import warnings import torch @@ -27,12 +29,17 @@ from modelopt.torch.quantization.utils import replace_function -from ..nn import QuantModuleRegistry +from ..nn import QuantModuleRegistry, SequentialQuantizer from .custom import _ParallelLinear _TE_VERSION = Version(te.__version__) +def _per_expert_weight_quantizer_enabled() -> bool: + """Opt-in MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER=1: per-gemm weight_quantizer in TEGroupedLinear.""" + return os.environ.get("MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER", "0") == "1" + + def _assert_te_fp8_enabled(): """Check if Transformer Engine FP8 autocast is enabled and raise error if so.""" try: @@ -137,8 +144,10 @@ def _setup(self): # Remove self.weight after setup. delattr(self, "weight") - # TODO: GroupedLinear supports weights split by `num_gemms`, to support quantization - # with static parameters beyond per-tensor, we need to support a unique quantizer for each gemm. + self._per_expert_weight_quantizer = _per_expert_weight_quantizer_enabled() + if self._per_expert_weight_quantizer: + for i in range(self.num_gemms): + self.add_module(f"weight_quantizer_{i}", copy.deepcopy(self.weight_quantizer)) def modelopt_post_restore(self, prefix: str = ""): # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to @@ -150,12 +159,35 @@ def modelopt_post_restore(self, prefix: str = ""): # Remove self.weight after post_restore. delattr(self, "weight") + # Base post_restore only re-calibrates self.weight_quantizer; the per-expert + # weight_quantizer_{i} also need re-calibration so a TP/EP change between save + # and restore produces correctly shaped per-channel _amax. Mirror base behavior: + # only re-calibrate quantizers whose loaded state had _amax (skip unused ones). + if getattr(self, "_per_expert_weight_quantizer", False): + from modelopt.torch.quantization.model_calib import max_calibrate + + for i in range(self.num_gemms): + weight_i = getattr(self, f"weight{i}", None) + if weight_i is None: + continue + wq_i = self._get_weight_quantizer(i) + q = wq_i[0] if isinstance(wq_i, SequentialQuantizer) else wq_i + if not hasattr(q, "_amax"): + continue + wq_i.reset_amax() + max_calibrate(wq_i, lambda wq, w=weight_i: wq(w), distributed_sync=False) + + def _get_weight_quantizer(self, gemm_idx: int): + if getattr(self, "_per_expert_weight_quantizer", False): + return getattr(self, f"weight_quantizer_{gemm_idx}") + return self.weight_quantizer + def iter_weights_for_calibration(self): """Yield ``(weight_i, weight_quantizer)`` for each of the ``num_gemms`` grouped weights.""" for i in range(self.num_gemms): weight_i = getattr(self, f"weight{i}", None) if weight_i is not None: - yield weight_i, self.weight_quantizer + yield weight_i, self._get_weight_quantizer(i) @staticmethod def te_grouped_quantized_linear_fn(package, func_name, self, *args): @@ -182,8 +214,9 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args): new_args = list(args) new_args[inp_pos] = self.input_quantizer(args[inp_pos]) - for i in range(weights_start, weights_start + num_gemms): - new_args[i] = self.weight_quantizer(args[i]) + for gemm_idx in range(num_gemms): + pos = weights_start + gemm_idx + new_args[pos] = self._get_weight_quantizer(gemm_idx)(args[pos]) output = getattr(package, func_name)(*new_args) # TE 2.15+ returns `(out, new_workspaces)`; TE <= 2.14 returns just `out`. # Only the activation tensor participates in output quantization. diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index c34fb2df376..710725af04c 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -14,6 +14,7 @@ # limitations under the License. import copy +import math from functools import partial import pytest @@ -694,6 +695,131 @@ def test_te_grouped_vs_sequential_quantize(dist_workers_size_4, quant_cfg): ) +def _test_te_grouped_vs_sequential_default_amax_helper(tp_size, ep_size, quant_cfg, rank, size): + """TEGrouped per-linear amax should equal max-over-Sequential-experts under default sync=False.""" + initialize_for_megatron( + tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + seed=SEED, + ) + + te_grouped = _gpt_model_provider( + tp_size=tp_size, ep_size=ep_size, hidden_size=32, moe_grouped_gemm=True, + transformer_impl="transformer_engine", num_moe_experts=4, + ) + forward = get_forward(te_grouped, batch_size=8) + + sequential = _gpt_model_provider( + tp_size=tp_size, ep_size=ep_size, hidden_size=32, moe_grouped_gemm=False, + num_moe_experts=4, transformer_impl="modelopt", + ) + copy_weights_from_grouped_to_non_grouped(te_grouped, sequential) + + for module in te_grouped.modules(): + if isinstance(module, TopKRouter): + module.topk = module.num_experts + for module in sequential.modules(): + if isinstance(module, TopKRouter): + module.topk = module.num_experts + + mtq.quantize(te_grouped, quant_cfg, forward) + mtq.quantize(sequential, quant_cfg, forward) + + te_modules = [m for m in te_grouped.modules() if isinstance(m, TEGroupedMLP)] + seq_modules = [m for m in sequential.modules() if isinstance(m, SequentialMLP)] + assert len(te_modules) == len(seq_modules) + + saw_per_expert_divergence = False + for te_mlp, seq_mlp in zip(te_modules, seq_modules): + for linear_name in ("linear_fc1", "linear_fc2"): + te_amax = getattr(te_mlp, linear_name).weight_quantizer.amax + assert te_amax is not None and te_amax.numel() == 1 + + seq_amaxes = torch.stack([ + getattr(expert, linear_name).weight_quantizer.amax.view(()) + for expert in seq_mlp.local_experts + ]) + seq_max = seq_amaxes.max() + + assert torch.allclose(te_amax.view(()), seq_max, atol=1e-5, rtol=1e-5), ( + f"TEGrouped per-linear amax ({te_amax.item()}) != " + f"max-over-Sequential-experts ({seq_max.item()}) for {linear_name}" + ) + + if (seq_amaxes.max() - seq_amaxes.min()).item() > 1e-5: + saw_per_expert_divergence = True + + assert saw_per_expert_divergence, ( + "Expected per-expert weight amax to diverge across SequentialMLP experts." + ) + + +@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]) +def test_te_grouped_vs_sequential_default_amax(dist_workers_size_4, quant_cfg): + dist_workers_size_4.run( + partial(_test_te_grouped_vs_sequential_default_amax_helper, 1, 2, quant_cfg) + ) + + +def _test_te_grouped_vs_sequential_default_loss_helper(tp_size, ep_size, quant_cfg, rank, size): + """TEGrouped quantized output should diverge from BF16 reference more than SequentialMLP under default sync=False.""" + initialize_for_megatron( + tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + seed=SEED, + ) + + te_grouped = _gpt_model_provider( + tp_size=tp_size, ep_size=ep_size, hidden_size=32, moe_grouped_gemm=True, + transformer_impl="transformer_engine", num_moe_experts=4, + ) + forward = get_forward(te_grouped, batch_size=8) + + sequential = _gpt_model_provider( + tp_size=tp_size, ep_size=ep_size, hidden_size=32, moe_grouped_gemm=False, + num_moe_experts=4, transformer_impl="modelopt", + ) + copy_weights_from_grouped_to_non_grouped(te_grouped, sequential) + + for module in te_grouped.modules(): + if isinstance(module, TopKRouter): + module.topk = module.num_experts + for module in sequential.modules(): + if isinstance(module, TopKRouter): + module.topk = module.num_experts + + ref_te = forward(te_grouped) + ref_seq = forward(sequential) + + mtq.quantize(te_grouped, quant_cfg, forward) + mtq.quantize(sequential, quant_cfg, forward) + + out_te = forward(te_grouped) + out_seq = forward(sequential) + + err_te = (out_te - ref_te).abs().mean().item() + err_seq = (out_seq - ref_seq).abs().mean().item() + + if rank == 0: + print( + f"\n[default-amax] TEGrouped quant-err={err_te:.6f}, " + f"Sequential quant-err={err_seq:.6f}, ratio TE/Seq={err_te / max(err_seq, 1e-12):.3f}" + ) + + # At toy scale (4 small experts) the per-tensor amax difference is dominated + # by other numerical noise (~few %); the effect amplifies at production scale + # (e.g. 128 experts in Nemotron Nano). Just sanity-check both errors are finite. + assert err_te > 0 and err_seq > 0 + assert math.isfinite(err_te) and math.isfinite(err_seq) + + +@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]) +def test_te_grouped_vs_sequential_default_loss(dist_workers_size_4, quant_cfg): + dist_workers_size_4.run( + partial(_test_te_grouped_vs_sequential_default_loss_helper, 1, 2, quant_cfg) + ) + + @pytest.mark.parametrize("ep_size", [1, 2]) @pytest.mark.parametrize("sync_weight_amax", [True, False]) def test_layer_sync_moe_local_experts_amax(dist_workers, ep_size, sync_weight_amax):