Describe the bug
When using THD layout + GQA, multiple calls to fused_attn_bwd produce unexpected dk and dv results. Observations:
- The second call appears to reuse the workspace from the previous call
- This causes inconsistent and unexpected outputs
- Changing the execution order also affects the results
In MHA mode, the same operations produce correct and consistent outputs.
Steps/Code to reproduce bug
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_bwd,
)
seqlen=12
hq=2
hk=1
torch.manual_seed(42)
for i in range(2):
cu_seqlens_q_per_step=[
torch.tensor([0,1,3],dtype=torch.int32,device=torch.device("cuda")),
torch.tensor([0,0,3],dtype=torch.int32,device=torch.device("cuda")),
]
cu_seqlens_kv_per_step=[
torch.tensor([0,3,9],dtype=torch.int32,device=torch.device("cuda")),
torch.tensor([0,3,9],dtype=torch.int32,device=torch.device("cuda")),
]
q=torch.randn((seqlen,hq,8),dtype=torch.float16,device=torch.device("cuda"))
k=torch.randn((seqlen,hk,8),dtype=torch.float16,device=torch.device("cuda"))
v=torch.randn((seqlen,hk,8),dtype=torch.float16,device=torch.device("cuda"))
out=torch.randn((seqlen,hq,8),dtype=torch.float16,device=torch.device("cuda"))
dout=torch.randn((seqlen,hq,8),dtype=torch.float16,device=torch.device("cuda"))
lse=torch.randn((seqlen,hq,1),dtype=torch.float16,device=torch.device("cuda"))
rng=torch.randn(hq,dtype=torch.float16,device=torch.device("cuda"))
cu_seqlens_q_padded=torch.tensor([0,4,12],dtype=torch.int32,device=torch.device("cuda"))
cu_seqlens_kv_padded=torch.tensor([0,4,12],dtype=torch.int32,device=torch.device("cuda"))
fused_attn_meta_args = [
torch.float16,
TE_DType[dout.dtype],
[lse,rng],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
]
fused_attn_meta_kwargs = {
"attn_scale": 0.1,
"dropout": 0.0,
"qkv_layout": "thd_thd_thd",
"attn_mask_type": "padding",
"attn_bias_type": "no_bias",
"deterministic": True,
}
print(cu_seqlens_q_per_step[i])
print(cu_seqlens_kv_per_step[i])
dq,dk0,dv0, _, _ = fused_attn_bwd(
8,
8,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
q,
k,
v,
out,
dout,
*fused_attn_meta_args,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
**fused_attn_meta_kwargs,
)
print(f"{i=},{dk0=}")
Expected behavior
I simulated two sequences, padded to lengths 4 and 8.
- First call:
- Query valid lengths: 1 and 2
- Key/Value valid lengths: 3 and 6
- Second call:
- Query valid lengths: 0 and 3
- Key/Value valid lengths: 3 and 6
Expected result for the second call:
- The first 4 tokens of dk and dv should be all zeros.
Observed behavior for the second call:
- In THD layout + GQA, the first 4 tokens of dk and dv are not zero, and appear to reuse the previous dk/dv buffers.
- Changing to MHA mode or swapping the call order produces the correct result, where the first 4 tokens of dk and dv are zero.
Environment overview
sudo docker run -it -d --privileged --gpus all --network host --ipc host --name {name} nvcr.io/nvidia/pytorch:25.10-py3
Environment details
None.
Device details
Describe the bug
When using THD layout + GQA, multiple calls to fused_attn_bwd produce unexpected dk and dv results. Observations:
In MHA mode, the same operations produce correct and consistent outputs.
Steps/Code to reproduce bug
Expected behavior
I simulated two sequences, padded to lengths 4 and 8.
Expected result for the second call:
Observed behavior for the second call:
Environment overview
sudo docker run -it -d --privileged --gpus all --network host --ipc host --name {name} nvcr.io/nvidia/pytorch:25.10-py3
Environment details
None.
Device details