Skip to content

Fix for async dcp checkpointing with Float8Tensors#2721

Open
pstjohn wants to merge 5 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/fix-async-dcp
Open

Fix for async dcp checkpointing with Float8Tensors#2721
pstjohn wants to merge 5 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/fix-async-dcp

Conversation

@pstjohn
Copy link
Contributor

@pstjohn pstjohn commented Mar 2, 2026

(includes changes from #2698)

dcp.async_save fails silently with QuantizedTensor (Float8Tensor) — staged tensors contain uninitialized (NaN) data instead of actual FP8 values.

PyTorch's async save stages tensors to CPU by copying raw storage via new_empty() + deep_copy. Float8Tensor is a wrapper subclass with data_ptr()==0 (empty storage), so:

  1. new_empty() falls through to default dispatch, returning a plain tensor instead of a Float8Tensor
  2. The deep-copied _data/_scale_inv attributes land on the plain tensor but are ignored by DCP's write path

Changes

  • quantized_tensor.py: Handle aten.new_empty.default in torch_dispatch so staging preserves the Float8Tensor subclass type
  • float8_tensor_storage.py: Add a CPU fallback in dequantize() using PyTorch native FP8 dtypes, since tex.dequantize is CUDA-only and the staged tensor lives on CPU
  • run_fsdp2_fused_adam.py: Remove the _dequantize_state_dict workaround — dcp.async_save now works transparently

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/fix-async-dcp branch from 1fba8c7 to 641898c Compare March 11, 2026 17:05
@pstjohn pstjohn marked this pull request as ready for review March 11, 2026 17:06
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 11, 2026

Greptile Summary

This PR fixes a silent data corruption bug in PyTorch's dcp.async_save when checkpointing models that contain Float8Tensor weights. The root cause was that DCP's staging path calls aten.new_empty on the tensor followed by a deep copy — but because Float8Tensor is a wrapper subclass with data_ptr() == 0, the default dispatch returned a plain (uninitialized) tensor instead of a Float8Tensor, so the actual FP8 values were never staged.

The fix is logically sound and covers the three necessary layers:

  • quantized_tensor.py — intercepts aten.new_empty.default in __torch_dispatch__ and delegates to _quantizer.make_empty(), preserving the Float8Tensor subclass type and placing the staging buffer on the requested device.
  • float8_tensor.py — adds the is_cpu property and an override of __reduce_ex__ that dequantizes CPU-staged Float8Tensors to plain floats before pickling, enabling weights_only=True compatibility required by DCP's write path.
  • float8_tensor_storage.py — adds a pure-PyTorch CPU dequantize path (uint8 → FP8 view → float32 → scale → target dtype) since tex.dequantize is CUDA-only, and propagates device through get_metadata().

Several edge-case issues raised in prior review threads remain open in the current code: _quantizer can be None for deserialized tensors (causing an AttributeError on a second async save round-trip); __reduce_ex__ calls self.is_cpu unconditionally, which raises RuntimeError when both _data and _transpose are None (e.g. after prepare_for_saving() / clear()); and get_metadata() now raises in the same cleared state. Addressing these before merge would make the fix production-hardened.

Confidence Score: 3/5

  • The happy-path DCP async staging fix is correct, but several edge cases in error handling remain unresolved from prior review threads.
  • The core three-layer fix (new_empty dispatch, CPU reduce_ex, CPU dequantize fallback) is logically correct and the xfail removals in the tests confirm expected coverage. However, multiple known-open issues from the review thread — notably the _quantizer is None AttributeError on a second save round-trip, and the __reduce_ex__ / is_cpu crash on cleared tensors — are still present in the HEAD code, which limits confidence in production correctness for all lifecycle states of a Float8Tensor.
  • transformer_engine/pytorch/tensor/float8_tensor.py (__reduce_ex__ line 997 — unconditional is_cpu call) and transformer_engine/pytorch/quantized_tensor.py (line 570 — unguarded _quantizer.make_empty call) need attention before merging.

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantized_tensor.py Adds aten.new_empty.default dispatch to return a proper Float8Tensor for DCP async staging; _quantizer null-check and dtype semantics issues remain open from prior review threads.
transformer_engine/pytorch/tensor/float8_tensor.py Adds is_cpu property and CPU-aware __reduce_ex__ that dequantizes staged tensors to plain floats; is_cpu will raise RuntimeError when both _data and _transpose are None (edge case from prior threads still unresolved).
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Adds CPU dequantize fallback via uint8 → fp8 view → float → scale_inv and adds device field to get_metadata(); device property raises when tensor is cleared, which pre-existing call sites could hit.
tests/pytorch/distributed/run_fsdp2_fused_adam.py Refactors test to share save_state dict, moves future.result() immediately after async_save, and improves assert messages; straightforward and correct.
tests/pytorch/distributed/test_torch_fsdp2.py Removes the xfail markers for DelayedScaling and Float8CurrentScaling async DCP tests, now that the underlying bug is fixed.

Sequence Diagram

sequenceDiagram
    participant DCP as DCP Async Save
    participant QT as QuantizedTensor.__torch_dispatch__
    participant F8Q as Float8Quantizer.make_empty()
    participant F8T as Float8Tensor (CPU staged)
    participant RE as __reduce_ex__
    participant F8F as _FromFloat8Func.forward (CPU path)

    DCP->>QT: aten.new_empty.default(fp8_tensor, size, device=cpu)
    QT->>F8Q: make_empty(shape, dtype, device=cpu, pin_memory=True)
    F8Q-->>QT: Float8Tensor(_data=uint8 on CPU, _scale_inv on CPU)
    QT-->>DCP: Float8Tensor (CPU staged, correct subclass)

    DCP->>F8T: copy_(original_fp8_tensor)
    Note over F8T: _data, _scale_inv, _transpose copied to CPU buffer

    DCP->>RE: pickle(Float8Tensor on CPU)
    RE->>RE: self.is_cpu → True
    RE->>F8F: dequantize(dtype=self.dtype)
    F8F->>F8F: _data.view(fp8_torch_dtype).float() * _scale_inv → cast to dtype
    F8F-->>RE: plain float tensor
    RE-->>DCP: pickle plain float tensor (weights_only=True compatible)

    DCP->>DCP: write checkpoint to storage
Loading

Last reviewed commit: bbaef50

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Peter St. John <peterc.stjohn@gmail.com>
Comment on lines +566 to +572
out = tensor._quantizer.make_empty(
shape=torch.Size(size),
dtype=dtype,
device=device,
requires_grad=tensor.requires_grad,
pin_memory=pin_memory,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AttributeError when _quantizer is None

tensor._quantizer can be None for Float8Tensor objects deserialized via the GPU path (_make_in_reduce_ex), which does not pass a quantizer argument. If a second async DCP save is attempted after a load/save round-trip, new_empty will be dispatched on the deserialized tensor, causing AttributeError: 'NoneType' object has no attribute 'make_empty'.

A guard is needed before calling make_empty:

if func == torch.ops.aten.new_empty.default:
    tensor = args[0]
    size = args[1]
    dtype = kwargs.get("dtype", tensor.dtype)
    device = kwargs.get("device", tensor.device)
    pin_memory = kwargs.get("pin_memory", False)
    if tensor._quantizer is None:
        raise RuntimeError(
            f"{type(tensor).__name__} does not have a quantizer; "
            "cannot create new_empty QuantizedTensor"
        )
    out = tensor._quantizer.make_empty(
        shape=torch.Size(size),
        dtype=dtype,
        device=device,
        requires_grad=tensor.requires_grad,
        pin_memory=pin_memory,
    )
    return out

Comment on lines +995 to +996
if self.is_cpu:
return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_cpu raises RuntimeError when both _data and _transpose are None

The is_cpu property raises RuntimeError("Both data and transpose are None") when both _data and _transpose are None — for example, after prepare_for_saving() or clear() sets them to None. The pre-existing GPU path in __reduce_ex__ handled this gracefully by passing None to _make_in_reduce_ex. The new check unconditionally calls self.is_cpu, introducing a regression for any Float8Tensor pickled in a cleared state.

A safe guard would be:

def __reduce_ex__(self, protocol: int) -> tuple:
    if self._data is not None and self._data.is_cpu:
        return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol)
    return (
        Float8Tensor._make_in_reduce_ex,
        (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape),
    )

This covers only the case where _data is a CPU tensor (the staged DCP scenario) and falls through to the original path for all other cases, including _data is None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, didn't i just change this from the other comment starting with "CPU-only transpose path not handled in reduce_ex"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right — I suggested changing to if self.is_cpu: to handle the CPU transpose path, but that creates a new issue when both _data and _transpose are None.

The is_cpu property raises RuntimeError("Both data and transpose are None") in that case, which the original GPU path handled gracefully by passing None to _make_in_reduce_ex.

The fix needs to handle all three cases:

  1. _data on CPU → dequantize
  2. _transpose on CPU (no _data) → dequantize
  3. Both None → fall through to original path

Here's a clean solution using try-except:

Suggested change
if self.is_cpu:
return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol)
def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling to remove references to FP8 metadata objects
CPU Float8Tensors are serialized as dequantized plain tensors
for compatibility with torch.load(weights_only=True), which is
used by DCP async save staging.
"""
try:
is_cpu_tensor = self.is_cpu
except RuntimeError:
# Both _data and _transpose are None
is_cpu_tensor = False
if is_cpu_tensor:
return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol)
return (
Float8Tensor._make_in_reduce_ex,
(self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape),
)

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Comment on lines 138 to 140
"quantizer": self._quantizer,
"device": self.device,
"fake_dtype": self._dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_metadata() raises when tensor is in cleared state

Adding "device": self.device is correct for the normal lifecycle, but Float8TensorStorage.device raises RuntimeError("Float8TensorStorage has no data!") when both _data and _transpose are None — exactly the state left by prepare_for_saving() or clear().

Before this PR, get_metadata() returned None for data and data_transpose without raising. Now any call to get_metadata() (e.g., via make_like()) on a cleared tensor would raise instead of propagating gracefully.

A safe guard:

"device": self._data.device if self._data is not None
          else (self._transpose.device if self._transpose is not None else None),

…e.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Member

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants