From ab542fc42f9d274356347075a89c5d5e5f2a7978 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Sat, 20 Dec 2025 00:54:50 +0000 Subject: [PATCH 01/25] Plumbing correct bias dims from TE to cudnn Signed-off-by: Kshitij Lakhani --- .../fused_attn_f16_arbitrary_seqlen.cu | 34 +++++++++++++------ .../common/fused_attn/fused_attn_fp8.cu | 20 +++++++---- transformer_engine/common/fused_attn/utils.h | 6 ++-- 3 files changed, 41 insertions(+), 19 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 53023361e4..9c79a57398 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -52,7 +52,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, - int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, + int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ, @@ -121,6 +121,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( max_pages_per_seq_v, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, is_training, dropout_probability, @@ -272,8 +274,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( if (is_bias) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("bias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_options.set_bias(bias); } @@ -548,7 +550,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, @@ -623,6 +625,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( 0, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, true, dropout_probability, @@ -814,12 +818,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_bias) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("bias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); dBias = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dBias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_backward_options.set_bias(bias); // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] // are not supported for dbias calculation but they are @@ -1084,10 +1088,14 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrBias = nullptr; size_t bias_b = 0; size_t bias_h = 0; + size_t bias_sq = 0; + size_t bias_skv = 0; if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { devPtrBias = input_Bias->data.dptr; bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; + bias_sq = input_Bias->data.shape[2]; + bias_skv = input_Bias->data.shape[3]; } void *devPtrSoftmaxOffset = nullptr; if (softmax_type != NVTE_VANILLA_SOFTMAX) { @@ -1153,7 +1161,7 @@ void fused_attn_arbitrary_seqlen_fwd( if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; + output_bias->data.shape = {bias_b, bias_h, bias_sq, bias_skv}; output_bias->data.dtype = QKV_type; } @@ -1198,7 +1206,7 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, @@ -1245,11 +1253,15 @@ void fused_attn_arbitrary_seqlen_bwd( void *devPtrdBias = nullptr; size_t bias_b = 0; size_t bias_h = 0; + size_t bias_sq = 0; + size_t bias_skv = 0; if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { devPtrBias = input_Bias->data.dptr; devPtrdBias = output_dBias->data.dptr; bias_b = output_dBias->data.shape[0]; bias_h = output_dBias->data.shape[1]; + bias_sq = input_Bias->data.shape[2]; + bias_skv = input_Bias->data.shape[3]; } size_t max_batch_size = 0; @@ -1292,7 +1304,7 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f886ec77f4..b2f323441e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1671,6 +1671,8 @@ void fused_attn_fp8_fwd_impl_v1( bool is_dropout = (is_training && dropout_probability != 0.0f); auto bias_b = b; auto bias_h = h; + auto bias_sq = s_q; + auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF || @@ -1697,6 +1699,8 @@ void fused_attn_fp8_fwd_impl_v1( 0, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, is_training, dropout_probability, @@ -1818,8 +1822,8 @@ void fused_attn_fp8_fwd_impl_v1( // if (is_bias) { // bias = mha_graph->tensor(fe::graph::Tensor_attributes() // .set_name("bias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + // .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); // sdpa_options.set_bias(bias); // } @@ -1999,6 +2003,8 @@ void fused_attn_fp8_bwd_impl_v1( bool is_dropout = (dropout_probability != 0.0f); auto bias_b = b; auto bias_h = h; + auto bias_sq = s_q; + auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || @@ -2027,6 +2033,8 @@ void fused_attn_fp8_bwd_impl_v1( 0, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, true, dropout_probability, @@ -2194,12 +2202,12 @@ void fused_attn_fp8_bwd_impl_v1( // if (is_bias) { // bias = mha_graph->tensor(fe::graph::Tensor_attributes() // .set_name("bias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + // .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); // dBias = mha_graph->tensor(fe::graph::Tensor_attributes() // .set_name("dBias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + // .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); // sdpa_backward_options.set_bias(bias); // // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] // // are not supported for dbias calculation but they are diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index fdfc4abe82..ae27223980 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -101,6 +101,8 @@ struct FADescriptor_v1 { std::int64_t max_pages_per_seq_v; std::int64_t bias_b; std::int64_t bias_h; + std::int64_t bias_sq; + std::int64_t bias_skv; float attnScale; bool isTraining; float dropoutProbability; @@ -120,14 +122,14 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, - rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, + rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, From c86328e2d97bf4619a0dd24f8ad48e41203c9063 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Dec 2025 18:22:38 +0000 Subject: [PATCH 02/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../fused_attn_f16_arbitrary_seqlen.cu | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 9c79a57398..02e2df5644 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -272,10 +272,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options.set_alibi_mask(is_alibi); if (is_bias) { - bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_dim({bias_b, bias_h, bias_sq, bias_skv}) - .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); + bias = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_options.set_bias(bias); } @@ -816,14 +817,16 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_alibi_mask(is_alibi); if (is_bias) { - bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_dim({bias_b, bias_h, bias_sq, bias_skv}) - .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); - dBias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dBias") - .set_dim({bias_b, bias_h, bias_sq, bias_skv}) - .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); + bias = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); + dBias = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("dBias") + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_backward_options.set_bias(bias); // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] // are not supported for dbias calculation but they are From fddf0acbae454d2b15d26071492a107377a65798 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 9 Jan 2026 00:27:48 +0000 Subject: [PATCH 03/25] Make changes for cp bias code Signed-off-by: Kshitij Lakhani --- tests/pytorch/attention/test_attention.py | 5 ++--- .../pytorch/attention/dot_product_attention/utils.py | 7 +------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index bd0ac41974..e094da1eec 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -526,7 +526,7 @@ def test_dpa_mask(dtype, model_configs, model): model_configs_bias = { # test: ModelConfig(b, sq, hq, dqk) - "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"), + "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="111s"), "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"), "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"), @@ -1143,11 +1143,10 @@ def _run_dot_product_attention( bias = None if config.attn_bias_type == "post_scale_bias": shape = "_".join(config.bias_shape) + shape = shape.replace("_1_s", "_1_skv") shape = shape.replace("_s_s", "_sq_skv") tensor_shape = [dim_to_num[j] for j in shape.split("_")] bias = torch.randn(tensor_shape, dtype=dtype, device="cuda") - if config.bias_shape != "1hss": - bias.requires_grad = False # Create RNG _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 0c5a519813..4db5d5ea7c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -966,12 +966,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt and fu_core_attention_bias_type == "post_scale_bias" and fu_core_attention_bias_shape != "1hss" ): - if fu_core_attention_bias_requires_grad: - # remove this line when cuDNN adds bwd support for - # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s] - logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape") - use_fused_attention = False - else: + if not fu_core_attention_bias_requires_grad: # max512 backend will only support [1, h, s, s] os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" From d3aa7eccb953185bc367855556cfcf06c3523a43 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 9 Jan 2026 02:19:48 +0000 Subject: [PATCH 04/25] Add dbias and dbias_ to run_dpa_with_cp test Signed-off-by: Kshitij Lakhani --- .../attention/run_attention_with_cp.py | 174 +++++++++++++----- 1 file changed, 126 insertions(+), 48 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 3efb516b57..49ad88be1d 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -307,6 +307,7 @@ def run_dpa_with_cp( if config.attn_bias_type not in ["no_bias", "alibi"]: attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() + bias.requires_grad = True else: bias = None @@ -338,7 +339,7 @@ def run_dpa_with_cp( out.backward(dout_fp8) else: out.backward(dout) - dq, dk, dv = q.grad, k.grad, v.grad + dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad d_softmax_offset = None if config.softmax_type != "vanilla": d_softmax_offset = core_attn.softmax_offset.grad @@ -394,6 +395,7 @@ def run_dpa_with_cp( ) bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) + bias_.requires_grad = True # set up environment core_attn.set_context_parallel_group( cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, @@ -433,23 +435,23 @@ def run_dpa_with_cp( out_.backward(dout_fp8_) else: out_.backward(dout_) - dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad + dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad d_softmax_offset_ = None if config.softmax_type != "vanilla": d_softmax_offset_ = core_attn.softmax_offset.grad.clone() # get outputs - tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_] + tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] if fp8_mha: tensors_to_deq = [out, out_] if not fp8_bwd else tensors for i, tensor in enumerate(tensors_to_deq): tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: - tensors[0], tensors[4] = tensors_to_deq + tensors[0], tensors[5] = tensors_to_deq for tensor in tensors: assert torch.all(~torch.isnan(tensor)) assert torch.all(~torch.isinf(tensor)) - out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors + out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors ############ compare results between CP and no-CP ############ if qkv_format == "bshd" or qkv_format == "sbhd": @@ -467,6 +469,26 @@ def run_dpa_with_cp( x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) for x in [dq_, dk_, dv_, out_] ] + if dbias is not None and dbias_ is not None: + dbias = dbias.view( + dbias.shape[0], + dbias.shape[1], + 2 * world_size, + dbias.shape[2] // (2 * world_size), + dbias.shape[3] + ) + # bias has fixed axis (2) as dbias shape: (1, 1, max_seqlen_q, max_seqlen_kv) + dbias = dbias.index_select(2, seq_idx) + # Flatten + dbias = dbias.view(dbias.shape[0], dbias.shape[1], -1, dbias.shape[-1]) + dbias_ = dbias_.view( + dbias_.shape[0], + dbias_.shape[1], + 2, + dbias_.shape[2] // 2, + dbias_.shape[3] + ) + elif qkv_format == "thd": dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] @@ -509,9 +531,9 @@ def run_dpa_with_cp( ) atol, rtol, rmse_tol = get_tols(config, dtype) - tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_] - tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit] - names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"] + tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_] + tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit] + names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"] names_cp = [x + "_cp" for x in names] names_no_cp = [x + "_no_cp" for x in names] is_fp8 = dtype == "fp8" @@ -519,47 +541,103 @@ def run_dpa_with_cp( if t is not None: if "softmax_offset" not in names[i] and "max_logit" not in names[i]: if qkv_format == "bshd": - compare_and_assert( - t[:, 0], - tensors_cp[i][:, 0], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[:, 1], - tensors_cp[i][:, 1], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) + # Compare the two sequence chunks separately + # Compare dbias + if names[i] == "dbias": + # After reshaping: (1, 1, 2, seq_q//2, seq_kv) + # Compare along dimension 2 (the split sequence dimension) + compare_and_assert( + t[:, :, 0], # First sequence chunk + tensors_cp[i][:, :, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, :, 1], # Second sequence chunk + tensors_cp[i][:, :, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + # Compare Q/K/V/out + else: + # Compare along dimension 1 (the split sequence dimension) + compare_and_assert( + t[:, 0], + tensors_cp[i][:, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, 1], + tensors_cp[i][:, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) elif qkv_format == "sbhd": - compare_and_assert( - t[0], - tensors_cp[i][0], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[1], - tensors_cp[i][1], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) + # Compare the two sequence chunks separately + # Compare dbias (same as BSHD) + if names[i] == "dbias": + # After reshaping: (1, 1, 2, seq_q//2, seq_kv) + # Compare along dimension 2 (the split sequence dimension) + compare_and_assert( + t[:, :, 0], # First sequence chunk + tensors_cp[i][:, :, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, :, 1], # Second sequence chunk + tensors_cp[i][:, :, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + # Compare Q/K/V/out + else: + # Compare along dimension 0 (the split sequence dimension) + compare_and_assert( + t[0], + tensors_cp[i][0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[1], + tensors_cp[i][1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) elif qkv_format == "thd": compare_and_assert( t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 From 4d295c464e1d32037432090d127fee6422340cc1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Jan 2026 02:22:04 +0000 Subject: [PATCH 05/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/run_attention_with_cp.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 49ad88be1d..2ae5937a14 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -475,18 +475,14 @@ def run_dpa_with_cp( dbias.shape[1], 2 * world_size, dbias.shape[2] // (2 * world_size), - dbias.shape[3] + dbias.shape[3], ) # bias has fixed axis (2) as dbias shape: (1, 1, max_seqlen_q, max_seqlen_kv) dbias = dbias.index_select(2, seq_idx) # Flatten dbias = dbias.view(dbias.shape[0], dbias.shape[1], -1, dbias.shape[-1]) dbias_ = dbias_.view( - dbias_.shape[0], - dbias_.shape[1], - 2, - dbias_.shape[2] // 2, - dbias_.shape[3] + dbias_.shape[0], dbias_.shape[1], 2, dbias_.shape[2] // 2, dbias_.shape[3] ) elif qkv_format == "thd": From f4f9cc65fd74902382c6489eca36c579dce55300 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 9 Jan 2026 02:27:06 +0000 Subject: [PATCH 06/25] Fix: Use output_dBias instead of input_dBias to extract the shape Signed-off-by: Kshitij Lakhani --- .../common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 02e2df5644..4cf72d9dca 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1263,8 +1263,8 @@ void fused_attn_arbitrary_seqlen_bwd( devPtrdBias = output_dBias->data.dptr; bias_b = output_dBias->data.shape[0]; bias_h = output_dBias->data.shape[1]; - bias_sq = input_Bias->data.shape[2]; - bias_skv = input_Bias->data.shape[3]; + bias_sq = output_dBias->data.shape[2]; + bias_skv = output_dBias->data.shape[3]; } size_t max_batch_size = 0; From f9f4fb8688f5c06bf52f700926d24399e5a85e13 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 21 Jan 2026 19:31:44 +0000 Subject: [PATCH 07/25] Add guards for bias/bias_/dbias/dbias_ being None Signed-off-by: Kshitij Lakhani --- tests/pytorch/attention/run_attention_with_cp.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 2ae5937a14..baf48bc407 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -339,7 +339,7 @@ def run_dpa_with_cp( out.backward(dout_fp8) else: out.backward(dout) - dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad + dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None d_softmax_offset = None if config.softmax_type != "vanilla": d_softmax_offset = core_attn.softmax_offset.grad @@ -435,7 +435,7 @@ def run_dpa_with_cp( out_.backward(dout_fp8_) else: out_.backward(dout_) - dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad + dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad if bias_ is not None else None d_softmax_offset_ = None if config.softmax_type != "vanilla": d_softmax_offset_ = core_attn.softmax_offset.grad.clone() @@ -445,12 +445,16 @@ def run_dpa_with_cp( if fp8_mha: tensors_to_deq = [out, out_] if not fp8_bwd else tensors for i, tensor in enumerate(tensors_to_deq): - tensors_to_deq[i] = tensor.dequantize() + # dbias/dbias_ could be None, so skip check for it + if tensor is not None: + tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[5] = tensors_to_deq for tensor in tensors: - assert torch.all(~torch.isnan(tensor)) - assert torch.all(~torch.isinf(tensor)) + # dbias/dbias_ could be None, so skip check for it + if tensor is not None: + assert torch.all(~torch.isnan(tensor)) + assert torch.all(~torch.isinf(tensor)) out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors ############ compare results between CP and no-CP ############ From d9547fa1e4a1c1b4698207c23c075de34f0a8333 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 22 Jan 2026 00:47:25 +0000 Subject: [PATCH 08/25] Add support for bias shape 111s in addition to the original 1hss, 11ss, b1ss and bhss Signed-off-by: Kshitij Lakhani --- tests/pytorch/attention/test_attention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index e094da1eec..1f9fff479c 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1143,8 +1143,12 @@ def _run_dot_product_attention( bias = None if config.attn_bias_type == "post_scale_bias": shape = "_".join(config.bias_shape) - shape = shape.replace("_1_s", "_1_skv") + # For 1hss, 11ss, b1ss, bhss + shape_cache = shape shape = shape.replace("_s_s", "_sq_skv") + if shape==shape_cache: + # For 111s + shape = shape.replace("_1_s", "_1_skv") tensor_shape = [dim_to_num[j] for j in shape.split("_")] bias = torch.randn(tensor_shape, dtype=dtype, device="cuda") From 7ede1fe2bc2db94c3639d2da475c95419cad1544 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 6 Feb 2026 17:53:55 +0000 Subject: [PATCH 09/25] Add support for dbias calculation and variant packing for the dbias shapes b1ss, bhss, 11ss in addition to the already supported 1hss Signed-off-by: Kshitij Lakhani --- .../fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 11 ++++++----- .../common/fused_attn/fused_attn_fp8.cu | 11 +++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 4cf72d9dca..bf0075232c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -828,10 +828,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_dim({bias_b, bias_h, bias_sq, bias_skv}) .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_backward_options.set_bias(bias); - // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] - // are not supported for dbias calculation but they are - // supported for forward bias calculation - if ((bias_b == 1) && (bias_h == h)) { + // bias shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s], [1, h, s, s] are supported for dbias calculation + // bias shape [1, 1, 1, s] is not supported for dbias calculation as of cuDNN 9.18 + if (!((bias_b == 1) && (bias_h == 1) && (bias_sq == 1))) { sdpa_backward_options.set_dbias(dBias); } } @@ -982,7 +981,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_bias) { variant_pack[bias] = devPtrBias; - if ((bias_b == 1) && (bias_h == h)) { + // bias shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s], [1, h, s, s] are supported for dbias calculation + // bias shape [1, 1, 1, s] is not supported for dbias calculation as of cuDNN 9.18 + if (!((bias_b == 1) && (bias_h == 1) && (bias_sq == 1))) { variant_pack[dBias] = devPtrdBias; } else { variant_pack[dBias] = nullptr; diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index b2f323441e..71f90aee26 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2209,12 +2209,11 @@ void fused_attn_fp8_bwd_impl_v1( // .set_dim({bias_b, bias_h, bias_sq, bias_skv}) // .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); // sdpa_backward_options.set_bias(bias); - // // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] - // // are not supported for dbias calculation but they are - // // supported for forward bias calculation - // if ((bias_b == 1) && (bias_h == h)) { - // sdpa_backward_options.set_dbias(dBias); - // } + // bias shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s], [1, h, s, s] are supported for dbias calculation + // bias shape [1, 1, 1, s] is not supported for dbias calculation as of cuDNN 9.18 + // if (!((bias_b == 1) && (bias_h == 1) && (bias_sq == 1))) { + // sdpa_backward_options.set_dbias(dBias); + // } // } if (is_padding) { From c20d67ad1e881fdcf779b1ca211635b3a7094251 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 6 Feb 2026 18:14:49 +0000 Subject: [PATCH 10/25] Add support for 111s bias shape in DPA Signed-off-by: Kshitij Lakhani --- .../dot_product_attention/dot_product_attention.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 5d830dca33..64db4646f6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1318,11 +1318,14 @@ def forward( ): core_attention_bias_shape = "b1ss" elif core_attention_bias.shape[0] == 1 and core_attention_bias.shape[1] == 1: - core_attention_bias_shape = "11ss" + if core_attention_bias.shape[2] == 1: + core_attention_bias_shape = "111s" + else: + core_attention_bias_shape = "11ss" else: assert ( False - ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes" + ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss, 111s} shapes" # check if there is padding between sequences when qkv_format='thd' if pad_between_seqs is None: From 303aee77c61d5202ad4b79843e8b9214867ada5d Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 6 Feb 2026 18:18:05 +0000 Subject: [PATCH 11/25] Allow fused attn for dbias calculation for 11ss, b1ss, bhss. Disable fused attn if dbias calculation for 111s is required, else enable Signed-off-by: Kshitij Lakhani --- .../pytorch/attention/dot_product_attention/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 4db5d5ea7c..3502a90ad2 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -966,6 +966,10 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt and fu_core_attention_bias_type == "post_scale_bias" and fu_core_attention_bias_shape != "1hss" ): + # dbias calculation is not supported for 111s as of cuDNN 9.18. So, use fused attention backend only if bias does not require grad. + if fu_core_attention_bias_requires_grad and fu_core_attention_bias_shape == "111s": + logger.warning("Disabling FusedAttention as dbias calculation is not supported for 111s") + use_fused_attention = False if not fu_core_attention_bias_requires_grad: # max512 backend will only support [1, h, s, s] os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" From e9f88f0d7dafee0cabd50dbc766b6a2cbff65d4b Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 6 Feb 2026 18:44:03 +0000 Subject: [PATCH 12/25] Disable requires_grad for bias for shape 111s in tests Signed-off-by: Kshitij Lakhani --- tests/pytorch/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index b6a84a8e2b..d898543421 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -271,7 +271,6 @@ def get_available_attention_backends( os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True - alibi_slopes_shape = None if config.attn_bias_type == "alibi" and config.alibi_type == "custom": if config.bias_shape == "1hss": @@ -289,7 +288,9 @@ def get_available_attention_backends( and config.head_dim_qk <= 128 and config.head_dim_v <= 128 ): - core_attention_bias_requires_grad = True + #TODO(KshitijLakhani): Remove this guard when cuDNN starts support dbias calculation for bias shape 111s + if core_attention_bias_shape != "111s": + core_attention_bias_requires_grad = True fused_attn_backends = [] available_backends = None From 6bf73e1b5db7a1a45a0fbe71959a2ba84be5bd40 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 6 Feb 2026 19:00:17 +0000 Subject: [PATCH 13/25] Disable bias grad / training flag for 111s bias in the non-CP attn tests. Add bias shape 111s to test_dpa_bias_shapes Signed-off-by: Kshitij Lakhani --- tests/pytorch/attention/test_attention.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 1f9fff479c..2977bab592 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -162,7 +162,13 @@ def test_dot_product_attention( ) # Get backends + # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. + # For all other shapes test fwd+bwd is_training = True + # TODO(KshitijLakhani): Set is_training to True for all cases once cuDNN supports dbias for 111s. + if config.bias_shape == "111s": + is_training = False + logging.info(f"Setting is_training to False as cuDNN does not support dbias for {config.bias_shape=} ") available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, @@ -526,7 +532,7 @@ def test_dpa_mask(dtype, model_configs, model): model_configs_bias = { # test: ModelConfig(b, sq, hq, dqk) - "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="111s"), + "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"), "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"), "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"), @@ -636,7 +642,8 @@ def test_dpa_bias(dtype, model_configs, model): "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"), "bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"), - "bias_1_4": ModelConfig( + "bias_1_4": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="111s"), + "bias_1_5": ModelConfig( 4, 2048, 24, @@ -646,7 +653,7 @@ def test_dpa_bias(dtype, model_configs, model): bias_shape="1hss", alibi_type="custom", ), - "bias_1_5": ModelConfig( + "bias_1_6": ModelConfig( 2, 2048, 24, @@ -1146,11 +1153,14 @@ def _run_dot_product_attention( # For 1hss, 11ss, b1ss, bhss shape_cache = shape shape = shape.replace("_s_s", "_sq_skv") - if shape==shape_cache: - # For 111s + # For 111s + if shape == shape_cache: shape = shape.replace("_1_s", "_1_skv") tensor_shape = [dim_to_num[j] for j in shape.split("_")] bias = torch.randn(tensor_shape, dtype=dtype, device="cuda") + # For 111s, dbias calculation is not supported as of cuDNN 9.18 + if config.bias_shape == "111s": + bias.requires_grad = False # Create RNG _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() From ebee29bee46cd763ae2b1c787a29f20a52350e55 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 6 Feb 2026 23:14:34 +0000 Subject: [PATCH 14/25] Fix to correctly create the bias shape tensor instead of the hard coded shape. Fix the comparison logic shapes for bias/dbias Signed-off-by: Kshitij Lakhani --- .../attention/run_attention_with_cp.py | 100 +++++++++++++----- 1 file changed, 75 insertions(+), 25 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index baf48bc407..d6300ef5d9 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -305,9 +305,20 @@ def run_dpa_with_cp( x.requires_grad = True if config.attn_bias_type not in ["no_bias", "alibi"]: - attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) + bias_shape_map = { + "1hss": (1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv), + "11ss": (1, 1, config.max_seqlen_q, config.max_seqlen_kv), + "b1ss": (config.batch_size, 1, config.max_seqlen_q, config.max_seqlen_kv), + "bhss": (config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv), + "111s": (1, 1, 1, config.max_seqlen_kv), + } + attn_bias_shape = bias_shape_map.get(config.bias_shape) + if attn_bias_shape is None: + assert False, f"cuDNN does not support {config.bias_shape=}" bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() - bias.requires_grad = True + # cuDNN does not support dbias calculation for 111s as of cuDNN 9.18 + # TODO(KshitijLakhani): Set requires_grad to True for all shapes once 111s is supported + bias.requires_grad = True if config.bias_shape != "111s" else False else: bias = None @@ -390,12 +401,30 @@ def run_dpa_with_cp( q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: - bias_ = bias_.view( - *bias_.shape[:-2], 2 * world_size, bias_.shape[-2] // (2 * world_size), bias_.shape[-1] - ) - bias_ = bias_.index_select(2, seq_idx) - bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) - bias_.requires_grad = True + ndim = bias_.ndim + seq_q_dim = ndim - 2 + if qkv_format == "thd": + bias_seq_idx = seq_idx_q + else: + bias_seq_idx = seq_idx + shape_before_seq = bias_.shape[:seq_q_dim] + seq_q_size = bias_.shape[seq_q_dim] + seq_kv_size = bias_.shape[-1] + if seq_q_size == 1: + #TODO(KshitijLakhani): Set to True always once cuDNN supports dbias for 111s + bias_.requires_grad = False + # Bias is broadcast, no need to partition along sequence dimension + pass + else: + bias_ = bias_.view( + *shape_before_seq, + 2 * world_size, + seq_q_size // (2 * world_size), + seq_kv_size + ) + bias_ = bias_.index_select(seq_q_dim, bias_seq_idx) + bias_ = bias_.view(*shape_before_seq, -1, seq_kv_size) + bias_.requires_grad = True # set up environment core_attn.set_context_parallel_group( cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, @@ -474,19 +503,26 @@ def run_dpa_with_cp( for x in [dq_, dk_, dv_, out_] ] if dbias is not None and dbias_ is not None: + ndim = dbias.ndim + # Query seq is at dim -2 + seq_q_dim = ndim - 2 + shape_before_seq = dbias.shape[:seq_q_dim] + seq_q_size = dbias.shape[seq_q_dim] + seq_kv_size = dbias.shape[-1] + # Reshape to split seq_q dimension dbias = dbias.view( - dbias.shape[0], - dbias.shape[1], + *shape_before_seq, 2 * world_size, - dbias.shape[2] // (2 * world_size), - dbias.shape[3], + seq_q_size // (2 * world_size), + seq_kv_size ) - # bias has fixed axis (2) as dbias shape: (1, 1, max_seqlen_q, max_seqlen_kv) - dbias = dbias.index_select(2, seq_idx) - # Flatten - dbias = dbias.view(dbias.shape[0], dbias.shape[1], -1, dbias.shape[-1]) + # Index select on the newly created dimension (now at position seq_q_dim) + dbias = dbias.index_select(seq_q_dim, seq_idx) dbias_ = dbias_.view( - dbias_.shape[0], dbias_.shape[1], 2, dbias_.shape[2] // 2, dbias_.shape[3] + *shape_before_seq, + 2, + dbias_.shape[seq_q_dim] // 2, + seq_kv_size ) elif qkv_format == "thd": @@ -546,9 +582,17 @@ def run_dpa_with_cp( if names[i] == "dbias": # After reshaping: (1, 1, 2, seq_q//2, seq_kv) # Compare along dimension 2 (the split sequence dimension) + ndim_bias = t.ndim + seq_q_dim_bias = ndim_bias - 2 # Query sequence dimension + # After reshaping both have shape: [..., 2, seq_q//2, seq_kv] + # The split dimension is at seq_q_dim_bias + slice_0 = [slice(None)] * ndim_bias + slice_0[seq_q_dim_bias] = 0 + slice_1 = [slice(None)] * ndim_bias + slice_1[seq_q_dim_bias] = 1 compare_and_assert( - t[:, :, 0], # First sequence chunk - tensors_cp[i][:, :, 0], + t[tuple(slice_0)], # First sequence chunk + tensors_cp[i][tuple(slice_0)], names_no_cp[i], names_cp[i], atol, @@ -557,8 +601,8 @@ def run_dpa_with_cp( is_fp8, ) compare_and_assert( - t[:, :, 1], # Second sequence chunk - tensors_cp[i][:, :, 1], + t[tuple(slice_1)], # First sequence chunk + tensors_cp[i][tuple(slice_1)], names_no_cp[i], names_cp[i], atol, @@ -595,9 +639,15 @@ def run_dpa_with_cp( if names[i] == "dbias": # After reshaping: (1, 1, 2, seq_q//2, seq_kv) # Compare along dimension 2 (the split sequence dimension) + ndim_bias = t.ndim + seq_q_dim_bias = ndim_bias - 2 + slice_0 = [slice(None)] * ndim_bias + slice_0[seq_q_dim_bias] = 0 + slice_1 = [slice(None)] * ndim_bias + slice_1[seq_q_dim_bias] = 1 compare_and_assert( - t[:, :, 0], # First sequence chunk - tensors_cp[i][:, :, 0], + t[tuple(slice_0)], # First sequence chunk + tensors_cp[i][tuple(slice_0)], names_no_cp[i], names_cp[i], atol, @@ -606,8 +656,8 @@ def run_dpa_with_cp( is_fp8, ) compare_and_assert( - t[:, :, 1], # Second sequence chunk - tensors_cp[i][:, :, 1], + t[tuple(slice_1)], # First sequence chunk + tensors_cp[i][tuple(slice_1)], names_no_cp[i], names_cp[i], atol, From 126be036f7931c540b7e4c897b5242c2cf3b55f5 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 6 Feb 2026 23:15:31 +0000 Subject: [PATCH 15/25] Add fused attn cp test cases for all supported bias shapes Signed-off-by: Kshitij Lakhani --- .../attention/test_attention_with_cp.py | 41 +++++++++++++------ 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 836598087b..84e793e9ea 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -147,7 +147,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), # MHA "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA - "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA + "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"), # MHA + "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA "cp_2_2": ModelConfig( @@ -160,10 +161,30 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): attn_bias_type="post_scale_bias", ), # GQA "cp_2_3": ModelConfig( - 2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias" + 2, + 4096, + 12, + 128, + num_gqa_groups=2, + attn_mask_type="causal", + attn_bias_type="post_scale_bias", + bias_shape="11ss", ), # GQA "cp_2_4": ModelConfig( - 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512) + 2, + 4096, + 12, + 128, + num_gqa_groups=2, + attn_mask_type="causal", + attn_bias_type="post_scale_bias", + bias_shape="111s", + ), # GQA + "cp_2_5": ModelConfig( + 2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias" + ), # GQA + "cp_2_6": ModelConfig( + 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) ), # GQA "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA @@ -171,6 +192,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA "cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA + "cp_3_4": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss", head_dim_v=64), # MLA "cp_4_0": ModelConfig( 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla" ), # GQA @@ -187,16 +209,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = [ - "cp_1_0", - "cp_1_1", - "cp_1_4", - "cp_2_0", - "cp_2_2", - "cp_2_4", - "cp_3_2", - "cp_4_2", - ] + configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_1_5", "cp_2_0", "cp_2_2", "cp_2_3", "cp_2_4", "cp_3_2", "cp_3_4", "cp_4_2"] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] qkv_formats = ["sbhd", "thd"] @@ -257,6 +270,8 @@ def test_cp_with_fused_attention( pytest.skip("FP8 attention cannot work with sliding window yet!") if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("CP implementation with KV P2P does not support sliding window yet!") + if "p2p" in cp_comm_type and config.attn_bias_type != "no_bias" and config.bias_shape == "111s": + pytest.skip(f"CP implementation with KV P2P requires bias sequence dim to be divisible by 2") if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": pytest.skip("CP implementation with KV all-gather does not support bias yet!") if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": From 7b0f942a67f0d5fe39aebb05e6b59e92bec2fd2e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Feb 2026 23:50:31 +0000 Subject: [PATCH 16/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../attention/run_attention_with_cp.py | 30 ++++------ tests/pytorch/attention/test_attention.py | 7 ++- .../attention/test_attention_with_cp.py | 26 +++++++-- tests/pytorch/utils.py | 2 +- .../fused_attn_f16_arbitrary_seqlen.cu | 55 ++++++++++--------- transformer_engine/common/fused_attn/utils.h | 21 +++---- .../attention/dot_product_attention/utils.py | 4 +- 7 files changed, 82 insertions(+), 63 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index d6300ef5d9..4a8ab747ce 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -309,7 +309,12 @@ def run_dpa_with_cp( "1hss": (1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv), "11ss": (1, 1, config.max_seqlen_q, config.max_seqlen_kv), "b1ss": (config.batch_size, 1, config.max_seqlen_q, config.max_seqlen_kv), - "bhss": (config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv), + "bhss": ( + config.batch_size, + config.num_heads, + config.max_seqlen_q, + config.max_seqlen_kv, + ), "111s": (1, 1, 1, config.max_seqlen_kv), } attn_bias_shape = bias_shape_map.get(config.bias_shape) @@ -411,16 +416,13 @@ def run_dpa_with_cp( seq_q_size = bias_.shape[seq_q_dim] seq_kv_size = bias_.shape[-1] if seq_q_size == 1: - #TODO(KshitijLakhani): Set to True always once cuDNN supports dbias for 111s - bias_.requires_grad = False + # TODO(KshitijLakhani): Set to True always once cuDNN supports dbias for 111s + bias_.requires_grad = False # Bias is broadcast, no need to partition along sequence dimension pass else: bias_ = bias_.view( - *shape_before_seq, - 2 * world_size, - seq_q_size // (2 * world_size), - seq_kv_size + *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size ) bias_ = bias_.index_select(seq_q_dim, bias_seq_idx) bias_ = bias_.view(*shape_before_seq, -1, seq_kv_size) @@ -505,25 +507,17 @@ def run_dpa_with_cp( if dbias is not None and dbias_ is not None: ndim = dbias.ndim # Query seq is at dim -2 - seq_q_dim = ndim - 2 + seq_q_dim = ndim - 2 shape_before_seq = dbias.shape[:seq_q_dim] seq_q_size = dbias.shape[seq_q_dim] seq_kv_size = dbias.shape[-1] # Reshape to split seq_q dimension dbias = dbias.view( - *shape_before_seq, - 2 * world_size, - seq_q_size // (2 * world_size), - seq_kv_size + *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size ) # Index select on the newly created dimension (now at position seq_q_dim) dbias = dbias.index_select(seq_q_dim, seq_idx) - dbias_ = dbias_.view( - *shape_before_seq, - 2, - dbias_.shape[seq_q_dim] // 2, - seq_kv_size - ) + dbias_ = dbias_.view(*shape_before_seq, 2, dbias_.shape[seq_q_dim] // 2, seq_kv_size) elif qkv_format == "thd": dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 2977bab592..01b2aac453 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -162,13 +162,16 @@ def test_dot_product_attention( ) # Get backends - # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. + # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. # For all other shapes test fwd+bwd is_training = True # TODO(KshitijLakhani): Set is_training to True for all cases once cuDNN supports dbias for 111s. if config.bias_shape == "111s": is_training = False - logging.info(f"Setting is_training to False as cuDNN does not support dbias for {config.bias_shape=} ") + logging.info( + "Setting is_training to False as cuDNN does not support dbias for" + f" {config.bias_shape=} " + ) available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 84e793e9ea..769302a50a 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -147,7 +147,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), # MHA "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA - "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"), # MHA + "cp_1_4": ModelConfig( + 2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss" + ), # MHA "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA @@ -192,7 +194,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA "cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA - "cp_3_4": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss", head_dim_v=64), # MLA + "cp_3_4": ModelConfig( + 2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss", head_dim_v=64 + ), # MLA "cp_4_0": ModelConfig( 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla" ), # GQA @@ -209,7 +213,19 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_1_5", "cp_2_0", "cp_2_2", "cp_2_3", "cp_2_4", "cp_3_2", "cp_3_4", "cp_4_2"] + configs = [ + "cp_1_0", + "cp_1_1", + "cp_1_4", + "cp_1_5", + "cp_2_0", + "cp_2_2", + "cp_2_3", + "cp_2_4", + "cp_3_2", + "cp_3_4", + "cp_4_2", + ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] qkv_formats = ["sbhd", "thd"] @@ -271,7 +287,9 @@ def test_cp_with_fused_attention( if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if "p2p" in cp_comm_type and config.attn_bias_type != "no_bias" and config.bias_shape == "111s": - pytest.skip(f"CP implementation with KV P2P requires bias sequence dim to be divisible by 2") + pytest.skip( + f"CP implementation with KV P2P requires bias sequence dim to be divisible by 2" + ) if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": pytest.skip("CP implementation with KV all-gather does not support bias yet!") if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index d898543421..672c863389 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -288,7 +288,7 @@ def get_available_attention_backends( and config.head_dim_qk <= 128 and config.head_dim_v <= 128 ): - #TODO(KshitijLakhani): Remove this guard when cuDNN starts support dbias calculation for bias shape 111s + # TODO(KshitijLakhani): Remove this guard when cuDNN starts support dbias calculation for bias shape 111s if core_attention_bias_shape != "111s": core_attention_bias_requires_grad = True diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index bf0075232c..fcfd0bb9bd 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -52,13 +52,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, - int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, bool is_training, - bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ, - void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, - void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, - void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, + int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, + bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, + void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO, + void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -551,17 +552,17 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose, - void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset, - void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, - void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset, - void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, - void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, - size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, + int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, void *devPtrQ, void *devPtrKTranspose, + void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, + void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, + void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, + void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, + void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, + void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -1210,10 +1211,10 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, is_training, - return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, - devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, + is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, + devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1308,11 +1309,11 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, - devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, - devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, + p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, + devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index ae27223980..08a56cda6b 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -122,18 +122,19 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, - attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, deterministic, - bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, - generate_max_sum_exp) < + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, + bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, + dqkv_tensor_type, generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, - rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, rhs.attnScale, rhs.isTraining, - rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, - rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, - rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, - rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); + rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, + rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, + rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, + rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type, + rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, + rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); } }; diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 3502a90ad2..3a37ed4197 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -968,7 +968,9 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ): # dbias calculation is not supported for 111s as of cuDNN 9.18. So, use fused attention backend only if bias does not require grad. if fu_core_attention_bias_requires_grad and fu_core_attention_bias_shape == "111s": - logger.warning("Disabling FusedAttention as dbias calculation is not supported for 111s") + logger.warning( + "Disabling FusedAttention as dbias calculation is not supported for 111s" + ) use_fused_attention = False if not fu_core_attention_bias_requires_grad: # max512 backend will only support [1, h, s, s] From f79505642ec5632f9e1a4dcc00477b66118871e7 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 13 Feb 2026 13:44:14 -0800 Subject: [PATCH 17/25] nit: switch to elif for bias grad conditional Signed-off-by: Kshitij Lakhani --- .../pytorch/attention/dot_product_attention/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 3a37ed4197..3432fd832f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -972,7 +972,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt "Disabling FusedAttention as dbias calculation is not supported for 111s" ) use_fused_attention = False - if not fu_core_attention_bias_requires_grad: + elif not fu_core_attention_bias_requires_grad: # max512 backend will only support [1, h, s, s] os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" From 0e74dcff5cb79ee31d53db860125dcb372f8f49e Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 13 Feb 2026 14:24:29 -0800 Subject: [PATCH 18/25] Add CP support for bias/dbias shape 111s Signed-off-by: Kshitij Lakhani --- .../dot_product_attention/context_parallel.py | 76 ++++++++++++------- 1 file changed, 50 insertions(+), 26 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index a5931188dc..aee50e18e4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -840,13 +840,24 @@ def cp_p2p_fwd_fused_attn( q_part = q_part.contiguous() if attn_bias is not None: idx = (rank - step) % cp_size - attn_bias_inputs = torch.cat( - ( - attn_bias_[..., 1, :, idx, :], - attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :], - ), - dim=-1, - ).contiguous() + # For bias shape 111s, only the s_kv dim is split, i.e. [b, h, sq, 2*cp, sk//(2*cp)]) + if attn_bias.shape[-3] == 1: + attn_bias_inputs = torch.cat( + ( + attn_bias_[..., :, idx, :], + attn_bias_[..., :, (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() + # For bias shapes 1hss, 11ss, bhss, b1ss, the s_kv and s_q dims are split, i.e. [b, h, 2, sq//2, 2*cp, sk//(2*cp)]) + else: + attn_bias_inputs = torch.cat( + ( + attn_bias_[..., 1, :, idx, :], + attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() max_seqlen_q_ = max_seqlen_q // 2 max_seqlen_kv_ = max_seqlen_kv cu_seqlens_q_ = cu_seqlens_q_per_step @@ -1442,20 +1453,30 @@ def forward( attn_bias_ = None if attn_bias is not None: assert len(attn_bias.shape) == 4, ( - "Only support bias shape of [b, h, sq, sk] for forward, " - "and [1, h, sq, sk] for backward!" - ) - assert ( - attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0 - ), "Sequence length does not meet divisible requirements!" - # [b, h, sq, sk] -> [b, h, 2, sq//2, 2*cp, sk//(2*cp)] - attn_bias_ = attn_bias.view( - *attn_bias.shape[:-2], - 2, - attn_bias.shape[-2] // 2, - 2 * cp_size, - attn_bias.shape[-1] // (2 * cp_size), + "Only support bias shape of [1,1,sq,skv], [1,h,sq,skv], [b,1,sq,skv], [b,h,sq,skv], [1,1,sq,skv] for forward, " + "and [1,1,sq,skv], [1,h,sq,skv], [b,1,sq,skv], [b,h,sq,skv] for backward!" ) + # For all bias shapes except 111s, sq must be divisible by 2 and sk must be divisible by 2*cp_size + # For bias shape 111s, only sq must be divisible by 2 + if attn_bias.shape[-2] != 1: + assert ( + attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0 + ), "Sequence length does not meet divisible requirements!" + # [b, h, sq, sk] -> [b, h, 2, sq//2, 2*cp, sk//(2*cp)] + attn_bias_ = attn_bias.view( + *attn_bias.shape[:-2], + 2, + attn_bias.shape[-2] // 2, + 2 * cp_size, + attn_bias.shape[-1] // (2 * cp_size), + ) + else: + assert attn_bias.shape[-1] % (2 * cp_size) == 0, "Sequence length does not meet divisible requirements!" + # [b, h, sq, sk] -> [b, h, sq, 2*cp, sk//(2*cp)] + attn_bias_ = attn_bias.view( + *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) + ) + # [b, h, sq, sk] -> [b, h, sq, 2*cp, sk//(2*cp)] attn_bias = attn_bias.view( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) @@ -2076,10 +2097,13 @@ def backward(ctx, dout, *_args): attn_dbias = torch.zeros( *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device ) - # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, 2, sq//2, 2*cp, sk//(2*cp)] - attn_dbias_ = attn_dbias.view( - *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:] - ) + # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, 2, sq//2, 2*cp, sk//(2*cp)] only when sq > 1 (i.e. all supported bias shapes except 111s) + if attn_dbias.shape[-3] > 1: + attn_dbias_ = attn_dbias.view( + *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:] + ) + else: + attn_dbias_ = None else: attn_dbias = None attn_dbias_ = None @@ -2507,8 +2531,8 @@ def backward(ctx, dout, *_args): elif i >= (cp_size - rank - 1): # [b, h, sq, sk//(2*cp)] attn_dbias[..., idx, :].copy_(dbias_) - else: - # [b, h, sq//2, sk//cp] -> [b, h, sq//2, 2, sk//(2*cp)] + elif attn_dbias_ is not None: + # upper-triangle: [b, h, sq//2, sk//cp] -> [b, h, sq//2, 2, sk//(2*cp)] dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :]) attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) From 0acf8f815d07fd36e33535d9d3425a7b12442296 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 13 Feb 2026 14:35:23 -0800 Subject: [PATCH 19/25] Add support for is_training in CP attn tests Signed-off-by: Kshitij Lakhani --- .../attention/run_attention_with_cp.py | 213 +++++++++++------- .../attention/test_attention_with_cp.py | 9 +- 2 files changed, 135 insertions(+), 87 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 4a8ab747ce..5f25e94bbe 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -179,10 +179,13 @@ def run_dpa_with_cp( fp8_mha="False", scaling_mode="delayed", f16_O="False", + is_training="True", log_level=logging.WARNING, ): """Test DotProductAttention module with context parallelism""" logging.root.setLevel(log_level) + # When is_training is False, gradient outputs are None. + is_training = is_training == "True" # set up environment variables and config fp8_bwd = fp8_bwd == "True" and dtype == "fp8" @@ -257,7 +260,9 @@ def run_dpa_with_cp( softmax_type=config.softmax_type, return_max_logit=config.return_max_logit, ).cuda() - if config.softmax_type != "vanilla": + if not is_training: + core_attn.eval() + if is_training and config.softmax_type != "vanilla": core_attn.softmax_offset.requires_grad = True # generate attention inputs @@ -350,15 +355,20 @@ def run_dpa_with_cp( ) if config.return_max_logit: out, max_logit = out - if fp8_bwd and fp8_mha: - dout_fp8 = dout_quantizer(dout) - out.backward(dout_fp8) - else: - out.backward(dout) - dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None - d_softmax_offset = None - if config.softmax_type != "vanilla": - d_softmax_offset = core_attn.softmax_offset.grad + if is_training: + if fp8_bwd and fp8_mha: + dout_fp8 = dout_quantizer(dout) + out.backward(dout_fp8) + else: + out.backward(dout) + if is_training: + dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None + d_softmax_offset = ( + core_attn.softmax_offset.grad if config.softmax_type != "vanilla" else None + ) + else: + dq, dk, dv, dbias = None, None, None, None + d_softmax_offset = None ############ run with CP ############ logging.info(f"[Rank {rank}] Run with context parallelism") @@ -404,7 +414,8 @@ def run_dpa_with_cp( dout_quantizer.amax.fill_(0.0) if fp8_mha: q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) - q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] + if is_training: + q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: ndim = bias_.ndim seq_q_dim = ndim - 2 @@ -461,15 +472,27 @@ def run_dpa_with_cp( ) if config.return_max_logit: out_, max_logit_ = out_ - if fp8_bwd and fp8_mha: - dout_fp8_ = dout_quantizer(dout_) - out_.backward(dout_fp8_) - else: - out_.backward(dout_) - dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad if bias_ is not None else None - d_softmax_offset_ = None - if config.softmax_type != "vanilla": - d_softmax_offset_ = core_attn.softmax_offset.grad.clone() + if is_training: + if fp8_bwd and fp8_mha: + dout_fp8_ = dout_quantizer(dout_) + out_.backward(dout_fp8_) + else: + out_.backward(dout_) + if is_training: + dq_, dk_, dv_, dbias_ = ( + q_.grad, + k_.grad, + v_.grad, + bias_.grad if bias_ is not None else None, + ) + d_softmax_offset_ = ( + core_attn.softmax_offset.grad.clone() + if config.softmax_type != "vanilla" + else None + ) + else: + dq_, dk_, dv_, dbias_ = None, None, None, None + d_softmax_offset_ = None # get outputs tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] @@ -490,75 +513,99 @@ def run_dpa_with_cp( ############ compare results between CP and no-CP ############ if qkv_format == "bshd" or qkv_format == "sbhd": - dq, dk, dv, out = [ - x.view( - *x.shape[:seq_dim], + if is_training: + dq, dk, dv, out = [ + x.view( + *x.shape[:seq_dim], + 2 * world_size, + x.shape[seq_dim] // (2 * world_size), + *x.shape[(seq_dim + 1) :], + ) + for x in [dq, dk, dv, out] + ] + dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] + dq_, dk_, dv_, out_ = [ + x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) + for x in [dq_, dk_, dv_, out_] + ] + if dbias is not None and dbias_ is not None: + ndim = dbias.ndim + # Query seq is at dim -2 + seq_q_dim = ndim - 2 + shape_before_seq = dbias.shape[:seq_q_dim] + seq_q_size = dbias.shape[seq_q_dim] + seq_kv_size = dbias.shape[-1] + # Reshape to split seq_q dimension + dbias = dbias.view( + *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size + ) + # Index select on the newly created dimension (now at position seq_q_dim) + dbias = dbias.index_select(seq_q_dim, seq_idx) + dbias_ = dbias_.view( + *shape_before_seq, 2, dbias_.shape[seq_q_dim] // 2, seq_kv_size + ) + else: + # Forward-only: reshape only out/out_ for comparison + out = out.view( + *out.shape[:seq_dim], 2 * world_size, - x.shape[seq_dim] // (2 * world_size), - *x.shape[(seq_dim + 1) :], + out.shape[seq_dim] // (2 * world_size), + *out.shape[(seq_dim + 1) :], ) - for x in [dq, dk, dv, out] - ] - dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] - dq_, dk_, dv_, out_ = [ - x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) - for x in [dq_, dk_, dv_, out_] - ] - if dbias is not None and dbias_ is not None: - ndim = dbias.ndim - # Query seq is at dim -2 - seq_q_dim = ndim - 2 - shape_before_seq = dbias.shape[:seq_q_dim] - seq_q_size = dbias.shape[seq_q_dim] - seq_kv_size = dbias.shape[-1] - # Reshape to split seq_q dimension - dbias = dbias.view( - *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size + out = out.index_select(seq_dim, seq_idx) + out_ = out_.view( + *out_.shape[:seq_dim], 2, out_.shape[seq_dim] // 2, *out_.shape[(seq_dim + 1) :] ) - # Index select on the newly created dimension (now at position seq_q_dim) - dbias = dbias.index_select(seq_q_dim, seq_idx) - dbias_ = dbias_.view(*shape_before_seq, 2, dbias_.shape[seq_q_dim] // 2, seq_kv_size) elif qkv_format == "thd": - dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] - dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] - dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] - cu_seqlens_q_padded = cu_seqlens_q_padded // world_size - cu_seqlens_q = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True - ) - cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q - num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] - for x in [dq, out, dq_, out_]: - assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_q[b] == 0 - or torch.count_nonzero( - x[(cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[b + 1]] - ).item() - == 0 - ) - cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size - cu_seqlens_kv = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True - ) - cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv - num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] - for x in [dk, dv, dk_, dv_]: - assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_kv[b] == 0 - or torch.count_nonzero( - x[ - (cu_seqlens_kv_padded[b + 1] - num_pads_kv[b]) : cu_seqlens_kv_padded[ - b + 1 + if is_training: + dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] + dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] + dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] + cu_seqlens_q_padded = cu_seqlens_q_padded // world_size + cu_seqlens_q = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True + ) + cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q + num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] + for x in [dq, out, dq_, out_]: + assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_q[b] == 0 + or torch.count_nonzero( + x[ + (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ + b + 1 + ] ] - ] - ).item() - == 0 - ) + ).item() + == 0 + ) + cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size + cu_seqlens_kv = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True + ) + cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv + num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] + for x in [dk, dv, dk_, dv_]: + assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_kv[b] == 0 + or torch.count_nonzero( + x[ + (cu_seqlens_kv_padded[b + 1] - num_pads_kv[b]) : cu_seqlens_kv_padded[ + b + 1 + ] + ] + ).item() + == 0 + ) + else: + # Forward-only: reshape only out/out_ for comparison + out = out.index_select(0, seq_idx_q).contiguous() + out_ = out_ atol, rtol, rmse_tol = get_tols(config, dtype) tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_] diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 769302a50a..1052896c0f 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -181,6 +181,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): attn_mask_type="causal", attn_bias_type="post_scale_bias", bias_shape="111s", + return_max_logit=True ), # GQA "cp_2_5": ModelConfig( 2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias" @@ -286,10 +287,6 @@ def test_cp_with_fused_attention( pytest.skip("FP8 attention cannot work with sliding window yet!") if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("CP implementation with KV P2P does not support sliding window yet!") - if "p2p" in cp_comm_type and config.attn_bias_type != "no_bias" and config.bias_shape == "111s": - pytest.skip( - f"CP implementation with KV P2P requires bias sequence dim to be divisible by 2" - ) if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": pytest.skip("CP implementation with KV all-gather does not support bias yet!") if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": @@ -357,12 +354,15 @@ def test_cp_with_fused_attention( Float8CurrentScaling(fp8_dpa=True), DelayedScaling(fp8_dpa=True), ] + # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. + is_training = False if config.bias_shape == "111s" else True available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, qkv_layout="_".join([qkv_format] * 3), fp8=fp8, fp8_meta=fp8_meta, + is_training=is_training, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -381,6 +381,7 @@ def test_cp_with_fused_attention( fp8_mha=fp8_mha, scaling_mode=scaling_mode, f16_O=f16_O, + is_training=is_training, log_level=pytest_logging_level, ), check=True, From 2133bd8ddf8112d67efd6ce20df158ccd48a867b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Feb 2026 22:37:03 +0000 Subject: [PATCH 20/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/run_attention_with_cp.py | 10 ++++------ tests/pytorch/attention/test_attention_with_cp.py | 2 +- .../dot_product_attention/context_parallel.py | 9 ++++++--- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 5f25e94bbe..fce55c8f6c 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -486,9 +486,7 @@ def run_dpa_with_cp( bias_.grad if bias_ is not None else None, ) d_softmax_offset_ = ( - core_attn.softmax_offset.grad.clone() - if config.softmax_type != "vanilla" - else None + core_attn.softmax_offset.grad.clone() if config.softmax_type != "vanilla" else None ) else: dq_, dk_, dv_, dbias_ = None, None, None, None @@ -595,9 +593,9 @@ def run_dpa_with_cp( num_pads_kv[b] == 0 or torch.count_nonzero( x[ - (cu_seqlens_kv_padded[b + 1] - num_pads_kv[b]) : cu_seqlens_kv_padded[ - b + 1 - ] + ( + cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] + ) : cu_seqlens_kv_padded[b + 1] ] ).item() == 0 diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 1052896c0f..523a1a4c38 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -181,7 +181,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): attn_mask_type="causal", attn_bias_type="post_scale_bias", bias_shape="111s", - return_max_logit=True + return_max_logit=True, ), # GQA "cp_2_5": ModelConfig( 2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias" diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index aee50e18e4..e632f28051 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1453,8 +1453,9 @@ def forward( attn_bias_ = None if attn_bias is not None: assert len(attn_bias.shape) == 4, ( - "Only support bias shape of [1,1,sq,skv], [1,h,sq,skv], [b,1,sq,skv], [b,h,sq,skv], [1,1,sq,skv] for forward, " - "and [1,1,sq,skv], [1,h,sq,skv], [b,1,sq,skv], [b,h,sq,skv] for backward!" + "Only support bias shape of [1,1,sq,skv], [1,h,sq,skv], [b,1,sq,skv], [b,h,sq,skv]," + " [1,1,sq,skv] for forward, and [1,1,sq,skv], [1,h,sq,skv], [b,1,sq,skv]," + " [b,h,sq,skv] for backward!" ) # For all bias shapes except 111s, sq must be divisible by 2 and sk must be divisible by 2*cp_size # For bias shape 111s, only sq must be divisible by 2 @@ -1471,7 +1472,9 @@ def forward( attn_bias.shape[-1] // (2 * cp_size), ) else: - assert attn_bias.shape[-1] % (2 * cp_size) == 0, "Sequence length does not meet divisible requirements!" + assert ( + attn_bias.shape[-1] % (2 * cp_size) == 0 + ), "Sequence length does not meet divisible requirements!" # [b, h, sq, sk] -> [b, h, sq, 2*cp, sk//(2*cp)] attn_bias_ = attn_bias.view( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) From 89e90a58b8ebd1bf87c6d129d0d690a34ea8ea03 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 17 Feb 2026 11:30:46 -0800 Subject: [PATCH 21/25] nit: Fix incorrect comment Signed-off-by: Kshitij Lakhani --- .../attention/dot_product_attention/context_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index e632f28051..c76d21c17e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1457,8 +1457,8 @@ def forward( " [1,1,sq,skv] for forward, and [1,1,sq,skv], [1,h,sq,skv], [b,1,sq,skv]," " [b,h,sq,skv] for backward!" ) - # For all bias shapes except 111s, sq must be divisible by 2 and sk must be divisible by 2*cp_size - # For bias shape 111s, only sq must be divisible by 2 + # For all bias shapes except 111s, sq must be divisible by 2 and skv must be divisible by 2*cp_size + # For bias shape 111s, only skv must be divisible by 2 if attn_bias.shape[-2] != 1: assert ( attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0 From 5a25d9c4274cc43b71092407d4325e3137092ac8 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 17 Feb 2026 12:08:34 -0800 Subject: [PATCH 22/25] nit: Fix incorrect comment and assert string Signed-off-by: Kshitij Lakhani --- .../attention/dot_product_attention/context_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index c76d21c17e..bd6b626b64 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1454,11 +1454,11 @@ def forward( if attn_bias is not None: assert len(attn_bias.shape) == 4, ( "Only support bias shape of [1,1,sq,skv], [1,h,sq,skv], [b,1,sq,skv], [b,h,sq,skv]," - " [1,1,sq,skv] for forward, and [1,1,sq,skv], [1,h,sq,skv], [b,1,sq,skv]," + " [1,1,1,skv] for forward, and [1,1,sq,skv], [1,h,sq,skv], [b,1,sq,skv]," " [b,h,sq,skv] for backward!" ) # For all bias shapes except 111s, sq must be divisible by 2 and skv must be divisible by 2*cp_size - # For bias shape 111s, only skv must be divisible by 2 + # For bias shape 111s, only skv must be divisible by 2*cp_size if attn_bias.shape[-2] != 1: assert ( attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0 From 0e2a72fa94d96a70a39e0f62ef1408c9166d01e6 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 17 Feb 2026 17:00:51 -0800 Subject: [PATCH 23/25] Create the dbias graph tensor only if it is a cuDNN supported bias shape Signed-off-by: Kshitij Lakhani --- .../fused_attn_f16_arbitrary_seqlen.cu | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index fcfd0bb9bd..be2e0dea06 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -823,15 +823,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_name("bias") .set_dim({bias_b, bias_h, bias_sq, bias_skv}) .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); - dBias = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("dBias") - .set_dim({bias_b, bias_h, bias_sq, bias_skv}) - .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_backward_options.set_bias(bias); // bias shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s], [1, h, s, s] are supported for dbias calculation // bias shape [1, 1, 1, s] is not supported for dbias calculation as of cuDNN 9.18 if (!((bias_b == 1) && (bias_h == 1) && (bias_sq == 1))) { + dBias = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("dBias") + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_backward_options.set_dbias(dBias); } } @@ -982,12 +982,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_bias) { variant_pack[bias] = devPtrBias; - // bias shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s], [1, h, s, s] are supported for dbias calculation - // bias shape [1, 1, 1, s] is not supported for dbias calculation as of cuDNN 9.18 - if (!((bias_b == 1) && (bias_h == 1) && (bias_sq == 1))) { + if (dBias != nullptr) { variant_pack[dBias] = devPtrdBias; - } else { - variant_pack[dBias] = nullptr; } } From f066c889ad53c25fcefc8f9904d91af617190987 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 17 Feb 2026 17:21:42 -0800 Subject: [PATCH 24/25] Fix the dim that is being compared for the two cp chunks in the test Signed-off-by: Kshitij Lakhani --- .../attention/run_attention_with_cp.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index fce55c8f6c..0f36a8816d 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -619,18 +619,15 @@ def run_dpa_with_cp( # Compare the two sequence chunks separately # Compare dbias if names[i] == "dbias": - # After reshaping: (1, 1, 2, seq_q//2, seq_kv) - # Compare along dimension 2 (the split sequence dimension) + # Compare the two chunks along dimension 2 (the split sequence dimension) + seq_q_dim_bias = 2 ndim_bias = t.ndim - seq_q_dim_bias = ndim_bias - 2 # Query sequence dimension - # After reshaping both have shape: [..., 2, seq_q//2, seq_kv] - # The split dimension is at seq_q_dim_bias slice_0 = [slice(None)] * ndim_bias slice_0[seq_q_dim_bias] = 0 slice_1 = [slice(None)] * ndim_bias slice_1[seq_q_dim_bias] = 1 compare_and_assert( - t[tuple(slice_0)], # First sequence chunk + t[tuple(slice_0)], tensors_cp[i][tuple(slice_0)], names_no_cp[i], names_cp[i], @@ -640,7 +637,7 @@ def run_dpa_with_cp( is_fp8, ) compare_and_assert( - t[tuple(slice_1)], # First sequence chunk + t[tuple(slice_1)], tensors_cp[i][tuple(slice_1)], names_no_cp[i], names_cp[i], @@ -651,7 +648,7 @@ def run_dpa_with_cp( ) # Compare Q/K/V/out else: - # Compare along dimension 1 (the split sequence dimension) + # Compare the two chunks along dimension 1 (the split sequence dimension) compare_and_assert( t[:, 0], tensors_cp[i][:, 0], @@ -676,16 +673,15 @@ def run_dpa_with_cp( # Compare the two sequence chunks separately # Compare dbias (same as BSHD) if names[i] == "dbias": - # After reshaping: (1, 1, 2, seq_q//2, seq_kv) - # Compare along dimension 2 (the split sequence dimension) + # Same as bshd: Compare the two chunks along dimension 2 (the split sequence dimension) + seq_q_dim_bias = 2 ndim_bias = t.ndim - seq_q_dim_bias = ndim_bias - 2 slice_0 = [slice(None)] * ndim_bias slice_0[seq_q_dim_bias] = 0 slice_1 = [slice(None)] * ndim_bias slice_1[seq_q_dim_bias] = 1 compare_and_assert( - t[tuple(slice_0)], # First sequence chunk + t[tuple(slice_0)], tensors_cp[i][tuple(slice_0)], names_no_cp[i], names_cp[i], @@ -695,7 +691,7 @@ def run_dpa_with_cp( is_fp8, ) compare_and_assert( - t[tuple(slice_1)], # First sequence chunk + t[tuple(slice_1)], tensors_cp[i][tuple(slice_1)], names_no_cp[i], names_cp[i], @@ -706,7 +702,7 @@ def run_dpa_with_cp( ) # Compare Q/K/V/out else: - # Compare along dimension 0 (the split sequence dimension) + # Compare the two chunks along dimension 0 (the split sequence dimension) compare_and_assert( t[0], tensors_cp[i][0], From ff174a8a080e392f334ee8fce7b769370fdd352c Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 17 Feb 2026 18:30:33 -0800 Subject: [PATCH 25/25] nit: Reinstate the original test for right side swa Signed-off-by: Kshitij Lakhani --- tests/pytorch/attention/test_attention_with_cp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 523a1a4c38..ecd0090a3b 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -150,7 +150,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_4": ModelConfig( 2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss" ), # MHA - "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA + "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA "cp_2_2": ModelConfig( @@ -187,7 +187,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias" ), # GQA "cp_2_6": ModelConfig( - 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) + 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512) ), # GQA "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA