Skip to content

[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974

Open
vthumbe1503 wants to merge 9 commits into
NVIDIA:mainfrom
vthumbe1503:fsdp2_dcp_laod_fix
Open

[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974
vthumbe1503 wants to merge 9 commits into
NVIDIA:mainfrom
vthumbe1503:fsdp2_dcp_laod_fix

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 commented May 11, 2026

Description

Fixes DCP Sync and Async checkpoint loading for MXFP8/NVFP4.
Fixes DCP Async checkpoint loading for all Quantization recipes
Fixes NVFP4 allgather + dequant numerical errors for fsdp2. Turns out this was due to us not setting the amax reduction group as the fsdp group in the quantizer

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • DCP Sync Checkpoint loading

    • untyped_storage is now defined for the base QuantizedTensor to return empty storage. Untyped_storage refers to the backing storage that we use to create all the internal tensors. Since we use make_wrapper_subclass to create TE QuantizedTensors, we use dont have any backing storage associated with the tensor. data_ptr on our Custom QuantizedTensor also returns 0.
    • The main issue is that FSDP2 maintains sharded param tensor for checkpointing. It does so by calling view(-1) on our Quantized sharded model parameters. We return back a dequantized 1D tensor in TE. So, the sharded tensor that FSDP2 maintains for checkpointing is BF16 and Quantized sharded param is our custom FP8 tensor. It evaluates untyped_storage(BF16 sharded tensor reloaded from disk) == untyped_storage(Quantized sharded parameter) to see if the same_tensor. With us returning empty storage now, this would never be equal to sharded tensor's untyped storage.
  • DCP Async Checkpointing

    • to_new_empty function with device="cpu" is being used in Async Checkpointing. This function returned Quantizer.make_empty without setting the device. For device = "cpu" we now dequantize. So that the Async checkpointing directly saves the bf16 data on disk and reload works fine.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@vthumbe1503 vthumbe1503 changed the title [Pytorch][Bug] DCP Load Fixes for FSDP2 with QuantizedModelInit [Pytorch][Bug] DCP Checkpoint Load Fixes for FSDP2 with QuantizedModelInit May 11, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 11, 2026

Greptile Summary

This PR fixes DCP sync and async checkpoint loading for MXFP8/NVFP4 and async loading for all quantization recipes when using FSDP2 with QuantizedModelInit. The core changes introduce a _to_copy dispatch handler that moves all internal buffers when .cpu() / .to(device=...) is called, promote all _make_*_in_reduce_ex reconstructors to module-level functions (avoiding getattr in the pickle stream), serialize pybind11 DType enums as plain int, register TE types with torch.serialization.add_safe_globals, and add a CUDA-bounce path so MXFP8/NVFP4 tensors can dequantize from CPU.

  • Sync checkpoint fix: QuantizedTensor.untyped_storage() now returns a zero-byte storage so FSDP2's same-tensor identity check never falsely matches the BF16 sharded param against the QuantizedTensor.
  • Async checkpoint fix: Float8Tensor.__reduce_ex__ drops the CPU-dequantize fallback and always round-trips FP8 buffers; MXFP8/NVFP4 dequantize by temporarily moving to CUDA.
  • Pickle safety: all four tensor types now register module-level reconstructors and serialize TE_DType as int to keep weights_only=True loads free of getattr gadgets.

Confidence Score: 4/5

Safe to merge once the None-data guard is added to _make_float8_tensor_in_reduce_ex.

In _make_float8_tensor_in_reduce_ex, device=data.device is evaluated unconditionally, but Float8Tensor._data can be None when the tensor only has a transpose buffer populated. Pickling and restoring such a tensor will raise AttributeError: 'NoneType' object has no attribute 'device'. The parallel helper functions for MXFP8, NVFP4, and Float8Blockwise all guard against this with if data is not None checks — Float8Tensor is the only one that missed the pattern. The rest of the changes look correct.

transformer_engine/pytorch/tensor/float8_tensor.py — specifically _make_float8_tensor_in_reduce_ex.

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/float8_tensor.py Removes the CPU-dequantize fallback in reduce_ex and introduces _make_float8_tensor_in_reduce_ex at module level; device=data.device will raise AttributeError when _data is None (transpose-only tensors).
transformer_engine/pytorch/quantized_tensor.py Adds untyped_storage() returning empty storage, new _to_copy dispatch handler that moves all internal buffers to the target device while preserving the subclass, and Quantizer getstate/setstate for safe DType pickling.
transformer_engine/pytorch/init.py Registers TE quantized tensor types and module-level reconstructors with torch.serialization.add_safe_globals to allow weights_only=True DCP loads; falls back gracefully on older PyTorch.
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py Removes per-class untyped_storage() override and promotes _make_float8_blockwise_tensor_in_reduce_ex to module level with int fp8_dtype and device inference from inner buffers.
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py _FromMXFP8Func.forward now bounces CPU tensors to CUDA for dequantization and returns the result back to the original device.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py _FromNVFP4Func.forward adds the same CUDA-bounce pattern as MXFP8 for CPU dequantize; has a truncated comment on the bounce block.
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Removes the device key from get_metadata() so callers (make_like, _to_copy handler) can pass device explicitly without a duplicate-kwarg collision.
transformer_engine/pytorch/module/base.py Extends the DTensor amax_reduction_group branch to also handle NVFP4Quantizer, matching the existing Float8CurrentScalingQuantizer handling.
tests/pytorch/test_quantized_tensor.py Adds test_cpu_dequantize covering all quantization recipes: moves tensor to CPU, dequantizes, and checks bit-exact match against dequantize-on-CUDA-then-move-to-CPU reference.
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py Removes xfail markers for MXFP8BlockScaling, NVFP4BlockScaling, and Float8BlockScaling async DCP now that the underlying bugs are fixed.
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py Removes xfail guards for Float8BlockScaling+fp8_init and NVFP4+fp8_init+TransformerLayer; the Float8BlockScaling SM120 xfail is now applied regardless of async_save.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[DCP Checkpoint Load] --> B{Sync or Async?}
    B -->|Sync| C[FSDP2 compares untyped_storage]
    C --> D["QuantizedTensor.untyped_storage()\nreturns empty 0-byte storage"]
    D --> E[Storage != BF16 sharded param\n→ not same_tensor → copy path taken]
    E --> F[copy_ dispatches to QuantizedTensor\n→ quantize_ or dequantize]
    B -->|Async| G[DCP staging serializes\nQuantizedTensor to CPU]
    G --> H["__reduce_ex__ called\n(module-level _make_*_in_reduce_ex)"]
    H --> I[fp8/fp4_dtype serialized as int\nQuantizer.dtype serialized as int]
    I --> J["torch.load(weights_only=True)\nuses add_safe_globals whitelist"]
    J --> K[Reconstruct QuantizedTensor\nvia _make_*_in_reduce_ex]
    G --> L{Float8Tensor}
    G --> M{MXFP8Tensor / NVFP4Tensor}
    L --> N[Serialize FP8 buffers directly\nno CPU-dequantize fallback]
    M --> O["dequantize() on CPU:\nbounce to CUDA → tex.dequantize\n→ move result back to CPU"]
    K --> P[QuantizedTensor on CUDA\nready for training]
    O --> P
Loading

Reviews (6): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +536 to +545
def untyped_storage(self) -> torch.UntypedStorage:
"""Return an empty UntypedStorage on the tensor's device.

``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
backing storage of its own; the actual bytes live in the inner
buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
an implementation detail of the quantization scheme. Need to define
this method to avoid DCP staging errors with FSDP2.
"""
return torch.UntypedStorage(0, device=self.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Empty storage breaks shared-storage detection in existing callers

QuantizedTensor.untyped_storage() now returns a freshly allocated zero-byte storage every call. Code in module/_common.py:128 compares tensors[0].untyped_storage().nbytes() against expected size to decide between a no-op view and an out-of-place torch.cat. With 0 bytes returned, that condition is always true, silently disabling the in-place fast path for any QuantizedTensor through ConcatMerge.forward. More critically, utils.py:403-412 in SplitAlongDim.backward uses data_ptr() for noop detection — if all zero-size CUDA allocations return data_ptr() == 0, every QuantizedTensor pair incorrectly appears co-located, setting noop_ok = True and crashing on ret.set_() against a 0-byte storage.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, while I don't think we use QuantizedTensors in the SplitAlongDim ever, the concat sounds plausible to be hit.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to resolve this comment after going thoroughly over noop_cat consequences on Quantizedtensors

Comment on lines +820 to +828
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
torch.testing.assert_close(
loaded_output,
ref_output,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Fresh model loaded from DCP checkpoint produces different output: {x}",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Typo: "neec" should be "need" — appears in both NVFP4 tolerance blocks.

Suggested change
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
torch.testing.assert_close(
loaded_output,
ref_output,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Fresh model loaded from DCP checkpoint produces different output: {x}",
)
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so need to relax tolerances
torch.testing.assert_close(
loaded_output,
ref_output,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Fresh model loaded from DCP checkpoint produces different output: {x}",
)

Comment on lines +867 to +875
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
torch.testing.assert_close(
out2,
out1,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Training step after DCP load produces different output: {x}",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Same typo ("neec") in the second NVFP4 tolerance block for the post-training-step check.

Suggested change
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
torch.testing.assert_close(
out2,
out1,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Training step after DCP load produces different output: {x}",
)
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so need to relax tolerances
torch.testing.assert_close(
out2,
out1,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Training step after DCP load produces different output: {x}",
)

Comment thread tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py
@vthumbe1503 vthumbe1503 changed the title [Pytorch][Bug] DCP Checkpoint Load Fixes for FSDP2 with QuantizedModelInit [Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit May 11, 2026
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}",
)
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need dequant + quant here?

Comment on lines +613 to +616
# When a CPU copy of a quantized tensor is requested (e.g. by
# torch DCP staging via ``x.new_empty(..., device="cpu")``), we
# save the high-precision values in a plain CPU dense tensor.
# For the DCP load path, we will re-quantize the high-precision values.
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix seems ad hoc to me. It's not obvious why qtensor.new_empty(..., device="cpu") returns a quantized tensor while qtensor.new_empty(..., device="cuda") returns a plain tensor. I wonder if it would be cleaner to just return a plain tensor in all cases. Thoughts:

  • It's uncomfortable how new_empty and empty_like would have different behavior. I suppose we could interpret empty_like as "make a tensor that matches the input" and new_empty as "call torch.empty with defaults taken from input", but that would be a private interpretation that no one else follows.
  • Would this affect FSDP or CPU offloading?
  • Given the weirdness, would it be worthwhile raising a warning if new_empty is called outside of DCP?

# torch DCP staging via ``x.new_empty(..., device="cpu")``), we
# save the high-precision values in a plain CPU dense tensor.
# For the DCP load path, we will re-quantize the high-precision values.
target_size = torch.Size(size) if len(size) > 0 else tensor.size()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An empty size is valid and it corresponds to a tensor with 1 entry (for the same reason 2^0=1).

>>> import torch
>>> x = torch.ones(123).new_empty([])
>>> print(x.numel())
1
Suggested change
target_size = torch.Size(size) if len(size) > 0 else tensor.size()
target_size = size

Comment on lines +536 to +545
def untyped_storage(self) -> torch.UntypedStorage:
"""Return an empty UntypedStorage on the tensor's device.

``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
backing storage of its own; the actual bytes live in the inner
buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
an implementation detail of the quantization scheme. Need to define
this method to avoid DCP staging errors with FSDP2.
"""
return torch.UntypedStorage(0, device=self.device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.

# differences vs the manual dequantize-then-allgather path.
if isinstance(param, NVFP4Tensor):
tols = dict(atol=5e-4, rtol=5e-3)
tols = dict(atol=0.125, rtol=0.25)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are the tolerances so much bigger? Is it also due to the dequant+quant path? If so, the comment above is no longer relevant and should be replaced with a better one (but I would still like an explanation why we cannot just load the nvfp4 values from the checkpoint).

Comment on lines +613 to +616
# When a CPU copy of a quantized tensor is requested (e.g. by
# torch DCP staging via ``x.new_empty(..., device="cpu")``), we
# save the high-precision values in a plain CPU dense tensor.
# For the DCP load path, we will re-quantize the high-precision values.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see now why you want to dequantize. I don't think this is needed though - we should be able to create the QuantlizedTensor on the CPU and save it, no? I remember that the CPU offloading of the activations faced similar problem and already had to support some CPU ops on the QuantizedTensor anyway.

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 force-pushed the fsdp2_dcp_laod_fix branch from 3589ffa to 4197bee Compare May 13, 2026 04:00
@vthumbe1503 vthumbe1503 requested a review from ksivaman as a code owner May 13, 2026 04:00
pre-commit-ci Bot and others added 2 commits May 13, 2026 04:01
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

Comment thread transformer_engine/pytorch/__init__.py Outdated
# allow-listed for ``torch.load(weights_only=True)`` (used
# internally by DCP async-staging) to accept the stream.
_TE_DType,
getattr,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 security Session-wide getattr whitelisted for weights_only=True loading

getattr is registered as a safe global at module import time. add_safe_globals is process-wide in PyTorch, so any torch.load(…, weights_only=True) call made anywhere in a session that has imported transformer_engine.pytorch — including checkpoint loads for entirely different models — now has getattr available to the pickle stream. A malicious checkpoint loaded elsewhere could use getattr to access sensitive attributes of any already-constructed object reachable from the whitelisted globals (e.g. getattr(Float8Quantizer_instance, 'amax_reduction_group') to obtain a process group, or to build callable gadget chains). The weights_only=True flag is specifically a defence against untrusted pickle payloads; adding a general-purpose reflective accessor defeats that defence.

A targeted alternative: serialize _fp8_dtype as its integer value (int(self._fp8_dtype)) and reconstruct it in _make_in_reduce_ex via TE_DType(int_value), then add TE_DType to safe globals instead of getattr. This preserves the weights_only invariant without whitelisting a reflective accessor.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

vthumbe1503 and others added 6 commits May 13, 2026 04:19
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Comment thread transformer_engine/pytorch/tensor/float8_tensor.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants