From 2c4913358f70258d5ea1da9c68a88d40746b6730 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 14 Jan 2026 19:46:41 +0100 Subject: [PATCH 01/10] jjit bug fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/jit.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index 5884188b7e..c19b24944f 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -51,9 +51,24 @@ def wrapper(*args, **kwargs): import torch._dynamo if torch.__version__ >= "2.1": - no_torch_dynamo = lambda recursive=True: lambda f: ( - f if is_in_onnx_export_mode() else torch._dynamo.disable(f, recursive=recursive) - ) + + def no_torch_dynamo(recursive=True): + """Decorator to disable Torch Dynamo, except during ONNX export.""" + + def decorator(f): + disabled_f = torch._dynamo.disable(f, recursive=recursive) + + @wraps(f) + def wrapper(*args, **kwargs): + # Check dynamically at call time, not at decoration time + if is_in_onnx_export_mode(): + return f(*args, **kwargs) + return disabled_f(*args, **kwargs) + + return wrapper + + return decorator + else: # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True no_torch_dynamo = lambda recursive=True: torch._dynamo.disable From bec2c3ced042b184390ab9aaffc4328c6adffabb Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 14 Jan 2026 20:11:46 +0100 Subject: [PATCH 02/10] fix' Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_onnx_export.py | 23 +++++-- .../dot_product_attention/backends.py | 64 +++++++++++++++++++ 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 50cd150c4e..c55b0cc193 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -713,6 +713,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation): _test_export_layernorm_mlp(activation=activation) +# FP8 recipes with fp8_dpa=True for attention FP8 emulation export test +fp8_dpa_recipes = [None] # None = no FP8 +if fp8_available: + fp8_dpa_recipes.append(recipe.DelayedScaling(fp8_dpa=True)) + fp8_dpa_recipes.append(recipe.Float8CurrentScaling(fp8_dpa=True)) + + +@pytest.mark.parametrize("fp8_recipe", fp8_dpa_recipes) @pytest.mark.parametrize( "precision, use_mask, attn_mask_type", [ @@ -730,6 +738,7 @@ def test_export_core_attention( precision: torch.dtype, use_mask: bool, attn_mask_type: str, + fp8_recipe: recipe.Recipe, ): # Set dimensions (these are arbitrary). seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64) @@ -749,22 +758,26 @@ def test_export_core_attention( mask_str = get_attn_mask_str(use_mask, attn_mask_type) high_prec_str = dtype2str(precision) - fname = f"te.core_attention{mask_str}{high_prec_str}.onnx" + fp8_str = "_fp8_dpa" if fp8_recipe is not None else "" + fname = f"te.core_attention{fp8_str}{mask_str}{high_prec_str}.onnx" + + is_fp8 = fp8_recipe is not None model = te.attention.DotProductAttention( num_attention_heads=num_attention_heads, kv_channels=kv_channels, - attention_dropout=0.5, + attention_dropout=0.0 if is_fp8 else 0.5, # Disable dropout for FP8 deterministic results qkv_format=qkv_format, attn_mask_type=attn_mask_type, ).to(device="cuda") - do_export(model, inp, fname, input_names=input_names, fp8_recipe=None) - te_outputs = te_infer(model, inp, is_fp8=False, fp8_recipe=None) + do_export(model, inp, fname, input_names=input_names, fp8_recipe=fp8_recipe) + te_outputs = te_infer(model, inp, is_fp8=is_fp8, fp8_recipe=fp8_recipe) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) if precision in (torch.bfloat16,): return + atol = 5e-2 if is_fp8 else 1e-2 # Higher tolerance for FP8 due to quantization effects validate_result( - fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs + fname, inp, model, is_fp8=True, atol=atol, input_names=input_names, te_outputs=te_outputs ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index c726ed8849..0c93c37b43 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -164,6 +164,9 @@ class FP8EmulationFunc(torch.autograd.Function): @staticmethod def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout): # pylint: disable=missing-function-docstring + if is_in_onnx_export_mode(): + return FP8EmulationFunc.onnx_forward(tensor1, tensor2, tensor3, quantizer_name, qkv_layout) + if quantizer_name == "QKV_quantizer": query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] @@ -202,6 +205,67 @@ def backward(ctx, grad1, grad2, grad3): tensors = grad1, grad2, grad3 return tensors[0], tensors[1], tensors[2], None, None, None + @staticmethod + def onnx_forward(tensor1, tensor2, tensor3, quantizer_name, qkv_layout=None): + """ + ONNX-compatible forward for FP8 emulation using operations with defined ONNX translations. + + This method performs quantize + dequantize to emulate FP8 effects using ONNX-compatible ops. + For ONNX export, we use current scaling (dynamic) quantization. + + Parameters + ---------- + tensor1, tensor2, tensor3 : torch.Tensor + Input tensors (e.g., Q, K, V for QKV_quantizer, or single tensor for S/O quantizers) + quantizer_name : str + Name of quantizer: "QKV_quantizer", "S_quantizer", "O_quantizer", etc. + qkv_layout : str, optional + QKV layout string (not used in ONNX path) + + Returns + ------- + Tuple of emulated tensors + """ + # pylint: disable=unused-argument + + def _fp8_emulate(tensor): + """Quantize + dequantize using existing ONNX-compatible ops.""" + orig_dtype = tensor.dtype + tensor_fp32 = tensor.to(torch.float32) + data, scale_inv = torch.ops.tex.fp8_cs_quantize(tensor_fp32) + out = torch.ops.tex.fp8_dequantize(data, scale_inv) + return out.to(orig_dtype) + + if quantizer_name == "QKV_quantizer": + # Combine Q, K, V -> quantize together -> split back + orig_dtype = tensor1.dtype + shapes = [tensor1.shape, tensor2.shape, tensor3.shape] + numels = [tensor1.numel(), tensor2.numel(), tensor3.numel()] + + # Flatten and concatenate + combined = torch.cat([ + tensor1.reshape(-1), + tensor2.reshape(-1), + tensor3.reshape(-1) + ], dim=0).to(torch.float32) + + # Quantize + dequantize combined tensor (shared scale) + data, scale_inv = torch.ops.tex.fp8_cs_quantize(combined) + out = torch.ops.tex.fp8_dequantize(data, scale_inv) + + # Split back + out1 = out[:numels[0]].reshape(shapes[0]).to(orig_dtype) + out2 = out[numels[0]:numels[0] + numels[1]].reshape(shapes[1]).to(orig_dtype) + out3 = out[numels[0] + numels[1]:].reshape(shapes[2]).to(orig_dtype) + + return out1, out2, out3 + elif quantizer_name in ["S_quantizer", "O_quantizer"]: + # Emulate FP8 on single tensor + return _fp8_emulate(tensor1), tensor2, tensor3 + else: + # Pass-through + return tensor1, tensor2, tensor3 + class UnfusedDotProductAttention(torch.nn.Module): """Parallel attention w/o QKV and Proj Gemms From 0e054ad27658165c873a523a80d2b68a894fa2bd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 19:29:35 +0000 Subject: [PATCH 03/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../dot_product_attention/backends.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 0c93c37b43..5e5ef000e5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -165,7 +165,9 @@ class FP8EmulationFunc(torch.autograd.Function): def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout): # pylint: disable=missing-function-docstring if is_in_onnx_export_mode(): - return FP8EmulationFunc.onnx_forward(tensor1, tensor2, tensor3, quantizer_name, qkv_layout) + return FP8EmulationFunc.onnx_forward( + tensor1, tensor2, tensor3, quantizer_name, qkv_layout + ) if quantizer_name == "QKV_quantizer": query_layer, key_layer, value_layer = [ @@ -243,20 +245,18 @@ def _fp8_emulate(tensor): numels = [tensor1.numel(), tensor2.numel(), tensor3.numel()] # Flatten and concatenate - combined = torch.cat([ - tensor1.reshape(-1), - tensor2.reshape(-1), - tensor3.reshape(-1) - ], dim=0).to(torch.float32) + combined = torch.cat( + [tensor1.reshape(-1), tensor2.reshape(-1), tensor3.reshape(-1)], dim=0 + ).to(torch.float32) # Quantize + dequantize combined tensor (shared scale) data, scale_inv = torch.ops.tex.fp8_cs_quantize(combined) out = torch.ops.tex.fp8_dequantize(data, scale_inv) # Split back - out1 = out[:numels[0]].reshape(shapes[0]).to(orig_dtype) - out2 = out[numels[0]:numels[0] + numels[1]].reshape(shapes[1]).to(orig_dtype) - out3 = out[numels[0] + numels[1]:].reshape(shapes[2]).to(orig_dtype) + out1 = out[: numels[0]].reshape(shapes[0]).to(orig_dtype) + out2 = out[numels[0] : numels[0] + numels[1]].reshape(shapes[1]).to(orig_dtype) + out3 = out[numels[0] + numels[1] :].reshape(shapes[2]).to(orig_dtype) return out1, out2, out3 elif quantizer_name in ["S_quantizer", "O_quantizer"]: From 4bc878c54630899e9c92c8019c9318698a8ee8b6 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 14 Jan 2026 20:35:14 +0100 Subject: [PATCH 04/10] fix Signed-off-by: Pawel Gadzinski --- .../pytorch/attention/dot_product_attention/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index bf19388d7e..f80d715a13 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -474,7 +474,9 @@ def get_attention_backend( logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False if use_unfused_attention: - allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" + allow_emulation = ( + os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode() + ) if not allow_emulation: logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False From 75ca1745627a6e62e71a6ddeacef7f52b7503fec Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 14 Jan 2026 20:42:00 +0100 Subject: [PATCH 05/10] fix Signed-off-by: Pawel Gadzinski --- .../dot_product_attention/backends.py | 37 +++++++++---------- .../attention/dot_product_attention/utils.py | 4 +- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 5e5ef000e5..96791f72df 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -166,7 +166,7 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou # pylint: disable=missing-function-docstring if is_in_onnx_export_mode(): return FP8EmulationFunc.onnx_forward( - tensor1, tensor2, tensor3, quantizer_name, qkv_layout + tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout ) if quantizer_name == "QKV_quantizer": @@ -208,17 +208,19 @@ def backward(ctx, grad1, grad2, grad3): return tensors[0], tensors[1], tensors[2], None, None, None @staticmethod - def onnx_forward(tensor1, tensor2, tensor3, quantizer_name, qkv_layout=None): + def onnx_forward(tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout=None): """ ONNX-compatible forward for FP8 emulation using operations with defined ONNX translations. This method performs quantize + dequantize to emulate FP8 effects using ONNX-compatible ops. - For ONNX export, we use current scaling (dynamic) quantization. + Uses the quantizer's onnx_quantize/onnx_dequantize methods for proper scaling behavior. Parameters ---------- tensor1, tensor2, tensor3 : torch.Tensor Input tensors (e.g., Q, K, V for QKV_quantizer, or single tensor for S/O quantizers) + quantizer : Quantizer + The quantizer object with onnx_quantize/onnx_dequantize methods quantizer_name : str Name of quantizer: "QKV_quantizer", "S_quantizer", "O_quantizer", etc. qkv_layout : str, optional @@ -230,14 +232,6 @@ def onnx_forward(tensor1, tensor2, tensor3, quantizer_name, qkv_layout=None): """ # pylint: disable=unused-argument - def _fp8_emulate(tensor): - """Quantize + dequantize using existing ONNX-compatible ops.""" - orig_dtype = tensor.dtype - tensor_fp32 = tensor.to(torch.float32) - data, scale_inv = torch.ops.tex.fp8_cs_quantize(tensor_fp32) - out = torch.ops.tex.fp8_dequantize(data, scale_inv) - return out.to(orig_dtype) - if quantizer_name == "QKV_quantizer": # Combine Q, K, V -> quantize together -> split back orig_dtype = tensor1.dtype @@ -247,21 +241,24 @@ def _fp8_emulate(tensor): # Flatten and concatenate combined = torch.cat( [tensor1.reshape(-1), tensor2.reshape(-1), tensor3.reshape(-1)], dim=0 - ).to(torch.float32) + ) - # Quantize + dequantize combined tensor (shared scale) - data, scale_inv = torch.ops.tex.fp8_cs_quantize(combined) - out = torch.ops.tex.fp8_dequantize(data, scale_inv) + # Quantize + dequantize combined tensor using quantizer's ONNX methods + combined_fp8 = quantizer.onnx_quantize(combined) + out = quantizer.onnx_dequantize(combined_fp8).to(orig_dtype) # Split back - out1 = out[: numels[0]].reshape(shapes[0]).to(orig_dtype) - out2 = out[numels[0] : numels[0] + numels[1]].reshape(shapes[1]).to(orig_dtype) - out3 = out[numels[0] + numels[1] :].reshape(shapes[2]).to(orig_dtype) + out1 = out[: numels[0]].reshape(shapes[0]) + out2 = out[numels[0] : numels[0] + numels[1]].reshape(shapes[1]) + out3 = out[numels[0] + numels[1] :].reshape(shapes[2]) return out1, out2, out3 elif quantizer_name in ["S_quantizer", "O_quantizer"]: - # Emulate FP8 on single tensor - return _fp8_emulate(tensor1), tensor2, tensor3 + # Emulate FP8 on single tensor using quantizer's ONNX methods + orig_dtype = tensor1.dtype + t_fp8 = quantizer.onnx_quantize(tensor1) + out = quantizer.onnx_dequantize(t_fp8).to(orig_dtype) + return out, tensor2, tensor3 else: # Pass-through return tensor1, tensor2, tensor3 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index f80d715a13..bf19388d7e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -474,9 +474,7 @@ def get_attention_backend( logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False if use_unfused_attention: - allow_emulation = ( - os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode() - ) + allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" if not allow_emulation: logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False From 1f0111f4ad7b88ca99cd7db73f75e46eae1d52a7 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 27 Jan 2026 16:54:57 +0000 Subject: [PATCH 06/10] fixes Signed-off-by: Pawel Gadzinski --- qa/L1_pytorch_onnx_unittest/test.sh | 3 +- tests/pytorch/test_onnx_export.py | 3 +- .../dot_product_attention/backends.py | 20 +-------- .../dot_product_attention.py | 5 ++- .../attention/dot_product_attention/utils.py | 5 ++- transformer_engine/pytorch/jit.py | 42 +++++++++---------- 6 files changed, 32 insertions(+), 46 deletions(-) diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh index b3a520e129..6f9ff54e48 100644 --- a/qa/L1_pytorch_onnx_unittest/test.sh +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -6,4 +6,5 @@ : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py +# NVTE_UnfusedDPA_Emulate_FP8=1 enables FP8 attention emulation when no native backend is available +NVTE_UnfusedDPA_Emulate_FP8=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index c55b0cc193..a1c54b22fc 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -766,7 +766,6 @@ def test_export_core_attention( model = te.attention.DotProductAttention( num_attention_heads=num_attention_heads, kv_channels=kv_channels, - attention_dropout=0.0 if is_fp8 else 0.5, # Disable dropout for FP8 deterministic results qkv_format=qkv_format, attn_mask_type=attn_mask_type, ).to(device="cuda") @@ -775,7 +774,7 @@ def test_export_core_attention( serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) if precision in (torch.bfloat16,): return - atol = 5e-2 if is_fp8 else 1e-2 # Higher tolerance for FP8 due to quantization effects + atol = 1.5e-1 if is_fp8 else 1e-2 validate_result( fname, inp, model, is_fp8=True, atol=atol, input_names=input_names, te_outputs=te_outputs ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 96791f72df..c6820bd87a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -211,29 +211,11 @@ def backward(ctx, grad1, grad2, grad3): def onnx_forward(tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout=None): """ ONNX-compatible forward for FP8 emulation using operations with defined ONNX translations. - - This method performs quantize + dequantize to emulate FP8 effects using ONNX-compatible ops. - Uses the quantizer's onnx_quantize/onnx_dequantize methods for proper scaling behavior. - - Parameters - ---------- - tensor1, tensor2, tensor3 : torch.Tensor - Input tensors (e.g., Q, K, V for QKV_quantizer, or single tensor for S/O quantizers) - quantizer : Quantizer - The quantizer object with onnx_quantize/onnx_dequantize methods - quantizer_name : str - Name of quantizer: "QKV_quantizer", "S_quantizer", "O_quantizer", etc. - qkv_layout : str, optional - QKV layout string (not used in ONNX path) - - Returns - ------- - Tuple of emulated tensors """ # pylint: disable=unused-argument if quantizer_name == "QKV_quantizer": - # Combine Q, K, V -> quantize together -> split back + # Flatten + concatenate + quantize + split. Equivalent to combine_and_quantize Case 3. orig_dtype = tensor1.dtype shapes = [tensor1.shape, tensor2.shape, tensor3.shape] numels = [tensor1.numel(), tensor2.numel(), tensor3.numel()] diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 6e5a12a103..9c1db63341 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1522,7 +1522,10 @@ def forward( ) if use_unfused_attention: - allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" + allow_emulation = ( + os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" + or is_in_onnx_export_mode() + ) if checkpoint_core_attention: return self._checkpointed_attention_forward( self.unfused_attention, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index bf19388d7e..6a8e7b755b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -474,7 +474,10 @@ def get_attention_backend( logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False if use_unfused_attention: - allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" + allow_emulation = ( + os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" + or is_in_onnx_export_mode() + ) if not allow_emulation: logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index c19b24944f..8e377b0203 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -50,28 +50,26 @@ def wrapper(*args, **kwargs): if torch.__version__ >= "2": import torch._dynamo - if torch.__version__ >= "2.1": - - def no_torch_dynamo(recursive=True): - """Decorator to disable Torch Dynamo, except during ONNX export.""" - - def decorator(f): - disabled_f = torch._dynamo.disable(f, recursive=recursive) - - @wraps(f) - def wrapper(*args, **kwargs): - # Check dynamically at call time, not at decoration time - if is_in_onnx_export_mode(): - return f(*args, **kwargs) - return disabled_f(*args, **kwargs) - - return wrapper - - return decorator - - else: - # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True - no_torch_dynamo = lambda recursive=True: torch._dynamo.disable + def no_torch_dynamo(recursive=True): + """Decorator to disable Torch Dynamo, except during ONNX export.""" + + def decorator(f): + # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True + disabled_f = ( + torch._dynamo.disable(f, recursive=recursive) + if torch.__version__ >= "2.1" + else torch._dynamo.disable(f) + ) + + @wraps(f) + def wrapper(*args, **kwargs): + if is_in_onnx_export_mode(): + return f(*args, **kwargs) + return disabled_f(*args, **kwargs) + + return wrapper + + return decorator def set_jit_fusion_options() -> None: From 768796bde28541478f7adb5eeafd0f914d76bd61 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Jan 2026 16:56:31 +0000 Subject: [PATCH 07/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../attention/dot_product_attention/dot_product_attention.py | 3 +-- .../pytorch/attention/dot_product_attention/utils.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 9c1db63341..67581c9edf 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1523,8 +1523,7 @@ def forward( if use_unfused_attention: allow_emulation = ( - os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" - or is_in_onnx_export_mode() + os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode() ) if checkpoint_core_attention: return self._checkpointed_attention_forward( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 6a8e7b755b..f80d715a13 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -475,8 +475,7 @@ def get_attention_backend( use_flash_attention_3 = False if use_unfused_attention: allow_emulation = ( - os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" - or is_in_onnx_export_mode() + os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode() ) if not allow_emulation: logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") From 6f04da2cc1f9c8d8a876e5d48a7ce14eaf206384 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 27 Jan 2026 17:38:40 +0000 Subject: [PATCH 08/10] lint fixes Signed-off-by: Pawel Gadzinski --- .../pytorch/attention/dot_product_attention/backends.py | 7 +++---- transformer_engine/pytorch/jit.py | 6 +++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 34d2dd4e8b..a194ca7aae 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -235,15 +235,14 @@ def onnx_forward(tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou out3 = out[numels[0] + numels[1] :].reshape(shapes[2]) return out1, out2, out3 - elif quantizer_name in ["S_quantizer", "O_quantizer"]: + if quantizer_name in ["S_quantizer", "O_quantizer"]: # Emulate FP8 on single tensor using quantizer's ONNX methods orig_dtype = tensor1.dtype t_fp8 = quantizer.onnx_quantize(tensor1) out = quantizer.onnx_dequantize(t_fp8).to(orig_dtype) return out, tensor2, tensor3 - else: - # Pass-through - return tensor1, tensor2, tensor3 + # Pass-through + return tensor1, tensor2, tensor3 class UnfusedDotProductAttention(torch.nn.Module): diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index 8e377b0203..502b3a5823 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -46,7 +46,6 @@ def wrapper(*args, **kwargs): # Decorator to disable Torch Dynamo # See: https://github.com/NVIDIA/TransformerEngine/issues/308 -no_torch_dynamo = lambda recursive=True: lambda func: func if torch.__version__ >= "2": import torch._dynamo @@ -70,6 +69,11 @@ def wrapper(*args, **kwargs): return wrapper return decorator +else: + # Fallback for PyTorch < 2.0: no-op decorator + def no_torch_dynamo(recursive=True): # pylint: disable=unused-argument + """No-op decorator for PyTorch < 2.0.""" + return lambda func: func def set_jit_fusion_options() -> None: From c3a1acfab0a7a1ecd892f18cec5a22b22055b48d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Jan 2026 17:40:28 +0000 Subject: [PATCH 09/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/jit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index 502b3a5823..1b93b8254c 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -69,6 +69,7 @@ def wrapper(*args, **kwargs): return wrapper return decorator + else: # Fallback for PyTorch < 2.0: no-op decorator def no_torch_dynamo(recursive=True): # pylint: disable=unused-argument From c07ffb7953aed7ce30053dbea09c4e5aef580834 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 27 Jan 2026 22:31:49 +0000 Subject: [PATCH 10/10] fix Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_onnx_export.py | 12 ++++++------ .../attention/dot_product_attention/backends.py | 6 +++++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index a1c54b22fc..9aea3bc274 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -713,14 +713,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation): _test_export_layernorm_mlp(activation=activation) -# FP8 recipes with fp8_dpa=True for attention FP8 emulation export test -fp8_dpa_recipes = [None] # None = no FP8 +# Quantization recipes with fp8_dpa=True for attention emulation export test +dpa_quantization_recipes = [None] # None = no quantization if fp8_available: - fp8_dpa_recipes.append(recipe.DelayedScaling(fp8_dpa=True)) - fp8_dpa_recipes.append(recipe.Float8CurrentScaling(fp8_dpa=True)) + dpa_quantization_recipes.append(recipe.DelayedScaling(fp8_dpa=True)) + dpa_quantization_recipes.append(recipe.Float8CurrentScaling(fp8_dpa=True)) -@pytest.mark.parametrize("fp8_recipe", fp8_dpa_recipes) +@pytest.mark.parametrize("fp8_recipe", dpa_quantization_recipes) @pytest.mark.parametrize( "precision, use_mask, attn_mask_type", [ @@ -774,7 +774,7 @@ def test_export_core_attention( serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) if precision in (torch.bfloat16,): return - atol = 1.5e-1 if is_fp8 else 1e-2 + atol = 5e-1 if is_fp8 else 1e-2 validate_result( fname, inp, model, is_fp8=True, atol=atol, input_names=input_names, te_outputs=te_outputs ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index a194ca7aae..aa6c063951 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -213,8 +213,12 @@ def onnx_forward(tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou ONNX-compatible forward for FP8 emulation using operations with defined ONNX translations. """ # pylint: disable=unused-argument + is_qkv_quantizer = quantizer_name == "QKV_quantizer" + assert isinstance( + quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ), "ONNX FP8 emulation path supports only Float8 quantizers." - if quantizer_name == "QKV_quantizer": + if is_qkv_quantizer: # Flatten + concatenate + quantize + split. Equivalent to combine_and_quantize Case 3. orig_dtype = tensor1.dtype shapes = [tensor1.shape, tensor2.shape, tensor3.shape]