From c91cd3512b44fe1706020ff03542e38fa79c5566 Mon Sep 17 00:00:00 2001 From: Dongmin Ra Date: Fri, 27 Feb 2026 16:27:23 +0900 Subject: [PATCH 1/2] fix: scope get_full_cu_seqlens cache key by device and inference mode Signed-off-by: Dongmin Ra --- .../attention/test_cu_seqlens_cache.py | 97 +++++++++++++++++++ .../attention/dot_product_attention/utils.py | 9 +- 2 files changed, 103 insertions(+), 3 deletions(-) create mode 100644 tests/pytorch/attention/test_cu_seqlens_cache.py diff --git a/tests/pytorch/attention/test_cu_seqlens_cache.py b/tests/pytorch/attention/test_cu_seqlens_cache.py new file mode 100644 index 0000000000..be4895199a --- /dev/null +++ b/tests/pytorch/attention/test_cu_seqlens_cache.py @@ -0,0 +1,97 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch + +from transformer_engine.pytorch import DotProductAttention +from transformer_engine.pytorch.attention.dot_product_attention import utils as dpa_utils +from transformer_engine.pytorch.utils import get_cudnn_version + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") + + +@pytest.fixture(autouse=True) +def clear_cu_seqlens_cache(): + dpa_utils._cu_seqlens_cache.clear() + yield + dpa_utils._cu_seqlens_cache.clear() + + +def _make_dpa(device: torch.device) -> DotProductAttention: + return DotProductAttention( + num_attention_heads=2, + kv_channels=16, + attention_dropout=0.0, + qkv_format="bshd", + attn_mask_type="no_mask", + attention_type="self", + ).to(device=device, dtype=torch.float16) + + +def _make_qkv(device: torch.device, requires_grad: bool = False): + shape = (2, 8, 2, 16) + q = torch.randn(*shape, device=device, dtype=torch.float16, requires_grad=requires_grad) + k = torch.randn(*shape, device=device, dtype=torch.float16, requires_grad=requires_grad) + v = torch.randn(*shape, device=device, dtype=torch.float16, requires_grad=requires_grad) + return q, k, v + + +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") +def test_cu_seqlens_cache_isolated_across_devices_for_forward(): + if torch.cuda.device_count() < 2: + pytest.skip("Requires at least 2 CUDA devices.") + + dev0 = torch.device("cuda:0") + dev1 = torch.device("cuda:1") + + dpa0 = _make_dpa(dev0).eval() + dpa1 = _make_dpa(dev1).eval() + + with torch.no_grad(): + q0, k0, v0 = _make_qkv(dev0) + out0 = dpa0(q0, k0, v0, attn_mask_type="no_mask") + + q1, k1, v1 = _make_qkv(dev1) + out1 = dpa1(q1, k1, v1, attn_mask_type="no_mask") + + assert out0.device == dev0 + assert out1.device == dev1 + + expected_key_0 = (2, 8, dev0, False) + expected_key_1 = (2, 8, dev1, False) + assert expected_key_0 in dpa_utils._cu_seqlens_cache + assert expected_key_1 in dpa_utils._cu_seqlens_cache + + assert dpa_utils._cu_seqlens_cache[expected_key_0].device == dev0 + assert dpa_utils._cu_seqlens_cache[expected_key_1].device == dev1 + + +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") +def test_cu_seqlens_cache_isolated_between_inference_and_train_forward(): + dev = torch.device("cuda:0") + dpa = _make_dpa(dev) + + dpa.eval() + with torch.inference_mode(): + q_inf, k_inf, v_inf = _make_qkv(dev) + out_inf = dpa(q_inf, k_inf, v_inf, attn_mask_type="no_mask") + + inf_key = (2, 8, dev, True) + assert inf_key in dpa_utils._cu_seqlens_cache + assert dpa_utils._cu_seqlens_cache[inf_key].device == dev + + dpa.train() + q_tr, k_tr, v_tr = _make_qkv(dev, requires_grad=True) + out_tr = dpa(q_tr, k_tr, v_tr, attn_mask_type="no_mask") + out_tr.sum().backward() + + train_key = (2, 8, dev, False) + assert train_key in dpa_utils._cu_seqlens_cache + assert dpa_utils._cu_seqlens_cache[train_key].device == dev + + assert out_inf.device == dev + assert out_tr.device == dev + assert dpa_utils._cu_seqlens_cache[inf_key] is not dpa_utils._cu_seqlens_cache[train_key] diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 567fd17c34..e5e9642c8b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1583,11 +1583,14 @@ def _get_cu_seqlens(batch_size, max_seqlen, device): if is_in_onnx_export_mode(): return _get_cu_seqlens(batch_size, max_seqlen, device) - if (batch_size, max_seqlen) not in _cu_seqlens_cache: - _cu_seqlens_cache[(batch_size, max_seqlen)] = _get_cu_seqlens( + + is_inference = torch.is_inference_mode_enabled() + cu_seqlens_cache_key = (batch_size, max_seqlen, device, is_inference) + if cu_seqlens_cache_key not in _cu_seqlens_cache: + _cu_seqlens_cache[cu_seqlens_cache_key] = _get_cu_seqlens( batch_size, max_seqlen, device ) - return _cu_seqlens_cache[(batch_size, max_seqlen)] + return _cu_seqlens_cache[cu_seqlens_cache_key] @jit_fuser From 02fbe605e2d1bc348f3964a09b96735f578de019 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Mar 2026 09:10:42 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/dot_product_attention/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index e5e9642c8b..9acfe0e89c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1587,9 +1587,7 @@ def _get_cu_seqlens(batch_size, max_seqlen, device): is_inference = torch.is_inference_mode_enabled() cu_seqlens_cache_key = (batch_size, max_seqlen, device, is_inference) if cu_seqlens_cache_key not in _cu_seqlens_cache: - _cu_seqlens_cache[cu_seqlens_cache_key] = _get_cu_seqlens( - batch_size, max_seqlen, device - ) + _cu_seqlens_cache[cu_seqlens_cache_key] = _get_cu_seqlens(batch_size, max_seqlen, device) return _cu_seqlens_cache[cu_seqlens_cache_key]