Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ae9b8fb
initial prototype
vthumbe1503 May 22, 2026
ad44101
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2026
beb020d
address review comment
vthumbe1503 May 31, 2026
1be6e66
cleanup
vthumbe1503 May 31, 2026
b10446b
some more
vthumbe1503 May 31, 2026
834241e
done
vthumbe1503 May 31, 2026
d27c24c
fix merge conflicts
vthumbe1503 May 31, 2026
0a3855d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2026
44c5b6e
clean
vthumbe1503 May 31, 2026
2c1aac0
Merge branch 'te_dtype' of github.com:vthumbe1503/TransformerEngine i…
vthumbe1503 May 31, 2026
57ba8ec
cache python_to_cpp and cpp_to_python casts for dtype
vthumbe1503 Jun 1, 2026
f83db87
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2026
11016c9
add the missing conversion file
vthumbe1503 Jun 1, 2026
2225f7e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2026
f20c7f6
cleanup comments
vthumbe1503 Jun 1, 2026
d8f94d1
cleanup
vthumbe1503 Jun 1, 2026
497427f
lint
vthumbe1503 Jun 1, 2026
7a635d2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2026
de110f9
address review comment
vthumbe1503 Jun 2, 2026
69f6edc
resolve merge conflicts
vthumbe1503 Jun 2, 2026
86acd9d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2026
418f848
address review comments
vthumbe1503 Jun 2, 2026
0a5bd29
Merge branch 'te_dtype' of github.com:vthumbe1503/TransformerEngine i…
vthumbe1503 Jun 2, 2026
252745b
fix build docs
vthumbe1503 Jun 2, 2026
5a3163b
fix review comment, lint
vthumbe1503 Jun 2, 2026
7c02a40
address review comments
vthumbe1503 Jun 2, 2026
e0c2e10
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions benchmarks/benchmark_rht_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.utils.benchmark as benchmark

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
import transformer_engine.pytorch.cpp_extensions as ext

from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
Expand All @@ -17,7 +16,7 @@
permute_scale = False

TORCH_TO_TE_FLOAT_MAP = {
torch.bfloat16: tex.DType.kBFloat16,
torch.bfloat16: te.DType.kBFloat16,
}


Expand All @@ -31,7 +30,7 @@ def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16):

# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
fp4_dtype=te.DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
Expand Down
3 changes: 1 addition & 2 deletions docs/examples/quickstart_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,8 @@ def share_parameters_with_transformerlayer_te_model(te_model, basic_model):

def cast_to_representable(inp, scale=1.0, fp8_format="e4m3"):
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
import transformer_engine_torch as tex

fp8_type = tex.DType.kFloat8E4M3 if fp8_format == "e4m3" else tex.DType.kFloat8E5M2
fp8_type = te.DType.kFloat8E4M3 if fp8_format == "e4m3" else te.DType.kFloat8E5M2
scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale
amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
quantizer = Float8Quantizer(scale=scale, amax=amax_history, fp8_dtype=fp8_type)
Expand Down
13 changes: 7 additions & 6 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize
import transformer_engine_torch as tex
from transformer_engine.pytorch import DType
from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch import (
autocast,
Expand Down Expand Up @@ -328,34 +329,34 @@ def run_dpa_with_cp(
).cuda()
if scaling_mode == "delayed":
qkv_quantizer = Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_dtype=DType.kFloat8E4M3,
scale=torch.tensor([1], dtype=torch.float32).cuda(),
amax=torch.tensor([0], dtype=torch.float32).cuda(),
)
dout_quantizer = Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
fp8_dtype=DType.kFloat8E5M2,
scale=torch.tensor([1], dtype=torch.float32).cuda(),
amax=torch.tensor([0], dtype=torch.float32).cuda(),
)
if scaling_mode == "current":
qkv_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_dtype=DType.kFloat8E4M3,
device="cuda",
)
dout_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
fp8_dtype=DType.kFloat8E5M2,
device="cuda",
)
if scaling_mode == "mxfp8":
qkv_quantizer = MXFP8Quantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_dtype=DType.kFloat8E4M3,
rowwise=True,
columnwise=True,
)
qkv_quantizer.optimize_for_gemm = True
qkv_quantizer.internal = False
dout_quantizer = MXFP8Quantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
fp8_dtype=DType.kFloat8E5M2,
rowwise=True,
columnwise=True,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/debug/run_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
import torch.distributed as dist
import transformer_engine
import transformer_engine_torch as tex
from transformer_engine.pytorch import DType
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch import is_fp8_available
Expand Down Expand Up @@ -683,7 +683,7 @@ def _run_test_with_combinations(
)

