From bb9211cad4d8c3a2bf3830f6617e1c6d7f399602 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Thu, 19 Mar 2026 16:27:40 -0700 Subject: [PATCH 1/2] Add 2:4 sparse softmax to the Triton flash attention kernel Signed-off-by: Kai Xu --- modelopt/torch/kernels/triton_fa.py | 203 +++++- .../attention_sparsity/test_triton_fa.py | 612 +++++++++++------- 2 files changed, 589 insertions(+), 226 deletions(-) diff --git a/modelopt/torch/kernels/triton_fa.py b/modelopt/torch/kernels/triton_fa.py index b9184788b..c39c23165 100644 --- a/modelopt/torch/kernels/triton_fa.py +++ b/modelopt/torch/kernels/triton_fa.py @@ -45,6 +45,112 @@ _FWD_CONFIGS = [triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=1, num_warps=4)] +# --------------------------------------------------------------------------- +# N:M structured sparsity helpers +# --------------------------------------------------------------------------- +@triton.jit +def _sparse_nm_masks_m4(x0, x1, x2, x3, N: tl.constexpr): + """Top-N of 4 selection via pure boolean logic (6 comparisons, no int casts). + + Uses ``>=`` so that ties are broken by index (lower index wins). + Guarantees exactly N masks are True for any input including all-equal. + + Boolean formulas for "at least K of 3 wins": + K=3 (N=1): AND of all — must beat all 3 others + K=2 (N=2): majority — must beat at least 2 (sorting network) + K=1 (N=3): OR of all — must beat at least 1 + """ + c01 = x0 >= x1 + c02 = x0 >= x2 + c03 = x0 >= x3 + c12 = x1 >= x2 + c13 = x1 >= x3 + c23 = x2 >= x3 + + nc01 = ~c01 + nc02 = ~c02 + nc03 = ~c03 + nc12 = ~c12 + nc13 = ~c13 + nc23 = ~c23 + + if N == 1: + # Keep max only: must beat all 3 + m0 = c01 & c02 & c03 + m1 = nc01 & c12 & c13 + m2 = nc02 & nc12 & c23 + m3 = nc03 & nc13 & nc23 + elif N == 2: + # Majority vote: must beat at least 2 of 3 + m0 = (c01 & c02) | (c01 & c03) | (c02 & c03) + m1 = (nc01 & c12) | (nc01 & c13) | (c12 & c13) + m2 = (nc02 & nc12) | (nc02 & c23) | (nc12 & c23) + m3 = (nc03 & nc13) | (nc03 & nc23) | (nc13 & nc23) + elif N == 3: + # Keep all but min: must beat at least 1 + m0 = c01 | c02 | c03 + m1 = nc01 | c12 | c13 + m2 = nc02 | nc12 | c23 + m3 = nc03 | nc13 | nc23 + else: + tl.static_assert(False, "N must be 1, 2, or 3 for M=4") + + return m0, m1, m2, m3 + + +@triton.jit +def _apply_sparse_nm_to_qk_tile( + qk, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + SPARSITY_N: tl.constexpr, + SPARSITY_M: tl.constexpr, +): + """Apply N:M structured sparsity to a QK score tile. + + For every ``SPARSITY_M`` consecutive elements along the N (key) dimension, + keeps the top ``SPARSITY_N`` values and sets the rest to ``-inf``. + ``BLOCK_N`` must be divisible by ``SPARSITY_M``. + """ + tl.static_assert(SPARSITY_M == 4 or SPARSITY_M == 8, "SPARSITY_M must be 4 or 8") # noqa: PLR1714 + MASK_VAL: tl.constexpr = float("-inf") + + if SPARSITY_M == 4: + tl.static_assert(BLOCK_N % 4 == 0, "BLOCK_N must be divisible by 4") + reshaped = tl.reshape(qk, (BLOCK_M, BLOCK_N // 4, 4)) + cols = tl.arange(0, 4)[None, None, :] + x0 = tl.sum(tl.where(cols == 0, reshaped, 0.0), axis=2) + x1 = tl.sum(tl.where(cols == 1, reshaped, 0.0), axis=2) + x2 = tl.sum(tl.where(cols == 2, reshaped, 0.0), axis=2) + x3 = tl.sum(tl.where(cols == 3, reshaped, 0.0), axis=2) + + m0, m1, m2, m3 = _sparse_nm_masks_m4(x0, x1, x2, x3, SPARSITY_N) + + out = tl.full((BLOCK_M, BLOCK_N // 4, 4), 0.0, dtype=qk.dtype) + out = tl.where(cols == 0, tl.expand_dims(tl.where(m0, x0, MASK_VAL), 2), out) + out = tl.where(cols == 1, tl.expand_dims(tl.where(m1, x1, MASK_VAL), 2), out) + out = tl.where(cols == 2, tl.expand_dims(tl.where(m2, x2, MASK_VAL), 2), out) + out = tl.where(cols == 3, tl.expand_dims(tl.where(m3, x3, MASK_VAL), 2), out) + return tl.reshape(out, (BLOCK_M, BLOCK_N)) + + else: # SPARSITY_M == 8 + tl.static_assert(BLOCK_N % 8 == 0, "BLOCK_N must be divisible by 8") + reshaped = tl.reshape(qk, (BLOCK_M, BLOCK_N // 8, 8)) + + # Sort each group of 8 ascending; N-th largest is at index (8 - N) + sorted_vals = tl.sort(reshaped, dim=2) + KTH_IDX: tl.constexpr = SPARSITY_M - SPARSITY_N # index of N-th largest in ascending order + + # Extract the threshold value (one extraction vs eight before) + # Use 0.0 as fill (not -inf) so sum equals just the KTH element + cols = tl.arange(0, 8)[None, None, :] + threshold = tl.sum(tl.where(cols == KTH_IDX, sorted_vals, 0.0), axis=2) + + # Mask: keep elements >= threshold (may keep >N on ties — acceptable) + mask = reshaped >= tl.expand_dims(threshold, 2) + return tl.reshape(tl.where(mask, reshaped, MASK_VAL), (BLOCK_M, BLOCK_N)) + + # --------------------------------------------------------------------------- # Masking helper # --------------------------------------------------------------------------- @@ -105,6 +211,10 @@ def _attn_fwd( IS_CAUSAL: tl.constexpr, # Whether to apply causal mask HEAD_DIM: tl.constexpr, # Actual head dimension (for d_mask) STORE_LSE: tl.constexpr, # Whether to save LSE for backward pass + SPARSITY_N: tl.constexpr = 0, # N:M sparsity — keep top-N of every M elements (0 = disabled) + SPARSITY_M: tl.constexpr = 4, # N:M sparsity — group size (4 or 8) + NUM_SINK_TOKENS: tl.constexpr = 0, # First N tokens kept dense (attention sinks) + DENSE_WINDOW_BLOCKS: tl.constexpr = 1, # Local blocks near diagonal kept dense ): # --- Grid: (batch, num_q_heads, num_q_tiles) --- # Example: batch=2, num_q_heads=32, seq_len=256, BLOCK_M=128 @@ -162,6 +272,21 @@ def _attn_fwd( scores = tl.dot(q, k) * qk_scale scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) + # --- Optional 2:4 structured sparsity --- + if SPARSITY_N > 0: + # Check if this KV tile should be kept dense + is_sink = kv_start < NUM_SINK_TOKENS + # causal_offset handles chunked prefill: q starts at (seq_len_kv - seq_len_q) + causal_offset = seq_len_kv - seq_len_q + q_abs_block = (tile_q * BLOCK_M + causal_offset) // BLOCK_N + kv_block_idx = kv_start // BLOCK_N + block_distance = q_abs_block - kv_block_idx + is_local = (block_distance < DENSE_WINDOW_BLOCKS) and (block_distance >= 0) + if not is_sink and not is_local: + scores = _apply_sparse_nm_to_qk_tile( + scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M + ) + # --- Online softmax update --- # 1. Update running max m_new = tl.maximum(row_max, tl.max(scores, 1)) @@ -278,6 +403,10 @@ def _attn_bwd_dq( BLOCK_N: tl.constexpr, IS_CAUSAL: tl.constexpr, HEAD_DIM: tl.constexpr, + SPARSITY_N: tl.constexpr = 0, + SPARSITY_M: tl.constexpr = 4, + NUM_SINK_TOKENS: tl.constexpr = 0, + DENSE_WINDOW_BLOCKS: tl.constexpr = 1, ): """Phase 3 of backward: compute dQ for one Q tile, looping over KV tiles. @@ -343,6 +472,20 @@ def _attn_bwd_dq( # Recompute attention: S = Q @ K^T, P = exp2(S - LSE) scores = tl.dot(q, kT) * qk_scale scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) + + # Re-apply 2:4 sparsity to match forward pass + if SPARSITY_N > 0: + is_sink = kv_start < NUM_SINK_TOKENS + causal_offset = seq_len_kv - seq_len_q + q_abs_block = (tile_q * BLOCK_M + causal_offset) // BLOCK_N + kv_block_idx = kv_start // BLOCK_N + block_distance = q_abs_block - kv_block_idx + is_local = (block_distance < DENSE_WINDOW_BLOCKS) and (block_distance >= 0) + if not is_sink and not is_local: + scores = _apply_sparse_nm_to_qk_tile( + scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M + ) + p = tl.math.exp2(scores - lse[:, None]) # dP = dO @ V^T, dS = P * (dP - delta), dQ += dS @ K @@ -392,6 +535,10 @@ def _attn_bwd_dkdv( BLOCK_N: tl.constexpr, IS_CAUSAL: tl.constexpr, HEAD_DIM: tl.constexpr, + SPARSITY_N: tl.constexpr = 0, + SPARSITY_M: tl.constexpr = 4, + NUM_SINK_TOKENS: tl.constexpr = 0, + DENSE_WINDOW_BLOCKS: tl.constexpr = 1, ): """Phase 2 of backward: compute dK, dV for one KV tile. @@ -465,6 +612,20 @@ def _attn_bwd_dkdv( # Recompute attention: S = Q @ K^T, P = exp2(S - LSE) scores = tl.dot(q_tile, kT) * qk_scale scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) + + # Re-apply 2:4 sparsity to match forward pass + if SPARSITY_N > 0: + is_sink = kv_start < NUM_SINK_TOKENS + causal_offset = seq_len_kv - seq_len_q + q_abs_block = (qi * BLOCK_M + causal_offset) // BLOCK_N + kv_block_idx = kv_start // BLOCK_N + block_distance = q_abs_block - kv_block_idx + is_local = (block_distance < DENSE_WINDOW_BLOCKS) and (block_distance >= 0) + if not is_sink and not is_local: + scores = _apply_sparse_nm_to_qk_tile( + scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M + ) + p = tl.math.exp2(scores - lse[:, None]) # dV += P^T @ dO @@ -498,6 +659,10 @@ def forward( b_start_loc_k, b_seq_len_k, max_input_len_k, + sparsity_n, + sparsity_m, + num_sink_tokens, + dense_window_blocks, ): HEAD_DIM = q.shape[2] num_q_heads = q.shape[1] @@ -552,6 +717,10 @@ def grid(META): IS_CAUSAL=is_causal, HEAD_DIM=HEAD_DIM, STORE_LSE=True, + SPARSITY_N=sparsity_n, + SPARSITY_M=sparsity_m, + NUM_SINK_TOKENS=num_sink_tokens, + DENSE_WINDOW_BLOCKS=dense_window_blocks, # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune ) @@ -566,6 +735,10 @@ def grid(META): ctx.num_q_heads = num_q_heads ctx.num_kv_heads = num_kv_heads ctx.batch = batch + ctx.sparsity_n = sparsity_n + ctx.sparsity_m = sparsity_m + ctx.num_sink_tokens = num_sink_tokens + ctx.dense_window_blocks = dense_window_blocks return o @staticmethod @@ -640,6 +813,10 @@ def backward(ctx, grad_output): BLOCK_N=BLOCK, IS_CAUSAL=ctx.is_causal, HEAD_DIM=HEAD_DIM, + SPARSITY_N=ctx.sparsity_n, + SPARSITY_M=ctx.sparsity_m, + NUM_SINK_TOKENS=ctx.num_sink_tokens, + DENSE_WINDOW_BLOCKS=ctx.dense_window_blocks, num_warps=num_warps, num_stages=1, ) @@ -659,11 +836,15 @@ def backward(ctx, grad_output): BLOCK_N=BLOCK, IS_CAUSAL=ctx.is_causal, HEAD_DIM=HEAD_DIM, + SPARSITY_N=ctx.sparsity_n, + SPARSITY_M=ctx.sparsity_m, + NUM_SINK_TOKENS=ctx.num_sink_tokens, + DENSE_WINDOW_BLOCKS=ctx.dense_window_blocks, num_warps=num_warps, num_stages=1, ) - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None def attention( @@ -678,8 +859,13 @@ def attention( b_start_loc_k: torch.Tensor | None = None, b_seq_len_k: torch.Tensor | None = None, max_input_len_k: int | None = None, + *, + sparsity_n: int = 0, + sparsity_m: int = 4, + num_sink_tokens: int = 0, + dense_window_blocks: int = 1, ) -> torch.Tensor: - """Variable-length flash attention with GQA and autograd support. + """Variable-length flash attention with GQA, autograd, and optional N:M sparsity. Args: q: [total_q_tokens, num_q_heads, head_dim] @@ -693,6 +879,15 @@ def attention( b_start_loc_k: [batch] start offset for K/V (None = same as Q). b_seq_len_k: [batch] length for K/V (None = same as Q). max_input_len_k: Maximum K/V sequence length (None = same as Q). + sparsity_n: N:M sparsity — keep top-N of every M attention scores + along the key dimension. Set to 0 to disable. Examples: + ``sparsity_n=2, sparsity_m=4`` for 2:4 sparsity; + ``sparsity_n=4, sparsity_m=8`` for 4:8 sparsity. + sparsity_m: N:M sparsity — group size (4 or 8). + num_sink_tokens: KV blocks containing any of the first N token positions + are kept dense (attention sinks). Granularity is per-block. + dense_window_blocks: KV blocks within this distance from the query + diagonal are kept dense (local attention window). Returns: Output tensor [total_q_tokens, num_q_heads, head_dim]. @@ -710,6 +905,10 @@ def attention( b_start_loc_k, b_seq_len_k, max_input_len_k, + sparsity_n, + sparsity_m, + num_sink_tokens, + dense_window_blocks, ) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py index c86a4131e..e8ea18077 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py @@ -28,14 +28,31 @@ from modelopt.torch.kernels import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE if TRITON_KERNEL_AVAILABLE: + import triton + import triton.language as tl + from modelopt.torch.kernels import attention, register_triton_attention + from modelopt.torch.kernels.triton_fa import _apply_sparse_nm_to_qk_tile if register_triton_attention is not None: register_triton_attention() + @triton.jit + def _test_apply_sparse_nm(In, Out, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + SPARSITY_N: tl.constexpr, SPARSITY_M: tl.constexpr): + """Test wrapper: apply N:M sparsity to a tile and store result.""" + offs = tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + qk = tl.load(In + offs) + tl.store(Out + offs, _apply_sparse_nm_to_qk_tile(qk, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M)) + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + -def _sdpa_reference(q, k, v, b_start_loc, b_seq_len): - """SDPA causal reference. Supports GQA. Returns [total_tokens, num_heads, dim].""" +def _sdpa_reference(q, k, v, b_start_loc, b_seq_len, is_causal=True): + """SDPA reference. Supports GQA. Returns [total_tokens, num_heads, dim].""" batch = b_seq_len.shape[0] num_q, num_kv = q.shape[1], k.shape[1] parts = [] @@ -48,11 +65,30 @@ def _sdpa_reference(q, k, v, b_start_loc, b_seq_len): r = num_q // num_kv kb = kb.repeat_interleave(r, dim=1) vb = vb.repeat_interleave(r, dim=1) - ob = F.scaled_dot_product_attention(qb, kb, vb, is_causal=True) + ob = F.scaled_dot_product_attention(qb, kb, vb, is_causal=is_causal) parts.append(ob.permute(0, 2, 1, 3).squeeze(0)) return torch.cat(parts, dim=0) +def _make_qkv(total, num_heads, num_kv_heads, head_dim, device="cuda", dtype=torch.float16): + """Create packed Q, K, V tensors.""" + q = torch.randn(total, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(total, num_kv_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(total, num_kv_heads, head_dim, device=device, dtype=dtype) + return q, k, v + + +def _make_varlen_meta(seq_lens, device="cuda"): + """Create b_start_loc and b_seq_len from a list of sequence lengths.""" + b_seq_len = torch.tensor(seq_lens, device=device, dtype=torch.int32) + b_start_loc = torch.zeros(len(seq_lens), device=device, dtype=torch.int32) + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], dim=0) + return b_start_loc, b_seq_len + + +# --------------------------------------------------------------------------- +# Forward correctness +# --------------------------------------------------------------------------- @pytest.fixture(scope="module") def tiny_llama_dir(tmp_path_factory): """Tiny Llama: 2 layers, 64 hidden, 4 q-heads, 2 kv-heads, head_dim=16.""" @@ -71,8 +107,8 @@ def tiny_llama_dir(tmp_path_factory): @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") -class TestTritonFaVsSdpa: - """Triton flash attention matches PyTorch SDPA for prefill and decode.""" +class TestForward: + """Forward pass correctness for dense and sparse attention.""" @pytest.mark.parametrize( ("dtype", "num_heads", "num_kv_heads", "head_dim"), @@ -84,42 +120,27 @@ class TestTritonFaVsSdpa: ids=["fp32_mha", "fp16_gqa", "bf16_gqa_hdim128"], ) def test_prefill_matches_sdpa(self, dtype, num_heads, num_kv_heads, head_dim): - """Prefill matches SDPA.""" + """Dense prefill matches SDPA.""" seq_lens = [8, 12] total = sum(seq_lens) scale = 1.0 / (head_dim**0.5) torch.manual_seed(123) - q = torch.randn(total, num_heads, head_dim, device="cuda", dtype=dtype) - k = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) - v = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) - locs = torch.tensor([0, seq_lens[0]], device="cuda", dtype=torch.int32) - lens = torch.tensor(seq_lens, device="cuda", dtype=torch.int32) + q, k, v = _make_qkv(total, num_heads, num_kv_heads, head_dim, dtype=dtype) + locs, lens = _make_varlen_meta(seq_lens) - o = attention( - q, - k, - v, - b_start_loc=locs, - b_seq_len=lens, - max_input_len=max(seq_lens), - is_causal=True, - softmax_scale=scale, - ) + o = attention(q, k, v, locs, lens, max(seq_lens), softmax_scale=scale) torch.testing.assert_close(o, _sdpa_reference(q, k, v, locs, lens), rtol=1e-3, atol=1e-3) def test_decode_matches_sdpa(self): - """Decode matches SDPA.""" + """Dense decode matches SDPA.""" batch = 2 - seq_lens_k = [5, 9] # KV lengths (context + current token) + seq_lens_k = [5, 9] num_heads, num_kv_heads, head_dim = 4, 2, 32 scale = 1.0 / (head_dim**0.5) torch.manual_seed(103) - # Q: one token per batch element -> flat [batch, num_heads, head_dim] q_flat = torch.randn(batch, num_heads, head_dim, device="cuda", dtype=torch.float32) - - # K/V: variable-length, packed into flat tensors total_kv = sum(seq_lens_k) k_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) v_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) @@ -136,9 +157,9 @@ def test_decode_matches_sdpa(self): q_flat, k_flat, v_flat, - b_start_loc=b_start_loc_q, - b_seq_len=b_seq_len_q, - max_input_len=1, + b_start_loc_q, + b_seq_len_q, + 1, is_causal=False, softmax_scale=scale, b_start_loc_k=b_start_loc_k, @@ -149,7 +170,7 @@ def test_decode_matches_sdpa(self): for i in range(batch): sl = seq_lens_k[i] s = cumsum[i] - qb = q_flat[i : i + 1].unsqueeze(2) # [1, heads, 1, dim] + qb = q_flat[i : i + 1].unsqueeze(2) kb = k_flat[s : s + sl].unsqueeze(0).permute(0, 2, 1, 3) vb = v_flat[s : s + sl].unsqueeze(0).permute(0, 2, 1, 3) kb = kb.repeat_interleave(num_heads // num_kv_heads, dim=1) @@ -157,93 +178,27 @@ def test_decode_matches_sdpa(self): ref = F.scaled_dot_product_attention(qb, kb, vb, is_causal=False).squeeze(2) torch.testing.assert_close(out[i : i + 1], ref, rtol=1e-3, atol=1e-3) + def test_sparse_disabled_matches_dense(self): + """sparsity_n=0 produces bit-identical output to default (dense).""" + seq_lens = [128, 128] + total = sum(seq_lens) + scale = 1.0 / (64**0.5) -@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") -class TestSparseAttentionIntegration: - """HF model + mtsa.sparsify integration.""" - - def test_triton_matches_eager(self, tiny_llama_dir): - """Triton attention produces same logits and generated tokens as eager.""" - pytest.importorskip("transformers") - from transformers import AutoModelForCausalLM, AutoTokenizer - - tok = AutoTokenizer.from_pretrained(tiny_llama_dir) - if tok.pad_token_id is None: - tok.pad_token_id = tok.eos_token_id - ids = tok("The capital of France is", return_tensors="pt").input_ids.to("cuda") - - # Eager baseline - model_eager = AutoModelForCausalLM.from_pretrained( - tiny_llama_dir, - attn_implementation="eager", - torch_dtype=torch.bfloat16, - device_map="cuda", - ) - model_eager.eval() - with torch.no_grad(): - logits_eager = model_eager(input_ids=ids).logits - out_eager = model_eager.generate( - ids, - max_new_tokens=5, - do_sample=False, - pad_token_id=tok.pad_token_id, - ) - del model_eager - - # Triton - model_triton = AutoModelForCausalLM.from_pretrained( - tiny_llama_dir, - attn_implementation="modelopt_triton", - torch_dtype=torch.bfloat16, - device_map="cuda", - ) - model_triton.eval() - with torch.no_grad(): - logits_triton = model_triton(input_ids=ids).logits - out_triton = model_triton.generate( - ids, - max_new_tokens=5, - do_sample=False, - pad_token_id=tok.pad_token_id, - ) - - # Logits should be close (bf16 tolerance) - torch.testing.assert_close(logits_triton, logits_eager, rtol=2e-2, atol=2e-2) - # Generated tokens must be identical (greedy decoding is deterministic) - assert torch.equal(out_triton, out_eager), ( - f"Generated tokens differ:\n eager: {out_eager}\n triton: {out_triton}" - ) - - def test_triton_padded_batch(self, tiny_llama_dir): - """Padded batch (2D attention mask) produces valid logits for each sequence.""" - pytest.importorskip("transformers") - from transformers import AutoModelForCausalLM, AutoTokenizer - - model = AutoModelForCausalLM.from_pretrained( - tiny_llama_dir, - attn_implementation="modelopt_triton", - torch_dtype=torch.bfloat16, - device_map="cuda", - ) - model.eval() - tok = AutoTokenizer.from_pretrained(tiny_llama_dir) - if tok.pad_token_id is None: - tok.pad_token_id = tok.eos_token_id - tok.padding_side = "right" + torch.manual_seed(99) + q, k, v = _make_qkv(total, 4, 2, 64) + locs, lens = _make_varlen_meta(seq_lens) - inputs = tok( - ["Hello world", "The capital of France is Paris and"], - return_tensors="pt", - padding=True, - ).to("cuda") - with torch.no_grad(): - logits = model(**inputs).logits - assert not torch.isnan(logits).any() and not torch.isinf(logits).any() + out_dense = attention(q, k, v, locs, lens, 128, softmax_scale=scale) + out_n0 = attention(q, k, v, locs, lens, 128, softmax_scale=scale, sparsity_n=0) + assert torch.equal(out_dense, out_n0) +# --------------------------------------------------------------------------- +# Backward correctness +# --------------------------------------------------------------------------- @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") class TestBackward: - """Backward pass gradient correctness tests.""" + """Backward pass gradient correctness for dense and sparse attention.""" def _sdpa_backward_ref(self, q, k, v, scale, is_causal=True): """Run SDPA forward+backward, return output and gradients.""" @@ -264,132 +219,73 @@ def _sdpa_backward_ref(self, q, k, v, scale, is_causal=True): dq = q_ref.grad.permute(0, 2, 1, 3).squeeze(0) dk = k_ref.grad.permute(0, 2, 1, 3).squeeze(0) dv = v_ref.grad.permute(0, 2, 1, 3).squeeze(0) - return o_ref.permute(0, 2, 1, 3).squeeze(0).detach(), dq.detach(), dk.detach(), dv.detach() - - def test_backward_causal_matches_sdpa(self): - """dQ, dK, dV match SDPA backward for causal self-attention.""" - from modelopt.torch.kernels import attention + return dq.detach(), dk.detach(), dv.detach() - seq_len = 16 - num_heads, num_kv_heads, head_dim = 2, 2, 32 + # --- Dense backward vs SDPA --- + def test_dense_causal_matches_sdpa(self): + """dQ, dK, dV match SDPA for causal self-attention.""" + seq_len, num_heads, num_kv_heads, head_dim = 16, 2, 2, 32 scale = 1.0 / (head_dim**0.5) torch.manual_seed(42) - q = torch.randn( - seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - k = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - v = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) + q, k, v = _make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + locs, lens = _make_varlen_meta([seq_len]) - o = attention( - q, - k, - v, - b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), - b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), - max_input_len=seq_len, - is_causal=True, - softmax_scale=scale, - ) - o.sum().backward() - - _, dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref( - q.detach(), k.detach(), v.detach(), scale, is_causal=True - ) + attention(q, k, v, locs, lens, seq_len, softmax_scale=scale).sum().backward() + dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref(q.detach(), k.detach(), v.detach(), scale) torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) - def test_backward_gqa(self): - """Backward with GQA (4 q-heads, 2 kv-heads), multi-tile (seq_len=256).""" - from modelopt.torch.kernels import attention - - seq_len = 256 - num_heads, num_kv_heads, head_dim = 4, 2, 32 + def test_dense_gqa_matches_sdpa(self): + """Dense backward with GQA (4 q-heads, 2 kv-heads), seq_len=256.""" + seq_len, num_heads, num_kv_heads, head_dim = 256, 4, 2, 32 scale = 1.0 / (head_dim**0.5) torch.manual_seed(43) - q = torch.randn( - seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - k = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - v = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) + q, k, v = _make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + locs, lens = _make_varlen_meta([seq_len]) - o = attention( - q, - k, - v, - b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), - b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), - max_input_len=seq_len, - is_causal=True, - softmax_scale=scale, - ) - o.sum().backward() - - _, dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref( - q.detach(), k.detach(), v.detach(), scale, is_causal=True - ) + attention(q, k, v, locs, lens, seq_len, softmax_scale=scale).sum().backward() + dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref(q.detach(), k.detach(), v.detach(), scale) torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) - def test_backward_multi_batch_variable_length(self): - """Multi-batch variable-length causal backward matches per-sample SDPA.""" - from modelopt.torch.kernels import attention - + def test_dense_multi_batch_variable_length(self): + """Multi-batch variable-length backward matches per-sample SDPA.""" seq_lens = [8, 12] total = sum(seq_lens) num_heads, num_kv_heads, head_dim = 2, 2, 32 scale = 1.0 / (head_dim**0.5) torch.manual_seed(45) - q = torch.randn( - total, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - k = torch.randn( - total, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - v = torch.randn( - total, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - locs = torch.tensor([0, seq_lens[0]], device="cuda", dtype=torch.int32) - lens = torch.tensor(seq_lens, device="cuda", dtype=torch.int32) + q, k, v = _make_qkv(total, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + locs, lens = _make_varlen_meta(seq_lens) - o = attention( - q, - k, - v, - b_start_loc=locs, - b_seq_len=lens, - max_input_len=max(seq_lens), - is_causal=True, - softmax_scale=scale, - ) - o.sum().backward() + attention(q, k, v, locs, lens, max(seq_lens), softmax_scale=scale).sum().backward() - # Per-sample SDPA reference dq_ref = torch.zeros_like(q) dk_ref = torch.zeros_like(k) dv_ref = torch.zeros_like(v) for b in range(len(seq_lens)): s, n = int(locs[b].item()), seq_lens[b] - _, dq_b, dk_b, dv_b = self._sdpa_backward_ref( + dq_b, dk_b, dv_b = self._sdpa_backward_ref( q.detach()[s : s + n], k.detach()[s : s + n], v.detach()[s : s + n], scale, - is_causal=True, ) dq_ref[s : s + n] = dq_b dk_ref[s : s + n] = dk_b @@ -399,41 +295,309 @@ def test_backward_multi_batch_variable_length(self): torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) - def test_backward_longer_sequences(self): - """Backward with seq_len=512, GQA, exercises multi-tile loops.""" - from modelopt.torch.kernels import attention - - seq_len = 512 - num_heads, num_kv_heads, head_dim = 4, 2, 64 + def test_dense_longer_sequences(self): + """Dense backward with seq_len=512, GQA, exercises multi-tile loops.""" + seq_len, num_heads, num_kv_heads, head_dim = 512, 4, 2, 64 scale = 1.0 / (head_dim**0.5) torch.manual_seed(49) - q = torch.randn( - seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + q, k, v = _make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + locs, lens = _make_varlen_meta([seq_len]) + + attention(q, k, v, locs, lens, seq_len, softmax_scale=scale).sum().backward() + dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref(q.detach(), k.detach(), v.detach(), scale) + + torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) + + # --- Sparse backward sanity checks --- + + @pytest.mark.parametrize( + ("n", "m"), + [(2, 4), (4, 8)], + ids=["2:4", "4:8"], + ) + def test_sparse_gradients_finite(self, n, m): + """Backward with N:M sparsity produces finite, non-zero gradients.""" + seq_len, num_heads, num_kv_heads, head_dim = 128, 4, 2, 64 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(55) + q, k, v = _make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + locs, lens = _make_varlen_meta([seq_len]) + + attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + ).sum().backward() + + for name, grad in [("dQ", q.grad), ("dK", k.grad), ("dV", v.grad)]: + assert grad is not None, f"{name} is None for {n}:{m}" + assert not torch.isnan(grad).any(), f"NaN in {name} for {n}:{m}" + assert not torch.isinf(grad).any(), f"Inf in {name} for {n}:{m}" + assert grad.abs().sum() > 0, f"{name} is all zeros for {n}:{m}" + + def test_sparse_gradients_differ_from_dense(self): + """Gradients with 2:4 sparsity should differ from dense gradients.""" + seq_len, num_heads, num_kv_heads, head_dim = 256, 4, 2, 64 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(66) + q, k, v = _make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + locs, lens = _make_varlen_meta([seq_len]) + + q_d = q.clone().requires_grad_(True) + k_d = k.clone().requires_grad_(True) + v_d = v.clone().requires_grad_(True) + attention(q_d, k_d, v_d, locs, lens, seq_len, softmax_scale=scale).sum().backward() + + q_s = q.clone().requires_grad_(True) + k_s = k.clone().requires_grad_(True) + v_s = v.clone().requires_grad_(True) + attention( + q_s, + k_s, + v_s, + locs, + lens, + seq_len, + softmax_scale=scale, + sparsity_n=2, + sparsity_m=4, + ).sum().backward() + + assert not torch.allclose(q_d.grad, q_s.grad, atol=1e-3), ( + "dQ same with and without sparsity" ) - k = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + assert not torch.allclose(k_d.grad, k_s.grad, atol=1e-3), ( + "dK same with and without sparsity" ) - v = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + + +# --------------------------------------------------------------------------- +# N:M sparsity +# --------------------------------------------------------------------------- +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSparseNM: + """N:M structured sparsity behavior on attention scores (prefill only).""" + + def _make_inputs(self, batch=2, seq_len=256, num_heads=4, num_kv_heads=2, head_dim=64): + total = batch * seq_len + torch.manual_seed(99) + q, k, v = _make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = _make_varlen_meta([seq_len] * batch) + return q, k, v, locs, lens + + @pytest.mark.parametrize( + ("n", "m"), + [(1, 4), (2, 4), (3, 4), (1, 8), (2, 8), (4, 8)], + ids=["1:4", "2:4", "3:4", "1:8", "2:8", "4:8"], + ) + def test_output_shape(self, n, m): + """Output shape matches Q shape for all N:M patterns.""" + q, k, v, locs, lens = self._make_inputs() + out = attention( + q, k, v, locs, lens, 256, softmax_scale=1.0 / 8.0, sparsity_n=n, sparsity_m=m ) + assert out.shape == q.shape - o = attention( + @pytest.mark.parametrize( + ("n", "m"), + [(1, 4), (2, 4), (3, 4), (1, 8), (2, 8), (4, 8)], + ids=["1:4", "2:4", "3:4", "1:8", "2:8", "4:8"], + ) + def test_no_nan(self, n, m): + """All N:M patterns produce finite output.""" + q, k, v, locs, lens = self._make_inputs() + out = attention( + q, k, v, locs, lens, 256, softmax_scale=1.0 / 8.0, sparsity_n=n, sparsity_m=m + ) + assert not torch.isnan(out).any(), f"NaN in output for {n}:{m}" + assert not torch.isinf(out).any(), f"Inf in output for {n}:{m}" + + @pytest.mark.parametrize( + ("n", "m"), + [(1, 4), (2, 4), (1, 8), (4, 8)], + ids=["1:4", "2:4", "1:8", "4:8"], + ) + def test_sparse_differs_from_dense(self, n, m): + """Sparse output should differ from dense for long sequences.""" + q, k, v, locs, lens = self._make_inputs(seq_len=512) + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 512, softmax_scale=scale) + out_sparse = attention( + q, k, v, locs, lens, 512, softmax_scale=scale, sparsity_n=n, sparsity_m=m + ) + assert not torch.allclose(out_sparse, out_dense, atol=1e-3) + + @pytest.mark.parametrize( + ("n_values", "m"), + [([1, 2, 3], 4), ([1, 2, 4], 8)], + ids=["m4", "m8"], + ) + def test_more_sparsity_more_error(self, n_values, m): + """Keeping more elements should deviate less from dense (monotonic decreasing error).""" + q, k, v, locs, lens = self._make_inputs(seq_len=512) + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 512, softmax_scale=scale) + errors = [] + for n in n_values: + out = attention( + q, k, v, locs, lens, 512, softmax_scale=scale, sparsity_n=n, sparsity_m=m + ) + errors.append((out - out_dense).abs().mean().item()) + for i in range(len(errors) - 1): + assert errors[i] > errors[i + 1], ( + f"Errors not monotonically decreasing for M={m}: " + + ", ".join(f"{n}:{m}={e:.6f}" for n, e in zip(n_values, errors)) + ) + + @pytest.mark.parametrize( + ("n", "m"), + [(2, 4), (4, 8)], + ids=["2:4", "4:8"], + ) + def test_dense_window_preserves_local(self, n, m): + """Large dense_window_blocks makes sparse output closer to dense.""" + q, k, v, locs, lens = self._make_inputs(seq_len=256) + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 256, softmax_scale=scale) + out_small = attention( q, k, v, - b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), - b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), - max_input_len=seq_len, - is_causal=True, + locs, + lens, + 256, softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + dense_window_blocks=1, ) - o.sum().backward() + out_large = attention( + q, + k, + v, + locs, + lens, + 256, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + dense_window_blocks=100, + ) + err_small = (out_small - out_dense).abs().mean().item() + err_large = (out_large - out_dense).abs().mean().item() + assert err_large < err_small + + @pytest.mark.parametrize( + ("n", "m"), + [(1, 4), (2, 4), (3, 4), (1, 8), (2, 8), (4, 8)], + ids=["1:4", "2:4", "3:4", "1:8", "2:8", "4:8"], + ) + def test_sparsity_structure(self, n, m): + """Verify N:M structure: exactly N kept per group of M.""" + BM, BN = 32, 64 + torch.manual_seed(88) + tile = torch.randn(BM, BN, device="cuda", dtype=torch.float32) + out = torch.empty_like(tile) + _test_apply_sparse_nm[(1,)](tile, out, BLOCK_M=BM, BLOCK_N=BN, SPARSITY_N=n, SPARSITY_M=m) - _, dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref( - q.detach(), k.detach(), v.detach(), scale, is_causal=True + kept = (out.reshape(BM, BN // m, m) != float("-inf")).sum(dim=-1) + assert (kept == n).all(), f"Expected {n} kept per group of {m}, got min={kept.min()}, max={kept.max()}" + + +# --------------------------------------------------------------------------- +# HuggingFace integration +# --------------------------------------------------------------------------- +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestHFIntegration: + """HF model integration with Triton attention backend.""" + + def test_triton_matches_eager(self, tiny_llama_dir): + """Triton attention produces same logits and generated tokens as eager.""" + pytest.importorskip("transformers") + from transformers import AutoModelForCausalLM, AutoTokenizer + + tok = AutoTokenizer.from_pretrained(tiny_llama_dir) + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + ids = tok("The capital of France is", return_tensors="pt").input_ids.to("cuda") + + model_eager = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="eager", + torch_dtype=torch.bfloat16, + device_map="cuda", ) + model_eager.eval() + with torch.no_grad(): + logits_eager = model_eager(input_ids=ids).logits + out_eager = model_eager.generate( + ids, + max_new_tokens=5, + do_sample=False, + pad_token_id=tok.pad_token_id, + ) + del model_eager - torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) + model_triton = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="modelopt_triton", + torch_dtype=torch.bfloat16, + device_map="cuda", + ) + model_triton.eval() + with torch.no_grad(): + logits_triton = model_triton(input_ids=ids).logits + out_triton = model_triton.generate( + ids, + max_new_tokens=5, + do_sample=False, + pad_token_id=tok.pad_token_id, + ) + + torch.testing.assert_close(logits_triton, logits_eager, rtol=2e-2, atol=2e-2) + assert torch.equal(out_triton, out_eager), ( + f"Generated tokens differ:\n eager: {out_eager}\n triton: {out_triton}" + ) + + def test_triton_padded_batch(self, tiny_llama_dir): + """Padded batch produces valid logits.""" + pytest.importorskip("transformers") + from transformers import AutoModelForCausalLM, AutoTokenizer + + model = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="modelopt_triton", + torch_dtype=torch.bfloat16, + device_map="cuda", + ) + model.eval() + tok = AutoTokenizer.from_pretrained(tiny_llama_dir) + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + tok.padding_side = "right" + + inputs = tok( + ["Hello world", "The capital of France is Paris and"], + return_tensors="pt", + padding=True, + ).to("cuda") + with torch.no_grad(): + logits = model(**inputs).logits + assert not torch.isnan(logits).any() and not torch.isinf(logits).any() From 31655ce9e09d1cbd4d9156559f49a445de66dafb Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Thu, 19 Mar 2026 21:26:27 -0700 Subject: [PATCH 2/2] Add unit test for 2:4 sparse softmax Signed-off-by: Kai Xu --- CHANGELOG.rst | 1 + modelopt/torch/kernels/hf_triton_attention.py | 9 + modelopt/torch/kernels/triton_fa.py | 53 +-- .../sparsity/attention_sparsity/config.py | 51 +++ .../attention_sparsity/methods/__init__.py | 2 +- .../methods/triton_sparse_softmax.py | 78 ++++ .../sparsity/attention_sparsity/conftest.py | 72 ++++ .../attention_sparsity/test_triton_fa.py | 371 ++++-------------- .../test_triton_fa_sparse_nm.py | 348 ++++++++++++++++ 9 files changed, 668 insertions(+), 317 deletions(-) create mode 100644 modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py create mode 100644 tests/gpu/torch/sparsity/attention_sparsity/conftest.py create mode 100644 tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 2282620bc..11ef2aaa6 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,7 @@ NVIDIA Model Optimizer Changelog **New Features** - Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics. +- Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). For every M consecutive key positions, the top-N attention scores are kept and the rest are set to -inf before softmax. See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. **Bug Fixes** diff --git a/modelopt/torch/kernels/hf_triton_attention.py b/modelopt/torch/kernels/hf_triton_attention.py index 5f10df250..73db8b69a 100644 --- a/modelopt/torch/kernels/hf_triton_attention.py +++ b/modelopt/torch/kernels/hf_triton_attention.py @@ -105,6 +105,15 @@ def triton_attention_forward( kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32) kw["max_input_len_k"] = seq_k + # N:M sparse softmax + if getattr(module, "_apply_sparse_nm", False): + method = getattr(module, "_sparse_method_instance", None) + if method is not None: + kw["sparsity_n"] = getattr(method, "sparsity_n", 2) + kw["sparsity_m"] = getattr(method, "sparsity_m", 4) + kw["num_sink_blocks"] = getattr(method, "num_sink_blocks", 0) + kw["dense_window_blocks"] = getattr(method, "dense_window_blocks", 1) + o = attention(q, k, v, **kw) attn_output = o.view(batch, seq_len, num_heads, head_dim) diff --git a/modelopt/torch/kernels/triton_fa.py b/modelopt/torch/kernels/triton_fa.py index c39c23165..391c9cd7c 100644 --- a/modelopt/torch/kernels/triton_fa.py +++ b/modelopt/torch/kernels/triton_fa.py @@ -46,7 +46,7 @@ # --------------------------------------------------------------------------- -# N:M structured sparsity helpers +# N:M sparse softmax helpers # --------------------------------------------------------------------------- @triton.jit def _sparse_nm_masks_m4(x0, x1, x2, x3, N: tl.constexpr): @@ -106,11 +106,15 @@ def _apply_sparse_nm_to_qk_tile( SPARSITY_N: tl.constexpr, SPARSITY_M: tl.constexpr, ): - """Apply N:M structured sparsity to a QK score tile. + """Apply N:M sparse softmax to a QK score tile. For every ``SPARSITY_M`` consecutive elements along the N (key) dimension, keeps the top ``SPARSITY_N`` values and sets the rest to ``-inf``. ``BLOCK_N`` must be divisible by ``SPARSITY_M``. + + For M=4, exactly N values are retained (ties broken by position). + For M=8, a threshold-based approach (``tl.sort``) may retain more + than N values when ties straddle the threshold boundary. """ tl.static_assert(SPARSITY_M == 4 or SPARSITY_M == 8, "SPARSITY_M must be 4 or 8") # noqa: PLR1714 MASK_VAL: tl.constexpr = float("-inf") @@ -141,7 +145,7 @@ def _apply_sparse_nm_to_qk_tile( sorted_vals = tl.sort(reshaped, dim=2) KTH_IDX: tl.constexpr = SPARSITY_M - SPARSITY_N # index of N-th largest in ascending order - # Extract the threshold value (one extraction vs eight before) + # Extract the threshold value at KTH_IDX via masked sum # Use 0.0 as fill (not -inf) so sum equals just the KTH element cols = tl.arange(0, 8)[None, None, :] threshold = tl.sum(tl.where(cols == KTH_IDX, sorted_vals, 0.0), axis=2) @@ -213,7 +217,7 @@ def _attn_fwd( STORE_LSE: tl.constexpr, # Whether to save LSE for backward pass SPARSITY_N: tl.constexpr = 0, # N:M sparsity — keep top-N of every M elements (0 = disabled) SPARSITY_M: tl.constexpr = 4, # N:M sparsity — group size (4 or 8) - NUM_SINK_TOKENS: tl.constexpr = 0, # First N tokens kept dense (attention sinks) + NUM_SINK_BLOCKS: tl.constexpr = 0, # Leading KV blocks kept dense (attention sinks) DENSE_WINDOW_BLOCKS: tl.constexpr = 1, # Local blocks near diagonal kept dense ): # --- Grid: (batch, num_q_heads, num_q_tiles) --- @@ -272,14 +276,14 @@ def _attn_fwd( scores = tl.dot(q, k) * qk_scale scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) - # --- Optional 2:4 structured sparsity --- + # --- Optional N:M sparse softmax --- if SPARSITY_N > 0: # Check if this KV tile should be kept dense - is_sink = kv_start < NUM_SINK_TOKENS + kv_block_idx = kv_start // BLOCK_N + is_sink = kv_block_idx < NUM_SINK_BLOCKS # causal_offset handles chunked prefill: q starts at (seq_len_kv - seq_len_q) causal_offset = seq_len_kv - seq_len_q q_abs_block = (tile_q * BLOCK_M + causal_offset) // BLOCK_N - kv_block_idx = kv_start // BLOCK_N block_distance = q_abs_block - kv_block_idx is_local = (block_distance < DENSE_WINDOW_BLOCKS) and (block_distance >= 0) if not is_sink and not is_local: @@ -405,7 +409,7 @@ def _attn_bwd_dq( HEAD_DIM: tl.constexpr, SPARSITY_N: tl.constexpr = 0, SPARSITY_M: tl.constexpr = 4, - NUM_SINK_TOKENS: tl.constexpr = 0, + NUM_SINK_BLOCKS: tl.constexpr = 0, DENSE_WINDOW_BLOCKS: tl.constexpr = 1, ): """Phase 3 of backward: compute dQ for one Q tile, looping over KV tiles. @@ -473,12 +477,12 @@ def _attn_bwd_dq( scores = tl.dot(q, kT) * qk_scale scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) - # Re-apply 2:4 sparsity to match forward pass + # Re-apply N:M sparse softmax to match forward pass if SPARSITY_N > 0: - is_sink = kv_start < NUM_SINK_TOKENS + kv_block_idx = kv_start // BLOCK_N + is_sink = kv_block_idx < NUM_SINK_BLOCKS causal_offset = seq_len_kv - seq_len_q q_abs_block = (tile_q * BLOCK_M + causal_offset) // BLOCK_N - kv_block_idx = kv_start // BLOCK_N block_distance = q_abs_block - kv_block_idx is_local = (block_distance < DENSE_WINDOW_BLOCKS) and (block_distance >= 0) if not is_sink and not is_local: @@ -537,7 +541,7 @@ def _attn_bwd_dkdv( HEAD_DIM: tl.constexpr, SPARSITY_N: tl.constexpr = 0, SPARSITY_M: tl.constexpr = 4, - NUM_SINK_TOKENS: tl.constexpr = 0, + NUM_SINK_BLOCKS: tl.constexpr = 0, DENSE_WINDOW_BLOCKS: tl.constexpr = 1, ): """Phase 2 of backward: compute dK, dV for one KV tile. @@ -613,12 +617,12 @@ def _attn_bwd_dkdv( scores = tl.dot(q_tile, kT) * qk_scale scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) - # Re-apply 2:4 sparsity to match forward pass + # Re-apply N:M sparse softmax to match forward pass if SPARSITY_N > 0: - is_sink = kv_start < NUM_SINK_TOKENS + kv_block_idx = kv_start // BLOCK_N + is_sink = kv_block_idx < NUM_SINK_BLOCKS causal_offset = seq_len_kv - seq_len_q q_abs_block = (qi * BLOCK_M + causal_offset) // BLOCK_N - kv_block_idx = kv_start // BLOCK_N block_distance = q_abs_block - kv_block_idx is_local = (block_distance < DENSE_WINDOW_BLOCKS) and (block_distance >= 0) if not is_sink and not is_local: @@ -661,7 +665,7 @@ def forward( max_input_len_k, sparsity_n, sparsity_m, - num_sink_tokens, + num_sink_blocks, dense_window_blocks, ): HEAD_DIM = q.shape[2] @@ -719,7 +723,7 @@ def grid(META): STORE_LSE=True, SPARSITY_N=sparsity_n, SPARSITY_M=sparsity_m, - NUM_SINK_TOKENS=num_sink_tokens, + NUM_SINK_BLOCKS=num_sink_blocks, DENSE_WINDOW_BLOCKS=dense_window_blocks, # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune ) @@ -737,7 +741,7 @@ def grid(META): ctx.batch = batch ctx.sparsity_n = sparsity_n ctx.sparsity_m = sparsity_m - ctx.num_sink_tokens = num_sink_tokens + ctx.num_sink_blocks = num_sink_blocks ctx.dense_window_blocks = dense_window_blocks return o @@ -815,7 +819,7 @@ def backward(ctx, grad_output): HEAD_DIM=HEAD_DIM, SPARSITY_N=ctx.sparsity_n, SPARSITY_M=ctx.sparsity_m, - NUM_SINK_TOKENS=ctx.num_sink_tokens, + NUM_SINK_BLOCKS=ctx.num_sink_blocks, DENSE_WINDOW_BLOCKS=ctx.dense_window_blocks, num_warps=num_warps, num_stages=1, @@ -838,7 +842,7 @@ def backward(ctx, grad_output): HEAD_DIM=HEAD_DIM, SPARSITY_N=ctx.sparsity_n, SPARSITY_M=ctx.sparsity_m, - NUM_SINK_TOKENS=ctx.num_sink_tokens, + NUM_SINK_BLOCKS=ctx.num_sink_blocks, DENSE_WINDOW_BLOCKS=ctx.dense_window_blocks, num_warps=num_warps, num_stages=1, @@ -862,10 +866,10 @@ def attention( *, sparsity_n: int = 0, sparsity_m: int = 4, - num_sink_tokens: int = 0, + num_sink_blocks: int = 0, dense_window_blocks: int = 1, ) -> torch.Tensor: - """Variable-length flash attention with GQA, autograd, and optional N:M sparsity. + """Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax. Args: q: [total_q_tokens, num_q_heads, head_dim] @@ -884,8 +888,7 @@ def attention( ``sparsity_n=2, sparsity_m=4`` for 2:4 sparsity; ``sparsity_n=4, sparsity_m=8`` for 4:8 sparsity. sparsity_m: N:M sparsity — group size (4 or 8). - num_sink_tokens: KV blocks containing any of the first N token positions - are kept dense (attention sinks). Granularity is per-block. + num_sink_blocks: Number of leading KV blocks to keep dense (attention sinks). dense_window_blocks: KV blocks within this distance from the query diagonal are kept dense (local attention window). @@ -907,7 +910,7 @@ def attention( max_input_len_k, sparsity_n, sparsity_m, - num_sink_tokens, + num_sink_blocks, dense_window_blocks, ) diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 4baf5bbe6..21b5b8657 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -96,6 +96,39 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) + sparsity_n: int = ModeloptField( + default=2, + title="N in N:M sparsity.", + description=( + "Keep top-N of every M attention scores. Only used by triton_sparse_softmax. " + "Set to 0 to disable sparsity." + ), + ) + + sparsity_m: int = ModeloptField( + default=4, + title="M in N:M sparsity.", + description="Group size for N:M sparsity (4 or 8). Only used by triton_sparse_softmax.", + ) + + num_sink_blocks: int = ModeloptField( + default=0, + title="Number of sink blocks.", + description=( + "Number of leading KV blocks to keep dense (attention sinks). " + "Only used by triton_sparse_softmax." + ), + ) + + dense_window_blocks: int = ModeloptField( + default=1, + title="Dense window blocks.", + description=( + "Number of local attention blocks around diagonal to keep dense. " + "Only used by triton_sparse_softmax." + ), + ) + @field_validator("method") @classmethod def validate_method(cls, v): @@ -434,9 +467,27 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): } +# Default N:M sparse softmax configuration +SPARSE_SOFTMAX_DEFAULT = { + "sparse_cfg": { + "*attn*": { + "method": "triton_sparse_softmax", + "sparsity_n": 2, + "sparsity_m": 4, + "num_sink_blocks": 0, + "dense_window_blocks": 1, + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + }, +} + + __all__ = [ "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_DEFAULT", + "SPARSE_SOFTMAX_DEFAULT", "CalibrationConfig", "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py index 8a109fda7..1bd9a547d 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py @@ -24,4 +24,4 @@ ] # Import method implementations to trigger registration -from . import flash_skip_softmax +from . import flash_skip_softmax, triton_sparse_softmax diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py new file mode 100644 index 000000000..df99c4ac1 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""N:M sparse softmax method for attention scores via Triton kernel.""" + +from contextlib import contextmanager + +import torch + +from .registry import SparseAttentionMethod, register_sparse_method + + +@register_sparse_method("triton_sparse_softmax") +class TritonSparseSoftmaxMethod(SparseAttentionMethod): + """N:M sparse softmax applied to attention scores via Triton kernel. + + Sparsity is applied inside the fused Triton flash attention kernel, + not as a separate pre/post-processing step. For every M consecutive + K positions, the top-N attention scores are kept; the other M-N are + set to -inf before softmax. + + Config params: + sparsity_n: Keep top-N of every M attention scores (0 to disable). + sparsity_m: Group size (4 or 8). + num_sink_blocks: Number of leading KV blocks kept dense (attention sinks). + dense_window_blocks: Local attention blocks kept dense near diagonal. + """ + + def __init__(self, method_config=None): + """Initialize with N:M sparsity parameters from config.""" + super().__init__() + method_config = method_config or {} + self.sparsity_n = method_config.get("sparsity_n", 2) + self.sparsity_m = method_config.get("sparsity_m", 4) + self.num_sink_blocks = method_config.get("num_sink_blocks", 0) + self.dense_window_blocks = method_config.get("dense_window_blocks", 1) + + @property + def name(self) -> str: + """Method name identifier.""" + return "triton_sparse_softmax" + + def calculate_sparsity(self, attention_scores): + """Return a no-op mask (sparsity is applied inside the Triton kernel).""" + mask = torch.ones_like(attention_scores, dtype=torch.bool) + return mask, {} + + def apply_sparsity(self, attention_scores, sparse_mask=None): + """Not supported — sparsity is fused into the Triton kernel.""" + raise NotImplementedError( + "triton_sparse_softmax applies sparsity inside the Triton kernel. " + "Use backend='triton' or backend='vllm', not backend='pytorch'." + ) + + def get_sparse_context(self, module): + """Return context manager that activates N:M sparse softmax during forward.""" + + @contextmanager + def _sparse_nm_context(): + module._apply_sparse_nm = True + try: + yield + finally: + module._apply_sparse_nm = False + + return _sparse_nm_context() diff --git a/tests/gpu/torch/sparsity/attention_sparsity/conftest.py b/tests/gpu/torch/sparsity/attention_sparsity/conftest.py new file mode 100644 index 000000000..fa4f61771 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/conftest.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared fixtures and helpers for Triton flash attention tests.""" + +import pytest +import torch +import torch.nn.functional as F + + +def make_qkv(total, num_heads, num_kv_heads, head_dim, device="cuda", dtype=torch.float16): + """Create packed Q, K, V tensors.""" + q = torch.randn(total, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(total, num_kv_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(total, num_kv_heads, head_dim, device=device, dtype=dtype) + return q, k, v + + +def make_varlen_meta(seq_lens, device="cuda"): + """Create b_start_loc and b_seq_len from a list of sequence lengths.""" + b_seq_len = torch.tensor(seq_lens, device=device, dtype=torch.int32) + b_start_loc = torch.zeros(len(seq_lens), device=device, dtype=torch.int32) + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], dim=0) + return b_start_loc, b_seq_len + + +def sdpa_reference(q, k, v, b_start_loc, b_seq_len, is_causal=True): + """SDPA reference. Supports GQA. Returns [total_tokens, num_heads, dim].""" + batch = b_seq_len.shape[0] + num_q, num_kv = q.shape[1], k.shape[1] + parts = [] + for b in range(batch): + s, n = int(b_start_loc[b].item()), int(b_seq_len[b].item()) + qb = q[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) + kb = k[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) + vb = v[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) + if num_q != num_kv: + r = num_q // num_kv + kb = kb.repeat_interleave(r, dim=1) + vb = vb.repeat_interleave(r, dim=1) + ob = F.scaled_dot_product_attention(qb, kb, vb, is_causal=is_causal) + parts.append(ob.permute(0, 2, 1, 3).squeeze(0)) + return torch.cat(parts, dim=0) + + +@pytest.fixture(scope="module") +def tiny_llama_dir(tmp_path_factory): + """Tiny Llama: 2 layers, 64 hidden, 4 q-heads, 2 kv-heads, head_dim=16.""" + from _test_utils.torch.transformers_models import create_tiny_llama_dir + + return create_tiny_llama_dir( + tmp_path_factory.mktemp("tiny_llama"), + with_tokenizer=True, + num_hidden_layers=2, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=64, + max_position_embeddings=64, + ) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py index e8ea18077..73d20ba52 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""GPU tests for Triton flash attention kernel.""" +"""GPU tests for Triton flash attention kernel (dense path).""" import pytest import torch import torch.nn.functional as F +from conftest import make_qkv, make_varlen_meta, sdpa_reference pytestmark = [ pytest.mark.filterwarnings("ignore::UserWarning"), @@ -28,87 +29,20 @@ from modelopt.torch.kernels import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE if TRITON_KERNEL_AVAILABLE: - import triton - import triton.language as tl - from modelopt.torch.kernels import attention, register_triton_attention - from modelopt.torch.kernels.triton_fa import _apply_sparse_nm_to_qk_tile if register_triton_attention is not None: register_triton_attention() - @triton.jit - def _test_apply_sparse_nm(In, Out, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - SPARSITY_N: tl.constexpr, SPARSITY_M: tl.constexpr): - """Test wrapper: apply N:M sparsity to a tile and store result.""" - offs = tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] - qk = tl.load(In + offs) - tl.store(Out + offs, _apply_sparse_nm_to_qk_tile(qk, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M)) - - -# --------------------------------------------------------------------------- -# Shared helpers -# --------------------------------------------------------------------------- - - -def _sdpa_reference(q, k, v, b_start_loc, b_seq_len, is_causal=True): - """SDPA reference. Supports GQA. Returns [total_tokens, num_heads, dim].""" - batch = b_seq_len.shape[0] - num_q, num_kv = q.shape[1], k.shape[1] - parts = [] - for b in range(batch): - s, n = int(b_start_loc[b].item()), int(b_seq_len[b].item()) - qb = q[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) - kb = k[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) - vb = v[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) - if num_q != num_kv: - r = num_q // num_kv - kb = kb.repeat_interleave(r, dim=1) - vb = vb.repeat_interleave(r, dim=1) - ob = F.scaled_dot_product_attention(qb, kb, vb, is_causal=is_causal) - parts.append(ob.permute(0, 2, 1, 3).squeeze(0)) - return torch.cat(parts, dim=0) - - -def _make_qkv(total, num_heads, num_kv_heads, head_dim, device="cuda", dtype=torch.float16): - """Create packed Q, K, V tensors.""" - q = torch.randn(total, num_heads, head_dim, device=device, dtype=dtype) - k = torch.randn(total, num_kv_heads, head_dim, device=device, dtype=dtype) - v = torch.randn(total, num_kv_heads, head_dim, device=device, dtype=dtype) - return q, k, v - - -def _make_varlen_meta(seq_lens, device="cuda"): - """Create b_start_loc and b_seq_len from a list of sequence lengths.""" - b_seq_len = torch.tensor(seq_lens, device=device, dtype=torch.int32) - b_start_loc = torch.zeros(len(seq_lens), device=device, dtype=torch.int32) - b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], dim=0) - return b_start_loc, b_seq_len - # --------------------------------------------------------------------------- # Forward correctness # --------------------------------------------------------------------------- -@pytest.fixture(scope="module") -def tiny_llama_dir(tmp_path_factory): - """Tiny Llama: 2 layers, 64 hidden, 4 q-heads, 2 kv-heads, head_dim=16.""" - from _test_utils.torch.transformers_models import create_tiny_llama_dir - - return create_tiny_llama_dir( - tmp_path_factory.mktemp("tiny_llama"), - with_tokenizer=True, - num_hidden_layers=2, - hidden_size=64, - num_attention_heads=4, - num_key_value_heads=2, - intermediate_size=64, - max_position_embeddings=64, - ) @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") class TestForward: - """Forward pass correctness for dense and sparse attention.""" + """Forward pass correctness for dense attention.""" @pytest.mark.parametrize( ("dtype", "num_heads", "num_kv_heads", "head_dim"), @@ -126,11 +60,11 @@ def test_prefill_matches_sdpa(self, dtype, num_heads, num_kv_heads, head_dim): scale = 1.0 / (head_dim**0.5) torch.manual_seed(123) - q, k, v = _make_qkv(total, num_heads, num_kv_heads, head_dim, dtype=dtype) - locs, lens = _make_varlen_meta(seq_lens) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim, dtype=dtype) + locs, lens = make_varlen_meta(seq_lens) o = attention(q, k, v, locs, lens, max(seq_lens), softmax_scale=scale) - torch.testing.assert_close(o, _sdpa_reference(q, k, v, locs, lens), rtol=1e-3, atol=1e-3) + torch.testing.assert_close(o, sdpa_reference(q, k, v, locs, lens), rtol=1e-3, atol=1e-3) def test_decode_matches_sdpa(self): """Dense decode matches SDPA.""" @@ -185,8 +119,8 @@ def test_sparse_disabled_matches_dense(self): scale = 1.0 / (64**0.5) torch.manual_seed(99) - q, k, v = _make_qkv(total, 4, 2, 64) - locs, lens = _make_varlen_meta(seq_lens) + q, k, v = make_qkv(total, 4, 2, 64) + locs, lens = make_varlen_meta(seq_lens) out_dense = attention(q, k, v, locs, lens, 128, softmax_scale=scale) out_n0 = attention(q, k, v, locs, lens, 128, softmax_scale=scale, sparsity_n=0) @@ -194,14 +128,16 @@ def test_sparse_disabled_matches_dense(self): # --------------------------------------------------------------------------- -# Backward correctness +# Backward correctness (dense) # --------------------------------------------------------------------------- + + @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") class TestBackward: - """Backward pass gradient correctness for dense and sparse attention.""" + """Backward pass gradient correctness for dense attention.""" def _sdpa_backward_ref(self, q, k, v, scale, is_causal=True): - """Run SDPA forward+backward, return output and gradients.""" + """Run SDPA forward+backward, return gradients.""" q_ref = q.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) k_ref = k.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) v_ref = v.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) @@ -221,18 +157,17 @@ def _sdpa_backward_ref(self, q, k, v, scale, is_causal=True): dv = v_ref.grad.permute(0, 2, 1, 3).squeeze(0) return dq.detach(), dk.detach(), dv.detach() - # --- Dense backward vs SDPA --- def test_dense_causal_matches_sdpa(self): """dQ, dK, dV match SDPA for causal self-attention.""" seq_len, num_heads, num_kv_heads, head_dim = 16, 2, 2, 32 scale = 1.0 / (head_dim**0.5) torch.manual_seed(42) - q, k, v = _make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q, k, v = make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) q.requires_grad_(True) k.requires_grad_(True) v.requires_grad_(True) - locs, lens = _make_varlen_meta([seq_len]) + locs, lens = make_varlen_meta([seq_len]) attention(q, k, v, locs, lens, seq_len, softmax_scale=scale).sum().backward() dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref(q.detach(), k.detach(), v.detach(), scale) @@ -247,11 +182,11 @@ def test_dense_gqa_matches_sdpa(self): scale = 1.0 / (head_dim**0.5) torch.manual_seed(43) - q, k, v = _make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q, k, v = make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) q.requires_grad_(True) k.requires_grad_(True) v.requires_grad_(True) - locs, lens = _make_varlen_meta([seq_len]) + locs, lens = make_varlen_meta([seq_len]) attention(q, k, v, locs, lens, seq_len, softmax_scale=scale).sum().backward() dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref(q.detach(), k.detach(), v.detach(), scale) @@ -268,11 +203,11 @@ def test_dense_multi_batch_variable_length(self): scale = 1.0 / (head_dim**0.5) torch.manual_seed(45) - q, k, v = _make_qkv(total, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim, dtype=torch.float32) q.requires_grad_(True) k.requires_grad_(True) v.requires_grad_(True) - locs, lens = _make_varlen_meta(seq_lens) + locs, lens = make_varlen_meta(seq_lens) attention(q, k, v, locs, lens, max(seq_lens), softmax_scale=scale).sum().backward() @@ -301,11 +236,11 @@ def test_dense_longer_sequences(self): scale = 1.0 / (head_dim**0.5) torch.manual_seed(49) - q, k, v = _make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q, k, v = make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) q.requires_grad_(True) k.requires_grad_(True) v.requires_grad_(True) - locs, lens = _make_varlen_meta([seq_len]) + locs, lens = make_varlen_meta([seq_len]) attention(q, k, v, locs, lens, seq_len, softmax_scale=scale).sum().backward() dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref(q.detach(), k.detach(), v.detach(), scale) @@ -314,216 +249,12 @@ def test_dense_longer_sequences(self): torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) - # --- Sparse backward sanity checks --- - - @pytest.mark.parametrize( - ("n", "m"), - [(2, 4), (4, 8)], - ids=["2:4", "4:8"], - ) - def test_sparse_gradients_finite(self, n, m): - """Backward with N:M sparsity produces finite, non-zero gradients.""" - seq_len, num_heads, num_kv_heads, head_dim = 128, 4, 2, 64 - scale = 1.0 / (head_dim**0.5) - - torch.manual_seed(55) - q, k, v = _make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) - q.requires_grad_(True) - k.requires_grad_(True) - v.requires_grad_(True) - locs, lens = _make_varlen_meta([seq_len]) - - attention( - q, - k, - v, - locs, - lens, - seq_len, - softmax_scale=scale, - sparsity_n=n, - sparsity_m=m, - ).sum().backward() - - for name, grad in [("dQ", q.grad), ("dK", k.grad), ("dV", v.grad)]: - assert grad is not None, f"{name} is None for {n}:{m}" - assert not torch.isnan(grad).any(), f"NaN in {name} for {n}:{m}" - assert not torch.isinf(grad).any(), f"Inf in {name} for {n}:{m}" - assert grad.abs().sum() > 0, f"{name} is all zeros for {n}:{m}" - - def test_sparse_gradients_differ_from_dense(self): - """Gradients with 2:4 sparsity should differ from dense gradients.""" - seq_len, num_heads, num_kv_heads, head_dim = 256, 4, 2, 64 - scale = 1.0 / (head_dim**0.5) - - torch.manual_seed(66) - q, k, v = _make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) - locs, lens = _make_varlen_meta([seq_len]) - - q_d = q.clone().requires_grad_(True) - k_d = k.clone().requires_grad_(True) - v_d = v.clone().requires_grad_(True) - attention(q_d, k_d, v_d, locs, lens, seq_len, softmax_scale=scale).sum().backward() - - q_s = q.clone().requires_grad_(True) - k_s = k.clone().requires_grad_(True) - v_s = v.clone().requires_grad_(True) - attention( - q_s, - k_s, - v_s, - locs, - lens, - seq_len, - softmax_scale=scale, - sparsity_n=2, - sparsity_m=4, - ).sum().backward() - - assert not torch.allclose(q_d.grad, q_s.grad, atol=1e-3), ( - "dQ same with and without sparsity" - ) - assert not torch.allclose(k_d.grad, k_s.grad, atol=1e-3), ( - "dK same with and without sparsity" - ) - # --------------------------------------------------------------------------- -# N:M sparsity +# HuggingFace integration # --------------------------------------------------------------------------- -@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") -class TestSparseNM: - """N:M structured sparsity behavior on attention scores (prefill only).""" - - def _make_inputs(self, batch=2, seq_len=256, num_heads=4, num_kv_heads=2, head_dim=64): - total = batch * seq_len - torch.manual_seed(99) - q, k, v = _make_qkv(total, num_heads, num_kv_heads, head_dim) - locs, lens = _make_varlen_meta([seq_len] * batch) - return q, k, v, locs, lens - - @pytest.mark.parametrize( - ("n", "m"), - [(1, 4), (2, 4), (3, 4), (1, 8), (2, 8), (4, 8)], - ids=["1:4", "2:4", "3:4", "1:8", "2:8", "4:8"], - ) - def test_output_shape(self, n, m): - """Output shape matches Q shape for all N:M patterns.""" - q, k, v, locs, lens = self._make_inputs() - out = attention( - q, k, v, locs, lens, 256, softmax_scale=1.0 / 8.0, sparsity_n=n, sparsity_m=m - ) - assert out.shape == q.shape - - @pytest.mark.parametrize( - ("n", "m"), - [(1, 4), (2, 4), (3, 4), (1, 8), (2, 8), (4, 8)], - ids=["1:4", "2:4", "3:4", "1:8", "2:8", "4:8"], - ) - def test_no_nan(self, n, m): - """All N:M patterns produce finite output.""" - q, k, v, locs, lens = self._make_inputs() - out = attention( - q, k, v, locs, lens, 256, softmax_scale=1.0 / 8.0, sparsity_n=n, sparsity_m=m - ) - assert not torch.isnan(out).any(), f"NaN in output for {n}:{m}" - assert not torch.isinf(out).any(), f"Inf in output for {n}:{m}" - @pytest.mark.parametrize( - ("n", "m"), - [(1, 4), (2, 4), (1, 8), (4, 8)], - ids=["1:4", "2:4", "1:8", "4:8"], - ) - def test_sparse_differs_from_dense(self, n, m): - """Sparse output should differ from dense for long sequences.""" - q, k, v, locs, lens = self._make_inputs(seq_len=512) - scale = 1.0 / (64**0.5) - out_dense = attention(q, k, v, locs, lens, 512, softmax_scale=scale) - out_sparse = attention( - q, k, v, locs, lens, 512, softmax_scale=scale, sparsity_n=n, sparsity_m=m - ) - assert not torch.allclose(out_sparse, out_dense, atol=1e-3) - - @pytest.mark.parametrize( - ("n_values", "m"), - [([1, 2, 3], 4), ([1, 2, 4], 8)], - ids=["m4", "m8"], - ) - def test_more_sparsity_more_error(self, n_values, m): - """Keeping more elements should deviate less from dense (monotonic decreasing error).""" - q, k, v, locs, lens = self._make_inputs(seq_len=512) - scale = 1.0 / (64**0.5) - out_dense = attention(q, k, v, locs, lens, 512, softmax_scale=scale) - errors = [] - for n in n_values: - out = attention( - q, k, v, locs, lens, 512, softmax_scale=scale, sparsity_n=n, sparsity_m=m - ) - errors.append((out - out_dense).abs().mean().item()) - for i in range(len(errors) - 1): - assert errors[i] > errors[i + 1], ( - f"Errors not monotonically decreasing for M={m}: " - + ", ".join(f"{n}:{m}={e:.6f}" for n, e in zip(n_values, errors)) - ) - - @pytest.mark.parametrize( - ("n", "m"), - [(2, 4), (4, 8)], - ids=["2:4", "4:8"], - ) - def test_dense_window_preserves_local(self, n, m): - """Large dense_window_blocks makes sparse output closer to dense.""" - q, k, v, locs, lens = self._make_inputs(seq_len=256) - scale = 1.0 / (64**0.5) - out_dense = attention(q, k, v, locs, lens, 256, softmax_scale=scale) - out_small = attention( - q, - k, - v, - locs, - lens, - 256, - softmax_scale=scale, - sparsity_n=n, - sparsity_m=m, - dense_window_blocks=1, - ) - out_large = attention( - q, - k, - v, - locs, - lens, - 256, - softmax_scale=scale, - sparsity_n=n, - sparsity_m=m, - dense_window_blocks=100, - ) - err_small = (out_small - out_dense).abs().mean().item() - err_large = (out_large - out_dense).abs().mean().item() - assert err_large < err_small - @pytest.mark.parametrize( - ("n", "m"), - [(1, 4), (2, 4), (3, 4), (1, 8), (2, 8), (4, 8)], - ids=["1:4", "2:4", "3:4", "1:8", "2:8", "4:8"], - ) - def test_sparsity_structure(self, n, m): - """Verify N:M structure: exactly N kept per group of M.""" - BM, BN = 32, 64 - torch.manual_seed(88) - tile = torch.randn(BM, BN, device="cuda", dtype=torch.float32) - out = torch.empty_like(tile) - _test_apply_sparse_nm[(1,)](tile, out, BLOCK_M=BM, BLOCK_N=BN, SPARSITY_N=n, SPARSITY_M=m) - - kept = (out.reshape(BM, BN // m, m) != float("-inf")).sum(dim=-1) - assert (kept == n).all(), f"Expected {n} kept per group of {m}, got min={kept.min()}, max={kept.max()}" - - -# --------------------------------------------------------------------------- -# HuggingFace integration -# --------------------------------------------------------------------------- @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") class TestHFIntegration: """HF model integration with Triton attention backend.""" @@ -601,3 +332,61 @@ def test_triton_padded_batch(self, tiny_llama_dir): with torch.no_grad(): logits = model(**inputs).logits assert not torch.isnan(logits).any() and not torch.isinf(logits).any() + + def test_sparse_nm_via_sparsify(self, tiny_llama_dir): + """mtsa.sparsify() with N:M sparse softmax produces finite logits that differ from dense.""" + pytest.importorskip("transformers") + from transformers import AutoModelForCausalLM, AutoTokenizer + + import modelopt.torch.sparsity.attention_sparsity as mtsa + + tok = AutoTokenizer.from_pretrained(tiny_llama_dir) + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + # Use a long input (fill max_position_embeddings=64) so sparsity has tiles to prune + ids = torch.randint(1, tok.vocab_size, (1, 64), device="cuda") + + # Dense baseline (triton backend, no sparsity) + model_dense = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="modelopt_triton", + torch_dtype=torch.bfloat16, + device_map="cuda", + ) + model_dense.eval() + with torch.no_grad(): + logits_dense = model_dense(input_ids=ids).logits + del model_dense + + # Sparse via mtsa.sparsify() with dense_window_blocks=0 to force sparsity on all tiles + sparse_cfg = { + "sparse_cfg": { + "*attn*": { + "method": "triton_sparse_softmax", + "sparsity_n": 2, + "sparsity_m": 4, + "num_sink_blocks": 0, + "dense_window_blocks": 0, + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + }, + } + model_sparse = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + torch_dtype=torch.bfloat16, + device_map="cuda", + ) + mtsa.sparsify(model_sparse, sparse_cfg) + model_sparse.eval() + with torch.no_grad(): + logits_sparse = model_sparse(input_ids=ids).logits + + # Sparse output should be finite + assert not torch.isnan(logits_sparse).any(), "NaN in sparse logits" + assert not torch.isinf(logits_sparse).any(), "Inf in sparse logits" + # Sparse output should differ from dense (sparsity changes attention) + assert not torch.allclose(logits_sparse, logits_dense, atol=1e-2), ( + "Sparse logits identical to dense — sparsity may not be applied" + ) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py new file mode 100644 index 000000000..7f89be325 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py @@ -0,0 +1,348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: N803 — Triton JIT wrapper uses uppercase for constexpr and tensor args + +"""GPU tests for N:M sparse softmax on the Triton flash attention kernel.""" + +import pytest +import torch +from conftest import make_qkv, make_varlen_meta + +pytestmark = [ + pytest.mark.filterwarnings("ignore::UserWarning"), + pytest.mark.filterwarnings("ignore::RuntimeWarning"), + pytest.mark.filterwarnings("ignore::DeprecationWarning"), +] + +from modelopt.torch.kernels import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE + +if TRITON_KERNEL_AVAILABLE: + import triton + import triton.language as tl + + from modelopt.torch.kernels import attention + from modelopt.torch.kernels.triton_fa import _apply_sparse_nm_to_qk_tile + + @triton.jit + def _test_apply_sparse_nm( + In, + Out, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + SPARSITY_N: tl.constexpr, + SPARSITY_M: tl.constexpr, + ): + """Test wrapper: apply N:M sparsity to a tile and store result.""" + offs = tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + qk = tl.load(In + offs) + tl.store( + Out + offs, + _apply_sparse_nm_to_qk_tile(qk, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M), + ) + + +# --------------------------------------------------------------------------- +# N:M sparsity behavior (prefill only) +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSparseNM: + """N:M sparse softmax behavior on attention scores.""" + + def _make_inputs(self, batch=2, seq_len=256, num_heads=4, num_kv_heads=2, head_dim=64): + total = batch * seq_len + torch.manual_seed(99) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + return q, k, v, locs, lens + + @pytest.mark.parametrize( + ("n", "m"), + [(1, 4), (2, 4), (3, 4), (1, 8), (2, 8), (4, 8)], + ids=["1:4", "2:4", "3:4", "1:8", "2:8", "4:8"], + ) + def test_output_shape(self, n, m): + """Output shape matches Q shape for all N:M patterns.""" + q, k, v, locs, lens = self._make_inputs() + out = attention( + q, k, v, locs, lens, 256, softmax_scale=1.0 / 8.0, sparsity_n=n, sparsity_m=m + ) + assert out.shape == q.shape + + @pytest.mark.parametrize( + ("n", "m"), + [(1, 4), (2, 4), (3, 4), (1, 8), (2, 8), (4, 8)], + ids=["1:4", "2:4", "3:4", "1:8", "2:8", "4:8"], + ) + def test_no_nan(self, n, m): + """All N:M patterns produce finite output.""" + q, k, v, locs, lens = self._make_inputs() + out = attention( + q, k, v, locs, lens, 256, softmax_scale=1.0 / 8.0, sparsity_n=n, sparsity_m=m + ) + assert not torch.isnan(out).any(), f"NaN in output for {n}:{m}" + assert not torch.isinf(out).any(), f"Inf in output for {n}:{m}" + + @pytest.mark.parametrize( + ("n", "m"), + [(1, 4), (2, 4), (1, 8), (4, 8)], + ids=["1:4", "2:4", "1:8", "4:8"], + ) + def test_sparse_differs_from_dense(self, n, m): + """Sparse output should differ from dense for long sequences.""" + q, k, v, locs, lens = self._make_inputs(seq_len=512) + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 512, softmax_scale=scale) + out_sparse = attention( + q, k, v, locs, lens, 512, softmax_scale=scale, sparsity_n=n, sparsity_m=m + ) + assert not torch.allclose(out_sparse, out_dense, atol=1e-3) + + @pytest.mark.parametrize( + ("n_values", "m"), + [([1, 2, 3], 4), ([1, 2, 4], 8)], + ids=["m4", "m8"], + ) + def test_more_sparsity_more_error(self, n_values, m): + """Keeping more elements should deviate less from dense (monotonic decreasing error).""" + q, k, v, locs, lens = self._make_inputs(seq_len=512) + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 512, softmax_scale=scale) + errors = [] + for n in n_values: + out = attention( + q, k, v, locs, lens, 512, softmax_scale=scale, sparsity_n=n, sparsity_m=m + ) + errors.append((out - out_dense).abs().mean().item()) + for i in range(len(errors) - 1): + assert errors[i] > errors[i + 1], ( + f"Errors not monotonically decreasing for M={m}: " + + ", ".join(f"{n}:{m}={e:.6f}" for n, e in zip(n_values, errors)) + ) + + @pytest.mark.parametrize( + ("n", "m"), + [(2, 4), (4, 8)], + ids=["2:4", "4:8"], + ) + def test_dense_window_preserves_local(self, n, m): + """Large dense_window_blocks makes sparse output closer to dense.""" + q, k, v, locs, lens = self._make_inputs(seq_len=256) + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 256, softmax_scale=scale) + out_small = attention( + q, + k, + v, + locs, + lens, + 256, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + dense_window_blocks=1, + ) + out_large = attention( + q, + k, + v, + locs, + lens, + 256, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + dense_window_blocks=100, + ) + err_small = (out_small - out_dense).abs().mean().item() + err_large = (out_large - out_dense).abs().mean().item() + assert err_large < err_small + + @pytest.mark.parametrize( + ("n", "m"), + [(2, 4), (4, 8)], + ids=["2:4", "4:8"], + ) + def test_sink_blocks_preserve_early_kv(self, n, m): + """num_sink_blocks keeps early KV blocks dense, reducing error vs fully sparse.""" + q, k, v, locs, lens = self._make_inputs(seq_len=512) + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 512, softmax_scale=scale) + out_no_sink = attention( + q, + k, + v, + locs, + lens, + 512, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + num_sink_blocks=0, + ) + out_with_sink = attention( + q, + k, + v, + locs, + lens, + 512, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + num_sink_blocks=2, + ) + err_no_sink = (out_no_sink - out_dense).abs().mean().item() + err_with_sink = (out_with_sink - out_dense).abs().mean().item() + assert err_with_sink < err_no_sink, ( + f"Sink tokens should reduce error: no_sink={err_no_sink:.6f}, with_sink={err_with_sink:.6f}" + ) + + # NOTE: N:M sparse attention is for prefill only, not decode. + + +# --------------------------------------------------------------------------- +# Sparsity tile structure +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSparseTileStructure: + """Direct unit tests for _apply_sparse_nm_to_qk_tile via wrapper kernel.""" + + @pytest.mark.parametrize( + ("n", "m"), + [(1, 4), (2, 4), (3, 4), (1, 8), (2, 8), (4, 8)], + ids=["1:4", "2:4", "3:4", "1:8", "2:8", "4:8"], + ) + def test_sparsity_structure(self, n, m): + """Verify N:M structure: exactly N kept per group of M.""" + bm, bn = 32, 64 + torch.manual_seed(88) + tile = torch.randn(bm, bn, device="cuda", dtype=torch.float32) + out = torch.empty_like(tile) + _test_apply_sparse_nm[(1,)](tile, out, BLOCK_M=bm, BLOCK_N=bn, SPARSITY_N=n, SPARSITY_M=m) + + kept = (out.reshape(bm, bn // m, m) != float("-inf")).sum(dim=-1) + assert (kept == n).all(), ( + f"Expected {n} kept per group of {m}, got min={kept.min()}, max={kept.max()}" + ) + + @pytest.mark.parametrize( + ("n", "m"), + [(2, 4), (4, 8)], + ids=["2:4", "4:8"], + ) + def test_sparsity_structure_ties(self, n, m): + """M=4 keeps exactly N on ties; M=8 (tl.sort) may keep >= N on ties.""" + bm, bn = 32, 64 + tile = torch.ones(bm, bn, device="cuda", dtype=torch.float32) + out = torch.empty_like(tile) + _test_apply_sparse_nm[(1,)](tile, out, BLOCK_M=bm, BLOCK_N=bn, SPARSITY_N=n, SPARSITY_M=m) + + kept = (out.reshape(bm, bn // m, m) != float("-inf")).sum(dim=-1) + if m == 4: + assert (kept == n).all(), ( + f"M=4 tie: expected {n}, got min={kept.min()}, max={kept.max()}" + ) + else: + assert (kept >= n).all(), f"M=8 tie: expected >= {n}, got min={kept.min()}" + assert (kept <= m).all(), f"M=8 tie: expected <= {m}, got max={kept.max()}" + + +# --------------------------------------------------------------------------- +# Sparse backward sanity +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSparseBackward: + """Backward pass sanity checks with N:M sparsity enabled.""" + + @pytest.mark.parametrize( + ("n", "m"), + [(2, 4), (4, 8)], + ids=["2:4", "4:8"], + ) + def test_sparse_gradients_finite(self, n, m): + """Backward with N:M sparsity produces finite, non-zero gradients.""" + seq_len, num_heads, num_kv_heads, head_dim = 128, 4, 2, 64 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(55) + q, k, v = make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + locs, lens = make_varlen_meta([seq_len]) + + attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + ).sum().backward() + + for name, grad in [("dQ", q.grad), ("dK", k.grad), ("dV", v.grad)]: + assert grad is not None, f"{name} is None for {n}:{m}" + assert not torch.isnan(grad).any(), f"NaN in {name} for {n}:{m}" + assert not torch.isinf(grad).any(), f"Inf in {name} for {n}:{m}" + assert grad.abs().sum() > 0, f"{name} is all zeros for {n}:{m}" + + def test_sparse_gradients_differ_from_dense(self): + """Gradients with 2:4 sparsity should differ from dense gradients.""" + seq_len, num_heads, num_kv_heads, head_dim = 256, 4, 2, 64 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(66) + q, k, v = make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + locs, lens = make_varlen_meta([seq_len]) + + q_d = q.clone().requires_grad_(True) + k_d = k.clone().requires_grad_(True) + v_d = v.clone().requires_grad_(True) + attention(q_d, k_d, v_d, locs, lens, seq_len, softmax_scale=scale).sum().backward() + + q_s = q.clone().requires_grad_(True) + k_s = k.clone().requires_grad_(True) + v_s = v.clone().requires_grad_(True) + attention( + q_s, + k_s, + v_s, + locs, + lens, + seq_len, + softmax_scale=scale, + sparsity_n=2, + sparsity_m=4, + ).sum().backward() + + assert not torch.allclose(q_d.grad, q_s.grad, atol=1e-3), ( + "dQ same with and without sparsity" + ) + assert not torch.allclose(k_d.grad, k_s.grad, atol=1e-3), ( + "dK same with and without sparsity" + ) + assert not torch.allclose(v_d.grad, v_s.grad, atol=1e-3), ( + "dV same with and without sparsity" + )