-
Notifications
You must be signed in to change notification settings - Fork 25
Add fsdp2 fp8 unit tests TE 2.10 #492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
e8e63b1
c3e33e3
db36143
13b4007
d91241f
2b8818d
8964d56
54938d9
f771955
c1949d3
dade028
40ba9eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -11,55 +11,93 @@ | |||||||||||||||||||||||||||
| from transformer_engine.pytorch.quantization import FP8GlobalStateManager | ||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||
| from run_fsdp2_fp8_model import SimpleNet | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| from transformer_engine.pytorch.quantization import quantized_model_init | ||||||||||||||||||||||||||||
| from transformer_engine.common.recipe import Float8CurrentScaling, DelayedScaling, MXFP8BlockScaling | ||||||||||||||||||||||||||||
| fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() | ||||||||||||||||||||||||||||
| mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| NUM_PROCS: int = torch.cuda.device_count() | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def assertEqual( | ||||||||||||||||||||||||||||
| l1: List[torch.Tensor], l2: List[torch.Tensor]) -> bool: | ||||||||||||||||||||||||||||
| """Ensures two lists are exactly equal.""" | ||||||||||||||||||||||||||||
| def assert_allclose( | ||||||||||||||||||||||||||||
| l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None | ||||||||||||||||||||||||||||
| ) -> bool: | ||||||||||||||||||||||||||||
| """Ensures two lists are equal.""" | ||||||||||||||||||||||||||||
| assert len(l1) == len(l2), "Unequal number of outputs." | ||||||||||||||||||||||||||||
| tols = dict(atol=atol) | ||||||||||||||||||||||||||||
| tols["rtol"] = rtol if rtol is not None else 0 | ||||||||||||||||||||||||||||
| tol = tols["atol"] + (tols["rtol"] * torch.abs(l2)) | ||||||||||||||||||||||||||||
| for i, (t1, t2) in enumerate(zip(l1, l2)): | ||||||||||||||||||||||||||||
| result = torch.allclose(t1, t2, atol=0, rtol=0) | ||||||||||||||||||||||||||||
| result = torch.allclose(t1, t2, **tols) | ||||||||||||||||||||||||||||
| if not result: | ||||||||||||||||||||||||||||
| diff = torch.abs(t1 - t2) | ||||||||||||||||||||||||||||
| exceed_mask = diff > 0 | ||||||||||||||||||||||||||||
| if exceed_mask.any(): | ||||||||||||||||||||||||||||
| indices = torch.nonzero(exceed_mask, as_tuple=True) | ||||||||||||||||||||||||||||
| max_diff = diff[exceed_mask].max() | ||||||||||||||||||||||||||||
| max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0] | ||||||||||||||||||||||||||||
| max_location = [idx[max_idx].item() for idx in indices] | ||||||||||||||||||||||||||||
| if diff.dim() == 0: | ||||||||||||||||||||||||||||
| max_diff = diff | ||||||||||||||||||||||||||||
| max_location = [] | ||||||||||||||||||||||||||||
| msg = ( | ||||||||||||||||||||||||||||
| f"Outputs not close enough in tensor at idx={i}. " | ||||||||||||||||||||||||||||
| f"Maximum difference at location {max_location} " | ||||||||||||||||||||||||||||
| f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} " | ||||||||||||||||||||||||||||
| f"(diff {max_diff.item()})." | ||||||||||||||||||||||||||||
| f"Outputs not close enough in scalar tensor at idx={i}. " | ||||||||||||||||||||||||||||
| f"Difference: {max_diff.item()}." | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||
| exceed_mask = diff > tol | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if exceed_mask.any(): | ||||||||||||||||||||||||||||
| indices = torch.nonzero(exceed_mask, as_tuple=True) | ||||||||||||||||||||||||||||
| max_diff = diff[exceed_mask].max() | ||||||||||||||||||||||||||||
| max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0] | ||||||||||||||||||||||||||||
| max_location = [idx[max_idx].item() for idx in indices] | ||||||||||||||||||||||||||||
| msg = ( | ||||||||||||||||||||||||||||
| f"Outputs not close enough in tensor at idx={i}. " | ||||||||||||||||||||||||||||
| f"Maximum difference at location {max_location} " | ||||||||||||||||||||||||||||
| f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} " | ||||||||||||||||||||||||||||
| f"(diff {max_diff.item()})." | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| raise AssertionError(msg) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _run_test(fp_init, recipe): | ||||||||||||||||||||||||||||
| def _run_test(quantized_init, autocast, recipe): | ||||||||||||||||||||||||||||
| test_dir = Path(__file__).parent.resolve() | ||||||||||||||||||||||||||||
| fsdp_script = test_dir / "run_fsdp2_fp8_model.py" | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", "--master-port=29501", str(fsdp_script)] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if fp_init: | ||||||||||||||||||||||||||||
| test_cmd += ["--fp8-init"] | ||||||||||||||||||||||||||||
| test_cmd += ["--recipe", recipe] | ||||||||||||||||||||||||||||
| if quantized_init: | ||||||||||||||||||||||||||||
| test_cmd += ["--quantized-init"] | ||||||||||||||||||||||||||||
| if autocast: | ||||||||||||||||||||||||||||
| test_cmd += ["--autocast"] | ||||||||||||||||||||||||||||
| if autocast or quantized_init: | ||||||||||||||||||||||||||||
| test_cmd += ["--recipe", recipe] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| subprocess.run(test_cmd + ['--use-fsdp2','--gradients-save-file', 'all_iters_fsdp2.pt'], env=os.environ, check=True) | ||||||||||||||||||||||||||||
| subprocess.run(test_cmd + ['--gradients-save-file', 'all_iters_dp.pt'], env=os.environ, check=True) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Load outputs | ||||||||||||||||||||||||||||
| output_fsdp = torch.load("all_iters_fsdp2.pt", map_location="cpu") | ||||||||||||||||||||||||||||
| output_dp = torch.load("all_iters_dp.pt", map_location="cpu") | ||||||||||||||||||||||||||||
| atol = 0 | ||||||||||||||||||||||||||||
| rtol = 0 | ||||||||||||||||||||||||||||
| # Use relaxed tolerance when FSDP2 and DDP are not guaranteed to be bit-identical: | ||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||
| # - quantized_init=True: With FSDP2, the FP8 Adam optimizer | ||||||||||||||||||||||||||||
| # re-quantizes FP32 master weights back to FP8 using a scale derived from | ||||||||||||||||||||||||||||
| # the previous iteration's shard-local max_abs (each rank only sees its | ||||||||||||||||||||||||||||
| # weight shard). In DDP, the same scale is derived from the full tensor's | ||||||||||||||||||||||||||||
| # max_abs. This scale difference causes different FP8 rounding, which | ||||||||||||||||||||||||||||
| # compounds over iterations. Hence we use a relaxed tolerance. | ||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||
| # - No FP8 (quantized_init=False, autocast=False): gradient reduction order differs | ||||||||||||||||||||||||||||
| # (all-reduce vs reduce-scatter), so float non-associativity produces last-bit | ||||||||||||||||||||||||||||
| # differences in the reduced gradients and updated weights. Hence we use a relaxed tolerance. | ||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||
| # When autocast=True and quantized_init=False, FP8 quantization happens after the | ||||||||||||||||||||||||||||
| # FSDP2 AllGather reconstructs the full weight, so both paths compute identical | ||||||||||||||||||||||||||||
| # scales and produce bit-identical FP8 GEMMs — strict tolerance (0) is used. | ||||||||||||||||||||||||||||
| if quantized_init or (not quantized_init and not autocast): | ||||||||||||||||||||||||||||
| atol = 1e-6 | ||||||||||||||||||||||||||||
| rtol = 5e-5 | ||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If our reference is ddp with the same fp8 primary weight, then the same cast from fp32 master weight to fp8 happens in both target and reference flow. Then we will have exact match?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When quantized_init=True, weights are stored as FP8 parameters. With FSDP2, each rank only holds a shard of the weight, so the optimizer computes amax = max_abs(master_weight_shard) —> a shard-local value. With DDP, each rank holds the full weight, so it computes amax = max_abs(master_weight_full). Although the shard-local amaxes are all-reduced before computing the next iteration's scale, the scale used for the current re-quantization was already derived from the previous iteration's shard-local amax. Since max_abs(shard) might not be equal to max_abs(full_tensor), the FSDP2 and DDP scales can differ, leading to different FP8 rounding. This compounds over iterations. With autocast-only (no quantized_init), this doesn't happen because weights remain in BF16/FP32 — the optimizer uses regular Adam with no FP8 re-quantization.
TransformerEngine/transformer_engine/common/multi_tensor/adam.cu Lines 170 to 181 in f6efbbf
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| for idx, (te_output_no_cache, te_output_cache) in enumerate(zip(output_fsdp, output_dp)): | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| print(f"Comparing FSDP {te_output_no_cache[0]}, DDP {te_output_cache[0]} at index {idx}...") | ||||||||||||||||||||||||||||
| assertEqual(te_output_no_cache[1], te_output_cache[1]) # expects exact match | ||||||||||||||||||||||||||||
| assert_allclose(te_output_no_cache[1], te_output_cache[1], atol=atol, rtol=rtol) | ||||||||||||||||||||||||||||
| print(f"Tensor at index {idx} passed comparison.") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
@@ -70,13 +108,24 @@ def cleanup_artifacts(): | |||||||||||||||||||||||||||
| if os.path.exists(fname): | ||||||||||||||||||||||||||||
| os.remove(fname) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Define test cases explicitly | ||||||||||||||||||||||||||||
| test_cases = [] | ||||||||||||||||||||||||||||
| # All FP8 enabled cases (all recipes) | ||||||||||||||||||||||||||||
| for quantized_init in [True, False]: | ||||||||||||||||||||||||||||
| for autocast in [True, False]: | ||||||||||||||||||||||||||||
| if quantized_init or autocast: | ||||||||||||||||||||||||||||
| for recipe in ["delayed", "current", "mxfp8"]: | ||||||||||||||||||||||||||||
| test_cases.append((quantized_init, autocast, recipe)) | ||||||||||||||||||||||||||||
| # FP8 disabled case (only once) | ||||||||||||||||||||||||||||
| test_cases.append((False, False, "delayed")) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| @pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") | ||||||||||||||||||||||||||||
| @pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") | ||||||||||||||||||||||||||||
| @pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") | ||||||||||||||||||||||||||||
| @pytest.mark.parametrize("fp8_init", ([False])) | ||||||||||||||||||||||||||||
| @pytest.mark.parametrize("recipe", (["delayed", "current", "mxfp8"])) | ||||||||||||||||||||||||||||
| @pytest.mark.parametrize("quantized_init, autocast, recipe", test_cases) | ||||||||||||||||||||||||||||
| @pytest.mark.usefixtures("cleanup_artifacts") | ||||||||||||||||||||||||||||
| def test_distributed(fp8_init, recipe): | ||||||||||||||||||||||||||||
| def test_distributed(quantized_init, autocast, recipe): | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| batch_size = 2048 | ||||||||||||||||||||||||||||
| input_size = 2048 | ||||||||||||||||||||||||||||
|
|
@@ -90,18 +139,23 @@ def test_distributed(fp8_init, recipe): | |||||||||||||||||||||||||||
| torch.save(input_data.cpu(), input_path) | ||||||||||||||||||||||||||||
| print("Generated and saved shared input tensor.") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| model = SimpleNet(input_size, 2048, 2048) | ||||||||||||||||||||||||||||
| if quantized_init: | ||||||||||||||||||||||||||||
| fp8_recipe = {"delayed": DelayedScaling(), "current": Float8CurrentScaling(), "mxfp8": MXFP8BlockScaling()}[recipe] | ||||||||||||||||||||||||||||
| with quantized_model_init(enabled=True, recipe=fp8_recipe): | ||||||||||||||||||||||||||||
| model = SimpleNet(input_size, 2048, 2048) | ||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||
| model = SimpleNet(input_size, 2048, 2048) | ||||||||||||||||||||||||||||
| torch.save(model.state_dict(), 'fsdp_model.pth') | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if torch.cuda.device_count() < 4: | ||||||||||||||||||||||||||||
| pytest.skip("FSDP2 test requires at least 4 GPUs") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if fp8_init and not fp8_available: | ||||||||||||||||||||||||||||
| if quantized_init and not fp8_available: | ||||||||||||||||||||||||||||
| pytest.skip(reason_for_no_fp8) | ||||||||||||||||||||||||||||
| if recipe == "mxfp8" and not mxfp8_available: | ||||||||||||||||||||||||||||
| pytest.skip(reason_for_no_mxfp8) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| _run_test(fp8_init, recipe) | ||||||||||||||||||||||||||||
| _run_test(quantized_init, autocast, recipe) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def test_dummy() -> None: | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does with te.fp8_autocast(enabled=args.fp8_autocast,.. ) do the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does do the same but since with TEv2.10, te.fp8_autocast is replaced with te.autocast, I've made the change to be consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So will 'with te.autocast(enabled=args.fp8_autocast, recipe=...)' do the same as if/else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it should. I'll make the changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.