# test_fake_quant_fp8
dtype_options = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None]
dtype_options = [DType.kFloat8E4M3, DType.kFloat8E5M2, None]
_run_test_with_combinations(
test_fake_quant_fp8,
dtype_options,
Expand Down
14 changes: 7 additions & 7 deletions tests/pytorch/debug/test_api_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import torch
from transformer_engine.pytorch import Float8Tensor, Float8Quantizer
from transformer_engine.pytorch import DType

import nvdlfw_inspect.api as debug_api

try:
import transformer_engine
import transformer_engine_torch as tex
except (ImportError, ModuleNotFoundError):
print("Could not find TransformerEngine package.")
exit(1)
Expand Down Expand Up @@ -128,12 +128,12 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
default_quantizer1 = Float8Quantizer(
scale=torch.tensor([1]).cuda(),
amax=torch.tensor([0]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_dtype=DType.kFloat8E4M3,
)
default_quantizer2 = Float8Quantizer(
scale=torch.tensor([1]).cuda(),
amax=torch.tensor([0]).cuda(),
fp8_dtype=tex.DType.kFloat8E5M2,
fp8_dtype=DType.kFloat8E5M2,
)

output1 = debug_api.transformer_engine.modify_tensor(
Expand All @@ -145,7 +145,7 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
tensor=tensor,
)
assert type(output1) == Float8Tensor
assert output1._fp8_dtype == tex.DType.kFloat8E4M3
assert output1._fp8_dtype == DType.kFloat8E4M3

output2 = debug_api.transformer_engine.modify_tensor(
"decoder.1.mlp.fc1",
Expand All @@ -156,7 +156,7 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
iteration=0,
)
assert type(output2) == Float8Tensor
assert output2._fp8_dtype == tex.DType.kFloat8E5M2
assert output2._fp8_dtype == DType.kFloat8E5M2

assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1",
Expand Down Expand Up @@ -234,7 +234,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_dtype=DType.kFloat8E4M3,
)
tensor_fp8 = quantizer(tensor)

Expand Down Expand Up @@ -372,7 +372,7 @@ def log_stats():
quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_dtype=DType.kFloat8E4M3,
)

