Fix for async dcp checkpointing with Float8Tensors#2721
Fix for async dcp checkpointing with Float8Tensors#2721pstjohn wants to merge 5 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1fba8c7 to
641898c
Compare
Greptile SummaryThis PR fixes a silent data corruption bug in PyTorch's The fix is logically sound and covers the three necessary layers:
Several edge-case issues raised in prior review threads remain open in the current code: Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant DCP as DCP Async Save
participant QT as QuantizedTensor.__torch_dispatch__
participant F8Q as Float8Quantizer.make_empty()
participant F8T as Float8Tensor (CPU staged)
participant RE as __reduce_ex__
participant F8F as _FromFloat8Func.forward (CPU path)
DCP->>QT: aten.new_empty.default(fp8_tensor, size, device=cpu)
QT->>F8Q: make_empty(shape, dtype, device=cpu, pin_memory=True)
F8Q-->>QT: Float8Tensor(_data=uint8 on CPU, _scale_inv on CPU)
QT-->>DCP: Float8Tensor (CPU staged, correct subclass)
DCP->>F8T: copy_(original_fp8_tensor)
Note over F8T: _data, _scale_inv, _transpose copied to CPU buffer
DCP->>RE: pickle(Float8Tensor on CPU)
RE->>RE: self.is_cpu → True
RE->>F8F: dequantize(dtype=self.dtype)
F8F->>F8F: _data.view(fp8_torch_dtype).float() * _scale_inv → cast to dtype
F8F-->>RE: plain float tensor
RE-->>DCP: pickle plain float tensor (weights_only=True compatible)
DCP->>DCP: write checkpoint to storage
Last reviewed commit: bbaef50 |
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Peter St. John <peterc.stjohn@gmail.com>
| out = tensor._quantizer.make_empty( | ||
| shape=torch.Size(size), | ||
| dtype=dtype, | ||
| device=device, | ||
| requires_grad=tensor.requires_grad, | ||
| pin_memory=pin_memory, | ||
| ) |
There was a problem hiding this comment.
AttributeError when _quantizer is None
tensor._quantizer can be None for Float8Tensor objects deserialized via the GPU path (_make_in_reduce_ex), which does not pass a quantizer argument. If a second async DCP save is attempted after a load/save round-trip, new_empty will be dispatched on the deserialized tensor, causing AttributeError: 'NoneType' object has no attribute 'make_empty'.
A guard is needed before calling make_empty:
if func == torch.ops.aten.new_empty.default:
tensor = args[0]
size = args[1]
dtype = kwargs.get("dtype", tensor.dtype)
device = kwargs.get("device", tensor.device)
pin_memory = kwargs.get("pin_memory", False)
if tensor._quantizer is None:
raise RuntimeError(
f"{type(tensor).__name__} does not have a quantizer; "
"cannot create new_empty QuantizedTensor"
)
out = tensor._quantizer.make_empty(
shape=torch.Size(size),
dtype=dtype,
device=device,
requires_grad=tensor.requires_grad,
pin_memory=pin_memory,
)
return out| if self.is_cpu: | ||
| return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol) |
There was a problem hiding this comment.
is_cpu raises RuntimeError when both _data and _transpose are None
The is_cpu property raises RuntimeError("Both data and transpose are None") when both _data and _transpose are None — for example, after prepare_for_saving() or clear() sets them to None. The pre-existing GPU path in __reduce_ex__ handled this gracefully by passing None to _make_in_reduce_ex. The new check unconditionally calls self.is_cpu, introducing a regression for any Float8Tensor pickled in a cleared state.
A safe guard would be:
def __reduce_ex__(self, protocol: int) -> tuple:
if self._data is not None and self._data.is_cpu:
return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol)
return (
Float8Tensor._make_in_reduce_ex,
(self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape),
)This covers only the case where _data is a CPU tensor (the staged DCP scenario) and falls through to the original path for all other cases, including _data is None.
There was a problem hiding this comment.
wait, didn't i just change this from the other comment starting with "CPU-only transpose path not handled in reduce_ex"?
There was a problem hiding this comment.
You're right — I suggested changing to if self.is_cpu: to handle the CPU transpose path, but that creates a new issue when both _data and _transpose are None.
The is_cpu property raises RuntimeError("Both data and transpose are None") in that case, which the original GPU path handled gracefully by passing None to _make_in_reduce_ex.
The fix needs to handle all three cases:
_dataon CPU → dequantize_transposeon CPU (no_data) → dequantize- Both None → fall through to original path
Here's a clean solution using try-except:
| if self.is_cpu: | |
| return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol) | |
| def __reduce_ex__(self, protocol: int) -> tuple: | |
| """Custom pickling to remove references to FP8 metadata objects | |
| CPU Float8Tensors are serialized as dequantized plain tensors | |
| for compatibility with torch.load(weights_only=True), which is | |
| used by DCP async save staging. | |
| """ | |
| try: | |
| is_cpu_tensor = self.is_cpu | |
| except RuntimeError: | |
| # Both _data and _transpose are None | |
| is_cpu_tensor = False | |
| if is_cpu_tensor: | |
| return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol) | |
| return ( | |
| Float8Tensor._make_in_reduce_ex, | |
| (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape), | |
| ) |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py
Outdated
Show resolved
Hide resolved
| "quantizer": self._quantizer, | ||
| "device": self.device, | ||
| "fake_dtype": self._dtype, |
There was a problem hiding this comment.
get_metadata() raises when tensor is in cleared state
Adding "device": self.device is correct for the normal lifecycle, but Float8TensorStorage.device raises RuntimeError("Float8TensorStorage has no data!") when both _data and _transpose are None — exactly the state left by prepare_for_saving() or clear().
Before this PR, get_metadata() returned None for data and data_transpose without raising. Now any call to get_metadata() (e.g., via make_like()) on a cleared tensor would raise instead of propagating gracefully.
A safe guard:
"device": self._data.device if self._data is not None
else (self._transpose.device if self._transpose is not None else None),…e.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci pytorch |
for more information, see https://pre-commit.ci
(includes changes from #2698)
dcp.async_save fails silently with QuantizedTensor (Float8Tensor) — staged tensors contain uninitialized (NaN) data instead of actual FP8 values.
PyTorch's async save stages tensors to CPU by copying raw storage via new_empty() + deep_copy. Float8Tensor is a wrapper subclass with data_ptr()==0 (empty storage), so:
Changes