Skip to content

[JAX] [PyT] Tighten SWA checks in DPA, MHA and other APIs before passing onto cuDNN fused attn & unfused attn backends#2970

Open
KshitijLakhani wants to merge 8 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/feat/attn-swa-enh-jax
Open

[JAX] [PyT] Tighten SWA checks in DPA, MHA and other APIs before passing onto cuDNN fused attn & unfused attn backends#2970
KshitijLakhani wants to merge 8 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/feat/attn-swa-enh-jax

Commits

Commits on May 8, 2026