diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 2282620bc..11ef2aaa6 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,7 @@ NVIDIA Model Optimizer Changelog **New Features** - Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics. +- Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). For every M consecutive key positions, the top-N attention scores are kept and the rest are set to -inf before softmax. See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. **Bug Fixes** diff --git a/modelopt/torch/kernels/hf_triton_attention.py b/modelopt/torch/kernels/hf_triton_attention.py index 5f10df250..73db8b69a 100644 --- a/modelopt/torch/kernels/hf_triton_attention.py +++ b/modelopt/torch/kernels/hf_triton_attention.py @@ -105,6 +105,15 @@ def triton_attention_forward( kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32) kw["max_input_len_k"] = seq_k + # N:M sparse softmax + if getattr(module, "_apply_sparse_nm", False): + method = getattr(module, "_sparse_method_instance", None) + if method is not None: + kw["sparsity_n"] = getattr(method, "sparsity_n", 2) + kw["sparsity_m"] = getattr(method, "sparsity_m", 4) + kw["num_sink_blocks"] = getattr(method, "num_sink_blocks", 0) + kw["dense_window_blocks"] = getattr(method, "dense_window_blocks", 1) + o = attention(q, k, v, **kw) attn_output = o.view(batch, seq_len, num_heads, head_dim) diff --git a/modelopt/torch/kernels/triton_fa.py b/modelopt/torch/kernels/triton_fa.py index b9184788b..391c9cd7c 100644 --- a/modelopt/torch/kernels/triton_fa.py +++ b/modelopt/torch/kernels/triton_fa.py @@ -45,6 +45,116 @@ _FWD_CONFIGS = [triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=1, num_warps=4)] +# --------------------------------------------------------------------------- +# N:M sparse softmax helpers +# --------------------------------------------------------------------------- +@triton.jit +def _sparse_nm_masks_m4(x0, x1, x2, x3, N: tl.constexpr): + """Top-N of 4 selection via pure boolean logic (6 comparisons, no int casts). + + Uses ``>=`` so that ties are broken by index (lower index wins). + Guarantees exactly N masks are True for any input including all-equal. + + Boolean formulas for "at least K of 3 wins": + K=3 (N=1): AND of all — must beat all 3 others + K=2 (N=2): majority — must beat at least 2 (sorting network) + K=1 (N=3): OR of all — must beat at least 1 + """ + c01 = x0 >= x1 + c02 = x0 >= x2 + c03 = x0 >= x3 + c12 = x1 >= x2 + c13 = x1 >= x3 + c23 = x2 >= x3 + + nc01 = ~c01 + nc02 = ~c02 + nc03 = ~c03 + nc12 = ~c12 + nc13 = ~c13 + nc23 = ~c23 + + if N == 1: + # Keep max only: must beat all 3 + m0 = c01 & c02 & c03 + m1 = nc01 & c12 & c13 + m2 = nc02 & nc12 & c23 + m3 = nc03 & nc13 & nc23 + elif N == 2: + # Majority vote: must beat at least 2 of 3 + m0 = (c01 & c02) | (c01 & c03) | (c02 & c03) + m1 = (nc01 & c12) | (nc01 & c13) | (c12 & c13) + m2 = (nc02 & nc12) | (nc02 & c23) | (nc12 & c23) + m3 = (nc03 & nc13) | (nc03 & nc23) | (nc13 & nc23) + elif N == 3: + # Keep all but min: must beat at least 1 + m0 = c01 | c02 | c03 + m1 = nc01 | c12 | c13 + m2 = nc02 | nc12 | c23 + m3 = nc03 | nc13 | nc23 + else: + tl.static_assert(False, "N must be 1, 2, or 3 for M=4") + + return m0, m1, m2, m3 + + +@triton.jit +def _apply_sparse_nm_to_qk_tile( + qk, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + SPARSITY_N: tl.constexpr, + SPARSITY_M: tl.constexpr, +): + """Apply N:M sparse softmax to a QK score tile. + + For every ``SPARSITY_M`` consecutive elements along the N (key) dimension, + keeps the top ``SPARSITY_N`` values and sets the rest to ``-inf``. + ``BLOCK_N`` must be divisible by ``SPARSITY_M``. + + For M=4, exactly N values are retained (ties broken by position). + For M=8, a threshold-based approach (``tl.sort``) may retain more + than N values when ties straddle the threshold boundary. + """ + tl.static_assert(SPARSITY_M == 4 or SPARSITY_M == 8, "SPARSITY_M must be 4 or 8") # noqa: PLR1714 + MASK_VAL: tl.constexpr = float("-inf") + + if SPARSITY_M == 4: + tl.static_assert(BLOCK_N % 4 == 0, "BLOCK_N must be divisible by 4") + reshaped = tl.reshape(qk, (BLOCK_M, BLOCK_N // 4, 4)) + cols = tl.arange(0, 4)[None, None, :] + x0 = tl.sum(tl.where(cols == 0, reshaped, 0.0), axis=2) + x1 = tl.sum(tl.where(cols == 1, reshaped, 0.0), axis=2) + x2 = tl.sum(tl.where(cols == 2, reshaped, 0.0), axis=2) + x3 = tl.sum(tl.where(cols == 3, reshaped, 0.0), axis=2) + + m0, m1, m2, m3 = _sparse_nm_masks_m4(x0, x1, x2, x3, SPARSITY_N) + + out = tl.full((BLOCK_M, BLOCK_N // 4, 4), 0.0, dtype=qk.dtype) + out = tl.where(cols == 0, tl.expand_dims(tl.where(m0, x0, MASK_VAL), 2), out) + out = tl.where(cols == 1, tl.expand_dims(tl.where(m1, x1, MASK_VAL), 2), out) + out = tl.where(cols == 2, tl.expand_dims(tl.where(m2, x2, MASK_VAL), 2), out) + out = tl.where(cols == 3, tl.expand_dims(tl.where(m3, x3, MASK_VAL), 2), out) + return tl.reshape(out, (BLOCK_M, BLOCK_N)) + + else: # SPARSITY_M == 8 + tl.static_assert(BLOCK_N % 8 == 0, "BLOCK_N must be divisible by 8") + reshaped = tl.reshape(qk, (BLOCK_M, BLOCK_N // 8, 8)) + + # Sort each group of 8 ascending; N-th largest is at index (8 - N) + sorted_vals = tl.sort(reshaped, dim=2) + KTH_IDX: tl.constexpr = SPARSITY_M - SPARSITY_N # index of N-th largest in ascending order + + # Extract the threshold value at KTH_IDX via masked sum + # Use 0.0 as fill (not -inf) so sum equals just the KTH element + cols = tl.arange(0, 8)[None, None, :] + threshold = tl.sum(tl.where(cols == KTH_IDX, sorted_vals, 0.0), axis=2) + + # Mask: keep elements >= threshold (may keep >N on ties — acceptable) + mask = reshaped >= tl.expand_dims(threshold, 2) + return tl.reshape(tl.where(mask, reshaped, MASK_VAL), (BLOCK_M, BLOCK_N)) + + # --------------------------------------------------------------------------- # Masking helper # --------------------------------------------------------------------------- @@ -105,6 +215,10 @@ def _attn_fwd( IS_CAUSAL: tl.constexpr, # Whether to apply causal mask HEAD_DIM: tl.constexpr, # Actual head dimension (for d_mask) STORE_LSE: tl.constexpr, # Whether to save LSE for backward pass + SPARSITY_N: tl.constexpr = 0, # N:M sparsity — keep top-N of every M elements (0 = disabled) + SPARSITY_M: tl.constexpr = 4, # N:M sparsity — group size (4 or 8) + NUM_SINK_BLOCKS: tl.constexpr = 0, # Leading KV blocks kept dense (attention sinks) + DENSE_WINDOW_BLOCKS: tl.constexpr = 1, # Local blocks near diagonal kept dense ): # --- Grid: (batch, num_q_heads, num_q_tiles) --- # Example: batch=2, num_q_heads=32, seq_len=256, BLOCK_M=128 @@ -162,6 +276,21 @@ def _attn_fwd( scores = tl.dot(q, k) * qk_scale scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) + # --- Optional N:M sparse softmax --- + if SPARSITY_N > 0: + # Check if this KV tile should be kept dense + kv_block_idx = kv_start // BLOCK_N + is_sink = kv_block_idx < NUM_SINK_BLOCKS + # causal_offset handles chunked prefill: q starts at (seq_len_kv - seq_len_q) + causal_offset = seq_len_kv - seq_len_q + q_abs_block = (tile_q * BLOCK_M + causal_offset) // BLOCK_N + block_distance = q_abs_block - kv_block_idx + is_local = (block_distance < DENSE_WINDOW_BLOCKS) and (block_distance >= 0) + if not is_sink and not is_local: + scores = _apply_sparse_nm_to_qk_tile( + scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M + ) + # --- Online softmax update --- # 1. Update running max m_new = tl.maximum(row_max, tl.max(scores, 1)) @@ -278,6 +407,10 @@ def _attn_bwd_dq( BLOCK_N: tl.constexpr, IS_CAUSAL: tl.constexpr, HEAD_DIM: tl.constexpr, + SPARSITY_N: tl.constexpr = 0, + SPARSITY_M: tl.constexpr = 4, + NUM_SINK_BLOCKS: tl.constexpr = 0, + DENSE_WINDOW_BLOCKS: tl.constexpr = 1, ): """Phase 3 of backward: compute dQ for one Q tile, looping over KV tiles. @@ -343,6 +476,20 @@ def _attn_bwd_dq( # Recompute attention: S = Q @ K^T, P = exp2(S - LSE) scores = tl.dot(q, kT) * qk_scale scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) + + # Re-apply N:M sparse softmax to match forward pass + if SPARSITY_N > 0: + kv_block_idx = kv_start // BLOCK_N + is_sink = kv_block_idx < NUM_SINK_BLOCKS + causal_offset = seq_len_kv - seq_len_q + q_abs_block = (tile_q * BLOCK_M + causal_offset) // BLOCK_N + block_distance = q_abs_block - kv_block_idx + is_local = (block_distance < DENSE_WINDOW_BLOCKS) and (block_distance >= 0) + if not is_sink and not is_local: + scores = _apply_sparse_nm_to_qk_tile( + scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M + ) + p = tl.math.exp2(scores - lse[:, None]) # dP = dO @ V^T, dS = P * (dP - delta), dQ += dS @ K @@ -392,6 +539,10 @@ def _attn_bwd_dkdv( BLOCK_N: tl.constexpr, IS_CAUSAL: tl.constexpr, HEAD_DIM: tl.constexpr, + SPARSITY_N: tl.constexpr = 0, + SPARSITY_M: tl.constexpr = 4, + NUM_SINK_BLOCKS: tl.constexpr = 0, + DENSE_WINDOW_BLOCKS: tl.constexpr = 1, ): """Phase 2 of backward: compute dK, dV for one KV tile. @@ -465,6 +616,20 @@ def _attn_bwd_dkdv( # Recompute attention: S = Q @ K^T, P = exp2(S - LSE) scores = tl.dot(q_tile, kT) * qk_scale scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) + + # Re-apply N:M sparse softmax to match forward pass + if SPARSITY_N > 0: + kv_block_idx = kv_start // BLOCK_N + is_sink = kv_block_idx < NUM_SINK_BLOCKS + causal_offset = seq_len_kv - seq_len_q + q_abs_block = (qi * BLOCK_M + causal_offset) // BLOCK_N + block_distance = q_abs_block - kv_block_idx + is_local = (block_distance < DENSE_WINDOW_BLOCKS) and (block_distance >= 0) + if not is_sink and not is_local: + scores = _apply_sparse_nm_to_qk_tile( + scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M + ) + p = tl.math.exp2(scores - lse[:, None]) # dV += P^T @ dO @@ -498,6 +663,10 @@ def forward( b_start_loc_k, b_seq_len_k, max_input_len_k, + sparsity_n, + sparsity_m, + num_sink_blocks, + dense_window_blocks, ): HEAD_DIM = q.shape[2] num_q_heads = q.shape[1] @@ -552,6 +721,10 @@ def grid(META): IS_CAUSAL=is_causal, HEAD_DIM=HEAD_DIM, STORE_LSE=True, + SPARSITY_N=sparsity_n, + SPARSITY_M=sparsity_m, + NUM_SINK_BLOCKS=num_sink_blocks, + DENSE_WINDOW_BLOCKS=dense_window_blocks, # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune ) @@ -566,6 +739,10 @@ def grid(META): ctx.num_q_heads = num_q_heads ctx.num_kv_heads = num_kv_heads ctx.batch = batch + ctx.sparsity_n = sparsity_n + ctx.sparsity_m = sparsity_m + ctx.num_sink_blocks = num_sink_blocks + ctx.dense_window_blocks = dense_window_blocks return o @staticmethod @@ -640,6 +817,10 @@ def backward(ctx, grad_output): BLOCK_N=BLOCK, IS_CAUSAL=ctx.is_causal, HEAD_DIM=HEAD_DIM, + SPARSITY_N=ctx.sparsity_n, + SPARSITY_M=ctx.sparsity_m, + NUM_SINK_BLOCKS=ctx.num_sink_blocks, + DENSE_WINDOW_BLOCKS=ctx.dense_window_blocks, num_warps=num_warps, num_stages=1, ) @@ -659,11 +840,15 @@ def backward(ctx, grad_output): BLOCK_N=BLOCK, IS_CAUSAL=ctx.is_causal, HEAD_DIM=HEAD_DIM, + SPARSITY_N=ctx.sparsity_n, + SPARSITY_M=ctx.sparsity_m, + NUM_SINK_BLOCKS=ctx.num_sink_blocks, + DENSE_WINDOW_BLOCKS=ctx.dense_window_blocks, num_warps=num_warps, num_stages=1, ) - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None def attention( @@ -678,8 +863,13 @@ def attention( b_start_loc_k: torch.Tensor | None = None, b_seq_len_k: torch.Tensor | None = None, max_input_len_k: int | None = None, + *, + sparsity_n: int = 0, + sparsity_m: int = 4, + num_sink_blocks: int = 0, + dense_window_blocks: int = 1, ) -> torch.Tensor: - """Variable-length flash attention with GQA and autograd support. + """Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax. Args: q: [total_q_tokens, num_q_heads, head_dim] @@ -693,6 +883,14 @@ def attention( b_start_loc_k: [batch] start offset for K/V (None = same as Q). b_seq_len_k: [batch] length for K/V (None = same as Q). max_input_len_k: Maximum K/V sequence length (None = same as Q). + sparsity_n: N:M sparsity — keep top-N of every M attention scores + along the key dimension. Set to 0 to disable. Examples: + ``sparsity_n=2, sparsity_m=4`` for 2:4 sparsity; + ``sparsity_n=4, sparsity_m=8`` for 4:8 sparsity. + sparsity_m: N:M sparsity — group size (4 or 8). + num_sink_blocks: Number of leading KV blocks to keep dense (attention sinks). + dense_window_blocks: KV blocks within this distance from the query + diagonal are kept dense (local attention window). Returns: Output tensor [total_q_tokens, num_q_heads, head_dim]. @@ -710,6 +908,10 @@ def attention( b_start_loc_k, b_seq_len_k, max_input_len_k, + sparsity_n, + sparsity_m, + num_sink_blocks, + dense_window_blocks, ) diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 4baf5bbe6..21b5b8657 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -96,6 +96,39 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) + sparsity_n: int = ModeloptField( + default=2, + title="N in N:M sparsity.", + description=( + "Keep top-N of every M attention scores. Only used by triton_sparse_softmax. " + "Set to 0 to disable sparsity." + ), + ) + + sparsity_m: int = ModeloptField( + default=4, + title="M in N:M sparsity.", + description="Group size for N:M sparsity (4 or 8). Only used by triton_sparse_softmax.", + ) + + num_sink_blocks: int = ModeloptField( + default=0, + title="Number of sink blocks.", + description=( + "Number of leading KV blocks to keep dense (attention sinks). " + "Only used by triton_sparse_softmax." + ), + ) + + dense_window_blocks: int = ModeloptField( + default=1, + title="Dense window blocks.", + description=( + "Number of local attention blocks around diagonal to keep dense. " + "Only used by triton_sparse_softmax." + ), + ) + @field_validator("method") @classmethod def validate_method(cls, v): @@ -434,9 +467,27 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): } +# Default N:M sparse softmax configuration +SPARSE_SOFTMAX_DEFAULT = { + "sparse_cfg": { + "*attn*": { + "method": "triton_sparse_softmax", + "sparsity_n": 2, + "sparsity_m": 4, + "num_sink_blocks": 0, + "dense_window_blocks": 1, + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + }, +} + + __all__ = [ "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_DEFAULT", + "SPARSE_SOFTMAX_DEFAULT", "CalibrationConfig", "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py index 8a109fda7..1bd9a547d 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py @@ -24,4 +24,4 @@ ] # Import method implementations to trigger registration -from . import flash_skip_softmax +from . import flash_skip_softmax, triton_sparse_softmax diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py new file mode 100644 index 000000000..df99c4ac1 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""N:M sparse softmax method for attention scores via Triton kernel.""" + +from contextlib import contextmanager + +import torch + +from .registry import SparseAttentionMethod, register_sparse_method + + +@register_sparse_method("triton_sparse_softmax") +class TritonSparseSoftmaxMethod(SparseAttentionMethod): + """N:M sparse softmax applied to attention scores via Triton kernel. + + Sparsity is applied inside the fused Triton flash attention kernel, + not as a separate pre/post-processing step. For every M consecutive + K positions, the top-N attention scores are kept; the other M-N are + set to -inf before softmax. + + Config params: + sparsity_n: Keep top-N of every M attention scores (0 to disable). + sparsity_m: Group size (4 or 8). + num_sink_blocks: Number of leading KV blocks kept dense (attention sinks). + dense_window_blocks: Local attention blocks kept dense near diagonal. + """ + + def __init__(self, method_config=None): + """Initialize with N:M sparsity parameters from config.""" + super().__init__() + method_config = method_config or {} + self.sparsity_n = method_config.get("sparsity_n", 2) + self.sparsity_m = method_config.get("sparsity_m", 4) + self.num_sink_blocks = method_config.get("num_sink_blocks", 0) + self.dense_window_blocks = method_config.get("dense_window_blocks", 1) + + @property + def name(self) -> str: + """Method name identifier.""" + return "triton_sparse_softmax" + + def calculate_sparsity(self, attention_scores): + """Return a no-op mask (sparsity is applied inside the Triton kernel).""" + mask = torch.ones_like(attention_scores, dtype=torch.bool) + return mask, {} + + def apply_sparsity(self, attention_scores, sparse_mask=None): + """Not supported — sparsity is fused into the Triton kernel.""" + raise NotImplementedError( + "triton_sparse_softmax applies sparsity inside the Triton kernel. " + "Use backend='triton' or backend='vllm', not backend='pytorch'." + ) + + def get_sparse_context(self, module): + """Return context manager that activates N:M sparse softmax during forward.""" + + @contextmanager + def _sparse_nm_context(): + module._apply_sparse_nm = True + try: + yield + finally: + module._apply_sparse_nm = False + + return _sparse_nm_context() diff --git a/tests/gpu/torch/sparsity/attention_sparsity/conftest.py b/tests/gpu/torch/sparsity/attention_sparsity/conftest.py new file mode 100644 index 000000000..fa4f61771 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/conftest.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared fixtures and helpers for Triton flash attention tests.""" + +import pytest +import torch +import torch.nn.functional as F + + +def make_qkv(total, num_heads, num_kv_heads, head_dim, device="cuda", dtype=torch.float16): + """Create packed Q, K, V tensors.""" + q = torch.randn(total, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(total, num_kv_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(total, num_kv_heads, head_dim, device=device, dtype=dtype) + return q, k, v + + +def make_varlen_meta(seq_lens, device="cuda"): + """Create b_start_loc and b_seq_len from a list of sequence lengths.""" + b_seq_len = torch.tensor(seq_lens, device=device, dtype=torch.int32) + b_start_loc = torch.zeros(len(seq_lens), device=device, dtype=torch.int32) + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], dim=0) + return b_start_loc, b_seq_len + + +def sdpa_reference(q, k, v, b_start_loc, b_seq_len, is_causal=True): + """SDPA reference. Supports GQA. Returns [total_tokens, num_heads, dim].""" + batch = b_seq_len.shape[0] + num_q, num_kv = q.shape[1], k.shape[1] + parts = [] + for b in range(batch): + s, n = int(b_start_loc[b].item()), int(b_seq_len[b].item()) + qb = q[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) + kb = k[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) + vb = v[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) + if num_q != num_kv: + r = num_q // num_kv + kb = kb.repeat_interleave(r, dim=1) + vb = vb.repeat_interleave(r, dim=1) + ob = F.scaled_dot_product_attention(qb, kb, vb, is_causal=is_causal) + parts.append(ob.permute(0, 2, 1, 3).squeeze(0)) + return torch.cat(parts, dim=0) + + +@pytest.fixture(scope="module") +def tiny_llama_dir(tmp_path_factory): + """Tiny Llama: 2 layers, 64 hidden, 4 q-heads, 2 kv-heads, head_dim=16.""" + from _test_utils.torch.transformers_models import create_tiny_llama_dir + + return create_tiny_llama_dir( + tmp_path_factory.mktemp("tiny_llama"), + with_tokenizer=True, + num_hidden_layers=2, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=64, + max_position_embeddings=64, + ) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py index c86a4131e..73d20ba52 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""GPU tests for Triton flash attention kernel.""" +"""GPU tests for Triton flash attention kernel (dense path).""" import pytest import torch import torch.nn.functional as F +from conftest import make_qkv, make_varlen_meta, sdpa_reference pytestmark = [ pytest.mark.filterwarnings("ignore::UserWarning"), @@ -34,45 +35,14 @@ register_triton_attention() -def _sdpa_reference(q, k, v, b_start_loc, b_seq_len): - """SDPA causal reference. Supports GQA. Returns [total_tokens, num_heads, dim].""" - batch = b_seq_len.shape[0] - num_q, num_kv = q.shape[1], k.shape[1] - parts = [] - for b in range(batch): - s, n = int(b_start_loc[b].item()), int(b_seq_len[b].item()) - qb = q[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) - kb = k[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) - vb = v[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) - if num_q != num_kv: - r = num_q // num_kv - kb = kb.repeat_interleave(r, dim=1) - vb = vb.repeat_interleave(r, dim=1) - ob = F.scaled_dot_product_attention(qb, kb, vb, is_causal=True) - parts.append(ob.permute(0, 2, 1, 3).squeeze(0)) - return torch.cat(parts, dim=0) - - -@pytest.fixture(scope="module") -def tiny_llama_dir(tmp_path_factory): - """Tiny Llama: 2 layers, 64 hidden, 4 q-heads, 2 kv-heads, head_dim=16.""" - from _test_utils.torch.transformers_models import create_tiny_llama_dir - - return create_tiny_llama_dir( - tmp_path_factory.mktemp("tiny_llama"), - with_tokenizer=True, - num_hidden_layers=2, - hidden_size=64, - num_attention_heads=4, - num_key_value_heads=2, - intermediate_size=64, - max_position_embeddings=64, - ) +# --------------------------------------------------------------------------- +# Forward correctness +# --------------------------------------------------------------------------- @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") -class TestTritonFaVsSdpa: - """Triton flash attention matches PyTorch SDPA for prefill and decode.""" +class TestForward: + """Forward pass correctness for dense attention.""" @pytest.mark.parametrize( ("dtype", "num_heads", "num_kv_heads", "head_dim"), @@ -84,42 +54,27 @@ class TestTritonFaVsSdpa: ids=["fp32_mha", "fp16_gqa", "bf16_gqa_hdim128"], ) def test_prefill_matches_sdpa(self, dtype, num_heads, num_kv_heads, head_dim): - """Prefill matches SDPA.""" + """Dense prefill matches SDPA.""" seq_lens = [8, 12] total = sum(seq_lens) scale = 1.0 / (head_dim**0.5) torch.manual_seed(123) - q = torch.randn(total, num_heads, head_dim, device="cuda", dtype=dtype) - k = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) - v = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) - locs = torch.tensor([0, seq_lens[0]], device="cuda", dtype=torch.int32) - lens = torch.tensor(seq_lens, device="cuda", dtype=torch.int32) - - o = attention( - q, - k, - v, - b_start_loc=locs, - b_seq_len=lens, - max_input_len=max(seq_lens), - is_causal=True, - softmax_scale=scale, - ) - torch.testing.assert_close(o, _sdpa_reference(q, k, v, locs, lens), rtol=1e-3, atol=1e-3) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim, dtype=dtype) + locs, lens = make_varlen_meta(seq_lens) + + o = attention(q, k, v, locs, lens, max(seq_lens), softmax_scale=scale) + torch.testing.assert_close(o, sdpa_reference(q, k, v, locs, lens), rtol=1e-3, atol=1e-3) def test_decode_matches_sdpa(self): - """Decode matches SDPA.""" + """Dense decode matches SDPA.""" batch = 2 - seq_lens_k = [5, 9] # KV lengths (context + current token) + seq_lens_k = [5, 9] num_heads, num_kv_heads, head_dim = 4, 2, 32 scale = 1.0 / (head_dim**0.5) torch.manual_seed(103) - # Q: one token per batch element -> flat [batch, num_heads, head_dim] q_flat = torch.randn(batch, num_heads, head_dim, device="cuda", dtype=torch.float32) - - # K/V: variable-length, packed into flat tensors total_kv = sum(seq_lens_k) k_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) v_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) @@ -136,9 +91,9 @@ def test_decode_matches_sdpa(self): q_flat, k_flat, v_flat, - b_start_loc=b_start_loc_q, - b_seq_len=b_seq_len_q, - max_input_len=1, + b_start_loc_q, + b_seq_len_q, + 1, is_causal=False, softmax_scale=scale, b_start_loc_k=b_start_loc_k, @@ -149,7 +104,7 @@ def test_decode_matches_sdpa(self): for i in range(batch): sl = seq_lens_k[i] s = cumsum[i] - qb = q_flat[i : i + 1].unsqueeze(2) # [1, heads, 1, dim] + qb = q_flat[i : i + 1].unsqueeze(2) kb = k_flat[s : s + sl].unsqueeze(0).permute(0, 2, 1, 3) vb = v_flat[s : s + sl].unsqueeze(0).permute(0, 2, 1, 3) kb = kb.repeat_interleave(num_heads // num_kv_heads, dim=1) @@ -157,10 +112,152 @@ def test_decode_matches_sdpa(self): ref = F.scaled_dot_product_attention(qb, kb, vb, is_causal=False).squeeze(2) torch.testing.assert_close(out[i : i + 1], ref, rtol=1e-3, atol=1e-3) + def test_sparse_disabled_matches_dense(self): + """sparsity_n=0 produces bit-identical output to default (dense).""" + seq_lens = [128, 128] + total = sum(seq_lens) + scale = 1.0 / (64**0.5) + + torch.manual_seed(99) + q, k, v = make_qkv(total, 4, 2, 64) + locs, lens = make_varlen_meta(seq_lens) + + out_dense = attention(q, k, v, locs, lens, 128, softmax_scale=scale) + out_n0 = attention(q, k, v, locs, lens, 128, softmax_scale=scale, sparsity_n=0) + assert torch.equal(out_dense, out_n0) + + +# --------------------------------------------------------------------------- +# Backward correctness (dense) +# --------------------------------------------------------------------------- + @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") -class TestSparseAttentionIntegration: - """HF model + mtsa.sparsify integration.""" +class TestBackward: + """Backward pass gradient correctness for dense attention.""" + + def _sdpa_backward_ref(self, q, k, v, scale, is_causal=True): + """Run SDPA forward+backward, return gradients.""" + q_ref = q.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) + k_ref = k.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) + v_ref = v.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) + num_q, num_kv = q_ref.shape[1], k_ref.shape[1] + if num_q != num_kv: + r = num_q // num_kv + k_exp = k_ref.repeat_interleave(r, dim=1) + v_exp = v_ref.repeat_interleave(r, dim=1) + else: + k_exp, v_exp = k_ref, v_ref + o_ref = F.scaled_dot_product_attention( + q_ref, k_exp, v_exp, is_causal=is_causal, scale=scale + ) + o_ref.sum().backward() + dq = q_ref.grad.permute(0, 2, 1, 3).squeeze(0) + dk = k_ref.grad.permute(0, 2, 1, 3).squeeze(0) + dv = v_ref.grad.permute(0, 2, 1, 3).squeeze(0) + return dq.detach(), dk.detach(), dv.detach() + + def test_dense_causal_matches_sdpa(self): + """dQ, dK, dV match SDPA for causal self-attention.""" + seq_len, num_heads, num_kv_heads, head_dim = 16, 2, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(42) + q, k, v = make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + locs, lens = make_varlen_meta([seq_len]) + + attention(q, k, v, locs, lens, seq_len, softmax_scale=scale).sum().backward() + dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref(q.detach(), k.detach(), v.detach(), scale) + + torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) + + def test_dense_gqa_matches_sdpa(self): + """Dense backward with GQA (4 q-heads, 2 kv-heads), seq_len=256.""" + seq_len, num_heads, num_kv_heads, head_dim = 256, 4, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(43) + q, k, v = make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + locs, lens = make_varlen_meta([seq_len]) + + attention(q, k, v, locs, lens, seq_len, softmax_scale=scale).sum().backward() + dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref(q.detach(), k.detach(), v.detach(), scale) + + torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) + + def test_dense_multi_batch_variable_length(self): + """Multi-batch variable-length backward matches per-sample SDPA.""" + seq_lens = [8, 12] + total = sum(seq_lens) + num_heads, num_kv_heads, head_dim = 2, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(45) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + locs, lens = make_varlen_meta(seq_lens) + + attention(q, k, v, locs, lens, max(seq_lens), softmax_scale=scale).sum().backward() + + dq_ref = torch.zeros_like(q) + dk_ref = torch.zeros_like(k) + dv_ref = torch.zeros_like(v) + for b in range(len(seq_lens)): + s, n = int(locs[b].item()), seq_lens[b] + dq_b, dk_b, dv_b = self._sdpa_backward_ref( + q.detach()[s : s + n], + k.detach()[s : s + n], + v.detach()[s : s + n], + scale, + ) + dq_ref[s : s + n] = dq_b + dk_ref[s : s + n] = dk_b + dv_ref[s : s + n] = dv_b + + torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) + + def test_dense_longer_sequences(self): + """Dense backward with seq_len=512, GQA, exercises multi-tile loops.""" + seq_len, num_heads, num_kv_heads, head_dim = 512, 4, 2, 64 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(49) + q, k, v = make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + locs, lens = make_varlen_meta([seq_len]) + + attention(q, k, v, locs, lens, seq_len, softmax_scale=scale).sum().backward() + dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref(q.detach(), k.detach(), v.detach(), scale) + + torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) + + +# --------------------------------------------------------------------------- +# HuggingFace integration +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestHFIntegration: + """HF model integration with Triton attention backend.""" def test_triton_matches_eager(self, tiny_llama_dir): """Triton attention produces same logits and generated tokens as eager.""" @@ -172,7 +269,6 @@ def test_triton_matches_eager(self, tiny_llama_dir): tok.pad_token_id = tok.eos_token_id ids = tok("The capital of France is", return_tensors="pt").input_ids.to("cuda") - # Eager baseline model_eager = AutoModelForCausalLM.from_pretrained( tiny_llama_dir, attn_implementation="eager", @@ -190,7 +286,6 @@ def test_triton_matches_eager(self, tiny_llama_dir): ) del model_eager - # Triton model_triton = AutoModelForCausalLM.from_pretrained( tiny_llama_dir, attn_implementation="modelopt_triton", @@ -207,15 +302,13 @@ def test_triton_matches_eager(self, tiny_llama_dir): pad_token_id=tok.pad_token_id, ) - # Logits should be close (bf16 tolerance) torch.testing.assert_close(logits_triton, logits_eager, rtol=2e-2, atol=2e-2) - # Generated tokens must be identical (greedy decoding is deterministic) assert torch.equal(out_triton, out_eager), ( f"Generated tokens differ:\n eager: {out_eager}\n triton: {out_triton}" ) def test_triton_padded_batch(self, tiny_llama_dir): - """Padded batch (2D attention mask) produces valid logits for each sequence.""" + """Padded batch produces valid logits.""" pytest.importorskip("transformers") from transformers import AutoModelForCausalLM, AutoTokenizer @@ -240,200 +333,60 @@ def test_triton_padded_batch(self, tiny_llama_dir): logits = model(**inputs).logits assert not torch.isnan(logits).any() and not torch.isinf(logits).any() + def test_sparse_nm_via_sparsify(self, tiny_llama_dir): + """mtsa.sparsify() with N:M sparse softmax produces finite logits that differ from dense.""" + pytest.importorskip("transformers") + from transformers import AutoModelForCausalLM, AutoTokenizer -@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") -class TestBackward: - """Backward pass gradient correctness tests.""" - - def _sdpa_backward_ref(self, q, k, v, scale, is_causal=True): - """Run SDPA forward+backward, return output and gradients.""" - q_ref = q.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) - k_ref = k.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) - v_ref = v.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) - num_q, num_kv = q_ref.shape[1], k_ref.shape[1] - if num_q != num_kv: - r = num_q // num_kv - k_exp = k_ref.repeat_interleave(r, dim=1) - v_exp = v_ref.repeat_interleave(r, dim=1) - else: - k_exp, v_exp = k_ref, v_ref - o_ref = F.scaled_dot_product_attention( - q_ref, k_exp, v_exp, is_causal=is_causal, scale=scale - ) - o_ref.sum().backward() - dq = q_ref.grad.permute(0, 2, 1, 3).squeeze(0) - dk = k_ref.grad.permute(0, 2, 1, 3).squeeze(0) - dv = v_ref.grad.permute(0, 2, 1, 3).squeeze(0) - return o_ref.permute(0, 2, 1, 3).squeeze(0).detach(), dq.detach(), dk.detach(), dv.detach() - - def test_backward_causal_matches_sdpa(self): - """dQ, dK, dV match SDPA backward for causal self-attention.""" - from modelopt.torch.kernels import attention - - seq_len = 16 - num_heads, num_kv_heads, head_dim = 2, 2, 32 - scale = 1.0 / (head_dim**0.5) - - torch.manual_seed(42) - q = torch.randn( - seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - k = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - v = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - - o = attention( - q, - k, - v, - b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), - b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), - max_input_len=seq_len, - is_causal=True, - softmax_scale=scale, - ) - o.sum().backward() - - _, dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref( - q.detach(), k.detach(), v.detach(), scale, is_causal=True - ) - - torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) - - def test_backward_gqa(self): - """Backward with GQA (4 q-heads, 2 kv-heads), multi-tile (seq_len=256).""" - from modelopt.torch.kernels import attention - - seq_len = 256 - num_heads, num_kv_heads, head_dim = 4, 2, 32 - scale = 1.0 / (head_dim**0.5) - - torch.manual_seed(43) - q = torch.randn( - seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - k = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - v = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - - o = attention( - q, - k, - v, - b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), - b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), - max_input_len=seq_len, - is_causal=True, - softmax_scale=scale, - ) - o.sum().backward() - - _, dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref( - q.detach(), k.detach(), v.detach(), scale, is_causal=True - ) - - torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) - - def test_backward_multi_batch_variable_length(self): - """Multi-batch variable-length causal backward matches per-sample SDPA.""" - from modelopt.torch.kernels import attention - - seq_lens = [8, 12] - total = sum(seq_lens) - num_heads, num_kv_heads, head_dim = 2, 2, 32 - scale = 1.0 / (head_dim**0.5) - - torch.manual_seed(45) - q = torch.randn( - total, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - k = torch.randn( - total, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - v = torch.randn( - total, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - locs = torch.tensor([0, seq_lens[0]], device="cuda", dtype=torch.int32) - lens = torch.tensor(seq_lens, device="cuda", dtype=torch.int32) - - o = attention( - q, - k, - v, - b_start_loc=locs, - b_seq_len=lens, - max_input_len=max(seq_lens), - is_causal=True, - softmax_scale=scale, - ) - o.sum().backward() - - # Per-sample SDPA reference - dq_ref = torch.zeros_like(q) - dk_ref = torch.zeros_like(k) - dv_ref = torch.zeros_like(v) - for b in range(len(seq_lens)): - s, n = int(locs[b].item()), seq_lens[b] - _, dq_b, dk_b, dv_b = self._sdpa_backward_ref( - q.detach()[s : s + n], - k.detach()[s : s + n], - v.detach()[s : s + n], - scale, - is_causal=True, - ) - dq_ref[s : s + n] = dq_b - dk_ref[s : s + n] = dk_b - dv_ref[s : s + n] = dv_b - - torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) - - def test_backward_longer_sequences(self): - """Backward with seq_len=512, GQA, exercises multi-tile loops.""" - from modelopt.torch.kernels import attention + import modelopt.torch.sparsity.attention_sparsity as mtsa - seq_len = 512 - num_heads, num_kv_heads, head_dim = 4, 2, 64 - scale = 1.0 / (head_dim**0.5) + tok = AutoTokenizer.from_pretrained(tiny_llama_dir) + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + # Use a long input (fill max_position_embeddings=64) so sparsity has tiles to prune + ids = torch.randint(1, tok.vocab_size, (1, 64), device="cuda") - torch.manual_seed(49) - q = torch.randn( - seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - k = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - v = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + # Dense baseline (triton backend, no sparsity) + model_dense = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="modelopt_triton", + torch_dtype=torch.bfloat16, + device_map="cuda", ) - - o = attention( - q, - k, - v, - b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), - b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), - max_input_len=seq_len, - is_causal=True, - softmax_scale=scale, + model_dense.eval() + with torch.no_grad(): + logits_dense = model_dense(input_ids=ids).logits + del model_dense + + # Sparse via mtsa.sparsify() with dense_window_blocks=0 to force sparsity on all tiles + sparse_cfg = { + "sparse_cfg": { + "*attn*": { + "method": "triton_sparse_softmax", + "sparsity_n": 2, + "sparsity_m": 4, + "num_sink_blocks": 0, + "dense_window_blocks": 0, + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + }, + } + model_sparse = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + torch_dtype=torch.bfloat16, + device_map="cuda", ) - o.sum().backward() - - _, dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref( - q.detach(), k.detach(), v.detach(), scale, is_causal=True + mtsa.sparsify(model_sparse, sparse_cfg) + model_sparse.eval() + with torch.no_grad(): + logits_sparse = model_sparse(input_ids=ids).logits + + # Sparse output should be finite + assert not torch.isnan(logits_sparse).any(), "NaN in sparse logits" + assert not torch.isinf(logits_sparse).any(), "Inf in sparse logits" + # Sparse output should differ from dense (sparsity changes attention) + assert not torch.allclose(logits_sparse, logits_dense, atol=1e-2), ( + "Sparse logits identical to dense — sparsity may not be applied" ) - - torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py new file mode 100644 index 000000000..7f89be325 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py @@ -0,0 +1,348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: N803 — Triton JIT wrapper uses uppercase for constexpr and tensor args + +"""GPU tests for N:M sparse softmax on the Triton flash attention kernel.""" + +import pytest +import torch +from conftest import make_qkv, make_varlen_meta + +pytestmark = [ + pytest.mark.filterwarnings("ignore::UserWarning"), + pytest.mark.filterwarnings("ignore::RuntimeWarning"), + pytest.mark.filterwarnings("ignore::DeprecationWarning"), +] + +from modelopt.torch.kernels import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE + +if TRITON_KERNEL_AVAILABLE: + import triton + import triton.language as tl + + from modelopt.torch.kernels import attention + from modelopt.torch.kernels.triton_fa import _apply_sparse_nm_to_qk_tile + + @triton.jit + def _test_apply_sparse_nm( + In, + Out, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + SPARSITY_N: tl.constexpr, + SPARSITY_M: tl.constexpr, + ): + """Test wrapper: apply N:M sparsity to a tile and store result.""" + offs = tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + qk = tl.load(In + offs) + tl.store( + Out + offs, + _apply_sparse_nm_to_qk_tile(qk, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M), + ) + + +# --------------------------------------------------------------------------- +# N:M sparsity behavior (prefill only) +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSparseNM: + """N:M sparse softmax behavior on attention scores.""" + + def _make_inputs(self, batch=2, seq_len=256, num_heads=4, num_kv_heads=2, head_dim=64): + total = batch * seq_len + torch.manual_seed(99) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + return q, k, v, locs, lens + + @pytest.mark.parametrize( + ("n", "m"), + [(1, 4), (2, 4), (3, 4), (1, 8), (2, 8), (4, 8)], + ids=["1:4", "2:4", "3:4", "1:8", "2:8", "4:8"], + ) + def test_output_shape(self, n, m): + """Output shape matches Q shape for all N:M patterns.""" + q, k, v, locs, lens = self._make_inputs() + out = attention( + q, k, v, locs, lens, 256, softmax_scale=1.0 / 8.0, sparsity_n=n, sparsity_m=m + ) + assert out.shape == q.shape + + @pytest.mark.parametrize( + ("n", "m"), + [(1, 4), (2, 4), (3, 4), (1, 8), (2, 8), (4, 8)], + ids=["1:4", "2:4", "3:4", "1:8", "2:8", "4:8"], + ) + def test_no_nan(self, n, m): + """All N:M patterns produce finite output.""" + q, k, v, locs, lens = self._make_inputs() + out = attention( + q, k, v, locs, lens, 256, softmax_scale=1.0 / 8.0, sparsity_n=n, sparsity_m=m + ) + assert not torch.isnan(out).any(), f"NaN in output for {n}:{m}" + assert not torch.isinf(out).any(), f"Inf in output for {n}:{m}" + + @pytest.mark.parametrize( + ("n", "m"), + [(1, 4), (2, 4), (1, 8), (4, 8)], + ids=["1:4", "2:4", "1:8", "4:8"], + ) + def test_sparse_differs_from_dense(self, n, m): + """Sparse output should differ from dense for long sequences.""" + q, k, v, locs, lens = self._make_inputs(seq_len=512) + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 512, softmax_scale=scale) + out_sparse = attention( + q, k, v, locs, lens, 512, softmax_scale=scale, sparsity_n=n, sparsity_m=m + ) + assert not torch.allclose(out_sparse, out_dense, atol=1e-3) + + @pytest.mark.parametrize( + ("n_values", "m"), + [([1, 2, 3], 4), ([1, 2, 4], 8)], + ids=["m4", "m8"], + ) + def test_more_sparsity_more_error(self, n_values, m): + """Keeping more elements should deviate less from dense (monotonic decreasing error).""" + q, k, v, locs, lens = self._make_inputs(seq_len=512) + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 512, softmax_scale=scale) + errors = [] + for n in n_values: + out = attention( + q, k, v, locs, lens, 512, softmax_scale=scale, sparsity_n=n, sparsity_m=m + ) + errors.append((out - out_dense).abs().mean().item()) + for i in range(len(errors) - 1): + assert errors[i] > errors[i + 1], ( + f"Errors not monotonically decreasing for M={m}: " + + ", ".join(f"{n}:{m}={e:.6f}" for n, e in zip(n_values, errors)) + ) + + @pytest.mark.parametrize( + ("n", "m"), + [(2, 4), (4, 8)], + ids=["2:4", "4:8"], + ) + def test_dense_window_preserves_local(self, n, m): + """Large dense_window_blocks makes sparse output closer to dense.""" + q, k, v, locs, lens = self._make_inputs(seq_len=256) + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 256, softmax_scale=scale) + out_small = attention( + q, + k, + v, + locs, + lens, + 256, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + dense_window_blocks=1, + ) + out_large = attention( + q, + k, + v, + locs, + lens, + 256, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + dense_window_blocks=100, + ) + err_small = (out_small - out_dense).abs().mean().item() + err_large = (out_large - out_dense).abs().mean().item() + assert err_large < err_small + + @pytest.mark.parametrize( + ("n", "m"), + [(2, 4), (4, 8)], + ids=["2:4", "4:8"], + ) + def test_sink_blocks_preserve_early_kv(self, n, m): + """num_sink_blocks keeps early KV blocks dense, reducing error vs fully sparse.""" + q, k, v, locs, lens = self._make_inputs(seq_len=512) + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 512, softmax_scale=scale) + out_no_sink = attention( + q, + k, + v, + locs, + lens, + 512, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + num_sink_blocks=0, + ) + out_with_sink = attention( + q, + k, + v, + locs, + lens, + 512, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + num_sink_blocks=2, + ) + err_no_sink = (out_no_sink - out_dense).abs().mean().item() + err_with_sink = (out_with_sink - out_dense).abs().mean().item() + assert err_with_sink < err_no_sink, ( + f"Sink tokens should reduce error: no_sink={err_no_sink:.6f}, with_sink={err_with_sink:.6f}" + ) + + # NOTE: N:M sparse attention is for prefill only, not decode. + + +# --------------------------------------------------------------------------- +# Sparsity tile structure +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSparseTileStructure: + """Direct unit tests for _apply_sparse_nm_to_qk_tile via wrapper kernel.""" + + @pytest.mark.parametrize( + ("n", "m"), + [(1, 4), (2, 4), (3, 4), (1, 8), (2, 8), (4, 8)], + ids=["1:4", "2:4", "3:4", "1:8", "2:8", "4:8"], + ) + def test_sparsity_structure(self, n, m): + """Verify N:M structure: exactly N kept per group of M.""" + bm, bn = 32, 64 + torch.manual_seed(88) + tile = torch.randn(bm, bn, device="cuda", dtype=torch.float32) + out = torch.empty_like(tile) + _test_apply_sparse_nm[(1,)](tile, out, BLOCK_M=bm, BLOCK_N=bn, SPARSITY_N=n, SPARSITY_M=m) + + kept = (out.reshape(bm, bn // m, m) != float("-inf")).sum(dim=-1) + assert (kept == n).all(), ( + f"Expected {n} kept per group of {m}, got min={kept.min()}, max={kept.max()}" + ) + + @pytest.mark.parametrize( + ("n", "m"), + [(2, 4), (4, 8)], + ids=["2:4", "4:8"], + ) + def test_sparsity_structure_ties(self, n, m): + """M=4 keeps exactly N on ties; M=8 (tl.sort) may keep >= N on ties.""" + bm, bn = 32, 64 + tile = torch.ones(bm, bn, device="cuda", dtype=torch.float32) + out = torch.empty_like(tile) + _test_apply_sparse_nm[(1,)](tile, out, BLOCK_M=bm, BLOCK_N=bn, SPARSITY_N=n, SPARSITY_M=m) + + kept = (out.reshape(bm, bn // m, m) != float("-inf")).sum(dim=-1) + if m == 4: + assert (kept == n).all(), ( + f"M=4 tie: expected {n}, got min={kept.min()}, max={kept.max()}" + ) + else: + assert (kept >= n).all(), f"M=8 tie: expected >= {n}, got min={kept.min()}" + assert (kept <= m).all(), f"M=8 tie: expected <= {m}, got max={kept.max()}" + + +# --------------------------------------------------------------------------- +# Sparse backward sanity +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSparseBackward: + """Backward pass sanity checks with N:M sparsity enabled.""" + + @pytest.mark.parametrize( + ("n", "m"), + [(2, 4), (4, 8)], + ids=["2:4", "4:8"], + ) + def test_sparse_gradients_finite(self, n, m): + """Backward with N:M sparsity produces finite, non-zero gradients.""" + seq_len, num_heads, num_kv_heads, head_dim = 128, 4, 2, 64 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(55) + q, k, v = make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + locs, lens = make_varlen_meta([seq_len]) + + attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + ).sum().backward() + + for name, grad in [("dQ", q.grad), ("dK", k.grad), ("dV", v.grad)]: + assert grad is not None, f"{name} is None for {n}:{m}" + assert not torch.isnan(grad).any(), f"NaN in {name} for {n}:{m}" + assert not torch.isinf(grad).any(), f"Inf in {name} for {n}:{m}" + assert grad.abs().sum() > 0, f"{name} is all zeros for {n}:{m}" + + def test_sparse_gradients_differ_from_dense(self): + """Gradients with 2:4 sparsity should differ from dense gradients.""" + seq_len, num_heads, num_kv_heads, head_dim = 256, 4, 2, 64 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(66) + q, k, v = make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + locs, lens = make_varlen_meta([seq_len]) + + q_d = q.clone().requires_grad_(True) + k_d = k.clone().requires_grad_(True) + v_d = v.clone().requires_grad_(True) + attention(q_d, k_d, v_d, locs, lens, seq_len, softmax_scale=scale).sum().backward() + + q_s = q.clone().requires_grad_(True) + k_s = k.clone().requires_grad_(True) + v_s = v.clone().requires_grad_(True) + attention( + q_s, + k_s, + v_s, + locs, + lens, + seq_len, + softmax_scale=scale, + sparsity_n=2, + sparsity_m=4, + ).sum().backward() + + assert not torch.allclose(q_d.grad, q_s.grad, atol=1e-3), ( + "dQ same with and without sparsity" + ) + assert not torch.allclose(k_d.grad, k_s.grad, atol=1e-3), ( + "dK same with and without sparsity" + ) + assert not torch.allclose(v_d.grad, v_s.grad, atol=1e-3), ( + "dV same with and without sparsity" + )