Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 38 additions & 29 deletions tests/pytorch/distributed/run_fsdp2_fp8_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model")
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model")
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
"--quantized-init", action="store_true", default=False, help="Initialize primary weights in FP8 via quantized_model_init."
)
parser.add_argument(
"--autocast", action="store_true", default=False, help="Enable te.autocast for FP8 compute."
)
parser.add_argument(
"--iter", type=int, default=10, help="Number of iterations for forward pass"
Expand Down Expand Up @@ -169,15 +172,34 @@ def _train(args):

if args.memory_profile:
torch.cuda.memory._record_memory_history(enabled='all', context='all', stacks='all')
if args.fp8_init:
# Build the model with the specified context
with quantized_model_init(enabled=True):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2)
else:

prof = None
if (
args.profile
and torch.distributed.get_rank() in args.profile_ranks
):
prof = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=max(args.profile_step_start - 1, 0),
warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end - args.profile_step_start,
repeat=1,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
prof.start()

# Build the model with the specified context
with quantized_model_init(enabled=args.quantized_init, recipe=fp8_recipe):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2)
# Move the model to the correct device
if not args.memory_profile:
model.load_state_dict(torch.load('fsdp_model.pth'))
if not args.memory_profile and not args.profile:
# weights_only = False when we have fp8 param in state dict
model.load_state_dict(torch.load('fsdp_model.pth', weights_only=not args.quantized_init))
model.to(device)

# Creating a DeviceMesh for fully_shard
Expand Down Expand Up @@ -215,7 +237,10 @@ def _train(args):
else:
model = DDP(model, device_ids=[LOCAL_RANK])

optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3)
if args.quantized_init:
optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3, master_weights=True)
else:
optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3)

input_path = Path("shared_input.pt")
if input_path.exists():
Expand All @@ -226,25 +251,6 @@ def _train(args):
print("Generated and saved shared input tensor.")

out_tensors = []
prof = None
if (
args.profile
and torch.distributed.get_rank() in args.profile_ranks
):
prof = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=max(args.profile_step_start - 1, 0),
warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end - args.profile_step_start,
repeat=1,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
prof.start()
for iteration in range(args.iter):
if LOCAL_RANK == 0:
print(f"Starting iteration...{iteration}")
Expand All @@ -253,7 +259,7 @@ def _train(args):

# Zero the parameter gradients
optimizer.zero_grad()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does with te.fp8_autocast(enabled=args.fp8_autocast,.. ) do the same?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does do the same but since with TEv2.10, te.fp8_autocast is replaced with te.autocast, I've made the change to be consistent.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So will 'with te.autocast(enabled=args.fp8_autocast, recipe=...)' do the same as if/else?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should. I'll make the changes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

with te.autocast(enabled=args.autocast, recipe=fp8_recipe):
output = model(input_data)
target = torch.randn(args.batch_size, args.output_size).to(device)
loss = F.mse_loss(output, target)
Expand Down Expand Up @@ -286,6 +292,9 @@ def _train(args):
torch.save(out_tensors, args.gradients_save_file)

if args.memory_profile:
with open('memory_summary.txt', 'w') as f:
f.write(torch.cuda.memory_summary(device=None, abbreviated=False))

snapshot = torch.cuda.memory._snapshot()
import pickle
with open('memory_snapshot.pickle', 'wb') as f:
Expand Down
106 changes: 80 additions & 26 deletions tests/pytorch/distributed/test_torch_fsdp2_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,55 +11,93 @@
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
import torch
from run_fsdp2_fp8_model import SimpleNet

from transformer_engine.pytorch.quantization import quantized_model_init
from transformer_engine.common.recipe import Float8CurrentScaling, DelayedScaling, MXFP8BlockScaling
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()

NUM_PROCS: int = torch.cuda.device_count()

def assertEqual(
l1: List[torch.Tensor], l2: List[torch.Tensor]) -> bool:
"""Ensures two lists are exactly equal."""
def assert_allclose(
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
tols = dict(atol=atol)
tols["rtol"] = rtol if rtol is not None else 0
tol = tols["atol"] + (tols["rtol"] * torch.abs(l2))
for i, (t1, t2) in enumerate(zip(l1, l2)):
result = torch.allclose(t1, t2, atol=0, rtol=0)
result = torch.allclose(t1, t2, **tols)
if not result:
diff = torch.abs(t1 - t2)
exceed_mask = diff > 0
if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
max_diff = diff[exceed_mask].max()
max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
max_location = [idx[max_idx].item() for idx in indices]
if diff.dim() == 0:
max_diff = diff
max_location = []
msg = (
f"Outputs not close enough in tensor at idx={i}. "
f"Maximum difference at location {max_location} "
f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
f"(diff {max_diff.item()})."
f"Outputs not close enough in scalar tensor at idx={i}. "
f"Difference: {max_diff.item()}."
)
else:
exceed_mask = diff > tol

if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
max_diff = diff[exceed_mask].max()
max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
max_location = [idx[max_idx].item() for idx in indices]
msg = (
f"Outputs not close enough in tensor at idx={i}. "
f"Maximum difference at location {max_location} "
f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
f"(diff {max_diff.item()})."
)
raise AssertionError(msg)

def _run_test(fp_init, recipe):
def _run_test(quantized_init, autocast, recipe):
test_dir = Path(__file__).parent.resolve()
fsdp_script = test_dir / "run_fsdp2_fp8_model.py"

test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", "--master-port=29501", str(fsdp_script)]

if fp_init:
test_cmd += ["--fp8-init"]
test_cmd += ["--recipe", recipe]
if quantized_init:
test_cmd += ["--quantized-init"]
if autocast:
test_cmd += ["--autocast"]
if autocast or quantized_init:
test_cmd += ["--recipe", recipe]

subprocess.run(test_cmd + ['--use-fsdp2','--gradients-save-file', 'all_iters_fsdp2.pt'], env=os.environ, check=True)
subprocess.run(test_cmd + ['--gradients-save-file', 'all_iters_dp.pt'], env=os.environ, check=True)

# Load outputs
output_fsdp = torch.load("all_iters_fsdp2.pt", map_location="cpu")
output_dp = torch.load("all_iters_dp.pt", map_location="cpu")
atol = 0
rtol = 0
# Use relaxed tolerance when FSDP2 and DDP are not guaranteed to be bit-identical:
#
# - quantized_init=True: With FSDP2, the FP8 Adam optimizer
# re-quantizes FP32 master weights back to FP8 using a scale derived from
# the previous iteration's shard-local max_abs (each rank only sees its
# weight shard). In DDP, the same scale is derived from the full tensor's
# max_abs. This scale difference causes different FP8 rounding, which
# compounds over iterations. Hence we use a relaxed tolerance.
#
# - No FP8 (quantized_init=False, autocast=False): gradient reduction order differs
# (all-reduce vs reduce-scatter), so float non-associativity produces last-bit
# differences in the reduced gradients and updated weights. Hence we use a relaxed tolerance.
#
# When autocast=True and quantized_init=False, FP8 quantization happens after the
# FSDP2 AllGather reconstructs the full weight, so both paths compute identical
# scales and produce bit-identical FP8 GEMMs — strict tolerance (0) is used.
if quantized_init or (not quantized_init and not autocast):
atol = 1e-6
rtol = 5e-5
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If our reference is ddp with the same fp8 primary weight, then the same cast from fp32 master weight to fp8 happens in both target and reference flow. Then we will have exact match?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When quantized_init=True, weights are stored as FP8 parameters.
After each optimizer step, the FP8-aware Adam kernel (AdamFunctorMaster<fp8_e4m3> in adam.cu) re-quantizes the updated FP32 master weights back to FP8. This re-quantization requires a scale, which is derived from the amax written during the previous optimizer step for delayed scaling.

With FSDP2, each rank only holds a shard of the weight, so the optimizer computes amax = max_abs(master_weight_shard) —> a shard-local value.

With DDP, each rank holds the full weight, so it computes amax = max_abs(master_weight_full).

Although the shard-local amaxes are all-reduced before computing the next iteration's scale, the scale used for the current re-quantization was already derived from the previous iteration's shard-local amax. Since max_abs(shard) might not be equal to max_abs(full_tensor), the FSDP2 and DDP scales can differ, leading to different FP8 rounding. This compounds over iterations.

With autocast-only (no quantized_init), this doesn't happen because weights remain in BF16/FP32 — the optimizer uses regular Adam with no FP8 re-quantization.

fp8_data.max = fmaxf(fabsf(r_p[ii]), fp8_data.max);

if constexpr (is_fp8_type) {
fp8_data.max = transformer_engine::reduce_max<BLOCK_SIZE / THREADS_PER_WARP>(
fp8_data.max, fp8_data.warp_id);
if (threadIdx.x == 0) {
if (fp8_data.amax_ptr != nullptr) {
transformer_engine::atomicMaxFloat(fp8_data.amax_ptr, fp8_data.max);
}
if (fp8_data.scale_inv_ptr != nullptr) {
*fp8_data.scale_inv_ptr = __frcp_rn(fp8_data.scale);
}
}
}


for idx, (te_output_no_cache, te_output_cache) in enumerate(zip(output_fsdp, output_dp)):

print(f"Comparing FSDP {te_output_no_cache[0]}, DDP {te_output_cache[0]} at index {idx}...")
assertEqual(te_output_no_cache[1], te_output_cache[1]) # expects exact match
assert_allclose(te_output_no_cache[1], te_output_cache[1], atol=atol, rtol=rtol)
print(f"Tensor at index {idx} passed comparison.")


Expand All @@ -70,13 +108,24 @@ def cleanup_artifacts():
if os.path.exists(fname):
os.remove(fname)

# Define test cases explicitly
test_cases = []
# All FP8 enabled cases (all recipes)
for quantized_init in [True, False]:
for autocast in [True, False]:
if quantized_init or autocast:
for recipe in ["delayed", "current", "mxfp8"]:
test_cases.append((quantized_init, autocast, recipe))
# FP8 disabled case (only once)
test_cases.append((False, False, "delayed"))


@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs")
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.parametrize("fp8_init", ([False]))
@pytest.mark.parametrize("recipe", (["delayed", "current", "mxfp8"]))
@pytest.mark.parametrize("quantized_init, autocast, recipe", test_cases)
@pytest.mark.usefixtures("cleanup_artifacts")
def test_distributed(fp8_init, recipe):
def test_distributed(quantized_init, autocast, recipe):

batch_size = 2048
input_size = 2048
Expand All @@ -90,18 +139,23 @@ def test_distributed(fp8_init, recipe):
torch.save(input_data.cpu(), input_path)
print("Generated and saved shared input tensor.")

model = SimpleNet(input_size, 2048, 2048)
if quantized_init:
fp8_recipe = {"delayed": DelayedScaling(), "current": Float8CurrentScaling(), "mxfp8": MXFP8BlockScaling()}[recipe]
with quantized_model_init(enabled=True, recipe=fp8_recipe):
model = SimpleNet(input_size, 2048, 2048)
else:
model = SimpleNet(input_size, 2048, 2048)
torch.save(model.state_dict(), 'fsdp_model.pth')

if torch.cuda.device_count() < 4:
pytest.skip("FSDP2 test requires at least 4 GPUs")

if fp8_init and not fp8_available:
if quantized_init and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

_run_test(fp8_init, recipe)
_run_test(quantized_init, autocast, recipe)


def test_dummy() -> None:
Expand Down
7 changes: 3 additions & 4 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch.nn.functional as F
from torch.distributed.tensor import DTensor
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from torch.distributed.tensor import DTensor

import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
Expand Down Expand Up @@ -1053,7 +1054,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
fp8_enabled = self.fp8 or self.fp8_calibration
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration

if IS_HIP_EXTENSION and not FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 and hasattr(self, 'use_fsdp2') and self.use_fsdp2:
if IS_HIP_EXTENSION and not self.primary_weights_in_fp8 and not FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 and hasattr(self, 'use_fsdp2') and self.use_fsdp2:
FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 = True

if self.fp8_parameters or fp8_enabled:
Expand Down Expand Up @@ -1358,9 +1359,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
self.keep_fp8_weight_transpose_cache = False
param = FSDPAGTensor(
param,
module=self,
fp8_meta_index=fp8_meta_index,
keep_fp8_weight_transpose_cache=self.keep_fp8_weight_transpose_cache
fp8_meta_index=fp8_meta_index,
)

# Redo parameter wrap in case we broke it above
Expand Down
24 changes: 22 additions & 2 deletions transformer_engine/pytorch/optimizers/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
from torch.distributed._tensor import DTensor
from transformer_engine.pytorch.tensor.fsdp2_allgather_tensor import FSDPAGTensor
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
Expand Down Expand Up @@ -377,8 +378,17 @@ def _initialize_state(
store_param_remainders (bool): Store only trailing remainder bits.
"""
dtype = self.name_to_dtype_map[state_name]
# (Upstream fix https://github.com/NVIDIA/TransformerEngine/commit/139c863f92420271bbae2cbce49d9b170b7d03f9)
# Extract local tensor from DTensor (e.g. from FSDP2) to avoid
# QuantizedTensor.__torch_dispatch__ ignoring the dtype kwarg in
# torch.empty_like, and to ensure optimizer states are plain tensors.
local_param = param._local_tensor if isinstance(param, DTensor) else param
# ROCm fix: FSDPAGTensor is a wrapper around a plain tensor, so we need to extract the underlying tensor.
local_param = local_param._data if isinstance(local_param, FSDPAGTensor) else local_param
# Handle QuantizedTensor by dequantizing first
param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param
param_for_empty = (
local_param.dequantize() if isinstance(local_param, QuantizedTensor) else local_param
)
if store_param_remainders:
data = torch.zeros_like(param_for_empty, dtype=torch.int16)
else:
Expand Down Expand Up @@ -423,7 +433,17 @@ def initialize_state(self, param, store_param_remainders):
store_param_remainders=store_param_remainders,
)
if not store_param_remainders:
self.set_scaled_state(param, "master_param", param.clone().detach().float())
# (Upstream fix https://github.com/NVIDIA/TransformerEngine/commit/139c863f92420271bbae2cbce49d9b170b7d03f9)
#Extract local tensor from DTensor and dequantize QuantizedTensor
# to get a plain float32 copy for the master weight.
local_param = param._local_tensor if isinstance(param, DTensor) else param
# ROCm fix: FSDPAGTensor is a wrapper around a plain tensor, so we need to extract the underlying tensor.
local_param = local_param._data if isinstance(local_param, FSDPAGTensor) else local_param
if isinstance(local_param, QuantizedTensor):
master = local_param.dequantize(dtype=torch.float32).clone().detach()
else:
master = local_param.clone().detach().float()
self.set_scaled_state(param, "master_param", master)

def state_dict(self):
"""Override the state_dict() of pytorch. Before returning the state_dict, cast all
Expand Down
Loading
Loading