[PyTorch] Python DType enum#3039
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
| // when this runs, so the GIL is held and Python imports are legal. | ||
| static pybind11::object te_dtype_cls = | ||
| pybind11::module_::import("transformer_engine.pytorch.constants").attr("TE_DType"); | ||
| return te_dtype_cls(static_cast<int>(dtype)); |
There was a problem hiding this comment.
Find a way to bind C++ and python Dtype through pybind cast mechanism
There was a problem hiding this comment.
This is done for Python. -> C++
For C++ to Python. --> Cant avoid this.
| # pybind11 enum used as Quantizer.dtype | ||
| tex.DType, | ||
| # Python IntEnum used as Quantizer.dtype | ||
| TE_DType, |
There was a problem hiding this comment.
save/load backward compatibilty should be there
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci |
Greptile SummaryThis PR replaces the pybind11
Confidence Score: 4/5Safe to merge with minor defensive hardening recommended in MakePythonDType and the pybind eq lambda. The migration is thorough across all 65 files, the type caster correctly handles both IntEnum and pybind enums in both conversion directions, and the equality/hash contracts between DType and tex.DType are properly maintained. Two non-blocking quality improvements are noted: a missing null guard in MakePythonDType that would produce a confusing error rather than a silent wrong result, and an unnecessary py::cast(self) per comparison in the eq lambda. transformer_engine/pytorch/csrc/common.cpp (MakePythonDType null-return path) and transformer_engine/pytorch/csrc/extensions/pybind.cpp (eq lambda) Important Files Changed
Sequence DiagramsequenceDiagram
participant PY as Python caller
participant DType as constants.DType (IntEnum)
participant Caster as dtype_pybind_conversion.h
participant CPP as C++ transformer_engine::DType
participant MakePY as MakePythonDType()
Note over PY,MakePY: Python to C++ path
PY->>Caster: pybind load(src)
alt src is int/IntEnum
Caster->>CPP: PyLong_AsLong to static_cast DType
else src is tex.DType
Caster->>CPP: type_caster_base::load
end
Note over PY,MakePY: C++ to Python path
CPP->>MakePY: MakePythonDType(dtype)
MakePY-->>DType: import transformer_engine.pytorch.DType cached
MakePY-->>PY: constants.DType member
Note over PY,MakePY: Equality
PY->>DType: "DType.kX == tex.DType.kY"
DType->>DType: isinstance tex.DType then int comparison
PY->>CPP: "tex.DType.kX == DType.kY"
CPP->>CPP: isinstance py int then int comparison
Reviews (5): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| # Fail fast at import time if a new enumerator is added | ||
| # on the C++ side without being mirrored above. | ||
| assert {m.name for m in DType} == set(tex.DType.__members__), ( | ||
| "DType is out of sync with transformer_engine_torch.DType; " | ||
| "add the new pybind enumerator to DType in constants.py." | ||
| ) |
There was a problem hiding this comment.
Import-time sync check can be silently skipped
Python's -O (optimize) flag strips all assert statements, so this import-time guard that verifies DType is in sync with tex.DType will never run in optimized/production builds. A build where a new C++ enumerator was added without updating DType would import without error and produce silent mismatches downstream. Replace with an explicit if ... raise.
| /*! @brief Register the Python -> C++ ``DType`` implicit conversion. | ||
| * Allows a Python object of type ``transformer_engine.pytorch.constants.DType`` | ||
| * to be passed wherever a pybind-bound ``transformer_engine::DType`` argument is expected. | ||
| * pybind-bound ``transformer_engine::DType`` argument is expected. | ||
| * Must be called after the pybind ``DType`` enum has been registered. | ||
| */ |
There was a problem hiding this comment.
Duplicate sentence in the docstring — the line "pybind-bound
transformer_engine::DType argument is expected." appears twice.
| /*! @brief Register the Python -> C++ ``DType`` implicit conversion. | |
| * Allows a Python object of type ``transformer_engine.pytorch.constants.DType`` | |
| * to be passed wherever a pybind-bound ``transformer_engine::DType`` argument is expected. | |
| * pybind-bound ``transformer_engine::DType`` argument is expected. | |
| * Must be called after the pybind ``DType`` enum has been registered. | |
| */ | |
| /*! @brief Register the Python -> C++ ``DType`` implicit conversion. | |
| * Allows a Python object of type ``transformer_engine.pytorch.constants.DType`` | |
| * to be passed wherever a pybind-bound ``transformer_engine::DType`` argument is expected. | |
| * Must be called after the pybind ``DType`` enum has been registered. | |
| */ |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| class DType(enum.IntEnum): | ||
| """Python mirror of ``transformer_engine_torch.DType`` (pybind11 enum). | ||
| Members are constructed manually from the underlying pybind enum so | ||
| that this class is the single source of truth for dtype tags used | ||
| across ``transformer_engine.pytorch``. | ||
| """ | ||
|
|
||
| kByte = int(tex.DType.kByte) | ||
| kInt32 = int(tex.DType.kInt32) | ||
| kFloat32 = int(tex.DType.kFloat32) | ||
| kFloat16 = int(tex.DType.kFloat16) | ||
| kBFloat16 = int(tex.DType.kBFloat16) | ||
| kFloat8E4M3 = int(tex.DType.kFloat8E4M3) | ||
| kFloat8E5M2 = int(tex.DType.kFloat8E5M2) | ||
| kFloat4E2M1 = int(tex.DType.kFloat4E2M1) | ||
|
|
||
| @classmethod | ||
| def cast(cls, dtype: "DTypeSupported") -> "DType": | ||
| """Normalize any ``DTypeSupported`` value to the canonical ``DType`` ``IntEnum``. | ||
| ``DType`` is the canonical dtype tag used internally throughout | ||
| ``transformer_engine.pytorch``, and is what this function always outputs. | ||
| The pybind ``transformer_engine_torch.DType`` enum is an additional type | ||
| accepted as input (for backward compatibility), which this function maps | ||
| to the matching ``DType`` member so stored attributes are always ``DType``. | ||
| """ | ||
| if isinstance(dtype, cls): | ||
| return dtype | ||
| return cls(int(dtype)) |
There was a problem hiding this comment.
Equality comparison between
constants.DType and tex.DType silently returns False
tex.DType is a pybind11 enum without .arithmetic(), so its __eq__ only compares with the same C-extension type. constants.DType is an IntEnum (a Python int subclass), so int.__eq__ is used on the left side — but CPython's int.__eq__ returns NotImplemented for non-PyLong objects, and pybind11's __eq__ also returns NotImplemented for a non-tex.DType right-hand side. Python falls back to identity comparison, yielding False. Existing user code like if quantizer.dtype == tex.DType.kFloat8E4M3: now silently evaluates to False even though the types are equivalent. The PR's documented backward-compat guarantee covers constructors and checkpoints but not equality comparisons, leaving this as an undocumented silent break.
|
|
||
| import transformer_engine.pytorch as te | ||
| import transformer_engine_torch as tex | ||
| from transformer_engine.pytorch import constants |
There was a problem hiding this comment.
Considering how frequent it is, shouldn't we just expose the DType in the top module (transformer_engine.pytorch)?
| return cached_dtype_object; | ||
| } | ||
|
|
||
| /*! @brief Register the Python -> C++ ``DType`` implicit conversion. |
There was a problem hiding this comment.
Not sure if that is accurate to be honest. My understanding of this function is that it converts the Python object of constants.DType to tex.DType, and only then the real conversion to the C++ type happens - this would actually increase the overhead I think. I believe the right approach is to use the custom type_caster from pybind to get the DType from either type.
| # ``constants.DType`` is implicitly convertible to ``transformer_engine::DType`` | ||
| # on the C++ side, so pass it straight to the pybind function. |
There was a problem hiding this comment.
Do we need those comments? They do not really give that much value.
| // Construct Python tensor. ``MakePythonDType`` returns the cached | ||
| // ``constants.DType`` IntEnum member, so ``_fp8_dtype`` ends up as that | ||
| // IntEnum without a per-call ``DType.cast`` normalization. |
There was a problem hiding this comment.
Same, please clean those comments (especially duplicate ones).
| return transformer_engine::DType::kFloat8E5M2; | ||
| } | ||
|
|
||
| pybind11::object MakePythonDType(transformer_engine::DType dtype) { |
There was a problem hiding this comment.
I am slightly worried about the thread safety of this function when there is no GIL.
| * produces a fresh ``tex.DType`` pybind enum object on every call, which the | ||
| * Python constructors then have to normalize via ``DType.cast``. | ||
| * | ||
| * Must be called with the GIL held (always true inside pybind-invoked code). |
There was a problem hiding this comment.
That is not necessarily true with the GIL-less Python becoming a thing.
| Float8BlockQuantizer, | ||
| # pybind11 enum used as Quantizer.dtype | ||
| # Python IntEnum used as Quantizer.dtype. | ||
| constants.DType, |
There was a problem hiding this comment.
Do we still need to add it here even if it is a regular Python object?
| kFloat4E2M1 = int(tex.DType.kFloat4E2M1) | ||
|
|
||
| @classmethod | ||
| def cast(cls, dtype: "DTypeSupported") -> "DType": |
| # tex.DType is the pybind enum kept for backward compatibility. | ||
| # in the constructors for QuantizedTensors and Quantizers. | ||
| DTypeSupported = Union[DType, tex.DType] |
There was a problem hiding this comment.
Hmmm, why can't you just use this union type directly in the cast function declaration rather than having this indirection?
There was a problem hiding this comment.
This is reused in multiple places including Quantizer and QuantizedTensor constructors. So I defined it once here.
| # ``transformer_engine.h``). Use the bracket syntax ``TE_DType[torch_dtype]`` | ||
| # to resolve a ``torch.dtype`` to its matching ``DType`` member. | ||
| # Used for passing dtypes into cuda extension. | ||
| TE_DType = { |
| DType.kByte: torch.uint8, | ||
| DType.kFloat8E4M3: torch.float8_e4m3fn, | ||
| DType.kFloat8E5M2: torch.float8_e5m2, | ||
| DType.kInt32: torch.int32, | ||
| DType.kFloat32: torch.float32, | ||
| DType.kFloat16: torch.half, | ||
| DType.kBFloat16: torch.bfloat16, |
There was a problem hiding this comment.
{value: key for key, value in TE_DType.items()}Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Description
Replace the pybind tex.DType with a canonical Python DType IntEnum throughout transformer_engine.pytorch. For backward compat, cpp->python and python->cpp DType object conversions are cached at the pybind boundaries to reduce CPU overheads.
Motivation
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: