From a03db8edb7f96a415797b95c42c1987b889486c1 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Thu, 28 May 2026 06:49:36 +0000 Subject: [PATCH] fix(quant): sync NVFP4StaticQuantizer global_amax across TP and EP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `promote_nvfp4_static_quantizers` registers `_global_amax` from `reduce_amax(local _amax)` before the distributed-sync block in `max_calibrate`. With TP (different weight shards per rank) or MoE EP (different experts per rank), each rank's local `_amax` covers a different subset of the global weight, so the resulting `_global_amax` diverges across ranks. The existing TP/EP sync only touches `_amax`, leaving `_global_amax` permanently rank-local — making the upper level of the two-level NVFP4 scale TP/EP-layout-dependent. Add `NVFP4StaticQuantizer.sync_global_amax_across_distributed_group` (mirroring the existing `sync_amax_across_distributed_group` on `TensorQuantizer`) and call it from both `sync_quantizer_amax_across_dp_ep` (EP group) and `sync_quantizer_amax_across_tp` (TP group) alongside the existing `_amax` sync. DP doesn't need a separate call because weights are replicated across DP, so `_global_amax` is naturally identical. Co-authored-by: Claude Opus 4.7 (1M context) Signed-off-by: Chenjie Luo --- modelopt/torch/quantization/model_calib.py | 16 ++++++++++++++++ .../quantization/nn/modules/tensor_quantizer.py | 17 +++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 78b237847b1..dd8af9e5da1 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -291,6 +291,13 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): if getattr(quantizer, "_amax", None) is not None: quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) + # NVFP4StaticQuantizer._global_amax is computed locally during promotion, so for + # MoE EP (different experts per rank) it diverges across ranks. Reconcile it + # alongside the per-block _amax sync. + if isinstance(quantizer, NVFP4StaticQuantizer): + quantizer.sync_global_amax_across_distributed_group( + parallel_state.expert_model_parallel_group + ) # TODO: create sync_bias_across_distributed_group # Step 2:Sync amax across data parallelism @@ -341,6 +348,15 @@ def sync_quantizer_amax_across_tp( if quantizer.axis in axes_for_sync and quantizer.amax is not None: quantizer.sync_amax_across_distributed_group(parallel_state.tensor_parallel_group) + # Reconcile NVFP4StaticQuantizer._global_amax across TP. The buffer is set + # locally during promotion from this rank's shard of _amax, so each TP rank + # otherwise carries a different value and the two-level NVFP4 scale becomes + # TP-layout-dependent. + if isinstance(quantizer, NVFP4StaticQuantizer): + quantizer.sync_global_amax_across_distributed_group( + parallel_state.tensor_parallel_group + ) + # Step 2: Sync amax across relevant parallelism (such as TP / EP) for name, module in model.named_modules(): if getattr(module, "_parallel_state", None) is None: diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 5e3cea44c2a..4a8efa3e8dc 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -1353,6 +1353,23 @@ def global_amax(self, value): else: self._global_amax.data.copy_(value.clone().detach().to(self._global_amax.device)) + def sync_global_amax_across_distributed_group(self, parallel_group: DistributedProcessGroup): + """All-reduce ``_global_amax`` (MAX) across the given group. + + ``_global_amax`` is computed locally from each rank's shard of ``_amax`` at + promotion time, so for TP/EP it must be reconciled across ranks before it is + used as the upper level of the two-level NVFP4 scale. + """ + if parallel_group.is_initialized() and getattr(self, "_global_amax", None) is not None: + try: + dist.all_reduce(self._global_amax, op=dist.ReduceOp.MAX, group=parallel_group.group) + except RuntimeError as e: + warnings.warn( + f"Failed to synchronize _global_amax: {e}, probably because the tensor " + "is on a device which is not supported by the current distributed backend. " + "This warning can be ignored if happening during modelopt restore." + ) + def _fake_quantize(self, inputs): """Fake quantization using two-level scaling with _amax and _global_amax.""" if self.amax is not None: