Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,13 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
if getattr(quantizer, "_amax", None) is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
# NVFP4StaticQuantizer._global_amax is computed locally during promotion, so for
# MoE EP (different experts per rank) it diverges across ranks. Reconcile it
# alongside the per-block _amax sync.
if isinstance(quantizer, NVFP4StaticQuantizer):
quantizer.sync_global_amax_across_distributed_group(
parallel_state.expert_model_parallel_group
)
# TODO: create sync_bias_across_distributed_group

# Step 2:Sync amax across data parallelism
Expand Down Expand Up @@ -341,6 +348,15 @@ def sync_quantizer_amax_across_tp(
if quantizer.axis in axes_for_sync and quantizer.amax is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.tensor_parallel_group)

# Reconcile NVFP4StaticQuantizer._global_amax across TP. The buffer is set
# locally during promotion from this rank's shard of _amax, so each TP rank
# otherwise carries a different value and the two-level NVFP4 scale becomes
# TP-layout-dependent.
if isinstance(quantizer, NVFP4StaticQuantizer):
quantizer.sync_global_amax_across_distributed_group(
parallel_state.tensor_parallel_group
)

# Step 2: Sync amax across relevant parallelism (such as TP / EP)
for name, module in model.named_modules():
if getattr(module, "_parallel_state", None) is None:
Expand Down
17 changes: 17 additions & 0 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,23 @@ def global_amax(self, value):
else:
self._global_amax.data.copy_(value.clone().detach().to(self._global_amax.device))

def sync_global_amax_across_distributed_group(self, parallel_group: DistributedProcessGroup):
"""All-reduce ``_global_amax`` (MAX) across the given group.

``_global_amax`` is computed locally from each rank's shard of ``_amax`` at
promotion time, so for TP/EP it must be reconciled across ranks before it is
used as the upper level of the two-level NVFP4 scale.
"""
if parallel_group.is_initialized() and getattr(self, "_global_amax", None) is not None:
try:
dist.all_reduce(self._global_amax, op=dist.ReduceOp.MAX, group=parallel_group.group)
except RuntimeError as e:
warnings.warn(
f"Failed to synchronize _global_amax: {e}, probably because the tensor "
"is on a device which is not supported by the current distributed backend. "
"This warning can be ignored if happening during modelopt restore."
)

def _fake_quantize(self, inputs):
"""Fake quantization using two-level scaling with _amax and _global_amax."""
if self.amax is not None:
Expand Down
Loading