From 641898cdc93679b90609f7738712e4fa9afd4760 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Mon, 2 Mar 2026 08:39:17 -0800 Subject: [PATCH 1/4] fix for async dcp checkpointing Signed-off-by: Peter St. John --- .../distributed/run_fsdp2_fused_adam.py | 21 +++++++------------ tests/pytorch/distributed/test_torch_fsdp2.py | 10 --------- .../pytorch/quantized_tensor.py | 15 +++++++++++++ .../pytorch/tensor/float8_tensor.py | 20 +++++++++++++++++- .../tensor/storage/float8_tensor_storage.py | 7 ++++++- 5 files changed, 47 insertions(+), 26 deletions(-) diff --git a/tests/pytorch/distributed/run_fsdp2_fused_adam.py b/tests/pytorch/distributed/run_fsdp2_fused_adam.py index 0439bf1b5a..7302755e7f 100644 --- a/tests/pytorch/distributed/run_fsdp2_fused_adam.py +++ b/tests/pytorch/distributed/run_fsdp2_fused_adam.py @@ -506,17 +506,13 @@ def test_dcp_output_parity(recipe=None, async_save=False): else: model_state = model.state_dict() + save_state = {"model": model_state, "optimizer": optimizer.state_dict()} + if not async_save: - dcp.save( - {"model": model_state, "optimizer": optimizer.state_dict()}, - checkpoint_id=checkpoint_dir, - ) - future = None + dcp.save(save_state, checkpoint_id=checkpoint_dir) else: - future = dcp.async_save( - {"model": model_state, "optimizer": optimizer.state_dict()}, - checkpoint_id=checkpoint_dir, - ) + future = dcp.async_save(save_state, checkpoint_id=checkpoint_dir) + future.result() # Block on async save completion # ── Build a fresh model and load the checkpoint ────────────────── model2 = _build_model(fp8_init=True, recipe=recipe) @@ -545,9 +541,6 @@ def test_dcp_output_parity(recipe=None, async_save=False): state_to_load = {"model": model2_state, "optimizer": optimizer2.state_dict()} - if async_save: - future.result() # Block on async save completion - dcp.load(state_to_load, checkpoint_id=checkpoint_dir) model2.load_state_dict( state_to_load["model"], @@ -572,7 +565,7 @@ def test_dcp_output_parity(recipe=None, async_save=False): ref_output, rtol=0.05, atol=0.1, - msg="Fresh model loaded from DCP checkpoint produces different output", + msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}", ) else: torch.testing.assert_close( @@ -580,7 +573,7 @@ def test_dcp_output_parity(recipe=None, async_save=False): ref_output, rtol=0, atol=0, - msg="Fresh model loaded from DCP checkpoint produces different output", + msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}", ) # ── Verify one more training step produces identical results ───── diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index b10f31ea07..f51887f799 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -171,16 +171,6 @@ def test_fsdp2_dcp_output_parity(fp_recipe): @pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") def test_fsdp2_dcp_output_parity_async(fp_recipe): """DCP save/load round-trip into a fresh model produces identical outputs.""" - if fp_recipe in ("DelayedScaling", "Float8CurrentScaling"): - pytest.xfail( - f"async DCP save/load with {fp_recipe} uses StateDictStager._offload_tensor() which " - "tries to deep-copy the tensor's underlying storage. Float8Tensor is a wrapper subclass" - "(_make_wrapper_subclass) with data_ptr() == 0 (empty storage). The staging code at " - "line 215 skips the storage copy for wrapper subclasses, creating a plain tensor with " - "uninitialized garbage data. The actual FP8 data (in _data, _scale_inv attributes) is " - "deep-copied but ignored by DCP when writing." - ) - if fp_recipe == "MXFP8BlockScaling": pytest.xfail( "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index cb697bc197..bbcbd51425 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -556,6 +556,21 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.view.default: raise NotImplementedError("{cls.__name__} class does not support tensor views") + # New empty op (used by DCP async staging to create CPU copies) + if func == torch.ops.aten.new_empty.default: + tensor = args[0] + size = args[1] + device = kwargs.get("device", tensor.device) + pin_memory = kwargs.get("pin_memory", False) + out = tensor._quantizer.make_empty( + shape=torch.Size(size), + dtype=tensor.dtype, + device=device, + requires_grad=tensor.requires_grad, + pin_memory=pin_memory, + ) + return out + # Empty like op if func == torch.ops.aten.empty_like.default: tensor = args[0] diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index c60bb2308d..6120fdaca6 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -152,6 +152,7 @@ def make_empty( requires_grad=requires_grad, data_transpose=data_transpose, quantizer=self, + device=device, ) def calibrate(self, tensor: torch.Tensor) -> None: @@ -378,6 +379,7 @@ def make_empty( requires_grad=requires_grad, data_transpose=data_transpose, quantizer=self, + device=device, ) def calibrate(self, tensor: torch.Tensor) -> None: @@ -951,6 +953,15 @@ def is_cuda(self): return self._transpose.is_cuda raise RuntimeError("Both data and transpose are None") + @property + def is_cpu(self): + """Return whether the tensor is on CPU.""" + if self._data is not None: + return self._data.is_cpu + if self._transpose is not None: + return self._transpose.is_cpu + raise RuntimeError("Both data and transpose are None") + @classmethod def _make_in_reduce_ex( cls, @@ -975,7 +986,14 @@ def _make_in_reduce_ex( ) def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling to remove references to FP8 metadata objects""" + """Custom pickling to remove references to FP8 metadata objects + + CPU Float8Tensors are serialized as dequantized plain tensors + for compatibility with torch.load(weights_only=True), which is + used by DCP async save staging. + """ + if self._data is not None and self._data.is_cpu: + return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol) return ( Float8Tensor._make_in_reduce_ex, (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape), diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index a815b366b2..df164fe338 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -14,7 +14,7 @@ from ...quantized_tensor import QuantizedTensorStorage, Quantizer -from ...constants import TE_DType as torch_to_transformer_engine_dtype +from ...constants import TE_DType as torch_to_transformer_engine_dtype, TE_DType_To_Torch from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor @@ -35,6 +35,10 @@ def forward( if tensor._data is not None: if tensor._data.numel() == 0: return torch.empty_like(tensor._data, dtype=dtype) + if tensor._data.is_cpu: + # CPU fallback: reinterpret uint8 as FP8, cast to target dtype, scale + fp8_torch_dtype = TE_DType_To_Torch[tensor._fp8_dtype] + return (tensor._data.view(fp8_torch_dtype).float() * tensor._scale_inv).to(dtype) # Cast from FP8 return tex.dequantize(tensor, te_dtype) @@ -130,6 +134,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp8_dtype": self._fp8_dtype, "data_transpose": self._transpose, "quantizer": self._quantizer, + "device": self.device, } def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: From 5c34ffcdf6e9f3f6060d0f8437f6761517d953d9 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Wed, 11 Mar 2026 11:48:06 -0600 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Peter St. John --- transformer_engine/pytorch/quantized_tensor.py | 3 ++- transformer_engine/pytorch/tensor/float8_tensor.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index bbcbd51425..b5ad9bd701 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -560,11 +560,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.new_empty.default: tensor = args[0] size = args[1] + dtype = kwargs.get("dtype", tensor.dtype) device = kwargs.get("device", tensor.device) pin_memory = kwargs.get("pin_memory", False) out = tensor._quantizer.make_empty( shape=torch.Size(size), - dtype=tensor.dtype, + dtype=dtype, device=device, requires_grad=tensor.requires_grad, pin_memory=pin_memory, diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 6120fdaca6..69a2892d51 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -992,7 +992,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: for compatibility with torch.load(weights_only=True), which is used by DCP async save staging. """ - if self._data is not None and self._data.is_cpu: + if self.is_cpu: return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol) return ( Float8Tensor._make_in_reduce_ex, From 5df7fb4e18414230490dacdce5584476c4c66183 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 12 Mar 2026 00:54:14 +0530 Subject: [PATCH 3/4] Update transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani --- .../pytorch/tensor/storage/float8_tensor_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index a75fd95dc0..a176db0f18 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -38,7 +38,7 @@ def forward( if tensor._data.is_cpu: # CPU fallback: reinterpret uint8 as FP8, cast to target dtype, scale fp8_torch_dtype = TE_DType_To_Torch[tensor._fp8_dtype] - return (tensor._data.view(fp8_torch_dtype).float() * tensor._scale_inv).to(dtype) + return (tensor._data.view(fp8_torch_dtype).float() * tensor._scale_inv.to(tensor._data.device)).to(dtype) # Cast from FP8 return tex.dequantize(tensor, te_dtype) From bbaef502aa1cc547a7eb828a47c10d0c087e44eb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Mar 2026 19:25:00 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/tensor/storage/float8_tensor_storage.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index a176db0f18..0d9afd56d6 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -38,7 +38,10 @@ def forward( if tensor._data.is_cpu: # CPU fallback: reinterpret uint8 as FP8, cast to target dtype, scale fp8_torch_dtype = TE_DType_To_Torch[tensor._fp8_dtype] - return (tensor._data.view(fp8_torch_dtype).float() * tensor._scale_inv.to(tensor._data.device)).to(dtype) + return ( + tensor._data.view(fp8_torch_dtype).float() + * tensor._scale_inv.to(tensor._data.device) + ).to(dtype) # Cast from FP8 return tex.dequantize(tensor, te_dtype)