From 4197bee78341999183108c9d09d2cb88346196dc Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 13 May 2026 03:55:02 +0000 Subject: [PATCH 01/10] all changes in Signed-off-by: Varun Thumbe --- .../fsdp2_tests/run_fsdp2_fused_adam.py | 36 ----------- .../fsdp2_tests/run_fsdp2_model.py | 13 ---- transformer_engine/pytorch/__init__.py | 57 +++++++++++++++++ transformer_engine/pytorch/module/base.py | 5 +- .../pytorch/quantized_tensor.py | 63 +++++++++++++++++-- .../pytorch/tensor/_quantization_helpers.py | 1 + .../pytorch/tensor/float8_blockwise_tensor.py | 14 ----- .../pytorch/tensor/float8_tensor.py | 15 +++-- .../tensor/storage/float8_tensor_storage.py | 5 -- 9 files changed, 128 insertions(+), 81 deletions(-) diff --git a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py index ecda481ed9..92cebf2b53 100644 --- a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py +++ b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py @@ -228,13 +228,6 @@ def test_fused_adam_fp8_master_weights_no_meta(recipe_name): """ recipe = get_recipe_from_string(recipe_name) - if recipe_name in ("MXFP8BlockScaling", "Float8BlockScaling", "NVFP4BlockScaling"): - pytest.xfail( - f"{recipe_name}: FSDP2 all-gather hooks for block-scaling QuantizedTensor " - "subclasses fail when parameters are initialized on CUDA. " - "Use device='meta' + reset_parameters() after sharding." - ) - world_size, device = _get_dist_info() model = _build_model(fp8_init=True, recipe=recipe, use_meta_device=False) @@ -604,12 +597,6 @@ def test_safetensors_fp32_export(recipe_name): - Saved tensor shapes match expected (unsharded) shapes """ recipe = get_recipe_from_string(recipe_name) - if recipe_name == "MXFP8BlockScaling": - pytest.xfail( - "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " - "MXFP8 quantized tensors, causing illegal memory access. " - "Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789." - ) from safetensors.torch import load_file, save_file from torch.distributed.checkpoint.state_dict import ( @@ -692,24 +679,8 @@ def test_dcp_output_parity(recipe_name, async_save): """ recipe = get_recipe_from_string(recipe_name) - if recipe_name == "MXFP8BlockScaling": - pytest.xfail( - "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " - "MXFP8 quantized tensors, causing illegal memory access: " - "/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh:92 in function " - "multi_tensor_apply: CUDA Error: an illegal memory access was encountered. " - "Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789." - ) - - if recipe_name == "NVFP4BlockScaling": - pytest.xfail( - "NVFP4BlockScaling: DCP load_state_dict triggers reset_sharded_param() " - "which calls data_ptr() on NVFP4Tensor wrapper subclass with invalid storage" - ) - if ( recipe_name == "Float8BlockScaling" - and not async_save and torch.cuda.get_device_capability()[0] == 12 ): pytest.xfail( @@ -719,13 +690,6 @@ def test_dcp_output_parity(recipe_name, async_save): "Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, which " "requires using power of two scaling factors." ) - if recipe_name == "Float8BlockScaling" and async_save: - pytest.xfail( - "Float8BlockScaling: async DCP save/load round-trip produces different model " - "outputs — quantization metadata (scales) is not correctly persisted through " - "async distributed checkpointing. On SM120, additionally fails with pow2_scale " - "assertion in quantize_transpose_vector_blockwise." - ) import torch.distributed.checkpoint as dcp diff --git a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py index 6342e63e75..9383355fcc 100644 --- a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py +++ b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py @@ -380,19 +380,6 @@ def test_distributed(recipe_name, fp8_init, sharding_dims, layer_type): "data and scale_inv into a single buffer in pre_all_gather, split in post." ) - if recipe_name == "Float8BlockScaling" and fp8_init: - pytest.xfail( - "Float8BlockScaling + fp8_init: scale inverse padding is not handled " - "correctly during FSDP2 all-gather slice ops." - ) - if recipe_name == "NVFP4BlockScaling" and fp8_init and layer_type == "TransformerLayer": - pytest.xfail( - "NVFP4BlockScaling + fp8_init + TransformerLayer: " - "_check_fp8_fsdp2_allgather numerical error compounds across multiple " - "linear layers in the transformer block (up to ~1e-2 max abs diff). " - "LayerNormLinear passes with relaxed tolerances. " - "NVFP4 + FSDP2 training is validated by run_fsdp2_fused_adam.py." - ) torch.manual_seed(42) torch.cuda.manual_seed(42) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index d145cf0a21..6284e96cef 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -91,3 +91,60 @@ torch._dynamo.config.error_on_nested_jit_trace = False except AttributeError: pass # error_on_nested_jit_trace was added in PyTorch 2.2.0 + + +# Allow QuantizedTensor subclasses (and the metadata they pickle) to +# round-trip through ``torch.load(weights_only=True)``. DCP async-staging +# writes a torch.save / torch.load step internally, so without this the +# default safe-unpickler rejects our custom classes. +try: + from torch.serialization import add_safe_globals + from transformer_engine_torch import DType as _TE_DType + + add_safe_globals( + [ + # Wrapper subclasses + QuantizedTensor, + Float8Tensor, + MXFP8Tensor, + NVFP4Tensor, + Float8BlockwiseQTensor, + # Storage mixins (used during pickling of internal-only tensors) + QuantizedTensorStorage, + Float8TensorStorage, + MXFP8TensorStorage, + NVFP4TensorStorage, + Float8BlockwiseQTensorStorage, + # Quantizer types embedded in metadata + Quantizer, + Float8Quantizer, + Float8CurrentScalingQuantizer, + MXFP8Quantizer, + NVFP4Quantizer, + Float8BlockQuantizer, + # __reduce_ex__ constructors (bound classmethods). + Float8Tensor._make_in_reduce_ex, + MXFP8Tensor._make_in_reduce_ex, + NVFP4Tensor._make_in_reduce_ex, + Float8BlockwiseQTensor._make_in_reduce_ex, + # The pickle stream produced by ``__reduce_ex__`` references + # the pybind11 enum ``transformer_engine_torch.DType`` (e.g. + # the ``fp8_dtype`` argument) and uses ``builtins.getattr`` to + # resolve both the enum members and the bound-classmethod + # ``_make_in_reduce_ex`` callables above. Both must be + # allow-listed for ``torch.load(weights_only=True)`` (used + # internally by DCP async-staging) to accept the stream. + _TE_DType, + getattr, + ] + ) +except (ImportError, AttributeError): + import warnings as _warnings + _warnings.warn( + "transformer_engine: torch.serialization.add_safe_globals is " + "unavailable on this PyTorch version (added in 2.4). DCP " + "checkpointing of QuantizedTensor weights with FSDP2 will not " + "work; upgrade to PyTorch >= 2.4 to enable it.", + RuntimeWarning, + stacklevel=2, + ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index e6bedee0c0..873c980579 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -41,6 +41,7 @@ from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage @@ -1466,7 +1467,9 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: raise RuntimeError("Weight quantizer has not been initialized") quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) quantizer.internal = False - if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer): + if is_dtensor and isinstance( + quantizer, (Float8CurrentScalingQuantizer, NVFP4Quantizer) + ): device_mesh = dtensor_param.device_mesh amax_reduction_group = ( device_mesh.get_group(mesh_dim="shard") diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index a7722f777e..1470f5eca2 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -529,9 +529,26 @@ def half(self) -> torch.Tensor: # pylint: disable=missing-function-docstring return self.dequantize(dtype=torch.float16) - def cpu(self, memory_format=torch.preserve_format) -> torch.Tensor: + def cpu(self, memory_format=torch.preserve_format) -> QuantizedTensor: + """Move tensor to CPU while preserving the QuantizedTensor type. + + Routes through ``aten._to_copy.default`` so the subclass-preserving + handler in ``__torch_dispatch__`` runs (rather than dequantizing). + + """ # pylint: disable=missing-function-docstring - return self.dequantize().cpu(memory_format=memory_format) + return self.to(device=torch.device("cpu"), memory_format=memory_format) + + def untyped_storage(self) -> torch.UntypedStorage: + """Return an empty UntypedStorage on the tensor's device. + + ``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real + backing storage of its own; the actual bytes live in the inner + buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are + an implementation detail of the quantization scheme. Need to define + this method to avoid DCP staging errors with FSDP2. + """ + return torch.UntypedStorage(0, device=self.device) def expand_as(self, other: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -585,6 +602,34 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst.copy_(src) return None + # _to_copy op (used by .to(device=...), .cpu(), DCP staging). + # Preserve the QuantizedTensor subclass and move all internal + # buffers (data, scales, etc.) to the requested device. + if func == torch.ops.aten._to_copy.default: + tensor = args[0] + kw = dict(kwargs) if kwargs else {} + target_device = kw.get("device", tensor.device) or tensor.device + target_device = torch.device(target_device) + target_dtype = kw.get("dtype", tensor.dtype) or tensor.dtype + pin_memory = bool(kw.get("pin_memory", False)) + non_blocking = bool(kw.get("non_blocking", False)) + + new_metadata = {} + for key, value in tensor.get_metadata().items(): + if isinstance(value, torch.Tensor): + value = value.to(device=target_device, non_blocking=non_blocking) + if pin_memory and target_device.type == "cpu": + value = value.pin_memory() + new_metadata[key] = value + new_metadata["fake_dtype"] = target_dtype + return type(tensor)( + shape=tensor.shape, + dtype=target_dtype, + requires_grad=tensor.requires_grad, + device=target_device, + **new_metadata, + ) + # View op if func == torch.ops.aten.view.default: raise NotImplementedError("{cls.__name__} class does not support tensor views") @@ -725,14 +770,24 @@ def make_like( """Create new quantized tensor By default, new tensor has the same attributes and underlying - data. This function is intended to create view of tensors. + data. This function is intended to create a view of ``tensor``, + so the new tensor lives on the same device. To move quantized + data across devices use ``.to(device=...)`` / ``.cpu()`` / + ``aten._to_copy.default`` instead, which actually copies the + inner buffers. """ shape = shape if shape is not None else tensor.shape dtype = dtype if dtype is not None else tensor.dtype kwargs = tensor.get_metadata() kwargs["fake_dtype"] = dtype - return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs) + return cls( + shape=shape, + dtype=dtype, + requires_grad=requires_grad, + device=tensor.device, + **kwargs, + ) def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor: """Create `QuantizedTensor` with given nominal dtype diff --git a/transformer_engine/pytorch/tensor/_quantization_helpers.py b/transformer_engine/pytorch/tensor/_quantization_helpers.py index ba3407e13b..56cf503630 100644 --- a/transformer_engine/pytorch/tensor/_quantization_helpers.py +++ b/transformer_engine/pytorch/tensor/_quantization_helpers.py @@ -61,6 +61,7 @@ def forward( kwargs = tensor.get_metadata() for key, val in init_kwargs.items(): kwargs[key] = val + kwargs["device"] = tensor.device return type(tensor)(tensor.shape, tensor.dtype, **kwargs) @staticmethod diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 914397b9b6..0beca32a18 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -389,20 +389,6 @@ def reshape(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring return _ReshapeFunc.apply(self, shape) - def untyped_storage(self) -> torch.UntypedStorage: - """Return the underlying UntypedStorage of the FP8 data. - - Note that FP8 block-scaled tensor may involve multiple - buffers: row-wise FP8 data, row-wise scales, column-wise FP8 - data, column-wise scales. The UntypedStorage of the row-wise - FP8 data is returned if it exists, and otherwise the - UntypedStorage of the column-wise FP8 data. - - """ - data = self._rowwise_data if self._rowwise_data is not None else self._columnwise_data - if data is not None: - return data.untyped_storage() - return torch.UntypedStorage(0, device=self.device) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index ed6091c85b..ddee20b32e 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -1007,16 +1007,14 @@ def _make_in_reduce_ex( ) def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling to remove references to FP8 metadata objects + """Custom pickling to remove references to FP8 metadata objects. - CPU Float8Tensors are serialized as dequantized plain tensors - for compatibility with torch.load(weights_only=True), which is - used by DCP async save staging. + Always serializes the underlying FP8 buffers (no dequantization + fallback for CPU tensors) so that DCP async-staging round-trips + preserve bitwise-identical data. ``Float8Tensor`` is registered + with ``torch.serialization.add_safe_globals`` to keep + ``torch.load(weights_only=True)`` compatibility. """ - data_is_cpu = self._data is not None and self._data.is_cpu - transpose_is_cpu = self._transpose is not None and self._transpose.is_cpu - if data_is_cpu or transpose_is_cpu: - return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol) return ( Float8Tensor._make_in_reduce_ex, (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape), @@ -1177,3 +1175,4 @@ def backward( ) -> Tuple[Optional[torch.Tensor], ...]: # pylint: disable=missing-function-docstring return grad.reshape(ctx.shape), None + diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index de7f8f58e2..3a72ec5d1a 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -139,11 +139,6 @@ def get_metadata(self) -> Dict[str, Any]: "fp8_dtype": self._fp8_dtype, "data_transpose": self._transpose, "quantizer": self._quantizer, - "device": ( - self._data.device - if self._data is not None - else (self._transpose.device if self._transpose is not None else None) - ), "fake_dtype": self._dtype, } From 8496440690f401d5efadf87bd29c60d349847536 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 May 2026 04:01:37 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py | 5 +---- transformer_engine/pytorch/__init__.py | 1 + transformer_engine/pytorch/tensor/float8_blockwise_tensor.py | 1 - transformer_engine/pytorch/tensor/float8_tensor.py | 1 - 4 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py index 92cebf2b53..1abb49e98c 100644 --- a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py +++ b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py @@ -679,10 +679,7 @@ def test_dcp_output_parity(recipe_name, async_save): """ recipe = get_recipe_from_string(recipe_name) - if ( - recipe_name == "Float8BlockScaling" - and torch.cuda.get_device_capability()[0] == 12 - ): + if recipe_name == "Float8BlockScaling" and torch.cuda.get_device_capability()[0] == 12: pytest.xfail( "Float8BlockScaling is failing on SM120 with RuntimeError: " "transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu:534 " diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 6284e96cef..0c019c209c 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -140,6 +140,7 @@ ) except (ImportError, AttributeError): import warnings as _warnings + _warnings.warn( "transformer_engine: torch.serialization.add_safe_globals is " "unavailable on this PyTorch version (added in 2.4). DCP " diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 0beca32a18..6928543f43 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -389,7 +389,6 @@ def reshape(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring return _ReshapeFunc.apply(self, shape) - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index ddee20b32e..c7b69581a0 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -1175,4 +1175,3 @@ def backward( ) -> Tuple[Optional[torch.Tensor], ...]: # pylint: disable=missing-function-docstring return grad.reshape(ctx.shape), None - From dde83666cade2cb8b5686224d770eb58769352b2 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 13 May 2026 04:02:10 +0000 Subject: [PATCH 03/10] simplify Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/quantized_tensor.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 1470f5eca2..71d59c529f 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -771,11 +771,6 @@ def make_like( By default, new tensor has the same attributes and underlying data. This function is intended to create a view of ``tensor``, - so the new tensor lives on the same device. To move quantized - data across devices use ``.to(device=...)`` / ``.cpu()`` / - ``aten._to_copy.default`` instead, which actually copies the - inner buffers. - """ shape = shape if shape is not None else tensor.shape dtype = dtype if dtype is not None else tensor.dtype From 9cdfe7a3078881d6c11c096b129a4bc196347f0f Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 13 May 2026 04:19:13 +0000 Subject: [PATCH 04/10] address review comment Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/__init__.py | 42 ++++++---- .../pytorch/quantized_tensor.py | 28 +++++++ .../pytorch/tensor/float8_blockwise_tensor.py | 69 ++++++++-------- .../pytorch/tensor/float8_tensor.py | 54 +++++++------ .../pytorch/tensor/mxfp8_tensor.py | 67 ++++++++-------- .../pytorch/tensor/nvfp4_tensor.py | 79 ++++++++++--------- 6 files changed, 192 insertions(+), 147 deletions(-) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 0c019c209c..3da552833b 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -86,6 +86,18 @@ from transformer_engine.pytorch.tensor import MXFP8Tensor from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor import NVFP4Tensor +from transformer_engine.pytorch.tensor.float8_tensor import ( + _make_float8_tensor_in_reduce_ex, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import ( + _make_mxfp8_tensor_in_reduce_ex, +) +from transformer_engine.pytorch.tensor.nvfp4_tensor import ( + _make_nvfp4_tensor_in_reduce_ex, +) +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + _make_float8_blockwise_tensor_in_reduce_ex, +) try: torch._dynamo.config.error_on_nested_jit_trace = False @@ -97,9 +109,18 @@ # round-trip through ``torch.load(weights_only=True)``. DCP async-staging # writes a torch.save / torch.load step internally, so without this the # default safe-unpickler rejects our custom classes. +# +# The ``_make_*_in_reduce_ex`` reconstructors are defined as module-level +# functions (not classmethods) so they pickle as a single ``GLOBAL`` opcode +# rather than a ``(getattr, (cls, name))`` reduction. Their ``fp8_dtype`` / +# ``fp4_dtype`` arguments are passed as plain ``int`` values (converted back +# to the pybind11 ``transformer_engine_torch.DType`` enum on reconstruction) +# and ``Quantizer.__getstate__`` similarly serializes its embedded ``dtype`` +# as an ``int``. Together these keep the pickle stream free of pybind11-enum +# reductions and bound-classmethod references, so we don't need to allow-list +# ``builtins.getattr`` or the enum type itself for ``weights_only=True``. try: from torch.serialization import add_safe_globals - from transformer_engine_torch import DType as _TE_DType add_safe_globals( [ @@ -122,20 +143,11 @@ MXFP8Quantizer, NVFP4Quantizer, Float8BlockQuantizer, - # __reduce_ex__ constructors (bound classmethods). - Float8Tensor._make_in_reduce_ex, - MXFP8Tensor._make_in_reduce_ex, - NVFP4Tensor._make_in_reduce_ex, - Float8BlockwiseQTensor._make_in_reduce_ex, - # The pickle stream produced by ``__reduce_ex__`` references - # the pybind11 enum ``transformer_engine_torch.DType`` (e.g. - # the ``fp8_dtype`` argument) and uses ``builtins.getattr`` to - # resolve both the enum members and the bound-classmethod - # ``_make_in_reduce_ex`` callables above. Both must be - # allow-listed for ``torch.load(weights_only=True)`` (used - # internally by DCP async-staging) to accept the stream. - _TE_DType, - getattr, + # __reduce_ex__ reconstructors (module-level functions). + _make_float8_tensor_in_reduce_ex, + _make_mxfp8_tensor_in_reduce_ex, + _make_nvfp4_tensor_in_reduce_ex, + _make_float8_blockwise_tensor_in_reduce_ex, ] ) except (ImportError, AttributeError): diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 71d59c529f..f261a2ddaa 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -254,6 +254,34 @@ def __init__(self, *, rowwise: bool, columnwise: bool) -> None: self.internal = False self.optimize_for_gemm = False + def __getstate__(self): + """Custom pickling. + + FP8/FP4 quantizer subclasses store ``self.dtype`` as a + ``transformer_engine_torch.DType`` (pybind11 enum). Pybind11 + enums reduce as ``(getattr, (Enum, name))``, which would force + callers using ``torch.load(weights_only=True)`` (e.g. DCP + async-staging) to allow-list ``builtins.getattr``. Serialize + ``dtype`` as an ``int`` here so the pickle stream stays free of + those enum reductions. Subclass overrides should call + ``super().__getstate__()`` rather than ``self.__dict__.copy()`` + to preserve this behavior. + """ + from transformer_engine_torch import DType as _TE_DType + + state = self.__dict__.copy() + if isinstance(state.get("dtype"), _TE_DType): + state["dtype"] = int(state["dtype"]) + return state + + def __setstate__(self, state): + """Reconstruct ``dtype`` from its serialized ``int`` form.""" + from transformer_engine_torch import DType as _TE_DType + + if isinstance(state.get("dtype"), int): + state["dtype"] = _TE_DType(state["dtype"]) + self.__dict__.update(state) + def __repr__(self): return ( f"{self.__class__.__name__}(" diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 6928543f43..c70c30c4ae 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -473,49 +473,17 @@ def contiguous( return self raise ValueError("Float8BlockwiseQTensor does not support different memory formats!") - @classmethod - def _make_in_reduce_ex( - cls, - shape: torch.Size, - rowwise_data: torch.Tensor, - rowwise_scale_inv: torch.Tensor, - columnwise_data: torch.Tensor, - columnwise_scale_inv: torch.Tensor, - fp8_dtype: TE_DType, - dtype: torch.dtype, - quantizer: Quantizer, - is_2D_scaled: bool, - data_format: Any = None, # pylint: disable=unused-argument - ) -> Float8BlockwiseQTensor: - """Build Float8BlockwiseQTensor, for use in __reduce__ - - __reduce_ex__ assumes object constructor has positional - arguments. - - """ - return Float8BlockwiseQTensor( - shape=shape, - rowwise_data=rowwise_data, - rowwise_scale_inv=rowwise_scale_inv, - fp8_dtype=fp8_dtype, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, - dtype=dtype, - quantizer=quantizer, - is_2D_scaled=is_2D_scaled, - ) - def __reduce_ex__(self, protocol: int) -> tuple: """Custom pickling to remove references to FP8 metadata objects""" return ( - Float8BlockwiseQTensor._make_in_reduce_ex, + _make_float8_blockwise_tensor_in_reduce_ex, ( self.shape, self._rowwise_data, self._rowwise_scale_inv, self._columnwise_data, self._columnwise_scale_inv, - self._fp8_dtype, + int(self._fp8_dtype), self.dtype, self._quantizer, self._is_2D_scaled, @@ -709,6 +677,39 @@ def fsdp_post_all_gather( return out, all_gather_outputs +def _make_float8_blockwise_tensor_in_reduce_ex( + shape: torch.Size, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: int, + dtype: torch.dtype, + quantizer: Quantizer, + is_2D_scaled: bool, + data_format: Any = None, # pylint: disable=unused-argument +) -> Float8BlockwiseQTensor: + """Reconstruct a ``Float8BlockwiseQTensor`` from ``__reduce_ex__``. + + Defined at module level so the pickle stream uses a single + ``GLOBAL`` opcode rather than the ``(getattr, (cls, name))`` + reduction that bound classmethods produce. ``fp8_dtype`` is passed + as an ``int`` and converted back to the pybind11 ``TE_DType`` enum + here. + """ + return Float8BlockwiseQTensor( + shape=shape, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + fp8_dtype=TE_DType(fp8_dtype), + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + dtype=dtype, + quantizer=quantizer, + is_2D_scaled=is_2D_scaled, + ) + + class _ViewFunc(torch.autograd.Function): """View function diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index c7b69581a0..d81458b3c4 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -287,7 +287,7 @@ def __init__( def __getstate__(self): """Exclude unpicklable process group from serialized state.""" - state = self.__dict__.copy() + state = super().__getstate__() state["amax_reduction_group"] = None return state @@ -983,29 +983,6 @@ def is_cpu(self): return self._transpose.is_cpu raise RuntimeError("Both data and transpose are None") - @classmethod - def _make_in_reduce_ex( - cls, - data: torch.Tensor, - fp8_dtype: TE_DType, - fp8_scale_inv: torch.Tensor, - dtype: torch.dtype, - shape: torch.shape, - ) -> Float8Tensor: - """Build Float8Tensor, for use in __reduce__ - - __reduce_ex__ assumes object constructor has positional - arguments. - - """ - return Float8Tensor( - data=data, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - dtype=dtype, - shape=shape, - ) - def __reduce_ex__(self, protocol: int) -> tuple: """Custom pickling to remove references to FP8 metadata objects. @@ -1016,8 +993,8 @@ def __reduce_ex__(self, protocol: int) -> tuple: ``torch.load(weights_only=True)`` compatibility. """ return ( - Float8Tensor._make_in_reduce_ex, - (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape), + _make_float8_tensor_in_reduce_ex, + (self._data, int(self._fp8_dtype), self._scale_inv, self.dtype, self.shape), ) def _get_data(self) -> Float8Tensor: @@ -1083,6 +1060,31 @@ def _set_data(self, tensor: torch.Tensor) -> None: data = property(_get_data, _set_data) +def _make_float8_tensor_in_reduce_ex( + data: torch.Tensor, + fp8_dtype: int, + fp8_scale_inv: torch.Tensor, + dtype: torch.dtype, + shape: torch.Size, +) -> Float8Tensor: + """Reconstruct a ``Float8Tensor`` from its ``__reduce_ex__`` payload. + + Defined at module level (not as a classmethod) so the pickle stream + references it via a single ``GLOBAL`` opcode rather than the + ``(getattr, (cls, name))`` reduction that bound classmethods/static + methods produce. ``fp8_dtype`` is passed as an ``int`` and converted + back to the pybind11 ``TE_DType`` enum here so the pickle stream + stays free of enum reductions as well. + """ + return Float8Tensor( + data=data, + fp8_dtype=TE_DType(fp8_dtype), + fp8_scale_inv=fp8_scale_inv, + dtype=dtype, + shape=shape, + ) + + class _ViewFunc(torch.autograd.Function): """View function diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5cab519c79..96f4b9554a 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -760,47 +760,16 @@ def fsdp_post_all_gather( out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage) return out, all_gather_outputs - @classmethod - def _make_in_reduce_ex( - cls, - rowwise_data: torch.Tensor, - rowwise_scale_inv: torch.Tensor, - columnwise_data: torch.Tensor, - columnwise_scale_inv: torch.Tensor, - fp8_dtype: TE_DType, - dtype: torch.dtype, - shape: torch.shape, - quantizer: Optional[Quantizer] = None, - with_gemm_swizzled_scales: bool = False, - ) -> MXFP8Tensor: - """Build MXFP8Tensor, for use in __reduce__ - - __reduce_ex__ assumes object constructor has positional - arguments. - - """ - return MXFP8Tensor( - rowwise_data=rowwise_data, - rowwise_scale_inv=rowwise_scale_inv, - fp8_dtype=fp8_dtype, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, - dtype=dtype, - shape=shape, - quantizer=quantizer, - with_gemm_swizzled_scales=with_gemm_swizzled_scales, - ) - def __reduce_ex__(self, protocol: int) -> tuple: """Custom pickling""" return ( - MXFP8Tensor._make_in_reduce_ex, + _make_mxfp8_tensor_in_reduce_ex, ( self._rowwise_data, self._rowwise_scale_inv, self._columnwise_data, self._columnwise_scale_inv, - self._fp8_dtype, + int(self._fp8_dtype), self.dtype, self.shape, self._quantizer, @@ -896,6 +865,38 @@ def is_cuda(self): raise RuntimeError("MXFP8Tensor has no data!") +def _make_mxfp8_tensor_in_reduce_ex( + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: int, + dtype: torch.dtype, + shape: torch.Size, + quantizer: Optional[Quantizer] = None, + with_gemm_swizzled_scales: bool = False, +) -> MXFP8Tensor: + """Reconstruct an ``MXFP8Tensor`` from its ``__reduce_ex__`` payload. + + Defined at module level so the pickle stream uses a single + ``GLOBAL`` opcode rather than the ``(getattr, (cls, name))`` + reduction that bound classmethods produce. ``fp8_dtype`` is passed + as an ``int`` and converted back to the pybind11 ``TE_DType`` enum + here. + """ + return MXFP8Tensor( + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + fp8_dtype=TE_DType(fp8_dtype), + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + dtype=dtype, + shape=shape, + quantizer=quantizer, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, + ) + + class _ViewFunc(torch.autograd.Function): """View function diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 285a7f030a..0ec472c592 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -165,7 +165,7 @@ def __init__( def __getstate__(self): """Exclude unpicklable process group from serialized state.""" - state = self.__dict__.copy() + state = super().__getstate__() state["amax_reduction_group"] = None return state @@ -820,46 +820,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Default case return super().__torch_dispatch__(func, types, args, kwargs) - @classmethod - def _make_in_reduce_ex( - cls, - shape: torch.Size, - rowwise_data: torch.Tensor, - rowwise_scale_inv: torch.Tensor, - columnwise_data: torch.Tensor, - columnwise_scale_inv: torch.Tensor, - amax_rowwise: torch.Tensor, - amax_columnwise: torch.Tensor, - fp4_dtype: TE_DType, - dtype: torch.dtype, - quantizer: Quantizer, - with_gemm_swizzled_scales: bool = False, - ) -> NVFP4Tensor: - """Build NVFP4Tensor, for use in __reduce__ - - __reduce_ex__ assumes object constructor has positional - arguments. - - """ - return NVFP4Tensor( - shape=shape, - dtype=dtype, - fp4_dtype=fp4_dtype, - rowwise_data=rowwise_data, - rowwise_scale_inv=rowwise_scale_inv, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, - amax_rowwise=amax_rowwise, - amax_columnwise=amax_columnwise, - quantizer=quantizer, - requires_grad=False, - with_gemm_swizzled_scales=with_gemm_swizzled_scales, - ) - def __reduce_ex__(self, protocol: int) -> tuple: """Custom pickling""" return ( - NVFP4Tensor._make_in_reduce_ex, + _make_nvfp4_tensor_in_reduce_ex, ( self.shape, self._rowwise_data, @@ -868,7 +832,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._columnwise_scale_inv, self._amax_rowwise, self._amax_columnwise, - self._fp4_dtype, + int(self._fp4_dtype), self.dtype, self._quantizer, self._with_gemm_swizzled_scales, @@ -965,6 +929,43 @@ def is_cuda(self): raise RuntimeError("NVFP4Tensor has no data!") +def _make_nvfp4_tensor_in_reduce_ex( + shape: torch.Size, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + amax_rowwise: torch.Tensor, + amax_columnwise: torch.Tensor, + fp4_dtype: int, + dtype: torch.dtype, + quantizer: Quantizer, + with_gemm_swizzled_scales: bool = False, +) -> NVFP4Tensor: + """Reconstruct an ``NVFP4Tensor`` from its ``__reduce_ex__`` payload. + + Defined at module level so the pickle stream uses a single + ``GLOBAL`` opcode rather than the ``(getattr, (cls, name))`` + reduction that bound classmethods produce. ``fp4_dtype`` is passed + as an ``int`` and converted back to the pybind11 ``TE_DType`` enum + here. + """ + return NVFP4Tensor( + shape=shape, + dtype=dtype, + fp4_dtype=TE_DType(fp4_dtype), + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + quantizer=quantizer, + requires_grad=False, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, + ) + + class _ViewFunc(torch.autograd.Function): """View function From ec01bb6eb9d1c81ada01dbd8ec12f0dab881a2ff Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Thu, 14 May 2026 03:21:11 +0000 Subject: [PATCH 05/10] fix CI, Test for CPU quantized tensor Signed-off-by: Varun Thumbe --- tests/pytorch/test_quantized_tensor.py | 51 +++++++++++++++++++ .../pytorch/quantized_tensor.py | 42 +++++++-------- .../pytorch/tensor/float8_blockwise_tensor.py | 14 +++++ .../pytorch/tensor/float8_tensor.py | 9 ++++ .../pytorch/tensor/mxfp8_tensor.py | 21 ++++++++ .../pytorch/tensor/nvfp4_tensor.py | 16 ++++++ .../tensor/storage/mxfp8_tensor_storage.py | 16 +++--- .../tensor/storage/nvfp4_tensor_storage.py | 17 ++++--- 8 files changed, 153 insertions(+), 33 deletions(-) diff --git a/tests/pytorch/test_quantized_tensor.py b/tests/pytorch/test_quantized_tensor.py index 526045e43e..d019929639 100644 --- a/tests/pytorch/test_quantized_tensor.py +++ b/tests/pytorch/test_quantized_tensor.py @@ -616,6 +616,57 @@ def test_identity_op( torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, dx_ref, **tols) + @pytest.mark.parametrize("quantization", _quantization_list) + def test_cpu_dequantize( + self, + *, + quantization: str, + shape: Iterable[int] = (128, 128), + dtype: torch.dtype = torch.bfloat16, + ) -> None: + """Dequantize on a CPU-resident QuantizedTensor. + """ + + # Construct a quantized tensor on CUDA. + _, x_cuda = make_reference_and_test_tensors( + shape=shape, + quantization=quantization, + test_dtype=dtype, + requires_grad=False, + ) + assert isinstance(x_cuda, QuantizedTensor) + assert x_cuda.device.type == "cuda" + + # Reference: dequantize on CUDA, then move the dense result to CPU. + ref_cpu = x_cuda.dequantize().to(device="cpu") + + # Move the QuantizedTensor itself to CPU and dequantize there. + # ``.cpu()`` routes through ``aten._to_copy.default`` so all inner + # buffers (data, scales, amax) are moved to CPU. + x_cpu = x_cuda.cpu() + assert isinstance(x_cpu, QuantizedTensor) + assert x_cpu.device.type == "cpu" + for attr in ( + "_data", + "_rowwise_data", + "_columnwise_data", + "_rowwise_scale_inv", + "_columnwise_scale_inv", + "_amax_rowwise", + "_amax_columnwise", + ): + buf = getattr(x_cpu, attr, None) + if buf is not None: + assert buf.device.type == "cpu", f"{attr} did not move to CPU" + + # Dequantize the CPU tensor. Implementation may bounce through CUDA + # internally, but must return a CPU tensor. + y_cpu = x_cpu.dequantize() + assert y_cpu.device.type == "cpu" + assert y_cpu.dtype == ref_cpu.dtype + assert y_cpu.shape == ref_cpu.shape + torch.testing.assert_close(y_cpu, ref_cpu, rtol=0, atol=0) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("dim", [0, 1]) def test_chunk( diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index f261a2ddaa..9b99a1978a 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -636,27 +636,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten._to_copy.default: tensor = args[0] kw = dict(kwargs) if kwargs else {} - target_device = kw.get("device", tensor.device) or tensor.device - target_device = torch.device(target_device) - target_dtype = kw.get("dtype", tensor.dtype) or tensor.dtype - pin_memory = bool(kw.get("pin_memory", False)) - non_blocking = bool(kw.get("non_blocking", False)) - - new_metadata = {} - for key, value in tensor.get_metadata().items(): - if isinstance(value, torch.Tensor): - value = value.to(device=target_device, non_blocking=non_blocking) - if pin_memory and target_device.type == "cpu": - value = value.pin_memory() - new_metadata[key] = value - new_metadata["fake_dtype"] = target_dtype - return type(tensor)( - shape=tensor.shape, - dtype=target_dtype, - requires_grad=tensor.requires_grad, - device=target_device, - **new_metadata, - ) + dtype = kw.get("dtype", None) + if dtype is None or dtype == tensor.dtype: + target_device = kw.get("device", tensor.device) or tensor.device + target_device = torch.device(target_device) + pin_memory = bool(kw.get("pin_memory", False)) + non_blocking = bool(kw.get("non_blocking", False)) + new_metadata = {"device": target_device} + # Update tensor storage metadata + for key, value in tensor.get_metadata().items(): + if isinstance(value, torch.Tensor): + value = value.to(device=target_device, non_blocking=non_blocking) + if pin_memory and target_device.type == "cpu": + value = value.pin_memory() + new_metadata[key] = value + # Update torch Tensor metadata + new_metadata.update({ + "dtype": tensor.dtype, + "shape": tensor.shape, + "requires_grad": tensor.requires_grad, + }) + return type(tensor)(**new_metadata) # View op if func == torch.ops.aten.view.default: diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index c70c30c4ae..64c4882344 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -256,6 +256,7 @@ def make_empty( quantizer=self, is_2D_scaled=self.block_scaling_dim == 2, requires_grad=requires_grad, + device=tensor_kwargs["device"], ) def calibrate(self, tensor: torch.Tensor) -> None: @@ -662,6 +663,7 @@ def fsdp_post_all_gather( columnwise_scale_inv=None, quantizer=self._quantizer, is_2D_scaled=is_2D_scaled, + device=rowwise_data.device, ) # For 2D block scaling, derive columnwise data and scales from rowwise @@ -697,6 +699,13 @@ def _make_float8_blockwise_tensor_in_reduce_ex( as an ``int`` and converted back to the pybind11 ``TE_DType`` enum here. """ + # Infer device from inner buffers so the wrapper subclass stays + # consistent with its data (e.g. CPU after DCP staging deserialize). + device = None + if rowwise_data is not None: + device = rowwise_data.device + elif columnwise_data is not None: + device = columnwise_data.device return Float8BlockwiseQTensor( shape=shape, rowwise_data=rowwise_data, @@ -707,6 +716,7 @@ def _make_float8_blockwise_tensor_in_reduce_ex( dtype=dtype, quantizer=quantizer, is_2D_scaled=is_2D_scaled, + device=device, ) @@ -791,6 +801,7 @@ def forward( quantizer=tensor._quantizer, is_2D_scaled=tensor._is_2D_scaled, requires_grad=tensor.requires_grad, + device=tensor.device, ) @staticmethod @@ -820,6 +831,7 @@ def backward( quantizer=grad._quantizer, is_2D_scaled=grad._is_2D_scaled, requires_grad=grad.requires_grad, + device=grad.device, ) return dgrad, None return grad.view(ctx.shape), None @@ -905,6 +917,7 @@ def forward( quantizer=tensor._quantizer, is_2D_scaled=tensor._is_2D_scaled, requires_grad=tensor.requires_grad, + device=tensor.device, ) @staticmethod @@ -933,6 +946,7 @@ def backward( quantizer=grad._quantizer, is_2D_scaled=grad._is_2D_scaled, requires_grad=grad.requires_grad, + device=grad.device, ) return dgrad, None return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index d81458b3c4..5b93a74ee3 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -197,6 +197,7 @@ def create_tensor_from_data( requires_grad=requires_grad, data_transpose=None, quantizer=self, + device=data.device, ) def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: @@ -420,6 +421,7 @@ def create_tensor_from_data( requires_grad=requires_grad, data_transpose=None, quantizer=self, + device=data.device, ) def get_columnwise_shape(self, rowwise_data_shape: Iterable[int]) -> Tuple[int, ...]: @@ -440,6 +442,7 @@ def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: requires_grad=False, data_transpose=None, quantizer=self, + device=data.device, ) def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor: @@ -672,6 +675,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_dtype=tensor._fp8_dtype, data_transpose=out_transpose, quantizer=tensor._quantizer, + device=tensor.device, ) if func in (aten.slice.Tensor, aten.select.int): @@ -772,6 +776,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_scale_inv=scale_inv, data_transpose=func_transposed_out, quantizer=quantizer, + device=tensor.device, ) return out_tensor @@ -945,6 +950,7 @@ def fsdp_post_all_gather( "quantizer": self._quantizer, "requires_grad": False, "data": data, + "device": data.device, } out = Float8Tensor(**fp8_args) @@ -1082,6 +1088,7 @@ def _make_float8_tensor_in_reduce_ex( fp8_scale_inv=fp8_scale_inv, dtype=dtype, shape=shape, + device=data.device ) @@ -1121,6 +1128,7 @@ def forward( fp8_dtype=tensor._fp8_dtype, data_transpose=out_transpose, quantizer=tensor._quantizer, + device=tensor.device, ) @staticmethod @@ -1168,6 +1176,7 @@ def forward( fp8_dtype=tensor._fp8_dtype, data_transpose=out_transpose, quantizer=tensor._quantizer, + device=tensor.device, ) @staticmethod diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 96f4b9554a..186ee689cc 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -158,6 +158,7 @@ def make_empty( quantizer=self, requires_grad=requires_grad, with_gemm_swizzled_scales=self.optimize_for_gemm, + device=device, ) def calibrate(self, tensor: torch.Tensor) -> None: @@ -225,6 +226,7 @@ def create_tensor_from_data( fp8_dtype=fp8_dtype, quantizer=self, with_gemm_swizzled_scales=False, + device=data.device, ) def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: @@ -410,6 +412,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): requires_grad=False, fp8_dtype=tensor._fp8_dtype, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + device=tensor.device, ) if func == torch.ops.aten.copy_.default: @@ -516,6 +519,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): requires_grad=False, fp8_dtype=tensor._fp8_dtype, with_gemm_swizzled_scales=False, + device=tensor.device, ) for splitted_tensor_data in zip(*out_data) ] @@ -605,6 +609,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): requires_grad=False, fp8_dtype=tensor._fp8_dtype, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + device=tensor.device, ) # Default case @@ -756,6 +761,9 @@ def fsdp_post_all_gather( shape=(rowwise_data.shape if rowwise_data is not None else columnwise_data.shape), quantizer=self._quantizer, with_gemm_swizzled_scales=False, + device=( + rowwise_data.device if rowwise_data is not None else columnwise_data.device + ), ) out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage) return out, all_gather_outputs @@ -884,6 +892,14 @@ def _make_mxfp8_tensor_in_reduce_ex( as an ``int`` and converted back to the pybind11 ``TE_DType`` enum here. """ + # Infer device from inner buffers so the wrapper subclass stays + # consistent with its data (CPU after DCP staging deserialize, + # CUDA after the usual quantize path). + device = None + if rowwise_data is not None: + device = rowwise_data.device + elif columnwise_data is not None: + device = columnwise_data.device return MXFP8Tensor( rowwise_data=rowwise_data, rowwise_scale_inv=rowwise_scale_inv, @@ -894,6 +910,7 @@ def _make_mxfp8_tensor_in_reduce_ex( shape=shape, quantizer=quantizer, with_gemm_swizzled_scales=with_gemm_swizzled_scales, + device=device, ) @@ -956,6 +973,7 @@ def forward( fp8_dtype=tensor._fp8_dtype, quantizer=tensor._quantizer, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + device=tensor.device, ) @staticmethod @@ -983,6 +1001,7 @@ def backward( fp8_dtype=grad._fp8_dtype, quantizer=grad._quantizer, with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, + device=grad.device, ) return dgrad, None return grad.view(ctx.shape), None @@ -1044,6 +1063,7 @@ def forward( fp8_dtype=tensor._fp8_dtype, quantizer=tensor._quantizer, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + device=tensor.device, ) @staticmethod @@ -1069,6 +1089,7 @@ def backward( columnwise_scale_inv=grad._columnwise_scale_inv, fp8_dtype=grad._fp8_dtype, quantizer=grad._quantizer, + device=grad.device, with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, ) return dgrad, None diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 0ec472c592..26f3799da5 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -388,6 +388,7 @@ def make_empty( requires_grad=requires_grad, with_gemm_swizzled_scales=False, row_scaled_nvfp4=self.row_scaled_nvfp4, + device=device, ) def calibrate(self, tensor: torch.Tensor) -> None: @@ -677,6 +678,7 @@ def fsdp_post_all_gather( quantizer=self._quantizer, requires_grad=False, with_gemm_swizzled_scales=False, + device=rowwise_data.device, ) # Derive columnwise data locally via transpose instead of all-gathering it @@ -815,6 +817,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): quantizer=tensor._quantizer, requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + device=tensor.device, ) # Default case @@ -950,6 +953,14 @@ def _make_nvfp4_tensor_in_reduce_ex( as an ``int`` and converted back to the pybind11 ``TE_DType`` enum here. """ + # Infer device from whichever inner buffer is populated so the wrapper + # subclass stays consistent with its data buffers (e.g. CPU after DCP + # async-staging deserialize, CUDA after the usual quantize path). + device = None + if rowwise_data is not None: + device = rowwise_data.device + elif columnwise_data is not None: + device = columnwise_data.device return NVFP4Tensor( shape=shape, dtype=dtype, @@ -963,6 +974,7 @@ def _make_nvfp4_tensor_in_reduce_ex( quantizer=quantizer, requires_grad=False, with_gemm_swizzled_scales=with_gemm_swizzled_scales, + device=device, ) @@ -1045,6 +1057,7 @@ def forward( fp4_dtype=tensor._fp4_dtype, requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + device=tensor.device, ) @staticmethod @@ -1087,6 +1100,7 @@ def backward( fp4_dtype=grad._fp4_dtype, requires_grad=grad.requires_grad, with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, + device=grad.device, ) return dgrad, None return grad.view(ctx.shape), None @@ -1171,6 +1185,7 @@ def forward( fp4_dtype=tensor._fp4_dtype, requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + device=tensor.device, ) @staticmethod @@ -1213,6 +1228,7 @@ def backward( fp4_dtype=grad._fp4_dtype, requires_grad=grad.requires_grad, with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, + device=grad.device, ) return dgrad, None return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 842f42838b..874555f465 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -35,12 +35,16 @@ def forward( if tensor._columnwise_data is not None and tensor._columnwise_data.numel() == 0: return torch.empty(tensor.size(), dtype=dtype, device=tensor.device) - dtype = torch_to_transformer_engine_dtype[dtype] - - # Make sure FP8 data is in expected format - if tensor._rowwise_data is not None or tensor._columnwise_data is not None: - return tex.dequantize(tensor, dtype) - raise ValueError("Cannot dequantize MXFP8 tensor with no data") + if tensor._rowwise_data is None and tensor._columnwise_data is None: + raise ValueError("Cannot dequantize MXFP8 tensor with no data") + te_dtype = torch_to_transformer_engine_dtype[dtype] + # ``tex.dequantize`` requires CUDA-resident buffers. + src_device = tensor.device + if src_device.type != "cuda": + cuda_tensor = tensor.to(device=torch.device("cuda")) + result = tex.dequantize(cuda_tensor, te_dtype) + return result.to(device=src_device) + return tex.dequantize(tensor, te_dtype) @staticmethod def backward( diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index e51acb71e5..490184e5f8 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -47,13 +47,18 @@ def forward( if tensor._columnwise_data is not None and tensor._columnwise_data.numel() == 0: return torch.empty(tensor.size(), dtype=dtype, device=tensor.device) - # Dequantize row-wise data - if tensor._rowwise_data is not None: - return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype]) - - if tensor._columnwise_data is not None: + if tensor._rowwise_data is None and tensor._columnwise_data is None: + raise ValueError("Attempted to dequantize NVFP4 tensor with no data") + if tensor._rowwise_data is None and tensor._columnwise_data is not None: raise NotImplementedError("Dequantizing column-wise NVFP4 data is not implemented yet!") - raise ValueError("Attempted to dequantize NVFP4 tensor with no data") + + # ``tex.dequantize`` requires CUDA-resident buffers. If the tensor has + src_device = tensor.device + if src_device.type != "cuda": + cuda_tensor = tensor.to(device=torch.device("cuda")) + result = tex.dequantize(cuda_tensor, torch_to_transformer_engine_dtype[dtype]) + return result.to(device=src_device) + return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype]) @staticmethod def backward( From 7bf7f2e5aba727d77bf890029161da9def1b0603 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 May 2026 03:25:08 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_quantized_tensor.py | 3 +-- transformer_engine/pytorch/quantized_tensor.py | 12 +++++++----- transformer_engine/pytorch/tensor/float8_tensor.py | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/test_quantized_tensor.py b/tests/pytorch/test_quantized_tensor.py index d019929639..119914fbc3 100644 --- a/tests/pytorch/test_quantized_tensor.py +++ b/tests/pytorch/test_quantized_tensor.py @@ -624,8 +624,7 @@ def test_cpu_dequantize( shape: Iterable[int] = (128, 128), dtype: torch.dtype = torch.bfloat16, ) -> None: - """Dequantize on a CPU-resident QuantizedTensor. - """ + """Dequantize on a CPU-resident QuantizedTensor.""" # Construct a quantized tensor on CUDA. _, x_cuda = make_reference_and_test_tensors( diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 24fa097e22..074ab9f3fb 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -674,11 +674,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): value = value.pin_memory() new_metadata[key] = value # Update torch Tensor metadata - new_metadata.update({ - "dtype": tensor.dtype, - "shape": tensor.shape, - "requires_grad": tensor.requires_grad, - }) + new_metadata.update( + { + "dtype": tensor.dtype, + "shape": tensor.shape, + "requires_grad": tensor.requires_grad, + } + ) return type(tensor)(**new_metadata) # View op diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index c3569c7172..58753c8d73 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -1003,7 +1003,7 @@ def _make_float8_tensor_in_reduce_ex( fp8_scale_inv=fp8_scale_inv, dtype=dtype, shape=shape, - device=data.device + device=data.device, ) From e7f3441c09a1e2df6672d8bc38b4a6581245483d Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Thu, 14 May 2026 04:55:43 +0000 Subject: [PATCH 07/10] add things thats just necessary Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/__init__.py | 24 ++---------------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 7951052c37..7cc7058696 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -107,32 +107,12 @@ except AttributeError: pass # error_on_nested_jit_trace was added in PyTorch 2.2.0 - -# Allow QuantizedTensor subclasses (and the metadata they pickle) to -# round-trip through ``torch.load(weights_only=True)``. DCP async-staging -# writes a torch.save / torch.load step internally, so without this the -# default safe-unpickler rejects our custom classes. -# -# The ``_make_*_in_reduce_ex`` reconstructors are defined as module-level -# functions (not classmethods) so they pickle as a single ``GLOBAL`` opcode -# rather than a ``(getattr, (cls, name))`` reduction. Their ``fp8_dtype`` / -# ``fp4_dtype`` arguments are passed as plain ``int`` values (converted back -# to the pybind11 ``transformer_engine_torch.DType`` enum on reconstruction) -# and ``Quantizer.__getstate__`` similarly serializes its embedded ``dtype`` -# as an ``int``. Together these keep the pickle stream free of pybind11-enum -# reductions and bound-classmethod references, so we don't need to allow-list -# ``builtins.getattr`` or the enum type itself for ``weights_only=True``. +# To allow for safe unpickling of QuantizedTensors when +# using DCP checkpointing with FSDP2. try: from torch.serialization import add_safe_globals - add_safe_globals( [ - # Wrapper subclasses - QuantizedTensor, - Float8Tensor, - MXFP8Tensor, - NVFP4Tensor, - Float8BlockwiseQTensor, # Storage mixins (used during pickling of internal-only tensors) QuantizedTensorStorage, Float8TensorStorage, From cdb23b6198b8c899d0c02c28ce763fd7bfcbc5d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 May 2026 04:57:11 +0000 Subject: [PATCH 08/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 7cc7058696..7e83b23658 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -111,6 +111,7 @@ # using DCP checkpointing with FSDP2. try: from torch.serialization import add_safe_globals + add_safe_globals( [ # Storage mixins (used during pickling of internal-only tensors) From cb449813aad5ad88da24cd5b80270a144438fac7 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Thu, 14 May 2026 23:27:06 +0000 Subject: [PATCH 09/10] fix test Signed-off-by: Varun Thumbe --- .../common/util/pybind_helper.h | 13 ++++- transformer_engine/pytorch/__init__.py | 12 ++++- .../pytorch/quantized_tensor.py | 28 ----------- .../pytorch/tensor/float8_blockwise_tensor.py | 46 +++++++++++++++--- .../pytorch/tensor/float8_tensor.py | 35 ++++++++++---- .../pytorch/tensor/mxfp8_tensor.py | 44 ++++++++++++++--- .../pytorch/tensor/nvfp4_tensor.py | 47 ++++++++++++++++--- 7 files changed, 163 insertions(+), 62 deletions(-) diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index ef7687e3e9..df33497cf1 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -23,7 +23,18 @@ .value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ - .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \ + .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1) \ + .def( \ + "__reduce_ex__", \ + [](transformer_engine::DType self, pybind11::object /*protocol*/) { \ + return pybind11::make_tuple(pybind11::type::of(pybind11::cast(self)), \ + pybind11::make_tuple(static_cast(self))); \ + }) \ + .def("__reduce__", \ + [](transformer_engine::DType self) { \ + return pybind11::make_tuple(pybind11::type::of(pybind11::cast(self)), \ + pybind11::make_tuple(static_cast(self))); \ + }); \ pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 7e83b23658..aebdbd01a3 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -107,8 +107,14 @@ except AttributeError: pass # error_on_nested_jit_trace was added in PyTorch 2.2.0 -# To allow for safe unpickling of QuantizedTensors when -# using DCP checkpointing with FSDP2. +# To allow for safe unpickling of QuantizedTensors when using DCP +# checkpointing with FSDP2. ``tex.DType`` (the pybind11 enum) has its +# ``__reduce_ex__`` / ``__reduce__`` overridden in the C++ binding (see +# ``transformer_engine/common/util/pybind_helper.h``) so its pickle +# stream encodes as ``(tex.DType, (int,))`` and only the class itself +# needs to be allow-listed below. +import transformer_engine_torch as tex + try: from torch.serialization import add_safe_globals @@ -127,6 +133,8 @@ MXFP8Quantizer, NVFP4Quantizer, Float8BlockQuantizer, + # pybind11 enum used as Quantizer.dtype + tex.DType, # __reduce_ex__ reconstructors (module-level functions). _make_float8_tensor_in_reduce_ex, _make_mxfp8_tensor_in_reduce_ex, diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 074ab9f3fb..404796fd63 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -256,34 +256,6 @@ def __init__(self, *, rowwise: bool, columnwise: bool) -> None: self.internal = False self.optimize_for_gemm = False - def __getstate__(self): - """Custom pickling. - - FP8/FP4 quantizer subclasses store ``self.dtype`` as a - ``transformer_engine_torch.DType`` (pybind11 enum). Pybind11 - enums reduce as ``(getattr, (Enum, name))``, which would force - callers using ``torch.load(weights_only=True)`` (e.g. DCP - async-staging) to allow-list ``builtins.getattr``. Serialize - ``dtype`` as an ``int`` here so the pickle stream stays free of - those enum reductions. Subclass overrides should call - ``super().__getstate__()`` rather than ``self.__dict__.copy()`` - to preserve this behavior. - """ - from transformer_engine_torch import DType as _TE_DType - - state = self.__dict__.copy() - if isinstance(state.get("dtype"), _TE_DType): - state["dtype"] = int(state["dtype"]) - return state - - def __setstate__(self, state): - """Reconstruct ``dtype`` from its serialized ``int`` form.""" - from transformer_engine_torch import DType as _TE_DType - - if isinstance(state.get("dtype"), int): - state["dtype"] = _TE_DType(state["dtype"]) - self.__dict__.update(state) - def __repr__(self): return ( f"{self.__class__.__name__}(" diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 8b64ed5b57..d9c9510d17 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -427,7 +427,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._rowwise_scale_inv, self._columnwise_data, self._columnwise_scale_inv, - int(self._fp8_dtype), + self._fp8_dtype, self.dtype, self._quantizer, self._is_2D_scaled, @@ -435,6 +435,39 @@ def __reduce_ex__(self, protocol: int) -> tuple: ), ) + @classmethod + def _make_in_reduce_ex( + cls, + shape: torch.Size, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + dtype: torch.dtype, + quantizer: Quantizer, + is_2D_scaled: bool, + data_format: Any = None, + ) -> Float8BlockwiseQTensor: + """Build Float8BlockwiseQTensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return _make_float8_blockwise_tensor_in_reduce_ex( + shape, + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + fp8_dtype, + dtype, + quantizer, + is_2D_scaled, + data_format, + ) + def _get_data(self) -> Float8BlockwiseQTensor: """Get tensor data property""" return self @@ -628,7 +661,7 @@ def _make_float8_blockwise_tensor_in_reduce_ex( rowwise_scale_inv: torch.Tensor, columnwise_data: torch.Tensor, columnwise_scale_inv: torch.Tensor, - fp8_dtype: int, + fp8_dtype: TE_DType, dtype: torch.dtype, quantizer: Quantizer, is_2D_scaled: bool, @@ -636,11 +669,10 @@ def _make_float8_blockwise_tensor_in_reduce_ex( ) -> Float8BlockwiseQTensor: """Reconstruct a ``Float8BlockwiseQTensor`` from ``__reduce_ex__``. - Defined at module level so the pickle stream uses a single + Defined at module level (not as a Float8BlockwiseQTensor classmethod) + so the pickle stream references it via a single ``GLOBAL`` opcode rather than the ``(getattr, (cls, name))`` - reduction that bound classmethods produce. ``fp8_dtype`` is passed - as an ``int`` and converted back to the pybind11 ``TE_DType`` enum - here. + reduction that bound classmethods produce. """ # Infer device from inner buffers so the wrapper subclass stays # consistent with its data (e.g. CPU after DCP staging deserialize). @@ -653,7 +685,7 @@ def _make_float8_blockwise_tensor_in_reduce_ex( shape=shape, rowwise_data=rowwise_data, rowwise_scale_inv=rowwise_scale_inv, - fp8_dtype=TE_DType(fp8_dtype), + fp8_dtype=fp8_dtype, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, dtype=dtype, diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 58753c8d73..0a3fca2ae9 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -245,7 +245,7 @@ def __init__( def __getstate__(self): """Exclude unpicklable process group from serialized state.""" - state = super().__getstate__() + state = self.__dict__.copy() state["amax_reduction_group"] = None return state @@ -915,7 +915,25 @@ def __reduce_ex__(self, protocol: int) -> tuple: """ return ( _make_float8_tensor_in_reduce_ex, - (self._data, int(self._fp8_dtype), self._scale_inv, self.dtype, self.shape), + (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape), + ) + + @classmethod + def _make_in_reduce_ex( + cls, + data: torch.Tensor, + fp8_dtype: TE_DType, + fp8_scale_inv: torch.Tensor, + dtype: torch.dtype, + shape: torch.Size, + ) -> Float8Tensor: + """Build Float8Tensor, for use in __reduce__ + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return _make_float8_tensor_in_reduce_ex( + data, fp8_dtype, fp8_scale_inv, dtype, shape ) def _get_data(self) -> Float8Tensor: @@ -983,27 +1001,24 @@ def _set_data(self, tensor: torch.Tensor) -> None: def _make_float8_tensor_in_reduce_ex( data: torch.Tensor, - fp8_dtype: int, + fp8_dtype: TE_DType, fp8_scale_inv: torch.Tensor, dtype: torch.dtype, shape: torch.Size, ) -> Float8Tensor: """Reconstruct a ``Float8Tensor`` from its ``__reduce_ex__`` payload. - - Defined at module level (not as a classmethod) so the pickle stream + Defined at module level (not as a Float8Tensor classmethod) so the pickle stream references it via a single ``GLOBAL`` opcode rather than the ``(getattr, (cls, name))`` reduction that bound classmethods/static - methods produce. ``fp8_dtype`` is passed as an ``int`` and converted - back to the pybind11 ``TE_DType`` enum here so the pickle stream - stays free of enum reductions as well. + methods produce. """ return Float8Tensor( data=data, - fp8_dtype=TE_DType(fp8_dtype), + fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv, dtype=dtype, shape=shape, - device=data.device, + device=data.device if data is not None else None, ) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index f8b4c658ad..6c455b2690 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -712,7 +712,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._rowwise_scale_inv, self._columnwise_data, self._columnwise_scale_inv, - int(self._fp8_dtype), + self._fp8_dtype, self.dtype, self.shape, self._quantizer, @@ -720,6 +720,37 @@ def __reduce_ex__(self, protocol: int) -> tuple: ), ) + @classmethod + def _make_in_reduce_ex( + cls, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + dtype: torch.dtype, + shape: torch.Size, + quantizer: Optional[Quantizer] = None, + with_gemm_swizzled_scales: bool = False, + ) -> MXFP8Tensor: + """Build MXFP8Tensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return _make_mxfp8_tensor_in_reduce_ex( + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + fp8_dtype, + dtype, + shape, + quantizer, + with_gemm_swizzled_scales, + ) + def _get_data(self) -> MXFP8Tensor: """Get tensor data property""" return super().data @@ -813,7 +844,7 @@ def _make_mxfp8_tensor_in_reduce_ex( rowwise_scale_inv: torch.Tensor, columnwise_data: torch.Tensor, columnwise_scale_inv: torch.Tensor, - fp8_dtype: int, + fp8_dtype: TE_DType, dtype: torch.dtype, shape: torch.Size, quantizer: Optional[Quantizer] = None, @@ -821,11 +852,10 @@ def _make_mxfp8_tensor_in_reduce_ex( ) -> MXFP8Tensor: """Reconstruct an ``MXFP8Tensor`` from its ``__reduce_ex__`` payload. - Defined at module level so the pickle stream uses a single + Defined at module level (not as a MXFP8Tensor classmethod) so the pickle stream + references it via a single ``GLOBAL`` opcode rather than the ``GLOBAL`` opcode rather than the ``(getattr, (cls, name))`` - reduction that bound classmethods produce. ``fp8_dtype`` is passed - as an ``int`` and converted back to the pybind11 ``TE_DType`` enum - here. + reduction that bound classmethods produce. """ # Infer device from inner buffers so the wrapper subclass stays # consistent with its data (CPU after DCP staging deserialize, @@ -838,7 +868,7 @@ def _make_mxfp8_tensor_in_reduce_ex( return MXFP8Tensor( rowwise_data=rowwise_data, rowwise_scale_inv=rowwise_scale_inv, - fp8_dtype=TE_DType(fp8_dtype), + fp8_dtype=fp8_dtype, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, dtype=dtype, diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index c0d50035d4..3e84075707 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -165,7 +165,7 @@ def __init__( def __getstate__(self): """Exclude unpicklable process group from serialized state.""" - state = super().__getstate__() + state = self.__dict__.copy() state["amax_reduction_group"] = None return state @@ -741,13 +741,48 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._columnwise_scale_inv, self._amax_rowwise, self._amax_columnwise, - int(self._fp4_dtype), + self._fp4_dtype, self.dtype, self._quantizer, self._with_gemm_swizzled_scales, ), ) + @classmethod + def _make_in_reduce_ex( + cls, + shape: torch.Size, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + amax_rowwise: torch.Tensor, + amax_columnwise: torch.Tensor, + fp4_dtype: TE_DType, + dtype: torch.dtype, + quantizer: Quantizer, + with_gemm_swizzled_scales: bool = False, + ) -> NVFP4Tensor: + """Build NVFP4Tensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return _make_nvfp4_tensor_in_reduce_ex( + shape, + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + amax_rowwise, + amax_columnwise, + fp4_dtype, + dtype, + quantizer, + with_gemm_swizzled_scales, + ) + def _get_data(self) -> NVFP4Tensor: """Get tensor data property""" return super().data @@ -846,7 +881,7 @@ def _make_nvfp4_tensor_in_reduce_ex( columnwise_scale_inv: torch.Tensor, amax_rowwise: torch.Tensor, amax_columnwise: torch.Tensor, - fp4_dtype: int, + fp4_dtype: TE_DType, dtype: torch.dtype, quantizer: Quantizer, with_gemm_swizzled_scales: bool = False, @@ -855,9 +890,7 @@ def _make_nvfp4_tensor_in_reduce_ex( Defined at module level so the pickle stream uses a single ``GLOBAL`` opcode rather than the ``(getattr, (cls, name))`` - reduction that bound classmethods produce. ``fp4_dtype`` is passed - as an ``int`` and converted back to the pybind11 ``TE_DType`` enum - here. + reduction that bound classmethods produce. """ # Infer device from whichever inner buffer is populated so the wrapper # subclass stays consistent with its data buffers (e.g. CPU after DCP @@ -870,7 +903,7 @@ def _make_nvfp4_tensor_in_reduce_ex( return NVFP4Tensor( shape=shape, dtype=dtype, - fp4_dtype=TE_DType(fp4_dtype), + fp4_dtype=fp4_dtype, rowwise_data=rowwise_data, rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, From ddc1ece68119eeac7c8cacdd9ef8ba9fc56d8879 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 May 2026 23:28:52 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/util/pybind_helper.h | 16 +++++++--------- .../pytorch/tensor/float8_tensor.py | 4 +--- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index df33497cf1..ed48fe4d61 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -24,17 +24,15 @@ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1) \ - .def( \ - "__reduce_ex__", \ - [](transformer_engine::DType self, pybind11::object /*protocol*/) { \ - return pybind11::make_tuple(pybind11::type::of(pybind11::cast(self)), \ - pybind11::make_tuple(static_cast(self))); \ - }) \ - .def("__reduce__", \ - [](transformer_engine::DType self) { \ + .def("__reduce_ex__", \ + [](transformer_engine::DType self, pybind11::object /*protocol*/) { \ return pybind11::make_tuple(pybind11::type::of(pybind11::cast(self)), \ pybind11::make_tuple(static_cast(self))); \ - }); \ + }) \ + .def("__reduce__", [](transformer_engine::DType self) { \ + return pybind11::make_tuple(pybind11::type::of(pybind11::cast(self)), \ + pybind11::make_tuple(static_cast(self))); \ + }); \ pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 0a3fca2ae9..66567c1eb3 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -932,9 +932,7 @@ def _make_in_reduce_ex( arguments. """ - return _make_float8_tensor_in_reduce_ex( - data, fp8_dtype, fp8_scale_inv, dtype, shape - ) + return _make_float8_tensor_in_reduce_ex(data, fp8_dtype, fp8_scale_inv, dtype, shape) def _get_data(self) -> Float8Tensor: """Get tensor data property"""