def fp8_tensor(t):
Expand Down
43 changes: 21 additions & 22 deletions tests/pytorch/debug/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import nvdlfw_inspect.api as debug_api
import transformer_engine.debug
import transformer_engine.pytorch as tepytorch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.pytorch.quantization import _default_sf_compute
from transformer_engine.pytorch import (
Expand Down Expand Up @@ -57,7 +56,7 @@ def _cast_to_fp8(tensor, scale, dtype):


def _get_current_scale(tensor, fp8_dtype):
if fp8_dtype == tex.DType.kFloat8E4M3:
if fp8_dtype == te.DType.kFloat8E4M3:
fp8_max = Format.E4M3.value.max_fwd
else:
fp8_max = Format.E5M2.value.max_fwd
Expand Down Expand Up @@ -93,19 +92,19 @@ def _emulate_linear(
input: torch.Tensor,
weight: torch.Tensor,
fprop_fp8: bool = False,
fprop_input_fake_quant: tex.DType = None,
fprop_input_fake_quant: te.DType = None,
fprop_input_scale: torch.Tensor = None,
fprop_weight_fake_quant: tex.DType = None,
fprop_weight_fake_quant: te.DType = None,
fprop_weight_scale: torch.Tensor = None,
dgrad_fp8: bool = False,
dgrad_gradient_fake_quant: tex.DType = None,
dgrad_gradient_fake_quant: te.DType = None,
dgrad_gradient_scale: torch.Tensor = None,
dgrad_weight_fake_quant: tex.DType = None,
dgrad_weight_fake_quant: te.DType = None,
dgrad_weight_scale: torch.Tensor = None,
wgrad_fp8: bool = False,
wgrad_gradient_fake_quant: tex.DType = None,
wgrad_gradient_fake_quant: te.DType = None,
wgrad_gradient_scale: torch.Tensor = None,
wgrad_input_fake_quant: tex.DType = None,
wgrad_input_fake_quant: te.DType = None,
wgrad_input_scale: torch.Tensor = None,
loss_multiplier: float = 1.0,
activation_sync=None,
Expand All @@ -116,10 +115,10 @@ def _emulate_linear(
activation = _fp8_gemm_kernel(
input,
_scalar(fprop_input_scale or 1.0),
tex.DType.kFloat8E4M3,
te.DType.kFloat8E4M3,
weight,
_scalar(fprop_weight_scale or 1.0),
tex.DType.kFloat8E4M3,
te.DType.kFloat8E4M3,
_2X_ACC_FPROP,
)
activation = activation.clone().detach().contiguous().requires_grad_(True)
Expand Down Expand Up @@ -152,10 +151,10 @@ def _emulate_linear(
dgrad = _fp8_gemm_kernel(
weight.T,
_scalar(dgrad_weight_scale or 1.0),
tex.DType.kFloat8E4M3,
te.DType.kFloat8E4M3,
gradient,
_scalar(dgrad_gradient_scale or 1.0),
tex.DType.kFloat8E5M2,
te.DType.kFloat8E5M2,
_2X_ACC_DGRAD,
).T
else:
Expand All @@ -176,10 +175,10 @@ def _emulate_linear(
wgrad = _fp8_gemm_kernel(
input.T,
_scalar(wgrad_input_scale or 1.0),
tex.DType.kFloat8E4M3,
te.DType.kFloat8E4M3,
gradient.T,
_scalar(wgrad_gradient_scale or 1.0),
tex.DType.kFloat8E5M2,
te.DType.kFloat8E5M2,
_2X_ACC_WGRAD,
).T
else:
Expand Down Expand Up @@ -470,17 +469,17 @@ def set_scaling_factors(model, input_kwargs, fp8_kwargs):
def set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs):
# Compute per tensor scaling factor if respective flag in input_kwargs is set.
if input_kwargs["fprop_inp"]:
fp8_kwargs["fprop_input_scale"] = tex.DType.kFloat8E4M3
fp8_kwargs["fprop_input_scale"] = te.DType.kFloat8E4M3
if input_kwargs["fprop_weight"]:
fp8_kwargs["fprop_weight_scale"] = tex.DType.kFloat8E4M3
fp8_kwargs["fprop_weight_scale"] = te.DType.kFloat8E4M3
if input_kwargs["dgrad_grad"]:
fp8_kwargs["dgrad_gradient_scale"] = tex.DType.kFloat8E5M2
fp8_kwargs["dgrad_gradient_scale"] = te.DType.kFloat8E5M2
if input_kwargs["dgrad_weight"]:
fp8_kwargs["dgrad_weight_scale"] = tex.DType.kFloat8E4M3
fp8_kwargs["dgrad_weight_scale"] = te.DType.kFloat8E4M3
if input_kwargs["wgrad_grad"]:
fp8_kwargs["wgrad_gradient_scale"] = tex.DType.kFloat8E5M2
fp8_kwargs["wgrad_gradient_scale"] = te.DType.kFloat8E5M2
if input_kwargs["wgrad_input"]:
fp8_kwargs["wgrad_input_scale"] = tex.DType.kFloat8E4M3
fp8_kwargs["wgrad_input_scale"] = te.DType.kFloat8E4M3


@create_config_file
Expand Down Expand Up @@ -651,7 +650,7 @@ def init_and_warmup():


all_combinations = list(
itertools.product([tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None], repeat=6)
itertools.product([te.DType.kFloat8E4M3, te.DType.kFloat8E5M2, None], repeat=6)
)
subset_combinations = random.sample(all_combinations, 10)

Expand Down Expand Up @@ -687,7 +686,7 @@ def test_fake_quant_fp8(
def fake_quant_fp8_create_config(
fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad, config_file
):
format_to_str = {tex.DType.kFloat8E4M3: "FP8E4M3", tex.DType.kFloat8E5M2: "FP8E5M2"}
format_to_str = {te.DType.kFloat8E4M3: "FP8E4M3", te.DType.kFloat8E5M2: "FP8E5M2"}
gemms = ""

def _add_tensor(quant_format, tensor):
Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/distributed/run_gemm_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
if opts.quantization == "fp8":
# Structure to maintain amax and scale/scale_inv information for the kernel and input
num_gemms = 6 if ub_obj2 is not None else 3
fp8_dtype = tex.DType.kFloat8E4M3
fp8_dtype = te.DType.kFloat8E4M3
fp8_scales = torch.ones(num_gemms, dtype=torch.float, device="cuda")
fp8_amaxes = torch.zeros(num_gemms, dtype=torch.float, device="cuda")

Expand Down Expand Up @@ -516,7 +516,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype
)
elif opts.quantization == "mxfp8":
fp8_dtype = tex.DType.kFloat8E4M3
fp8_dtype = te.DType.kFloat8E4M3
inp_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
ker_quantizer = MXFP8Quantizer(fp8_dtype)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
Expand Down
11 changes: 5 additions & 6 deletions tests/pytorch/distributed/run_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch
from torch import nn
import torch.distributed as dist
import transformer_engine_torch as tex
from transformer_engine.common.recipe import (
MXFP8BlockScaling,
DelayedScaling,
Expand Down Expand Up @@ -399,7 +398,7 @@ def _test_quantizer(input_dtype, fp8_dtype):

Args:
input_dtype (torch.dtype): The data type of the input.
fp8_dtype (tex.DType): The data type of the fp8.
fp8_dtype (te.DType): The data type of the fp8.
"""

M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE
Expand Down Expand Up @@ -443,7 +442,7 @@ def test_quantizer():
return

input_dtypes = [torch.float32, torch.bfloat16]
fp8_dtypes = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
fp8_dtypes = [te.DType.kFloat8E4M3, te.DType.kFloat8E5M2]

for input_dtype in input_dtypes:
for fp8_dtype in fp8_dtypes:
Expand Down Expand Up @@ -514,7 +513,7 @@ def _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls):

Args:
input_dtype (torch.dtype): The data type of the input.
low_precision_dtype (tex.DType): The data type of the low precision, can be fp4 or fp8.
low_precision_dtype (te.DType): The data type of the low precision, can be fp4 or fp8.
"""

M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE // 2
Expand Down Expand Up @@ -623,8 +622,8 @@ def test_quantized_all_gather():
return

input_dtypes = [torch.bfloat16]
fp4_dtype = [tex.DType.kFloat4E2M1]
fp8_dtype = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
fp4_dtype = [te.DType.kFloat4E2M1]
fp8_dtype = [te.DType.kFloat8E4M3, te.DType.kFloat8E5M2]
quantizer_cls_nvfp4 = [NVFP4Quantizer]
# add FP8 quantizers if needed
quantizer_cls_fp8 = []
Expand Down
7 changes: 3 additions & 4 deletions tests/pytorch/distributed/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
is_bf16_available,
)
import transformer_engine.pytorch.ops as te_ops
import transformer_engine_torch as tex

# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
Expand Down Expand Up @@ -107,17 +106,17 @@ def make_reference_and_test_tensors(
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_dtype=te.DType.kFloat8E4M3,
)
test = quantizer(test)
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_dtype=te.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
test = MXFP8Quantizer(fp8_dtype=te.DType.kFloat8E4M3)(test)
elif quantization == "nvfp4":
test = NVFP4Quantizer(
with_rht=False,
Expand Down
Loading
Loading