From fbb0702939387d7766dd0b7359511a38eef18d89 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Mon, 5 Jan 2026 11:10:11 -0800 Subject: [PATCH 01/11] Update THD sink attention logic for newer cudnn versions THD Sink attention is supported in 9.18.0 Signed-off-by: Chen Cui --- .../pytorch/attention/dot_product_attention/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index bf19388d7e..3749d40e37 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -716,10 +716,11 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_unfused_attention = False if qkv_format == "thd": - logger.debug( - "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type - ) - use_fused_attention = False + if cudnn_version < (9, 18, 0): + logger.debug( + "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type + ) + use_fused_attention = False logger.debug( "Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd", softmax_type, From 01848c0687b89ec87d586d5e1070772d7cf68ea8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Jan 2026 19:11:33 +0000 Subject: [PATCH 02/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/dot_product_attention/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 3749d40e37..fce04bfa2d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -718,7 +718,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if qkv_format == "thd": if cudnn_version < (9, 18, 0): logger.debug( - "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type + "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", + softmax_type, ) use_fused_attention = False logger.debug( From ab13ba0aff7ab94669cd8939c35411a094f02e4c Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Tue, 6 Jan 2026 15:33:24 -0800 Subject: [PATCH 03/11] update thd sink attention logic for cp>1 Signed-off-by: Chen Cui --- .../dot_product_attention/context_parallel.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 75b360e485..a12ac9ae1a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4026,28 +4026,29 @@ def attn_forward_func_with_cp( assert not sliding_window_attn or cp_comm_type in [ "a2a", "all_gather", - ], "Context parallelism does not support sliding window attention with {cp_comm_type=}!" + ], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!" enable_mla = k.shape[-1] != v.shape[-1] assert not enable_mla or cp_comm_type in [ "p2p", "a2a+p2p", - ], "Context parallelism does not support MLA with {cp_comm_type=}!" + ], f"Context parallelism does not support MLA with {cp_comm_type=}!" if fp8 and fp8_meta is not None: if fp8_meta["recipe"].fp8_dpa: assert ( softmax_type == "vanilla" - ), "Context parallelism does not support {softmax_type=} with FP8 attention!" + ), f"Context parallelism does not support {softmax_type=} with FP8 attention!" assert ( softmax_type == "vanilla" or use_fused_attention - ), "Context parallelism only supports {softmax_type=} with FusedAttention backend!" + ), f"Context parallelism only supports {softmax_type=} with FusedAttention backend!" assert ( softmax_type == "vanilla" or cp_comm_type == "a2a" - ), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" - assert ( - softmax_type == "vanilla" or qkv_format != "thd" - ), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!" + ), f"Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" + if get_cudnn_version() < (9, 18, 0): + assert ( + softmax_type == "vanilla" or qkv_format != "thd" + ), f"Before cuDNN 9.18.0, context parallelism does not support {softmax_type=} with qkv_format = 'thd'!" args = [ is_training, From c2c4341ef27e4d1ee3c75fa008a9ee071fa6ac9a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Jan 2026 23:55:03 +0000 Subject: [PATCH 04/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../attention/dot_product_attention/context_parallel.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index a12ac9ae1a..a5931188dc 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4046,9 +4046,10 @@ def attn_forward_func_with_cp( softmax_type == "vanilla" or cp_comm_type == "a2a" ), f"Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" if get_cudnn_version() < (9, 18, 0): - assert ( - softmax_type == "vanilla" or qkv_format != "thd" - ), f"Before cuDNN 9.18.0, context parallelism does not support {softmax_type=} with qkv_format = 'thd'!" + assert softmax_type == "vanilla" or qkv_format != "thd", ( + f"Before cuDNN 9.18.0, context parallelism does not support {softmax_type=} with" + " qkv_format = 'thd'!" + ) args = [ is_training, From 392a0334b84679dcb53ac0966460eef9e3f75392 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Wed, 7 Jan 2026 14:54:44 -0800 Subject: [PATCH 05/11] add unit test for thd + sink attention Signed-off-by: Chen Cui --- tests/pytorch/attention/test_attention.py | 11 +++++++++++ .../pytorch/attention/dot_product_attention/utils.py | 5 ----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index eb7905bcd5..a81ee34ab4 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -419,6 +419,17 @@ def test_dpa_softmax(dtype, model_configs, model): ) +@pytest.mark.skipif(get_cudnn_version() < (9, 18, 0), reason="cuDNN 9.18.0+ is required.") +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_configs", [model_configs_softmax]) +@pytest.mark.parametrize("model", model_configs_softmax.keys()) +def test_dpa_softmax_thd(dtype, model_configs, model): + """Test DotProductAttention module with different softmax types""" + test_dot_product_attention( + dtype, model_configs, model, True, True, "thd_thd_thd", False, False + ) + + model_configs_mla = { # test: ModelConfig(b, sq, hq, dqk) "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index fce04bfa2d..8f8f4f4621 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -722,11 +722,6 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt softmax_type, ) use_fused_attention = False - logger.debug( - "Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd", - softmax_type, - ) - use_unfused_attention = False if context_parallel: logger.debug( "Disabling UnfusedDotProductAttention for context parallelism with softmax_type" From a75af446b7f2e59be2cbd0d4a712ef84e5b6d067 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Jan 2026 22:55:28 +0000 Subject: [PATCH 06/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index a81ee34ab4..de6f983e6f 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -425,9 +425,7 @@ def test_dpa_softmax(dtype, model_configs, model): @pytest.mark.parametrize("model", model_configs_softmax.keys()) def test_dpa_softmax_thd(dtype, model_configs, model): """Test DotProductAttention module with different softmax types""" - test_dot_product_attention( - dtype, model_configs, model, True, True, "thd_thd_thd", False, False - ) + test_dot_product_attention(dtype, model_configs, model, True, True, "thd_thd_thd", False, False) model_configs_mla = { From 200632f6ab17b987832c93d9fb61f2a72f7f3366 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Thu, 8 Jan 2026 15:42:15 -0800 Subject: [PATCH 07/11] address comments Signed-off-by: Chen Cui --- .../pytorch/attention/dot_product_attention/utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 8f8f4f4621..ac0d2bb400 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -718,17 +718,11 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if qkv_format == "thd": if cudnn_version < (9, 18, 0): logger.debug( - "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", + "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN version < 9.18", softmax_type, ) use_fused_attention = False if context_parallel: - logger.debug( - "Disabling UnfusedDotProductAttention for context parallelism with softmax_type" - " = %s", - softmax_type, - ) - use_unfused_attention = False if cp_comm_type != "a2a": logger.debug( "Disabling FusedAttention for context parallelism with softmax_type = %s and" From 7f4333abc3ee23e330222703cbc0d127f8afdaf2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 23:43:10 +0000 Subject: [PATCH 08/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/dot_product_attention/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index ac0d2bb400..097a3b60e5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -718,7 +718,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if qkv_format == "thd": if cudnn_version < (9, 18, 0): logger.debug( - "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN version < 9.18", + "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN" + " version < 9.18", softmax_type, ) use_fused_attention = False From 549888e74318668640fc78099a7c7bf9b5d54663 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Tue, 20 Jan 2026 14:04:01 -0800 Subject: [PATCH 09/11] do not skip thd cp sink attention test Signed-off-by: Chen Cui --- tests/pytorch/attention/test_attention_with_cp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 9480b8de70..59cc5f6547 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -283,9 +283,9 @@ def test_cp_with_fused_attention( pytest.skip( "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" ) - if config.softmax_type != "vanilla" and qkv_format == "thd": + if get_cudnn_version() < (9, 18, 0) and config.softmax_type != "vanilla" and qkv_format == "thd": pytest.skip( - "CP implementation does not support qkv_format=thd for non-vanilla softmax types!" + "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for non-vanilla softmax types!" ) dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} From f318e8b7e490c2eaae5346eac18084081480c43c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Jan 2026 22:04:53 +0000 Subject: [PATCH 10/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention_with_cp.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 59cc5f6547..06ed6e5723 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -283,9 +283,14 @@ def test_cp_with_fused_attention( pytest.skip( "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" ) - if get_cudnn_version() < (9, 18, 0) and config.softmax_type != "vanilla" and qkv_format == "thd": + if ( + get_cudnn_version() < (9, 18, 0) + and config.softmax_type != "vanilla" + and qkv_format == "thd" + ): pytest.skip( - "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for non-vanilla softmax types!" + "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for" + " non-vanilla softmax types!" ) dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} From 96e5c6d9a851722234ee8e7338438e00f8661a18 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Wed, 21 Jan 2026 19:10:34 -0800 Subject: [PATCH 11/11] disable deterministic mode for sink attention Signed-off-by: Chen Cui --- .../pytorch/attention/dot_product_attention/utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 8362e33846..fcac740cc3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1041,6 +1041,15 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_flash_attention_2 = False if use_fused_attention and deterministic: + if softmax_type != "vanilla": + logger.debug( + "Disabling FusedAttention for determinism reasons with softmax_type = %s. " + "Sink attention (off-by-one and learnable softmax) requires " + "NVTE_ALLOW_NONDETERMINISTIC_ALGO=1", + softmax_type, + ) + use_fused_attention = False + fused_attention_backend = None if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: logger.debug("Disabling FusedAttention for determinism reasons with FP8") use_fused_attention = False