Skip to content

[JAX] Support for cuDNN-backed flex attention#2985

Open
vcherepanov-nv wants to merge 4 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-jax
Open

[JAX] Support for cuDNN-backed flex attention#2985
vcherepanov-nv wants to merge 4 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-jax

Conversation

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

Description

This PR introduces an alternative code path for the FusedAttention backend for JAX.
The user can specify score_mod and score_mod_bprop functions, which get routed to the corresponding parameters of the sdpa and sdpa_backward calls to cuDNN FE.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • A new code path for FusedAttention backend, when score_mod (and the related parameters) is specified
  • Tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 13, 2026

Greptile Summary

This PR adds a score_mod / score_mod_bprop code path to the JAX FusedAttention backend, routing user-supplied cuDNN frontend graph callbacks through a new _fused_attn_score_mod custom_vjp primitive. At JAX trace time, cuDNN pygraph objects are built once per unique shape/dtype/config, registered in a C++ side-table keyed by a monotone integer ID, and replayed at CUDA dispatch time via an XLA FFI handler that acquires the GIL to call the Python graph's execution method.

  • New Python layer (cpp_extensions/attention.py): config builder, fwd/bwd cuDNN graph construction, module-level graph cache, and ffi_call wrappers for the two new FFI targets.
  • New C++ layer (attention.cpp): ScoreModGraphRegistry holding raw PyObject* entries, ExecuteScoreModGraph with GIL acquisition in CUDA dispatch, and a separate thread-local cuDNN handle.
  • Tests cover input validation, cache-key stability, and three numerical correctness cases; the softcap test uses a stateful callback class whose backward depends on an undocumented cuDNN callback-ordering guarantee.

Confidence Score: 3/5

Mergeable as a new-feature branch but carries several design-level risks that should be resolved before wide adoption.

The core library wiring is sound and the tests pass numerically. The C++ registry permanently leaks Python graph objects, the GIL is held across a live CUDA stream in every kernel dispatch, and the graph cache grows without bound. The new standalone cuDNN handle in ExecuteScoreModGraph is independent of the one prepared by CudnnHandleInitHandler. The softcap test demonstrates a stateful callback pattern that depends on cuDNN calling the forward score_mod callback before the backward one during sdpa_backward graph construction.

transformer_engine/jax/csrc/extensions/attention.cpp (registry lifetime, GIL, duplicate handle) and transformer_engine/jax/cpp_extensions/attention.py (private cuDNN API, unbounded cache, id-based cache keys)

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/attention.py Adds the full Python-side score_mod path: config builder, cuDNN pygraph construction for fwd/bwd, graph cache, and FFI wrappers. Uses id()-based cache keys, a private cuDNN API, and per-call workspace allocation.
transformer_engine/jax/csrc/extensions/attention.cpp Adds ScoreModGraphRegistry (unbounded, no Py_DECREF on eviction), ExecuteScoreModGraph with GIL acquisition in CUDA dispatch, and a second thread-local cuDNN handle that duplicates the existing CudnnHandleInitHandler handle.
transformer_engine/jax/attention.py Adds _fused_attn_score_mod custom_vjp, validation helper, and score_mod parameters to fused_attn. Logic is straightforward; sequence_descriptor type annotation correctly loosened to Optional.
tests/jax/test_fused_attn.py Adds validation tests and three numerical correctness tests. _ScoreModSoftcap stores mutable graph-construction state shared between forward and backward callbacks, relying on an undocumented cuDNN callback-ordering guarantee.
transformer_engine/jax/csrc/extensions.h Adds handler declarations and RegisterFusedAttnScoreModGraph signature. Clean header changes.
transformer_engine/jax/csrc/extensions/pybind.cpp Registers the two new FFI handlers and the register_fused_attn_score_mod_graph binding. Minimal, correct changes.
tests/jax/test_distributed_fused_attn.py Adds distributed score_mod test cases mirroring the single-device tests. No new logic concerns.

Sequence Diagram

sequenceDiagram
    participant User
    participant attention.py
    participant cpp_ext as cpp_extensions/attention.py
    participant CppReg as C++ ScoreModGraphRegistry
    participant cuDNN as cuDNN pygraph
    participant FFI as XLA FFI (C++)

    User->>attention.py: "fused_attn(..., score_mod=fn)"
    attention.py->>cpp_ext: make_fused_attn_score_mod_config()
    cpp_ext-->>attention.py: config, tensor_operands, bprop_tensor_operands
    attention.py->>attention.py: _fused_attn_score_mod (custom_vjp)

    note over cpp_ext,cuDNN: JAX trace time
    attention.py->>cpp_ext: fused_attn_score_mod_fwd(qkv, tensors, config)
    cpp_ext->>cuDNN: "pygraph() + sdpa(score_mod=callback)"
    cuDNN->>User: score_mod(graph, score, tensors) called
    cuDNN-->>cpp_ext: graph + workspace_size
    cpp_ext->>CppReg: register_fused_attn_score_mod_graph(graph, uids)
    CppReg-->>cpp_ext: graph_id
    cpp_ext->>FFI: ffi_call te_fused_attn_score_mod_forward_ffi

    note over FFI,cuDNN: CUDA execution time
    FFI->>CppReg: GetScoreModGraphEntry(graph_id)
    FFI->>FFI: acquire GIL
    FFI->>cuDNN: py_graph._execute_with_ptrs(user_ptrs, workspace, handle)
    FFI-->>User: output, softmax_stats

    note over cpp_ext,cuDNN: Backward graph build
    attention.py->>cpp_ext: fused_attn_score_mod_bwd(...)
    cpp_ext->>cuDNN: pygraph() + sdpa_backward(score_mod, score_mod_bprop)
    cuDNN->>User: score_mod called then score_mod_bprop called
    cpp_ext->>CppReg: register bwd graph
    cpp_ext->>FFI: ffi_call te_fused_attn_score_mod_backward_ffi
    FFI-->>User: dq, dk, dv
Loading

Reviews (2): Last reviewed commit: "Add distributed JAX score mod attention ..." | Re-trigger Greptile

Comment on lines +706 to +713
struct ScoreModGraphEntry {
PyObject *py_graph = nullptr;
std::vector<int64_t> user_uids;
std::vector<int64_t> input_uids;
std::vector<int64_t> output_uids;
std::vector<int64_t> scalar_uids;
std::vector<ScoreModScalarStorage> scalar_values;
};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Python reference leak: Py_INCREF without a matching Py_DECREF

ScoreModGraphEntry stores a raw PyObject* and its refcount is bumped at registration (Py_INCREF(entry->py_graph) at line 833), but the struct has no destructor to call Py_DECREF. Because ScoreModGraphRegistry never removes entries either, every cuDNN Python graph object registered here is permanently immortalised — it will never be collected by Python's GC regardless of what the call site does. Over many different attention shapes or graph configurations this accumulates silently. The fix is to add a destructor that acquires the GIL and calls Py_DECREF, or to store a pybind11::object (which manages the refcount automatically) and ensure destruction always happens under the GIL.

Comment on lines +684 to +692
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)

q_dim, q_stride = _bshd_as_bhsd_dim_stride(q_aval.shape)
k_dim, k_stride = _bshd_as_bhsd_dim_stride(k_aval.shape)
v_dim, v_stride = _bshd_as_bhsd_dim_stride(v_aval.shape)
o_dim, o_stride = _bshd_as_bhsd_dim_stride(output_aval.shape)
do_dim, do_stride = _bshd_as_bhsd_dim_stride(doutput_aval.shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 id()-based cache keys can produce false cache hits after GC

_score_mod_callback_cache_key builds its key from id(self_obj) and id(func). Python recycles object addresses after GC, so if a callback instance is collected and a new object (of a different class or with different graph logic) is allocated at the same address, the new config will compare equal to the old one under __eq__. JAX's nondiff-argnum caching then reuses the traced function and graph built for the original callback, silently executing the wrong cuDNN graph. The risk is low for long-lived module-level functions but real for short-lived class instances. Anchoring the key to a non-id stable identifier (e.g., a weakref plus explicit id, or requiring callers to supply an explicit stable key) would eliminate the ambiguity.

Comment on lines +765 to +807
Error_Type ExecuteScoreModGraph(cudaStream_t stream, int64_t graph_id,
const std::vector<void *> &input_ptrs,
const std::vector<void *> &output_ptrs, void *workspace) {
auto entry = GetScoreModGraphEntry(graph_id);
NVTE_CHECK(input_ptrs.size() == entry->input_uids.size(), "cuDNN score_mod graph expected ",
entry->input_uids.size(), " inputs but got ", input_ptrs.size());
NVTE_CHECK(output_ptrs.size() >= entry->output_uids.size(),
"cuDNN score_mod graph expected at least ", entry->output_uids.size(),
" outputs but got ", output_ptrs.size());

std::unordered_map<int64_t, void *> variant_pack;
for (size_t i = 0; i < entry->input_uids.size(); ++i) {
variant_pack.emplace(entry->input_uids[i], input_ptrs[i]);
}
for (size_t i = 0; i < entry->output_uids.size(); ++i) {
variant_pack.emplace(entry->output_uids[i], output_ptrs[i]);
}
for (size_t i = 0; i < entry->scalar_uids.size(); ++i) {
variant_pack.emplace(entry->scalar_uids[i], entry->scalar_values[i].data.data());
}

std::vector<std::intptr_t> user_ptrs;
user_ptrs.reserve(entry->user_uids.size());
for (const auto uid : entry->user_uids) {
auto it = variant_pack.find(uid);
NVTE_CHECK(it != variant_pack.end(), "cuDNN score_mod graph variant pack is missing UID ", uid);
user_ptrs.push_back(reinterpret_cast<std::intptr_t>(it->second));
}

auto handle = GetScoreModCudnnHandle();
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
{
pybind11::gil_scoped_acquire gil;
try {
auto graph = pybind11::reinterpret_borrow<pybind11::object>(entry->py_graph);
graph.attr("_execute_with_ptrs")(user_ptrs, reinterpret_cast<std::intptr_t>(workspace),
reinterpret_cast<std::intptr_t>(handle));
} catch (const pybind11::error_already_set &exc) {
NVTE_ERROR("cuDNN score_mod SDPA graph execution failed: ", exc.what());
}
}
return ffi_with_cuda_error_check();
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 GIL held across a CUDA FFI call boundary

ExecuteScoreModGraph acquires pybind11::gil_scoped_acquire while the CUDA stream is live and calls a Python method (_execute_with_ptrs) synchronously. Any other Python thread that holds the GIL and is waiting on CUDA work will deadlock. More broadly, acquiring the GIL inside an XLA/JAX FFI handler — which JAX may dispatch from a non-Python thread — creates a locking inversion risk. This is by-design if cuDNN's Python frontend has no C-level execution path, but the limitation should be documented and the possibility of multi-threaded JAX dispatch should be explicitly considered.

_SCORE_MOD_UID_DQ = 7
_SCORE_MOD_UID_DK = 8
_SCORE_MOD_UID_DV = 9
_SCORE_MOD_FWD_TENSOR_UID_BASE = 1000
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _score_mod_graph_cache and C++ registry grow without bound

_score_mod_graph_cache is a module-level dict that accumulates (graph_id, workspace_size) entries for every unique (direction, config, aval-tuple) seen during tracing, and the C++ ScoreModGraphRegistry holds the corresponding cuDNN graph objects forever. Each entry keeps a Python cuDNN graph alive (and, due to the missing Py_DECREF noted separately, prevents GC). In long-running services or evaluation loops that sweep over many shapes/dtypes, this leads to unbounded cuDNN graph memory accumulation. An LRU eviction strategy or an explicit graph-release API paired with cache invalidation would contain the growth.

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment on lines +143 to +195

def forward(self, graph, score, tensors):
import cudnn # pylint: disable=import-outside-toplevel

self.before_tanh_activation = graph.div(
a=score,
b=tensors["softcap"],
compute_data_type=cudnn.data_type.FLOAT,
)
self.before_tanh_activation.set_data_type(cudnn.data_type.FLOAT)
tanh_out = graph.tanh(input=self.before_tanh_activation)
tanh_out.set_data_type(cudnn.data_type.FLOAT)
return graph.mul(
a=tanh_out,
b=tensors["softcap"],
compute_data_type=cudnn.data_type.FLOAT,
)

def backward(self, graph, dscore, tensors):
import cudnn # pylint: disable=import-outside-toplevel

d_tanh_out = graph.mul(
a=dscore,
b=tensors["softcap"],
compute_data_type=cudnn.data_type.FLOAT,
)
d_tanh_out.set_data_type(cudnn.data_type.FLOAT)
d_before_tanh_activation = graph.tanh_backward(
loss=d_tanh_out,
input=self.before_tanh_activation,
compute_data_type=cudnn.data_type.FLOAT,
)
d_before_tanh_activation.set_data_type(cudnn.data_type.FLOAT)
return graph.div(
a=d_before_tanh_activation,
b=tensors["softcap"],
compute_data_type=cudnn.data_type.FLOAT,
)


def _reference_attention(
query, key, value, scale, *, causal=False, relative_position=False, softcap=None
):
scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale
if causal:
q_pos = jnp.arange(query.shape[1])[:, None]
kv_pos = jnp.arange(key.shape[1])[None, :]
scores = jnp.where(q_pos >= kv_pos, scores, -1e9)
if relative_position:
q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None]
kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :]
scores = scores + q_pos - kv_pos
if softcap is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 _ScoreModSoftcap.backward relies on undocumented cuDNN callback ordering

backward reads self.before_tanh_activation, which is written by forward during sdpa_backward graph construction. This is only safe if cuDNN's sdpa_backward guarantees it calls score_mod (the forward callback) before score_mod_bprop (the backward callback) within the same graph-build invocation. If that order is ever reversed, self.before_tanh_activation is None at the time backward runs, and graph.tanh_backward(input=None, ...) will fail silently or crash at execution time rather than at graph-build time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant