diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 32ea1694ee..89dd0eb77c 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -26,6 +26,7 @@ from transformer_engine.pytorch.attention.dot_product_attention import ( _attention_backends, ) +import transformer_engine.pytorch.attention.dot_product_attention.backends as dpa_backends from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils, check_set_window_size, @@ -1390,6 +1391,661 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: return out, max_logit, (None, None, None, d_softmax_offset) +def _score_mod_causal(score_mod_graph, score_tensor, tensors): + """cuDNN frontend score_mod implementing top-left causal masking.""" + cudnn = dpa_backends._import_cudnn_frontend() + + row_index = score_mod_graph.gen_index(input=score_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = score_mod_graph.gen_index(input=score_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + keep = score_mod_graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + keep.set_data_type(cudnn.data_type.BOOLEAN) + return score_mod_graph.binary_select( + input0=score_tensor, + input1=tensors["neg_inf"], + mask=keep, + ) + + +def _score_mod_causal_bprop(score_mod_graph, dP_tensor, tensors): + """cuDNN frontend score_mod_bprop implementing top-left causal masking.""" + cudnn = dpa_backends._import_cudnn_frontend() + + row_index = score_mod_graph.gen_index(input=dP_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = score_mod_graph.gen_index(input=dP_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + keep = score_mod_graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + keep.set_data_type(cudnn.data_type.BOOLEAN) + return score_mod_graph.binary_select( + input0=dP_tensor, + input1=tensors["zero"], + mask=keep, + ) + + +def _score_mod_post_scale_bias(score_mod_graph, score_tensor, _tensors): + """cuDNN frontend score_mod adding post-scale bias.""" + cudnn = dpa_backends._import_cudnn_frontend() + + row_index = score_mod_graph.gen_index(input=score_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = score_mod_graph.gen_index(input=score_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + post_scale_bias = score_mod_graph.sub( + a=row_index, + b=col_index, + compute_data_type=cudnn.data_type.FLOAT, + ) + post_scale_bias.set_data_type(cudnn.data_type.FLOAT) + return score_mod_graph.add( + a=score_tensor, + b=post_scale_bias, + compute_data_type=cudnn.data_type.FLOAT, + ) + + +def _score_mod_identity_bprop(_score_mod_graph, dP_tensor, _tensors): + """cuDNN frontend score_mod_bprop for score_mods with unit score derivative.""" + return dP_tensor + + +class _ScoreModSoftcap: + """cuDNN frontend score_mod implementing softcapping.""" + + def __init__(self): + self.before_tanh_activation = None + + def score_mod_graph_cache_key(self): + """Graph topology key for softcap score_mod.""" + return ("softcap",) + + def forward(self, score_mod_graph, score_tensor, tensors): + """Apply softcap * tanh(score / softcap).""" + cudnn = dpa_backends._import_cudnn_frontend() + + self.before_tanh_activation = score_mod_graph.div( + a=score_tensor, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + self.before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) + tanh_out = score_mod_graph.tanh(input=self.before_tanh_activation) + tanh_out.set_data_type(cudnn.data_type.FLOAT) + return score_mod_graph.mul( + a=tanh_out, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + + def backward(self, score_mod_graph, dP_tensor, tensors): + """Apply softcap derivative to dP.""" + cudnn = dpa_backends._import_cudnn_frontend() + + d_tanh_out = score_mod_graph.mul( + a=dP_tensor, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + d_tanh_out.set_data_type(cudnn.data_type.FLOAT) + d_before_tanh_activation = score_mod_graph.tanh_backward( + loss=d_tanh_out, + input=self.before_tanh_activation, + compute_data_type=cudnn.data_type.FLOAT, + ) + d_before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) + return score_mod_graph.div( + a=d_before_tanh_activation, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + + +def _score_mod_cache_cpu_inputs(): + """Small CPU tensors for score_mod cache-key tests.""" + q = torch.empty((2, 4, 3, 8), dtype=torch.float16) + k = torch.empty((2, 4, 3, 8), dtype=torch.float16) + v = torch.empty((2, 4, 3, 8), dtype=torch.float16) + o = torch.empty((2, 4, 3, 8), dtype=torch.float16) + stats = torch.empty((2, 3, 4, 1), dtype=torch.float32) + return q, k, v, o, stats + + +def test_score_mod_cache_bound_method_requires_explicit_key(): + """Unkeyed bound methods should be uncached instead of keyed by object id.""" + + class UnkeyedScoreMod: + def forward(self, _score_mod_graph, score_tensor, _tensors): + return score_tensor + + key = dpa_backends._score_mod_callback_cache_key(UnkeyedScoreMod().forward) + + assert key is dpa_backends._SCORE_MOD_UNCACHEABLE + + +def test_score_mod_cache_bound_method_explicit_key_stable(): + """Bound method keys should be stable when a structural graph key is provided.""" + softcap = _ScoreModSoftcap() + key_0 = dpa_backends._score_mod_callback_cache_key(softcap.forward) + key_1 = dpa_backends._score_mod_callback_cache_key(softcap.forward) + other_key = dpa_backends._score_mod_callback_cache_key(_ScoreModSoftcap().forward) + + assert key_0 == key_1 + assert key_0 == other_key + + +def test_score_mod_cache_explicit_key_distinguishes_topology(): + """Stateful score_mods can opt into caching with topology-specific keys.""" + + class LayeredScoreMod: + def __init__(self, num_layers): + self.num_layers = num_layers + + def score_mod_graph_cache_key(self): + return {"num_layers": self.num_layers} + + def forward(self, _score_mod_graph, score_tensor, _tensors): + return score_tensor + + key_0 = dpa_backends._score_mod_callback_cache_key(LayeredScoreMod(1).forward) + key_1 = dpa_backends._score_mod_callback_cache_key(LayeredScoreMod(1).forward) + key_2 = dpa_backends._score_mod_callback_cache_key(LayeredScoreMod(2).forward) + + assert key_0 == key_1 + assert key_0 != key_2 + + +def test_score_mod_cache_module_lambda_keys_do_not_collide(): + """Module-level lambdas should not reuse graphs only because qualnames match.""" + score_mod_0 = lambda _graph, score_tensor, _tensors: score_tensor # noqa: E731 + score_mod_1 = lambda _graph, score_tensor, _tensors: score_tensor # noqa: E731 + score_mod_0.__module__ = __name__ + score_mod_1.__module__ = __name__ + score_mod_0.__qualname__ = "" + score_mod_1.__qualname__ = "" + + key_0 = dpa_backends._score_mod_callback_cache_key(score_mod_0) + key_1 = dpa_backends._score_mod_callback_cache_key(score_mod_1) + + assert key_0 is not dpa_backends._SCORE_MOD_UNCACHEABLE + assert key_1 is not dpa_backends._SCORE_MOD_UNCACHEABLE + assert key_0 != key_1 + + +def test_score_mod_cache_key_ignores_pass_by_value_values(): + """Scalar CPU tensor values are runtime inputs, not execution-plan metadata.""" + q, k, v, o, stats = _score_mod_cache_cpu_inputs() + key_0 = dpa_backends._cudnn_score_mod_fwd_cache_key( + True, + q, + k, + v, + o, + stats, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor(0.8, dtype=torch.float32)}, + ) + key_1 = dpa_backends._cudnn_score_mod_fwd_cache_key( + True, + q, + k, + v, + o, + stats, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor(1.2, dtype=torch.float32)}, + ) + key_2 = dpa_backends._cudnn_score_mod_fwd_cache_key( + True, + q, + k, + v, + o, + stats, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor([0.8], dtype=torch.float32)}, + ) + + assert key_0 == key_1 + assert key_0 != key_2 + + +def test_score_mod_cache_fwd_reuses_graph_for_pass_by_value_changes(monkeypatch): + """Fprop graph cache should reuse entries when only scalar CPU tensor values change.""" + q, k, v, o, stats = _score_mod_cache_cpu_inputs() + cache = dpa_backends._cudnn_score_mod_graph_cache + saved_cache = dict(cache) + build_entries = [] + + def fake_build( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats, + ): + del ( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats, + ) + entry = object() + build_entries.append(entry) + return entry + + monkeypatch.setattr(dpa_backends, "_build_cudnn_score_mod_fwd_graph", fake_build) + try: + cache.clear() + entry_0 = dpa_backends._get_cudnn_score_mod_fwd_graph( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor(0.8, dtype=torch.float32)}, + o, + stats, + ) + entry_1 = dpa_backends._get_cudnn_score_mod_fwd_graph( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor(1.2, dtype=torch.float32)}, + o, + stats, + ) + entry_2 = dpa_backends._get_cudnn_score_mod_fwd_graph( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor([0.8], dtype=torch.float32)}, + o, + stats, + ) + finally: + cache.clear() + cache.update(saved_cache) + + assert entry_0 is entry_1 + assert entry_2 is not entry_0 + assert len(build_entries) == 2 + + +def test_score_mod_cache_fwd_skips_cache_for_unkeyed_bound_method(monkeypatch): + """Unkeyed bound methods should build fresh graphs instead of using an id-based key.""" + + class UnkeyedScoreMod: + def forward(self, _score_mod_graph, score_tensor, _tensors): + return score_tensor + + q, k, v, o, stats = _score_mod_cache_cpu_inputs() + score_mod = UnkeyedScoreMod() + cache = dpa_backends._cudnn_score_mod_graph_cache + saved_cache = dict(cache) + build_entries = [] + + def fake_build( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats, + ): + del ( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats, + ) + entry = object() + build_entries.append(entry) + return entry + + monkeypatch.setattr(dpa_backends, "_build_cudnn_score_mod_fwd_graph", fake_build) + try: + cache.clear() + entry_0 = dpa_backends._get_cudnn_score_mod_fwd_graph( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + score_mod.forward, + None, + o, + stats, + ) + entry_1 = dpa_backends._get_cudnn_score_mod_fwd_graph( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + score_mod.forward, + None, + o, + stats, + ) + assert len(cache) == 0 + finally: + cache.clear() + cache.update(saved_cache) + + assert entry_0 is not entry_1 + assert len(build_entries) == 2 + + +def test_score_mod_tensors_are_version_checked_for_backward(monkeypatch): + """In-place score_mod tensor updates before backward should be rejected.""" + + class FakeEntry: + graph = object() + q = object() + k = object() + v = object() + output = object() + stats = object() + score_mod_graph_tensors = {"softcap": object()} + workspace_size = 1 + + def fake_execute(graph, variant_pack, workspace_size, device): + del graph, variant_pack, workspace_size, device + + q, k, v, _, _ = _score_mod_cache_cpu_inputs() + q = q.requires_grad_() + k = k.requires_grad_() + v = v.requires_grad_() + softcap = torch.tensor(0.8, dtype=torch.float32) + + monkeypatch.setattr(dpa_backends, "_get_cudnn_score_mod_fwd_graph", lambda *args: FakeEntry()) + monkeypatch.setattr(dpa_backends, "_execute_cudnn_graph", fake_execute) + + out = dpa_backends.FusedAttentionWithScoreModFunc.apply( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + None, + {"softcap": softcap}, + None, + False, + ) + softcap.add_(1.0) + + with pytest.raises(RuntimeError, match="modified by an inplace operation"): + out.sum().backward() + + +def _post_scale_bias(config, dtype): + """Materialize score + (q_idx - kv_idx) as post-scale attention bias.""" + q_idx = torch.arange(config.max_seqlen_q, dtype=torch.float32, device="cuda").view(1, 1, -1, 1) + kv_idx = torch.arange(config.max_seqlen_kv, dtype=torch.float32, device="cuda").view( + 1, 1, 1, -1 + ) + return (q_idx - kv_idx).to(dtype).expand(1, config.num_heads, -1, -1).contiguous() + + +def _to_bhsd(tensor, qkv_format): + """Convert SBHD/BSHD test tensors to logical BHSD.""" + if qkv_format == "sbhd": + return tensor.permute(1, 2, 0, 3) + return tensor.permute(0, 2, 1, 3) + + +def _from_bhsd(tensor, qkv_format): + """Convert logical BHSD test tensors to SBHD/BSHD.""" + if qkv_format == "sbhd": + return tensor.permute(2, 0, 1, 3).contiguous() + return tensor.permute(0, 2, 1, 3).contiguous() + + +def _pytorch_softcap_attention(q, k, v, qkv_format, softmax_scale, softcap): + """PyTorch reference for softcapped scaled dot-product attention.""" + q_bhsd = _to_bhsd(q, qkv_format).float() + k_bhsd = _to_bhsd(k, qkv_format).float() + v_bhsd = _to_bhsd(v, qkv_format).float() + scores = torch.matmul(q_bhsd, k_bhsd.transpose(-2, -1)) * softmax_scale + scores = softcap * torch.tanh(scores / softcap) + probs = torch.softmax(scores, dim=-1) + out = _from_bhsd(torch.matmul(probs, v_bhsd), qkv_format).to(v.dtype) + return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1]) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") +@pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"]) +@pytest.mark.parametrize( + "score_mod_case, scalar_loss", + [ + ("causal", False), + ("causal", True), + ("softcap", False), + ("post_scale_bias", False), + ], +) +def test_dot_product_attention_score_mod(dtype, qkv_format, score_mod_case, scalar_loss): + """Compare score_mod attention against equivalent reference implementations.""" + try: + dpa_backends._import_cudnn_frontend() + except ImportError: + pytest.skip("cuDNN Python frontend is required for score_mod attention.") + + reset_rng_states() + + config = ModelConfig( + 2, + 64 if score_mod_case == "causal" else 16, + 4, + 64, + attn_mask_type="no_mask", + ) + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", + score_mod=True, + score_mod_bprop=True, + ) + if not available_backends[1] or not fused_attn_backends: + pytest.skip("FusedAttention is not available for this score_mod configuration.") + + if score_mod_case == "post_scale_bias": + bias_config = ModelConfig( + config.batch_size, + config.max_seqlen_q, + config.num_heads, + config.head_dim_qk, + attn_mask_type="no_mask", + attn_bias_type="post_scale_bias", + bias_shape="1hss", + ) + bias_available_backends, _, bias_fused_attn_backends = get_available_attention_backends( + bias_config, + qkv_dtype=dtype, + qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", + ) + if not bias_available_backends[1] or not bias_fused_attn_backends: + pytest.skip("FusedAttention is not available for post_scale_bias reference.") + + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + if qkv_format == "sbhd": + q_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + else: + q_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + + if score_mod_case == "softcap": + q = torch.randn(q_shape, dtype=dtype, device="cuda").requires_grad_() + k = torch.randn(kv_shape, dtype=dtype, device="cuda").requires_grad_() + v = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + else: + q = (0.1 * torch.randn(q_shape, dtype=dtype, device="cuda")).requires_grad_() + k = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + v = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + q_ref, k_ref, v_ref = [x.detach().clone().requires_grad_() for x in (q, k, v)] + + flex_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="no_mask", + layer_number=1, + ).to(dtype=dtype, device="cuda") + + if score_mod_case == "causal": + score_mod_kwargs = { + "score_mod": _score_mod_causal, + "score_mod_bprop": _score_mod_causal_bprop, + "score_mod_tensors": {"neg_inf": torch.full((1, 1, 1, 1), -1e9)}, + "score_mod_bprop_tensors": {"zero": torch.full((1, 1, 1, 1), 0.0)}, + } + ref_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="causal", + layer_number=1, + ).to(dtype=dtype, device="cuda") + out_ref = ref_attn(q_ref, k_ref, v_ref, qkv_format=qkv_format, attn_mask_type="causal") + tols = dict(atol=5e-2, rtol=5e-2) + elif score_mod_case == "softcap": + softcap = 0.8 + softcap_tensor = torch.full((1, 1, 1, 1), softcap) + softcap_score_mod = _ScoreModSoftcap() + score_mod_kwargs = { + "score_mod": softcap_score_mod.forward, + "score_mod_bprop": softcap_score_mod.backward, + "score_mod_tensors": {"softcap": softcap_tensor}, + "score_mod_bprop_tensors": {"softcap": softcap_tensor}, + } + out_ref = _pytorch_softcap_attention( + q_ref, + k_ref, + v_ref, + qkv_format, + 1.0 / config.head_dim_qk**0.5, + softcap, + ) + tols = dict(atol=7e-2, rtol=7e-2) + else: + assert score_mod_case == "post_scale_bias" + score_mod_kwargs = { + "score_mod": _score_mod_post_scale_bias, + "score_mod_bprop": _score_mod_identity_bprop, + } + ref_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="no_mask", + layer_number=1, + ).to(dtype=dtype, device="cuda") + out_ref = ref_attn( + q_ref, + k_ref, + v_ref, + qkv_format=qkv_format, + attn_mask_type="no_mask", + core_attention_bias_type="post_scale_bias", + core_attention_bias=_post_scale_bias(config, dtype), + ) + tols = dict(atol=5e-2, rtol=5e-2) + + out = flex_attn( + q, + k, + v, + qkv_format=qkv_format, + attn_mask_type="no_mask", + **score_mod_kwargs, + ) + + if scalar_loss: + out.sum().backward() + out_ref.sum().backward() + else: + d_out = torch.randn_like(out) + out.backward(d_out) + out_ref.backward(d_out) + + torch.testing.assert_close(out, out_ref, **tols) + torch.testing.assert_close(q.grad, q_ref.grad, **tols) + torch.testing.assert_close(k.grad, k_ref.grad, **tols) + torch.testing.assert_close(v.grad, v_ref.grad, **tols) + + model_configs_te_layer = { # test: ModelConfig(b, sq, hq, dqk) "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"), diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 32e44be2af..16f6c08bcf 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -329,6 +329,8 @@ def get_available_attention_backends( fp8_meta: Optional[Dict[str, Any]] = None, is_training: bool = True, inference_params: Optional[InferenceParams] = None, + score_mod: bool = False, + score_mod_bprop: bool = False, ) -> Tuple[List, List]: """Check for all available attention backends that support a model configuration""" @@ -390,6 +392,8 @@ def test(): inference_params=inference_params, softmax_type=config.softmax_type, return_max_logit=config.return_max_logit, + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, # allow all backends to pass so they can be used for testing; # check for FA3 availability later num_splits=1, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 79ebbd4afa..70ba826138 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -4,9 +4,14 @@ """Attention Backends.""" from contextlib import nullcontext +from dataclasses import dataclass +import importlib +import inspect from importlib.metadata import version as get_pkg_version from importlib.metadata import PackageNotFoundError import os +from pathlib import Path +import sys from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings import logging @@ -89,6 +94,25 @@ _flash_attn_bwd = None _flash_attn_varlen_fwd = None _flash_attn_varlen_bwd = None +_cudnn_score_mod_handles: Dict[torch.device, Any] = {} +_cudnn_score_mod_graph_cache: Dict[Tuple[Any, ...], Any] = {} +_SCORE_MOD_UNCACHEABLE = object() +_CUDNN_FRONTEND_PYTHON_PATH = ( + Path(__file__).resolve().parents[4] / "3rdparty" / "cudnn-frontend" / "python" +) + + +def _import_cudnn_frontend(): + """Import the vendored cuDNN frontend if built, otherwise use the installed package.""" + cudnn_frontend_path = str(_CUDNN_FRONTEND_PYTHON_PATH) + cudnn_frontend_package = _CUDNN_FRONTEND_PYTHON_PATH / "cudnn" + if ( + any(cudnn_frontend_package.glob("_compiled_module*")) + and cudnn_frontend_path not in sys.path + ): + sys.path.insert(0, cudnn_frontend_path) + return importlib.import_module("cudnn") + # Try to import Flash Attention v2 try: @@ -1244,6 +1268,787 @@ def convert_to_torch_float8(tensor, dtype): return output.contiguous() +def _bhsd_dim_stride( + tensor: torch.Tensor, tensor_format: str +) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """Describe an SBHD/BSHD tensor as cuDNN frontend's logical BHSD format.""" + if tensor_format == "sbhd": + return ( + (tensor.shape[1], tensor.shape[2], tensor.shape[0], tensor.shape[3]), + (tensor.stride(1), tensor.stride(2), tensor.stride(0), tensor.stride(3)), + ) + if tensor_format == "bshd": + return ( + (tensor.shape[0], tensor.shape[2], tensor.shape[1], tensor.shape[3]), + (tensor.stride(0), tensor.stride(2), tensor.stride(1), tensor.stride(3)), + ) + raise ValueError(f"score_mod only supports SBHD/BSHD tensor formats, got {tensor_format}.") + + +def _bhsd_graph_tensor(graph, tensor: torch.Tensor, tensor_format: str): + """Create a cuDNN graph tensor with BHSD dims and TE-layout strides.""" + dim, stride = _bhsd_dim_stride(tensor, tensor_format) + return graph.tensor(dim=dim, stride=stride, data_type=tensor.dtype) + + +# score_mod graph cache helpers. +def _freeze_score_mod_cache_key(value: Any) -> Any: + """Convert a user-provided score_mod graph key into a hashable structure.""" + if isinstance(value, torch.Tensor): + raise TypeError( + "score_mod_graph_cache_key() must not include tensors. Pass runtime tensors " + "through score_mod_tensors or score_mod_bprop_tensors instead." + ) + if isinstance(value, dict): + items = ( + ( + _freeze_score_mod_cache_key(key), + _freeze_score_mod_cache_key(val), + ) + for key, val in value.items() + ) + return tuple(sorted(items, key=repr)) + if isinstance(value, (list, tuple)): + return tuple(_freeze_score_mod_cache_key(item) for item in value) + if isinstance(value, (set, frozenset)): + items = (_freeze_score_mod_cache_key(item) for item in value) + return tuple(sorted(items, key=repr)) + try: + hash(value) + except TypeError as exc: + raise TypeError( + "score_mod_graph_cache_key() must return a hashable value or a nested " + "combination of dict/list/tuple/set values." + ) from exc + return value + + +def _score_mod_explicit_cache_key(callback_owner: Any) -> Optional[Any]: + """Return a user-provided structural graph key for a score_mod callback.""" + explicit_key = getattr(callback_owner, "score_mod_graph_cache_key", None) + if explicit_key is None: + return None + explicit_key = explicit_key() if callable(explicit_key) else explicit_key + return _freeze_score_mod_cache_key(explicit_key) + + +def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Any: + """Create a stable graph cache key for a score_mod callable. + + Module-level named functions are assumed to have stable topology. Anonymous functions + are keyed by code object because lambdas in the same module can share the same + qualname. Stateful bound methods and callable instances need an explicit + score_mod_graph_cache_key(); otherwise their graphs are left uncached to avoid reusing + stale graphs after Python object address reuse. + """ + if callback is None: + return None + self_obj = getattr(callback, "__self__", None) + func_obj = getattr(callback, "__func__", None) + if self_obj is not None and func_obj is not None: + explicit_key = _score_mod_explicit_cache_key(self_obj) + if explicit_key is None: + return _SCORE_MOD_UNCACHEABLE + return ( + "bound_method", + type(self_obj), + func_obj.__module__, + func_obj.__qualname__, + explicit_key, + ) + + explicit_key = _score_mod_explicit_cache_key(callback) + if explicit_key is not None: + return ( + "callable", + type(callback), + getattr(callback, "__module__", None), + getattr(callback, "__qualname__", None), + explicit_key, + ) + + if ( + inspect.isfunction(callback) + and callback.__closure__ is None + and "" not in callback.__qualname__ + ): + if callback.__name__ == "" or not callback.__qualname__: + return ("function", callback.__module__, callback.__code__) + return ("function", callback.__module__, callback.__qualname__) + + return _SCORE_MOD_UNCACHEABLE + + +def _score_mod_device_key(device: torch.device) -> Tuple[Any, ...]: + """Normalize a tensor device for graph cache keys.""" + if device.type == "cuda": + index = device.index + if index is None: + index = torch.cuda.current_device() + return (device.type, index) + return (device.type, device.index) + + +def _score_mod_tensor_metadata(tensor: torch.Tensor) -> Tuple[Any, ...]: + """Describe tensor metadata that can affect cuDNN graph construction.""" + return ( + tuple(tensor.size()), + tuple(tensor.stride()), + tensor.dtype, + _score_mod_device_key(tensor.device), + ) + + +def _score_mod_tensor_dict_metadata( + tensors: Optional[Dict[str, torch.Tensor]], +) -> Tuple[Tuple[str, Tuple[Any, ...]], ...]: + """Describe score_mod tensor parameters without including their values.""" + if tensors is None: + return () + return tuple((name, _score_mod_tensor_metadata(tensor)) for name, tensor in tensors.items()) + + +def _score_mod_bhsd_tensor_metadata(tensor: torch.Tensor, tensor_format: str) -> Tuple[Any, ...]: + """Describe an SBHD/BSHD runtime tensor as a cuDNN BHSD graph tensor.""" + dim, stride = _bhsd_dim_stride(tensor, tensor_format) + return (dim, stride, tensor.dtype, _score_mod_device_key(tensor.device)) + + +def _make_cudnn_graph_tensor_dict(graph, tensors: Optional[Dict[str, torch.Tensor]]): + """Create cuDNN graph tensors matching runtime tensors.""" + if tensors is None: + return {} + return {name: graph.tensor_like(tensor) for name, tensor in tensors.items()} + + +# score_mod cuDNN frontend graph helpers. +def _wrap_score_mod(score_mod: Optional[Callable], graph_tensors: Dict[str, Any]): + """Adapt TE's score_mod signature to cuDNN frontend's two-argument callback.""" + if score_mod is None: + return None + + def _wrapped_score_mod(sdpa_graph, score_tensor): + return score_mod(sdpa_graph, score_tensor, graph_tensors) + + return _wrapped_score_mod + + +def _get_cudnn_current_stream_handle(cudnn, device: torch.device): + """Return a cuDNN handle for device, bound to PyTorch's current stream.""" + if device.type != "cuda": + raise ValueError(f"score_mod only supports CUDA tensors, got device {device}.") + if device.index is None: + device = torch.device("cuda", torch.cuda.current_device()) + + handle = _cudnn_score_mod_handles.get(device) + with torch.cuda.device(device): + if handle is None: + handle = cudnn.create_handle() + _cudnn_score_mod_handles[device] = handle + + stream = torch.cuda.current_stream(device).cuda_stream + cudnn.set_stream(handle=handle, stream=stream) + return handle + + +def _build_cudnn_pygraph(dtype: torch.dtype, device: torch.device): + """Create a cuDNN frontend Python graph for F16/BF16 SDPA.""" + cudnn = _import_cudnn_frontend() + + if dtype == torch.float16: + io_data_type = cudnn.data_type.HALF + elif dtype == torch.bfloat16: + io_data_type = cudnn.data_type.BFLOAT16 + else: + raise ValueError(f"score_mod only supports FP16/BF16 tensors, got {dtype}.") + + graph = cudnn.pygraph( + io_data_type=io_data_type, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=_get_cudnn_current_stream_handle(cudnn, device), + ) + return cudnn, graph + + +@dataclass +class _CudnnScoreModFwdGraphEntry: + """Cached cuDNN frontend graph and graph tensor handles for score_mod fprop.""" + + graph: Any + q: Any + k: Any + v: Any + output: Any + stats: Optional[Any] + score_mod_graph_tensors: Dict[str, Any] + workspace_size: int + + +@dataclass +class _CudnnScoreModBwdGraphEntry: + """Cached cuDNN frontend graph and graph tensor handles for score_mod bprop.""" + + graph: Any + q: Any + k: Any + v: Any + output: Any + d_output: Any + stats: Any + dq: Any + dk: Any + dv: Any + score_mod_graph_tensors: Dict[str, Any] + score_mod_bprop_graph_tensors: Dict[str, Any] + workspace_size: int + + +def _finalize_cudnn_graph(graph) -> int: + """Build a cuDNN frontend Python graph and return its workspace size.""" + cudnn = _import_cudnn_frontend() + + graph.validate() + graph.build_operation_graph() + try: + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + except cudnn.cudnnGraphNotSupportedError as exc: + raise RuntimeError(f"cuDNN score_mod SDPA graph is not supported: {exc}") from exc + graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE) + return max(graph.get_workspace_size(), 1) + + +def _execute_cudnn_graph( + graph, + variant_pack: Dict[Any, torch.Tensor], + workspace_size: int, + device: torch.device, +): + """Execute a built cuDNN frontend Python graph.""" + cudnn = _import_cudnn_frontend() + + if device.type == "cuda" and device.index is None: + device = torch.device("cuda", torch.cuda.current_device()) + workspace = torch.empty( + workspace_size, + device=device, + dtype=torch.uint8, + ) + graph.execute( + variant_pack, + workspace, + handle=_get_cudnn_current_stream_handle(cudnn, device), + ) + + +def _cudnn_score_mod_fwd_cache_key( + is_training: bool, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + output_layer: torch.Tensor, + stats: Optional[torch.Tensor], + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_tensors: Optional[Dict[str, torch.Tensor]], +) -> Optional[Tuple[Any, ...]]: + """Cache key for score_mod fprop execution plans.""" + score_mod_key = _score_mod_callback_cache_key(score_mod) + if score_mod_key is _SCORE_MOD_UNCACHEABLE: + return None + return ( + "fwd", + is_training, + q_format, + kv_format, + attn_scale, + score_mod_key, + _score_mod_bhsd_tensor_metadata(query_layer, q_format), + _score_mod_bhsd_tensor_metadata(key_layer, kv_format), + _score_mod_bhsd_tensor_metadata(value_layer, kv_format), + _score_mod_bhsd_tensor_metadata(output_layer, q_format), + _score_mod_tensor_metadata(stats) if stats is not None else None, + _score_mod_tensor_dict_metadata(score_mod_tensors), + ) + + +def _cudnn_score_mod_bwd_cache_key( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + output_layer: torch.Tensor, + d_out: torch.Tensor, + stats: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], + deterministic: bool, +) -> Optional[Tuple[Any, ...]]: + """Cache key for score_mod bprop execution plans.""" + score_mod_key = _score_mod_callback_cache_key(score_mod) + score_mod_bprop_key = _score_mod_callback_cache_key(score_mod_bprop) + if score_mod_key is _SCORE_MOD_UNCACHEABLE or score_mod_bprop_key is _SCORE_MOD_UNCACHEABLE: + return None + return ( + "bwd", + q_format, + kv_format, + attn_scale, + deterministic, + score_mod_key, + score_mod_bprop_key, + _score_mod_bhsd_tensor_metadata(query_layer, q_format), + _score_mod_bhsd_tensor_metadata(key_layer, kv_format), + _score_mod_bhsd_tensor_metadata(value_layer, kv_format), + _score_mod_bhsd_tensor_metadata(output_layer, q_format), + _score_mod_bhsd_tensor_metadata(d_out, q_format), + _score_mod_tensor_metadata(stats), + _score_mod_tensor_dict_metadata(score_mod_tensors), + _score_mod_tensor_dict_metadata(score_mod_bprop_tensors), + ) + + +def _build_cudnn_score_mod_fwd_graph( + is_training: bool, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + output_layer: torch.Tensor, + stats: Optional[torch.Tensor], +) -> _CudnnScoreModFwdGraphEntry: + """Build a cached cuDNN frontend graph for score_mod fprop.""" + cudnn = _import_cudnn_frontend() + + _, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) + q = _bhsd_graph_tensor(graph, query_layer, q_format) + k = _bhsd_graph_tensor(graph, key_layer, kv_format) + v = _bhsd_graph_tensor(graph, value_layer, kv_format) + + score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, score_mod_tensors) + wrapped_score_mod = _wrap_score_mod(score_mod, score_mod_graph_tensors) + + output_dim, output_stride = _bhsd_dim_stride(output_layer, q_format) + output, stats_tensor = graph.sdpa( + name="te_score_mod_sdpa", + q=q, + k=k, + v=v, + generate_stats=is_training, + attn_scale=attn_scale, + use_causal_mask=False, + score_mod=wrapped_score_mod, + ) + output.set_output(True).set_dim(output_dim).set_stride(output_stride) + + if is_training: + assert stats is not None + stats_tensor.set_output(True).set_dim(stats.size()).set_stride( + stats.stride() + ).set_data_type(cudnn.data_type.FLOAT) + else: + stats_tensor = None + + workspace_size = _finalize_cudnn_graph(graph) + return _CudnnScoreModFwdGraphEntry( + graph=graph, + q=q, + k=k, + v=v, + output=output, + stats=stats_tensor, + score_mod_graph_tensors=score_mod_graph_tensors, + workspace_size=workspace_size, + ) + + +def _get_cudnn_score_mod_fwd_graph( + is_training: bool, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + output_layer: torch.Tensor, + stats: Optional[torch.Tensor], +) -> _CudnnScoreModFwdGraphEntry: + """Return a cached cuDNN frontend graph for score_mod fprop.""" + key = _cudnn_score_mod_fwd_cache_key( + is_training, + query_layer, + key_layer, + value_layer, + output_layer, + stats, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + ) + if key is None: + return _build_cudnn_score_mod_fwd_graph( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats, + ) + entry = _cudnn_score_mod_graph_cache.get(key) + if entry is None: + entry = _build_cudnn_score_mod_fwd_graph( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats, + ) + _cudnn_score_mod_graph_cache[key] = entry + return entry + + +def _build_cudnn_score_mod_bwd_graph( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + output_layer: torch.Tensor, + d_out: torch.Tensor, + stats: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], + deterministic: bool, +) -> _CudnnScoreModBwdGraphEntry: + """Build a cached cuDNN frontend graph for score_mod bprop.""" + _, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) + q = _bhsd_graph_tensor(graph, query_layer, q_format) + k = _bhsd_graph_tensor(graph, key_layer, kv_format) + v = _bhsd_graph_tensor(graph, value_layer, kv_format) + output = _bhsd_graph_tensor(graph, output_layer, q_format) + d_output = _bhsd_graph_tensor(graph, d_out, q_format) + stats_tensor = graph.tensor_like(stats) + + score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, score_mod_tensors) + score_mod_bprop_graph_tensors = ( + _make_cudnn_graph_tensor_dict(graph, score_mod_bprop_tensors) + if score_mod_bprop is not None + else {} + ) + wrapped_score_mod = _wrap_score_mod(score_mod, score_mod_graph_tensors) + wrapped_score_mod_bprop = _wrap_score_mod(score_mod_bprop, score_mod_bprop_graph_tensors) + + dq_layer = torch.empty_like(query_layer) + dk_layer = torch.empty_like(key_layer) + dv_layer = torch.empty_like(value_layer) + dq_dim, dq_stride = _bhsd_dim_stride(dq_layer, q_format) + dk_dim, dk_stride = _bhsd_dim_stride(dk_layer, kv_format) + dv_dim, dv_stride = _bhsd_dim_stride(dv_layer, kv_format) + dq, dk, dv = graph.sdpa_backward( + name="te_score_mod_sdpa_backward", + q=q, + k=k, + v=v, + o=output, + dO=d_output, + stats=stats_tensor, + attn_scale=attn_scale, + use_causal_mask=False, + score_mod=wrapped_score_mod, + score_mod_bprop=wrapped_score_mod_bprop, + use_deterministic_algorithm=deterministic, + ) + dq.set_output(True).set_dim(dq_dim).set_stride(dq_stride) + dk.set_output(True).set_dim(dk_dim).set_stride(dk_stride) + dv.set_output(True).set_dim(dv_dim).set_stride(dv_stride) + + workspace_size = _finalize_cudnn_graph(graph) + return _CudnnScoreModBwdGraphEntry( + graph=graph, + q=q, + k=k, + v=v, + output=output, + d_output=d_output, + stats=stats_tensor, + dq=dq, + dk=dk, + dv=dv, + score_mod_graph_tensors=score_mod_graph_tensors, + score_mod_bprop_graph_tensors=score_mod_bprop_graph_tensors, + workspace_size=workspace_size, + ) + + +def _get_cudnn_score_mod_bwd_graph( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + output_layer: torch.Tensor, + d_out: torch.Tensor, + stats: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], + deterministic: bool, +) -> _CudnnScoreModBwdGraphEntry: + """Return a cached cuDNN frontend graph for score_mod bprop.""" + key = _cudnn_score_mod_bwd_cache_key( + query_layer, + key_layer, + value_layer, + output_layer, + d_out, + stats, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + deterministic, + ) + if key is None: + return _build_cudnn_score_mod_bwd_graph( + query_layer, + key_layer, + value_layer, + output_layer, + d_out, + stats, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + deterministic, + ) + entry = _cudnn_score_mod_graph_cache.get(key) + if entry is None: + entry = _build_cudnn_score_mod_bwd_graph( + query_layer, + key_layer, + value_layer, + output_layer, + d_out, + stats, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + deterministic, + ) + _cudnn_score_mod_graph_cache[key] = entry + return entry + + +class FusedAttentionWithScoreModFunc(torch.autograd.Function): + """cuDNN frontend Python SDPA path with score_mod callback support.""" + + @staticmethod + def forward( + ctx, + is_training: bool, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], + deterministic: bool, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + q_bhsd_dim, _ = _bhsd_dim_stride(query_layer, q_format) + score_mod_tensors = dict(score_mod_tensors or {}) + score_mod_bprop_tensors = dict(score_mod_bprop_tensors or {}) + output_shape = (*query_layer.shape[:-1], value_layer.shape[-1]) + output_layer = torch.empty(output_shape, device=query_layer.device, dtype=query_layer.dtype) + if is_training: + stats = torch.empty( + (*q_bhsd_dim[:-1], 1), + device=query_layer.device, + dtype=torch.float32, + ) + else: + stats = None + + entry = _get_cudnn_score_mod_fwd_graph( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats, + ) + variant_pack = { + entry.q: query_layer, + entry.k: key_layer, + entry.v: value_layer, + entry.output: output_layer, + } + if is_training: + variant_pack[entry.stats] = stats + for name, graph_tensor in entry.score_mod_graph_tensors.items(): + variant_pack[graph_tensor] = score_mod_tensors[name] + + _execute_cudnn_graph( + entry.graph, + variant_pack, + entry.workspace_size, + query_layer.device, + ) + + ctx.is_training = is_training + ctx.q_format = q_format + ctx.kv_format = kv_format + ctx.attn_scale = attn_scale + ctx.score_mod = score_mod + ctx.score_mod_bprop = score_mod_bprop + ctx.score_mod_tensor_names = tuple(score_mod_tensors.keys()) + ctx.score_mod_bprop_tensor_names = tuple(score_mod_bprop_tensors.keys()) + ctx.deterministic = deterministic + if is_training: + # save_for_backward records version counters without copying tensor data. + # This catches in-place score_mod tensor updates before backward. + ctx.save_for_backward( + query_layer, + key_layer, + value_layer, + output_layer, + stats, + *score_mod_tensors.values(), + *score_mod_bprop_tensors.values(), + ) + else: + ctx.save_for_backward(query_layer, key_layer, value_layer, output_layer) + + return output_layer + + @staticmethod + def backward(ctx, d_out: torch.Tensor): + # pylint: disable=missing-function-docstring + if not ctx.is_training: + raise RuntimeError( + "score_mod backward requires DotProductAttention to be in training mode." + ) + + saved_tensors = ctx.saved_tensors + query_layer, key_layer, value_layer, output_layer, stats = saved_tensors[:5] + score_mod_tensors_end = 5 + len(ctx.score_mod_tensor_names) + score_mod_tensors = dict( + zip(ctx.score_mod_tensor_names, saved_tensors[5:score_mod_tensors_end]) + ) + score_mod_bprop_tensors = dict( + zip(ctx.score_mod_bprop_tensor_names, saved_tensors[score_mod_tensors_end:]) + ) + d_out = d_out.contiguous() + + dq_layer = torch.empty_like(query_layer) + dk_layer = torch.empty_like(key_layer) + dv_layer = torch.empty_like(value_layer) + entry = _get_cudnn_score_mod_bwd_graph( + query_layer, + key_layer, + value_layer, + output_layer, + d_out, + stats, + ctx.q_format, + ctx.kv_format, + ctx.attn_scale, + ctx.score_mod, + ctx.score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + ctx.deterministic, + ) + variant_pack = { + entry.q: query_layer, + entry.k: key_layer, + entry.v: value_layer, + entry.output: output_layer, + entry.d_output: d_out, + entry.stats: stats, + entry.dq: dq_layer, + entry.dk: dk_layer, + entry.dv: dv_layer, + } + for name, graph_tensor in entry.score_mod_graph_tensors.items(): + variant_pack[graph_tensor] = score_mod_tensors[name] + for name, graph_tensor in entry.score_mod_bprop_graph_tensors.items(): + variant_pack[graph_tensor] = score_mod_bprop_tensors[name] + + _execute_cudnn_graph( + entry.graph, + variant_pack, + entry.workspace_size, + query_layer.device, + ) + + return ( + None, + dq_layer, + dk_layer, + dv_layer, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + class FusedAttnFunc(torch.autograd.Function): """FusedAttention forward and backward implementation""" @@ -1945,6 +2750,10 @@ def forward( inference_params: Optional[InferenceParams] = None, softmax_offset: torch.Tensor = None, fp8_output: bool = False, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Dict[str, torch.Tensor]] = None, + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: """fused attention fprop""" assert ( @@ -2068,6 +2877,16 @@ def forward( ) if context_parallel: + assert score_mod is None, "score_mod is not supported with context parallelism!" + assert ( + score_mod_bprop is None + ), "score_mod_bprop is not supported with context parallelism!" + assert ( + score_mod_tensors is None + ), "score_mod_tensors is not supported with context parallelism!" + assert ( + score_mod_bprop_tensors is None + ), "score_mod_bprop_tensors is not supported with context parallelism!" assert ( fp8 or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen @@ -2113,6 +2932,39 @@ def forward( layer_number=self.layer_number, return_max_logit=self.return_max_logit, ) + elif score_mod is not None: + assert not fp8, "score_mod is not supported with FP8 FusedAttention!" + assert not fp8_output, "score_mod is not supported with fp8_output!" + assert not self.return_max_logit, "score_mod is not supported with return_max_logit!" + assert ( + type(query_layer) is torch.Tensor + and type(key_layer) is torch.Tensor + and type(value_layer) is torch.Tensor + ), "score_mod only supports unquantized torch.Tensor Q, K and V inputs!" + assert ( + fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + ), "score_mod requires the F16/BF16 cuDNN fused attention backend!" + assert ( + attn_mask_type == "no_mask" + and core_attention_bias_type == "no_bias" + and core_attention_bias is None + and self.softmax_type == "vanilla" + and self.attention_dropout == 0.0 + ), "score_mod is mutually exclusive with masks, bias, sink attention and dropout!" + output = FusedAttentionWithScoreModFunc.apply( + self.training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + self.softmax_scale, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + self.deterministic, + ) else: with self.attention_dropout_ctx(): output = FusedAttnFunc.apply( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 17e9a337a4..a9d1a48f20 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -892,6 +892,10 @@ def forward( pad_between_seqs: Optional[bool] = None, fp8_output: Optional[bool] = False, num_splits: Optional[int] = 1, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Dict[str, torch.Tensor]] = None, + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: r""" Dot Product Attention Layer. @@ -1080,6 +1084,18 @@ def forward( Optional split control for FlashAttention-3 only. When set, this value is forwarded to the FA3 backend to control internal kernel splitting behavior for non-context-parallel cases. It is ignored for other backends and when context parallelism is enabled. + score_mod: Optional[Callable], default = None + Experimental cuDNN frontend score modification callback. This is a cuDNN-only path + and is mutually exclusive with masks, bias, ALiBi, sink attention, dropout, FP8, + context parallelism, THD format, KV caching, and return_max_logit. The callback + signature is ``score_mod(graph, score, tensors) -> score``. + score_mod_bprop: Optional[Callable], default = None + Optional cuDNN frontend callback for the backward pass of score_mod. The callback + signature is ``score_mod_bprop(graph, dP, tensors) -> dP``. + score_mod_tensors: Optional[Dict[str, torch.Tensor]], default = None + Runtime tensors exposed to score_mod as cuDNN graph tensors. + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], default = None + Runtime tensors exposed to score_mod_bprop as cuDNN graph tensors. """ with self.prepare_forward_ctx( @@ -1088,6 +1104,13 @@ def forward( allow_non_contiguous=True, allow_different_data_and_param_types=self.softmax_type != "vanilla", ) as query_layer: + user_supplied_seqlens = ( + cu_seqlens_q is not None + or cu_seqlens_kv is not None + or cu_seqlens_q_padded is not None + or cu_seqlens_kv_padded is not None + ) + # checks for RNG if self.rng_states_tracker is not None and is_graph_capturing(): assert isinstance( @@ -1226,6 +1249,9 @@ def forward( seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) + if score_mod is not None: + assert inference_params is None, "score_mod is not supported with KV caching!" + # update KV cache and retrieve saved tokens from cache for inference if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" @@ -1406,6 +1432,84 @@ def forward( else: pad_between_seqs = False + if score_mod is None: + assert score_mod_bprop is None, "score_mod_bprop requires score_mod!" + assert score_mod_tensors is None, "score_mod_tensors requires score_mod!" + assert ( + score_mod_bprop_tensors is None + ), "score_mod_bprop_tensors requires score_mod!" + else: + assert callable(score_mod), "score_mod must be callable!" + assert score_mod_bprop is None or callable( + score_mod_bprop + ), "score_mod_bprop must be callable when provided!" + assert query_layer.dtype in [ + torch.float16, + torch.bfloat16, + ], "score_mod only supports FP16 and BF16 tensors!" + assert ( + key_layer.dtype == query_layer.dtype and value_layer.dtype == query_layer.dtype + ), "score_mod requires Q, K and V tensors to have the same dtype!" + assert ( + type(query_layer) is torch.Tensor + and type(key_layer) is torch.Tensor + and type(value_layer) is torch.Tensor + ), "score_mod only supports unquantized torch.Tensor Q, K and V inputs!" + assert not self.fp8, "score_mod is not supported with FP8 DotProductAttention!" + assert not fp8_output, "score_mod is not supported with fp8_output!" + assert not context_parallel, "score_mod is not supported with context parallelism!" + assert qkv_format != "thd", "score_mod is not supported with qkv_format='thd'!" + assert ( + not user_supplied_seqlens + ), "score_mod is mutually exclusive with explicit sequence length metadata!" + assert not pad_between_seqs, "score_mod is not supported with pad_between_seqs!" + assert ( + attention_mask is None + ), "score_mod is mutually exclusive with attention_mask!" + assert attn_mask_type == "no_mask", "score_mod requires attn_mask_type='no_mask'!" + assert window_size is None or window_size == ( + -1, + -1, + ), "score_mod is mutually exclusive with sliding window attention!" + assert ( + core_attention_bias_type == "no_bias" and core_attention_bias is None + ), "score_mod is mutually exclusive with attention bias!" + assert alibi_slopes is None, "score_mod is mutually exclusive with ALiBi!" + assert ( + self.softmax_type == "vanilla" + ), "score_mod is mutually exclusive with sink attention!" + assert ( + self.attention_dropout == 0.0 + ), "score_mod is not supported with attention dropout!" + assert ( + not self.return_max_logit + ), "score_mod is not supported with return_max_logit!" + assert ( + not checkpoint_core_attention + ), "score_mod is not supported with checkpoint_core_attention!" + assert ( + not is_graph_capturing() + ), "score_mod is not supported with CUDA graph capture!" + assert num_splits == 1, "score_mod is not supported with num_splits != 1!" + assert q_format in ["sbhd", "bshd"] and kv_format in [ + "sbhd", + "bshd", + ], "score_mod only supports SBHD/BSHD QKV formats!" + if score_mod_tensors is not None: + assert isinstance(score_mod_tensors, dict), "score_mod_tensors must be a dict!" + assert all( + isinstance(k, str) and isinstance(v, torch.Tensor) + for k, v in score_mod_tensors.items() + ), "score_mod_tensors must map string names to torch.Tensor instances!" + if score_mod_bprop_tensors is not None: + assert isinstance( + score_mod_bprop_tensors, dict + ), "score_mod_bprop_tensors must be a dict!" + assert all( + isinstance(k, str) and isinstance(v, torch.Tensor) + for k, v in score_mod_bprop_tensors.items() + ), "score_mod_bprop_tensors must map string names to torch.Tensor instances!" + # gather attention params for get_attention_backend attention_params = dpa_utils.AttentionParams( qkv_type=type(query_layer), @@ -1436,14 +1540,21 @@ def forward( is_training=self.training, fp8=self.fp8, fp8_meta=self.fp8_meta, + fp8_output=fp8_output, inference_params=inference_params, softmax_type=self.softmax_type, return_max_logit=self.return_max_logit, + checkpoint_core_attention=checkpoint_core_attention, cuda_graph=is_graph_capturing(), num_splits=num_splits, + has_attention_mask=attention_mask is not None, + has_core_attention_bias=core_attention_bias is not None, + user_supplied_seqlens=user_supplied_seqlens, + score_mod=score_mod is not None, + score_mod_bprop=score_mod_bprop is not None, ) global _attention_backends - if is_in_onnx_export_mode(): + if is_in_onnx_export_mode() and score_mod is None: # We do not want to call get_attention_backend() in ONNX mode # and we want to avoid using any global variables like _attention_backends. use_flash_attention = False @@ -1619,6 +1730,10 @@ def forward( inference_params=inference_params, softmax_offset=softmax_offset, fp8_output=fp8_output, + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, ) if use_unfused_attention: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 7df5daabe5..0b6a2b85d5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -246,16 +246,30 @@ class AttentionParams: Whether `DotProductAttention` is in an `autocast` region. fp8_meta : Optional[Dict[str Any]], default = None The FP8 metadata tensor of `DotProductAttention`. + fp8_output : bool, default = False + Whether output is requested in FP8. inference_params : Optional[InferenceParams], default = None Inference-related parameters. See InferenceParams for details. softmax_type : str, default = "vanilla" The type of softmax operation. See DotProductAttention for details. return_max_logit : bool, default = False Whether to output max_logit. + checkpoint_core_attention : bool, default = False + Whether core attention is recomputed during backward. cuda_graph : bool, default = `False` Whether support for cuda graph capture is needed or not. num_splits : int, default = 1 The number of kernels to split attention to. + has_attention_mask : bool, default = False + Whether an explicit attention mask tensor was provided. + has_core_attention_bias : bool, default = False + Whether an explicit core attention bias tensor was provided. + user_supplied_seqlens : bool, default = False + Whether explicit cu_seqlens metadata was provided. + score_mod : bool, default = False + Whether a score_mod callback was provided. + score_mod_bprop : bool, default = False + Whether a score_mod bprop callback was provided. """ qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor @@ -284,11 +298,18 @@ class AttentionParams: is_training: bool = True fp8: bool = False fp8_meta: Union[Dict[str, Any], None] = None + fp8_output: bool = False inference_params: Optional[InferenceParams] = None softmax_type: str = "vanilla" return_max_logit: bool = False + checkpoint_core_attention: bool = False cuda_graph: bool = False num_splits: int = 1 + has_attention_mask: bool = False + has_core_attention_bias: bool = False + user_supplied_seqlens: bool = False + score_mod: bool = False + score_mod_bprop: bool = False def __eq__(self, other): """ @@ -362,11 +383,18 @@ def get_attention_backend( is_training = attention_params.is_training fp8 = attention_params.fp8 fp8_meta = attention_params.fp8_meta + fp8_output = attention_params.fp8_output inference_params = attention_params.inference_params softmax_type = attention_params.softmax_type return_max_logit = attention_params.return_max_logit + checkpoint_core_attention = attention_params.checkpoint_core_attention cuda_graph = attention_params.cuda_graph num_splits = attention_params.num_splits + has_attention_mask = attention_params.has_attention_mask + has_core_attention_bias = attention_params.has_core_attention_bias + user_supplied_seqlens = attention_params.user_supplied_seqlens + score_mod = attention_params.score_mod + score_mod_bprop = attention_params.score_mod_bprop # Run config logger = logging.getLogger("DotProductAttention") @@ -432,7 +460,7 @@ def get_attention_backend( # regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is # necessary for performance/functionality, a warning will be issued to prompt users to # install an appropriate FA version. - qkv_format, q_format, _ = get_qkv_format(qkv_layout, inference_params) + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) # Filter: Environment variables use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) @@ -647,6 +675,90 @@ def get_attention_backend( use_unfused_attention = False logger.debug("Disabling all backends for max_logit with FP8 attention") + # Filter: score_mod + if score_mod_bprop and not score_mod: + logger.debug("Disabling all backends because score_mod_bprop requires score_mod") + use_flash_attention = False + use_flash_attention_2 = False + use_flash_attention_3 = False + use_flash_attention_4 = False + use_fused_attention = False + use_unfused_attention = False + if score_mod: + if use_flash_attention_2 or use_flash_attention_3 or use_flash_attention_4: + logger.debug("Disabling FlashAttention for score_mod") + use_flash_attention = False + use_flash_attention_2 = False + use_flash_attention_3 = False + use_flash_attention_4 = False + if use_unfused_attention: + logger.debug("Disabling UnfusedDotProductAttention for score_mod") + use_unfused_attention = False + + score_mod_unsupported_reasons = [] + if qkv_dtype not in [torch.float16, torch.bfloat16]: + score_mod_unsupported_reasons.append( + f"unsupported qkv_dtype = {qkv_dtype}; supported: torch.float16, torch.bfloat16" + ) + if qkv_type is not torch.Tensor: + score_mod_unsupported_reasons.append( + f"unsupported qkv_type = {qkv_type}; supported: torch.Tensor" + ) + if fp8: + score_mod_unsupported_reasons.append("FP8 DotProductAttention is enabled") + if fp8_output: + score_mod_unsupported_reasons.append("fp8_output is enabled") + if inference_params is not None: + score_mod_unsupported_reasons.append("KV caching is enabled") + if context_parallel: + score_mod_unsupported_reasons.append("context parallelism is enabled") + if ( + qkv_format == "thd" + or q_format not in ["sbhd", "bshd"] + or kv_format + not in [ + "sbhd", + "bshd", + ] + ): + score_mod_unsupported_reasons.append( + f"unsupported QKV format: q_format = {q_format}, kv_format = {kv_format}" + ) + if user_supplied_seqlens: + score_mod_unsupported_reasons.append("explicit sequence length metadata was provided") + if pad_between_seqs: + score_mod_unsupported_reasons.append("pad_between_seqs is enabled") + if has_attention_mask: + score_mod_unsupported_reasons.append("attention_mask was provided") + if attn_mask_type != "no_mask": + score_mod_unsupported_reasons.append(f"attn_mask_type = {attn_mask_type}") + if window_size is not None and window_size != (-1, -1): + score_mod_unsupported_reasons.append(f"window_size = {window_size}") + if core_attention_bias_type != "no_bias" or has_core_attention_bias: + score_mod_unsupported_reasons.append( + f"core_attention_bias_type = {core_attention_bias_type}" + ) + if alibi_slopes_shape is not None: + score_mod_unsupported_reasons.append("ALiBi slopes were provided") + if softmax_type != "vanilla": + score_mod_unsupported_reasons.append(f"softmax_type = {softmax_type}") + if attention_dropout != 0.0: + score_mod_unsupported_reasons.append(f"attention_dropout = {attention_dropout}") + if return_max_logit: + score_mod_unsupported_reasons.append("return_max_logit is enabled") + if checkpoint_core_attention: + score_mod_unsupported_reasons.append("checkpoint_core_attention is enabled") + if cuda_graph: + score_mod_unsupported_reasons.append("CUDA graph capture is enabled") + if num_splits != 1: + score_mod_unsupported_reasons.append(f"num_splits = {num_splits}") + if score_mod_unsupported_reasons and use_fused_attention: + logger.debug( + "Disabling FusedAttention for score_mod because %s", + "; ".join(score_mod_unsupported_reasons), + ) + use_fused_attention = False + # Filter: KV cache # backend | precision | KV cache | architecture | qkv_format | page_size # --------------------------------------------------------------------------------------- @@ -1250,6 +1362,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention as no backend supports the provided input") use_fused_attention = False fused_attention_backend = None + elif score_mod and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]: + logger.debug( + "Disabling FusedAttention for score_mod because sub-backend %s is not " + "F16/BF16 arbitrary-seqlen", + int(fused_attention_backend), + ) + use_fused_attention = False + fused_attention_backend = None # Filter: Determinism # backend | deterministic # ---------------------------------------------