Skip to content
Open
19 changes: 13 additions & 6 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
fused_attn,
run_length_fill,
make_swa_mask,
check_set_window_size,
SequenceDescriptor,
CPStrategy,
ReorderStrategy,
Expand Down Expand Up @@ -1065,10 +1066,13 @@ def _get_swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tu

cuDNN < 9.2: skip (no SWA support).
cuDNN >= 9.2: left-only window (s_kv // 10, 0).
cuDNN >= 9.6: bidirectional window (s_kv // 10, s_kv // 10 + 5) for the mask types whose
bidirectional fused dispatch is meaningful here (NO_MASK, PADDING_MASK).
Other mask types keep the left-only window: causal-family masks would
collapse (W, W) -> (W, 0), hence not tested here.
cuDNN >= 9.6: bidirectional asymmetric window (s_kv // 10, s_kv // 10 + 5) for the mask
types whose bidirectional fused dispatch is meaningful here (NO_MASK,
PADDING_MASK). Other mask types keep the left-only window: causal-family
masks would collapse (W, W) -> (W, 0), hence not tested here.

The chosen ``(left, right)`` is then routed through :func:`check_set_window_size`, which
is the same canonicalizer the production modules call at construction time.
"""
cudnn_version = get_cudnn_version()
if cudnn_version < 90200:
Expand All @@ -1080,8 +1084,11 @@ def _get_swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tu
AttnMaskType.NO_MASK,
AttnMaskType.PADDING_MASK,
):
return (left_window_size, right_window_size)
return (left_window_size, 0)
candidate = (left_window_size, right_window_size)
else:
candidate = (left_window_size, 0)
# Validate the window size against the contract and return the canonicalized value.
return check_set_window_size(attn_mask_type, candidate)


@pytest.mark.parametrize(
Expand Down
128 changes: 112 additions & 16 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,88 @@ def canonicalize_attn_mask_type(attn_mask_type: str):
)


def check_set_window_size(
attn_mask_type: Union[str, AttnMaskType],
window_size: Optional[Tuple[int, int]] = None,
*,
warn: bool = True,
) -> Tuple[int, int]:
"""Check if sliding window size is compliant with attention mask type.
If not, set it to the appropriate size.

attn_mask_type | window_size
----------------------------------------------------------------------------
no_mask, padding | (-1, -1) or (>=0, >=0)
causal, padding_causal | (-1, 0) or (>=0, 0)
causal_bottom_right, padding_causal_bottom_right | (-1, 0) or (>=0, 0)

``(-1, -1)`` and ``(-1, 0)`` are sentinels meaning "no window" (full attention) and
"infinite-left causal" respectively. Negative entries are otherwise rejected.

Args:
attn_mask_type: Either a canonical ``attn_mask_type`` string (e.g. ``"no_mask"``,
``"padding"``, ``"causal"``, ``"padding_causal"``, ``"causal_bottom_right"``,
``"padding_causal_bottom_right"``) or an :class:`AttnMaskType` enum value.
window_size: ``(left, right)`` tuple, or ``None`` to use the natural default for the
given mask type.
warn: When ``True`` (default), emit a :class:`UserWarning` whenever the supplied
``window_size`` is silently coerced to the canonical form for ``attn_mask_type``
Set to ``False`` for internal call sites that do not need to emit warnings.
Hard-error branches (negative bounds outside the recognized sentinels) are not gated by this flag
and always raise.

Returns:
The canonicalized ``(left, right)`` tuple.
"""
if isinstance(attn_mask_type, str):
attn_mask_type_enum = canonicalize_attn_mask_type(attn_mask_type)
attn_mask_type_str = attn_mask_type
else:
attn_mask_type_enum = attn_mask_type
attn_mask_type_str = attn_mask_type.name

orig_window_size = window_size
if attn_mask_type_enum.is_causal():
if orig_window_size is None:
window_size = (-1, 0)
# Coerce the right side window to 0.
elif orig_window_size == (-1, -1) or (orig_window_size[0] >= 0 and orig_window_size[1] > 0):
window_size = (orig_window_size[0], 0)
if warn:
warnings.warn(
"window_size should be (-1, 0) or (>=0, 0) for "
f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}; "
f"coercing to {window_size}."
)
# Assert if invalid window size is provided.
elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0):
raise AssertionError(
"window_size should be (-1, 0) or (>=0, 0) for "
f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}."
)
elif attn_mask_type_enum in (AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK):
if orig_window_size is None:
window_size = (-1, -1)
# Coerce the right side window to -1.
elif orig_window_size == (-1, 0):
window_size = (-1, -1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should do this, or go the other direction and change the mask? Technically, this could be a valid combination right? no_mask/padding + swa(left, 0) -> essentially causal + swa(left,0)?

if warn:
warnings.warn(
"window_size should be (-1, -1) or (>=0, >=0) for "
f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}; "
f"coercing to {window_size}."
)
# Assert if invalid window size is provided.
elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0):
raise AssertionError(
"window_size should be (-1, -1) or (>=0, >=0) for "
f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}."
)
else:
raise AssertionError(f"Invalid attn_mask_type: {attn_mask_type_str}")
return window_size


def is_fused_attn_kernel_available(
is_training,
q_dtype,
Expand All @@ -343,7 +425,9 @@ def is_fused_attn_kernel_available(
"""
To check whether the fused attention kernel is supported
"""
window_size_tuple = (-1, -1) if window_size is None else window_size
# Canonicalize at the CPP-extension boundary so direct callers see the same
# canonical encoding as users of DPA/MHA API to ensure consistency.
window_size_tuple = check_set_window_size(attn_mask_type, window_size)
Comment thread
KshitijLakhani marked this conversation as resolved.

def make_helper(attn_mask_type):
return tex.FusedAttnHelper(
Expand Down Expand Up @@ -688,9 +772,9 @@ def _segment_ids_pos_to_seqlens_offsets(
segment_ids_kv,
segment_pos_q,
segment_pos_kv,
attn_mask_type,
window_size,
max_segments_per_seq,
attn_mask_type: AttnMaskType,
window_size: Tuple[int, int],
max_segments_per_seq: int,
):
"""Compute per-segment seqlens and start offsets(currently only used for THD)
Given segment-id and segment-position tensors for Q and KV,
Expand All @@ -708,21 +792,24 @@ def _segment_ids_pos_to_seqlens_offsets(
attn_mask_type: AttnMaskType. Selects the mask predicate used to decide
which positions are valid (top-left causal vs
bottom-right causal vs. padding-only)
window_size: Optional sliding-window tuple ``(left, right)`` or None
Used here only as a fast-path eligibility hint
window_size: Sliding-window tuple ``(left, right)``. Required (not
Optional): Tuple[int, int]. Window size received should be
already canonicalized by check_set_window_size.
max_segments_per_seq: maximum number of segments expected per row
Used to size the bincount / argwhere outputs

Routing (only invoked for THD qkv_layout):
1. Fast path -- ``_segment_ids_pos_to_seqlens_offsets_fast_causal_path``.
O(T) per row. Counts all segment tokens via bincount on
segment_ids and trims at most one token per segment at the
boundary. Used for:
- top-left CAUSAL / PADDING_CAUSAL with ``window_size is None``
- SWA with ``window_size == (-1, -1)`` and not bottom-right
Bottom-right causal cross-attention is excluded: the boundary
trim leaves kv_seqlen short by one per active segment, which
shifts the BRCM bottom-right alignment by one KV per Q row.
boundary. Used for any non-bottom-right mask with no finite
sliding window, i.e. ``window_size`` in
``{(-1, -1), (-1, 0)}``. ``window_size`` is guaranteed to be
non-``None`` here because it is already canonicalized by check_set_window_size.
Bottom-right causal cross-attention is excluded:
the boundary trim leaves kv_seqlen short by one per active
segment, which shifts the BRCM bottom-right alignment by one KV
per Q row.

2. Slow path -- ``_get_seqlens_offsets_thd``.
O(T * max_segments_per_seq) per row. Per-segment min/max
Expand Down Expand Up @@ -755,9 +842,11 @@ def _segment_ids_pos_to_seqlens_offsets(
# must route bottom-right masks to the slow path.

# Fast path: O(T) per row.
if (
attn_mask_type.is_causal() and not attn_mask_type.is_bottom_right() and window_size is None
) or (window_size == (-1, -1) and not attn_mask_type.is_bottom_right()):
# "No finite window" is encoded as (-1, -1) for non-causal masks and (-1, 0) for
# causal-family masks; both share window_size[0] == -1, which is therefore the
# mask-type-agnostic SWA-presence sentinel.
no_finite_window = window_size[0] == -1
if no_finite_window and not attn_mask_type.is_bottom_right():
return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
)
Expand Down Expand Up @@ -825,10 +914,17 @@ def tree_unflatten(cls, aux_data, children):
return cls(*children)

def get_seqlens_and_offsets(
self, attn_mask_type, qkv_layout, window_size, max_segments_per_seq
self,
attn_mask_type: "AttnMaskType",
qkv_layout: "QKVLayout",
window_size: Tuple[int, int],
max_segments_per_seq: int,
):
"""
Acquire the seqlens/offsets for cuDNN backend.

``window_size`` must be a ``Tuple[int, int]`` (never ``None``)
and already canonicalized by check_set_window_size.
"""
q_segment_ids, kv_segment_ids = self.segment_ids
q_segment_pos, kv_segment_pos = self.segment_pos
Expand Down
38 changes: 32 additions & 6 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
QKVFormat,
CPStrategy,
SequenceDescriptor,
check_set_window_size,
)
from ..sharding import with_sharding_constraint_by_logical_axes, HEAD_AXES, is_mesh_available

Expand Down Expand Up @@ -2479,7 +2480,19 @@ def check_supported(self):
)

def get_step_config(self, attn_mask_type) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call to fused attention."""
"""Returns a _FusedAttnConfig for single CP step call to fused attention.

Ring CP overrides ``attn_mask_type`` per step (e.g. ``CAUSAL_MASK`` -> ``NO_MASK``
for off-diagonal steps where the kv chunk is fully past or fully future of the
local q chunk; see ``ring_attn_fwd_impl`` / ``ring_attn_bwd_impl``). The user's
``window_size`` is the canonical no-SWA form for the *original* mask, so we
re-canonicalize it for the per-step mask.

``warn=False`` because the user's ``window_size`` was already canonicalized
by check_set_window_size upstream. This per-step coercion is an internal mask
switch (for ring P2P) which if reported, may confuse the user.
"""
per_step_window = check_set_window_size(attn_mask_type, self.config.window_size, warn=False)
return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type,
attn_mask_type=attn_mask_type,
Expand All @@ -2489,7 +2502,7 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig:
dropout_probability=self.config.dropout_probability,
is_training=self.config.is_training,
max_segments_per_seq=self.config.max_segments_per_seq,
window_size=self.config.window_size,
window_size=per_step_window,
bottom_right_diagonal=attn_mask_type.is_bottom_right(),
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
Expand Down Expand Up @@ -3149,7 +3162,11 @@ def compute(config):
config=config,
)

if config.window_size != (-1, -1):
# Trigger striped-window adjustment only when there is a finite SWA.
# window_size[0] == -1 is the unified "no finite window" sentinel that
# covers both the non-causal (-1, -1) form and the causal-family (-1, 0)
# form produced by check_set_window_size
if config.window_size[0] != -1:
kv_src_rank = (cp_size + cp_rank - idx) % cp_size
# Note: all inputs of adjust_cp_striped_window_size should be host values
cp_striped_window_size = adjust_cp_striped_window_size(
Expand Down Expand Up @@ -3302,7 +3319,10 @@ def compute(config):
)
return dq_per_step, dkv_per_step, dbias_per_step

if config.window_size != (-1, -1):
# See fwd path above: window_size[0] != -1 is the unified "finite SWA"
# sentinel that handles both the non-causal (-1, -1) and causal-family
# (-1, 0) canonical forms produced by check_set_window_size.
if config.window_size[0] != -1:
kv_src_rank = (cp_size + cp_rank - idx) % cp_size
# Note: all inputs of adjust_cp_striped_window_size should be host values
cp_striped_window_size = adjust_cp_striped_window_size(
Expand Down Expand Up @@ -3486,7 +3506,10 @@ def fused_attn_fwd(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=(-1, -1) if window_size is None else window_size,
# Canonicalize at the CPP-extension boundary so every downstream primitive this
# function dispatches to (default fused-attn and the CP all-gather/ring variants)
# sees the same canonical encoding as the DPA/MHA API.
window_size=check_set_window_size(attn_mask_type, window_size),
bottom_right_diagonal=attn_mask_type.is_bottom_right(),
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
Expand Down Expand Up @@ -3661,7 +3684,10 @@ def fused_attn_bwd(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=(-1, -1) if window_size is None else window_size,
# Canonicalize at the CPP-extension boundary so every downstream primitive this
# function dispatches to (default fused-attn and the CP all-gather/ring variants)
# sees the same canonical encoding as the DPA/MHA API.
window_size=check_set_window_size(attn_mask_type, window_size),
bottom_right_diagonal=attn_mask_type.is_bottom_right(),
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
Expand Down
Loading
Loading