From 839e0061499fc5518d6ddeeb3219e0f961a03ed8 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Tue, 12 May 2026 20:28:12 -0700 Subject: [PATCH 1/4] te_qad_debug: per-expert TEGrouped quantizer + logits debug print + infra fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit modelopt/torch/quantization/plugins/transformer_engine.py: MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER=1 opts into per-gemm weight_quantizer_0..N-1 inside _QuantTEGroupedLinear (deepcopied from the shared weight_quantizer). Lets TEGroupedMLP recover per-expert amax granularity, matching SequentialMLP's default behavior. modelopt/torch/distill/plugins/megatron.py: LogitsKLLoss.forward prints student/teacher logit stats (mean/std/ min/max/shape) on rank 0 each call. Diagnostic for the QAD loss-spike investigation — confirms which spec produces which logits without changing the KL math. tests/gpu_megatron/torch/quantization/plugins/test_megatron.py: New test_te_grouped_vs_sequential_default_amax + ..._default_loss cover the structural amax asymmetry between TEGroupedMLP and SequentialMLP (TEGrouped per-linear amax = max-over-Sequential-experts amax) and a finiteness sanity check on the resulting quant error. tools/launcher/common/service_utils.sh: - Fall back to SLURM_PROCID / SLURM_LOCALID when PMIX_*/OMPI_* are unset, so `[[ "$mpi_local_rank" -eq 0 ]]` doesn't silently pass on every rank under plain srun. - util_install_extra_dep: per-node marker so concurrent ranks wait for rank 0 to finish installing (concurrent pip on a shared FS leaves a broken state); also installs nvidia-resiliency-ext. Signed-off-by: Jennifer Chen --- modelopt/torch/distill/plugins/megatron.py | 16 +++ .../plugins/transformer_engine.py | 27 +++- .../quantization/plugins/test_megatron.py | 126 ++++++++++++++++++ tools/launcher/common/service_utils.sh | 2 + 4 files changed, 166 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index 9a98eee9c77..ac9c9b01b3c 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -318,6 +318,22 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: """ predictions, targets = self.pre_forward(predictions, targets) + if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: + if not hasattr(self, "_dbg_call"): + self._dbg_call = 0 + with torch.no_grad(): + s = predictions.float() + t = targets.float() + print( + f"[LogitsKLLoss call={self._dbg_call}] " + f"student: mean={s.mean().item():.5f} std={s.std().item():.5f} " + f"min={s.min().item():.3f} max={s.max().item():.3f} shape={tuple(predictions.shape)} | " + f"teacher: mean={t.mean().item():.5f} std={t.std().item():.5f} " + f"min={t.min().item():.3f} max={t.max().item():.3f}", + flush=True, + ) + self._dbg_call += 1 + # Division by temp should happen prior to finding max for both student and teacher. output_teacher = targets.float() / self._temperature output_student = predictions.float() / self._temperature diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index e670141f79a..7e675197e4c 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -15,7 +15,11 @@ """Support quantization for Transformer Engine layers.""" +import copy +import os import inspect +import copy +import os import warnings import torch @@ -33,6 +37,11 @@ _TE_VERSION = Version(te.__version__) +def _per_expert_weight_quantizer_enabled() -> bool: + """Opt-in MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER=1: per-gemm weight_quantizer in TEGroupedLinear.""" + return os.environ.get("MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER", "0") == "1" + + def _assert_te_fp8_enabled(): """Check if Transformer Engine FP8 autocast is enabled and raise error if so.""" try: @@ -137,8 +146,10 @@ def _setup(self): # Remove self.weight after setup. delattr(self, "weight") - # TODO: GroupedLinear supports weights split by `num_gemms`, to support quantization - # with static parameters beyond per-tensor, we need to support a unique quantizer for each gemm. + self._per_expert_weight_quantizer = _per_expert_weight_quantizer_enabled() + if self._per_expert_weight_quantizer: + for i in range(self.num_gemms): + self.add_module(f"weight_quantizer_{i}", copy.deepcopy(self.weight_quantizer)) def modelopt_post_restore(self, prefix: str = ""): # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to @@ -150,12 +161,17 @@ def modelopt_post_restore(self, prefix: str = ""): # Remove self.weight after post_restore. delattr(self, "weight") + def _get_weight_quantizer(self, gemm_idx: int): + if getattr(self, "_per_expert_weight_quantizer", False): + return getattr(self, f"weight_quantizer_{gemm_idx}") + return self.weight_quantizer + def iter_weights_for_calibration(self): """Yield ``(weight_i, weight_quantizer)`` for each of the ``num_gemms`` grouped weights.""" for i in range(self.num_gemms): weight_i = getattr(self, f"weight{i}", None) if weight_i is not None: - yield weight_i, self.weight_quantizer + yield weight_i, self._get_weight_quantizer(i) @staticmethod def te_grouped_quantized_linear_fn(package, func_name, self, *args): @@ -182,8 +198,9 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args): new_args = list(args) new_args[inp_pos] = self.input_quantizer(args[inp_pos]) - for i in range(weights_start, weights_start + num_gemms): - new_args[i] = self.weight_quantizer(args[i]) + for gemm_idx in range(num_gemms): + pos = weights_start + gemm_idx + new_args[pos] = self._get_weight_quantizer(gemm_idx)(args[pos]) output = getattr(package, func_name)(*new_args) # TE 2.15+ returns `(out, new_workspaces)`; TE <= 2.14 returns just `out`. # Only the activation tensor participates in output quantization. diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index c34fb2df376..710725af04c 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -14,6 +14,7 @@ # limitations under the License. import copy +import math from functools import partial import pytest @@ -694,6 +695,131 @@ def test_te_grouped_vs_sequential_quantize(dist_workers_size_4, quant_cfg): ) +def _test_te_grouped_vs_sequential_default_amax_helper(tp_size, ep_size, quant_cfg, rank, size): + """TEGrouped per-linear amax should equal max-over-Sequential-experts under default sync=False.""" + initialize_for_megatron( + tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + seed=SEED, + ) + + te_grouped = _gpt_model_provider( + tp_size=tp_size, ep_size=ep_size, hidden_size=32, moe_grouped_gemm=True, + transformer_impl="transformer_engine", num_moe_experts=4, + ) + forward = get_forward(te_grouped, batch_size=8) + + sequential = _gpt_model_provider( + tp_size=tp_size, ep_size=ep_size, hidden_size=32, moe_grouped_gemm=False, + num_moe_experts=4, transformer_impl="modelopt", + ) + copy_weights_from_grouped_to_non_grouped(te_grouped, sequential) + + for module in te_grouped.modules(): + if isinstance(module, TopKRouter): + module.topk = module.num_experts + for module in sequential.modules(): + if isinstance(module, TopKRouter): + module.topk = module.num_experts + + mtq.quantize(te_grouped, quant_cfg, forward) + mtq.quantize(sequential, quant_cfg, forward) + + te_modules = [m for m in te_grouped.modules() if isinstance(m, TEGroupedMLP)] + seq_modules = [m for m in sequential.modules() if isinstance(m, SequentialMLP)] + assert len(te_modules) == len(seq_modules) + + saw_per_expert_divergence = False + for te_mlp, seq_mlp in zip(te_modules, seq_modules): + for linear_name in ("linear_fc1", "linear_fc2"): + te_amax = getattr(te_mlp, linear_name).weight_quantizer.amax + assert te_amax is not None and te_amax.numel() == 1 + + seq_amaxes = torch.stack([ + getattr(expert, linear_name).weight_quantizer.amax.view(()) + for expert in seq_mlp.local_experts + ]) + seq_max = seq_amaxes.max() + + assert torch.allclose(te_amax.view(()), seq_max, atol=1e-5, rtol=1e-5), ( + f"TEGrouped per-linear amax ({te_amax.item()}) != " + f"max-over-Sequential-experts ({seq_max.item()}) for {linear_name}" + ) + + if (seq_amaxes.max() - seq_amaxes.min()).item() > 1e-5: + saw_per_expert_divergence = True + + assert saw_per_expert_divergence, ( + "Expected per-expert weight amax to diverge across SequentialMLP experts." + ) + + +@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]) +def test_te_grouped_vs_sequential_default_amax(dist_workers_size_4, quant_cfg): + dist_workers_size_4.run( + partial(_test_te_grouped_vs_sequential_default_amax_helper, 1, 2, quant_cfg) + ) + + +def _test_te_grouped_vs_sequential_default_loss_helper(tp_size, ep_size, quant_cfg, rank, size): + """TEGrouped quantized output should diverge from BF16 reference more than SequentialMLP under default sync=False.""" + initialize_for_megatron( + tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + seed=SEED, + ) + + te_grouped = _gpt_model_provider( + tp_size=tp_size, ep_size=ep_size, hidden_size=32, moe_grouped_gemm=True, + transformer_impl="transformer_engine", num_moe_experts=4, + ) + forward = get_forward(te_grouped, batch_size=8) + + sequential = _gpt_model_provider( + tp_size=tp_size, ep_size=ep_size, hidden_size=32, moe_grouped_gemm=False, + num_moe_experts=4, transformer_impl="modelopt", + ) + copy_weights_from_grouped_to_non_grouped(te_grouped, sequential) + + for module in te_grouped.modules(): + if isinstance(module, TopKRouter): + module.topk = module.num_experts + for module in sequential.modules(): + if isinstance(module, TopKRouter): + module.topk = module.num_experts + + ref_te = forward(te_grouped) + ref_seq = forward(sequential) + + mtq.quantize(te_grouped, quant_cfg, forward) + mtq.quantize(sequential, quant_cfg, forward) + + out_te = forward(te_grouped) + out_seq = forward(sequential) + + err_te = (out_te - ref_te).abs().mean().item() + err_seq = (out_seq - ref_seq).abs().mean().item() + + if rank == 0: + print( + f"\n[default-amax] TEGrouped quant-err={err_te:.6f}, " + f"Sequential quant-err={err_seq:.6f}, ratio TE/Seq={err_te / max(err_seq, 1e-12):.3f}" + ) + + # At toy scale (4 small experts) the per-tensor amax difference is dominated + # by other numerical noise (~few %); the effect amplifies at production scale + # (e.g. 128 experts in Nemotron Nano). Just sanity-check both errors are finite. + assert err_te > 0 and err_seq > 0 + assert math.isfinite(err_te) and math.isfinite(err_seq) + + +@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]) +def test_te_grouped_vs_sequential_default_loss(dist_workers_size_4, quant_cfg): + dist_workers_size_4.run( + partial(_test_te_grouped_vs_sequential_default_loss_helper, 1, 2, quant_cfg) + ) + + @pytest.mark.parametrize("ep_size", [1, 2]) @pytest.mark.parametrize("sync_weight_amax", [True, False]) def test_layer_sync_moe_local_experts_amax(dist_workers, ep_size, sync_weight_amax): diff --git a/tools/launcher/common/service_utils.sh b/tools/launcher/common/service_utils.sh index b3e9e3f725d..4f6e920b701 100755 --- a/tools/launcher/common/service_utils.sh +++ b/tools/launcher/common/service_utils.sh @@ -20,6 +20,8 @@ native_mpi_local_rank=$OMPI_COMM_WORLD_LOCAL_RANK # Works with Slurm launching with `--mpi=pmix` mpi_rank=${PMIX_RANK:-${native_mpi_rank:-${SLURM_PROCID:-0}}} mpi_local_rank=${PMIX_LOCAL_RANK:-${native_mpi_local_rank:-${SLURM_LOCALID:-0}}} +mpi_rank=${PMIX_RANK:-${native_mpi_rank:-${SLURM_PROCID:-0}}} +mpi_local_rank=${PMIX_LOCAL_RANK:-${native_mpi_local_rank:-${SLURM_LOCALID:-0}}} FAIL=0 FAIL_EXIT=0 From cc7953ee2221166ff876ed657f580109bf6002ed Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 27 May 2026 15:47:21 -0700 Subject: [PATCH 2/4] revert logging Signed-off-by: Jennifer Chen --- modelopt/torch/distill/plugins/megatron.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index ac9c9b01b3c..9a98eee9c77 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -318,22 +318,6 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: """ predictions, targets = self.pre_forward(predictions, targets) - if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: - if not hasattr(self, "_dbg_call"): - self._dbg_call = 0 - with torch.no_grad(): - s = predictions.float() - t = targets.float() - print( - f"[LogitsKLLoss call={self._dbg_call}] " - f"student: mean={s.mean().item():.5f} std={s.std().item():.5f} " - f"min={s.min().item():.3f} max={s.max().item():.3f} shape={tuple(predictions.shape)} | " - f"teacher: mean={t.mean().item():.5f} std={t.std().item():.5f} " - f"min={t.min().item():.3f} max={t.max().item():.3f}", - flush=True, - ) - self._dbg_call += 1 - # Division by temp should happen prior to finding max for both student and teacher. output_teacher = targets.float() / self._temperature output_student = predictions.float() / self._temperature From 4a58f0289422f63711750eda7c398f645630cc9f Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 27 May 2026 15:57:42 -0700 Subject: [PATCH 3/4] te_per_expert: dedup imports + add TP-change post-restore caveat MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - transformer_engine.py: dedup `import copy`/`import os` left over from the rebase, sort the four imports alphabetically. - transformer_engine.py: comment near the per-expert weight_quantizer setup explaining that base modelopt_post_restore won't re-calibrate the weight_quantizer_{i} modules, so save/restore is only safe when TP/EP is unchanged. Per-channel _amax shape depends on the TP-sliced output dim. - service_utils.sh: drop the duplicated mpi_rank / mpi_local_rank re-assignments — main already carries the SLURM fallback, the extra two lines were leftover rebase noise. Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/plugins/transformer_engine.py | 8 ++++++-- tools/launcher/common/service_utils.sh | 2 -- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index 7e675197e4c..3e06b08f6ec 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -16,9 +16,7 @@ """Support quantization for Transformer Engine layers.""" import copy -import os import inspect -import copy import os import warnings @@ -150,6 +148,12 @@ def _setup(self): if self._per_expert_weight_quantizer: for i in range(self.num_gemms): self.add_module(f"weight_quantizer_{i}", copy.deepcopy(self.weight_quantizer)) + # NOTE: base modelopt_post_restore only re-calibrates self.weight_quantizer. + # The per-expert weight_quantizer_{i} _amax buffers ride through state_dict + # unchanged, which works when TP/EP is identical between save and restore. + # If parallelism changes (per-channel _amax shape depends on the TP-sliced + # output dim), the loaded _amax won't match the new weight shape — needs a + # custom modelopt_post_restore that re-runs max_calibrate per expert. def modelopt_post_restore(self, prefix: str = ""): # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to diff --git a/tools/launcher/common/service_utils.sh b/tools/launcher/common/service_utils.sh index 4f6e920b701..b3e9e3f725d 100755 --- a/tools/launcher/common/service_utils.sh +++ b/tools/launcher/common/service_utils.sh @@ -20,8 +20,6 @@ native_mpi_local_rank=$OMPI_COMM_WORLD_LOCAL_RANK # Works with Slurm launching with `--mpi=pmix` mpi_rank=${PMIX_RANK:-${native_mpi_rank:-${SLURM_PROCID:-0}}} mpi_local_rank=${PMIX_LOCAL_RANK:-${native_mpi_local_rank:-${SLURM_LOCALID:-0}}} -mpi_rank=${PMIX_RANK:-${native_mpi_rank:-${SLURM_PROCID:-0}}} -mpi_local_rank=${PMIX_LOCAL_RANK:-${native_mpi_local_rank:-${SLURM_LOCALID:-0}}} FAIL=0 FAIL_EXIT=0 From b1e32d9c0a6b5f889c045169fed8dec3fc584d79 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 27 May 2026 16:07:33 -0700 Subject: [PATCH 4/4] sync weight_quantizer_{i} during post restore Signed-off-by: Jennifer Chen --- .../plugins/transformer_engine.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index 3e06b08f6ec..e3a87927fd3 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -29,7 +29,7 @@ from modelopt.torch.quantization.utils import replace_function -from ..nn import QuantModuleRegistry +from ..nn import QuantModuleRegistry, SequentialQuantizer from .custom import _ParallelLinear _TE_VERSION = Version(te.__version__) @@ -148,12 +148,6 @@ def _setup(self): if self._per_expert_weight_quantizer: for i in range(self.num_gemms): self.add_module(f"weight_quantizer_{i}", copy.deepcopy(self.weight_quantizer)) - # NOTE: base modelopt_post_restore only re-calibrates self.weight_quantizer. - # The per-expert weight_quantizer_{i} _amax buffers ride through state_dict - # unchanged, which works when TP/EP is identical between save and restore. - # If parallelism changes (per-channel _amax shape depends on the TP-sliced - # output dim), the loaded _amax won't match the new weight shape — needs a - # custom modelopt_post_restore that re-runs max_calibrate per expert. def modelopt_post_restore(self, prefix: str = ""): # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to @@ -165,6 +159,24 @@ def modelopt_post_restore(self, prefix: str = ""): # Remove self.weight after post_restore. delattr(self, "weight") + # Base post_restore only re-calibrates self.weight_quantizer; the per-expert + # weight_quantizer_{i} also need re-calibration so a TP/EP change between save + # and restore produces correctly shaped per-channel _amax. Mirror base behavior: + # only re-calibrate quantizers whose loaded state had _amax (skip unused ones). + if getattr(self, "_per_expert_weight_quantizer", False): + from modelopt.torch.quantization.model_calib import max_calibrate + + for i in range(self.num_gemms): + weight_i = getattr(self, f"weight{i}", None) + if weight_i is None: + continue + wq_i = self._get_weight_quantizer(i) + q = wq_i[0] if isinstance(wq_i, SequentialQuantizer) else wq_i + if not hasattr(q, "_amax"): + continue + wq_i.reset_amax() + max_calibrate(wq_i, lambda wq, w=weight_i: wq(w), distributed_sync=False) + def _get_weight_quantizer(self, gemm_idx: int): if getattr(self, "_per_expert_weight_quantizer", False): return getattr(self, f"weight_quantizer_{gemm_idx}")