Skip to content

[PyTorch] Python DType enum#3039

Open
vthumbe1503 wants to merge 27 commits into
NVIDIA:mainfrom
vthumbe1503:te_dtype
Open

[PyTorch] Python DType enum#3039
vthumbe1503 wants to merge 27 commits into
NVIDIA:mainfrom
vthumbe1503:te_dtype

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 commented May 22, 2026

Description

Replace the pybind tex.DType with a canonical Python DType IntEnum throughout transformer_engine.pytorch. For backward compat, cpp->python and python->cpp DType object conversions are cached at the pybind boundaries to reduce CPU overheads.

Motivation

  • CPU overheads: tex.DType is a pybind enum, so every access/compare/convert in Python crosses into C-extension code.
  • torch.compile: tex.DType won't work with torch.compile — TorchDynamo doesn't understand pybind enums, so it graph-breaks (or fails to trace) when one flows through a compiled region.
  • Checkpointing: tex.DType lives in tensor/quantizer state and lands in checkpoints; pickling a pybind enum is fragile and awkward to allow-list vs. a stdlib python enum.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Canonical DType: Add a pure-Python DType(IntEnum) in constants.py as the single source of truth. Its members are defined from the C++ enum values, and an import-time assert verifies it stays in sync with the pybinded enum tex.DType
  • Migration: Repoint TE_DType maps and move pytorch modules, examples, benchmarks, and tests off raw tex.DType onto constants.DType.
  • Backward compatibility: Add DTypeSupported = Union[DType, tex.DType]; tex.DType is still accepted at constructor boundaries and stays allow-listed for loading old checkpoints.
  • Python → C++: Register a cached pybind implicit conversion (dtype_pybind_conversion.h) so a constants.DType is auto-accepted wherever a C++ tex.DType is expected.
  • C++ → Python: Add cached MakePythonDType (csrc/common.*) and use it at quantizer/quantizedtensor construction.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vthumbe1503 and others added 2 commits May 22, 2026 21:50
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title initial prototype TE_DType in python May 22, 2026
// when this runs, so the GIL is held and Python imports are legal.
static pybind11::object te_dtype_cls =
pybind11::module_::import("transformer_engine.pytorch.constants").attr("TE_DType");
return te_dtype_cls(static_cast<int>(dtype));
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Find a way to bind C++ and python Dtype through pybind cast mechanism

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done for Python. -> C++

For C++ to Python. --> Cant avoid this.

Comment thread transformer_engine/pytorch/__init__.py Outdated
# pybind11 enum used as Quantizer.dtype
tex.DType,
# Python IntEnum used as Quantizer.dtype
TE_DType,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save/load backward compatibilty should be there

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

vthumbe1503 and others added 14 commits May 31, 2026 19:45
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title TE_DType in python [PyTorch] Python DType enum Jun 1, 2026
vthumbe1503 and others added 2 commits June 1, 2026 08:27
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as ready for review June 1, 2026 08:32
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 1, 2026

Greptile Summary

This PR replaces the pybind11 tex.DType enum with a canonical Python DType(IntEnum) in transformer_engine/pytorch/constants.py, migrating all internal usages across 65 files. Cross-boundary compatibility is provided by a new type caster (dtype_pybind_conversion.h) that accepts both types on the Python→C++ path, and by MakePythonDType for C++→Python construction at quantizer/tensor callsites.

  • New DType(IntEnum): Defined in constants.py from the C++ enum values; import-time assert verifies sync; __eq__/__ne__/__hash__ overrides keep DType and tex.DType interoperable.
  • C++ type caster (dtype_pybind_conversion.h): Accepts IntEnum, plain int, or tex.DType on the load path; the cast path still returns cached tex.DType objects for the standard pybind11 direction.
  • MakePythonDType (csrc/common.cpp): A cached helper that explicitly returns constants.DType members at quantizer/tensor construction sites in quantizer.cpp and cast.cpp.

Confidence Score: 4/5

Safe to merge with minor defensive hardening recommended in MakePythonDType and the pybind eq lambda.

The migration is thorough across all 65 files, the type caster correctly handles both IntEnum and pybind enums in both conversion directions, and the equality/hash contracts between DType and tex.DType are properly maintained. Two non-blocking quality improvements are noted: a missing null guard in MakePythonDType that would produce a confusing error rather than a silent wrong result, and an unnecessary py::cast(self) per comparison in the eq lambda.

transformer_engine/pytorch/csrc/common.cpp (MakePythonDType null-return path) and transformer_engine/pytorch/csrc/extensions/pybind.cpp (eq lambda)

Important Files Changed

