Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ab542fc
Plumbing correct bias dims from TE to cudnn
KshitijLakhani Dec 20, 2025
c86328e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 22, 2025
fddf0ac
Make changes for cp bias code
KshitijLakhani Jan 9, 2026
d3aa7ec
Add dbias and dbias_ to run_dpa_with_cp test
KshitijLakhani Jan 9, 2026
4d295c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2026
f4f9cc6
Fix: Use output_dBias instead of input_dBias to extract the shape
KshitijLakhani Jan 9, 2026
f9f4fb8
Add guards for bias/bias_/dbias/dbias_ being None
KshitijLakhani Jan 21, 2026
d9547fa
Add support for bias shape 111s in addition to the original 1hss, 11s…
KshitijLakhani Jan 22, 2026
7ede1fe
Add support for dbias calculation and variant packing for the dbias s…
KshitijLakhani Feb 6, 2026
c20d67a
Add support for 111s bias shape in DPA
KshitijLakhani Feb 6, 2026
303aee7
Allow fused attn for dbias calculation for 11ss, b1ss, bhss. Disable …
KshitijLakhani Feb 6, 2026
e9f88f0
Disable requires_grad for bias for shape 111s in tests
KshitijLakhani Feb 6, 2026
6bf73e1
Disable bias grad / training flag for 111s bias in the non-CP attn te…
KshitijLakhani Feb 6, 2026
ebee29b
Fix to correctly create the bias shape tensor instead of the hard cod…
KshitijLakhani Feb 6, 2026
126be03
Add fused attn cp test cases for all supported bias shapes
KshitijLakhani Feb 6, 2026
7b0f942
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2026
f795056
nit: switch to elif for bias grad conditional
KshitijLakhani Feb 13, 2026
0e74dcf
Add CP support for bias/dbias shape 111s
KshitijLakhani Feb 13, 2026
0acf8f8
Add support for is_training in CP attn tests
KshitijLakhani Feb 13, 2026
2133bd8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 13, 2026
89e90a5
nit: Fix incorrect comment
KshitijLakhani Feb 17, 2026
5a25d9c
nit: Fix incorrect comment and assert string
KshitijLakhani Feb 17, 2026
0e2a72f
Create the dbias graph tensor only if it is a cuDNN supported bias shape
KshitijLakhani Feb 18, 2026
f066c88
Fix the dim that is being compared for the two cp chunks in the test
KshitijLakhani Feb 18, 2026
ff174a8
nit: Reinstate the original test for right side swa
KshitijLakhani Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
413 changes: 288 additions & 125 deletions tests/pytorch/attention/run_attention_with_cp.py

Large diffs are not rendered by default.

22 changes: 19 additions & 3 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,16 @@ def test_dot_product_attention(
)

# Get backends
# For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s.
# For all other shapes test fwd+bwd
is_training = True
# TODO(KshitijLakhani): Set is_training to True for all cases once cuDNN supports dbias for 111s.
if config.bias_shape == "111s":
is_training = False
logging.info(
"Setting is_training to False as cuDNN does not support dbias for"
f" {config.bias_shape=} "
)
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
Expand Down Expand Up @@ -636,7 +645,8 @@ def test_dpa_bias(dtype, model_configs, model):
"bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
"bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"),
"bias_1_4": ModelConfig(
"bias_1_4": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="111s"),
"bias_1_5": ModelConfig(
4,
2048,
24,
Expand All @@ -646,7 +656,7 @@ def test_dpa_bias(dtype, model_configs, model):
bias_shape="1hss",
alibi_type="custom",
),
"bias_1_5": ModelConfig(
"bias_1_6": ModelConfig(
2,
2048,
24,
Expand Down Expand Up @@ -1143,10 +1153,16 @@ def _run_dot_product_attention(
bias = None
if config.attn_bias_type == "post_scale_bias":
shape = "_".join(config.bias_shape)
# For 1hss, 11ss, b1ss, bhss
shape_cache = shape
shape = shape.replace("_s_s", "_sq_skv")
# For 111s
if shape == shape_cache:
shape = shape.replace("_1_s", "_1_skv")
tensor_shape = [dim_to_num[j] for j in shape.split("_")]
bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
if config.bias_shape != "1hss":
# For 111s, dbias calculation is not supported as of cuDNN 9.18
if config.bias_shape == "111s":
bias.requires_grad = False

# Create RNG
Expand Down
38 changes: 36 additions & 2 deletions tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), # MHA
"cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA
"cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA
"cp_1_4": ModelConfig(
2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"
), # MHA
"cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA
"cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
"cp_2_2": ModelConfig(
Expand All @@ -160,9 +163,30 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
attn_bias_type="post_scale_bias",
), # GQA
"cp_2_3": ModelConfig(
2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias"
2,
4096,
12,
128,
num_gqa_groups=2,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
bias_shape="11ss",
), # GQA
"cp_2_4": ModelConfig(
2,
4096,
12,
128,
num_gqa_groups=2,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
bias_shape="111s",
return_max_logit=True,
), # GQA
"cp_2_5": ModelConfig(
2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias"
), # GQA
"cp_2_6": ModelConfig(
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512)
), # GQA
"cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA
Expand All @@ -171,6 +195,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
"cp_3_4": ModelConfig(
2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss", head_dim_v=64
), # MLA
"cp_4_0": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla"
), # GQA
Expand All @@ -191,10 +218,13 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
"cp_1_0",
"cp_1_1",
"cp_1_4",
"cp_1_5",
"cp_2_0",
"cp_2_2",
"cp_2_3",
"cp_2_4",
"cp_3_2",
"cp_3_4",
"cp_4_2",
]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
Expand Down Expand Up @@ -324,12 +354,15 @@ def test_cp_with_fused_attention(
Float8CurrentScaling(fp8_dpa=True),
DelayedScaling(fp8_dpa=True),
]
# For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s.
is_training = False if config.bias_shape == "111s" else True
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn,
qkv_layout="_".join([qkv_format] * 3),
fp8=fp8,
fp8_meta=fp8_meta,
is_training=is_training,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
Expand All @@ -348,6 +381,7 @@ def test_cp_with_fused_attention(
fp8_mha=fp8_mha,
scaling_mode=scaling_mode,
f16_O=f16_O,
is_training=is_training,
log_level=pytest_logging_level,
),
check=True,
Expand Down
5 changes: 3 additions & 2 deletions tests/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def get_available_attention_backends(
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True

alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
Expand All @@ -289,7 +288,9 @@ def get_available_attention_backends(
and config.head_dim_qk <= 128
and config.head_dim_v <= 128
):
core_attention_bias_requires_grad = True
# TODO(KshitijLakhani): Remove this guard when cuDNN starts support dbias calculation for bias shape 111s
if core_attention_bias_shape != "111s":
core_attention_bias_requires_grad = True

fused_attn_backends = []
available_backends = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v,
int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k,
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training,
bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ,
void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1,
void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv,
bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias,
void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
Expand Down Expand Up @@ -121,6 +122,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
max_pages_per_seq_v,
bias_b,
bias_h,
bias_sq,
bias_skv,
scaling_factor,
is_training,
dropout_probability,
Expand Down Expand Up @@ -270,10 +273,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
sdpa_options.set_alibi_mask(is_alibi);

if (is_bias) {
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
bias = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({bias_b, bias_h, bias_sq, bias_skv})
.set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
sdpa_options.set_bias(bias);
}

Expand Down Expand Up @@ -549,16 +553,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose,
void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset,
void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias,
void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ,
void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace,
size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, bool deterministic, void *devPtrQ, void *devPtrKTranspose,
void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias,
void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO,
void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed,
void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;

bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
Expand Down Expand Up @@ -623,6 +627,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
0,
bias_b,
bias_h,
bias_sq,
bias_skv,
scaling_factor,
true,
dropout_probability,
Expand Down Expand Up @@ -812,19 +818,20 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options.set_alibi_mask(is_alibi);

if (is_bias) {
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
dBias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dBias")
.set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
bias = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({bias_b, bias_h, bias_sq, bias_skv})
.set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
sdpa_backward_options.set_bias(bias);
// shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s]
// are not supported for dbias calculation but they are
// supported for forward bias calculation
if ((bias_b == 1) && (bias_h == h)) {
// bias shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s], [1, h, s, s] are supported for dbias calculation
// bias shape [1, 1, 1, s] is not supported for dbias calculation as of cuDNN 9.18
if (!((bias_b == 1) && (bias_h == 1) && (bias_sq == 1))) {
dBias = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("dBias")
.set_dim({bias_b, bias_h, bias_sq, bias_skv})
.set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
sdpa_backward_options.set_dbias(dBias);
}
}
Expand Down Expand Up @@ -975,10 +982,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(

if (is_bias) {
variant_pack[bias] = devPtrBias;
if ((bias_b == 1) && (bias_h == h)) {
if (dBias != nullptr) {
variant_pack[dBias] = devPtrdBias;
} else {
variant_pack[dBias] = nullptr;
}
}

Expand Down Expand Up @@ -1084,10 +1089,14 @@ void fused_attn_arbitrary_seqlen_fwd(
void *devPtrBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
size_t bias_sq = 0;
size_t bias_skv = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
bias_sq = input_Bias->data.shape[2];
bias_skv = input_Bias->data.shape[3];
}
void *devPtrSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Expand Down Expand Up @@ -1153,7 +1162,7 @@ void fused_attn_arbitrary_seqlen_fwd(
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.shape = {bias_b, bias_h, bias_sq, bias_skv};
output_bias->data.dtype = QKV_type;
}

Expand Down Expand Up @@ -1198,10 +1207,10 @@ void fused_attn_arbitrary_seqlen_fwd(
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV,
devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv,
is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK,
devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
Expand Down Expand Up @@ -1245,11 +1254,15 @@ void fused_attn_arbitrary_seqlen_bwd(
void *devPtrdBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
size_t bias_sq = 0;
size_t bias_skv = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
devPtrdBias = output_dBias->data.dptr;
bias_b = output_dBias->data.shape[0];
bias_h = output_dBias->data.shape[1];
bias_sq = output_dBias->data.shape[2];
bias_skv = output_dBias->data.shape[3];
}

size_t max_batch_size = 0;
Expand Down Expand Up @@ -1292,11 +1305,11 @@ void fused_attn_arbitrary_seqlen_bwd(

fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale,
p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO,
devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);

Expand Down
Loading
Loading