Skip to content

Expose max_logit (and softmax aux statistics) from JAX fused attention higher-level APIs #2945

@zeryx

Description

@zeryx

Summary

(coauthed by Claude)

The PyTorch fused attention bindings support an optional return_max_logit parameter that returns per-head maximum attention scores, but the JAX fused attention API (transformer_engine.jax.attention.fused_attn) has no equivalent mechanism. Additionally, the softmax_aux tensor (log-sum-exp statistics) computed during the forward pass is kept internal to the custom VJP machinery and never exposed to callers.

Primary ask: Expose max_logit (the per-head maximum attention score) through transformer_engine.jax.attention.fused_attn, matching the PyTorch API's return_max_logit parameter.

Secondary ask: Optionally expose softmax_aux (log-sum-exp statistics) for users who need them.

Motivation

The max_logit value is already supported in the PyTorch frontend via return_max_logit: bool in transformer_engine/pytorch/cpp_extensions/fused_attn.py. Use cases include:

  • Custom loss functions and reward signals that depend on attention statistics
  • Numerical stability diagnostics and debugging
  • Implementing custom backward passes outside of JAX's built-in autodiff
  • Feature parity with the PyTorch frontend

Current Behavior

PyTorch (supports max_logit):
In transformer_engine/pytorch/cpp_extensions/fused_attn.py, the forward function accepts return_max_logit: bool and returns max_logit as a separate output tensor when requested.

JAX (no max_logit support):
In transformer_engine/jax/attention.py, the public API fused_attn (line ~1394) returns only the output tensor. The internal forward rule _fused_attn_fwd_rule (line ~1272) calls tex.fused_attn_fwd which returns (output, softmax_aux, rng_state), but only output is surfaced to the caller. The softmax_aux and rng_state are stored in the custom VJP context for the backward pass and discarded from the public return value.

# Current JAX public API — no way to get aux stats
def fused_attn(qkv, bias, sequence_descriptor, seed, ...) -> jnp.ndarray:
    ...
    return output  # softmax_aux is only available internally for backward pass

Expected Behavior

from transformer_engine.jax.attention import fused_attn

# Primary: max_logit support
output, aux = fused_attn(q, k, v, ..., return_max_logit=True)
max_logit = aux["max_logit"]  # per-head max attention scores

# Secondary: softmax_aux (log-sum-exp stats)
output, aux = fused_attn(q, k, v, ..., return_aux=True)
softmax_aux = aux["softmax_aux"]  # shape [B, H, Sq, 1], float32 (log-sum-exp)

Proposed Changes

1. Add return_max_logit support to the JAX C++ extension layer

The underlying cuDNN kernels already support computing max_logit (as evidenced by PyTorch support). Wire this through fused_attn_fwd in transformer_engine/jax/cpp_extensions/attention.py.

2. Propagate through the custom VJP wrapper

Update _fused_attn and its forward/backward rules in transformer_engine/jax/attention.py to optionally return max_logit alongside output. Care is needed since changing the return signature affects jax.custom_vjp — one approach is to always compute max_logit when requested and pass it through the VJP context without requiring gradients for it.

3. Expose in the public fused_attn API

Add return_max_logit: bool = False (and optionally return_softmax_aux: bool = False) to the fused_attn signature at line ~1394. When enabled, return a tuple (output, aux_dict) instead of just output.

4. Current fused_attn signature for reference

def fused_attn(
    qkv, bias, sequence_descriptor, seed,
    attn_bias_type, attn_mask_type, qkv_layout, softmax_type,
    scaling_factor, dropout_probability, is_training,
    max_segments_per_seq=1, window_size=None,
    context_parallel_strategy=CPStrategy.DEFAULT,
    context_parallel_causal_load_balanced=False,
    context_parallel_axis="", context_checkpoint_name="context",
    softmax_offset=None, stripe_size=None,
) -> jnp.ndarray

Softmax Aux Details

The softmax_aux tensor in JAX contains log-sum-exp statistics with shape determined by cuDNN version:

  • cuDNN ≥ 9.6: [B, H, Sq, 1] for BSHD layouts, [B, Sq, H, 1] for THD layouts
  • cuDNN < 9.6: [B, H, Sq, max_segments_per_seq]
  • Always: float32

Note: softmax_aux contains log(Σ exp(x - max(x))), not the raw max logits. max_logit is a separate output tensor.

References

Environment

  • TransformerEngine version: main branch (HEAD)
  • Framework: JAX

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions