Skip to content
Open
41 changes: 1 addition & 40 deletions tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,6 @@ def test_fused_adam_fp8_master_weights_no_meta(recipe_name):
"""
recipe = get_recipe_from_string(recipe_name)

if recipe_name in ("MXFP8BlockScaling", "Float8BlockScaling", "NVFP4BlockScaling"):
pytest.xfail(
f"{recipe_name}: FSDP2 all-gather hooks for block-scaling QuantizedTensor "
"subclasses fail when parameters are initialized on CUDA. "
"Use device='meta' + reset_parameters() after sharding."
)

world_size, device = _get_dist_info()

model = _build_model(fp8_init=True, recipe=recipe, use_meta_device=False)
Expand Down Expand Up @@ -604,12 +597,6 @@ def test_safetensors_fp32_export(recipe_name):
- Saved tensor shapes match expected (unsharded) shapes
"""
recipe = get_recipe_from_string(recipe_name)
if recipe_name == "MXFP8BlockScaling":
pytest.xfail(
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
"MXFP8 quantized tensors, causing illegal memory access. "
"Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789."
)

from safetensors.torch import load_file, save_file
from torch.distributed.checkpoint.state_dict import (
Expand Down Expand Up @@ -692,40 +679,14 @@ def test_dcp_output_parity(recipe_name, async_save):
"""
recipe = get_recipe_from_string(recipe_name)

if recipe_name == "MXFP8BlockScaling":
pytest.xfail(
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
"MXFP8 quantized tensors, causing illegal memory access: "
"/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh:92 in function "
"multi_tensor_apply: CUDA Error: an illegal memory access was encountered. "
"Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789."
)

if recipe_name == "NVFP4BlockScaling":
pytest.xfail(
"NVFP4BlockScaling: DCP load_state_dict triggers reset_sharded_param() "
"which calls data_ptr() on NVFP4Tensor wrapper subclass with invalid storage"
)

if (
recipe_name == "Float8BlockScaling"
and not async_save
and torch.cuda.get_device_capability()[0] == 12
):
if recipe_name == "Float8BlockScaling" and torch.cuda.get_device_capability()[0] == 12:
pytest.xfail(
"Float8BlockScaling is failing on SM120 with RuntimeError: "
"transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu:534 "
"in function quantize_transpose_vector_blockwise: Assertion failed: pow2_scale. On "
"Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, which "
"requires using power of two scaling factors."
)
if recipe_name == "Float8BlockScaling" and async_save:
pytest.xfail(
"Float8BlockScaling: async DCP save/load round-trip produces different model "
"outputs — quantization metadata (scales) is not correctly persisted through "
"async distributed checkpointing. On SM120, additionally fails with pow2_scale "
"assertion in quantize_transpose_vector_blockwise."
)

import torch.distributed.checkpoint as dcp

Expand Down
13 changes: 0 additions & 13 deletions tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,19 +380,6 @@ def test_distributed(recipe_name, fp8_init, sharding_dims, layer_type):
"data and scale_inv into a single buffer in pre_all_gather, split in post."
)

if recipe_name == "Float8BlockScaling" and fp8_init:
pytest.xfail(
"Float8BlockScaling + fp8_init: scale inverse padding is not handled "
"correctly during FSDP2 all-gather slice ops."
)
if recipe_name == "NVFP4BlockScaling" and fp8_init and layer_type == "TransformerLayer":
pytest.xfail(
"NVFP4BlockScaling + fp8_init + TransformerLayer: "
"_check_fp8_fsdp2_allgather numerical error compounds across multiple "
"linear layers in the transformer block (up to ~1e-2 max abs diff). "
"LayerNormLinear passes with relaxed tolerances. "
"NVFP4 + FSDP2 training is validated by run_fsdp2_fused_adam.py."
)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

Expand Down
50 changes: 50 additions & 0 deletions tests/pytorch/test_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,56 @@ def test_identity_op(
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)

@pytest.mark.parametrize("quantization", _quantization_list)
def test_cpu_dequantize(
self,
*,
quantization: str,
shape: Iterable[int] = (128, 128),
dtype: torch.dtype = torch.bfloat16,
) -> None:
"""Dequantize on a CPU-resident QuantizedTensor."""

# Construct a quantized tensor on CUDA.
_, x_cuda = make_reference_and_test_tensors(
shape=shape,
quantization=quantization,
test_dtype=dtype,
requires_grad=False,
)
assert isinstance(x_cuda, QuantizedTensor)
assert x_cuda.device.type == "cuda"

# Reference: dequantize on CUDA, then move the dense result to CPU.
ref_cpu = x_cuda.dequantize().to(device="cpu")

# Move the QuantizedTensor itself to CPU and dequantize there.
# ``.cpu()`` routes through ``aten._to_copy.default`` so all inner
# buffers (data, scales, amax) are moved to CPU.
x_cpu = x_cuda.cpu()
assert isinstance(x_cpu, QuantizedTensor)
assert x_cpu.device.type == "cpu"
for attr in (
"_data",
"_rowwise_data",
"_columnwise_data",
"_rowwise_scale_inv",
"_columnwise_scale_inv",
"_amax_rowwise",
"_amax_columnwise",
):
buf = getattr(x_cpu, attr, None)
if buf is not None:
assert buf.device.type == "cpu", f"{attr} did not move to CPU"

# Dequantize the CPU tensor. Implementation may bounce through CUDA
# internally, but must return a CPU tensor.
y_cpu = x_cpu.dequantize()
assert y_cpu.device.type == "cpu"
assert y_cpu.dtype == ref_cpu.dtype
assert y_cpu.shape == ref_cpu.shape
torch.testing.assert_close(y_cpu, ref_cpu, rtol=0, atol=0)

@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("dim", [0, 1])
def test_chunk(
Expand Down
11 changes: 10 additions & 1 deletion transformer_engine/common/util/pybind_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,16 @@
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1) \
.def("__reduce_ex__", \
[](transformer_engine::DType self, pybind11::object /*protocol*/) { \
return pybind11::make_tuple(pybind11::type::of(pybind11::cast(self)), \
pybind11::make_tuple(static_cast<int>(self))); \
}) \
.def("__reduce__", [](transformer_engine::DType self) { \
return pybind11::make_tuple(pybind11::type::of(pybind11::cast(self)), \
pybind11::make_tuple(static_cast<int>(self))); \
}); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \
Expand Down
59 changes: 59 additions & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,67 @@
from transformer_engine.pytorch.tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor import NVFP4Tensor
from transformer_engine.pytorch.tensor.float8_tensor import (
_make_float8_tensor_in_reduce_ex,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import (
_make_mxfp8_tensor_in_reduce_ex,
)
from transformer_engine.pytorch.tensor.nvfp4_tensor import (
_make_nvfp4_tensor_in_reduce_ex,
)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
_make_float8_blockwise_tensor_in_reduce_ex,
)

try:
torch._dynamo.config.error_on_nested_jit_trace = False
except AttributeError:
pass # error_on_nested_jit_trace was added in PyTorch 2.2.0

# To allow for safe unpickling of QuantizedTensors when using DCP
# checkpointing with FSDP2. ``tex.DType`` (the pybind11 enum) has its
# ``__reduce_ex__`` / ``__reduce__`` overridden in the C++ binding (see
# ``transformer_engine/common/util/pybind_helper.h``) so its pickle
# stream encodes as ``(tex.DType, (int,))`` and only the class itself
# needs to be allow-listed below.
import transformer_engine_torch as tex

try:
from torch.serialization import add_safe_globals

add_safe_globals(
[
# Storage mixins (used during pickling of internal-only tensors)
QuantizedTensorStorage,
Float8TensorStorage,
MXFP8TensorStorage,
NVFP4TensorStorage,
Float8BlockwiseQTensorStorage,
# Quantizer types embedded in metadata
Quantizer,
Float8Quantizer,
Float8CurrentScalingQuantizer,
MXFP8Quantizer,
NVFP4Quantizer,
Float8BlockQuantizer,
# pybind11 enum used as Quantizer.dtype
tex.DType,
# __reduce_ex__ reconstructors (module-level functions).
_make_float8_tensor_in_reduce_ex,
_make_mxfp8_tensor_in_reduce_ex,
_make_nvfp4_tensor_in_reduce_ex,
_make_float8_blockwise_tensor_in_reduce_ex,
]
)
except (ImportError, AttributeError):
import warnings as _warnings

_warnings.warn(
"transformer_engine: torch.serialization.add_safe_globals is "
"unavailable on this PyTorch version (added in 2.4). DCP "
"checkpointing of QuantizedTensor weights with FSDP2 will not "
"work; upgrade to PyTorch >= 2.4 to enable it.",
RuntimeWarning,
stacklevel=2,
)
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
Expand Down Expand Up @@ -1641,7 +1642,9 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
raise RuntimeError("Weight quantizer has not been initialized")
quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
quantizer.internal = False
if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer):
if is_dtensor and isinstance(
quantizer, (Float8CurrentScalingQuantizer, NVFP4Quantizer)
):
device_mesh = dtensor_param.device_mesh
amax_reduction_group = (
device_mesh.get_group(mesh_dim="shard")
Expand Down
62 changes: 57 additions & 5 deletions transformer_engine/pytorch/quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,9 +552,26 @@ def half(self) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return self.dequantize(dtype=torch.float16)

def cpu(self, memory_format=torch.preserve_format) -> torch.Tensor:
def cpu(self, memory_format=torch.preserve_format) -> QuantizedTensor:
"""Move tensor to CPU while preserving the QuantizedTensor type.

Routes through ``aten._to_copy.default`` so the subclass-preserving
handler in ``__torch_dispatch__`` runs (rather than dequantizing).

"""
# pylint: disable=missing-function-docstring
return self.dequantize().cpu(memory_format=memory_format)
return self.to(device=torch.device("cpu"), memory_format=memory_format)

def untyped_storage(self) -> torch.UntypedStorage:
"""Return an empty UntypedStorage on the tensor's device.

``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
backing storage of its own; the actual bytes live in the inner
buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
an implementation detail of the quantization scheme. Need to define
this method to avoid DCP staging errors with FSDP2.
"""
return torch.UntypedStorage(0, device=self.device)
Comment on lines +565 to +574
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Empty storage breaks shared-storage detection in existing callers

QuantizedTensor.untyped_storage() now returns a freshly allocated zero-byte storage every call. Code in module/_common.py:128 compares tensors[0].untyped_storage().nbytes() against expected size to decide between a no-op view and an out-of-place torch.cat. With 0 bytes returned, that condition is always true, silently disabling the in-place fast path for any QuantizedTensor through ConcatMerge.forward. More critically, utils.py:403-412 in SplitAlongDim.backward uses data_ptr() for noop detection — if all zero-size CUDA allocations return data_ptr() == 0, every QuantizedTensor pair incorrectly appears co-located, setting noop_ok = True and crashing on ret.set_() against a 0-byte storage.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, while I don't think we use QuantizedTensors in the SplitAlongDim ever, the concat sounds plausible to be hit.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to resolve this comment after going thoroughly over noop_cat consequences on Quantizedtensors

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior is unchanged with the change. And I would argue the implementation now is more correct with the change. untyped_storage() default implementation from QuantizedTensor(torch.Tensor) before this change, gives a storage with two properties.

  1. storage.nbytes() returns bytes based on the fake_dtype that we use to register our QuantizedTensor as a torchTensor using make_wrapper_subclass method of torch.

  2. storage.data_ptr() gives an error saying it is an invalid storage and there is no data_ptr()

Both of them is not ideal.
The first one is grossly incrorrect due to two reasons. First we manage the backing storage for the inner tensors of QuantizedTensor and torch has no idea about it. Second nbytes based on fake_dtype is misleading since that might not actually be the number of bytes we actually allocate.
Second one is causing problems with FSDP2 now since it expects some storage for identity check.

For QuantizedTensor, noop_cat today always returns an actual torch.cat which goes through a dequantization luckily due to this condition being true. This condition is going to be true now with the change as well since nbytes() would return 0.

If we do QuantizedTensor.data_ptr() today it gives you 0. QuantizedTensor.untyped_storage().data_ptr() will give invalid storage error which is inconsistent. And giving empty storage as empty storage will fix this inconsitency.

As far as idenity checking goes, FSDP2 does all the comparisong logic only if data_ptr() is not 0. And it also doesnt really make sense to compare two empty storages.


def expand_as(self, other: torch.Tensor) -> torch.Tensor:
# pylint: disable=missing-function-docstring
Expand Down Expand Up @@ -608,6 +625,36 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
dst.copy_(src)
return None

# _to_copy op (used by .to(device=...), .cpu(), DCP staging).
# Preserve the QuantizedTensor subclass and move all internal
# buffers (data, scales, etc.) to the requested device.
if func == torch.ops.aten._to_copy.default:
tensor = args[0]
kw = dict(kwargs) if kwargs else {}
dtype = kw.get("dtype", None)
if dtype is None or dtype == tensor.dtype:
target_device = kw.get("device", tensor.device) or tensor.device
target_device = torch.device(target_device)
pin_memory = bool(kw.get("pin_memory", False))
non_blocking = bool(kw.get("non_blocking", False))
new_metadata = {"device": target_device}
# Update tensor storage metadata
for key, value in tensor.get_metadata().items():
if isinstance(value, torch.Tensor):
value = value.to(device=target_device, non_blocking=non_blocking)
if pin_memory and target_device.type == "cpu":
value = value.pin_memory()
new_metadata[key] = value
# Update torch Tensor metadata
new_metadata.update(
{
"dtype": tensor.dtype,
"shape": tensor.shape,
"requires_grad": tensor.requires_grad,
}
)
return type(tensor)(**new_metadata)

# View op
if func == torch.ops.aten.view.default:
raise NotImplementedError("{cls.__name__} class does not support tensor views")
Expand Down Expand Up @@ -748,14 +795,19 @@ def make_like(
"""Create new quantized tensor

By default, new tensor has the same attributes and underlying
data. This function is intended to create view of tensors.

data. This function is intended to create a view of ``tensor``,
"""
shape = shape if shape is not None else tensor.shape
dtype = dtype if dtype is not None else tensor.dtype
kwargs = tensor.get_metadata()
kwargs["fake_dtype"] = dtype
return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs)
return cls(
shape=shape,
dtype=dtype,
requires_grad=requires_grad,
device=tensor.device,
**kwargs,
)

def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor:
"""Create `QuantizedTensor` with given nominal dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def forward(
kwargs = tensor.get_metadata()
for key, val in init_kwargs.items():
kwargs[key] = val
kwargs["device"] = tensor.device
return type(tensor)(tensor.shape, tensor.dtype, **kwargs)

@staticmethod
Expand Down
Loading
Loading