From 11c3ed21fd82b0451481a7181fdcd0d740bcbf77 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 8 May 2026 21:41:26 +0000 Subject: [PATCH 01/12] Add cuDNN score_mod attention path Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 125 ++++++++ .../dot_product_attention/backends.py | 291 +++++++++++++++++- .../dot_product_attention.py | 141 ++++++++- 3 files changed, 555 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 32ea1694ee..879f48dc0c 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1390,6 +1390,131 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: return out, max_logit, (None, None, None, d_softmax_offset) +def _score_mod_causal(score_mod_graph, score_tensor, tensors): + """cuDNN frontend score_mod implementing top-left causal masking.""" + import cudnn # pylint: disable=import-outside-toplevel + + row_index = score_mod_graph.gen_index(input=score_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = score_mod_graph.gen_index(input=score_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + keep = score_mod_graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + keep.set_data_type(cudnn.data_type.BOOLEAN) + return score_mod_graph.binary_select( + input0=score_tensor, + input1=tensors["neg_inf"], + mask=keep, + ) + + +def _score_mod_causal_bprop(score_mod_graph, dP_tensor, tensors): + """cuDNN frontend score_mod_bprop implementing top-left causal masking.""" + import cudnn # pylint: disable=import-outside-toplevel + + row_index = score_mod_graph.gen_index(input=dP_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = score_mod_graph.gen_index(input=dP_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + keep = score_mod_graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + keep.set_data_type(cudnn.data_type.BOOLEAN) + return score_mod_graph.binary_select( + input0=dP_tensor, + input1=tensors["zero"], + mask=keep, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") +@pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"]) +def test_dot_product_attention_score_mod(dtype, qkv_format): + """Compare score_mod causal masking against standard cuDNN causal attention.""" + try: + import cudnn # pylint: disable=unused-import,import-outside-toplevel + except ImportError: + pytest.skip("cuDNN Python frontend is required for score_mod attention.") + + reset_rng_states() + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + config = ModelConfig(2, 64, 4, 64, attn_mask_type="no_mask") + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", + ) + if not available_backends[1] or not fused_attn_backends: + pytest.skip("FusedAttention is not available for this score_mod configuration.") + + if qkv_format == "sbhd": + q_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + else: + q_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + + q = (0.1 * torch.randn(q_shape, dtype=dtype, device="cuda")).requires_grad_() + k = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + v = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + q_ref, k_ref, v_ref = [x.detach().clone().requires_grad_() for x in (q, k, v)] + + flex_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="no_mask", + layer_number=1, + ).to(dtype=dtype, device="cuda") + ref_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="causal", + layer_number=1, + ).to(dtype=dtype, device="cuda") + + out = flex_attn( + q, + k, + v, + qkv_format=qkv_format, + attn_mask_type="no_mask", + score_mod=_score_mod_causal, + score_mod_bprop=_score_mod_causal_bprop, + score_mod_tensors={"neg_inf": torch.full((1, 1, 1, 1), -1e9)}, + score_mod_bprop_tensors={"zero": torch.full((1, 1, 1, 1), 0.0)}, + ) + out_ref = ref_attn( + q_ref, + k_ref, + v_ref, + qkv_format=qkv_format, + attn_mask_type="causal", + ) + + d_out = torch.randn_like(out) + out.backward(d_out) + out_ref.backward(d_out) + + tols = dict(atol=5e-2, rtol=5e-2) + torch.testing.assert_close(out, out_ref, **tols) + torch.testing.assert_close(q.grad, q_ref.grad, **tols) + torch.testing.assert_close(k.grad, k_ref.grad, **tols) + torch.testing.assert_close(v.grad, v_ref.grad, **tols) + + model_configs_te_layer = { # test: ModelConfig(b, sq, hq, dqk) "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"), diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 79ebbd4afa..e7e7a7d0f2 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1244,6 +1244,255 @@ def convert_to_torch_float8(tensor, dtype): return output.contiguous() +def _format_to_bhsd(tensor: torch.Tensor, tensor_format: str) -> torch.Tensor: + """Convert TE's SBHD/BSHD tensor formats to cuDNN frontend's BHSD format.""" + if tensor_format == "sbhd": + return tensor.permute(1, 2, 0, 3).contiguous() + if tensor_format == "bshd": + return tensor.permute(0, 2, 1, 3).contiguous() + raise ValueError(f"score_mod only supports SBHD/BSHD tensor formats, got {tensor_format}.") + + +def _bhsd_to_format(tensor: torch.Tensor, tensor_format: str) -> torch.Tensor: + """Convert cuDNN frontend's BHSD format back to TE's SBHD/BSHD tensor formats.""" + if tensor_format == "sbhd": + return tensor.permute(2, 0, 1, 3).contiguous() + if tensor_format == "bshd": + return tensor.permute(0, 2, 1, 3).contiguous() + raise ValueError(f"score_mod only supports SBHD/BSHD tensor formats, got {tensor_format}.") + + +def _make_cudnn_graph_tensor_dict(graph, tensors: Optional[Dict[str, torch.Tensor]]): + """Create cuDNN graph tensors matching runtime tensors.""" + if tensors is None: + return {} + return {name: graph.tensor_like(tensor) for name, tensor in tensors.items()} + + +def _wrap_score_mod(score_mod: Optional[Callable], graph_tensors: Dict[str, Any]): + """Adapt TE's score_mod signature to cuDNN frontend's two-argument callback.""" + if score_mod is None: + return None + + def _wrapped_score_mod(sdpa_graph, score_tensor): + return score_mod(sdpa_graph, score_tensor, graph_tensors) + + return _wrapped_score_mod + + +def _build_cudnn_pygraph(dtype: torch.dtype): + """Create a cuDNN frontend Python graph for F16/BF16 SDPA.""" + import cudnn # pylint: disable=import-outside-toplevel + + if dtype == torch.float16: + io_data_type = cudnn.data_type.HALF + elif dtype == torch.bfloat16: + io_data_type = cudnn.data_type.BFLOAT16 + else: + raise ValueError(f"score_mod only supports FP16/BF16 tensors, got {dtype}.") + + graph = cudnn.pygraph( + io_data_type=io_data_type, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + return cudnn, graph + + +def _build_and_run_cudnn_graph(graph, variant_pack: Dict[Any, torch.Tensor], device: torch.device): + """Build and execute a cuDNN frontend Python graph without caching.""" + import cudnn # pylint: disable=import-outside-toplevel + + graph.validate() + graph.build_operation_graph() + try: + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + except cudnn.cudnnGraphNotSupportedError as exc: + raise RuntimeError(f"cuDNN score_mod SDPA graph is not supported: {exc}") from exc + graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE) + + workspace = torch.empty( + max(graph.get_workspace_size(), 1), + device=device, + dtype=torch.uint8, + ) + graph.execute(variant_pack, workspace) + + +class FusedAttentionWithScoreModFunc(torch.autograd.Function): + """cuDNN frontend Python SDPA path with score_mod callback support.""" + + @staticmethod + def forward( + ctx, + is_training: bool, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], + deterministic: bool, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + query_bhsd = _format_to_bhsd(query_layer, q_format) + key_bhsd = _format_to_bhsd(key_layer, kv_format) + value_bhsd = _format_to_bhsd(value_layer, kv_format) + + cudnn, graph = _build_cudnn_pygraph(query_bhsd.dtype) + q = graph.tensor_like(query_bhsd) + k = graph.tensor_like(key_bhsd) + v = graph.tensor_like(value_bhsd) + + score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, score_mod_tensors) + wrapped_score_mod = _wrap_score_mod(score_mod, score_mod_graph_tensors) + + output_bhsd = torch.empty( + (*query_bhsd.shape[:-1], value_bhsd.shape[-1]), + device=query_bhsd.device, + dtype=query_bhsd.dtype, + ) + output, stats = graph.sdpa( + name="te_score_mod_sdpa", + q=q, + k=k, + v=v, + generate_stats=is_training, + attn_scale=attn_scale, + use_causal_mask=False, + score_mod=wrapped_score_mod, + ) + output.set_output(True).set_dim(output_bhsd.size()).set_stride(output_bhsd.stride()) + + variant_pack = { + q: query_bhsd, + k: key_bhsd, + v: value_bhsd, + output: output_bhsd, + } + if is_training: + stats_bhs1 = torch.empty( + (*query_bhsd.shape[:-1], 1), + device=query_bhsd.device, + dtype=torch.float32, + ) + stats.set_output(True).set_dim(stats_bhs1.size()).set_stride( + stats_bhs1.stride() + ).set_data_type(cudnn.data_type.FLOAT) + variant_pack[stats] = stats_bhs1 + else: + stats_bhs1 = None + for name, graph_tensor in score_mod_graph_tensors.items(): + variant_pack[graph_tensor] = score_mod_tensors[name] + + _build_and_run_cudnn_graph(graph, variant_pack, query_bhsd.device) + + ctx.is_training = is_training + ctx.q_format = q_format + ctx.kv_format = kv_format + ctx.attn_scale = attn_scale + ctx.score_mod = score_mod + ctx.score_mod_bprop = score_mod_bprop + ctx.score_mod_tensors = dict(score_mod_tensors or {}) + ctx.score_mod_bprop_tensors = dict(score_mod_bprop_tensors or {}) + ctx.deterministic = deterministic + if is_training: + ctx.save_for_backward(query_bhsd, key_bhsd, value_bhsd, output_bhsd, stats_bhs1) + else: + ctx.save_for_backward(query_bhsd, key_bhsd, value_bhsd, output_bhsd) + + return _bhsd_to_format(output_bhsd, q_format) + + @staticmethod + def backward(ctx, d_out: torch.Tensor): + # pylint: disable=missing-function-docstring + if not ctx.is_training: + raise RuntimeError( + "score_mod backward requires DotProductAttention to be in training mode." + ) + + query_bhsd, key_bhsd, value_bhsd, output_bhsd, stats_bhs1 = ctx.saved_tensors + d_out_bhsd = _format_to_bhsd(d_out, ctx.q_format) + + cudnn, graph = _build_cudnn_pygraph(query_bhsd.dtype) + q = graph.tensor_like(query_bhsd) + k = graph.tensor_like(key_bhsd) + v = graph.tensor_like(value_bhsd) + output = graph.tensor_like(output_bhsd) + d_output = graph.tensor_like(d_out_bhsd) + stats = graph.tensor_like(stats_bhs1) + + score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, ctx.score_mod_tensors) + score_mod_bprop_graph_tensors = ( + _make_cudnn_graph_tensor_dict(graph, ctx.score_mod_bprop_tensors) + if ctx.score_mod_bprop is not None + else {} + ) + wrapped_score_mod = _wrap_score_mod(ctx.score_mod, score_mod_graph_tensors) + wrapped_score_mod_bprop = _wrap_score_mod( + ctx.score_mod_bprop, score_mod_bprop_graph_tensors + ) + + dq_bhsd = torch.empty_like(query_bhsd) + dk_bhsd = torch.empty_like(key_bhsd) + dv_bhsd = torch.empty_like(value_bhsd) + dq, dk, dv = graph.sdpa_backward( + name="te_score_mod_sdpa_backward", + q=q, + k=k, + v=v, + o=output, + dO=d_output, + stats=stats, + attn_scale=ctx.attn_scale, + use_causal_mask=False, + score_mod=wrapped_score_mod, + score_mod_bprop=wrapped_score_mod_bprop, + use_deterministic_algorithm=ctx.deterministic, + ) + dq.set_output(True).set_dim(dq_bhsd.size()).set_stride(dq_bhsd.stride()) + dk.set_output(True).set_dim(dk_bhsd.size()).set_stride(dk_bhsd.stride()) + dv.set_output(True).set_dim(dv_bhsd.size()).set_stride(dv_bhsd.stride()) + + variant_pack = { + q: query_bhsd, + k: key_bhsd, + v: value_bhsd, + output: output_bhsd, + d_output: d_out_bhsd, + stats: stats_bhs1, + dq: dq_bhsd, + dk: dk_bhsd, + dv: dv_bhsd, + } + for name, graph_tensor in score_mod_graph_tensors.items(): + variant_pack[graph_tensor] = ctx.score_mod_tensors[name] + for name, graph_tensor in score_mod_bprop_graph_tensors.items(): + variant_pack[graph_tensor] = ctx.score_mod_bprop_tensors[name] + + _build_and_run_cudnn_graph(graph, variant_pack, query_bhsd.device) + + return ( + None, + _bhsd_to_format(dq_bhsd, ctx.q_format), + _bhsd_to_format(dk_bhsd, ctx.kv_format), + _bhsd_to_format(dv_bhsd, ctx.kv_format), + None, + None, + None, + None, + None, + None, + None, + None, + ) + + class FusedAttnFunc(torch.autograd.Function): """FusedAttention forward and backward implementation""" @@ -1945,6 +2194,10 @@ def forward( inference_params: Optional[InferenceParams] = None, softmax_offset: torch.Tensor = None, fp8_output: bool = False, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Dict[str, torch.Tensor]] = None, + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: """fused attention fprop""" assert ( @@ -2067,7 +2320,43 @@ def forward( cp_group[0] if cp_comm_type == "a2a+p2p" else cp_group ) - if context_parallel: + if score_mod is not None: + assert ( + not context_parallel + ), "score_mod is not supported with context parallelism!" + assert ( + not fp8 + ), "score_mod is not supported with FP8 FusedAttention!" + assert ( + type(query_layer) is torch.Tensor + and type(key_layer) is torch.Tensor + and type(value_layer) is torch.Tensor + ), "score_mod only supports unquantized torch.Tensor Q, K and V inputs!" + assert ( + fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + ), "score_mod requires the F16/BF16 cuDNN fused attention backend!" + assert ( + attn_mask_type == "no_mask" + and core_attention_bias_type == "no_bias" + and core_attention_bias is None + and self.softmax_type == "vanilla" + and self.attention_dropout == 0.0 + ), "score_mod is mutually exclusive with masks, bias, sink attention and dropout!" + output = FusedAttentionWithScoreModFunc.apply( + self.training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + self.softmax_scale, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + self.deterministic, + ) + elif context_parallel: assert ( fp8 or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 17e9a337a4..7a7745d7c2 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -892,6 +892,10 @@ def forward( pad_between_seqs: Optional[bool] = None, fp8_output: Optional[bool] = False, num_splits: Optional[int] = 1, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Dict[str, torch.Tensor]] = None, + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: r""" Dot Product Attention Layer. @@ -1080,6 +1084,18 @@ def forward( Optional split control for FlashAttention-3 only. When set, this value is forwarded to the FA3 backend to control internal kernel splitting behavior for non-context-parallel cases. It is ignored for other backends and when context parallelism is enabled. + score_mod: Optional[Callable], default = None + cuDNN frontend score modification callback. This is a cuDNN-only path and is mutually + exclusive with masks, bias, ALiBi, sink attention, dropout, FP8, context parallelism, + THD format, KV caching, and return_max_logit. The callback signature is + ``score_mod(graph, score, tensors) -> score``. + score_mod_bprop: Optional[Callable], default = None + Optional cuDNN frontend callback for the backward pass of score_mod. The callback + signature is ``score_mod_bprop(graph, dP, tensors) -> dP``. + score_mod_tensors: Optional[Dict[str, torch.Tensor]], default = None + Runtime tensors exposed to score_mod as cuDNN graph tensors. + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], default = None + Runtime tensors exposed to score_mod_bprop as cuDNN graph tensors. """ with self.prepare_forward_ctx( @@ -1088,6 +1104,13 @@ def forward( allow_non_contiguous=True, allow_different_data_and_param_types=self.softmax_type != "vanilla", ) as query_layer: + user_supplied_seqlens = ( + cu_seqlens_q is not None + or cu_seqlens_kv is not None + or cu_seqlens_q_padded is not None + or cu_seqlens_kv_padded is not None + ) + # checks for RNG if self.rng_states_tracker is not None and is_graph_capturing(): assert isinstance( @@ -1406,6 +1429,86 @@ def forward( else: pad_between_seqs = False + if score_mod is None: + assert score_mod_bprop is None, "score_mod_bprop requires score_mod!" + assert score_mod_tensors is None, "score_mod_tensors requires score_mod!" + assert score_mod_bprop_tensors is None, "score_mod_bprop_tensors requires score_mod!" + else: + assert callable(score_mod), "score_mod must be callable!" + assert ( + score_mod_bprop is None or callable(score_mod_bprop) + ), "score_mod_bprop must be callable when provided!" + assert ( + query_layer.dtype in [torch.float16, torch.bfloat16] + ), "score_mod only supports FP16 and BF16 tensors!" + assert ( + key_layer.dtype == query_layer.dtype and value_layer.dtype == query_layer.dtype + ), "score_mod requires Q, K and V tensors to have the same dtype!" + assert ( + type(query_layer) is torch.Tensor + and type(key_layer) is torch.Tensor + and type(value_layer) is torch.Tensor + ), "score_mod only supports unquantized torch.Tensor Q, K and V inputs!" + assert ( + not self.fp8 + ), "score_mod is not supported with FP8 DotProductAttention!" + assert not fp8_output, "score_mod is not supported with fp8_output!" + assert ( + not context_parallel + ), "score_mod is not supported with context parallelism!" + assert inference_params is None, "score_mod is not supported with KV caching!" + assert qkv_format != "thd", "score_mod is not supported with qkv_format='thd'!" + assert ( + not user_supplied_seqlens + ), "score_mod is mutually exclusive with explicit sequence length metadata!" + assert not pad_between_seqs, "score_mod is not supported with pad_between_seqs!" + assert attention_mask is None, "score_mod is mutually exclusive with attention_mask!" + assert ( + attn_mask_type == "no_mask" + ), "score_mod requires attn_mask_type='no_mask'!" + assert ( + window_size is None or window_size == (-1, -1) + ), "score_mod is mutually exclusive with sliding window attention!" + assert ( + core_attention_bias_type == "no_bias" and core_attention_bias is None + ), "score_mod is mutually exclusive with attention bias!" + assert alibi_slopes is None, "score_mod is mutually exclusive with ALiBi!" + assert ( + self.softmax_type == "vanilla" + ), "score_mod is mutually exclusive with sink attention!" + assert ( + self.attention_dropout == 0.0 + ), "score_mod is not supported with attention dropout!" + assert ( + not self.return_max_logit + ), "score_mod is not supported with return_max_logit!" + assert ( + not checkpoint_core_attention + ), "score_mod is not supported with checkpoint_core_attention!" + assert ( + not is_graph_capturing() + ), "score_mod is not supported with CUDA graph capture!" + assert num_splits == 1, "score_mod is not supported with num_splits != 1!" + assert ( + q_format in ["sbhd", "bshd"] and kv_format in ["sbhd", "bshd"] + ), "score_mod only supports SBHD/BSHD QKV formats!" + if score_mod_tensors is not None: + assert isinstance( + score_mod_tensors, dict + ), "score_mod_tensors must be a dict!" + assert all( + isinstance(k, str) and isinstance(v, torch.Tensor) + for k, v in score_mod_tensors.items() + ), "score_mod_tensors must map string names to torch.Tensor instances!" + if score_mod_bprop_tensors is not None: + assert isinstance( + score_mod_bprop_tensors, dict + ), "score_mod_bprop_tensors must be a dict!" + assert all( + isinstance(k, str) and isinstance(v, torch.Tensor) + for k, v in score_mod_bprop_tensors.items() + ), "score_mod_bprop_tensors must map string names to torch.Tensor instances!" + # gather attention params for get_attention_backend attention_params = dpa_utils.AttentionParams( qkv_type=type(query_layer), @@ -1443,7 +1546,39 @@ def forward( num_splits=num_splits, ) global _attention_backends - if is_in_onnx_export_mode(): + if score_mod is not None: + use_flash_attention = False + flash_attention_backend = None + use_fused_attention = True + use_unfused_attention = False + q_type = dpa_utils.TE_DType[query_layer.dtype] + fused_attention_backend = tex.get_fused_attn_backend( + self.training, + q_type, + q_type, + dpa_utils.QKVLayout["bshd_bshd_bshd"], + dpa_utils.AttnBiasType["no_bias"], + dpa_utils.AttnMaskType["no_mask"], + dpa_utils.SoftmaxType["vanilla"], + 0.0, + num_attention_heads, + num_gqa_groups, + max_seqlen_q, + max_seqlen_kv, + head_dim_qk, + head_dim_v, + -1, + -1, + False, + is_graph_capturing(), + self.deterministic, + ) + if fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend: + raise ValueError( + "score_mod requires a cuDNN FusedAttention backend, but no fused " + "attention backend supports the provided inputs." + ) + elif is_in_onnx_export_mode(): # We do not want to call get_attention_backend() in ONNX mode # and we want to avoid using any global variables like _attention_backends. use_flash_attention = False @@ -1619,6 +1754,10 @@ def forward( inference_params=inference_params, softmax_offset=softmax_offset, fp8_output=fp8_output, + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, ) if use_unfused_attention: From eb35191a7071d868ff4163b79666004e79d87b54 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 11 May 2026 18:43:04 +0000 Subject: [PATCH 02/12] Avoid BHSD copies in score_mod attention Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 13 +- .../dot_product_attention/backends.py | 124 +++++++++--------- 2 files changed, 73 insertions(+), 64 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 879f48dc0c..195087efb0 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1436,7 +1436,8 @@ def _score_mod_causal_bprop(score_mod_graph, dP_tensor, tensors): @pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"]) -def test_dot_product_attention_score_mod(dtype, qkv_format): +@pytest.mark.parametrize("scalar_loss", [False, True]) +def test_dot_product_attention_score_mod(dtype, qkv_format, scalar_loss): """Compare score_mod causal masking against standard cuDNN causal attention.""" try: import cudnn # pylint: disable=unused-import,import-outside-toplevel @@ -1504,9 +1505,13 @@ def test_dot_product_attention_score_mod(dtype, qkv_format): attn_mask_type="causal", ) - d_out = torch.randn_like(out) - out.backward(d_out) - out_ref.backward(d_out) + if scalar_loss: + out.sum().backward() + out_ref.sum().backward() + else: + d_out = torch.randn_like(out) + out.backward(d_out) + out_ref.backward(d_out) tols = dict(atol=5e-2, rtol=5e-2) torch.testing.assert_close(out, out_ref, **tols) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index e7e7a7d0f2..3f765dd634 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1244,22 +1244,27 @@ def convert_to_torch_float8(tensor, dtype): return output.contiguous() -def _format_to_bhsd(tensor: torch.Tensor, tensor_format: str) -> torch.Tensor: - """Convert TE's SBHD/BSHD tensor formats to cuDNN frontend's BHSD format.""" +def _bhsd_dim_stride( + tensor: torch.Tensor, tensor_format: str +) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """Describe an SBHD/BSHD tensor as cuDNN frontend's logical BHSD format.""" if tensor_format == "sbhd": - return tensor.permute(1, 2, 0, 3).contiguous() + return ( + (tensor.shape[1], tensor.shape[2], tensor.shape[0], tensor.shape[3]), + (tensor.stride(1), tensor.stride(2), tensor.stride(0), tensor.stride(3)), + ) if tensor_format == "bshd": - return tensor.permute(0, 2, 1, 3).contiguous() + return ( + (tensor.shape[0], tensor.shape[2], tensor.shape[1], tensor.shape[3]), + (tensor.stride(0), tensor.stride(2), tensor.stride(1), tensor.stride(3)), + ) raise ValueError(f"score_mod only supports SBHD/BSHD tensor formats, got {tensor_format}.") -def _bhsd_to_format(tensor: torch.Tensor, tensor_format: str) -> torch.Tensor: - """Convert cuDNN frontend's BHSD format back to TE's SBHD/BSHD tensor formats.""" - if tensor_format == "sbhd": - return tensor.permute(2, 0, 1, 3).contiguous() - if tensor_format == "bshd": - return tensor.permute(0, 2, 1, 3).contiguous() - raise ValueError(f"score_mod only supports SBHD/BSHD tensor formats, got {tensor_format}.") +def _bhsd_graph_tensor(graph, tensor: torch.Tensor, tensor_format: str): + """Create a cuDNN graph tensor with BHSD dims and TE-layout strides.""" + dim, stride = _bhsd_dim_stride(tensor, tensor_format) + return graph.tensor(dim=dim, stride=stride, data_type=tensor.dtype) def _make_cudnn_graph_tensor_dict(graph, tensors: Optional[Dict[str, torch.Tensor]]): @@ -1340,23 +1345,19 @@ def forward( deterministic: bool, ) -> torch.Tensor: # pylint: disable=missing-function-docstring - query_bhsd = _format_to_bhsd(query_layer, q_format) - key_bhsd = _format_to_bhsd(key_layer, kv_format) - value_bhsd = _format_to_bhsd(value_layer, kv_format) + q_bhsd_dim, _ = _bhsd_dim_stride(query_layer, q_format) - cudnn, graph = _build_cudnn_pygraph(query_bhsd.dtype) - q = graph.tensor_like(query_bhsd) - k = graph.tensor_like(key_bhsd) - v = graph.tensor_like(value_bhsd) + cudnn, graph = _build_cudnn_pygraph(query_layer.dtype) + q = _bhsd_graph_tensor(graph, query_layer, q_format) + k = _bhsd_graph_tensor(graph, key_layer, kv_format) + v = _bhsd_graph_tensor(graph, value_layer, kv_format) score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, score_mod_tensors) wrapped_score_mod = _wrap_score_mod(score_mod, score_mod_graph_tensors) - output_bhsd = torch.empty( - (*query_bhsd.shape[:-1], value_bhsd.shape[-1]), - device=query_bhsd.device, - dtype=query_bhsd.dtype, - ) + output_shape = (*query_layer.shape[:-1], value_layer.shape[-1]) + output_layer = torch.empty(output_shape, device=query_layer.device, dtype=query_layer.dtype) + output_dim, output_stride = _bhsd_dim_stride(output_layer, q_format) output, stats = graph.sdpa( name="te_score_mod_sdpa", q=q, @@ -1367,18 +1368,18 @@ def forward( use_causal_mask=False, score_mod=wrapped_score_mod, ) - output.set_output(True).set_dim(output_bhsd.size()).set_stride(output_bhsd.stride()) + output.set_output(True).set_dim(output_dim).set_stride(output_stride) variant_pack = { - q: query_bhsd, - k: key_bhsd, - v: value_bhsd, - output: output_bhsd, + q: query_layer, + k: key_layer, + v: value_layer, + output: output_layer, } if is_training: stats_bhs1 = torch.empty( - (*query_bhsd.shape[:-1], 1), - device=query_bhsd.device, + (*q_bhsd_dim[:-1], 1), + device=query_layer.device, dtype=torch.float32, ) stats.set_output(True).set_dim(stats_bhs1.size()).set_stride( @@ -1390,7 +1391,7 @@ def forward( for name, graph_tensor in score_mod_graph_tensors.items(): variant_pack[graph_tensor] = score_mod_tensors[name] - _build_and_run_cudnn_graph(graph, variant_pack, query_bhsd.device) + _build_and_run_cudnn_graph(graph, variant_pack, query_layer.device) ctx.is_training = is_training ctx.q_format = q_format @@ -1402,11 +1403,11 @@ def forward( ctx.score_mod_bprop_tensors = dict(score_mod_bprop_tensors or {}) ctx.deterministic = deterministic if is_training: - ctx.save_for_backward(query_bhsd, key_bhsd, value_bhsd, output_bhsd, stats_bhs1) + ctx.save_for_backward(query_layer, key_layer, value_layer, output_layer, stats_bhs1) else: - ctx.save_for_backward(query_bhsd, key_bhsd, value_bhsd, output_bhsd) + ctx.save_for_backward(query_layer, key_layer, value_layer, output_layer) - return _bhsd_to_format(output_bhsd, q_format) + return output_layer @staticmethod def backward(ctx, d_out: torch.Tensor): @@ -1416,15 +1417,15 @@ def backward(ctx, d_out: torch.Tensor): "score_mod backward requires DotProductAttention to be in training mode." ) - query_bhsd, key_bhsd, value_bhsd, output_bhsd, stats_bhs1 = ctx.saved_tensors - d_out_bhsd = _format_to_bhsd(d_out, ctx.q_format) + query_layer, key_layer, value_layer, output_layer, stats_bhs1 = ctx.saved_tensors + d_out = d_out.contiguous() - cudnn, graph = _build_cudnn_pygraph(query_bhsd.dtype) - q = graph.tensor_like(query_bhsd) - k = graph.tensor_like(key_bhsd) - v = graph.tensor_like(value_bhsd) - output = graph.tensor_like(output_bhsd) - d_output = graph.tensor_like(d_out_bhsd) + cudnn, graph = _build_cudnn_pygraph(query_layer.dtype) + q = _bhsd_graph_tensor(graph, query_layer, ctx.q_format) + k = _bhsd_graph_tensor(graph, key_layer, ctx.kv_format) + v = _bhsd_graph_tensor(graph, value_layer, ctx.kv_format) + output = _bhsd_graph_tensor(graph, output_layer, ctx.q_format) + d_output = _bhsd_graph_tensor(graph, d_out, ctx.q_format) stats = graph.tensor_like(stats_bhs1) score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, ctx.score_mod_tensors) @@ -1438,9 +1439,12 @@ def backward(ctx, d_out: torch.Tensor): ctx.score_mod_bprop, score_mod_bprop_graph_tensors ) - dq_bhsd = torch.empty_like(query_bhsd) - dk_bhsd = torch.empty_like(key_bhsd) - dv_bhsd = torch.empty_like(value_bhsd) + dq_layer = torch.empty_like(query_layer) + dk_layer = torch.empty_like(key_layer) + dv_layer = torch.empty_like(value_layer) + dq_dim, dq_stride = _bhsd_dim_stride(dq_layer, ctx.q_format) + dk_dim, dk_stride = _bhsd_dim_stride(dk_layer, ctx.kv_format) + dv_dim, dv_stride = _bhsd_dim_stride(dv_layer, ctx.kv_format) dq, dk, dv = graph.sdpa_backward( name="te_score_mod_sdpa_backward", q=q, @@ -1455,33 +1459,33 @@ def backward(ctx, d_out: torch.Tensor): score_mod_bprop=wrapped_score_mod_bprop, use_deterministic_algorithm=ctx.deterministic, ) - dq.set_output(True).set_dim(dq_bhsd.size()).set_stride(dq_bhsd.stride()) - dk.set_output(True).set_dim(dk_bhsd.size()).set_stride(dk_bhsd.stride()) - dv.set_output(True).set_dim(dv_bhsd.size()).set_stride(dv_bhsd.stride()) + dq.set_output(True).set_dim(dq_dim).set_stride(dq_stride) + dk.set_output(True).set_dim(dk_dim).set_stride(dk_stride) + dv.set_output(True).set_dim(dv_dim).set_stride(dv_stride) variant_pack = { - q: query_bhsd, - k: key_bhsd, - v: value_bhsd, - output: output_bhsd, - d_output: d_out_bhsd, + q: query_layer, + k: key_layer, + v: value_layer, + output: output_layer, + d_output: d_out, stats: stats_bhs1, - dq: dq_bhsd, - dk: dk_bhsd, - dv: dv_bhsd, + dq: dq_layer, + dk: dk_layer, + dv: dv_layer, } for name, graph_tensor in score_mod_graph_tensors.items(): variant_pack[graph_tensor] = ctx.score_mod_tensors[name] for name, graph_tensor in score_mod_bprop_graph_tensors.items(): variant_pack[graph_tensor] = ctx.score_mod_bprop_tensors[name] - _build_and_run_cudnn_graph(graph, variant_pack, query_bhsd.device) + _build_and_run_cudnn_graph(graph, variant_pack, query_layer.device) return ( None, - _bhsd_to_format(dq_bhsd, ctx.q_format), - _bhsd_to_format(dk_bhsd, ctx.kv_format), - _bhsd_to_format(dv_bhsd, ctx.kv_format), + dq_layer, + dk_layer, + dv_layer, None, None, None, From 57ce106435dba1e95ef8e58d15b64bfdcbfca59c Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 11 May 2026 21:21:44 +0000 Subject: [PATCH 03/12] Test relative position score_mod attention Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 138 ++++++++++++++++++++++ 1 file changed, 138 insertions(+) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 195087efb0..f715836a30 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1432,6 +1432,41 @@ def _score_mod_causal_bprop(score_mod_graph, dP_tensor, tensors): ) +def _score_mod_relative_position(score_mod_graph, score_tensor, _tensors): + """cuDNN frontend score_mod adding relative position bias.""" + import cudnn # pylint: disable=import-outside-toplevel + + row_index = score_mod_graph.gen_index(input=score_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = score_mod_graph.gen_index(input=score_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + relative_position = score_mod_graph.sub( + a=row_index, + b=col_index, + compute_data_type=cudnn.data_type.FLOAT, + ) + relative_position.set_data_type(cudnn.data_type.FLOAT) + return score_mod_graph.add( + a=score_tensor, + b=relative_position, + compute_data_type=cudnn.data_type.FLOAT, + ) + + +def _score_mod_identity_bprop(_score_mod_graph, dP_tensor, _tensors): + """cuDNN frontend score_mod_bprop for score_mods with unit score derivative.""" + return dP_tensor + + +def _relative_position_bias(config, dtype): + """Materialize score + (q_idx - kv_idx) as post-scale attention bias.""" + q_idx = torch.arange(config.max_seqlen_q, dtype=torch.float32, device="cuda").view(1, 1, -1, 1) + kv_idx = torch.arange(config.max_seqlen_kv, dtype=torch.float32, device="cuda").view( + 1, 1, 1, -1 + ) + return (q_idx - kv_idx).to(dtype).expand(1, config.num_heads, -1, -1).contiguous() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") @pytest.mark.parametrize("dtype", param_types) @@ -1520,6 +1555,109 @@ def test_dot_product_attention_score_mod(dtype, qkv_format, scalar_loss): torch.testing.assert_close(v.grad, v_ref.grad, **tols) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") +@pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"]) +def test_dot_product_attention_score_mod_relative_position(dtype, qkv_format): + """Compare relative-position score_mod against materialized post-scale bias.""" + try: + import cudnn # pylint: disable=unused-import,import-outside-toplevel + except ImportError: + pytest.skip("cuDNN Python frontend is required for score_mod attention.") + + reset_rng_states() + + config = ModelConfig(2, 16, 4, 64, attn_mask_type="no_mask") + bias_config = ModelConfig( + config.batch_size, + config.max_seqlen_q, + config.num_heads, + config.head_dim_qk, + attn_mask_type="no_mask", + attn_bias_type="post_scale_bias", + bias_shape="1hss", + ) + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", + ) + bias_available_backends, _, bias_fused_attn_backends = get_available_attention_backends( + bias_config, + qkv_dtype=dtype, + qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", + ) + if ( + not available_backends[1] + or not fused_attn_backends + or not bias_available_backends[1] + or not bias_fused_attn_backends + ): + pytest.skip("FusedAttention is not available for this relative-position configuration.") + + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + if qkv_format == "sbhd": + q_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + else: + q_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + + q = (0.1 * torch.randn(q_shape, dtype=dtype, device="cuda")).requires_grad_() + k = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + v = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + q_ref, k_ref, v_ref = [x.detach().clone().requires_grad_() for x in (q, k, v)] + + flex_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="no_mask", + layer_number=1, + ).to(dtype=dtype, device="cuda") + ref_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="no_mask", + layer_number=1, + ).to(dtype=dtype, device="cuda") + + out = flex_attn( + q, + k, + v, + qkv_format=qkv_format, + attn_mask_type="no_mask", + score_mod=_score_mod_relative_position, + score_mod_bprop=_score_mod_identity_bprop, + ) + out_ref = ref_attn( + q_ref, + k_ref, + v_ref, + qkv_format=qkv_format, + attn_mask_type="no_mask", + core_attention_bias_type="post_scale_bias", + core_attention_bias=_relative_position_bias(config, dtype), + ) + + d_out = torch.randn_like(out) + out.backward(d_out) + out_ref.backward(d_out) + + tols = dict(atol=5e-2, rtol=5e-2) + torch.testing.assert_close(out, out_ref, **tols) + torch.testing.assert_close(q.grad, q_ref.grad, **tols) + torch.testing.assert_close(k.grad, k_ref.grad, **tols) + torch.testing.assert_close(v.grad, v_ref.grad, **tols) + + model_configs_te_layer = { # test: ModelConfig(b, sq, hq, dqk) "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"), From e6ba0ea8907c7bb26a4499bb2b61dc605fe71c8d Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 11 May 2026 21:44:30 +0000 Subject: [PATCH 04/12] Test softcap score_mod attention Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 153 ++++++++++++++++++++++ 1 file changed, 153 insertions(+) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index f715836a30..c42a42d4b4 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1458,6 +1458,53 @@ def _score_mod_identity_bprop(_score_mod_graph, dP_tensor, _tensors): return dP_tensor +class _ScoreModSoftcap: + """cuDNN frontend score_mod implementing softcapping.""" + + def __init__(self): + self.before_tanh_activation = None + + def forward(self, score_mod_graph, score_tensor, tensors): + """Apply softcap * tanh(score / softcap).""" + import cudnn # pylint: disable=import-outside-toplevel + + self.before_tanh_activation = score_mod_graph.div( + a=score_tensor, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + self.before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) + tanh_out = score_mod_graph.tanh(input=self.before_tanh_activation) + tanh_out.set_data_type(cudnn.data_type.FLOAT) + return score_mod_graph.mul( + a=tanh_out, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + + def backward(self, score_mod_graph, dP_tensor, tensors): + """Apply softcap derivative to dP.""" + import cudnn # pylint: disable=import-outside-toplevel + + d_tanh_out = score_mod_graph.mul( + a=dP_tensor, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + d_tanh_out.set_data_type(cudnn.data_type.FLOAT) + d_before_tanh_activation = score_mod_graph.tanh_backward( + loss=d_tanh_out, + input=self.before_tanh_activation, + compute_data_type=cudnn.data_type.FLOAT, + ) + d_before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) + return score_mod_graph.div( + a=d_before_tanh_activation, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + + def _relative_position_bias(config, dtype): """Materialize score + (q_idx - kv_idx) as post-scale attention bias.""" q_idx = torch.arange(config.max_seqlen_q, dtype=torch.float32, device="cuda").view(1, 1, -1, 1) @@ -1467,6 +1514,32 @@ def _relative_position_bias(config, dtype): return (q_idx - kv_idx).to(dtype).expand(1, config.num_heads, -1, -1).contiguous() +def _to_bhsd(tensor, qkv_format): + """Convert SBHD/BSHD test tensors to logical BHSD.""" + if qkv_format == "sbhd": + return tensor.permute(1, 2, 0, 3) + return tensor.permute(0, 2, 1, 3) + + +def _from_bhsd(tensor, qkv_format): + """Convert logical BHSD test tensors to SBHD/BSHD.""" + if qkv_format == "sbhd": + return tensor.permute(2, 0, 1, 3).contiguous() + return tensor.permute(0, 2, 1, 3).contiguous() + + +def _pytorch_softcap_attention(q, k, v, qkv_format, softmax_scale, softcap): + """PyTorch reference for softcapped scaled dot-product attention.""" + q_bhsd = _to_bhsd(q, qkv_format).float() + k_bhsd = _to_bhsd(k, qkv_format).float() + v_bhsd = _to_bhsd(v, qkv_format).float() + scores = torch.matmul(q_bhsd, k_bhsd.transpose(-2, -1)) * softmax_scale + scores = softcap * torch.tanh(scores / softcap) + probs = torch.softmax(scores, dim=-1) + out = _from_bhsd(torch.matmul(probs, v_bhsd), qkv_format).to(v.dtype) + return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1]) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") @pytest.mark.parametrize("dtype", param_types) @@ -1555,6 +1628,86 @@ def test_dot_product_attention_score_mod(dtype, qkv_format, scalar_loss): torch.testing.assert_close(v.grad, v_ref.grad, **tols) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") +@pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"]) +def test_dot_product_attention_score_mod_softcap(dtype, qkv_format): + """Compare softcap score_mod against PyTorch math attention.""" + try: + import cudnn # pylint: disable=unused-import,import-outside-toplevel + except ImportError: + pytest.skip("cuDNN Python frontend is required for score_mod attention.") + + reset_rng_states() + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + config = ModelConfig(2, 16, 4, 64, attn_mask_type="no_mask") + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", + ) + if not available_backends[1] or not fused_attn_backends: + pytest.skip("FusedAttention is not available for this softcap configuration.") + + if qkv_format == "sbhd": + q_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + else: + q_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + + q = torch.randn(q_shape, dtype=dtype, device="cuda").requires_grad_() + k = torch.randn(kv_shape, dtype=dtype, device="cuda").requires_grad_() + v = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + q_ref, k_ref, v_ref = [x.detach().clone().requires_grad_() for x in (q, k, v)] + + softcap = 0.8 + softcap_tensor = torch.full((1, 1, 1, 1), softcap) + softcap_score_mod = _ScoreModSoftcap() + flex_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="no_mask", + layer_number=1, + ).to(dtype=dtype, device="cuda") + + out = flex_attn( + q, + k, + v, + qkv_format=qkv_format, + attn_mask_type="no_mask", + score_mod=softcap_score_mod.forward, + score_mod_bprop=softcap_score_mod.backward, + score_mod_tensors={"softcap": softcap_tensor}, + score_mod_bprop_tensors={"softcap": softcap_tensor}, + ) + out_ref = _pytorch_softcap_attention( + q_ref, + k_ref, + v_ref, + qkv_format, + 1.0 / config.head_dim_qk**0.5, + softcap, + ) + + d_out = torch.randn_like(out) + out.backward(d_out) + out_ref.backward(d_out) + + tols = dict(atol=7e-2, rtol=7e-2) + torch.testing.assert_close(out, out_ref, **tols) + torch.testing.assert_close(q.grad, q_ref.grad, **tols) + torch.testing.assert_close(k.grad, k_ref.grad, **tols) + torch.testing.assert_close(v.grad, v_ref.grad, **tols) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") @pytest.mark.parametrize("dtype", param_types) From dcb6b492cc3396a95a57132ba8fcd90711020233 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 11 May 2026 23:17:04 +0000 Subject: [PATCH 05/12] Run score_mod graphs on current CUDA stream Signed-off-by: Vladimir Cherepanov --- .../dot_product_attention/backends.py | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 3f765dd634..625c030e0c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -89,6 +89,7 @@ _flash_attn_bwd = None _flash_attn_varlen_fwd = None _flash_attn_varlen_bwd = None +_cudnn_score_mod_handles: Dict[torch.device, Any] = {} # Try to import Flash Attention v2 try: @@ -1285,7 +1286,25 @@ def _wrapped_score_mod(sdpa_graph, score_tensor): return _wrapped_score_mod -def _build_cudnn_pygraph(dtype: torch.dtype): +def _get_cudnn_current_stream_handle(cudnn, device: torch.device): + """Return a cuDNN handle for device, bound to PyTorch's current stream.""" + if device.type != "cuda": + raise ValueError(f"score_mod only supports CUDA tensors, got device {device}.") + if device.index is None: + device = torch.device("cuda", torch.cuda.current_device()) + + handle = _cudnn_score_mod_handles.get(device) + with torch.cuda.device(device): + if handle is None: + handle = cudnn.create_handle() + _cudnn_score_mod_handles[device] = handle + + stream = torch.cuda.current_stream(device).cuda_stream + cudnn.set_stream(handle=handle, stream=stream) + return handle + + +def _build_cudnn_pygraph(dtype: torch.dtype, device: torch.device): """Create a cuDNN frontend Python graph for F16/BF16 SDPA.""" import cudnn # pylint: disable=import-outside-toplevel @@ -1300,6 +1319,7 @@ def _build_cudnn_pygraph(dtype: torch.dtype): io_data_type=io_data_type, intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, + handle=_get_cudnn_current_stream_handle(cudnn, device), ) return cudnn, graph @@ -1322,7 +1342,11 @@ def _build_and_run_cudnn_graph(graph, variant_pack: Dict[Any, torch.Tensor], dev device=device, dtype=torch.uint8, ) - graph.execute(variant_pack, workspace) + graph.execute( + variant_pack, + workspace, + handle=_get_cudnn_current_stream_handle(cudnn, device), + ) class FusedAttentionWithScoreModFunc(torch.autograd.Function): @@ -1347,7 +1371,7 @@ def forward( # pylint: disable=missing-function-docstring q_bhsd_dim, _ = _bhsd_dim_stride(query_layer, q_format) - cudnn, graph = _build_cudnn_pygraph(query_layer.dtype) + cudnn, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) q = _bhsd_graph_tensor(graph, query_layer, q_format) k = _bhsd_graph_tensor(graph, key_layer, kv_format) v = _bhsd_graph_tensor(graph, value_layer, kv_format) @@ -1420,7 +1444,7 @@ def backward(ctx, d_out: torch.Tensor): query_layer, key_layer, value_layer, output_layer, stats_bhs1 = ctx.saved_tensors d_out = d_out.contiguous() - cudnn, graph = _build_cudnn_pygraph(query_layer.dtype) + cudnn, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) q = _bhsd_graph_tensor(graph, query_layer, ctx.q_format) k = _bhsd_graph_tensor(graph, key_layer, ctx.kv_format) v = _bhsd_graph_tensor(graph, value_layer, ctx.kv_format) From fefcbe7f37f995a09ff4df7c5c88c227d6fd0e71 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Tue, 12 May 2026 21:47:04 +0000 Subject: [PATCH 06/12] Add PyTorch score_mod execution plan cache Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 157 +++++ .../dot_product_attention/backends.py | 547 +++++++++++++++--- 2 files changed, 616 insertions(+), 88 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index c42a42d4b4..a8f1811dc0 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -26,6 +26,7 @@ from transformer_engine.pytorch.attention.dot_product_attention import ( _attention_backends, ) +import transformer_engine.pytorch.attention.dot_product_attention.backends as dpa_backends from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils, check_set_window_size, @@ -1505,6 +1506,162 @@ def backward(self, score_mod_graph, dP_tensor, tensors): ) +def _score_mod_cache_cpu_inputs(): + """Small CPU tensors for score_mod cache-key tests.""" + q = torch.empty((2, 4, 3, 8), dtype=torch.float16) + k = torch.empty((2, 4, 3, 8), dtype=torch.float16) + v = torch.empty((2, 4, 3, 8), dtype=torch.float16) + o = torch.empty((2, 4, 3, 8), dtype=torch.float16) + stats = torch.empty((2, 3, 4, 1), dtype=torch.float32) + return q, k, v, o, stats + + +def test_score_mod_cache_bound_method_key_stable(): + """Bound method keys should be stable across repeated attribute access.""" + softcap = _ScoreModSoftcap() + key_0 = dpa_backends._score_mod_callback_cache_key(softcap.forward) + key_1 = dpa_backends._score_mod_callback_cache_key(softcap.forward) + other_key = dpa_backends._score_mod_callback_cache_key(_ScoreModSoftcap().forward) + + assert key_0 == key_1 + assert key_0 != other_key + + +def test_score_mod_cache_key_ignores_pass_by_value_values(): + """Scalar CPU tensor values are runtime inputs, not execution-plan metadata.""" + q, k, v, o, stats = _score_mod_cache_cpu_inputs() + key_0 = dpa_backends._cudnn_score_mod_fwd_cache_key( + True, + q, + k, + v, + o, + stats, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor(0.8, dtype=torch.float32)}, + ) + key_1 = dpa_backends._cudnn_score_mod_fwd_cache_key( + True, + q, + k, + v, + o, + stats, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor(1.2, dtype=torch.float32)}, + ) + key_2 = dpa_backends._cudnn_score_mod_fwd_cache_key( + True, + q, + k, + v, + o, + stats, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor([0.8], dtype=torch.float32)}, + ) + + assert key_0 == key_1 + assert key_0 != key_2 + + +def test_score_mod_cache_fwd_reuses_graph_for_pass_by_value_changes(monkeypatch): + """Fprop graph cache should reuse entries when only scalar CPU tensor values change.""" + q, k, v, o, stats = _score_mod_cache_cpu_inputs() + cache = dpa_backends._cudnn_score_mod_graph_cache + saved_cache = dict(cache) + build_entries = [] + + def fake_build( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats_bhs1, + ): + del ( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats_bhs1, + ) + entry = object() + build_entries.append(entry) + return entry + + monkeypatch.setattr(dpa_backends, "_build_cudnn_score_mod_fwd_graph", fake_build) + try: + cache.clear() + entry_0 = dpa_backends._get_cudnn_score_mod_fwd_graph( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor(0.8, dtype=torch.float32)}, + o, + stats, + ) + entry_1 = dpa_backends._get_cudnn_score_mod_fwd_graph( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor(1.2, dtype=torch.float32)}, + o, + stats, + ) + entry_2 = dpa_backends._get_cudnn_score_mod_fwd_graph( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor([0.8], dtype=torch.float32)}, + o, + stats, + ) + finally: + cache.clear() + cache.update(saved_cache) + + assert entry_0 is entry_1 + assert entry_2 is not entry_0 + assert len(build_entries) == 2 + + def _relative_position_bias(config, dtype): """Materialize score + (q_idx - kv_idx) as post-scale attention bias.""" q_idx = torch.arange(config.max_seqlen_q, dtype=torch.float32, device="cuda").view(1, 1, -1, 1) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 625c030e0c..66a49bbc64 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -4,6 +4,7 @@ """Attention Backends.""" from contextlib import nullcontext +from dataclasses import dataclass from importlib.metadata import version as get_pkg_version from importlib.metadata import PackageNotFoundError import os @@ -90,6 +91,7 @@ _flash_attn_varlen_fwd = None _flash_attn_varlen_bwd = None _cudnn_score_mod_handles: Dict[torch.device, Any] = {} +_cudnn_score_mod_graph_cache: Dict[Tuple[Any, ...], Any] = {} # Try to import Flash Attention v2 try: @@ -1268,6 +1270,52 @@ def _bhsd_graph_tensor(graph, tensor: torch.Tensor, tensor_format: str): return graph.tensor(dim=dim, stride=stride, data_type=tensor.dtype) +def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Optional[Tuple[Any, ...]]: + """Create a stable cache key for a score_mod callable.""" + if callback is None: + return None + self_obj = getattr(callback, "__self__", None) + func_obj = getattr(callback, "__func__", None) + if self_obj is not None and func_obj is not None: + return ("bound_method", id(self_obj), id(func_obj)) + return ("callable", id(callback)) + + +def _score_mod_device_key(device: torch.device) -> Tuple[Any, ...]: + """Normalize a tensor device for graph cache keys.""" + if device.type == "cuda": + index = device.index + if index is None: + index = torch.cuda.current_device() + return (device.type, index) + return (device.type, device.index) + + +def _score_mod_tensor_metadata(tensor: torch.Tensor) -> Tuple[Any, ...]: + """Describe tensor metadata that can affect cuDNN graph construction.""" + return ( + tuple(tensor.size()), + tuple(tensor.stride()), + tensor.dtype, + _score_mod_device_key(tensor.device), + ) + + +def _score_mod_tensor_dict_metadata( + tensors: Optional[Dict[str, torch.Tensor]], +) -> Tuple[Tuple[str, Tuple[Any, ...]], ...]: + """Describe score_mod tensor parameters without including their values.""" + if tensors is None: + return () + return tuple((name, _score_mod_tensor_metadata(tensor)) for name, tensor in tensors.items()) + + +def _score_mod_bhsd_tensor_metadata(tensor: torch.Tensor, tensor_format: str) -> Tuple[Any, ...]: + """Describe an SBHD/BSHD runtime tensor as a cuDNN BHSD graph tensor.""" + dim, stride = _bhsd_dim_stride(tensor, tensor_format) + return (dim, stride, tensor.dtype, _score_mod_device_key(tensor.device)) + + def _make_cudnn_graph_tensor_dict(graph, tensors: Optional[Dict[str, torch.Tensor]]): """Create cuDNN graph tensors matching runtime tensors.""" if tensors is None: @@ -1324,8 +1372,41 @@ def _build_cudnn_pygraph(dtype: torch.dtype, device: torch.device): return cudnn, graph -def _build_and_run_cudnn_graph(graph, variant_pack: Dict[Any, torch.Tensor], device: torch.device): - """Build and execute a cuDNN frontend Python graph without caching.""" +@dataclass +class _CudnnScoreModFwdGraphEntry: + """Cached cuDNN frontend graph and graph tensor handles for score_mod fprop.""" + + graph: Any + q: Any + k: Any + v: Any + output: Any + stats: Optional[Any] + score_mod_graph_tensors: Dict[str, Any] + workspace_size: int + + +@dataclass +class _CudnnScoreModBwdGraphEntry: + """Cached cuDNN frontend graph and graph tensor handles for score_mod bprop.""" + + graph: Any + q: Any + k: Any + v: Any + output: Any + d_output: Any + stats: Any + dq: Any + dk: Any + dv: Any + score_mod_graph_tensors: Dict[str, Any] + score_mod_bprop_graph_tensors: Dict[str, Any] + workspace_size: int + + +def _finalize_cudnn_graph(graph) -> int: + """Build a cuDNN frontend Python graph and return its workspace size.""" import cudnn # pylint: disable=import-outside-toplevel graph.validate() @@ -1336,9 +1417,22 @@ def _build_and_run_cudnn_graph(graph, variant_pack: Dict[Any, torch.Tensor], dev except cudnn.cudnnGraphNotSupportedError as exc: raise RuntimeError(f"cuDNN score_mod SDPA graph is not supported: {exc}") from exc graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE) + return max(graph.get_workspace_size(), 1) + + +def _execute_cudnn_graph( + graph, + variant_pack: Dict[Any, torch.Tensor], + workspace_size: int, + device: torch.device, +): + """Execute a built cuDNN frontend Python graph.""" + import cudnn # pylint: disable=import-outside-toplevel + if device.type == "cuda" and device.index is None: + device = torch.device("cuda", torch.cuda.current_device()) workspace = torch.empty( - max(graph.get_workspace_size(), 1), + workspace_size, device=device, dtype=torch.uint8, ) @@ -1349,6 +1443,307 @@ def _build_and_run_cudnn_graph(graph, variant_pack: Dict[Any, torch.Tensor], dev ) +def _cudnn_score_mod_fwd_cache_key( + is_training: bool, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + output_layer: torch.Tensor, + stats_bhs1: Optional[torch.Tensor], + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_tensors: Optional[Dict[str, torch.Tensor]], +) -> Tuple[Any, ...]: + """Cache key for score_mod fprop execution plans.""" + return ( + "fwd", + is_training, + q_format, + kv_format, + attn_scale, + _score_mod_callback_cache_key(score_mod), + _score_mod_bhsd_tensor_metadata(query_layer, q_format), + _score_mod_bhsd_tensor_metadata(key_layer, kv_format), + _score_mod_bhsd_tensor_metadata(value_layer, kv_format), + _score_mod_bhsd_tensor_metadata(output_layer, q_format), + _score_mod_tensor_metadata(stats_bhs1) if stats_bhs1 is not None else None, + _score_mod_tensor_dict_metadata(score_mod_tensors), + ) + + +def _cudnn_score_mod_bwd_cache_key( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + output_layer: torch.Tensor, + d_out: torch.Tensor, + stats_bhs1: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], + deterministic: bool, +) -> Tuple[Any, ...]: + """Cache key for score_mod bprop execution plans.""" + return ( + "bwd", + q_format, + kv_format, + attn_scale, + deterministic, + _score_mod_callback_cache_key(score_mod), + _score_mod_callback_cache_key(score_mod_bprop), + _score_mod_bhsd_tensor_metadata(query_layer, q_format), + _score_mod_bhsd_tensor_metadata(key_layer, kv_format), + _score_mod_bhsd_tensor_metadata(value_layer, kv_format), + _score_mod_bhsd_tensor_metadata(output_layer, q_format), + _score_mod_bhsd_tensor_metadata(d_out, q_format), + _score_mod_tensor_metadata(stats_bhs1), + _score_mod_tensor_dict_metadata(score_mod_tensors), + _score_mod_tensor_dict_metadata(score_mod_bprop_tensors), + ) + + +def _build_cudnn_score_mod_fwd_graph( + is_training: bool, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + output_layer: torch.Tensor, + stats_bhs1: Optional[torch.Tensor], +) -> _CudnnScoreModFwdGraphEntry: + """Build a cached cuDNN frontend graph for score_mod fprop.""" + import cudnn # pylint: disable=import-outside-toplevel + + _, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) + q = _bhsd_graph_tensor(graph, query_layer, q_format) + k = _bhsd_graph_tensor(graph, key_layer, kv_format) + v = _bhsd_graph_tensor(graph, value_layer, kv_format) + + score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, score_mod_tensors) + wrapped_score_mod = _wrap_score_mod(score_mod, score_mod_graph_tensors) + + output_dim, output_stride = _bhsd_dim_stride(output_layer, q_format) + output, stats = graph.sdpa( + name="te_score_mod_sdpa", + q=q, + k=k, + v=v, + generate_stats=is_training, + attn_scale=attn_scale, + use_causal_mask=False, + score_mod=wrapped_score_mod, + ) + output.set_output(True).set_dim(output_dim).set_stride(output_stride) + + if is_training: + assert stats_bhs1 is not None + stats.set_output(True).set_dim(stats_bhs1.size()).set_stride( + stats_bhs1.stride() + ).set_data_type(cudnn.data_type.FLOAT) + else: + stats = None + + workspace_size = _finalize_cudnn_graph(graph) + return _CudnnScoreModFwdGraphEntry( + graph=graph, + q=q, + k=k, + v=v, + output=output, + stats=stats, + score_mod_graph_tensors=score_mod_graph_tensors, + workspace_size=workspace_size, + ) + + +def _get_cudnn_score_mod_fwd_graph( + is_training: bool, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + output_layer: torch.Tensor, + stats_bhs1: Optional[torch.Tensor], +) -> _CudnnScoreModFwdGraphEntry: + """Return a cached cuDNN frontend graph for score_mod fprop.""" + key = _cudnn_score_mod_fwd_cache_key( + is_training, + query_layer, + key_layer, + value_layer, + output_layer, + stats_bhs1, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + ) + entry = _cudnn_score_mod_graph_cache.get(key) + if entry is None: + entry = _build_cudnn_score_mod_fwd_graph( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats_bhs1, + ) + _cudnn_score_mod_graph_cache[key] = entry + return entry + + +def _build_cudnn_score_mod_bwd_graph( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + output_layer: torch.Tensor, + d_out: torch.Tensor, + stats_bhs1: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], + deterministic: bool, +) -> _CudnnScoreModBwdGraphEntry: + """Build a cached cuDNN frontend graph for score_mod bprop.""" + _, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) + q = _bhsd_graph_tensor(graph, query_layer, q_format) + k = _bhsd_graph_tensor(graph, key_layer, kv_format) + v = _bhsd_graph_tensor(graph, value_layer, kv_format) + output = _bhsd_graph_tensor(graph, output_layer, q_format) + d_output = _bhsd_graph_tensor(graph, d_out, q_format) + stats = graph.tensor_like(stats_bhs1) + + score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, score_mod_tensors) + score_mod_bprop_graph_tensors = ( + _make_cudnn_graph_tensor_dict(graph, score_mod_bprop_tensors) + if score_mod_bprop is not None + else {} + ) + wrapped_score_mod = _wrap_score_mod(score_mod, score_mod_graph_tensors) + wrapped_score_mod_bprop = _wrap_score_mod(score_mod_bprop, score_mod_bprop_graph_tensors) + + dq_layer = torch.empty_like(query_layer) + dk_layer = torch.empty_like(key_layer) + dv_layer = torch.empty_like(value_layer) + dq_dim, dq_stride = _bhsd_dim_stride(dq_layer, q_format) + dk_dim, dk_stride = _bhsd_dim_stride(dk_layer, kv_format) + dv_dim, dv_stride = _bhsd_dim_stride(dv_layer, kv_format) + dq, dk, dv = graph.sdpa_backward( + name="te_score_mod_sdpa_backward", + q=q, + k=k, + v=v, + o=output, + dO=d_output, + stats=stats, + attn_scale=attn_scale, + use_causal_mask=False, + score_mod=wrapped_score_mod, + score_mod_bprop=wrapped_score_mod_bprop, + use_deterministic_algorithm=deterministic, + ) + dq.set_output(True).set_dim(dq_dim).set_stride(dq_stride) + dk.set_output(True).set_dim(dk_dim).set_stride(dk_stride) + dv.set_output(True).set_dim(dv_dim).set_stride(dv_stride) + + workspace_size = _finalize_cudnn_graph(graph) + return _CudnnScoreModBwdGraphEntry( + graph=graph, + q=q, + k=k, + v=v, + output=output, + d_output=d_output, + stats=stats, + dq=dq, + dk=dk, + dv=dv, + score_mod_graph_tensors=score_mod_graph_tensors, + score_mod_bprop_graph_tensors=score_mod_bprop_graph_tensors, + workspace_size=workspace_size, + ) + + +def _get_cudnn_score_mod_bwd_graph( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + output_layer: torch.Tensor, + d_out: torch.Tensor, + stats_bhs1: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], + deterministic: bool, +) -> _CudnnScoreModBwdGraphEntry: + """Return a cached cuDNN frontend graph for score_mod bprop.""" + key = _cudnn_score_mod_bwd_cache_key( + query_layer, + key_layer, + value_layer, + output_layer, + d_out, + stats_bhs1, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + deterministic, + ) + entry = _cudnn_score_mod_graph_cache.get(key) + if entry is None: + entry = _build_cudnn_score_mod_bwd_graph( + query_layer, + key_layer, + value_layer, + output_layer, + d_out, + stats_bhs1, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + deterministic, + ) + _cudnn_score_mod_graph_cache[key] = entry + return entry + + class FusedAttentionWithScoreModFunc(torch.autograd.Function): """cuDNN frontend Python SDPA path with score_mod callback support.""" @@ -1370,52 +1765,47 @@ def forward( ) -> torch.Tensor: # pylint: disable=missing-function-docstring q_bhsd_dim, _ = _bhsd_dim_stride(query_layer, q_format) - - cudnn, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) - q = _bhsd_graph_tensor(graph, query_layer, q_format) - k = _bhsd_graph_tensor(graph, key_layer, kv_format) - v = _bhsd_graph_tensor(graph, value_layer, kv_format) - - score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, score_mod_tensors) - wrapped_score_mod = _wrap_score_mod(score_mod, score_mod_graph_tensors) - output_shape = (*query_layer.shape[:-1], value_layer.shape[-1]) output_layer = torch.empty(output_shape, device=query_layer.device, dtype=query_layer.dtype) - output_dim, output_stride = _bhsd_dim_stride(output_layer, q_format) - output, stats = graph.sdpa( - name="te_score_mod_sdpa", - q=q, - k=k, - v=v, - generate_stats=is_training, - attn_scale=attn_scale, - use_causal_mask=False, - score_mod=wrapped_score_mod, - ) - output.set_output(True).set_dim(output_dim).set_stride(output_stride) - - variant_pack = { - q: query_layer, - k: key_layer, - v: value_layer, - output: output_layer, - } if is_training: stats_bhs1 = torch.empty( (*q_bhsd_dim[:-1], 1), device=query_layer.device, dtype=torch.float32, ) - stats.set_output(True).set_dim(stats_bhs1.size()).set_stride( - stats_bhs1.stride() - ).set_data_type(cudnn.data_type.FLOAT) - variant_pack[stats] = stats_bhs1 else: stats_bhs1 = None - for name, graph_tensor in score_mod_graph_tensors.items(): + + entry = _get_cudnn_score_mod_fwd_graph( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats_bhs1, + ) + variant_pack = { + entry.q: query_layer, + entry.k: key_layer, + entry.v: value_layer, + entry.output: output_layer, + } + if is_training: + variant_pack[entry.stats] = stats_bhs1 + for name, graph_tensor in entry.score_mod_graph_tensors.items(): variant_pack[graph_tensor] = score_mod_tensors[name] - _build_and_run_cudnn_graph(graph, variant_pack, query_layer.device) + _execute_cudnn_graph( + entry.graph, + variant_pack, + entry.workspace_size, + query_layer.device, + ) ctx.is_training = is_training ctx.q_format = q_format @@ -1444,66 +1834,47 @@ def backward(ctx, d_out: torch.Tensor): query_layer, key_layer, value_layer, output_layer, stats_bhs1 = ctx.saved_tensors d_out = d_out.contiguous() - cudnn, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) - q = _bhsd_graph_tensor(graph, query_layer, ctx.q_format) - k = _bhsd_graph_tensor(graph, key_layer, ctx.kv_format) - v = _bhsd_graph_tensor(graph, value_layer, ctx.kv_format) - output = _bhsd_graph_tensor(graph, output_layer, ctx.q_format) - d_output = _bhsd_graph_tensor(graph, d_out, ctx.q_format) - stats = graph.tensor_like(stats_bhs1) - - score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, ctx.score_mod_tensors) - score_mod_bprop_graph_tensors = ( - _make_cudnn_graph_tensor_dict(graph, ctx.score_mod_bprop_tensors) - if ctx.score_mod_bprop is not None - else {} - ) - wrapped_score_mod = _wrap_score_mod(ctx.score_mod, score_mod_graph_tensors) - wrapped_score_mod_bprop = _wrap_score_mod( - ctx.score_mod_bprop, score_mod_bprop_graph_tensors - ) - dq_layer = torch.empty_like(query_layer) dk_layer = torch.empty_like(key_layer) dv_layer = torch.empty_like(value_layer) - dq_dim, dq_stride = _bhsd_dim_stride(dq_layer, ctx.q_format) - dk_dim, dk_stride = _bhsd_dim_stride(dk_layer, ctx.kv_format) - dv_dim, dv_stride = _bhsd_dim_stride(dv_layer, ctx.kv_format) - dq, dk, dv = graph.sdpa_backward( - name="te_score_mod_sdpa_backward", - q=q, - k=k, - v=v, - o=output, - dO=d_output, - stats=stats, - attn_scale=ctx.attn_scale, - use_causal_mask=False, - score_mod=wrapped_score_mod, - score_mod_bprop=wrapped_score_mod_bprop, - use_deterministic_algorithm=ctx.deterministic, + entry = _get_cudnn_score_mod_bwd_graph( + query_layer, + key_layer, + value_layer, + output_layer, + d_out, + stats_bhs1, + ctx.q_format, + ctx.kv_format, + ctx.attn_scale, + ctx.score_mod, + ctx.score_mod_bprop, + ctx.score_mod_tensors, + ctx.score_mod_bprop_tensors, + ctx.deterministic, ) - dq.set_output(True).set_dim(dq_dim).set_stride(dq_stride) - dk.set_output(True).set_dim(dk_dim).set_stride(dk_stride) - dv.set_output(True).set_dim(dv_dim).set_stride(dv_stride) - variant_pack = { - q: query_layer, - k: key_layer, - v: value_layer, - output: output_layer, - d_output: d_out, - stats: stats_bhs1, - dq: dq_layer, - dk: dk_layer, - dv: dv_layer, + entry.q: query_layer, + entry.k: key_layer, + entry.v: value_layer, + entry.output: output_layer, + entry.d_output: d_out, + entry.stats: stats_bhs1, + entry.dq: dq_layer, + entry.dk: dk_layer, + entry.dv: dv_layer, } - for name, graph_tensor in score_mod_graph_tensors.items(): + for name, graph_tensor in entry.score_mod_graph_tensors.items(): variant_pack[graph_tensor] = ctx.score_mod_tensors[name] - for name, graph_tensor in score_mod_bprop_graph_tensors.items(): + for name, graph_tensor in entry.score_mod_bprop_graph_tensors.items(): variant_pack[graph_tensor] = ctx.score_mod_bprop_tensors[name] - _build_and_run_cudnn_graph(graph, variant_pack, query_layer.device) + _execute_cudnn_graph( + entry.graph, + variant_pack, + entry.workspace_size, + query_layer.device, + ) return ( None, From ac4c60d03e6192ef4e5fd0f1cd1aebeea83c2791 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Tue, 12 May 2026 22:40:47 +0000 Subject: [PATCH 07/12] Fix score_mod cache edge cases Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 45 +++++++++++++++++++ .../dot_product_attention/backends.py | 40 +++++++++++++---- .../dot_product_attention.py | 4 +- 3 files changed, 80 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index a8f1811dc0..c88634fd61 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1662,6 +1662,51 @@ def fake_build( assert len(build_entries) == 2 +def test_score_mod_tensors_are_version_checked_for_backward(monkeypatch): + """In-place score_mod tensor updates before backward should be rejected.""" + + class FakeEntry: + graph = object() + q = object() + k = object() + v = object() + output = object() + stats = object() + score_mod_graph_tensors = {"softcap": object()} + workspace_size = 1 + + def fake_execute(graph, variant_pack, workspace_size, device): + del graph, variant_pack, workspace_size, device + + q, k, v, _, _ = _score_mod_cache_cpu_inputs() + q = q.requires_grad_() + k = k.requires_grad_() + v = v.requires_grad_() + softcap = torch.tensor(0.8, dtype=torch.float32) + + monkeypatch.setattr(dpa_backends, "_get_cudnn_score_mod_fwd_graph", lambda *args: FakeEntry()) + monkeypatch.setattr(dpa_backends, "_execute_cudnn_graph", fake_execute) + + out = dpa_backends.FusedAttentionWithScoreModFunc.apply( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + None, + {"softcap": softcap}, + None, + False, + ) + softcap.add_(1.0) + + with pytest.raises(RuntimeError, match="modified by an inplace operation"): + out.sum().backward() + + def _relative_position_bias(config, dtype): """Materialize score + (q_idx - kv_idx) as post-scale attention bias.""" q_idx = torch.arange(config.max_seqlen_q, dtype=torch.float32, device="cuda").view(1, 1, -1, 1) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 66a49bbc64..2bc8e596d3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1765,6 +1765,8 @@ def forward( ) -> torch.Tensor: # pylint: disable=missing-function-docstring q_bhsd_dim, _ = _bhsd_dim_stride(query_layer, q_format) + score_mod_tensors = dict(score_mod_tensors or {}) + score_mod_bprop_tensors = dict(score_mod_bprop_tensors or {}) output_shape = (*query_layer.shape[:-1], value_layer.shape[-1]) output_layer = torch.empty(output_shape, device=query_layer.device, dtype=query_layer.dtype) if is_training: @@ -1813,11 +1815,21 @@ def forward( ctx.attn_scale = attn_scale ctx.score_mod = score_mod ctx.score_mod_bprop = score_mod_bprop - ctx.score_mod_tensors = dict(score_mod_tensors or {}) - ctx.score_mod_bprop_tensors = dict(score_mod_bprop_tensors or {}) + ctx.score_mod_tensor_names = tuple(score_mod_tensors.keys()) + ctx.score_mod_bprop_tensor_names = tuple(score_mod_bprop_tensors.keys()) ctx.deterministic = deterministic if is_training: - ctx.save_for_backward(query_layer, key_layer, value_layer, output_layer, stats_bhs1) + # save_for_backward records version counters without copying tensor data. + # This catches in-place score_mod tensor updates before backward. + ctx.save_for_backward( + query_layer, + key_layer, + value_layer, + output_layer, + stats_bhs1, + *score_mod_tensors.values(), + *score_mod_bprop_tensors.values(), + ) else: ctx.save_for_backward(query_layer, key_layer, value_layer, output_layer) @@ -1831,7 +1843,15 @@ def backward(ctx, d_out: torch.Tensor): "score_mod backward requires DotProductAttention to be in training mode." ) - query_layer, key_layer, value_layer, output_layer, stats_bhs1 = ctx.saved_tensors + saved_tensors = ctx.saved_tensors + query_layer, key_layer, value_layer, output_layer, stats_bhs1 = saved_tensors[:5] + score_mod_tensors_end = 5 + len(ctx.score_mod_tensor_names) + score_mod_tensors = dict( + zip(ctx.score_mod_tensor_names, saved_tensors[5:score_mod_tensors_end]) + ) + score_mod_bprop_tensors = dict( + zip(ctx.score_mod_bprop_tensor_names, saved_tensors[score_mod_tensors_end:]) + ) d_out = d_out.contiguous() dq_layer = torch.empty_like(query_layer) @@ -1849,8 +1869,8 @@ def backward(ctx, d_out: torch.Tensor): ctx.attn_scale, ctx.score_mod, ctx.score_mod_bprop, - ctx.score_mod_tensors, - ctx.score_mod_bprop_tensors, + score_mod_tensors, + score_mod_bprop_tensors, ctx.deterministic, ) variant_pack = { @@ -1865,9 +1885,9 @@ def backward(ctx, d_out: torch.Tensor): entry.dv: dv_layer, } for name, graph_tensor in entry.score_mod_graph_tensors.items(): - variant_pack[graph_tensor] = ctx.score_mod_tensors[name] + variant_pack[graph_tensor] = score_mod_tensors[name] for name, graph_tensor in entry.score_mod_bprop_graph_tensors.items(): - variant_pack[graph_tensor] = ctx.score_mod_bprop_tensors[name] + variant_pack[graph_tensor] = score_mod_bprop_tensors[name] _execute_cudnn_graph( entry.graph, @@ -2726,6 +2746,10 @@ def forward( assert ( not fp8 ), "score_mod is not supported with FP8 FusedAttention!" + assert not fp8_output, "score_mod is not supported with fp8_output!" + assert ( + not self.return_max_logit + ), "score_mod is not supported with return_max_logit!" assert ( type(query_layer) is torch.Tensor and type(key_layer) is torch.Tensor diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 7a7745d7c2..b887ed50a1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1249,6 +1249,9 @@ def forward( seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) + if score_mod is not None: + assert inference_params is None, "score_mod is not supported with KV caching!" + # update KV cache and retrieve saved tokens from cache for inference if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" @@ -1456,7 +1459,6 @@ def forward( assert ( not context_parallel ), "score_mod is not supported with context parallelism!" - assert inference_params is None, "score_mod is not supported with KV caching!" assert qkv_format != "thd", "score_mod is not supported with qkv_format='thd'!" assert ( not user_supplied_seqlens From 6446825aeb7870c7088a1c7724fa101b151de61e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 May 2026 03:19:08 +0000 Subject: [PATCH 08/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../dot_product_attention/backends.py | 12 ++--- .../dot_product_attention.py | 45 +++++++++---------- 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 2bc8e596d3..fbb55250ef 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -2740,16 +2740,10 @@ def forward( ) if score_mod is not None: - assert ( - not context_parallel - ), "score_mod is not supported with context parallelism!" - assert ( - not fp8 - ), "score_mod is not supported with FP8 FusedAttention!" + assert not context_parallel, "score_mod is not supported with context parallelism!" + assert not fp8, "score_mod is not supported with FP8 FusedAttention!" assert not fp8_output, "score_mod is not supported with fp8_output!" - assert ( - not self.return_max_logit - ), "score_mod is not supported with return_max_logit!" + assert not self.return_max_logit, "score_mod is not supported with return_max_logit!" assert ( type(query_layer) is torch.Tensor and type(key_layer) is torch.Tensor diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index b887ed50a1..95a0b53b4a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1435,15 +1435,18 @@ def forward( if score_mod is None: assert score_mod_bprop is None, "score_mod_bprop requires score_mod!" assert score_mod_tensors is None, "score_mod_tensors requires score_mod!" - assert score_mod_bprop_tensors is None, "score_mod_bprop_tensors requires score_mod!" + assert ( + score_mod_bprop_tensors is None + ), "score_mod_bprop_tensors requires score_mod!" else: assert callable(score_mod), "score_mod must be callable!" - assert ( - score_mod_bprop is None or callable(score_mod_bprop) + assert score_mod_bprop is None or callable( + score_mod_bprop ), "score_mod_bprop must be callable when provided!" - assert ( - query_layer.dtype in [torch.float16, torch.bfloat16] - ), "score_mod only supports FP16 and BF16 tensors!" + assert query_layer.dtype in [ + torch.float16, + torch.bfloat16, + ], "score_mod only supports FP16 and BF16 tensors!" assert ( key_layer.dtype == query_layer.dtype and value_layer.dtype == query_layer.dtype ), "score_mod requires Q, K and V tensors to have the same dtype!" @@ -1452,24 +1455,21 @@ def forward( and type(key_layer) is torch.Tensor and type(value_layer) is torch.Tensor ), "score_mod only supports unquantized torch.Tensor Q, K and V inputs!" - assert ( - not self.fp8 - ), "score_mod is not supported with FP8 DotProductAttention!" + assert not self.fp8, "score_mod is not supported with FP8 DotProductAttention!" assert not fp8_output, "score_mod is not supported with fp8_output!" - assert ( - not context_parallel - ), "score_mod is not supported with context parallelism!" + assert not context_parallel, "score_mod is not supported with context parallelism!" assert qkv_format != "thd", "score_mod is not supported with qkv_format='thd'!" assert ( not user_supplied_seqlens ), "score_mod is mutually exclusive with explicit sequence length metadata!" assert not pad_between_seqs, "score_mod is not supported with pad_between_seqs!" - assert attention_mask is None, "score_mod is mutually exclusive with attention_mask!" assert ( - attn_mask_type == "no_mask" - ), "score_mod requires attn_mask_type='no_mask'!" - assert ( - window_size is None or window_size == (-1, -1) + attention_mask is None + ), "score_mod is mutually exclusive with attention_mask!" + assert attn_mask_type == "no_mask", "score_mod requires attn_mask_type='no_mask'!" + assert window_size is None or window_size == ( + -1, + -1, ), "score_mod is mutually exclusive with sliding window attention!" assert ( core_attention_bias_type == "no_bias" and core_attention_bias is None @@ -1491,13 +1491,12 @@ def forward( not is_graph_capturing() ), "score_mod is not supported with CUDA graph capture!" assert num_splits == 1, "score_mod is not supported with num_splits != 1!" - assert ( - q_format in ["sbhd", "bshd"] and kv_format in ["sbhd", "bshd"] - ), "score_mod only supports SBHD/BSHD QKV formats!" + assert q_format in ["sbhd", "bshd"] and kv_format in [ + "sbhd", + "bshd", + ], "score_mod only supports SBHD/BSHD QKV formats!" if score_mod_tensors is not None: - assert isinstance( - score_mod_tensors, dict - ), "score_mod_tensors must be a dict!" + assert isinstance(score_mod_tensors, dict), "score_mod_tensors must be a dict!" assert all( isinstance(k, str) and isinstance(v, torch.Tensor) for k, v in score_mod_tensors.items() From 58a5fb57ef08c9315313b6c36c4e558d9c635658 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 15 May 2026 00:48:58 +0000 Subject: [PATCH 09/12] Fix score_mod callback graph cache keys Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 124 +++++++++++++++- .../dot_product_attention/backends.py | 133 ++++++++++++++++-- 2 files changed, 245 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index c88634fd61..253b3e4640 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1465,6 +1465,10 @@ class _ScoreModSoftcap: def __init__(self): self.before_tanh_activation = None + def score_mod_graph_cache_key(self): + """Graph topology key for softcap score_mod.""" + return ("softcap",) + def forward(self, score_mod_graph, score_tensor, tensors): """Apply softcap * tanh(score / softcap).""" import cudnn # pylint: disable=import-outside-toplevel @@ -1516,15 +1520,48 @@ def _score_mod_cache_cpu_inputs(): return q, k, v, o, stats -def test_score_mod_cache_bound_method_key_stable(): - """Bound method keys should be stable across repeated attribute access.""" +def test_score_mod_cache_bound_method_requires_explicit_key(): + """Unkeyed bound methods should be uncached instead of keyed by object id.""" + + class UnkeyedScoreMod: + def forward(self, _score_mod_graph, score_tensor, _tensors): + return score_tensor + + key = dpa_backends._score_mod_callback_cache_key(UnkeyedScoreMod().forward) + + assert key is dpa_backends._SCORE_MOD_UNCACHEABLE + + +def test_score_mod_cache_bound_method_explicit_key_stable(): + """Bound method keys should be stable when a structural graph key is provided.""" softcap = _ScoreModSoftcap() key_0 = dpa_backends._score_mod_callback_cache_key(softcap.forward) key_1 = dpa_backends._score_mod_callback_cache_key(softcap.forward) other_key = dpa_backends._score_mod_callback_cache_key(_ScoreModSoftcap().forward) assert key_0 == key_1 - assert key_0 != other_key + assert key_0 == other_key + + +def test_score_mod_cache_explicit_key_distinguishes_topology(): + """Stateful score_mods can opt into caching with topology-specific keys.""" + + class LayeredScoreMod: + def __init__(self, num_layers): + self.num_layers = num_layers + + def score_mod_graph_cache_key(self): + return {"num_layers": self.num_layers} + + def forward(self, _score_mod_graph, score_tensor, _tensors): + return score_tensor + + key_0 = dpa_backends._score_mod_callback_cache_key(LayeredScoreMod(1).forward) + key_1 = dpa_backends._score_mod_callback_cache_key(LayeredScoreMod(1).forward) + key_2 = dpa_backends._score_mod_callback_cache_key(LayeredScoreMod(2).forward) + + assert key_0 == key_1 + assert key_0 != key_2 def test_score_mod_cache_key_ignores_pass_by_value_values(): @@ -1662,6 +1699,87 @@ def fake_build( assert len(build_entries) == 2 +def test_score_mod_cache_fwd_skips_cache_for_unkeyed_bound_method(monkeypatch): + """Unkeyed bound methods should build fresh graphs instead of using an id-based key.""" + + class UnkeyedScoreMod: + def forward(self, _score_mod_graph, score_tensor, _tensors): + return score_tensor + + q, k, v, o, stats = _score_mod_cache_cpu_inputs() + score_mod = UnkeyedScoreMod() + cache = dpa_backends._cudnn_score_mod_graph_cache + saved_cache = dict(cache) + build_entries = [] + + def fake_build( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats_bhs1, + ): + del ( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats_bhs1, + ) + entry = object() + build_entries.append(entry) + return entry + + monkeypatch.setattr(dpa_backends, "_build_cudnn_score_mod_fwd_graph", fake_build) + try: + cache.clear() + entry_0 = dpa_backends._get_cudnn_score_mod_fwd_graph( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + score_mod.forward, + None, + o, + stats, + ) + entry_1 = dpa_backends._get_cudnn_score_mod_fwd_graph( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + score_mod.forward, + None, + o, + stats, + ) + assert len(cache) == 0 + finally: + cache.clear() + cache.update(saved_cache) + + assert entry_0 is not entry_1 + assert len(build_entries) == 2 + + def test_score_mod_tensors_are_version_checked_for_backward(monkeypatch): """In-place score_mod tensor updates before backward should be rejected.""" diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index fbb55250ef..22621c1b8d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -5,6 +5,7 @@ """Attention Backends.""" from contextlib import nullcontext from dataclasses import dataclass +import inspect from importlib.metadata import version as get_pkg_version from importlib.metadata import PackageNotFoundError import os @@ -92,6 +93,7 @@ _flash_attn_varlen_bwd = None _cudnn_score_mod_handles: Dict[torch.device, Any] = {} _cudnn_score_mod_graph_cache: Dict[Tuple[Any, ...], Any] = {} +_SCORE_MOD_UNCACHEABLE = object() # Try to import Flash Attention v2 try: @@ -1270,15 +1272,87 @@ def _bhsd_graph_tensor(graph, tensor: torch.Tensor, tensor_format: str): return graph.tensor(dim=dim, stride=stride, data_type=tensor.dtype) -def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Optional[Tuple[Any, ...]]: - """Create a stable cache key for a score_mod callable.""" +def _freeze_score_mod_cache_key(value: Any) -> Any: + """Convert a user-provided score_mod graph key into a hashable structure.""" + if isinstance(value, torch.Tensor): + raise TypeError( + "score_mod_graph_cache_key() must not include tensors. Pass runtime tensors " + "through score_mod_tensors or score_mod_bprop_tensors instead." + ) + if isinstance(value, dict): + items = ( + ( + _freeze_score_mod_cache_key(key), + _freeze_score_mod_cache_key(val), + ) + for key, val in value.items() + ) + return tuple(sorted(items, key=repr)) + if isinstance(value, (list, tuple)): + return tuple(_freeze_score_mod_cache_key(item) for item in value) + if isinstance(value, (set, frozenset)): + items = (_freeze_score_mod_cache_key(item) for item in value) + return tuple(sorted(items, key=repr)) + try: + hash(value) + except TypeError as exc: + raise TypeError( + "score_mod_graph_cache_key() must return a hashable value or a nested " + "combination of dict/list/tuple/set values." + ) from exc + return value + + +def _score_mod_explicit_cache_key(callback_owner: Any) -> Optional[Any]: + """Return a user-provided structural graph key for a score_mod callback.""" + explicit_key = getattr(callback_owner, "score_mod_graph_cache_key", None) + if explicit_key is None: + return None + explicit_key = explicit_key() if callable(explicit_key) else explicit_key + return _freeze_score_mod_cache_key(explicit_key) + + +def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Any: + """Create a stable graph cache key for a score_mod callable. + + Module-level functions are assumed to have stable topology. Stateful bound methods and + callable instances need an explicit score_mod_graph_cache_key(); otherwise their graphs + are left uncached to avoid reusing stale graphs after Python object address reuse. + """ if callback is None: return None self_obj = getattr(callback, "__self__", None) func_obj = getattr(callback, "__func__", None) if self_obj is not None and func_obj is not None: - return ("bound_method", id(self_obj), id(func_obj)) - return ("callable", id(callback)) + explicit_key = _score_mod_explicit_cache_key(self_obj) + if explicit_key is None: + return _SCORE_MOD_UNCACHEABLE + return ( + "bound_method", + type(self_obj), + func_obj.__module__, + func_obj.__qualname__, + explicit_key, + ) + + explicit_key = _score_mod_explicit_cache_key(callback) + if explicit_key is not None: + return ( + "callable", + type(callback), + getattr(callback, "__module__", None), + getattr(callback, "__qualname__", None), + explicit_key, + ) + + if ( + inspect.isfunction(callback) + and callback.__closure__ is None + and "" not in callback.__qualname__ + ): + return ("function", callback.__module__, callback.__qualname__) + + return _SCORE_MOD_UNCACHEABLE def _score_mod_device_key(device: torch.device) -> Tuple[Any, ...]: @@ -1455,15 +1529,18 @@ def _cudnn_score_mod_fwd_cache_key( attn_scale: float, score_mod: Callable, score_mod_tensors: Optional[Dict[str, torch.Tensor]], -) -> Tuple[Any, ...]: +) -> Optional[Tuple[Any, ...]]: """Cache key for score_mod fprop execution plans.""" + score_mod_key = _score_mod_callback_cache_key(score_mod) + if score_mod_key is _SCORE_MOD_UNCACHEABLE: + return None return ( "fwd", is_training, q_format, kv_format, attn_scale, - _score_mod_callback_cache_key(score_mod), + score_mod_key, _score_mod_bhsd_tensor_metadata(query_layer, q_format), _score_mod_bhsd_tensor_metadata(key_layer, kv_format), _score_mod_bhsd_tensor_metadata(value_layer, kv_format), @@ -1488,16 +1565,23 @@ def _cudnn_score_mod_bwd_cache_key( score_mod_tensors: Optional[Dict[str, torch.Tensor]], score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], deterministic: bool, -) -> Tuple[Any, ...]: +) -> Optional[Tuple[Any, ...]]: """Cache key for score_mod bprop execution plans.""" + score_mod_key = _score_mod_callback_cache_key(score_mod) + score_mod_bprop_key = _score_mod_callback_cache_key(score_mod_bprop) + if ( + score_mod_key is _SCORE_MOD_UNCACHEABLE + or score_mod_bprop_key is _SCORE_MOD_UNCACHEABLE + ): + return None return ( "bwd", q_format, kv_format, attn_scale, deterministic, - _score_mod_callback_cache_key(score_mod), - _score_mod_callback_cache_key(score_mod_bprop), + score_mod_key, + score_mod_bprop_key, _score_mod_bhsd_tensor_metadata(query_layer, q_format), _score_mod_bhsd_tensor_metadata(key_layer, kv_format), _score_mod_bhsd_tensor_metadata(value_layer, kv_format), @@ -1594,6 +1678,20 @@ def _get_cudnn_score_mod_fwd_graph( score_mod, score_mod_tensors, ) + if key is None: + return _build_cudnn_score_mod_fwd_graph( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats_bhs1, + ) entry = _cudnn_score_mod_graph_cache.get(key) if entry is None: entry = _build_cudnn_score_mod_fwd_graph( @@ -1722,6 +1820,23 @@ def _get_cudnn_score_mod_bwd_graph( score_mod_bprop_tensors, deterministic, ) + if key is None: + return _build_cudnn_score_mod_bwd_graph( + query_layer, + key_layer, + value_layer, + output_layer, + d_out, + stats_bhs1, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + deterministic, + ) entry = _cudnn_score_mod_graph_cache.get(key) if entry is None: entry = _build_cudnn_score_mod_bwd_graph( From c00a0b786ea7d63e33b797490de78572027ea2f3 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 15 May 2026 01:18:49 +0000 Subject: [PATCH 10/12] Address score_mod review feedback Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 354 +++++++----------- tests/pytorch/utils.py | 4 + .../dot_product_attention/backends.py | 170 +++++---- .../dot_product_attention.py | 77 ++-- .../attention/dot_product_attention/utils.py | 120 +++++- 5 files changed, 383 insertions(+), 342 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 253b3e4640..959350e080 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1393,7 +1393,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: def _score_mod_causal(score_mod_graph, score_tensor, tensors): """cuDNN frontend score_mod implementing top-left causal masking.""" - import cudnn # pylint: disable=import-outside-toplevel + cudnn = dpa_backends._import_cudnn_frontend() row_index = score_mod_graph.gen_index(input=score_tensor, axis=2) row_index.set_data_type(cudnn.data_type.INT32) @@ -1414,7 +1414,7 @@ def _score_mod_causal(score_mod_graph, score_tensor, tensors): def _score_mod_causal_bprop(score_mod_graph, dP_tensor, tensors): """cuDNN frontend score_mod_bprop implementing top-left causal masking.""" - import cudnn # pylint: disable=import-outside-toplevel + cudnn = dpa_backends._import_cudnn_frontend() row_index = score_mod_graph.gen_index(input=dP_tensor, axis=2) row_index.set_data_type(cudnn.data_type.INT32) @@ -1433,23 +1433,23 @@ def _score_mod_causal_bprop(score_mod_graph, dP_tensor, tensors): ) -def _score_mod_relative_position(score_mod_graph, score_tensor, _tensors): - """cuDNN frontend score_mod adding relative position bias.""" - import cudnn # pylint: disable=import-outside-toplevel +def _score_mod_post_scale_bias(score_mod_graph, score_tensor, _tensors): + """cuDNN frontend score_mod adding post-scale bias.""" + cudnn = dpa_backends._import_cudnn_frontend() row_index = score_mod_graph.gen_index(input=score_tensor, axis=2) row_index.set_data_type(cudnn.data_type.INT32) col_index = score_mod_graph.gen_index(input=score_tensor, axis=3) col_index.set_data_type(cudnn.data_type.INT32) - relative_position = score_mod_graph.sub( + post_scale_bias = score_mod_graph.sub( a=row_index, b=col_index, compute_data_type=cudnn.data_type.FLOAT, ) - relative_position.set_data_type(cudnn.data_type.FLOAT) + post_scale_bias.set_data_type(cudnn.data_type.FLOAT) return score_mod_graph.add( a=score_tensor, - b=relative_position, + b=post_scale_bias, compute_data_type=cudnn.data_type.FLOAT, ) @@ -1471,7 +1471,7 @@ def score_mod_graph_cache_key(self): def forward(self, score_mod_graph, score_tensor, tensors): """Apply softcap * tanh(score / softcap).""" - import cudnn # pylint: disable=import-outside-toplevel + cudnn = dpa_backends._import_cudnn_frontend() self.before_tanh_activation = score_mod_graph.div( a=score_tensor, @@ -1489,7 +1489,7 @@ def forward(self, score_mod_graph, score_tensor, tensors): def backward(self, score_mod_graph, dP_tensor, tensors): """Apply softcap derivative to dP.""" - import cudnn # pylint: disable=import-outside-toplevel + cudnn = dpa_backends._import_cudnn_frontend() d_tanh_out = score_mod_graph.mul( a=dP_tensor, @@ -1629,7 +1629,7 @@ def fake_build( score_mod, score_mod_tensors, output_layer, - stats_bhs1, + stats, ): del ( is_training, @@ -1642,7 +1642,7 @@ def fake_build( score_mod, score_mod_tensors, output_layer, - stats_bhs1, + stats, ) entry = object() build_entries.append(entry) @@ -1723,7 +1723,7 @@ def fake_build( score_mod, score_mod_tensors, output_layer, - stats_bhs1, + stats, ): del ( is_training, @@ -1736,7 +1736,7 @@ def fake_build( score_mod, score_mod_tensors, output_layer, - stats_bhs1, + stats, ) entry = object() build_entries.append(entry) @@ -1825,7 +1825,7 @@ def fake_execute(graph, variant_pack, workspace_size, device): out.sum().backward() -def _relative_position_bias(config, dtype): +def _post_scale_bias(config, dtype): """Materialize score + (q_idx - kv_idx) as post-scale attention bias.""" q_idx = torch.arange(config.max_seqlen_q, dtype=torch.float32, device="cuda").view(1, 1, -1, 1) kv_idx = torch.arange(config.max_seqlen_kv, dtype=torch.float32, device="cuda").view( @@ -1864,116 +1864,64 @@ def _pytorch_softcap_attention(q, k, v, qkv_format, softmax_scale, softcap): @pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"]) -@pytest.mark.parametrize("scalar_loss", [False, True]) -def test_dot_product_attention_score_mod(dtype, qkv_format, scalar_loss): - """Compare score_mod causal masking against standard cuDNN causal attention.""" +@pytest.mark.parametrize( + "score_mod_case, scalar_loss", + [ + ("causal", False), + ("causal", True), + ("softcap", False), + ("post_scale_bias", False), + ], +) +def test_dot_product_attention_score_mod(dtype, qkv_format, score_mod_case, scalar_loss): + """Compare score_mod attention against equivalent reference implementations.""" try: - import cudnn # pylint: disable=unused-import,import-outside-toplevel + dpa_backends._import_cudnn_frontend() except ImportError: pytest.skip("cuDNN Python frontend is required for score_mod attention.") reset_rng_states() - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "1" - os.environ["NVTE_UNFUSED_ATTN"] = "0" - _attention_backends["backend_selection_requires_update"] = True - config = ModelConfig(2, 64, 4, 64, attn_mask_type="no_mask") + config = ModelConfig( + 2, + 64 if score_mod_case == "causal" else 16, + 4, + 64, + attn_mask_type="no_mask", + ) available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", + score_mod=True, + score_mod_bprop=True, ) if not available_backends[1] or not fused_attn_backends: pytest.skip("FusedAttention is not available for this score_mod configuration.") - if qkv_format == "sbhd": - q_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim_qk) - kv_shape = q_shape - else: - q_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk) - kv_shape = q_shape - - q = (0.1 * torch.randn(q_shape, dtype=dtype, device="cuda")).requires_grad_() - k = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() - v = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() - q_ref, k_ref, v_ref = [x.detach().clone().requires_grad_() for x in (q, k, v)] - - flex_attn = DotProductAttention( - config.num_heads, - config.head_dim_qk, - qkv_format=qkv_format, - attn_mask_type="no_mask", - layer_number=1, - ).to(dtype=dtype, device="cuda") - ref_attn = DotProductAttention( - config.num_heads, - config.head_dim_qk, - qkv_format=qkv_format, - attn_mask_type="causal", - layer_number=1, - ).to(dtype=dtype, device="cuda") - - out = flex_attn( - q, - k, - v, - qkv_format=qkv_format, - attn_mask_type="no_mask", - score_mod=_score_mod_causal, - score_mod_bprop=_score_mod_causal_bprop, - score_mod_tensors={"neg_inf": torch.full((1, 1, 1, 1), -1e9)}, - score_mod_bprop_tensors={"zero": torch.full((1, 1, 1, 1), 0.0)}, - ) - out_ref = ref_attn( - q_ref, - k_ref, - v_ref, - qkv_format=qkv_format, - attn_mask_type="causal", - ) - - if scalar_loss: - out.sum().backward() - out_ref.sum().backward() - else: - d_out = torch.randn_like(out) - out.backward(d_out) - out_ref.backward(d_out) - - tols = dict(atol=5e-2, rtol=5e-2) - torch.testing.assert_close(out, out_ref, **tols) - torch.testing.assert_close(q.grad, q_ref.grad, **tols) - torch.testing.assert_close(k.grad, k_ref.grad, **tols) - torch.testing.assert_close(v.grad, v_ref.grad, **tols) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") -@pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") -@pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"]) -def test_dot_product_attention_score_mod_softcap(dtype, qkv_format): - """Compare softcap score_mod against PyTorch math attention.""" - try: - import cudnn # pylint: disable=unused-import,import-outside-toplevel - except ImportError: - pytest.skip("cuDNN Python frontend is required for score_mod attention.") + if score_mod_case == "post_scale_bias": + bias_config = ModelConfig( + config.batch_size, + config.max_seqlen_q, + config.num_heads, + config.head_dim_qk, + attn_mask_type="no_mask", + attn_bias_type="post_scale_bias", + bias_shape="1hss", + ) + bias_available_backends, _, bias_fused_attn_backends = get_available_attention_backends( + bias_config, + qkv_dtype=dtype, + qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", + ) + if not bias_available_backends[1] or not bias_fused_attn_backends: + pytest.skip("FusedAttention is not available for post_scale_bias reference.") - reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True - config = ModelConfig(2, 16, 4, 64, attn_mask_type="no_mask") - available_backends, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", - ) - if not available_backends[1] or not fused_attn_backends: - pytest.skip("FusedAttention is not available for this softcap configuration.") - if qkv_format == "sbhd": q_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim_qk) kv_shape = q_shape @@ -1981,14 +1929,16 @@ def test_dot_product_attention_score_mod_softcap(dtype, qkv_format): q_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk) kv_shape = q_shape - q = torch.randn(q_shape, dtype=dtype, device="cuda").requires_grad_() - k = torch.randn(kv_shape, dtype=dtype, device="cuda").requires_grad_() - v = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + if score_mod_case == "softcap": + q = torch.randn(q_shape, dtype=dtype, device="cuda").requires_grad_() + k = torch.randn(kv_shape, dtype=dtype, device="cuda").requires_grad_() + v = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + else: + q = (0.1 * torch.randn(q_shape, dtype=dtype, device="cuda")).requires_grad_() + k = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + v = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() q_ref, k_ref, v_ref = [x.detach().clone().requires_grad_() for x in (q, k, v)] - softcap = 0.8 - softcap_tensor = torch.full((1, 1, 1, 1), softcap) - softcap_score_mod = _ScoreModSoftcap() flex_attn = DotProductAttention( config.num_heads, config.head_dim_qk, @@ -1997,109 +1947,64 @@ def test_dot_product_attention_score_mod_softcap(dtype, qkv_format): layer_number=1, ).to(dtype=dtype, device="cuda") - out = flex_attn( - q, - k, - v, - qkv_format=qkv_format, - attn_mask_type="no_mask", - score_mod=softcap_score_mod.forward, - score_mod_bprop=softcap_score_mod.backward, - score_mod_tensors={"softcap": softcap_tensor}, - score_mod_bprop_tensors={"softcap": softcap_tensor}, - ) - out_ref = _pytorch_softcap_attention( - q_ref, - k_ref, - v_ref, - qkv_format, - 1.0 / config.head_dim_qk**0.5, - softcap, - ) - - d_out = torch.randn_like(out) - out.backward(d_out) - out_ref.backward(d_out) - - tols = dict(atol=7e-2, rtol=7e-2) - torch.testing.assert_close(out, out_ref, **tols) - torch.testing.assert_close(q.grad, q_ref.grad, **tols) - torch.testing.assert_close(k.grad, k_ref.grad, **tols) - torch.testing.assert_close(v.grad, v_ref.grad, **tols) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") -@pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") -@pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"]) -def test_dot_product_attention_score_mod_relative_position(dtype, qkv_format): - """Compare relative-position score_mod against materialized post-scale bias.""" - try: - import cudnn # pylint: disable=unused-import,import-outside-toplevel - except ImportError: - pytest.skip("cuDNN Python frontend is required for score_mod attention.") - - reset_rng_states() - - config = ModelConfig(2, 16, 4, 64, attn_mask_type="no_mask") - bias_config = ModelConfig( - config.batch_size, - config.max_seqlen_q, - config.num_heads, - config.head_dim_qk, - attn_mask_type="no_mask", - attn_bias_type="post_scale_bias", - bias_shape="1hss", - ) - available_backends, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", - ) - bias_available_backends, _, bias_fused_attn_backends = get_available_attention_backends( - bias_config, - qkv_dtype=dtype, - qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", - ) - if ( - not available_backends[1] - or not fused_attn_backends - or not bias_available_backends[1] - or not bias_fused_attn_backends - ): - pytest.skip("FusedAttention is not available for this relative-position configuration.") - - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "1" - os.environ["NVTE_UNFUSED_ATTN"] = "0" - _attention_backends["backend_selection_requires_update"] = True - - if qkv_format == "sbhd": - q_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim_qk) - kv_shape = q_shape + if score_mod_case == "causal": + score_mod_kwargs = { + "score_mod": _score_mod_causal, + "score_mod_bprop": _score_mod_causal_bprop, + "score_mod_tensors": {"neg_inf": torch.full((1, 1, 1, 1), -1e9)}, + "score_mod_bprop_tensors": {"zero": torch.full((1, 1, 1, 1), 0.0)}, + } + ref_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="causal", + layer_number=1, + ).to(dtype=dtype, device="cuda") + out_ref = ref_attn(q_ref, k_ref, v_ref, qkv_format=qkv_format, attn_mask_type="causal") + tols = dict(atol=5e-2, rtol=5e-2) + elif score_mod_case == "softcap": + softcap = 0.8 + softcap_tensor = torch.full((1, 1, 1, 1), softcap) + softcap_score_mod = _ScoreModSoftcap() + score_mod_kwargs = { + "score_mod": softcap_score_mod.forward, + "score_mod_bprop": softcap_score_mod.backward, + "score_mod_tensors": {"softcap": softcap_tensor}, + "score_mod_bprop_tensors": {"softcap": softcap_tensor}, + } + out_ref = _pytorch_softcap_attention( + q_ref, + k_ref, + v_ref, + qkv_format, + 1.0 / config.head_dim_qk**0.5, + softcap, + ) + tols = dict(atol=7e-2, rtol=7e-2) else: - q_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk) - kv_shape = q_shape - - q = (0.1 * torch.randn(q_shape, dtype=dtype, device="cuda")).requires_grad_() - k = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() - v = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() - q_ref, k_ref, v_ref = [x.detach().clone().requires_grad_() for x in (q, k, v)] - - flex_attn = DotProductAttention( - config.num_heads, - config.head_dim_qk, - qkv_format=qkv_format, - attn_mask_type="no_mask", - layer_number=1, - ).to(dtype=dtype, device="cuda") - ref_attn = DotProductAttention( - config.num_heads, - config.head_dim_qk, - qkv_format=qkv_format, - attn_mask_type="no_mask", - layer_number=1, - ).to(dtype=dtype, device="cuda") + assert score_mod_case == "post_scale_bias" + score_mod_kwargs = { + "score_mod": _score_mod_post_scale_bias, + "score_mod_bprop": _score_mod_identity_bprop, + } + ref_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="no_mask", + layer_number=1, + ).to(dtype=dtype, device="cuda") + out_ref = ref_attn( + q_ref, + k_ref, + v_ref, + qkv_format=qkv_format, + attn_mask_type="no_mask", + core_attention_bias_type="post_scale_bias", + core_attention_bias=_post_scale_bias(config, dtype), + ) + tols = dict(atol=5e-2, rtol=5e-2) out = flex_attn( q, @@ -2107,24 +2012,17 @@ def test_dot_product_attention_score_mod_relative_position(dtype, qkv_format): v, qkv_format=qkv_format, attn_mask_type="no_mask", - score_mod=_score_mod_relative_position, - score_mod_bprop=_score_mod_identity_bprop, - ) - out_ref = ref_attn( - q_ref, - k_ref, - v_ref, - qkv_format=qkv_format, - attn_mask_type="no_mask", - core_attention_bias_type="post_scale_bias", - core_attention_bias=_relative_position_bias(config, dtype), + **score_mod_kwargs, ) - d_out = torch.randn_like(out) - out.backward(d_out) - out_ref.backward(d_out) + if scalar_loss: + out.sum().backward() + out_ref.sum().backward() + else: + d_out = torch.randn_like(out) + out.backward(d_out) + out_ref.backward(d_out) - tols = dict(atol=5e-2, rtol=5e-2) torch.testing.assert_close(out, out_ref, **tols) torch.testing.assert_close(q.grad, q_ref.grad, **tols) torch.testing.assert_close(k.grad, k_ref.grad, **tols) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 32e44be2af..16f6c08bcf 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -329,6 +329,8 @@ def get_available_attention_backends( fp8_meta: Optional[Dict[str, Any]] = None, is_training: bool = True, inference_params: Optional[InferenceParams] = None, + score_mod: bool = False, + score_mod_bprop: bool = False, ) -> Tuple[List, List]: """Check for all available attention backends that support a model configuration""" @@ -390,6 +392,8 @@ def test(): inference_params=inference_params, softmax_type=config.softmax_type, return_max_logit=config.return_max_logit, + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, # allow all backends to pass so they can be used for testing; # check for FA3 availability later num_splits=1, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 22621c1b8d..f481d39a22 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -5,10 +5,13 @@ """Attention Backends.""" from contextlib import nullcontext from dataclasses import dataclass +import importlib import inspect from importlib.metadata import version as get_pkg_version from importlib.metadata import PackageNotFoundError import os +from pathlib import Path +import sys from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings import logging @@ -94,6 +97,18 @@ _cudnn_score_mod_handles: Dict[torch.device, Any] = {} _cudnn_score_mod_graph_cache: Dict[Tuple[Any, ...], Any] = {} _SCORE_MOD_UNCACHEABLE = object() +_CUDNN_FRONTEND_PYTHON_PATH = ( + Path(__file__).resolve().parents[4] / "3rdparty" / "cudnn-frontend" / "python" +) + + +def _import_cudnn_frontend(): + """Import the vendored cuDNN frontend if built, otherwise use the installed package.""" + cudnn_frontend_path = str(_CUDNN_FRONTEND_PYTHON_PATH) + cudnn_frontend_package = _CUDNN_FRONTEND_PYTHON_PATH / "cudnn" + if any(cudnn_frontend_package.glob("_compiled_module*")) and cudnn_frontend_path not in sys.path: + sys.path.insert(0, cudnn_frontend_path) + return importlib.import_module("cudnn") # Try to import Flash Attention v2 try: @@ -1263,7 +1278,9 @@ def _bhsd_dim_stride( (tensor.shape[0], tensor.shape[2], tensor.shape[1], tensor.shape[3]), (tensor.stride(0), tensor.stride(2), tensor.stride(1), tensor.stride(3)), ) - raise ValueError(f"score_mod only supports SBHD/BSHD tensor formats, got {tensor_format}.") + raise ValueError( + f"score_mod only supports SBHD/BSHD tensor formats, got {tensor_format}." + ) def _bhsd_graph_tensor(graph, tensor: torch.Tensor, tensor_format: str): @@ -1272,6 +1289,7 @@ def _bhsd_graph_tensor(graph, tensor: torch.Tensor, tensor_format: str): return graph.tensor(dim=dim, stride=stride, data_type=tensor.dtype) +# score_mod graph cache helpers. def _freeze_score_mod_cache_key(value: Any) -> Any: """Convert a user-provided score_mod graph key into a hashable structure.""" if isinstance(value, torch.Tensor): @@ -1397,6 +1415,7 @@ def _make_cudnn_graph_tensor_dict(graph, tensors: Optional[Dict[str, torch.Tenso return {name: graph.tensor_like(tensor) for name, tensor in tensors.items()} +# score_mod cuDNN frontend graph helpers. def _wrap_score_mod(score_mod: Optional[Callable], graph_tensors: Dict[str, Any]): """Adapt TE's score_mod signature to cuDNN frontend's two-argument callback.""" if score_mod is None: @@ -1428,7 +1447,7 @@ def _get_cudnn_current_stream_handle(cudnn, device: torch.device): def _build_cudnn_pygraph(dtype: torch.dtype, device: torch.device): """Create a cuDNN frontend Python graph for F16/BF16 SDPA.""" - import cudnn # pylint: disable=import-outside-toplevel + cudnn = _import_cudnn_frontend() if dtype == torch.float16: io_data_type = cudnn.data_type.HALF @@ -1481,7 +1500,7 @@ class _CudnnScoreModBwdGraphEntry: def _finalize_cudnn_graph(graph) -> int: """Build a cuDNN frontend Python graph and return its workspace size.""" - import cudnn # pylint: disable=import-outside-toplevel + cudnn = _import_cudnn_frontend() graph.validate() graph.build_operation_graph() @@ -1501,7 +1520,7 @@ def _execute_cudnn_graph( device: torch.device, ): """Execute a built cuDNN frontend Python graph.""" - import cudnn # pylint: disable=import-outside-toplevel + cudnn = _import_cudnn_frontend() if device.type == "cuda" and device.index is None: device = torch.device("cuda", torch.cuda.current_device()) @@ -1523,7 +1542,7 @@ def _cudnn_score_mod_fwd_cache_key( key_layer: torch.Tensor, value_layer: torch.Tensor, output_layer: torch.Tensor, - stats_bhs1: Optional[torch.Tensor], + stats: Optional[torch.Tensor], q_format: str, kv_format: str, attn_scale: float, @@ -1545,7 +1564,7 @@ def _cudnn_score_mod_fwd_cache_key( _score_mod_bhsd_tensor_metadata(key_layer, kv_format), _score_mod_bhsd_tensor_metadata(value_layer, kv_format), _score_mod_bhsd_tensor_metadata(output_layer, q_format), - _score_mod_tensor_metadata(stats_bhs1) if stats_bhs1 is not None else None, + _score_mod_tensor_metadata(stats) if stats is not None else None, _score_mod_tensor_dict_metadata(score_mod_tensors), ) @@ -1556,7 +1575,7 @@ def _cudnn_score_mod_bwd_cache_key( value_layer: torch.Tensor, output_layer: torch.Tensor, d_out: torch.Tensor, - stats_bhs1: torch.Tensor, + stats: torch.Tensor, q_format: str, kv_format: str, attn_scale: float, @@ -1587,7 +1606,7 @@ def _cudnn_score_mod_bwd_cache_key( _score_mod_bhsd_tensor_metadata(value_layer, kv_format), _score_mod_bhsd_tensor_metadata(output_layer, q_format), _score_mod_bhsd_tensor_metadata(d_out, q_format), - _score_mod_tensor_metadata(stats_bhs1), + _score_mod_tensor_metadata(stats), _score_mod_tensor_dict_metadata(score_mod_tensors), _score_mod_tensor_dict_metadata(score_mod_bprop_tensors), ) @@ -1604,10 +1623,10 @@ def _build_cudnn_score_mod_fwd_graph( score_mod: Callable, score_mod_tensors: Optional[Dict[str, torch.Tensor]], output_layer: torch.Tensor, - stats_bhs1: Optional[torch.Tensor], + stats: Optional[torch.Tensor], ) -> _CudnnScoreModFwdGraphEntry: """Build a cached cuDNN frontend graph for score_mod fprop.""" - import cudnn # pylint: disable=import-outside-toplevel + cudnn = _import_cudnn_frontend() _, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) q = _bhsd_graph_tensor(graph, query_layer, q_format) @@ -1618,7 +1637,7 @@ def _build_cudnn_score_mod_fwd_graph( wrapped_score_mod = _wrap_score_mod(score_mod, score_mod_graph_tensors) output_dim, output_stride = _bhsd_dim_stride(output_layer, q_format) - output, stats = graph.sdpa( + output, stats_tensor = graph.sdpa( name="te_score_mod_sdpa", q=q, k=k, @@ -1631,12 +1650,12 @@ def _build_cudnn_score_mod_fwd_graph( output.set_output(True).set_dim(output_dim).set_stride(output_stride) if is_training: - assert stats_bhs1 is not None - stats.set_output(True).set_dim(stats_bhs1.size()).set_stride( - stats_bhs1.stride() + assert stats is not None + stats_tensor.set_output(True).set_dim(stats.size()).set_stride( + stats.stride() ).set_data_type(cudnn.data_type.FLOAT) else: - stats = None + stats_tensor = None workspace_size = _finalize_cudnn_graph(graph) return _CudnnScoreModFwdGraphEntry( @@ -1645,7 +1664,7 @@ def _build_cudnn_score_mod_fwd_graph( k=k, v=v, output=output, - stats=stats, + stats=stats_tensor, score_mod_graph_tensors=score_mod_graph_tensors, workspace_size=workspace_size, ) @@ -1662,7 +1681,7 @@ def _get_cudnn_score_mod_fwd_graph( score_mod: Callable, score_mod_tensors: Optional[Dict[str, torch.Tensor]], output_layer: torch.Tensor, - stats_bhs1: Optional[torch.Tensor], + stats: Optional[torch.Tensor], ) -> _CudnnScoreModFwdGraphEntry: """Return a cached cuDNN frontend graph for score_mod fprop.""" key = _cudnn_score_mod_fwd_cache_key( @@ -1671,7 +1690,7 @@ def _get_cudnn_score_mod_fwd_graph( key_layer, value_layer, output_layer, - stats_bhs1, + stats, q_format, kv_format, attn_scale, @@ -1690,7 +1709,7 @@ def _get_cudnn_score_mod_fwd_graph( score_mod, score_mod_tensors, output_layer, - stats_bhs1, + stats, ) entry = _cudnn_score_mod_graph_cache.get(key) if entry is None: @@ -1705,7 +1724,7 @@ def _get_cudnn_score_mod_fwd_graph( score_mod, score_mod_tensors, output_layer, - stats_bhs1, + stats, ) _cudnn_score_mod_graph_cache[key] = entry return entry @@ -1717,7 +1736,7 @@ def _build_cudnn_score_mod_bwd_graph( value_layer: torch.Tensor, output_layer: torch.Tensor, d_out: torch.Tensor, - stats_bhs1: torch.Tensor, + stats: torch.Tensor, q_format: str, kv_format: str, attn_scale: float, @@ -1734,7 +1753,7 @@ def _build_cudnn_score_mod_bwd_graph( v = _bhsd_graph_tensor(graph, value_layer, kv_format) output = _bhsd_graph_tensor(graph, output_layer, q_format) d_output = _bhsd_graph_tensor(graph, d_out, q_format) - stats = graph.tensor_like(stats_bhs1) + stats_tensor = graph.tensor_like(stats) score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, score_mod_tensors) score_mod_bprop_graph_tensors = ( @@ -1758,7 +1777,7 @@ def _build_cudnn_score_mod_bwd_graph( v=v, o=output, dO=d_output, - stats=stats, + stats=stats_tensor, attn_scale=attn_scale, use_causal_mask=False, score_mod=wrapped_score_mod, @@ -1777,7 +1796,7 @@ def _build_cudnn_score_mod_bwd_graph( v=v, output=output, d_output=d_output, - stats=stats, + stats=stats_tensor, dq=dq, dk=dk, dv=dv, @@ -1793,7 +1812,7 @@ def _get_cudnn_score_mod_bwd_graph( value_layer: torch.Tensor, output_layer: torch.Tensor, d_out: torch.Tensor, - stats_bhs1: torch.Tensor, + stats: torch.Tensor, q_format: str, kv_format: str, attn_scale: float, @@ -1810,7 +1829,7 @@ def _get_cudnn_score_mod_bwd_graph( value_layer, output_layer, d_out, - stats_bhs1, + stats, q_format, kv_format, attn_scale, @@ -1827,7 +1846,7 @@ def _get_cudnn_score_mod_bwd_graph( value_layer, output_layer, d_out, - stats_bhs1, + stats, q_format, kv_format, attn_scale, @@ -1845,7 +1864,7 @@ def _get_cudnn_score_mod_bwd_graph( value_layer, output_layer, d_out, - stats_bhs1, + stats, q_format, kv_format, attn_scale, @@ -1885,13 +1904,13 @@ def forward( output_shape = (*query_layer.shape[:-1], value_layer.shape[-1]) output_layer = torch.empty(output_shape, device=query_layer.device, dtype=query_layer.dtype) if is_training: - stats_bhs1 = torch.empty( + stats = torch.empty( (*q_bhsd_dim[:-1], 1), device=query_layer.device, dtype=torch.float32, ) else: - stats_bhs1 = None + stats = None entry = _get_cudnn_score_mod_fwd_graph( is_training, @@ -1904,7 +1923,7 @@ def forward( score_mod, score_mod_tensors, output_layer, - stats_bhs1, + stats, ) variant_pack = { entry.q: query_layer, @@ -1913,7 +1932,7 @@ def forward( entry.output: output_layer, } if is_training: - variant_pack[entry.stats] = stats_bhs1 + variant_pack[entry.stats] = stats for name, graph_tensor in entry.score_mod_graph_tensors.items(): variant_pack[graph_tensor] = score_mod_tensors[name] @@ -1941,7 +1960,7 @@ def forward( key_layer, value_layer, output_layer, - stats_bhs1, + stats, *score_mod_tensors.values(), *score_mod_bprop_tensors.values(), ) @@ -1959,7 +1978,7 @@ def backward(ctx, d_out: torch.Tensor): ) saved_tensors = ctx.saved_tensors - query_layer, key_layer, value_layer, output_layer, stats_bhs1 = saved_tensors[:5] + query_layer, key_layer, value_layer, output_layer, stats = saved_tensors[:5] score_mod_tensors_end = 5 + len(ctx.score_mod_tensor_names) score_mod_tensors = dict( zip(ctx.score_mod_tensor_names, saved_tensors[5:score_mod_tensors_end]) @@ -1978,7 +1997,7 @@ def backward(ctx, d_out: torch.Tensor): value_layer, output_layer, d_out, - stats_bhs1, + stats, ctx.q_format, ctx.kv_format, ctx.attn_scale, @@ -1994,7 +2013,7 @@ def backward(ctx, d_out: torch.Tensor): entry.v: value_layer, entry.output: output_layer, entry.d_output: d_out, - entry.stats: stats_bhs1, + entry.stats: stats, entry.dq: dq_layer, entry.dk: dk_layer, entry.dv: dv_layer, @@ -2854,41 +2873,19 @@ def forward( cp_group[0] if cp_comm_type == "a2a+p2p" else cp_group ) - if score_mod is not None: - assert not context_parallel, "score_mod is not supported with context parallelism!" - assert not fp8, "score_mod is not supported with FP8 FusedAttention!" - assert not fp8_output, "score_mod is not supported with fp8_output!" - assert not self.return_max_logit, "score_mod is not supported with return_max_logit!" + if context_parallel: assert ( - type(query_layer) is torch.Tensor - and type(key_layer) is torch.Tensor - and type(value_layer) is torch.Tensor - ), "score_mod only supports unquantized torch.Tensor Q, K and V inputs!" + score_mod is None + ), "score_mod is not supported with context parallelism!" assert ( - fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen - ), "score_mod requires the F16/BF16 cuDNN fused attention backend!" + score_mod_bprop is None + ), "score_mod_bprop is not supported with context parallelism!" assert ( - attn_mask_type == "no_mask" - and core_attention_bias_type == "no_bias" - and core_attention_bias is None - and self.softmax_type == "vanilla" - and self.attention_dropout == 0.0 - ), "score_mod is mutually exclusive with masks, bias, sink attention and dropout!" - output = FusedAttentionWithScoreModFunc.apply( - self.training, - query_layer, - key_layer, - value_layer, - q_format, - kv_format, - self.softmax_scale, - score_mod, - score_mod_bprop, - score_mod_tensors, - score_mod_bprop_tensors, - self.deterministic, - ) - elif context_parallel: + score_mod_tensors is None + ), "score_mod_tensors is not supported with context parallelism!" + assert ( + score_mod_bprop_tensors is None + ), "score_mod_bprop_tensors is not supported with context parallelism!" assert ( fp8 or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen @@ -2933,7 +2930,42 @@ def forward( fp8_output=fp8_output, layer_number=self.layer_number, return_max_logit=self.return_max_logit, - ) + ) + elif score_mod is not None: + assert not fp8, "score_mod is not supported with FP8 FusedAttention!" + assert not fp8_output, "score_mod is not supported with fp8_output!" + assert ( + not self.return_max_logit + ), "score_mod is not supported with return_max_logit!" + assert ( + type(query_layer) is torch.Tensor + and type(key_layer) is torch.Tensor + and type(value_layer) is torch.Tensor + ), "score_mod only supports unquantized torch.Tensor Q, K and V inputs!" + assert ( + fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + ), "score_mod requires the F16/BF16 cuDNN fused attention backend!" + assert ( + attn_mask_type == "no_mask" + and core_attention_bias_type == "no_bias" + and core_attention_bias is None + and self.softmax_type == "vanilla" + and self.attention_dropout == 0.0 + ), "score_mod is mutually exclusive with masks, bias, sink attention and dropout!" + output = FusedAttentionWithScoreModFunc.apply( + self.training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + self.softmax_scale, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + self.deterministic, + ) else: with self.attention_dropout_ctx(): output = FusedAttnFunc.apply( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 95a0b53b4a..d71cfc34cb 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1085,10 +1085,10 @@ def forward( to the FA3 backend to control internal kernel splitting behavior for non-context-parallel cases. It is ignored for other backends and when context parallelism is enabled. score_mod: Optional[Callable], default = None - cuDNN frontend score modification callback. This is a cuDNN-only path and is mutually - exclusive with masks, bias, ALiBi, sink attention, dropout, FP8, context parallelism, - THD format, KV caching, and return_max_logit. The callback signature is - ``score_mod(graph, score, tensors) -> score``. + Experimental cuDNN frontend score modification callback. This is a cuDNN-only path + and is mutually exclusive with masks, bias, ALiBi, sink attention, dropout, FP8, + context parallelism, THD format, KV caching, and return_max_logit. The callback + signature is ``score_mod(graph, score, tensors) -> score``. score_mod_bprop: Optional[Callable], default = None Optional cuDNN frontend callback for the backward pass of score_mod. The callback signature is ``score_mod_bprop(graph, dP, tensors) -> dP``. @@ -1434,7 +1434,9 @@ def forward( if score_mod is None: assert score_mod_bprop is None, "score_mod_bprop requires score_mod!" - assert score_mod_tensors is None, "score_mod_tensors requires score_mod!" + assert ( + score_mod_tensors is None + ), "score_mod_tensors requires score_mod!" assert ( score_mod_bprop_tensors is None ), "score_mod_bprop_tensors requires score_mod!" @@ -1455,18 +1457,28 @@ def forward( and type(key_layer) is torch.Tensor and type(value_layer) is torch.Tensor ), "score_mod only supports unquantized torch.Tensor Q, K and V inputs!" - assert not self.fp8, "score_mod is not supported with FP8 DotProductAttention!" + assert ( + not self.fp8 + ), "score_mod is not supported with FP8 DotProductAttention!" assert not fp8_output, "score_mod is not supported with fp8_output!" - assert not context_parallel, "score_mod is not supported with context parallelism!" - assert qkv_format != "thd", "score_mod is not supported with qkv_format='thd'!" + assert ( + not context_parallel + ), "score_mod is not supported with context parallelism!" + assert ( + qkv_format != "thd" + ), "score_mod is not supported with qkv_format='thd'!" assert ( not user_supplied_seqlens ), "score_mod is mutually exclusive with explicit sequence length metadata!" - assert not pad_between_seqs, "score_mod is not supported with pad_between_seqs!" + assert ( + not pad_between_seqs + ), "score_mod is not supported with pad_between_seqs!" assert ( attention_mask is None ), "score_mod is mutually exclusive with attention_mask!" - assert attn_mask_type == "no_mask", "score_mod requires attn_mask_type='no_mask'!" + assert ( + attn_mask_type == "no_mask" + ), "score_mod requires attn_mask_type='no_mask'!" assert window_size is None or window_size == ( -1, -1, @@ -1496,7 +1508,9 @@ def forward( "bshd", ], "score_mod only supports SBHD/BSHD QKV formats!" if score_mod_tensors is not None: - assert isinstance(score_mod_tensors, dict), "score_mod_tensors must be a dict!" + assert isinstance( + score_mod_tensors, dict + ), "score_mod_tensors must be a dict!" assert all( isinstance(k, str) and isinstance(v, torch.Tensor) for k, v in score_mod_tensors.items() @@ -1540,46 +1554,21 @@ def forward( is_training=self.training, fp8=self.fp8, fp8_meta=self.fp8_meta, + fp8_output=fp8_output, inference_params=inference_params, softmax_type=self.softmax_type, return_max_logit=self.return_max_logit, + checkpoint_core_attention=checkpoint_core_attention, cuda_graph=is_graph_capturing(), num_splits=num_splits, + has_attention_mask=attention_mask is not None, + has_core_attention_bias=core_attention_bias is not None, + user_supplied_seqlens=user_supplied_seqlens, + score_mod=score_mod is not None, + score_mod_bprop=score_mod_bprop is not None, ) global _attention_backends - if score_mod is not None: - use_flash_attention = False - flash_attention_backend = None - use_fused_attention = True - use_unfused_attention = False - q_type = dpa_utils.TE_DType[query_layer.dtype] - fused_attention_backend = tex.get_fused_attn_backend( - self.training, - q_type, - q_type, - dpa_utils.QKVLayout["bshd_bshd_bshd"], - dpa_utils.AttnBiasType["no_bias"], - dpa_utils.AttnMaskType["no_mask"], - dpa_utils.SoftmaxType["vanilla"], - 0.0, - num_attention_heads, - num_gqa_groups, - max_seqlen_q, - max_seqlen_kv, - head_dim_qk, - head_dim_v, - -1, - -1, - False, - is_graph_capturing(), - self.deterministic, - ) - if fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend: - raise ValueError( - "score_mod requires a cuDNN FusedAttention backend, but no fused " - "attention backend supports the provided inputs." - ) - elif is_in_onnx_export_mode(): + if is_in_onnx_export_mode() and score_mod is None: # We do not want to call get_attention_backend() in ONNX mode # and we want to avoid using any global variables like _attention_backends. use_flash_attention = False diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 7df5daabe5..f77bdecc02 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -246,16 +246,30 @@ class AttentionParams: Whether `DotProductAttention` is in an `autocast` region. fp8_meta : Optional[Dict[str Any]], default = None The FP8 metadata tensor of `DotProductAttention`. + fp8_output : bool, default = False + Whether output is requested in FP8. inference_params : Optional[InferenceParams], default = None Inference-related parameters. See InferenceParams for details. softmax_type : str, default = "vanilla" The type of softmax operation. See DotProductAttention for details. return_max_logit : bool, default = False Whether to output max_logit. + checkpoint_core_attention : bool, default = False + Whether core attention is recomputed during backward. cuda_graph : bool, default = `False` Whether support for cuda graph capture is needed or not. num_splits : int, default = 1 The number of kernels to split attention to. + has_attention_mask : bool, default = False + Whether an explicit attention mask tensor was provided. + has_core_attention_bias : bool, default = False + Whether an explicit core attention bias tensor was provided. + user_supplied_seqlens : bool, default = False + Whether explicit cu_seqlens metadata was provided. + score_mod : bool, default = False + Whether a score_mod callback was provided. + score_mod_bprop : bool, default = False + Whether a score_mod bprop callback was provided. """ qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor @@ -284,11 +298,18 @@ class AttentionParams: is_training: bool = True fp8: bool = False fp8_meta: Union[Dict[str, Any], None] = None + fp8_output: bool = False inference_params: Optional[InferenceParams] = None softmax_type: str = "vanilla" return_max_logit: bool = False + checkpoint_core_attention: bool = False cuda_graph: bool = False num_splits: int = 1 + has_attention_mask: bool = False + has_core_attention_bias: bool = False + user_supplied_seqlens: bool = False + score_mod: bool = False + score_mod_bprop: bool = False def __eq__(self, other): """ @@ -362,11 +383,18 @@ def get_attention_backend( is_training = attention_params.is_training fp8 = attention_params.fp8 fp8_meta = attention_params.fp8_meta + fp8_output = attention_params.fp8_output inference_params = attention_params.inference_params softmax_type = attention_params.softmax_type return_max_logit = attention_params.return_max_logit + checkpoint_core_attention = attention_params.checkpoint_core_attention cuda_graph = attention_params.cuda_graph num_splits = attention_params.num_splits + has_attention_mask = attention_params.has_attention_mask + has_core_attention_bias = attention_params.has_core_attention_bias + user_supplied_seqlens = attention_params.user_supplied_seqlens + score_mod = attention_params.score_mod + score_mod_bprop = attention_params.score_mod_bprop # Run config logger = logging.getLogger("DotProductAttention") @@ -432,7 +460,7 @@ def get_attention_backend( # regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is # necessary for performance/functionality, a warning will be issued to prompt users to # install an appropriate FA version. - qkv_format, q_format, _ = get_qkv_format(qkv_layout, inference_params) + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) # Filter: Environment variables use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) @@ -647,6 +675,85 @@ def get_attention_backend( use_unfused_attention = False logger.debug("Disabling all backends for max_logit with FP8 attention") + # Filter: score_mod + if score_mod_bprop and not score_mod: + logger.debug("Disabling all backends because score_mod_bprop requires score_mod") + use_flash_attention = False + use_flash_attention_2 = False + use_flash_attention_3 = False + use_flash_attention_4 = False + use_fused_attention = False + use_unfused_attention = False + if score_mod: + if use_flash_attention_2 or use_flash_attention_3 or use_flash_attention_4: + logger.debug("Disabling FlashAttention for score_mod") + use_flash_attention = False + use_flash_attention_2 = False + use_flash_attention_3 = False + use_flash_attention_4 = False + if use_unfused_attention: + logger.debug("Disabling UnfusedDotProductAttention for score_mod") + use_unfused_attention = False + + score_mod_unsupported_reasons = [] + if qkv_dtype not in [torch.float16, torch.bfloat16]: + score_mod_unsupported_reasons.append( + f"unsupported qkv_dtype = {qkv_dtype}; supported: torch.float16, torch.bfloat16" + ) + if qkv_type is not torch.Tensor: + score_mod_unsupported_reasons.append( + f"unsupported qkv_type = {qkv_type}; supported: torch.Tensor" + ) + if fp8: + score_mod_unsupported_reasons.append("FP8 DotProductAttention is enabled") + if fp8_output: + score_mod_unsupported_reasons.append("fp8_output is enabled") + if inference_params is not None: + score_mod_unsupported_reasons.append("KV caching is enabled") + if context_parallel: + score_mod_unsupported_reasons.append("context parallelism is enabled") + if qkv_format == "thd" or q_format not in ["sbhd", "bshd"] or kv_format not in [ + "sbhd", + "bshd", + ]: + score_mod_unsupported_reasons.append( + f"unsupported QKV format: q_format = {q_format}, kv_format = {kv_format}" + ) + if user_supplied_seqlens: + score_mod_unsupported_reasons.append("explicit sequence length metadata was provided") + if pad_between_seqs: + score_mod_unsupported_reasons.append("pad_between_seqs is enabled") + if has_attention_mask: + score_mod_unsupported_reasons.append("attention_mask was provided") + if attn_mask_type != "no_mask": + score_mod_unsupported_reasons.append(f"attn_mask_type = {attn_mask_type}") + if window_size is not None and window_size != (-1, -1): + score_mod_unsupported_reasons.append(f"window_size = {window_size}") + if core_attention_bias_type != "no_bias" or has_core_attention_bias: + score_mod_unsupported_reasons.append( + f"core_attention_bias_type = {core_attention_bias_type}" + ) + if alibi_slopes_shape is not None: + score_mod_unsupported_reasons.append("ALiBi slopes were provided") + if softmax_type != "vanilla": + score_mod_unsupported_reasons.append(f"softmax_type = {softmax_type}") + if attention_dropout != 0.0: + score_mod_unsupported_reasons.append(f"attention_dropout = {attention_dropout}") + if return_max_logit: + score_mod_unsupported_reasons.append("return_max_logit is enabled") + if checkpoint_core_attention: + score_mod_unsupported_reasons.append("checkpoint_core_attention is enabled") + if cuda_graph: + score_mod_unsupported_reasons.append("CUDA graph capture is enabled") + if num_splits != 1: + score_mod_unsupported_reasons.append(f"num_splits = {num_splits}") + if score_mod_unsupported_reasons and use_fused_attention: + logger.debug( + "Disabling FusedAttention for score_mod because %s", + "; ".join(score_mod_unsupported_reasons), + ) + use_fused_attention = False + # Filter: KV cache # backend | precision | KV cache | architecture | qkv_format | page_size # --------------------------------------------------------------------------------------- @@ -1250,6 +1357,17 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention as no backend supports the provided input") use_fused_attention = False fused_attention_backend = None + elif ( + score_mod + and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] + ): + logger.debug( + "Disabling FusedAttention for score_mod because sub-backend %s is not " + "F16/BF16 arbitrary-seqlen", + int(fused_attention_backend), + ) + use_fused_attention = False + fused_attention_backend = None # Filter: Determinism # backend | deterministic # --------------------------------------------- From a8ed67e652a636bee22fcd9a9f23d38ca82f88b5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 May 2026 01:23:19 +0000 Subject: [PATCH 11/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../dot_product_attention/backends.py | 25 +++++++---------- .../dot_product_attention.py | 28 +++++-------------- .../attention/dot_product_attention/utils.py | 18 ++++++------ 3 files changed, 27 insertions(+), 44 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index f481d39a22..8532809f39 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -106,10 +106,14 @@ def _import_cudnn_frontend(): """Import the vendored cuDNN frontend if built, otherwise use the installed package.""" cudnn_frontend_path = str(_CUDNN_FRONTEND_PYTHON_PATH) cudnn_frontend_package = _CUDNN_FRONTEND_PYTHON_PATH / "cudnn" - if any(cudnn_frontend_package.glob("_compiled_module*")) and cudnn_frontend_path not in sys.path: + if ( + any(cudnn_frontend_package.glob("_compiled_module*")) + and cudnn_frontend_path not in sys.path + ): sys.path.insert(0, cudnn_frontend_path) return importlib.import_module("cudnn") + # Try to import Flash Attention v2 try: fa_utils.version = PkgVersion(PkgVersion(get_pkg_version("flash-attn")).public) @@ -1278,9 +1282,7 @@ def _bhsd_dim_stride( (tensor.shape[0], tensor.shape[2], tensor.shape[1], tensor.shape[3]), (tensor.stride(0), tensor.stride(2), tensor.stride(1), tensor.stride(3)), ) - raise ValueError( - f"score_mod only supports SBHD/BSHD tensor formats, got {tensor_format}." - ) + raise ValueError(f"score_mod only supports SBHD/BSHD tensor formats, got {tensor_format}.") def _bhsd_graph_tensor(graph, tensor: torch.Tensor, tensor_format: str): @@ -1588,10 +1590,7 @@ def _cudnn_score_mod_bwd_cache_key( """Cache key for score_mod bprop execution plans.""" score_mod_key = _score_mod_callback_cache_key(score_mod) score_mod_bprop_key = _score_mod_callback_cache_key(score_mod_bprop) - if ( - score_mod_key is _SCORE_MOD_UNCACHEABLE - or score_mod_bprop_key is _SCORE_MOD_UNCACHEABLE - ): + if score_mod_key is _SCORE_MOD_UNCACHEABLE or score_mod_bprop_key is _SCORE_MOD_UNCACHEABLE: return None return ( "bwd", @@ -2874,9 +2873,7 @@ def forward( ) if context_parallel: - assert ( - score_mod is None - ), "score_mod is not supported with context parallelism!" + assert score_mod is None, "score_mod is not supported with context parallelism!" assert ( score_mod_bprop is None ), "score_mod_bprop is not supported with context parallelism!" @@ -2930,13 +2927,11 @@ def forward( fp8_output=fp8_output, layer_number=self.layer_number, return_max_logit=self.return_max_logit, - ) + ) elif score_mod is not None: assert not fp8, "score_mod is not supported with FP8 FusedAttention!" assert not fp8_output, "score_mod is not supported with fp8_output!" - assert ( - not self.return_max_logit - ), "score_mod is not supported with return_max_logit!" + assert not self.return_max_logit, "score_mod is not supported with return_max_logit!" assert ( type(query_layer) is torch.Tensor and type(key_layer) is torch.Tensor diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index d71cfc34cb..a9d1a48f20 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1434,9 +1434,7 @@ def forward( if score_mod is None: assert score_mod_bprop is None, "score_mod_bprop requires score_mod!" - assert ( - score_mod_tensors is None - ), "score_mod_tensors requires score_mod!" + assert score_mod_tensors is None, "score_mod_tensors requires score_mod!" assert ( score_mod_bprop_tensors is None ), "score_mod_bprop_tensors requires score_mod!" @@ -1457,28 +1455,18 @@ def forward( and type(key_layer) is torch.Tensor and type(value_layer) is torch.Tensor ), "score_mod only supports unquantized torch.Tensor Q, K and V inputs!" - assert ( - not self.fp8 - ), "score_mod is not supported with FP8 DotProductAttention!" + assert not self.fp8, "score_mod is not supported with FP8 DotProductAttention!" assert not fp8_output, "score_mod is not supported with fp8_output!" - assert ( - not context_parallel - ), "score_mod is not supported with context parallelism!" - assert ( - qkv_format != "thd" - ), "score_mod is not supported with qkv_format='thd'!" + assert not context_parallel, "score_mod is not supported with context parallelism!" + assert qkv_format != "thd", "score_mod is not supported with qkv_format='thd'!" assert ( not user_supplied_seqlens ), "score_mod is mutually exclusive with explicit sequence length metadata!" - assert ( - not pad_between_seqs - ), "score_mod is not supported with pad_between_seqs!" + assert not pad_between_seqs, "score_mod is not supported with pad_between_seqs!" assert ( attention_mask is None ), "score_mod is mutually exclusive with attention_mask!" - assert ( - attn_mask_type == "no_mask" - ), "score_mod requires attn_mask_type='no_mask'!" + assert attn_mask_type == "no_mask", "score_mod requires attn_mask_type='no_mask'!" assert window_size is None or window_size == ( -1, -1, @@ -1508,9 +1496,7 @@ def forward( "bshd", ], "score_mod only supports SBHD/BSHD QKV formats!" if score_mod_tensors is not None: - assert isinstance( - score_mod_tensors, dict - ), "score_mod_tensors must be a dict!" + assert isinstance(score_mod_tensors, dict), "score_mod_tensors must be a dict!" assert all( isinstance(k, str) and isinstance(v, torch.Tensor) for k, v in score_mod_tensors.items() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index f77bdecc02..0b6a2b85d5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -712,10 +712,15 @@ def get_attention_backend( score_mod_unsupported_reasons.append("KV caching is enabled") if context_parallel: score_mod_unsupported_reasons.append("context parallelism is enabled") - if qkv_format == "thd" or q_format not in ["sbhd", "bshd"] or kv_format not in [ - "sbhd", - "bshd", - ]: + if ( + qkv_format == "thd" + or q_format not in ["sbhd", "bshd"] + or kv_format + not in [ + "sbhd", + "bshd", + ] + ): score_mod_unsupported_reasons.append( f"unsupported QKV format: q_format = {q_format}, kv_format = {kv_format}" ) @@ -1357,10 +1362,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention as no backend supports the provided input") use_fused_attention = False fused_attention_backend = None - elif ( - score_mod - and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] - ): + elif score_mod and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]: logger.debug( "Disabling FusedAttention for score_mod because sub-backend %s is not " "F16/BF16 arbitrary-seqlen", From e2a69e130f5033b0348b2cdee91008599387f4c4 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 15 May 2026 01:37:39 +0000 Subject: [PATCH 12/12] Fix score_mod lambda cache keys Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 17 +++++++++++++++++ .../attention/dot_product_attention/backends.py | 10 +++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 959350e080..89dd0eb77c 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1564,6 +1564,23 @@ def forward(self, _score_mod_graph, score_tensor, _tensors): assert key_0 != key_2 +def test_score_mod_cache_module_lambda_keys_do_not_collide(): + """Module-level lambdas should not reuse graphs only because qualnames match.""" + score_mod_0 = lambda _graph, score_tensor, _tensors: score_tensor # noqa: E731 + score_mod_1 = lambda _graph, score_tensor, _tensors: score_tensor # noqa: E731 + score_mod_0.__module__ = __name__ + score_mod_1.__module__ = __name__ + score_mod_0.__qualname__ = "" + score_mod_1.__qualname__ = "" + + key_0 = dpa_backends._score_mod_callback_cache_key(score_mod_0) + key_1 = dpa_backends._score_mod_callback_cache_key(score_mod_1) + + assert key_0 is not dpa_backends._SCORE_MOD_UNCACHEABLE + assert key_1 is not dpa_backends._SCORE_MOD_UNCACHEABLE + assert key_0 != key_1 + + def test_score_mod_cache_key_ignores_pass_by_value_values(): """Scalar CPU tensor values are runtime inputs, not execution-plan metadata.""" q, k, v, o, stats = _score_mod_cache_cpu_inputs() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 8532809f39..70ba826138 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1335,9 +1335,11 @@ def _score_mod_explicit_cache_key(callback_owner: Any) -> Optional[Any]: def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Any: """Create a stable graph cache key for a score_mod callable. - Module-level functions are assumed to have stable topology. Stateful bound methods and - callable instances need an explicit score_mod_graph_cache_key(); otherwise their graphs - are left uncached to avoid reusing stale graphs after Python object address reuse. + Module-level named functions are assumed to have stable topology. Anonymous functions + are keyed by code object because lambdas in the same module can share the same + qualname. Stateful bound methods and callable instances need an explicit + score_mod_graph_cache_key(); otherwise their graphs are left uncached to avoid reusing + stale graphs after Python object address reuse. """ if callback is None: return None @@ -1370,6 +1372,8 @@ def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Any: and callback.__closure__ is None and "" not in callback.__qualname__ ): + if callback.__name__ == "" or not callback.__qualname__: + return ("function", callback.__module__, callback.__code__) return ("function", callback.__module__, callback.__qualname__) return _SCORE_MOD_UNCACHEABLE