Skip to content

Inconsistent fused_attn_bwd results of dk, dv under THD layout + GQA on multiple calls #2448

@Big-TRex

Description

@Big-TRex

Describe the bug

When using THD layout + GQA, multiple calls to fused_attn_bwd produce unexpected dk and dv results. Observations:

  • The second call appears to reuse the workspace from the previous call
  • This causes inconsistent and unexpected outputs
  • Changing the execution order also affects the results

In MHA mode, the same operations produce correct and consistent outputs.

Steps/Code to reproduce bug

import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    fused_attn_bwd,
)

seqlen=12
hq=2
hk=1

torch.manual_seed(42)

for i in range(2):
    cu_seqlens_q_per_step=[
        torch.tensor([0,1,3],dtype=torch.int32,device=torch.device("cuda")),
        torch.tensor([0,0,3],dtype=torch.int32,device=torch.device("cuda")),
    ]
    cu_seqlens_kv_per_step=[
        torch.tensor([0,3,9],dtype=torch.int32,device=torch.device("cuda")),
        torch.tensor([0,3,9],dtype=torch.int32,device=torch.device("cuda")),
    ]
    q=torch.randn((seqlen,hq,8),dtype=torch.float16,device=torch.device("cuda"))
    k=torch.randn((seqlen,hk,8),dtype=torch.float16,device=torch.device("cuda"))
    v=torch.randn((seqlen,hk,8),dtype=torch.float16,device=torch.device("cuda"))
    out=torch.randn((seqlen,hq,8),dtype=torch.float16,device=torch.device("cuda"))
    dout=torch.randn((seqlen,hq,8),dtype=torch.float16,device=torch.device("cuda"))
    lse=torch.randn((seqlen,hq,1),dtype=torch.float16,device=torch.device("cuda"))
    rng=torch.randn(hq,dtype=torch.float16,device=torch.device("cuda"))
    cu_seqlens_q_padded=torch.tensor([0,4,12],dtype=torch.int32,device=torch.device("cuda"))
    cu_seqlens_kv_padded=torch.tensor([0,4,12],dtype=torch.int32,device=torch.device("cuda"))

    fused_attn_meta_args = [
        torch.float16,
        TE_DType[dout.dtype],
        [lse,rng],
        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
        ]
    fused_attn_meta_kwargs = {
        "attn_scale": 0.1,
        "dropout": 0.0,
        "qkv_layout": "thd_thd_thd",
        "attn_mask_type": "padding",
        "attn_bias_type": "no_bias",
        "deterministic": True,
    }

    print(cu_seqlens_q_per_step[i])
    print(cu_seqlens_kv_per_step[i])
    dq,dk0,dv0, _, _ = fused_attn_bwd(
        8,
        8,
        cu_seqlens_q_per_step[i],
        cu_seqlens_kv_per_step[i],
        q,
        k,
        v,
        out,
        dout,
        *fused_attn_meta_args,
        cu_seqlens_q_padded=cu_seqlens_q_padded,
        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
        **fused_attn_meta_kwargs,
    )
    print(f"{i=},{dk0=}")

Expected behavior

I simulated two sequences, padded to lengths 4 and 8.

  • First call:
    • Query valid lengths: 1 and 2
    • Key/Value valid lengths: 3 and 6
  • Second call:
    • Query valid lengths: 0 and 3
    • Key/Value valid lengths: 3 and 6

Expected result for the second call:

  • The first 4 tokens of dk and dv should be all zeros.

Observed behavior for the second call:

  • In THD layout + GQA, the first 4 tokens of dk and dv are not zero, and appear to reuse the previous dk/dv buffers.
  • Changing to MHA mode or swapping the call order produces the correct result, where the first 4 tokens of dk and dv are zero.

Environment overview

sudo docker run -it -d --privileged --gpus all --network host --ipc host --name {name} nvcr.io/nvidia/pytorch:25.10-py3

Environment details

None.

Device details

  • Nvidia H100 GPU.

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions