diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 1fb0108068..7420dcfdc2 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -36,6 +36,7 @@ fused_attn, run_length_fill, make_swa_mask, + check_set_window_size, SequenceDescriptor, CPStrategy, ReorderStrategy, @@ -1065,10 +1066,13 @@ def _get_swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tu cuDNN < 9.2: skip (no SWA support). cuDNN >= 9.2: left-only window (s_kv // 10, 0). - cuDNN >= 9.6: bidirectional window (s_kv // 10, s_kv // 10 + 5) for the mask types whose - bidirectional fused dispatch is meaningful here (NO_MASK, PADDING_MASK). - Other mask types keep the left-only window: causal-family masks would - collapse (W, W) -> (W, 0), hence not tested here. + cuDNN >= 9.6: bidirectional asymmetric window (s_kv // 10, s_kv // 10 + 5) for the mask + types whose bidirectional fused dispatch is meaningful here (NO_MASK, + PADDING_MASK). Other mask types keep the left-only window: causal-family + masks would collapse (W, W) -> (W, 0), hence not tested here. + + The chosen ``(left, right)`` is then routed through :func:`check_set_window_size`, which + is the same canonicalizer the production modules call at construction time. """ cudnn_version = get_cudnn_version() if cudnn_version < 90200: @@ -1080,8 +1084,11 @@ def _get_swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tu AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK, ): - return (left_window_size, right_window_size) - return (left_window_size, 0) + candidate = (left_window_size, right_window_size) + else: + candidate = (left_window_size, 0) + # Validate the window size against the contract and return the canonicalized value. + return check_set_window_size(attn_mask_type, candidate) @pytest.mark.parametrize( diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index f54a043fd2..c4664cb764 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -323,6 +323,88 @@ def canonicalize_attn_mask_type(attn_mask_type: str): ) +def check_set_window_size( + attn_mask_type: Union[str, AttnMaskType], + window_size: Optional[Tuple[int, int]] = None, + *, + warn: bool = True, +) -> Tuple[int, int]: + """Check if sliding window size is compliant with attention mask type. + If not, set it to the appropriate size. + + attn_mask_type | window_size + ---------------------------------------------------------------------------- + no_mask, padding | (-1, -1) or (>=0, >=0) + causal, padding_causal | (-1, 0) or (>=0, 0) + causal_bottom_right, padding_causal_bottom_right | (-1, 0) or (>=0, 0) + + ``(-1, -1)`` and ``(-1, 0)`` are sentinels meaning "no window" (full attention) and + "infinite-left causal" respectively. Negative entries are otherwise rejected. + + Args: + attn_mask_type: Either a canonical ``attn_mask_type`` string (e.g. ``"no_mask"``, + ``"padding"``, ``"causal"``, ``"padding_causal"``, ``"causal_bottom_right"``, + ``"padding_causal_bottom_right"``) or an :class:`AttnMaskType` enum value. + window_size: ``(left, right)`` tuple, or ``None`` to use the natural default for the + given mask type. + warn: When ``True`` (default), emit a :class:`UserWarning` whenever the supplied + ``window_size`` is silently coerced to the canonical form for ``attn_mask_type`` + Set to ``False`` for internal call sites that do not need to emit warnings. + Hard-error branches (negative bounds outside the recognized sentinels) are not gated by this flag + and always raise. + + Returns: + The canonicalized ``(left, right)`` tuple. + """ + if isinstance(attn_mask_type, str): + attn_mask_type_enum = canonicalize_attn_mask_type(attn_mask_type) + attn_mask_type_str = attn_mask_type + else: + attn_mask_type_enum = attn_mask_type + attn_mask_type_str = attn_mask_type.name + + orig_window_size = window_size + if attn_mask_type_enum.is_causal(): + if orig_window_size is None: + window_size = (-1, 0) + # Coerce the right side window to 0. + elif orig_window_size == (-1, -1) or (orig_window_size[0] >= 0 and orig_window_size[1] > 0): + window_size = (orig_window_size[0], 0) + if warn: + warnings.warn( + "window_size should be (-1, 0) or (>=0, 0) for " + f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}; " + f"coercing to {window_size}." + ) + # Assert if invalid window size is provided. + elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0): + raise AssertionError( + "window_size should be (-1, 0) or (>=0, 0) for " + f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}." + ) + elif attn_mask_type_enum in (AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK): + if orig_window_size is None: + window_size = (-1, -1) + # Coerce the right side window to -1. + elif orig_window_size == (-1, 0): + window_size = (-1, -1) + if warn: + warnings.warn( + "window_size should be (-1, -1) or (>=0, >=0) for " + f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}; " + f"coercing to {window_size}." + ) + # Assert if invalid window size is provided. + elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0): + raise AssertionError( + "window_size should be (-1, -1) or (>=0, >=0) for " + f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}." + ) + else: + raise AssertionError(f"Invalid attn_mask_type: {attn_mask_type_str}") + return window_size + + def is_fused_attn_kernel_available( is_training, q_dtype, @@ -343,7 +425,9 @@ def is_fused_attn_kernel_available( """ To check whether the fused attention kernel is supported """ - window_size_tuple = (-1, -1) if window_size is None else window_size + # Canonicalize at the CPP-extension boundary so direct callers see the same + # canonical encoding as users of DPA/MHA API to ensure consistency. + window_size_tuple = check_set_window_size(attn_mask_type, window_size) def make_helper(attn_mask_type): return tex.FusedAttnHelper( @@ -688,9 +772,9 @@ def _segment_ids_pos_to_seqlens_offsets( segment_ids_kv, segment_pos_q, segment_pos_kv, - attn_mask_type, - window_size, - max_segments_per_seq, + attn_mask_type: AttnMaskType, + window_size: Tuple[int, int], + max_segments_per_seq: int, ): """Compute per-segment seqlens and start offsets(currently only used for THD) Given segment-id and segment-position tensors for Q and KV, @@ -708,8 +792,9 @@ def _segment_ids_pos_to_seqlens_offsets( attn_mask_type: AttnMaskType. Selects the mask predicate used to decide which positions are valid (top-left causal vs bottom-right causal vs. padding-only) - window_size: Optional sliding-window tuple ``(left, right)`` or None - Used here only as a fast-path eligibility hint + window_size: Sliding-window tuple ``(left, right)``. Required (not + Optional): Tuple[int, int]. Window size received should be + already canonicalized by check_set_window_size. max_segments_per_seq: maximum number of segments expected per row Used to size the bincount / argwhere outputs @@ -717,12 +802,14 @@ def _segment_ids_pos_to_seqlens_offsets( 1. Fast path -- ``_segment_ids_pos_to_seqlens_offsets_fast_causal_path``. O(T) per row. Counts all segment tokens via bincount on segment_ids and trims at most one token per segment at the - boundary. Used for: - - top-left CAUSAL / PADDING_CAUSAL with ``window_size is None`` - - SWA with ``window_size == (-1, -1)`` and not bottom-right - Bottom-right causal cross-attention is excluded: the boundary - trim leaves kv_seqlen short by one per active segment, which - shifts the BRCM bottom-right alignment by one KV per Q row. + boundary. Used for any non-bottom-right mask with no finite + sliding window, i.e. ``window_size`` in + ``{(-1, -1), (-1, 0)}``. ``window_size`` is guaranteed to be + non-``None`` here because it is already canonicalized by check_set_window_size. + Bottom-right causal cross-attention is excluded: + the boundary trim leaves kv_seqlen short by one per active + segment, which shifts the BRCM bottom-right alignment by one KV + per Q row. 2. Slow path -- ``_get_seqlens_offsets_thd``. O(T * max_segments_per_seq) per row. Per-segment min/max @@ -755,9 +842,11 @@ def _segment_ids_pos_to_seqlens_offsets( # must route bottom-right masks to the slow path. # Fast path: O(T) per row. - if ( - attn_mask_type.is_causal() and not attn_mask_type.is_bottom_right() and window_size is None - ) or (window_size == (-1, -1) and not attn_mask_type.is_bottom_right()): + # "No finite window" is encoded as (-1, -1) for non-causal masks and (-1, 0) for + # causal-family masks; both share window_size[0] == -1, which is therefore the + # mask-type-agnostic SWA-presence sentinel. + no_finite_window = window_size[0] == -1 + if no_finite_window and not attn_mask_type.is_bottom_right(): return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq ) @@ -825,10 +914,17 @@ def tree_unflatten(cls, aux_data, children): return cls(*children) def get_seqlens_and_offsets( - self, attn_mask_type, qkv_layout, window_size, max_segments_per_seq + self, + attn_mask_type: "AttnMaskType", + qkv_layout: "QKVLayout", + window_size: Tuple[int, int], + max_segments_per_seq: int, ): """ Acquire the seqlens/offsets for cuDNN backend. + + ``window_size`` must be a ``Tuple[int, int]`` (never ``None``) + and already canonicalized by check_set_window_size. """ q_segment_ids, kv_segment_ids = self.segment_ids q_segment_pos, kv_segment_pos = self.segment_pos diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 40d02f40e1..ce0eac834a 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -25,6 +25,7 @@ QKVFormat, CPStrategy, SequenceDescriptor, + check_set_window_size, ) from ..sharding import with_sharding_constraint_by_logical_axes, HEAD_AXES, is_mesh_available @@ -2479,7 +2480,19 @@ def check_supported(self): ) def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: - """Returns a _FusedAttnConfig for single CP step call to fused attention.""" + """Returns a _FusedAttnConfig for single CP step call to fused attention. + + Ring CP overrides ``attn_mask_type`` per step (e.g. ``CAUSAL_MASK`` -> ``NO_MASK`` + for off-diagonal steps where the kv chunk is fully past or fully future of the + local q chunk; see ``ring_attn_fwd_impl`` / ``ring_attn_bwd_impl``). The user's + ``window_size`` is the canonical no-SWA form for the *original* mask, so we + re-canonicalize it for the per-step mask. + + ``warn=False`` because the user's ``window_size`` was already canonicalized + by check_set_window_size upstream. This per-step coercion is an internal mask + switch (for ring P2P) which if reported, may confuse the user. + """ + per_step_window = check_set_window_size(attn_mask_type, self.config.window_size, warn=False) return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, attn_mask_type=attn_mask_type, @@ -2489,7 +2502,7 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: dropout_probability=self.config.dropout_probability, is_training=self.config.is_training, max_segments_per_seq=self.config.max_segments_per_seq, - window_size=self.config.window_size, + window_size=per_step_window, bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, @@ -3149,7 +3162,11 @@ def compute(config): config=config, ) - if config.window_size != (-1, -1): + # Trigger striped-window adjustment only when there is a finite SWA. + # window_size[0] == -1 is the unified "no finite window" sentinel that + # covers both the non-causal (-1, -1) form and the causal-family (-1, 0) + # form produced by check_set_window_size + if config.window_size[0] != -1: kv_src_rank = (cp_size + cp_rank - idx) % cp_size # Note: all inputs of adjust_cp_striped_window_size should be host values cp_striped_window_size = adjust_cp_striped_window_size( @@ -3302,7 +3319,10 @@ def compute(config): ) return dq_per_step, dkv_per_step, dbias_per_step - if config.window_size != (-1, -1): + # See fwd path above: window_size[0] != -1 is the unified "finite SWA" + # sentinel that handles both the non-causal (-1, -1) and causal-family + # (-1, 0) canonical forms produced by check_set_window_size. + if config.window_size[0] != -1: kv_src_rank = (cp_size + cp_rank - idx) % cp_size # Note: all inputs of adjust_cp_striped_window_size should be host values cp_striped_window_size = adjust_cp_striped_window_size( @@ -3486,7 +3506,10 @@ def fused_attn_fwd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, - window_size=(-1, -1) if window_size is None else window_size, + # Canonicalize at the CPP-extension boundary so every downstream primitive this + # function dispatches to (default fused-attn and the CP all-gather/ring variants) + # sees the same canonical encoding as the DPA/MHA API. + window_size=check_set_window_size(attn_mask_type, window_size), bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), @@ -3661,7 +3684,10 @@ def fused_attn_bwd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, - window_size=(-1, -1) if window_size is None else window_size, + # Canonicalize at the CPP-extension boundary so every downstream primitive this + # function dispatches to (default fused-attn and the CP all-gather/ring variants) + # sees the same canonical encoding as the DPA/MHA API. + window_size=check_set_window_size(attn_mask_type, window_size), bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index a2e7920843..53f580f883 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -30,7 +30,12 @@ QKVLayout, SequenceDescriptor, ) -from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type +from ..attention import ( + is_fused_attn_kernel_available, + make_swa_mask, + canonicalize_attn_mask_type, + check_set_window_size, +) from ..attention import fused_attn from ..attention import CPStrategy from ..softmax import SoftmaxFusionType @@ -127,6 +132,11 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- window_size: Optional[Tuple[int, int]] = None softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX + def __post_init__(self): + # Canonicalize window_size so that the inner module never has to handle window_size=None. + self.window_size = check_set_window_size(self.attn_mask_type, self.window_size) + super().__post_init__() + @nn.compact def __call__( self, @@ -239,10 +249,13 @@ def apply_swa_mask(original_mask: Array) -> Array: def convert_to_softmax_fusion_type(attn_mask_type, mask): """Convert the attn_mask_type to SoftmaxFusionType""" + # TODO(KshitijLakhani): Fix swa mask construction and softmax fusion selection for + # missing mask cases. # mask is ignored for no_mask and causal_mask without sliding window if attn_mask_type == AttnMaskType.NO_MASK: mask = None - if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None: + # No SWA + causal mask is equivalent to no mask + if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size == (-1, 0): mask = None if mask is not None: mask = apply_swa_mask(mask) @@ -306,6 +319,11 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me context_checkpoint_name: str = "context" softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX + def __post_init__(self): + # Canonicalize window_size so that the inner modules do not need to handle window_size=None. + self.window_size = check_set_window_size(self.attn_mask_type, self.window_size) + super().__post_init__() + @nn.compact def __call__( self, @@ -565,7 +583,21 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...). window_size: Optional[Tuple[int, int]], default = None - Sliding window size. The default value is no sliding window. + Sliding window size as ``(left, right)``. The default value of ``None`` means no + sliding window (full attention for ``no_mask`` / ``padding``, infinite-left for + causal-family masks). + + Allowed values per :attr:`attn_mask_type`: + + * ``no_mask``, ``padding``: ``(-1, -1)`` (sentinel for full attention) or + ``(>=0, >=0)``. + * ``causal``, ``padding_causal``, ``causal_bottom_right``, + ``padding_causal_bottom_right``: ``(-1, 0)`` (sentinel for infinite-left + causal) or ``(>=0, 0)``. + + Inputs are validated and canonicalized at construction time via check_set_window_size. + Bidirectional sliding windows (right > 0 with no_mask / padding) for fused attention + require cuDNN >= 9.6; left-only sliding windows (right = 0) require cuDNN >= 9.2. max_segments_per_seq: Optional[int], default = 1 The maximum number of segments per sequence, also used for THD format (sequence packing). context_parallel_causal_load_balanced: bool @@ -638,6 +670,8 @@ def __post_init__(self): " TransformerEngine v2.10" ) self.transpose_batch_sequence = False + # Validate / canonicalize window_size against attn_mask_type. + self.window_size = check_set_window_size(self.attn_mask_type, self.window_size) super().__post_init__() def _assert_dtypes(self, query: Array, key: Array, value: Array, qkv_layout: QKVLayout): @@ -1151,7 +1185,21 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods fuse_qkv: bool, default = None Deprecated. Please refer `fuse_qkv_params` window_size: Optional[Tuple[int, int]], default = None - Sliding window size. Default value is no sliding window. + Sliding window size as ``(left, right)``. The default value of ``None`` means no + sliding window (full attention for ``no_mask`` / ``padding``, infinite-left for + causal-family masks). + + Allowed values per :attr:`attn_mask_type`: + + * ``no_mask``, ``padding``: ``(-1, -1)`` (sentinel for full attention) or + ``(>=0, >=0)``. + * ``causal``, ``padding_causal``, ``causal_bottom_right``, + ``padding_causal_bottom_right``: ``(-1, 0)`` (sentinel for infinite-left + causal) or ``(>=0, 0)``. + + Inputs are validated and canonicalized at construction time via check_set_window_size. + Bidirectional sliding windows (right > 0 with no_mask / padding) for fused attention + require cuDNN >= 9.6; left-only sliding windows (right = 0) require cuDNN >= 9.2. softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' Softmax type as described in the paper `Efficient Streaming Language Models with Attention Sinks @@ -1266,6 +1314,8 @@ def __post_init__(self): ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads + # Validate / canonicalize window_size against attn_mask_type. + self.window_size = check_set_window_size(self.attn_mask_type, self.window_size) super().__post_init__() @nn.compact @@ -1911,7 +1961,23 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods enable_sequence_parallel: bool, default = False Whether to enable sequence parallelism to operations except dot. window_size: Optional[Tuple[int, int]], default = None - Sliding window size. Default value is no sliding window. + Sliding window size as ``(left, right)``. The default value of ``None`` means no + sliding window (full attention for ``no_mask`` / ``padding``, infinite-left for + causal-family masks). + + ``TransformerLayer`` deliberately does not canonicalize ``window_size`` and passes + on the responsibility to the inner modules (MultiHeadAttention) to canonicalize. + + Allowed values per mask type: + + * ``no_mask``, ``padding``: ``(-1, -1)`` (sentinel for full attention) or + ``(>=0, >=0)``. + * ``causal``, ``padding_causal``, ``causal_bottom_right``, + ``padding_causal_bottom_right``: ``(-1, 0)`` (sentinel for infinite-left + causal) or ``(>=0, 0)``. + + Bidirectional sliding windows (right > 0 with no_mask / padding) for fused attention + require cuDNN >= 9.6; left-only sliding windows (right = 0) require cuDNN >= 9.2. softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' Softmax type as described in the paper `Efficient Streaming Language Models with Attention Sinks @@ -2017,6 +2083,8 @@ def __post_init__(self): ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads + # window_size is intentionally NOT canonicalized here and is passed on to the inner modules + # (MultiHeadAttention) to canonicalize against the mask type. super().__post_init__() @nn.compact diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index ed87423534..30a310bb9a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2336,9 +2336,7 @@ def check_set_window_size( if "causal" in attn_mask_type: if orig_window_size is None: window_size = (-1, 0) - elif orig_window_size == (-1, -1) or ( - orig_window_size[0] >= 0 and orig_window_size[1] != 0 - ): + elif orig_window_size == (-1, -1) or (orig_window_size[0] >= 0 and orig_window_size[1] > 0): window_size = (orig_window_size[0], 0) warnings.warn( "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type