Filename Overview
transformer_engine/pytorch/constants.py Adds DType(IntEnum) as the canonical Python dtype; import-time assert syncs it with tex.DType; updates TE_DType/TE_DType_To_Torch maps to use DType.
transformer_engine/common/util/dtype_pybind_conversion.h New type caster; load path handles int/IntEnum/tex.DType; cast path returns cached tex.DType singletons (by design).
transformer_engine/pytorch/csrc/common.cpp Adds MakePythonDType with magic-static cache; missing null-guard on cache[idx] return value.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Patches eq/ne on tex.DType; lambda unnecessarily calls py::cast(self) on every invocation.
transformer_engine/pytorch/csrc/quantizer.cpp All py::cast(this->dtype) calls replaced with MakePythonDType; no missed sites.
transformer_engine/pytorch/init.py DType added to exports and pickle allowlist; tex.DType kept for backward-compatible checkpoint loading.
transformer_engine/pytorch/tensor/float8_tensor.py Float8Quantizer.dtype changed to DType; DType.cast() applied at constructor boundaries.

Sequence Diagram

sequenceDiagram
    participant PY as Python caller
    participant DType as constants.DType (IntEnum)
    participant Caster as dtype_pybind_conversion.h
    participant CPP as C++ transformer_engine::DType
    participant MakePY as MakePythonDType()

    Note over PY,MakePY: Python to C++ path
    PY->>Caster: pybind load(src)
    alt src is int/IntEnum
        Caster->>CPP: PyLong_AsLong to static_cast DType
    else src is tex.DType
        Caster->>CPP: type_caster_base::load
    end

    Note over PY,MakePY: C++ to Python path
    CPP->>MakePY: MakePythonDType(dtype)
    MakePY-->>DType: import transformer_engine.pytorch.DType cached
    MakePY-->>PY: constants.DType member

    Note over PY,MakePY: Equality
    PY->>DType: "DType.kX == tex.DType.kY"
    DType->>DType: isinstance tex.DType then int comparison
    PY->>CPP: "tex.DType.kX == DType.kY"
    CPP->>CPP: isinstance py int then int comparison
Loading

Reviews (5): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +44 to +49
# Fail fast at import time if a new enumerator is added
# on the C++ side without being mirrored above.
assert {m.name for m in DType} == set(tex.DType.__members__), (
"DType is out of sync with transformer_engine_torch.DType; "
"add the new pybind enumerator to DType in constants.py."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Import-time sync check can be silently skipped

Python's -O (optimize) flag strips all assert statements, so this import-time guard that verifies DType is in sync with tex.DType will never run in optimized/production builds. A build where a new C++ enumerator was added without updating DType would import without error and produce silent mismatches downstream. Replace with an explicit if ... raise.

Comment on lines +81 to +86
/*! @brief Register the Python -> C++ ``DType`` implicit conversion.
* Allows a Python object of type ``transformer_engine.pytorch.constants.DType``
* to be passed wherever a pybind-bound ``transformer_engine::DType`` argument is expected.
* pybind-bound ``transformer_engine::DType`` argument is expected.
* Must be called after the pybind ``DType`` enum has been registered.
*/
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Duplicate sentence in the docstring — the line "pybind-bound transformer_engine::DType argument is expected." appears twice.

Suggested change
/*! @brief Register the Python -> C++ ``DType`` implicit conversion.
* Allows a Python object of type ``transformer_engine.pytorch.constants.DType``
* to be passed wherever a pybind-bound ``transformer_engine::DType`` argument is expected.
* pybind-bound ``transformer_engine::DType`` argument is expected.
* Must be called after the pybind ``DType`` enum has been registered.
*/
/*! @brief Register the Python -> C++ ``DType`` implicit conversion.
* Allows a Python object of type ``transformer_engine.pytorch.constants.DType``
* to be passed wherever a pybind-bound ``transformer_engine::DType`` argument is expected.
* Must be called after the pybind ``DType`` enum has been registered.
*/

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +14 to +41
class DType(enum.IntEnum):
"""Python mirror of ``transformer_engine_torch.DType`` (pybind11 enum).
Members are constructed manually from the underlying pybind enum so
that this class is the single source of truth for dtype tags used
across ``transformer_engine.pytorch``.
"""

kByte = int(tex.DType.kByte)
kInt32 = int(tex.DType.kInt32)
kFloat32 = int(tex.DType.kFloat32)
kFloat16 = int(tex.DType.kFloat16)
kBFloat16 = int(tex.DType.kBFloat16)
kFloat8E4M3 = int(tex.DType.kFloat8E4M3)
kFloat8E5M2 = int(tex.DType.kFloat8E5M2)
kFloat4E2M1 = int(tex.DType.kFloat4E2M1)

@classmethod
def cast(cls, dtype: "DTypeSupported") -> "DType":
"""Normalize any ``DTypeSupported`` value to the canonical ``DType`` ``IntEnum``.
``DType`` is the canonical dtype tag used internally throughout
``transformer_engine.pytorch``, and is what this function always outputs.
The pybind ``transformer_engine_torch.DType`` enum is an additional type
accepted as input (for backward compatibility), which this function maps
to the matching ``DType`` member so stored attributes are always ``DType``.
"""
if isinstance(dtype, cls):
return dtype
return cls(int(dtype))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Equality comparison between constants.DType and tex.DType silently returns False

tex.DType is a pybind11 enum without .arithmetic(), so its __eq__ only compares with the same C-extension type. constants.DType is an IntEnum (a Python int subclass), so int.__eq__ is used on the left side — but CPython's int.__eq__ returns NotImplemented for non-PyLong objects, and pybind11's __eq__ also returns NotImplemented for a non-tex.DType right-hand side. Python falls back to identity comparison, yielding False. Existing user code like if quantizer.dtype == tex.DType.kFloat8E4M3: now silently evaluates to False even though the types are equivalent. The PR's documented backward-compat guarantee covers constructors and checkpoints but not equality comparisons, leaving this as an undocumented silent break.

@vthumbe1503 vthumbe1503 requested a review from ptrendx June 1, 2026 18:13
@vthumbe1503 vthumbe1503 requested a review from pggPL June 1, 2026 18:13
Comment thread benchmarks/benchmark_rht_cast.py Outdated

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import constants
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering how frequent it is, shouldn't we just expose the DType in the top module (transformer_engine.pytorch)?

return cached_dtype_object;
}

