diff --git a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py index ecda481ed9..1abb49e98c 100644 --- a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py +++ b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py @@ -228,13 +228,6 @@ def test_fused_adam_fp8_master_weights_no_meta(recipe_name): """ recipe = get_recipe_from_string(recipe_name) - if recipe_name in ("MXFP8BlockScaling", "Float8BlockScaling", "NVFP4BlockScaling"): - pytest.xfail( - f"{recipe_name}: FSDP2 all-gather hooks for block-scaling QuantizedTensor " - "subclasses fail when parameters are initialized on CUDA. " - "Use device='meta' + reset_parameters() after sharding." - ) - world_size, device = _get_dist_info() model = _build_model(fp8_init=True, recipe=recipe, use_meta_device=False) @@ -604,12 +597,6 @@ def test_safetensors_fp32_export(recipe_name): - Saved tensor shapes match expected (unsharded) shapes """ recipe = get_recipe_from_string(recipe_name) - if recipe_name == "MXFP8BlockScaling": - pytest.xfail( - "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " - "MXFP8 quantized tensors, causing illegal memory access. " - "Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789." - ) from safetensors.torch import load_file, save_file from torch.distributed.checkpoint.state_dict import ( @@ -692,26 +679,7 @@ def test_dcp_output_parity(recipe_name, async_save): """ recipe = get_recipe_from_string(recipe_name) - if recipe_name == "MXFP8BlockScaling": - pytest.xfail( - "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " - "MXFP8 quantized tensors, causing illegal memory access: " - "/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh:92 in function " - "multi_tensor_apply: CUDA Error: an illegal memory access was encountered. " - "Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789." - ) - - if recipe_name == "NVFP4BlockScaling": - pytest.xfail( - "NVFP4BlockScaling: DCP load_state_dict triggers reset_sharded_param() " - "which calls data_ptr() on NVFP4Tensor wrapper subclass with invalid storage" - ) - - if ( - recipe_name == "Float8BlockScaling" - and not async_save - and torch.cuda.get_device_capability()[0] == 12 - ): + if recipe_name == "Float8BlockScaling" and torch.cuda.get_device_capability()[0] == 12: pytest.xfail( "Float8BlockScaling is failing on SM120 with RuntimeError: " "transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu:534 " @@ -719,13 +687,6 @@ def test_dcp_output_parity(recipe_name, async_save): "Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, which " "requires using power of two scaling factors." ) - if recipe_name == "Float8BlockScaling" and async_save: - pytest.xfail( - "Float8BlockScaling: async DCP save/load round-trip produces different model " - "outputs — quantization metadata (scales) is not correctly persisted through " - "async distributed checkpointing. On SM120, additionally fails with pow2_scale " - "assertion in quantize_transpose_vector_blockwise." - ) import torch.distributed.checkpoint as dcp diff --git a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py index 6342e63e75..9383355fcc 100644 --- a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py +++ b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py @@ -380,19 +380,6 @@ def test_distributed(recipe_name, fp8_init, sharding_dims, layer_type): "data and scale_inv into a single buffer in pre_all_gather, split in post." ) - if recipe_name == "Float8BlockScaling" and fp8_init: - pytest.xfail( - "Float8BlockScaling + fp8_init: scale inverse padding is not handled " - "correctly during FSDP2 all-gather slice ops." - ) - if recipe_name == "NVFP4BlockScaling" and fp8_init and layer_type == "TransformerLayer": - pytest.xfail( - "NVFP4BlockScaling + fp8_init + TransformerLayer: " - "_check_fp8_fsdp2_allgather numerical error compounds across multiple " - "linear layers in the transformer block (up to ~1e-2 max abs diff). " - "LayerNormLinear passes with relaxed tolerances. " - "NVFP4 + FSDP2 training is validated by run_fsdp2_fused_adam.py." - ) torch.manual_seed(42) torch.cuda.manual_seed(42) diff --git a/tests/pytorch/test_quantized_tensor.py b/tests/pytorch/test_quantized_tensor.py index 526045e43e..119914fbc3 100644 --- a/tests/pytorch/test_quantized_tensor.py +++ b/tests/pytorch/test_quantized_tensor.py @@ -616,6 +616,56 @@ def test_identity_op( torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, dx_ref, **tols) + @pytest.mark.parametrize("quantization", _quantization_list) + def test_cpu_dequantize( + self, + *, + quantization: str, + shape: Iterable[int] = (128, 128), + dtype: torch.dtype = torch.bfloat16, + ) -> None: + """Dequantize on a CPU-resident QuantizedTensor.""" + + # Construct a quantized tensor on CUDA. + _, x_cuda = make_reference_and_test_tensors( + shape=shape, + quantization=quantization, + test_dtype=dtype, + requires_grad=False, + ) + assert isinstance(x_cuda, QuantizedTensor) + assert x_cuda.device.type == "cuda" + + # Reference: dequantize on CUDA, then move the dense result to CPU. + ref_cpu = x_cuda.dequantize().to(device="cpu") + + # Move the QuantizedTensor itself to CPU and dequantize there. + # ``.cpu()`` routes through ``aten._to_copy.default`` so all inner + # buffers (data, scales, amax) are moved to CPU. + x_cpu = x_cuda.cpu() + assert isinstance(x_cpu, QuantizedTensor) + assert x_cpu.device.type == "cpu" + for attr in ( + "_data", + "_rowwise_data", + "_columnwise_data", + "_rowwise_scale_inv", + "_columnwise_scale_inv", + "_amax_rowwise", + "_amax_columnwise", + ): + buf = getattr(x_cpu, attr, None) + if buf is not None: + assert buf.device.type == "cpu", f"{attr} did not move to CPU" + + # Dequantize the CPU tensor. Implementation may bounce through CUDA + # internally, but must return a CPU tensor. + y_cpu = x_cpu.dequantize() + assert y_cpu.device.type == "cpu" + assert y_cpu.dtype == ref_cpu.dtype + assert y_cpu.shape == ref_cpu.shape + torch.testing.assert_close(y_cpu, ref_cpu, rtol=0, atol=0) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("dim", [0, 1]) def test_chunk( diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index ef7687e3e9..ed48fe4d61 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -23,7 +23,16 @@ .value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ - .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \ + .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1) \ + .def("__reduce_ex__", \ + [](transformer_engine::DType self, pybind11::object /*protocol*/) { \ + return pybind11::make_tuple(pybind11::type::of(pybind11::cast(self)), \ + pybind11::make_tuple(static_cast(self))); \ + }) \ + .def("__reduce__", [](transformer_engine::DType self) { \ + return pybind11::make_tuple(pybind11::type::of(pybind11::cast(self)), \ + pybind11::make_tuple(static_cast(self))); \ + }); \ pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 3ff0d75ee4..aebdbd01a3 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -89,8 +89,67 @@ from transformer_engine.pytorch.tensor import MXFP8Tensor from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor import NVFP4Tensor +from transformer_engine.pytorch.tensor.float8_tensor import ( + _make_float8_tensor_in_reduce_ex, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import ( + _make_mxfp8_tensor_in_reduce_ex, +) +from transformer_engine.pytorch.tensor.nvfp4_tensor import ( + _make_nvfp4_tensor_in_reduce_ex, +) +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + _make_float8_blockwise_tensor_in_reduce_ex, +) try: torch._dynamo.config.error_on_nested_jit_trace = False except AttributeError: pass # error_on_nested_jit_trace was added in PyTorch 2.2.0 + +# To allow for safe unpickling of QuantizedTensors when using DCP +# checkpointing with FSDP2. ``tex.DType`` (the pybind11 enum) has its +# ``__reduce_ex__`` / ``__reduce__`` overridden in the C++ binding (see +# ``transformer_engine/common/util/pybind_helper.h``) so its pickle +# stream encodes as ``(tex.DType, (int,))`` and only the class itself +# needs to be allow-listed below. +import transformer_engine_torch as tex + +try: + from torch.serialization import add_safe_globals + + add_safe_globals( + [ + # Storage mixins (used during pickling of internal-only tensors) + QuantizedTensorStorage, + Float8TensorStorage, + MXFP8TensorStorage, + NVFP4TensorStorage, + Float8BlockwiseQTensorStorage, + # Quantizer types embedded in metadata + Quantizer, + Float8Quantizer, + Float8CurrentScalingQuantizer, + MXFP8Quantizer, + NVFP4Quantizer, + Float8BlockQuantizer, + # pybind11 enum used as Quantizer.dtype + tex.DType, + # __reduce_ex__ reconstructors (module-level functions). + _make_float8_tensor_in_reduce_ex, + _make_mxfp8_tensor_in_reduce_ex, + _make_nvfp4_tensor_in_reduce_ex, + _make_float8_blockwise_tensor_in_reduce_ex, + ] + ) +except (ImportError, AttributeError): + import warnings as _warnings + + _warnings.warn( + "transformer_engine: torch.serialization.add_safe_globals is " + "unavailable on this PyTorch version (added in 2.4). DCP " + "checkpointing of QuantizedTensor weights with FSDP2 will not " + "work; upgrade to PyTorch >= 2.4 to enable it.", + RuntimeWarning, + stacklevel=2, + ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 746177ec78..a1213fe493 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -44,6 +44,7 @@ from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage @@ -1641,7 +1642,9 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: raise RuntimeError("Weight quantizer has not been initialized") quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) quantizer.internal = False - if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer): + if is_dtensor and isinstance( + quantizer, (Float8CurrentScalingQuantizer, NVFP4Quantizer) + ): device_mesh = dtensor_param.device_mesh amax_reduction_group = ( device_mesh.get_group(mesh_dim="shard") diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 7163e2b172..404796fd63 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -552,9 +552,26 @@ def half(self) -> torch.Tensor: # pylint: disable=missing-function-docstring return self.dequantize(dtype=torch.float16) - def cpu(self, memory_format=torch.preserve_format) -> torch.Tensor: + def cpu(self, memory_format=torch.preserve_format) -> QuantizedTensor: + """Move tensor to CPU while preserving the QuantizedTensor type. + + Routes through ``aten._to_copy.default`` so the subclass-preserving + handler in ``__torch_dispatch__`` runs (rather than dequantizing). + + """ # pylint: disable=missing-function-docstring - return self.dequantize().cpu(memory_format=memory_format) + return self.to(device=torch.device("cpu"), memory_format=memory_format) + + def untyped_storage(self) -> torch.UntypedStorage: + """Return an empty UntypedStorage on the tensor's device. + + ``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real + backing storage of its own; the actual bytes live in the inner + buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are + an implementation detail of the quantization scheme. Need to define + this method to avoid DCP staging errors with FSDP2. + """ + return torch.UntypedStorage(0, device=self.device) def expand_as(self, other: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -608,6 +625,36 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst.copy_(src) return None + # _to_copy op (used by .to(device=...), .cpu(), DCP staging). + # Preserve the QuantizedTensor subclass and move all internal + # buffers (data, scales, etc.) to the requested device. + if func == torch.ops.aten._to_copy.default: + tensor = args[0] + kw = dict(kwargs) if kwargs else {} + dtype = kw.get("dtype", None) + if dtype is None or dtype == tensor.dtype: + target_device = kw.get("device", tensor.device) or tensor.device + target_device = torch.device(target_device) + pin_memory = bool(kw.get("pin_memory", False)) + non_blocking = bool(kw.get("non_blocking", False)) + new_metadata = {"device": target_device} + # Update tensor storage metadata + for key, value in tensor.get_metadata().items(): + if isinstance(value, torch.Tensor): + value = value.to(device=target_device, non_blocking=non_blocking) + if pin_memory and target_device.type == "cpu": + value = value.pin_memory() + new_metadata[key] = value + # Update torch Tensor metadata + new_metadata.update( + { + "dtype": tensor.dtype, + "shape": tensor.shape, + "requires_grad": tensor.requires_grad, + } + ) + return type(tensor)(**new_metadata) + # View op if func == torch.ops.aten.view.default: raise NotImplementedError("{cls.__name__} class does not support tensor views") @@ -748,14 +795,19 @@ def make_like( """Create new quantized tensor By default, new tensor has the same attributes and underlying - data. This function is intended to create view of tensors. - + data. This function is intended to create a view of ``tensor``, """ shape = shape if shape is not None else tensor.shape dtype = dtype if dtype is not None else tensor.dtype kwargs = tensor.get_metadata() kwargs["fake_dtype"] = dtype - return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs) + return cls( + shape=shape, + dtype=dtype, + requires_grad=requires_grad, + device=tensor.device, + **kwargs, + ) def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor: """Create `QuantizedTensor` with given nominal dtype diff --git a/transformer_engine/pytorch/tensor/_quantization_helpers.py b/transformer_engine/pytorch/tensor/_quantization_helpers.py index ba3407e13b..56cf503630 100644 --- a/transformer_engine/pytorch/tensor/_quantization_helpers.py +++ b/transformer_engine/pytorch/tensor/_quantization_helpers.py @@ -61,6 +61,7 @@ def forward( kwargs = tensor.get_metadata() for key, val in init_kwargs.items(): kwargs[key] = val + kwargs["device"] = tensor.device return type(tensor)(tensor.shape, tensor.dtype, **kwargs) @staticmethod diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index d0296902a9..d9c9510d17 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -333,21 +333,6 @@ def reshape(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring return _ReshapeFunc.apply(self, shape) - def untyped_storage(self) -> torch.UntypedStorage: - """Return the underlying UntypedStorage of the FP8 data. - - Note that FP8 block-scaled tensor may involve multiple - buffers: row-wise FP8 data, row-wise scales, column-wise FP8 - data, column-wise scales. The UntypedStorage of the row-wise - FP8 data is returned if it exists, and otherwise the - UntypedStorage of the column-wise FP8 data. - - """ - data = self._rowwise_data if self._rowwise_data is not None else self._columnwise_data - if data is not None: - return data.untyped_storage() - return torch.UntypedStorage(0, device=self.device) - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -432,6 +417,24 @@ def contiguous( return self raise ValueError("Float8BlockwiseQTensor does not support different memory formats!") + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling to remove references to FP8 metadata objects""" + return ( + _make_float8_blockwise_tensor_in_reduce_ex, + ( + self.shape, + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._fp8_dtype, + self.dtype, + self._quantizer, + self._is_2D_scaled, + None, # data_format + ), + ) + @classmethod def _make_in_reduce_ex( cls, @@ -444,7 +447,7 @@ def _make_in_reduce_ex( dtype: torch.dtype, quantizer: Quantizer, is_2D_scaled: bool, - data_format: Any = None, # pylint: disable=unused-argument + data_format: Any = None, ) -> Float8BlockwiseQTensor: """Build Float8BlockwiseQTensor, for use in __reduce__ @@ -452,34 +455,17 @@ def _make_in_reduce_ex( arguments. """ - return Float8BlockwiseQTensor( - shape=shape, - rowwise_data=rowwise_data, - rowwise_scale_inv=rowwise_scale_inv, - fp8_dtype=fp8_dtype, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, - dtype=dtype, - quantizer=quantizer, - is_2D_scaled=is_2D_scaled, - ) - - def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling to remove references to FP8 metadata objects""" - return ( - Float8BlockwiseQTensor._make_in_reduce_ex, - ( - self.shape, - self._rowwise_data, - self._rowwise_scale_inv, - self._columnwise_data, - self._columnwise_scale_inv, - self._fp8_dtype, - self.dtype, - self._quantizer, - self._is_2D_scaled, - None, # data_format - ), + return _make_float8_blockwise_tensor_in_reduce_ex( + shape, + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + fp8_dtype, + dtype, + quantizer, + is_2D_scaled, + data_format, ) def _get_data(self) -> Float8BlockwiseQTensor: @@ -653,6 +639,7 @@ def fsdp_post_all_gather( columnwise_scale_inv=None, quantizer=self._quantizer, is_2D_scaled=is_2D_scaled, + device=rowwise_data.device, ) # For 2D block scaling, derive columnwise data and scales from rowwise @@ -668,6 +655,46 @@ def fsdp_post_all_gather( return out, all_gather_outputs +def _make_float8_blockwise_tensor_in_reduce_ex( + shape: torch.Size, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + dtype: torch.dtype, + quantizer: Quantizer, + is_2D_scaled: bool, + data_format: Any = None, # pylint: disable=unused-argument +) -> Float8BlockwiseQTensor: + """Reconstruct a ``Float8BlockwiseQTensor`` from ``__reduce_ex__``. + + Defined at module level (not as a Float8BlockwiseQTensor classmethod) + so the pickle stream references it via a single + ``GLOBAL`` opcode rather than the ``(getattr, (cls, name))`` + reduction that bound classmethods produce. + """ + # Infer device from inner buffers so the wrapper subclass stays + # consistent with its data (e.g. CPU after DCP staging deserialize). + device = None + if rowwise_data is not None: + device = rowwise_data.device + elif columnwise_data is not None: + device = columnwise_data.device + return Float8BlockwiseQTensor( + shape=shape, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + fp8_dtype=fp8_dtype, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + dtype=dtype, + quantizer=quantizer, + is_2D_scaled=is_2D_scaled, + device=device, + ) + + class _ViewFunc(torch.autograd.Function): """View function @@ -749,6 +776,7 @@ def forward( quantizer=tensor._quantizer, is_2D_scaled=tensor._is_2D_scaled, requires_grad=tensor.requires_grad, + device=tensor.device, ) @staticmethod @@ -778,6 +806,7 @@ def backward( quantizer=grad._quantizer, is_2D_scaled=grad._is_2D_scaled, requires_grad=grad.requires_grad, + device=grad.device, ) return dgrad, None return grad.view(ctx.shape), None @@ -863,6 +892,7 @@ def forward( quantizer=tensor._quantizer, is_2D_scaled=tensor._is_2D_scaled, requires_grad=tensor.requires_grad, + device=tensor.device, ) @staticmethod @@ -891,6 +921,7 @@ def backward( quantizer=grad._quantizer, is_2D_scaled=grad._is_2D_scaled, requires_grad=grad.requires_grad, + device=grad.device, ) return dgrad, None return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index c4c5934f97..66567c1eb3 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -154,6 +154,7 @@ def create_tensor_from_data( requires_grad=requires_grad, data_transpose=None, quantizer=self, + device=data.device, ) def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: @@ -335,6 +336,7 @@ def create_tensor_from_data( requires_grad=requires_grad, data_transpose=None, quantizer=self, + device=data.device, ) def get_columnwise_shape(self, rowwise_data_shape: Iterable[int]) -> Tuple[int, ...]: @@ -355,6 +357,7 @@ def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: requires_grad=False, data_transpose=None, quantizer=self, + device=data.device, ) def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor: @@ -587,6 +590,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_dtype=tensor._fp8_dtype, data_transpose=out_transpose, quantizer=tensor._quantizer, + device=tensor.device, ) if func in (aten.slice.Tensor, aten.select.int): @@ -687,6 +691,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_scale_inv=scale_inv, data_transpose=func_transposed_out, quantizer=quantizer, + device=tensor.device, ) return out_tensor @@ -860,6 +865,7 @@ def fsdp_post_all_gather( "quantizer": self._quantizer, "requires_grad": False, "data": data, + "device": data.device, } out = Float8Tensor(**fp8_args) @@ -898,6 +904,20 @@ def is_cpu(self): return self._transpose.is_cpu raise RuntimeError("Both data and transpose are None") + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling to remove references to FP8 metadata objects. + + Always serializes the underlying FP8 buffers (no dequantization + fallback for CPU tensors) so that DCP async-staging round-trips + preserve bitwise-identical data. ``Float8Tensor`` is registered + with ``torch.serialization.add_safe_globals`` to keep + ``torch.load(weights_only=True)`` compatibility. + """ + return ( + _make_float8_tensor_in_reduce_ex, + (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape), + ) + @classmethod def _make_in_reduce_ex( cls, @@ -905,37 +925,14 @@ def _make_in_reduce_ex( fp8_dtype: TE_DType, fp8_scale_inv: torch.Tensor, dtype: torch.dtype, - shape: torch.shape, + shape: torch.Size, ) -> Float8Tensor: """Build Float8Tensor, for use in __reduce__ - __reduce_ex__ assumes object constructor has positional arguments. """ - return Float8Tensor( - data=data, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - dtype=dtype, - shape=shape, - ) - - def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling to remove references to FP8 metadata objects - - CPU Float8Tensors are serialized as dequantized plain tensors - for compatibility with torch.load(weights_only=True), which is - used by DCP async save staging. - """ - data_is_cpu = self._data is not None and self._data.is_cpu - transpose_is_cpu = self._transpose is not None and self._transpose.is_cpu - if data_is_cpu or transpose_is_cpu: - return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol) - return ( - Float8Tensor._make_in_reduce_ex, - (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape), - ) + return _make_float8_tensor_in_reduce_ex(data, fp8_dtype, fp8_scale_inv, dtype, shape) def _get_data(self) -> Float8Tensor: """Get tensor data property""" @@ -1000,6 +997,29 @@ def _set_data(self, tensor: torch.Tensor) -> None: data = property(_get_data, _set_data) +def _make_float8_tensor_in_reduce_ex( + data: torch.Tensor, + fp8_dtype: TE_DType, + fp8_scale_inv: torch.Tensor, + dtype: torch.dtype, + shape: torch.Size, +) -> Float8Tensor: + """Reconstruct a ``Float8Tensor`` from its ``__reduce_ex__`` payload. + Defined at module level (not as a Float8Tensor classmethod) so the pickle stream + references it via a single ``GLOBAL`` opcode rather than the + ``(getattr, (cls, name))`` reduction that bound classmethods/static + methods produce. + """ + return Float8Tensor( + data=data, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + dtype=dtype, + shape=shape, + device=data.device if data is not None else None, + ) + + class _ViewFunc(torch.autograd.Function): """View function @@ -1036,6 +1056,7 @@ def forward( fp8_dtype=tensor._fp8_dtype, data_transpose=out_transpose, quantizer=tensor._quantizer, + device=tensor.device, ) @staticmethod @@ -1083,6 +1104,7 @@ def forward( fp8_dtype=tensor._fp8_dtype, data_transpose=out_transpose, quantizer=tensor._quantizer, + device=tensor.device, ) @staticmethod diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 134f8b5a61..6c455b2690 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -161,6 +161,7 @@ def create_tensor_from_data( fp8_dtype=fp8_dtype, quantizer=self, with_gemm_swizzled_scales=False, + device=data.device, ) def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: @@ -346,6 +347,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): requires_grad=False, fp8_dtype=tensor._fp8_dtype, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + device=tensor.device, ) if func == torch.ops.aten.copy_.default: @@ -452,6 +454,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): requires_grad=False, fp8_dtype=tensor._fp8_dtype, with_gemm_swizzled_scales=False, + device=tensor.device, ) for splitted_tensor_data in zip(*out_data) ] @@ -541,6 +544,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): requires_grad=False, fp8_dtype=tensor._fp8_dtype, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + device=tensor.device, ) # Default case @@ -692,10 +696,30 @@ def fsdp_post_all_gather( shape=(rowwise_data.shape if rowwise_data is not None else columnwise_data.shape), quantizer=self._quantizer, with_gemm_swizzled_scales=False, + device=( + rowwise_data.device if rowwise_data is not None else columnwise_data.device + ), ) out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage) return out, all_gather_outputs + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling""" + return ( + _make_mxfp8_tensor_in_reduce_ex, + ( + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._fp8_dtype, + self.dtype, + self.shape, + self._quantizer, + self._with_gemm_swizzled_scales, + ), + ) + @classmethod def _make_in_reduce_ex( cls, @@ -705,7 +729,7 @@ def _make_in_reduce_ex( columnwise_scale_inv: torch.Tensor, fp8_dtype: TE_DType, dtype: torch.dtype, - shape: torch.shape, + shape: torch.Size, quantizer: Optional[Quantizer] = None, with_gemm_swizzled_scales: bool = False, ) -> MXFP8Tensor: @@ -715,33 +739,16 @@ def _make_in_reduce_ex( arguments. """ - return MXFP8Tensor( - rowwise_data=rowwise_data, - rowwise_scale_inv=rowwise_scale_inv, - fp8_dtype=fp8_dtype, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, - dtype=dtype, - shape=shape, - quantizer=quantizer, - with_gemm_swizzled_scales=with_gemm_swizzled_scales, - ) - - def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling""" - return ( - MXFP8Tensor._make_in_reduce_ex, - ( - self._rowwise_data, - self._rowwise_scale_inv, - self._columnwise_data, - self._columnwise_scale_inv, - self._fp8_dtype, - self.dtype, - self.shape, - self._quantizer, - self._with_gemm_swizzled_scales, - ), + return _make_mxfp8_tensor_in_reduce_ex( + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + fp8_dtype, + dtype, + shape, + quantizer, + with_gemm_swizzled_scales, ) def _get_data(self) -> MXFP8Tensor: @@ -832,6 +839,46 @@ def is_cuda(self): raise RuntimeError("MXFP8Tensor has no data!") +def _make_mxfp8_tensor_in_reduce_ex( + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + dtype: torch.dtype, + shape: torch.Size, + quantizer: Optional[Quantizer] = None, + with_gemm_swizzled_scales: bool = False, +) -> MXFP8Tensor: + """Reconstruct an ``MXFP8Tensor`` from its ``__reduce_ex__`` payload. + + Defined at module level (not as a MXFP8Tensor classmethod) so the pickle stream + references it via a single ``GLOBAL`` opcode rather than the + ``GLOBAL`` opcode rather than the ``(getattr, (cls, name))`` + reduction that bound classmethods produce. + """ + # Infer device from inner buffers so the wrapper subclass stays + # consistent with its data (CPU after DCP staging deserialize, + # CUDA after the usual quantize path). + device = None + if rowwise_data is not None: + device = rowwise_data.device + elif columnwise_data is not None: + device = columnwise_data.device + return MXFP8Tensor( + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + fp8_dtype=fp8_dtype, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + dtype=dtype, + shape=shape, + quantizer=quantizer, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, + device=device, + ) + + class _ViewFunc(torch.autograd.Function): """View function @@ -891,6 +938,7 @@ def forward( fp8_dtype=tensor._fp8_dtype, quantizer=tensor._quantizer, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + device=tensor.device, ) @staticmethod @@ -918,6 +966,7 @@ def backward( fp8_dtype=grad._fp8_dtype, quantizer=grad._quantizer, with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, + device=grad.device, ) return dgrad, None return grad.view(ctx.shape), None @@ -979,6 +1028,7 @@ def forward( fp8_dtype=tensor._fp8_dtype, quantizer=tensor._quantizer, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + device=tensor.device, ) @staticmethod @@ -1004,6 +1054,7 @@ def backward( columnwise_scale_inv=grad._columnwise_scale_inv, fp8_dtype=grad._fp8_dtype, quantizer=grad._quantizer, + device=grad.device, with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, ) return dgrad, None diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index df7a2b4bd3..3e84075707 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -584,6 +584,7 @@ def fsdp_post_all_gather( quantizer=self._quantizer, requires_grad=False, with_gemm_swizzled_scales=False, + device=rowwise_data.device, ) # Derive columnwise data locally via transpose instead of all-gathering it @@ -722,11 +723,31 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): quantizer=tensor._quantizer, requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + device=tensor.device, ) # Default case return super().__torch_dispatch__(func, types, args, kwargs) + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling""" + return ( + _make_nvfp4_tensor_in_reduce_ex, + ( + self.shape, + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._amax_rowwise, + self._amax_columnwise, + self._fp4_dtype, + self.dtype, + self._quantizer, + self._with_gemm_swizzled_scales, + ), + ) + @classmethod def _make_in_reduce_ex( cls, @@ -748,38 +769,18 @@ def _make_in_reduce_ex( arguments. """ - return NVFP4Tensor( - shape=shape, - dtype=dtype, - fp4_dtype=fp4_dtype, - rowwise_data=rowwise_data, - rowwise_scale_inv=rowwise_scale_inv, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, - amax_rowwise=amax_rowwise, - amax_columnwise=amax_columnwise, - quantizer=quantizer, - requires_grad=False, - with_gemm_swizzled_scales=with_gemm_swizzled_scales, - ) - - def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling""" - return ( - NVFP4Tensor._make_in_reduce_ex, - ( - self.shape, - self._rowwise_data, - self._rowwise_scale_inv, - self._columnwise_data, - self._columnwise_scale_inv, - self._amax_rowwise, - self._amax_columnwise, - self._fp4_dtype, - self.dtype, - self._quantizer, - self._with_gemm_swizzled_scales, - ), + return _make_nvfp4_tensor_in_reduce_ex( + shape, + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + amax_rowwise, + amax_columnwise, + fp4_dtype, + dtype, + quantizer, + with_gemm_swizzled_scales, ) def _get_data(self) -> NVFP4Tensor: @@ -872,6 +873,50 @@ def is_cuda(self): raise RuntimeError("NVFP4Tensor has no data!") +def _make_nvfp4_tensor_in_reduce_ex( + shape: torch.Size, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + amax_rowwise: torch.Tensor, + amax_columnwise: torch.Tensor, + fp4_dtype: TE_DType, + dtype: torch.dtype, + quantizer: Quantizer, + with_gemm_swizzled_scales: bool = False, +) -> NVFP4Tensor: + """Reconstruct an ``NVFP4Tensor`` from its ``__reduce_ex__`` payload. + + Defined at module level so the pickle stream uses a single + ``GLOBAL`` opcode rather than the ``(getattr, (cls, name))`` + reduction that bound classmethods produce. + """ + # Infer device from whichever inner buffer is populated so the wrapper + # subclass stays consistent with its data buffers (e.g. CPU after DCP + # async-staging deserialize, CUDA after the usual quantize path). + device = None + if rowwise_data is not None: + device = rowwise_data.device + elif columnwise_data is not None: + device = columnwise_data.device + return NVFP4Tensor( + shape=shape, + dtype=dtype, + fp4_dtype=fp4_dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + quantizer=quantizer, + requires_grad=False, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, + device=device, + ) + + class _ViewFunc(torch.autograd.Function): """View function @@ -951,6 +996,7 @@ def forward( fp4_dtype=tensor._fp4_dtype, requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + device=tensor.device, ) @staticmethod @@ -993,6 +1039,7 @@ def backward( fp4_dtype=grad._fp4_dtype, requires_grad=grad.requires_grad, with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, + device=grad.device, ) return dgrad, None return grad.view(ctx.shape), None @@ -1077,6 +1124,7 @@ def forward( fp4_dtype=tensor._fp4_dtype, requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + device=tensor.device, ) @staticmethod @@ -1119,6 +1167,7 @@ def backward( fp4_dtype=grad._fp4_dtype, requires_grad=grad.requires_grad, with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, + device=grad.device, ) return dgrad, None return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index de7f8f58e2..3a72ec5d1a 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -139,11 +139,6 @@ def get_metadata(self) -> Dict[str, Any]: "fp8_dtype": self._fp8_dtype, "data_transpose": self._transpose, "quantizer": self._quantizer, - "device": ( - self._data.device - if self._data is not None - else (self._transpose.device if self._transpose is not None else None) - ), "fake_dtype": self._dtype, } diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 842f42838b..874555f465 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -35,12 +35,16 @@ def forward( if tensor._columnwise_data is not None and tensor._columnwise_data.numel() == 0: return torch.empty(tensor.size(), dtype=dtype, device=tensor.device) - dtype = torch_to_transformer_engine_dtype[dtype] - - # Make sure FP8 data is in expected format - if tensor._rowwise_data is not None or tensor._columnwise_data is not None: - return tex.dequantize(tensor, dtype) - raise ValueError("Cannot dequantize MXFP8 tensor with no data") + if tensor._rowwise_data is None and tensor._columnwise_data is None: + raise ValueError("Cannot dequantize MXFP8 tensor with no data") + te_dtype = torch_to_transformer_engine_dtype[dtype] + # ``tex.dequantize`` requires CUDA-resident buffers. + src_device = tensor.device + if src_device.type != "cuda": + cuda_tensor = tensor.to(device=torch.device("cuda")) + result = tex.dequantize(cuda_tensor, te_dtype) + return result.to(device=src_device) + return tex.dequantize(tensor, te_dtype) @staticmethod def backward( diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index e51acb71e5..490184e5f8 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -47,13 +47,18 @@ def forward( if tensor._columnwise_data is not None and tensor._columnwise_data.numel() == 0: return torch.empty(tensor.size(), dtype=dtype, device=tensor.device) - # Dequantize row-wise data - if tensor._rowwise_data is not None: - return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype]) - - if tensor._columnwise_data is not None: + if tensor._rowwise_data is None and tensor._columnwise_data is None: + raise ValueError("Attempted to dequantize NVFP4 tensor with no data") + if tensor._rowwise_data is None and tensor._columnwise_data is not None: raise NotImplementedError("Dequantizing column-wise NVFP4 data is not implemented yet!") - raise ValueError("Attempted to dequantize NVFP4 tensor with no data") + + # ``tex.dequantize`` requires CUDA-resident buffers. If the tensor has + src_device = tensor.device + if src_device.type != "cuda": + cuda_tensor = tensor.to(device=torch.device("cuda")) + result = tex.dequantize(cuda_tensor, torch_to_transformer_engine_dtype[dtype]) + return result.to(device=src_device) + return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype]) @staticmethod def backward(