From ca43999f8c22099e8c5e8009ffc994ae37b15658 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 7 May 2026 14:34:55 -0700 Subject: [PATCH 1/8] Add check_set_window_size to JAX attn to mimic the behavior of Pyt attn - Verifies if a window is correct for a given mask type. If it isn't either force sentinel values or assert. If forcing sentinel values then warn the user - All possible ways of using attn, i.e. DPA, MHA, TL, fused attn APIs are all now guaranteeing that window size will not be None and appropriately set before passing downstream to internal APIs, primitives or classes. Signed-off-by: Kshitij Lakhani --- tests/jax/test_fused_attn.py | 20 ++- transformer_engine/jax/attention.py | 134 +++++++++++++++--- .../jax/cpp_extensions/attention.py | 26 +++- transformer_engine/jax/flax/transformer.py | 96 ++++++++++++- 4 files changed, 245 insertions(+), 31 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 1fb0108068..271905d437 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -36,6 +36,7 @@ fused_attn, run_length_fill, make_swa_mask, + check_set_window_size, SequenceDescriptor, CPStrategy, ReorderStrategy, @@ -1065,10 +1066,15 @@ def _get_swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tu cuDNN < 9.2: skip (no SWA support). cuDNN >= 9.2: left-only window (s_kv // 10, 0). - cuDNN >= 9.6: bidirectional window (s_kv // 10, s_kv // 10 + 5) for the mask types whose - bidirectional fused dispatch is meaningful here (NO_MASK, PADDING_MASK). - Other mask types keep the left-only window: causal-family masks would - collapse (W, W) -> (W, 0), hence not tested here. + cuDNN >= 9.6: bidirectional asymmetric window (s_kv // 10, s_kv // 10 + 5) for the mask + types whose bidirectional fused dispatch is meaningful here (NO_MASK, + PADDING_MASK). Other mask types keep the left-only window: causal-family + masks would collapse (W, W) -> (W, 0), hence not tested here. + + The chosen ``(left, right)`` is then routed through :func:`check_set_window_size`, which + is the same canonicalizer the production modules call at construction time. For the + candidates above this is always a no-op, so it acts as a contract self-check rather + than a value transformation. """ cudnn_version = get_cudnn_version() if cudnn_version < 90200: @@ -1080,8 +1086,10 @@ def _get_swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tu AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK, ): - return (left_window_size, right_window_size) - return (left_window_size, 0) + candidate = (left_window_size, right_window_size) + else: + candidate = (left_window_size, 0) + return check_set_window_size(attn_mask_type, candidate) @pytest.mark.parametrize( diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index f54a043fd2..31f346f675 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -323,6 +323,81 @@ def canonicalize_attn_mask_type(attn_mask_type: str): ) +def check_set_window_size( + attn_mask_type: Union[str, AttnMaskType], + window_size: Optional[Tuple[int, int]] = None, +) -> Tuple[int, int]: + """Check if sliding window size is compliant with attention mask type. + If not, set it to the appropriate size. + + attn_mask_type | window_size + ---------------------------------------------------------------------------- + no_mask, padding | (-1, -1) or (>=0, >=0) + causal, padding_causal | (-1, 0) or (>=0, 0) + causal_bottom_right, padding_causal_bottom_right | (-1, 0) or (>=0, 0) + + ``(-1, -1)`` and ``(-1, 0)`` are sentinels meaning "no window" (full attention) and + "infinite-left causal" respectively. Negative entries are otherwise rejected. + + Args: + attn_mask_type: Either a canonical ``attn_mask_type`` string (e.g. ``"no_mask"``, + ``"padding"``, ``"causal"``, ``"padding_causal"``, ``"causal_bottom_right"``, + ``"padding_causal_bottom_right"``) or an :class:`AttnMaskType` enum value. + window_size: ``(left, right)`` tuple, or ``None`` to use the natural default for the + given mask type. + + Returns: + The canonicalized ``(left, right)`` tuple. + """ + if isinstance(attn_mask_type, str): + attn_mask_type_enum = canonicalize_attn_mask_type(attn_mask_type) + attn_mask_type_str = attn_mask_type + else: + attn_mask_type_enum = attn_mask_type + attn_mask_type_str = attn_mask_type.name + + orig_window_size = window_size + if attn_mask_type_enum.is_causal(): + if orig_window_size is None: + window_size = (-1, 0) + elif orig_window_size == (-1, -1) or ( + orig_window_size[0] >= 0 and orig_window_size[1] != 0 + ): + window_size = (orig_window_size[0], 0) + warnings.warn( + f"window_size should be (-1, 0) or (>=0, 0) for " + f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}; " + f"coercing to {window_size}." + ) + elif orig_window_size != (-1, 0) and ( + orig_window_size[0] < 0 or orig_window_size[1] != 0 + ): + raise AssertionError( + f"window_size should be (-1, 0) or (>=0, 0) for " + f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}." + ) + elif attn_mask_type_enum in (AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK): + if orig_window_size is None: + window_size = (-1, -1) + elif orig_window_size == (-1, 0): + window_size = (-1, -1) + warnings.warn( + f"window_size should be (-1, -1) or (>=0, >=0) for " + f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}; " + f"coercing to {window_size}." + ) + elif orig_window_size != (-1, -1) and ( + orig_window_size[0] < 0 or orig_window_size[1] < 0 + ): + raise AssertionError( + f"window_size should be (-1, -1) or (>=0, >=0) for " + f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}." + ) + else: + raise AssertionError(f"Invalid attn_mask_type: {attn_mask_type_str}") + return window_size + + def is_fused_attn_kernel_available( is_training, q_dtype, @@ -343,7 +418,11 @@ def is_fused_attn_kernel_available( """ To check whether the fused attention kernel is supported """ - window_size_tuple = (-1, -1) if window_size is None else window_size + # Canonicalize: None -> (-1, -1) for non-causal, (-1, 0) for causal-family. + # The flax modules already canonicalize at __post_init__; this covers direct + # callers (e.g. tests, internal dispatch helpers) so the value reaching the + # FusedAttnHelper is always self-consistent with attn_mask_type. + window_size_tuple = check_set_window_size(attn_mask_type, window_size) def make_helper(attn_mask_type): return tex.FusedAttnHelper( @@ -688,9 +767,9 @@ def _segment_ids_pos_to_seqlens_offsets( segment_ids_kv, segment_pos_q, segment_pos_kv, - attn_mask_type, - window_size, - max_segments_per_seq, + attn_mask_type: AttnMaskType, + window_size: Tuple[int, int], + max_segments_per_seq: int, ): """Compute per-segment seqlens and start offsets(currently only used for THD) Given segment-id and segment-position tensors for Q and KV, @@ -708,8 +787,13 @@ def _segment_ids_pos_to_seqlens_offsets( attn_mask_type: AttnMaskType. Selects the mask predicate used to decide which positions are valid (top-left causal vs bottom-right causal vs. padding-only) - window_size: Optional sliding-window tuple ``(left, right)`` or None - Used here only as a fast-path eligibility hint + window_size: Sliding-window tuple ``(left, right)`` in canonical form. + Required (not Optional): every public boundary + (``check_set_window_size``, ``fused_attn_fwd`` / + ``fused_attn_bwd``, ``is_fused_attn_kernel_available``, + flax ``__post_init__``s) canonicalizes ``None`` upstream + before reaching this internal helper. Used here only + as a fast-path eligibility hint. max_segments_per_seq: maximum number of segments expected per row Used to size the bincount / argwhere outputs @@ -717,12 +801,18 @@ def _segment_ids_pos_to_seqlens_offsets( 1. Fast path -- ``_segment_ids_pos_to_seqlens_offsets_fast_causal_path``. O(T) per row. Counts all segment tokens via bincount on segment_ids and trims at most one token per segment at the - boundary. Used for: - - top-left CAUSAL / PADDING_CAUSAL with ``window_size is None`` - - SWA with ``window_size == (-1, -1)`` and not bottom-right - Bottom-right causal cross-attention is excluded: the boundary - trim leaves kv_seqlen short by one per active segment, which - shifts the BRCM bottom-right alignment by one KV per Q row. + boundary. Used for any non-bottom-right mask with no finite + sliding window, i.e. ``window_size`` in + ``{(-1, -1), (-1, 0)}``. ``(-1, 0)`` is the canonical causal-family + "no-SWA" sentinel produced by :func:`check_set_window_size`; + ``(-1, -1)`` is the corresponding non-causal sentinel. + ``window_size`` is guaranteed non-``None`` here because every + caller routes through ``check_set_window_size`` (flax modules at + ``__post_init__`` and the public ``fused_attn_fwd`` / ``fused_attn_bwd`` + entrypoints). Bottom-right causal cross-attention is excluded: + the boundary trim leaves kv_seqlen short by one per active + segment, which shifts the BRCM bottom-right alignment by one KV + per Q row. 2. Slow path -- ``_get_seqlens_offsets_thd``. O(T * max_segments_per_seq) per row. Per-segment min/max @@ -755,9 +845,12 @@ def _segment_ids_pos_to_seqlens_offsets( # must route bottom-right masks to the slow path. # Fast path: O(T) per row. - if ( - attn_mask_type.is_causal() and not attn_mask_type.is_bottom_right() and window_size is None - ) or (window_size == (-1, -1) and not attn_mask_type.is_bottom_right()): + # window_size is canonical here (Tuple[int, int], never None -- see signature). + # "No finite window" is encoded as (-1, -1) for non-causal masks and (-1, 0) for + # causal-family masks; both share window_size[0] == -1, which is therefore the + # mask-type-agnostic SWA-presence sentinel. + no_finite_window = window_size[0] == -1 + if no_finite_window and not attn_mask_type.is_bottom_right(): return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq ) @@ -825,10 +918,19 @@ def tree_unflatten(cls, aux_data, children): return cls(*children) def get_seqlens_and_offsets( - self, attn_mask_type, qkv_layout, window_size, max_segments_per_seq + self, + attn_mask_type: "AttnMaskType", + qkv_layout: "QKVLayout", + window_size: Tuple[int, int], + max_segments_per_seq: int, ): """ Acquire the seqlens/offsets for cuDNN backend. + + ``window_size`` must be in canonical form (``Tuple[int, int]``, + never ``None``). The callers (``FusedAttn{Fwd,Bwd}Primitive.impl``) + construct ``_FusedAttnConfig`` via ``check_set_window_size`` at the + public boundary, which guarantees this invariant. """ q_segment_ids, kv_segment_ids = self.segment_ids q_segment_pos, kv_segment_pos = self.segment_pos diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 40d02f40e1..c5e1b5155c 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -25,6 +25,7 @@ QKVFormat, CPStrategy, SequenceDescriptor, + check_set_window_size, ) from ..sharding import with_sharding_constraint_by_logical_axes, HEAD_AXES, is_mesh_available @@ -3149,7 +3150,13 @@ def compute(config): config=config, ) - if config.window_size != (-1, -1): + # Trigger striped-window adjustment only when there is a finite SWA. After + # check_set_window_size canonicalization, "no finite window" is encoded as + # (-1, -1) for non-causal masks and (-1, 0) for causal-family masks; both + # have window_size[0] == -1, so we use that as the unified SWA sentinel + # (matches the convention used elsewhere in this file, e.g. is_sliding_window + # at line ~2475). + if config.window_size[0] != -1: kv_src_rank = (cp_size + cp_rank - idx) % cp_size # Note: all inputs of adjust_cp_striped_window_size should be host values cp_striped_window_size = adjust_cp_striped_window_size( @@ -3302,7 +3309,10 @@ def compute(config): ) return dq_per_step, dkv_per_step, dbias_per_step - if config.window_size != (-1, -1): + # See fwd path above: window_size[0] != -1 is the unified "finite SWA" + # sentinel that handles both the (-1, -1) non-causal and (-1, 0) causal + # canonicalizations from check_set_window_size. + if config.window_size[0] != -1: kv_src_rank = (cp_size + cp_rank - idx) % cp_size # Note: all inputs of adjust_cp_striped_window_size should be host values cp_striped_window_size = adjust_cp_striped_window_size( @@ -3486,7 +3496,11 @@ def fused_attn_fwd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, - window_size=(-1, -1) if window_size is None else window_size, + # Canonicalize: None -> (-1, -1) for non-causal, (-1, 0) for causal-family. + # The flax modules already canonicalize at __post_init__; this covers direct + # callers of fused_attn_fwd/_bwd so the invariant holds at the C-extension + # boundary too. + window_size=check_set_window_size(attn_mask_type, window_size), bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), @@ -3661,7 +3675,11 @@ def fused_attn_bwd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, - window_size=(-1, -1) if window_size is None else window_size, + # Canonicalize: None -> (-1, -1) for non-causal, (-1, 0) for causal-family. + # The flax modules already canonicalize at __post_init__; this covers direct + # callers of fused_attn_fwd/_bwd so the invariant holds at the C-extension + # boundary too. + window_size=check_set_window_size(attn_mask_type, window_size), bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index a2e7920843..2826bd6a3e 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -30,7 +30,12 @@ QKVLayout, SequenceDescriptor, ) -from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type +from ..attention import ( + is_fused_attn_kernel_available, + make_swa_mask, + canonicalize_attn_mask_type, + check_set_window_size, +) from ..attention import fused_attn from ..attention import CPStrategy from ..softmax import SoftmaxFusionType @@ -127,6 +132,13 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- window_size: Optional[Tuple[int, int]] = None softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX + def __post_init__(self): + # Defensive canonicalization: in normal flow window_size arrives canonical + # from DotProductAttention, but this guarantees the inner module never has + # to handle window_size=None. + self.window_size = check_set_window_size(self.attn_mask_type, self.window_size) + super().__post_init__() + @nn.compact def __call__( self, @@ -242,7 +254,10 @@ def convert_to_softmax_fusion_type(attn_mask_type, mask): # mask is ignored for no_mask and causal_mask without sliding window if attn_mask_type == AttnMaskType.NO_MASK: mask = None - if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None: + # check_set_window_size (called in __post_init__) canonicalizes + # "no SWA + causal" to (-1, 0), so this is the only sentinel we need + # to recognize here. + if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size == (-1, 0): mask = None if mask is not None: mask = apply_swa_mask(mask) @@ -306,6 +321,14 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me context_checkpoint_name: str = "context" softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX + def __post_init__(self): + # Defensive canonicalization: parallels _UnfusedDotProductAttention. The + # value is also re-canonicalized inside fused_attn_fwd/_bwd at the + # C-extension boundary, so this is purely about keeping the invariant + # uniform across the private module layer. + self.window_size = check_set_window_size(self.attn_mask_type, self.window_size) + super().__post_init__() + @nn.compact def __call__( self, @@ -565,7 +588,25 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...). window_size: Optional[Tuple[int, int]], default = None - Sliding window size. The default value is no sliding window. + Sliding window size as ``(left, right)``. The default value of ``None`` means no + sliding window (full attention for ``no_mask`` / ``padding``, infinite-left for + causal-family masks). + + Allowed values per ``attn_mask_type``: + + * ``no_mask``, ``padding``: ``(-1, -1)`` (sentinel for full attention) or + ``(>=0, >=0)``. + * ``causal``, ``padding_causal``, ``causal_bottom_right``, + ``padding_causal_bottom_right``: ``(-1, 0)`` (sentinel for infinite-left + causal) or ``(>=0, 0)``. + + Inputs are validated and lightly canonicalized at construction time (e.g. ``None`` + is replaced with the sentinel for the given mask type, and inconsistent sentinels + such as ``(-1, 0)`` paired with ``no_mask`` are coerced with a warning). Values + with negative ``left`` or ``right`` outside the listed sentinels raise an + ``AssertionError``. Bidirectional sliding windows + (``right > 0`` with ``no_mask`` / ``padding``) require cuDNN >= 9.6 in the fused + backend; left-only sliding windows require cuDNN >= 9.2. max_segments_per_seq: Optional[int], default = 1 The maximum number of segments per sequence, also used for THD format (sequence packing). context_parallel_causal_load_balanced: bool @@ -638,6 +679,8 @@ def __post_init__(self): " TransformerEngine v2.10" ) self.transpose_batch_sequence = False + # Validate / canonicalize window_size against attn_mask_type. + self.window_size = check_set_window_size(self.attn_mask_type, self.window_size) super().__post_init__() def _assert_dtypes(self, query: Array, key: Array, value: Array, qkv_layout: QKVLayout): @@ -1151,7 +1194,25 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods fuse_qkv: bool, default = None Deprecated. Please refer `fuse_qkv_params` window_size: Optional[Tuple[int, int]], default = None - Sliding window size. Default value is no sliding window. + Sliding window size as ``(left, right)``. The default value of ``None`` means no + sliding window (full attention for ``no_mask`` / ``padding``, infinite-left for + causal-family masks). + + Allowed values per ``attn_mask_type``: + + * ``no_mask``, ``padding``: ``(-1, -1)`` (sentinel for full attention) or + ``(>=0, >=0)``. + * ``causal``, ``padding_causal``, ``causal_bottom_right``, + ``padding_causal_bottom_right``: ``(-1, 0)`` (sentinel for infinite-left + causal) or ``(>=0, 0)``. + + Inputs are validated and lightly canonicalized at construction time (e.g. ``None`` + is replaced with the sentinel for the given mask type, and inconsistent sentinels + such as ``(-1, 0)`` paired with ``no_mask`` are coerced with a warning). Values + with negative ``left`` or ``right`` outside the listed sentinels raise an + ``AssertionError``. Bidirectional sliding windows + (``right > 0`` with ``no_mask`` / ``padding``) require cuDNN >= 9.6 in the fused + backend; left-only sliding windows require cuDNN >= 9.2. softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' Softmax type as described in the paper `Efficient Streaming Language Models with Attention Sinks @@ -1266,6 +1327,8 @@ def __post_init__(self): ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads + # Validate / canonicalize window_size against attn_mask_type. + self.window_size = check_set_window_size(self.attn_mask_type, self.window_size) super().__post_init__() @nn.compact @@ -1911,7 +1974,25 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods enable_sequence_parallel: bool, default = False Whether to enable sequence parallelism to operations except dot. window_size: Optional[Tuple[int, int]], default = None - Sliding window size. Default value is no sliding window. + Sliding window size as ``(left, right)``. The default value of ``None`` means no + sliding window (full attention for ``no_mask`` / ``padding``, infinite-left for + causal-family masks). + + Allowed values per ``attn_mask_type``: + + * ``no_mask``, ``padding``: ``(-1, -1)`` (sentinel for full attention) or + ``(>=0, >=0)``. + * ``causal``, ``padding_causal``, ``causal_bottom_right``, + ``padding_causal_bottom_right``: ``(-1, 0)`` (sentinel for infinite-left + causal) or ``(>=0, 0)``. + + Inputs are validated and lightly canonicalized at construction time (e.g. ``None`` + is replaced with the sentinel for the given mask type, and inconsistent sentinels + such as ``(-1, 0)`` paired with ``no_mask`` are coerced with a warning). Values + with negative ``left`` or ``right`` outside the listed sentinels raise an + ``AssertionError``. Bidirectional sliding windows + (``right > 0`` with ``no_mask`` / ``padding``) require cuDNN >= 9.6 in the fused + backend; left-only sliding windows require cuDNN >= 9.2. softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' Softmax type as described in the paper `Efficient Streaming Language Models with Attention Sinks @@ -2017,6 +2098,11 @@ def __post_init__(self): ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads + # Validate / canonicalize window_size against self_attn_mask_type. The + # cross-attention block (encoder-decoder MHA) reuses the same window_size + # internally; that combo is re-validated against its fixed "padding" mask + # type inside MultiHeadAttention. + self.window_size = check_set_window_size(self.self_attn_mask_type, self.window_size) super().__post_init__() @nn.compact From d96e402563bdbe59096ae221634cce0d1724ff38 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 7 May 2026 16:39:09 -0700 Subject: [PATCH 2/8] TransformerLayer APi does not need to check the window size. That contract responsibility can be handled by MHA and lower APIs Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/flax/transformer.py | 70 ++++++++++++++-------- 1 file changed, 45 insertions(+), 25 deletions(-) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 2826bd6a3e..df67202cf6 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -592,7 +592,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods sliding window (full attention for ``no_mask`` / ``padding``, infinite-left for causal-family masks). - Allowed values per ``attn_mask_type``: + Allowed values per :attr:`attn_mask_type`: * ``no_mask``, ``padding``: ``(-1, -1)`` (sentinel for full attention) or ``(>=0, >=0)``. @@ -600,11 +600,13 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ``padding_causal_bottom_right``: ``(-1, 0)`` (sentinel for infinite-left causal) or ``(>=0, 0)``. - Inputs are validated and lightly canonicalized at construction time (e.g. ``None`` - is replaced with the sentinel for the given mask type, and inconsistent sentinels - such as ``(-1, 0)`` paired with ``no_mask`` are coerced with a warning). Values - with negative ``left`` or ``right`` outside the listed sentinels raise an - ``AssertionError``. Bidirectional sliding windows + Inputs are validated and canonicalized at construction time. ``None`` is + replaced silently with the sentinel for :attr:`attn_mask_type`: ``(-1, -1)`` + for ``no_mask`` / ``padding`` and ``(-1, 0)`` for the causal family. + Inconsistent sentinels (e.g. ``(-1, 0)`` paired with ``no_mask``, or + ``(W, R)`` with ``R != 0`` paired with a causal-family mask) are coerced with + a warning. Values with negative ``left`` or ``right`` outside the listed + sentinels raise an ``AssertionError``. Bidirectional sliding windows (``right > 0`` with ``no_mask`` / ``padding``) require cuDNN >= 9.6 in the fused backend; left-only sliding windows require cuDNN >= 9.2. max_segments_per_seq: Optional[int], default = 1 @@ -1198,7 +1200,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods sliding window (full attention for ``no_mask`` / ``padding``, infinite-left for causal-family masks). - Allowed values per ``attn_mask_type``: + Allowed values per :attr:`attn_mask_type`: * ``no_mask``, ``padding``: ``(-1, -1)`` (sentinel for full attention) or ``(>=0, >=0)``. @@ -1206,11 +1208,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ``padding_causal_bottom_right``: ``(-1, 0)`` (sentinel for infinite-left causal) or ``(>=0, 0)``. - Inputs are validated and lightly canonicalized at construction time (e.g. ``None`` - is replaced with the sentinel for the given mask type, and inconsistent sentinels - such as ``(-1, 0)`` paired with ``no_mask`` are coerced with a warning). Values - with negative ``left`` or ``right`` outside the listed sentinels raise an - ``AssertionError``. Bidirectional sliding windows + Inputs are validated and canonicalized at construction time. ``None`` is + replaced silently with the sentinel for :attr:`attn_mask_type`: ``(-1, -1)`` + for ``no_mask`` / ``padding`` and ``(-1, 0)`` for the causal family. + Inconsistent sentinels (e.g. ``(-1, 0)`` paired with ``no_mask``, or + ``(W, R)`` with ``R != 0`` paired with a causal-family mask) are coerced with + a warning. Values with negative ``left`` or ``right`` outside the listed + sentinels raise an ``AssertionError``. Bidirectional sliding windows (``right > 0`` with ``no_mask`` / ``padding``) require cuDNN >= 9.6 in the fused backend; left-only sliding windows require cuDNN >= 9.2. softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' @@ -1978,7 +1982,19 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods sliding window (full attention for ``no_mask`` / ``padding``, infinite-left for causal-family masks). - Allowed values per ``attn_mask_type``: + This value is forwarded as-is to both the self-attention block (which uses + :attr:`self_attn_mask_type`) and, when :attr:`layer_type` is ``DECODER``, the + cross-attention block (whose mask type is internally fixed to ``padding``). + ``TransformerLayer`` deliberately does not canonicalize ``window_size``: a + decoder layer carries two different mask-type contracts simultaneously, so each + ``MultiHeadAttention`` block canonicalizes against its own mask type in its + own ``__post_init__``. Concretely, for ``self_attn_mask_type`` in the causal + family with ``window_size=None``, the self-attention block silently expands + ``None`` to ``(-1, 0)`` while the cross-attention block silently expands the + same ``None`` to ``(-1, -1)`` -- both blocks land on the correct sentinel for + their own mask type without warnings. + + Allowed values per mask type: * ``no_mask``, ``padding``: ``(-1, -1)`` (sentinel for full attention) or ``(>=0, >=0)``. @@ -1986,13 +2002,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ``padding_causal_bottom_right``: ``(-1, 0)`` (sentinel for infinite-left causal) or ``(>=0, 0)``. - Inputs are validated and lightly canonicalized at construction time (e.g. ``None`` - is replaced with the sentinel for the given mask type, and inconsistent sentinels - such as ``(-1, 0)`` paired with ``no_mask`` are coerced with a warning). Values - with negative ``left`` or ``right`` outside the listed sentinels raise an - ``AssertionError``. Bidirectional sliding windows - (``right > 0`` with ``no_mask`` / ``padding``) require cuDNN >= 9.6 in the fused - backend; left-only sliding windows require cuDNN >= 9.2. + Inconsistent sentinels (e.g. ``(-1, 0)`` paired with ``no_mask``) are coerced + with a warning inside the relevant ``MultiHeadAttention`` block. Values with + negative ``left`` or ``right`` outside the listed sentinels raise an + ``AssertionError``. Bidirectional sliding windows (``right > 0`` with + ``no_mask`` / ``padding``) require cuDNN >= 9.6 in the fused backend; left-only + sliding windows require cuDNN >= 9.2. softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' Softmax type as described in the paper `Efficient Streaming Language Models with Attention Sinks @@ -2098,11 +2113,16 @@ def __post_init__(self): ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads - # Validate / canonicalize window_size against self_attn_mask_type. The - # cross-attention block (encoder-decoder MHA) reuses the same window_size - # internally; that combo is re-validated against its fixed "padding" mask - # type inside MultiHeadAttention. - self.window_size = check_set_window_size(self.self_attn_mask_type, self.window_size) + # window_size is intentionally NOT canonicalized here. A decoder layer + # constructs two MultiHeadAttention blocks with different mask types + # (self-attention uses self_attn_mask_type; cross-attention is hardcoded + # to "padding"). Canonicalizing at this layer would force a single + # mask-type contract onto a dual-contract object and produce a sentinel + # that is correct for one block but not the other (e.g. None + causal + # self-attn would yield (-1, 0), which the padding cross-attn would + # then warn-coerce). Instead, the raw user value is forwarded as-is to + # both blocks; each block canonicalizes against its own mask type in + # MultiHeadAttention.__post_init__. super().__post_init__() @nn.compact From 4a0f5739524cc778aaf91faa7e982ab036344818 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 7 May 2026 22:53:19 -0700 Subject: [PATCH 3/8] Update the window size via check set window size API per rank for CP fused attn Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/attention.py | 81 ++++++++++++------- .../jax/cpp_extensions/attention.py | 59 +++++++++----- 2 files changed, 92 insertions(+), 48 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 31f346f675..c3363907a8 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -326,6 +326,8 @@ def canonicalize_attn_mask_type(attn_mask_type: str): def check_set_window_size( attn_mask_type: Union[str, AttnMaskType], window_size: Optional[Tuple[int, int]] = None, + *, + warn: bool = True, ) -> Tuple[int, int]: """Check if sliding window size is compliant with attention mask type. If not, set it to the appropriate size. @@ -345,6 +347,16 @@ def check_set_window_size( ``"padding_causal_bottom_right"``) or an :class:`AttnMaskType` enum value. window_size: ``(left, right)`` tuple, or ``None`` to use the natural default for the given mask type. + warn: When ``True`` (default), emit a :class:`UserWarning` whenever the supplied + ``window_size`` is silently coerced to the canonical form for ``attn_mask_type`` + (e.g. ``(-1, -1)`` → ``(-1, 0)`` for causal masks, or ``(-1, 0)`` → ``(-1, -1)`` + for non-causal masks). Set to ``False`` for internal call sites that route an + already-validated value through this function purely for re-canonicalization + (for example, the context-parallel ring primitive's per-step mask switch from + ``CAUSAL_MASK`` to ``NO_MASK``), where the warning would target user-supplied + input that the user never provided in non-canonical form. Hard-error branches + (negative bounds outside the recognized sentinels) are not gated by this flag + and always raise. Returns: The canonicalized ``(left, right)`` tuple. @@ -364,11 +376,12 @@ def check_set_window_size( orig_window_size[0] >= 0 and orig_window_size[1] != 0 ): window_size = (orig_window_size[0], 0) - warnings.warn( - f"window_size should be (-1, 0) or (>=0, 0) for " - f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}; " - f"coercing to {window_size}." - ) + if warn: + warnings.warn( + f"window_size should be (-1, 0) or (>=0, 0) for " + f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}; " + f"coercing to {window_size}." + ) elif orig_window_size != (-1, 0) and ( orig_window_size[0] < 0 or orig_window_size[1] != 0 ): @@ -381,11 +394,12 @@ def check_set_window_size( window_size = (-1, -1) elif orig_window_size == (-1, 0): window_size = (-1, -1) - warnings.warn( - f"window_size should be (-1, -1) or (>=0, >=0) for " - f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}; " - f"coercing to {window_size}." - ) + if warn: + warnings.warn( + f"window_size should be (-1, -1) or (>=0, >=0) for " + f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}; " + f"coercing to {window_size}." + ) elif orig_window_size != (-1, -1) and ( orig_window_size[0] < 0 or orig_window_size[1] < 0 ): @@ -418,10 +432,11 @@ def is_fused_attn_kernel_available( """ To check whether the fused attention kernel is supported """ - # Canonicalize: None -> (-1, -1) for non-causal, (-1, 0) for causal-family. - # The flax modules already canonicalize at __post_init__; this covers direct - # callers (e.g. tests, internal dispatch helpers) so the value reaching the - # FusedAttnHelper is always self-consistent with attn_mask_type. + # Canonicalize at the C-extension boundary so direct callers see the same + # canonical encoding as flax users (whose flax DPA/MHA __post_init__ already + # canonicalizes via check_set_window_size). This keeps backend-availability + # queries consistent regardless of how the caller spelled the no-SWA + # sentinel ((-1, -1) vs (-1, 0) vs None). window_size_tuple = check_set_window_size(attn_mask_type, window_size) def make_helper(attn_mask_type): @@ -787,11 +802,12 @@ def _segment_ids_pos_to_seqlens_offsets( attn_mask_type: AttnMaskType. Selects the mask predicate used to decide which positions are valid (top-left causal vs bottom-right causal vs. padding-only) - window_size: Sliding-window tuple ``(left, right)`` in canonical form. - Required (not Optional): every public boundary - (``check_set_window_size``, ``fused_attn_fwd`` / - ``fused_attn_bwd``, ``is_fused_attn_kernel_available``, - flax ``__post_init__``s) canonicalizes ``None`` upstream + window_size: Sliding-window tuple ``(left, right)``. Required (not + Optional): the flax DPA / MHA ``__post_init__`` calls + :func:`check_set_window_size` to canonicalize user + input, and the public ``fused_attn_fwd`` / + ``fused_attn_bwd`` / ``is_fused_attn_kernel_available`` + entrypoints default ``None`` to ``(-1, -1)`` upstream before reaching this internal helper. Used here only as a fast-path eligibility hint. max_segments_per_seq: maximum number of segments expected per row @@ -803,13 +819,16 @@ def _segment_ids_pos_to_seqlens_offsets( segment_ids and trims at most one token per segment at the boundary. Used for any non-bottom-right mask with no finite sliding window, i.e. ``window_size`` in - ``{(-1, -1), (-1, 0)}``. ``(-1, 0)`` is the canonical causal-family - "no-SWA" sentinel produced by :func:`check_set_window_size`; - ``(-1, -1)`` is the corresponding non-causal sentinel. - ``window_size`` is guaranteed non-``None`` here because every - caller routes through ``check_set_window_size`` (flax modules at - ``__post_init__`` and the public ``fused_attn_fwd`` / ``fused_attn_bwd`` - entrypoints). Bottom-right causal cross-attention is excluded: + ``{(-1, -1), (-1, 0)}``. Both encodings represent "no SWA": + ``(-1, 0)`` is the canonical causal-family form produced by + :func:`check_set_window_size` at the flax DPA / MHA boundary, and + ``(-1, -1)`` is the corresponding non-causal form (and is also the + form that internal callers such as ring-attention CP primitives + continue to use unmodified). ``window_size`` is guaranteed + non-``None`` here because the flax modules canonicalize at + ``__post_init__`` and the public ``fused_attn_fwd`` / + ``fused_attn_bwd`` entrypoints default ``None`` to ``(-1, -1)``. + Bottom-right causal cross-attention is excluded: the boundary trim leaves kv_seqlen short by one per active segment, which shifts the BRCM bottom-right alignment by one KV per Q row. @@ -927,10 +946,12 @@ def get_seqlens_and_offsets( """ Acquire the seqlens/offsets for cuDNN backend. - ``window_size`` must be in canonical form (``Tuple[int, int]``, - never ``None``). The callers (``FusedAttn{Fwd,Bwd}Primitive.impl``) - construct ``_FusedAttnConfig`` via ``check_set_window_size`` at the - public boundary, which guarantees this invariant. + ``window_size`` must be a ``Tuple[int, int]`` (never ``None``). The + callers (``FusedAttn{Fwd,Bwd}Primitive.impl``) construct + ``_FusedAttnConfig`` via ``fused_attn_fwd`` / ``fused_attn_bwd``, + which default ``None`` to ``(-1, -1)`` at the public boundary; the + flax DPA / MHA modules additionally canonicalize user input through + :func:`check_set_window_size` at ``__post_init__`` upstream of that. """ q_segment_ids, kv_segment_ids = self.segment_ids q_segment_pos, kv_segment_pos = self.segment_pos diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index c5e1b5155c..9447eebb58 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -2480,7 +2480,26 @@ def check_supported(self): ) def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: - """Returns a _FusedAttnConfig for single CP step call to fused attention.""" + """Returns a _FusedAttnConfig for single CP step call to fused attention. + + Ring CP overrides ``attn_mask_type`` per step (e.g. ``CAUSAL_MASK`` -> ``NO_MASK`` + for off-diagonal steps where the kv chunk is fully past or fully future of the + local q chunk; see ``ring_attn_fwd_impl`` / ``ring_attn_bwd_impl``). The user's + ``window_size`` is the canonical no-SWA form for the *original* mask, so we + re-canonicalize it for the per-step mask: a causal-canonical ``(-1, 0)`` + propagated unchanged to a ``NO_MASK`` step would otherwise be interpreted by + cuDNN as "no mask + right=0 SWA", erroneously restricting attention to + ``kv_pos <= q_pos`` within that step. + + The partition-time gate (``FusedRingAttnFwdPrimitive.partition`` / + ``FusedRingAttnBwdPrimitive.partition``) guarantees no finite SWA reaches this + primitive, so the recanonicalization is unambiguous. ``warn=False`` because the + user's ``window_size`` was already validated upstream and this per-step coercion + is an internal mask switch the user neither requested nor can act on. + """ + per_step_window = check_set_window_size( + attn_mask_type, self.config.window_size, warn=False + ) return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, attn_mask_type=attn_mask_type, @@ -2490,7 +2509,7 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: dropout_probability=self.config.dropout_probability, is_training=self.config.is_training, max_segments_per_seq=self.config.max_segments_per_seq, - window_size=self.config.window_size, + window_size=per_step_window, bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, @@ -3150,12 +3169,13 @@ def compute(config): config=config, ) - # Trigger striped-window adjustment only when there is a finite SWA. After - # check_set_window_size canonicalization, "no finite window" is encoded as - # (-1, -1) for non-causal masks and (-1, 0) for causal-family masks; both - # have window_size[0] == -1, so we use that as the unified SWA sentinel - # (matches the convention used elsewhere in this file, e.g. is_sliding_window - # at line ~2475). + # Trigger striped-window adjustment only when there is a finite SWA. + # window_size[0] == -1 is the unified "no finite window" sentinel that + # covers both the non-causal (-1, -1) form and the causal-family (-1, 0) + # form produced by check_set_window_size at the flax boundary; window_size[0] + # is the SWA's left bound, so any finite SWA has window_size[0] >= 0 + # (matches the convention used elsewhere in this file, e.g. + # is_sliding_window at line ~2475). if config.window_size[0] != -1: kv_src_rank = (cp_size + cp_rank - idx) % cp_size # Note: all inputs of adjust_cp_striped_window_size should be host values @@ -3310,8 +3330,9 @@ def compute(config): return dq_per_step, dkv_per_step, dbias_per_step # See fwd path above: window_size[0] != -1 is the unified "finite SWA" - # sentinel that handles both the (-1, -1) non-causal and (-1, 0) causal - # canonicalizations from check_set_window_size. + # sentinel that handles both the non-causal (-1, -1) and causal-family + # (-1, 0) canonical forms produced by check_set_window_size at the flax + # boundary. if config.window_size[0] != -1: kv_src_rank = (cp_size + cp_rank - idx) % cp_size # Note: all inputs of adjust_cp_striped_window_size should be host values @@ -3496,10 +3517,13 @@ def fused_attn_fwd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, - # Canonicalize: None -> (-1, -1) for non-causal, (-1, 0) for causal-family. - # The flax modules already canonicalize at __post_init__; this covers direct - # callers of fused_attn_fwd/_bwd so the invariant holds at the C-extension - # boundary too. + # Canonicalize at the C-extension boundary so direct callers (e.g. the + # context-parallel ring primitives and unit tests that bypass flax) + # see the same canonical encoding as flax users (whose flax DPA/MHA + # __post_init__ already canonicalizes via check_set_window_size). + # For the ring CP P2P path, the per-step mask switch from CAUSAL_MASK + # to NO_MASK is handled inside _FusedAttnCPWithP2PHelper.get_step_config, + # which re-canonicalizes the window_size for the per-step mask type. window_size=check_set_window_size(attn_mask_type, window_size), bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=context_parallel_causal_load_balanced, @@ -3675,10 +3699,9 @@ def fused_attn_bwd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, - # Canonicalize: None -> (-1, -1) for non-causal, (-1, 0) for causal-family. - # The flax modules already canonicalize at __post_init__; this covers direct - # callers of fused_attn_fwd/_bwd so the invariant holds at the C-extension - # boundary too. + # Canonicalize at the C-extension boundary; see fused_attn_fwd for the + # rationale. The ring CP P2P path's per-step mask switch is handled + # inside _FusedAttnCPWithP2PHelper.get_step_config. window_size=check_set_window_size(attn_mask_type, window_size), bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=context_parallel_causal_load_balanced, From c770934a60daff9a5c3b2b073aba6f838f374b0b Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 8 May 2026 11:25:43 -0700 Subject: [PATCH 4/8] Code clean up Signed-off-by: Kshitij Lakhani --- tests/jax/test_fused_attn.py | 5 +- transformer_engine/jax/attention.py | 47 ++++-------- .../jax/cpp_extensions/attention.py | 39 ++++------ transformer_engine/jax/flax/transformer.py | 72 +++++-------------- 4 files changed, 45 insertions(+), 118 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 271905d437..7420dcfdc2 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1072,9 +1072,7 @@ def _get_swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tu masks would collapse (W, W) -> (W, 0), hence not tested here. The chosen ``(left, right)`` is then routed through :func:`check_set_window_size`, which - is the same canonicalizer the production modules call at construction time. For the - candidates above this is always a no-op, so it acts as a contract self-check rather - than a value transformation. + is the same canonicalizer the production modules call at construction time. """ cudnn_version = get_cudnn_version() if cudnn_version < 90200: @@ -1089,6 +1087,7 @@ def _get_swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tu candidate = (left_window_size, right_window_size) else: candidate = (left_window_size, 0) + # Validate the window size against the contract and return the canonicalized value. return check_set_window_size(attn_mask_type, candidate) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index c3363907a8..e683bbb833 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -349,13 +349,8 @@ def check_set_window_size( given mask type. warn: When ``True`` (default), emit a :class:`UserWarning` whenever the supplied ``window_size`` is silently coerced to the canonical form for ``attn_mask_type`` - (e.g. ``(-1, -1)`` → ``(-1, 0)`` for causal masks, or ``(-1, 0)`` → ``(-1, -1)`` - for non-causal masks). Set to ``False`` for internal call sites that route an - already-validated value through this function purely for re-canonicalization - (for example, the context-parallel ring primitive's per-step mask switch from - ``CAUSAL_MASK`` to ``NO_MASK``), where the warning would target user-supplied - input that the user never provided in non-canonical form. Hard-error branches - (negative bounds outside the recognized sentinels) are not gated by this flag + Set to ``False`` for internal call sites that do not need to emit warnings. + Hard-error branches (negative bounds outside the recognized sentinels) are not gated by this flag and always raise. Returns: @@ -372,6 +367,7 @@ def check_set_window_size( if attn_mask_type_enum.is_causal(): if orig_window_size is None: window_size = (-1, 0) + # Coerce the right side window to 0. elif orig_window_size == (-1, -1) or ( orig_window_size[0] >= 0 and orig_window_size[1] != 0 ): @@ -382,6 +378,7 @@ def check_set_window_size( f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}; " f"coercing to {window_size}." ) + # Assert if invalid window size is provided. elif orig_window_size != (-1, 0) and ( orig_window_size[0] < 0 or orig_window_size[1] != 0 ): @@ -392,6 +389,7 @@ def check_set_window_size( elif attn_mask_type_enum in (AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK): if orig_window_size is None: window_size = (-1, -1) + # Coerce the right side window to -1. elif orig_window_size == (-1, 0): window_size = (-1, -1) if warn: @@ -400,6 +398,7 @@ def check_set_window_size( f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}; " f"coercing to {window_size}." ) + # Assert if invalid window size is provided. elif orig_window_size != (-1, -1) and ( orig_window_size[0] < 0 or orig_window_size[1] < 0 ): @@ -433,10 +432,7 @@ def is_fused_attn_kernel_available( To check whether the fused attention kernel is supported """ # Canonicalize at the C-extension boundary so direct callers see the same - # canonical encoding as flax users (whose flax DPA/MHA __post_init__ already - # canonicalizes via check_set_window_size). This keeps backend-availability - # queries consistent regardless of how the caller spelled the no-SWA - # sentinel ((-1, -1) vs (-1, 0) vs None). + # canonical encoding as users of DPA/MHA API to ensure consistency. window_size_tuple = check_set_window_size(attn_mask_type, window_size) def make_helper(attn_mask_type): @@ -803,13 +799,8 @@ def _segment_ids_pos_to_seqlens_offsets( which positions are valid (top-left causal vs bottom-right causal vs. padding-only) window_size: Sliding-window tuple ``(left, right)``. Required (not - Optional): the flax DPA / MHA ``__post_init__`` calls - :func:`check_set_window_size` to canonicalize user - input, and the public ``fused_attn_fwd`` / - ``fused_attn_bwd`` / ``is_fused_attn_kernel_available`` - entrypoints default ``None`` to ``(-1, -1)`` upstream - before reaching this internal helper. Used here only - as a fast-path eligibility hint. + Optional): Tuple[int, int]. Window size received should be + already canonicalized by check_set_window_size. max_segments_per_seq: maximum number of segments expected per row Used to size the bincount / argwhere outputs @@ -819,15 +810,8 @@ def _segment_ids_pos_to_seqlens_offsets( segment_ids and trims at most one token per segment at the boundary. Used for any non-bottom-right mask with no finite sliding window, i.e. ``window_size`` in - ``{(-1, -1), (-1, 0)}``. Both encodings represent "no SWA": - ``(-1, 0)`` is the canonical causal-family form produced by - :func:`check_set_window_size` at the flax DPA / MHA boundary, and - ``(-1, -1)`` is the corresponding non-causal form (and is also the - form that internal callers such as ring-attention CP primitives - continue to use unmodified). ``window_size`` is guaranteed - non-``None`` here because the flax modules canonicalize at - ``__post_init__`` and the public ``fused_attn_fwd`` / - ``fused_attn_bwd`` entrypoints default ``None`` to ``(-1, -1)``. + ``{(-1, -1), (-1, 0)}``. ``window_size`` is guaranteed to be + non-``None`` here because it is already canonicalized by check_set_window_size. Bottom-right causal cross-attention is excluded: the boundary trim leaves kv_seqlen short by one per active segment, which shifts the BRCM bottom-right alignment by one KV @@ -864,7 +848,6 @@ def _segment_ids_pos_to_seqlens_offsets( # must route bottom-right masks to the slow path. # Fast path: O(T) per row. - # window_size is canonical here (Tuple[int, int], never None -- see signature). # "No finite window" is encoded as (-1, -1) for non-causal masks and (-1, 0) for # causal-family masks; both share window_size[0] == -1, which is therefore the # mask-type-agnostic SWA-presence sentinel. @@ -946,12 +929,8 @@ def get_seqlens_and_offsets( """ Acquire the seqlens/offsets for cuDNN backend. - ``window_size`` must be a ``Tuple[int, int]`` (never ``None``). The - callers (``FusedAttn{Fwd,Bwd}Primitive.impl``) construct - ``_FusedAttnConfig`` via ``fused_attn_fwd`` / ``fused_attn_bwd``, - which default ``None`` to ``(-1, -1)`` at the public boundary; the - flax DPA / MHA modules additionally canonicalize user input through - :func:`check_set_window_size` at ``__post_init__`` upstream of that. + ``window_size`` must be a ``Tuple[int, int]`` (never ``None``) + and already canonicalized by check_set_window_size. """ q_segment_ids, kv_segment_ids = self.segment_ids q_segment_pos, kv_segment_pos = self.segment_pos diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 9447eebb58..a03e91c2d9 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -2486,16 +2486,11 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: for off-diagonal steps where the kv chunk is fully past or fully future of the local q chunk; see ``ring_attn_fwd_impl`` / ``ring_attn_bwd_impl``). The user's ``window_size`` is the canonical no-SWA form for the *original* mask, so we - re-canonicalize it for the per-step mask: a causal-canonical ``(-1, 0)`` - propagated unchanged to a ``NO_MASK`` step would otherwise be interpreted by - cuDNN as "no mask + right=0 SWA", erroneously restricting attention to - ``kv_pos <= q_pos`` within that step. - - The partition-time gate (``FusedRingAttnFwdPrimitive.partition`` / - ``FusedRingAttnBwdPrimitive.partition``) guarantees no finite SWA reaches this - primitive, so the recanonicalization is unambiguous. ``warn=False`` because the - user's ``window_size`` was already validated upstream and this per-step coercion - is an internal mask switch the user neither requested nor can act on. + re-canonicalize it for the per-step mask. + + ``warn=False`` because the user's ``window_size`` was already canonicalized + by check_set_window_size upstream. This per-step coercion is an internal mask + switch (for ring P2P) which if reported, may confuse the user. """ per_step_window = check_set_window_size( attn_mask_type, self.config.window_size, warn=False @@ -3172,10 +3167,7 @@ def compute(config): # Trigger striped-window adjustment only when there is a finite SWA. # window_size[0] == -1 is the unified "no finite window" sentinel that # covers both the non-causal (-1, -1) form and the causal-family (-1, 0) - # form produced by check_set_window_size at the flax boundary; window_size[0] - # is the SWA's left bound, so any finite SWA has window_size[0] >= 0 - # (matches the convention used elsewhere in this file, e.g. - # is_sliding_window at line ~2475). + # form produced by check_set_window_size if config.window_size[0] != -1: kv_src_rank = (cp_size + cp_rank - idx) % cp_size # Note: all inputs of adjust_cp_striped_window_size should be host values @@ -3331,8 +3323,7 @@ def compute(config): # See fwd path above: window_size[0] != -1 is the unified "finite SWA" # sentinel that handles both the non-causal (-1, -1) and causal-family - # (-1, 0) canonical forms produced by check_set_window_size at the flax - # boundary. + # (-1, 0) canonical forms produced by check_set_window_size. if config.window_size[0] != -1: kv_src_rank = (cp_size + cp_rank - idx) % cp_size # Note: all inputs of adjust_cp_striped_window_size should be host values @@ -3517,13 +3508,9 @@ def fused_attn_fwd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, - # Canonicalize at the C-extension boundary so direct callers (e.g. the - # context-parallel ring primitives and unit tests that bypass flax) - # see the same canonical encoding as flax users (whose flax DPA/MHA - # __post_init__ already canonicalizes via check_set_window_size). - # For the ring CP P2P path, the per-step mask switch from CAUSAL_MASK - # to NO_MASK is handled inside _FusedAttnCPWithP2PHelper.get_step_config, - # which re-canonicalizes the window_size for the per-step mask type. + # Canonicalize at the CPP-extension boundary so every downstream primitive this + # function dispatches to (default fused-attn and the CP all-gather/ring variants) + # sees the same canonical encoding as the DPA/MHA API. window_size=check_set_window_size(attn_mask_type, window_size), bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=context_parallel_causal_load_balanced, @@ -3699,9 +3686,9 @@ def fused_attn_bwd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, - # Canonicalize at the C-extension boundary; see fused_attn_fwd for the - # rationale. The ring CP P2P path's per-step mask switch is handled - # inside _FusedAttnCPWithP2PHelper.get_step_config. + # Canonicalize at the CPP-extension boundary so every downstream primitive this + # function dispatches to (default fused-attn and the CP all-gather/ring variants) + # sees the same canonical encoding as the DPA/MHA API. window_size=check_set_window_size(attn_mask_type, window_size), bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=context_parallel_causal_load_balanced, diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index df67202cf6..c416ee4901 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -133,9 +133,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX def __post_init__(self): - # Defensive canonicalization: in normal flow window_size arrives canonical - # from DotProductAttention, but this guarantees the inner module never has - # to handle window_size=None. + # Canonicalize window_size so that the inner module never has to handle window_size=None. self.window_size = check_set_window_size(self.attn_mask_type, self.window_size) super().__post_init__() @@ -251,12 +249,12 @@ def apply_swa_mask(original_mask: Array) -> Array: def convert_to_softmax_fusion_type(attn_mask_type, mask): """Convert the attn_mask_type to SoftmaxFusionType""" + #TODO(KshitijLakhani): Fix swa mask construction and softmax fusion selection for + # missing mask cases. # mask is ignored for no_mask and causal_mask without sliding window if attn_mask_type == AttnMaskType.NO_MASK: mask = None - # check_set_window_size (called in __post_init__) canonicalizes - # "no SWA + causal" to (-1, 0), so this is the only sentinel we need - # to recognize here. + # No SWA + causal mask is equivalent to no mask if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size == (-1, 0): mask = None if mask is not None: @@ -322,10 +320,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX def __post_init__(self): - # Defensive canonicalization: parallels _UnfusedDotProductAttention. The - # value is also re-canonicalized inside fused_attn_fwd/_bwd at the - # C-extension boundary, so this is purely about keeping the invariant - # uniform across the private module layer. + # Canonicalize window_size so that the inner modules do not need to handle window_size=None. self.window_size = check_set_window_size(self.attn_mask_type, self.window_size) super().__post_init__() @@ -600,15 +595,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ``padding_causal_bottom_right``: ``(-1, 0)`` (sentinel for infinite-left causal) or ``(>=0, 0)``. - Inputs are validated and canonicalized at construction time. ``None`` is - replaced silently with the sentinel for :attr:`attn_mask_type`: ``(-1, -1)`` - for ``no_mask`` / ``padding`` and ``(-1, 0)`` for the causal family. - Inconsistent sentinels (e.g. ``(-1, 0)`` paired with ``no_mask``, or - ``(W, R)`` with ``R != 0`` paired with a causal-family mask) are coerced with - a warning. Values with negative ``left`` or ``right`` outside the listed - sentinels raise an ``AssertionError``. Bidirectional sliding windows - (``right > 0`` with ``no_mask`` / ``padding``) require cuDNN >= 9.6 in the fused - backend; left-only sliding windows require cuDNN >= 9.2. + Inputs are validated and canonicalized at construction time via check_set_window_size. + Bidirectional sliding windows (right > 0 with no_mask / padding) for fused attention + require cuDNN >= 9.6; left-only sliding windows (right = 0) require cuDNN >= 9.2. max_segments_per_seq: Optional[int], default = 1 The maximum number of segments per sequence, also used for THD format (sequence packing). context_parallel_causal_load_balanced: bool @@ -1208,15 +1197,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ``padding_causal_bottom_right``: ``(-1, 0)`` (sentinel for infinite-left causal) or ``(>=0, 0)``. - Inputs are validated and canonicalized at construction time. ``None`` is - replaced silently with the sentinel for :attr:`attn_mask_type`: ``(-1, -1)`` - for ``no_mask`` / ``padding`` and ``(-1, 0)`` for the causal family. - Inconsistent sentinels (e.g. ``(-1, 0)`` paired with ``no_mask``, or - ``(W, R)`` with ``R != 0`` paired with a causal-family mask) are coerced with - a warning. Values with negative ``left`` or ``right`` outside the listed - sentinels raise an ``AssertionError``. Bidirectional sliding windows - (``right > 0`` with ``no_mask`` / ``padding``) require cuDNN >= 9.6 in the fused - backend; left-only sliding windows require cuDNN >= 9.2. + Inputs are validated and canonicalized at construction time via check_set_window_size. + Bidirectional sliding windows (right > 0 with no_mask / padding) for fused attention + require cuDNN >= 9.6; left-only sliding windows (right = 0) require cuDNN >= 9.2. softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' Softmax type as described in the paper `Efficient Streaming Language Models with Attention Sinks @@ -1982,17 +1965,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods sliding window (full attention for ``no_mask`` / ``padding``, infinite-left for causal-family masks). - This value is forwarded as-is to both the self-attention block (which uses - :attr:`self_attn_mask_type`) and, when :attr:`layer_type` is ``DECODER``, the - cross-attention block (whose mask type is internally fixed to ``padding``). - ``TransformerLayer`` deliberately does not canonicalize ``window_size``: a - decoder layer carries two different mask-type contracts simultaneously, so each - ``MultiHeadAttention`` block canonicalizes against its own mask type in its - own ``__post_init__``. Concretely, for ``self_attn_mask_type`` in the causal - family with ``window_size=None``, the self-attention block silently expands - ``None`` to ``(-1, 0)`` while the cross-attention block silently expands the - same ``None`` to ``(-1, -1)`` -- both blocks land on the correct sentinel for - their own mask type without warnings. + ``TransformerLayer`` deliberately does not canonicalize ``window_size`` and passes + on the responsibility to the inner modules (MultiHeadAttention) to canonicalize. Allowed values per mask type: @@ -2002,12 +1976,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ``padding_causal_bottom_right``: ``(-1, 0)`` (sentinel for infinite-left causal) or ``(>=0, 0)``. - Inconsistent sentinels (e.g. ``(-1, 0)`` paired with ``no_mask``) are coerced - with a warning inside the relevant ``MultiHeadAttention`` block. Values with - negative ``left`` or ``right`` outside the listed sentinels raise an - ``AssertionError``. Bidirectional sliding windows (``right > 0`` with - ``no_mask`` / ``padding``) require cuDNN >= 9.6 in the fused backend; left-only - sliding windows require cuDNN >= 9.2. + Bidirectional sliding windows (right > 0 with no_mask / padding) for fused attention + require cuDNN >= 9.6; left-only sliding windows (right = 0) require cuDNN >= 9.2. softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' Softmax type as described in the paper `Efficient Streaming Language Models with Attention Sinks @@ -2113,16 +2083,8 @@ def __post_init__(self): ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads - # window_size is intentionally NOT canonicalized here. A decoder layer - # constructs two MultiHeadAttention blocks with different mask types - # (self-attention uses self_attn_mask_type; cross-attention is hardcoded - # to "padding"). Canonicalizing at this layer would force a single - # mask-type contract onto a dual-contract object and produce a sentinel - # that is correct for one block but not the other (e.g. None + causal - # self-attn would yield (-1, 0), which the padding cross-attn would - # then warn-coerce). Instead, the raw user value is forwarded as-is to - # both blocks; each block canonicalizes against its own mask type in - # MultiHeadAttention.__post_init__. + # window_size is intentionally NOT canonicalized here and is passed on to the inner modules + # (MultiHeadAttention) to canonicalize against the mask type. super().__post_init__() @nn.compact From f1a67b85d0cb6161bac364228cb311f620e57903 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 May 2026 18:29:08 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/attention.py | 18 +++++++----------- .../jax/cpp_extensions/attention.py | 8 +++----- transformer_engine/jax/flax/transformer.py | 6 +++--- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index e683bbb833..c9416c9bf4 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -374,16 +374,14 @@ def check_set_window_size( window_size = (orig_window_size[0], 0) if warn: warnings.warn( - f"window_size should be (-1, 0) or (>=0, 0) for " + "window_size should be (-1, 0) or (>=0, 0) for " f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}; " f"coercing to {window_size}." ) # Assert if invalid window size is provided. - elif orig_window_size != (-1, 0) and ( - orig_window_size[0] < 0 or orig_window_size[1] != 0 - ): + elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0): raise AssertionError( - f"window_size should be (-1, 0) or (>=0, 0) for " + "window_size should be (-1, 0) or (>=0, 0) for " f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}." ) elif attn_mask_type_enum in (AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK): @@ -394,16 +392,14 @@ def check_set_window_size( window_size = (-1, -1) if warn: warnings.warn( - f"window_size should be (-1, -1) or (>=0, >=0) for " + "window_size should be (-1, -1) or (>=0, >=0) for " f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}; " f"coercing to {window_size}." ) # Assert if invalid window size is provided. - elif orig_window_size != (-1, -1) and ( - orig_window_size[0] < 0 or orig_window_size[1] < 0 - ): + elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0): raise AssertionError( - f"window_size should be (-1, -1) or (>=0, >=0) for " + "window_size should be (-1, -1) or (>=0, >=0) for " f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}." ) else: @@ -929,7 +925,7 @@ def get_seqlens_and_offsets( """ Acquire the seqlens/offsets for cuDNN backend. - ``window_size`` must be a ``Tuple[int, int]`` (never ``None``) + ``window_size`` must be a ``Tuple[int, int]`` (never ``None``) and already canonicalized by check_set_window_size. """ q_segment_ids, kv_segment_ids = self.segment_ids diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index a03e91c2d9..ce0eac834a 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -2488,13 +2488,11 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: ``window_size`` is the canonical no-SWA form for the *original* mask, so we re-canonicalize it for the per-step mask. - ``warn=False`` because the user's ``window_size`` was already canonicalized - by check_set_window_size upstream. This per-step coercion is an internal mask + ``warn=False`` because the user's ``window_size`` was already canonicalized + by check_set_window_size upstream. This per-step coercion is an internal mask switch (for ring P2P) which if reported, may confuse the user. """ - per_step_window = check_set_window_size( - attn_mask_type, self.config.window_size, warn=False - ) + per_step_window = check_set_window_size(attn_mask_type, self.config.window_size, warn=False) return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, attn_mask_type=attn_mask_type, diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index c416ee4901..53f580f883 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -249,7 +249,7 @@ def apply_swa_mask(original_mask: Array) -> Array: def convert_to_softmax_fusion_type(attn_mask_type, mask): """Convert the attn_mask_type to SoftmaxFusionType""" - #TODO(KshitijLakhani): Fix swa mask construction and softmax fusion selection for + # TODO(KshitijLakhani): Fix swa mask construction and softmax fusion selection for # missing mask cases. # mask is ignored for no_mask and causal_mask without sliding window if attn_mask_type == AttnMaskType.NO_MASK: @@ -595,7 +595,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ``padding_causal_bottom_right``: ``(-1, 0)`` (sentinel for infinite-left causal) or ``(>=0, 0)``. - Inputs are validated and canonicalized at construction time via check_set_window_size. + Inputs are validated and canonicalized at construction time via check_set_window_size. Bidirectional sliding windows (right > 0 with no_mask / padding) for fused attention require cuDNN >= 9.6; left-only sliding windows (right = 0) require cuDNN >= 9.2. max_segments_per_seq: Optional[int], default = 1 @@ -1197,7 +1197,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ``padding_causal_bottom_right``: ``(-1, 0)`` (sentinel for infinite-left causal) or ``(>=0, 0)``. - Inputs are validated and canonicalized at construction time via check_set_window_size. + Inputs are validated and canonicalized at construction time via check_set_window_size. Bidirectional sliding windows (right > 0 with no_mask / padding) for fused attention require cuDNN >= 9.6; left-only sliding windows (right = 0) require cuDNN >= 9.2. softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' From 6c3c9b8aaf33fa0fd5422488be1b7d91b4741fa0 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 8 May 2026 14:16:31 -0700 Subject: [PATCH 6/8] Fix small conditional check on right side window for causal when coercing Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index c9416c9bf4..1b69607119 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -369,7 +369,7 @@ def check_set_window_size( window_size = (-1, 0) # Coerce the right side window to 0. elif orig_window_size == (-1, -1) or ( - orig_window_size[0] >= 0 and orig_window_size[1] != 0 + orig_window_size[0] >= 0 and orig_window_size[1] > 0 ): window_size = (orig_window_size[0], 0) if warn: @@ -427,7 +427,7 @@ def is_fused_attn_kernel_available( """ To check whether the fused attention kernel is supported """ - # Canonicalize at the C-extension boundary so direct callers see the same + # Canonicalize at the CPP-extension boundary so direct callers see the same # canonical encoding as users of DPA/MHA API to ensure consistency. window_size_tuple = check_set_window_size(attn_mask_type, window_size) From 1c53da0f5183c7555a9e23e9c3d3f9c5f5ceb31a Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 8 May 2026 14:20:19 -0700 Subject: [PATCH 7/8] Fix small conditional check on right side window for causal when coercing for PyTorch framework code Signed-off-by: Kshitij Lakhani --- .../pytorch/attention/dot_product_attention/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index ed87423534..7a2f1788ed 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2337,7 +2337,7 @@ def check_set_window_size( if orig_window_size is None: window_size = (-1, 0) elif orig_window_size == (-1, -1) or ( - orig_window_size[0] >= 0 and orig_window_size[1] != 0 + orig_window_size[0] >= 0 and orig_window_size[1] > 0 ): window_size = (orig_window_size[0], 0) warnings.warn( From 43a8d27b46bd9706b6480bd93a7f08cc8168e119 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 May 2026 21:21:36 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/attention.py | 4 +--- .../pytorch/attention/dot_product_attention/utils.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 1b69607119..c4664cb764 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -368,9 +368,7 @@ def check_set_window_size( if orig_window_size is None: window_size = (-1, 0) # Coerce the right side window to 0. - elif orig_window_size == (-1, -1) or ( - orig_window_size[0] >= 0 and orig_window_size[1] > 0 - ): + elif orig_window_size == (-1, -1) or (orig_window_size[0] >= 0 and orig_window_size[1] > 0): window_size = (orig_window_size[0], 0) if warn: warnings.warn( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 7a2f1788ed..30a310bb9a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2336,9 +2336,7 @@ def check_set_window_size( if "causal" in attn_mask_type: if orig_window_size is None: window_size = (-1, 0) - elif orig_window_size == (-1, -1) or ( - orig_window_size[0] >= 0 and orig_window_size[1] > 0 - ): + elif orig_window_size == (-1, -1) or (orig_window_size[0] >= 0 and orig_window_size[1] > 0): window_size = (orig_window_size[0], 0) warnings.warn( "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type