Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion qa/L1_pytorch_onnx_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"

python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py
# NVTE_UnfusedDPA_Emulate_FP8=1 enables FP8 attention emulation when no native backend is available
NVTE_UnfusedDPA_Emulate_FP8=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py
22 changes: 17 additions & 5 deletions tests/pytorch/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation):
_test_export_layernorm_mlp(activation=activation)


# Quantization recipes with fp8_dpa=True for attention emulation export test
dpa_quantization_recipes = [None] # None = no quantization
if fp8_available:
dpa_quantization_recipes.append(recipe.DelayedScaling(fp8_dpa=True))
dpa_quantization_recipes.append(recipe.Float8CurrentScaling(fp8_dpa=True))


@pytest.mark.parametrize("fp8_recipe", dpa_quantization_recipes)
@pytest.mark.parametrize(
"precision, use_mask, attn_mask_type",
[
Expand All @@ -730,6 +738,7 @@ def test_export_core_attention(
precision: torch.dtype,
use_mask: bool,
attn_mask_type: str,
fp8_recipe: recipe.Recipe,
):
# Set dimensions (these are arbitrary).
seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64)
Expand All @@ -749,22 +758,25 @@ def test_export_core_attention(

mask_str = get_attn_mask_str(use_mask, attn_mask_type)
high_prec_str = dtype2str(precision)
fname = f"te.core_attention{mask_str}{high_prec_str}.onnx"
fp8_str = "_fp8_dpa" if fp8_recipe is not None else ""
fname = f"te.core_attention{fp8_str}{mask_str}{high_prec_str}.onnx"

is_fp8 = fp8_recipe is not None

model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
attention_dropout=0.5,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
).to(device="cuda")
do_export(model, inp, fname, input_names=input_names, fp8_recipe=None)
te_outputs = te_infer(model, inp, is_fp8=False, fp8_recipe=None)
do_export(model, inp, fname, input_names=input_names, fp8_recipe=fp8_recipe)
te_outputs = te_infer(model, inp, is_fp8=is_fp8, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision in (torch.bfloat16,):
return
atol = 5e-1 if is_fp8 else 1e-2
validate_result(
fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs
fname, inp, model, is_fp8=True, atol=atol, input_names=input_names, te_outputs=te_outputs
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ class FP8EmulationFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout):
# pylint: disable=missing-function-docstring
if is_in_onnx_export_mode():
return FP8EmulationFunc.onnx_forward(
tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout
)

if quantizer_name == "QKV_quantizer":
query_layer, key_layer, value_layer = [
x.contiguous() for x in [tensor1, tensor2, tensor3]
Expand Down Expand Up @@ -202,6 +207,47 @@ def backward(ctx, grad1, grad2, grad3):
tensors = grad1, grad2, grad3
return tensors[0], tensors[1], tensors[2], None, None, None

@staticmethod
def onnx_forward(tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout=None):
"""
ONNX-compatible forward for FP8 emulation using operations with defined ONNX translations.
"""
# pylint: disable=unused-argument
is_qkv_quantizer = quantizer_name == "QKV_quantizer"
assert isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
), "ONNX FP8 emulation path supports only Float8 quantizers."
Comment on lines +217 to +219
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assert statement will cause ONNX export to fail if non-Float8 quantizers are used. Consider replacing with a runtime check that raises a more descriptive error or logging a warning.


if is_qkv_quantizer:
# Flatten + concatenate + quantize + split. Equivalent to combine_and_quantize Case 3.
orig_dtype = tensor1.dtype
shapes = [tensor1.shape, tensor2.shape, tensor3.shape]
numels = [tensor1.numel(), tensor2.numel(), tensor3.numel()]

# Flatten and concatenate
combined = torch.cat(
[tensor1.reshape(-1), tensor2.reshape(-1), tensor3.reshape(-1)], dim=0
)
Comment on lines +227 to +230
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for FP8 attention, although we'll need to revisit whenever we support MXFP8 or NVFP4. Why can't we concatenate the 2D tensors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this will work for all layouts and different max_q_length and max_kv_length. Added asserions that's it not mxfp8, because I want to merge it fast. I will rethink it when adding support for mxfp8.


# Quantize + dequantize combined tensor using quantizer's ONNX methods
combined_fp8 = quantizer.onnx_quantize(combined)
out = quantizer.onnx_dequantize(combined_fp8).to(orig_dtype)

# Split back
out1 = out[: numels[0]].reshape(shapes[0])
out2 = out[numels[0] : numels[0] + numels[1]].reshape(shapes[1])
out3 = out[numels[0] + numels[1] :].reshape(shapes[2])

return out1, out2, out3
if quantizer_name in ["S_quantizer", "O_quantizer"]:
# Emulate FP8 on single tensor using quantizer's ONNX methods
orig_dtype = tensor1.dtype
t_fp8 = quantizer.onnx_quantize(tensor1)
out = quantizer.onnx_dequantize(t_fp8).to(orig_dtype)
return out, tensor2, tensor3
# Pass-through
return tensor1, tensor2, tensor3


class UnfusedDotProductAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1552,7 +1552,9 @@ def forward(
)

if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
allow_emulation = (
os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode()
)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.unfused_attention,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,9 @@ def get_attention_backend(
logger.debug("Disabling FlashAttention 3 for FP8 training")
use_flash_attention_3 = False
if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
allow_emulation = (
os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode()
)
if not allow_emulation:
logger.debug("Disabling UnfusedDotProductAttention for FP8 attention")
use_unfused_attention = False
Expand Down
34 changes: 26 additions & 8 deletions transformer_engine/pytorch/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,35 @@ def wrapper(*args, **kwargs):

# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
import torch._dynamo

if torch.__version__ >= "2.1":
no_torch_dynamo = lambda recursive=True: lambda f: (
f if is_in_onnx_export_mode() else torch._dynamo.disable(f, recursive=recursive)
)
else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
def no_torch_dynamo(recursive=True):
"""Decorator to disable Torch Dynamo, except during ONNX export."""

def decorator(f):
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
disabled_f = (
torch._dynamo.disable(f, recursive=recursive)
if torch.__version__ >= "2.1"
else torch._dynamo.disable(f)
)

@wraps(f)
def wrapper(*args, **kwargs):
if is_in_onnx_export_mode():
return f(*args, **kwargs)
return disabled_f(*args, **kwargs)

return wrapper

return decorator

else:
# Fallback for PyTorch < 2.0: no-op decorator
def no_torch_dynamo(recursive=True): # pylint: disable=unused-argument
"""No-op decorator for PyTorch < 2.0."""
return lambda func: func


def set_jit_fusion_options() -> None:
Expand Down
Loading