/*! @brief Register the Python -> C++ ``DType`` implicit conversion.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if that is accurate to be honest. My understanding of this function is that it converts the Python object of constants.DType to tex.DType, and only then the real conversion to the C++ type happens - this would actually increase the overhead I think. I believe the right approach is to use the custom type_caster from pybind to get the DType from either type.

Comment on lines +1263 to +1264
# ``constants.DType`` is implicitly convertible to ``transformer_engine::DType``
# on the C++ side, so pass it straight to the pybind function.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need those comments? They do not really give that much value.

Comment on lines +578 to +580
// Construct Python tensor. ``MakePythonDType`` returns the cached
// ``constants.DType`` IntEnum member, so ``_fp8_dtype`` ends up as that
// IntEnum without a per-call ``DType.cast`` normalization.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, please clean those comments (especially duplicate ones).

return transformer_engine::DType::kFloat8E5M2;
}

pybind11::object MakePythonDType(transformer_engine::DType dtype) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am slightly worried about the thread safety of this function when there is no GIL.

* produces a fresh ``tex.DType`` pybind enum object on every call, which the
* Python constructors then have to normalize via ``DType.cast``.
*
* Must be called with the GIL held (always true inside pybind-invoked code).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is not necessarily true with the GIL-less Python becoming a thing.

Comment thread transformer_engine/pytorch/__init__.py Outdated
Float8BlockQuantizer,
# pybind11 enum used as Quantizer.dtype
# Python IntEnum used as Quantizer.dtype.
constants.DType,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to add it here even if it is a regular Python object?

Comment thread transformer_engine/pytorch/constants.py Outdated
kFloat4E2M1 = int(tex.DType.kFloat4E2M1)

@classmethod
def cast(cls, dtype: "DTypeSupported") -> "DType":
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DTypeSupported?

Comment thread transformer_engine/pytorch/constants.py Outdated
Comment on lines +52 to +54
# tex.DType is the pybind enum kept for backward compatibility.
# in the constructors for QuantizedTensors and Quantizers.
DTypeSupported = Union[DType, tex.DType]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, why can't you just use this union type directly in the cast function declaration rather than having this indirection?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is reused in multiple places including Quantizer and QuantizedTensor constructors. So I defined it once here.

# ``transformer_engine.h``). Use the bracket syntax ``TE_DType[torch_dtype]``
# to resolve a ``torch.dtype`` to its matching ``DType`` member.
# Used for passing dtypes into cuda extension.
TE_DType = {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we sometimes need int64

Comment on lines +75 to +81
DType.kByte: torch.uint8,
DType.kFloat8E4M3: torch.float8_e4m3fn,
DType.kFloat8E5M2: torch.float8_e5m2,
DType.kInt32: torch.int32,
DType.kFloat32: torch.float32,
DType.kFloat16: torch.half,
DType.kBFloat16: torch.bfloat16,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{value: key for key, value in TE_DType.items()}

vthumbe1503 and others added 9 commits June 2, 2026 06:01
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants