diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index abdce7fdac..33f76de3f5 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -545,6 +545,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) + // T3HD and TH3D are not supported by cuDNN on SM120; assert before hitting the path. + const int device_id_fwd = cuda::current_device(); + const int sm_arch_fwd = cuda::sm_arch(device_id_fwd); + if (sm_arch_fwd >= 120 && + (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD || qkv_layout == NVTE_QKV_Layout::NVTE_TH3D)) { + NVTE_ERROR( + "T3HD and TH3D QKV layouts are not supported by cuDNN on SM120 " + "Use thd_thd_thd or other THD layouts instead."); + } fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, @@ -644,6 +653,15 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) + // T3HD and TH3D are not supported by cuDNN on SM120; assert before hitting the path. + const int device_id_bwd = cuda::current_device(); + const int sm_arch_bwd = cuda::sm_arch(device_id_bwd); + if (sm_arch_bwd >= 120 && + (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD || qkv_layout == NVTE_QKV_Layout::NVTE_TH3D)) { + NVTE_ERROR( + "T3HD and TH3D QKV layouts are not supported by cuDNN on SM120. " + "Use thd_thd_thd or other THD layouts instead."); + } size_t i = 0; Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index eb2ebcff39..044d3874ac 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -85,6 +85,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); + const int device_id = cuda::current_device(); + const int sm_arch_ = cuda::sm_arch(device_id); + bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && !(sm_arch_ >= 120); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); @@ -96,11 +99,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t actual_b = b; if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); - // replace batch size and maximum sequence lengths with maximum token counts - // for query and key/value so the graph is static within each quantization bucket - b = max_b; - s_q = is_ragged_q ? max_t_q : s_q; - s_kv = is_ragged_kv ? max_t_kv : s_kv; + // On SM 120, cuDNN support check treats layouts with stride[0] > dim[1]*dim[2]*dim[3] + // as interleaved and rejects them. Use BHSD-like dimensions/strides with max_seqlen at plan build + // so the check passes; ragged offset still provides variable-length boundaries. + if (sm_arch_ < 120) { + // replace batch size and maximum sequence lengths with maximum token counts + // for query and key/value so the graph is static within each quantization bucket + b = max_b; + s_q = is_ragged_q ? max_t_q : s_q; + s_kv = is_ragged_kv ? max_t_kv : s_kv; + } } const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; @@ -336,7 +344,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } std::shared_ptr Max, Sum_Exp; - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { offset_stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_stats") @@ -353,7 +361,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_name("Sum_Exp") .set_dim({b, h, s_q, 1}) .set_data_type(fe::DataType_t::FLOAT)); - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { Max->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); Sum_Exp->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); } else { @@ -381,7 +389,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( if (!return_max_logit) { Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); } else { Stats->set_stride({h * s_q, s_q, 1, 1}); @@ -407,9 +415,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr); auto offset_kv_tuple = is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr); - auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600) - ? std::make_tuple(offset_stats) - : std::make_tuple(nullptr); + auto offset_s_tuple = + use_ragged_stats ? std::make_tuple(offset_stats) : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); @@ -443,7 +450,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( size_t seqlen_offsets_workspace_size = 0; if (is_ragged_q || is_ragged_kv) { size_t count = 2 * (static_cast(is_ragged_q) + static_cast(is_ragged_kv)); - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset; } else { seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset; @@ -510,7 +517,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; } void *devOffsetsS = nullptr; - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { devOffsetsS = static_cast(devOffsets) + (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset; @@ -529,7 +536,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( variant_pack[offset_k] = devOffsetsK; variant_pack[offset_v] = devOffsetsV; } - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { variant_pack[offset_stats] = devOffsetsS; } } @@ -587,6 +594,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const auto cudnn_runtime_version = cudnnGetVersion(); const int device_id = cuda::current_device(); const int sm_arch_ = cuda::sm_arch(device_id); + bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && !(sm_arch_ >= 120); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); @@ -598,13 +606,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl( int64_t actual_b = b; if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); - // replace batch size and maximum sequence lengths with maximum token counts - // for query and key/value so the graph is static within each quantization bucket - b = max_b; - s_q = is_ragged_q ? max_t_q : s_q; - s_kv = is_ragged_kv ? max_t_kv : s_kv; + // On SM 120, cuDNN support check requires BHSD-like strides with max_seqlen (see fwd). + if (sm_arch_ < 120) { + // replace batch size and maximum sequence lengths with maximum token counts + // for query and key/value so the graph is static within each quantization bucket + b = max_b; + s_q = is_ragged_q ? max_t_q : s_q; + s_kv = is_ragged_kv ? max_t_kv : s_kv; + } } - // We choose between 32-bit and 64-bit offsets depending on need. // This allows us to support older cuDNN runtimes gracefully. const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; @@ -765,7 +775,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_name("stats") .set_dim({b, h, s_q, 1}) .set_data_type(fe::DataType_t::FLOAT)); - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { offset_stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_stats") @@ -791,7 +801,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { sdpa_backward_options.set_max_total_seq_len_q(s_q); } if (is_ragged_kv && cudnn_runtime_version >= 90600) { @@ -914,9 +924,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr); auto offset_kv_tuple = is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr); - auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600) - ? std::make_tuple(offset_stats) - : std::make_tuple(nullptr); + auto offset_s_tuple = + use_ragged_stats ? std::make_tuple(offset_stats) : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); @@ -949,7 +958,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( size_t seqlen_offsets_workspace_size = 0; if (is_ragged_q || is_ragged_kv) { size_t count = 2 * (static_cast(is_ragged_q) + static_cast(is_ragged_kv)); - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset; } else { seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset; @@ -1019,7 +1028,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; } void *devOffsetsS = nullptr; - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { devOffsetsS = static_cast(devOffsets) + (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset; @@ -1038,7 +1047,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( variant_pack[offset_k] = devOffsetsK; variant_pack[offset_v] = devOffsetsV; } - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { variant_pack[offset_stats] = devOffsetsS; } } @@ -1102,6 +1111,9 @@ void fused_attn_arbitrary_seqlen_fwd( devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; } + const int device_id = cuda::current_device(); + const int sm_arch_ = cuda::sm_arch(device_id); + void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; @@ -1128,7 +1140,8 @@ void fused_attn_arbitrary_seqlen_fwd( if (return_max_logit) { Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Max->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && + !(sm_arch_ >= 120)) { output_Max->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1136,7 +1149,8 @@ void fused_attn_arbitrary_seqlen_fwd( output_Max->data.dtype = DType::kFloat32; Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Sum_Exp->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && + !(sm_arch_ >= 120)) { output_Sum_Exp->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1145,7 +1159,8 @@ void fused_attn_arbitrary_seqlen_fwd( } else { Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && + !(sm_arch_ >= 120)) { output_S->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index bd6b626b64..69681104ce 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1491,7 +1491,11 @@ def forward( softmax_lse_in_packed_format = False if qkv_format == "thd": if use_fused_attention: - softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0) + softmax_lse_in_packed_format = get_cudnn_version() >= ( + 9, + 6, + 0, + ) and get_device_compute_capability() < (12, 0) else: softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 567fd17c34..9478fb999a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -554,11 +554,15 @@ def get_attention_backend( # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: - # Temporarily disabling fused attention for kv caching for sm89 irrespective of cuDNN version - # until the cuDNN bug is resolved - if device_compute_capability == (8, 9): - logger.debug("Disabling FusedAttention for KV caching for sm89") + # Temporarily disabling fused attention for kv caching for sm89/sm120 irrespective of + # cuDNN version until the cuDNN bug is resolved. + if device_compute_capability in ((8, 9), (12, 0)): + logger.debug("Disabling FusedAttention for KV caching for sm89/sm120") use_fused_attention = False + # Temporarily disable FlashAttention for KV caching on sm120 + if device_compute_capability == (12, 0): + logger.debug("Disabling FlashAttention for KV caching for sm120") + use_flash_attention = False if context_parallel: logger.debug("Disabling all backends for KV caching with context parallelism") use_flash_attention = False @@ -690,11 +694,11 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) use_flash_attention = False - if device_compute_capability == (12, 0): + if device_compute_capability == (12, 0) and cudnn_version < (9, 18, 1): if use_fused_attention: logger.debug( "Disabling FusedAttention as qkv_format = thd is" - " not supported for compute capability = sm120" + " not supported for compute capability = sm120 and cuDNN version < 9.18.1" ) use_fused_attention = False diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 101e5b2525..7d462db2ec 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -339,13 +339,22 @@ def fused_attn_fwd( if return_max_logit: qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] - # thd: output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1] - # bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] - # sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] + # thd (newer cuDNN runtimes, non-sm120): output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1] + # thd (older cuDNN runtimes or sm120): output_tensors: out [tq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] + # bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] + # sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] stats = output_tensors[1] + torch.log(output_tensors[2]) - amax_dims = (0, 2) if qkv_format == "thd" else (0, 2, 3) + max_tensor = output_tensors[1] + if qkv_format == "thd" and max_tensor.ndim == 4: + # For THD on older cuDNN runtimes or THD on sm120, stats can be [b, h, sq, 1] with padded + # sequence positions. Exclude those padded positions when computing max_logit. + seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to(device=max_tensor.device) + sq_idx = torch.arange(max_tensor.shape[2], device=max_tensor.device).view(1, 1, -1, 1) + valid = sq_idx < seqlens_q.view(-1, 1, 1, 1) + max_tensor = max_tensor.masked_fill(~valid, float("-inf")) + amax_dims = (0, 2) if max_tensor.ndim == 3 else (0, 2, 3) # Max -> max_logit [h] - max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype) + max_logit = torch.amax(max_tensor, dim=amax_dims).to(dtype=output_tensors[0].dtype) aux_ctx_tensors = [stats] aux_ctx_tensors.extend(output_tensors[3:]) return output_tensors[0], aux_ctx_tensors, max_logit