Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ NVIDIA Model Optimizer Changelog
**New Features**

- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics.
- Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). For every M consecutive key positions, the top-N attention scores are kept and the rest are set to -inf before softmax. See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.

**Bug Fixes**

Expand Down
9 changes: 9 additions & 0 deletions modelopt/torch/kernels/hf_triton_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ def triton_attention_forward(
kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32)
kw["max_input_len_k"] = seq_k

# N:M sparse softmax
if getattr(module, "_apply_sparse_nm", False):
method = getattr(module, "_sparse_method_instance", None)
if method is not None:
kw["sparsity_n"] = getattr(method, "sparsity_n", 2)
kw["sparsity_m"] = getattr(method, "sparsity_m", 4)
kw["num_sink_blocks"] = getattr(method, "num_sink_blocks", 0)
kw["dense_window_blocks"] = getattr(method, "dense_window_blocks", 1)

o = attention(q, k, v, **kw)

attn_output = o.view(batch, seq_len, num_heads, head_dim)
Expand Down
206 changes: 204 additions & 2 deletions modelopt/torch/kernels/triton_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,116 @@
_FWD_CONFIGS = [triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=1, num_warps=4)]


# ---------------------------------------------------------------------------
# N:M sparse softmax helpers
# ---------------------------------------------------------------------------
@triton.jit
def _sparse_nm_masks_m4(x0, x1, x2, x3, N: tl.constexpr):
"""Top-N of 4 selection via pure boolean logic (6 comparisons, no int casts).

Uses ``>=`` so that ties are broken by index (lower index wins).
Guarantees exactly N masks are True for any input including all-equal.

Boolean formulas for "at least K of 3 wins":
K=3 (N=1): AND of all — must beat all 3 others
K=2 (N=2): majority — must beat at least 2 (sorting network)
K=1 (N=3): OR of all — must beat at least 1
"""
c01 = x0 >= x1
c02 = x0 >= x2
c03 = x0 >= x3
c12 = x1 >= x2
c13 = x1 >= x3
c23 = x2 >= x3

nc01 = ~c01
nc02 = ~c02
nc03 = ~c03
nc12 = ~c12
nc13 = ~c13
nc23 = ~c23

if N == 1:
# Keep max only: must beat all 3
m0 = c01 & c02 & c03
m1 = nc01 & c12 & c13
m2 = nc02 & nc12 & c23
m3 = nc03 & nc13 & nc23
elif N == 2:
# Majority vote: must beat at least 2 of 3
m0 = (c01 & c02) | (c01 & c03) | (c02 & c03)
m1 = (nc01 & c12) | (nc01 & c13) | (c12 & c13)
m2 = (nc02 & nc12) | (nc02 & c23) | (nc12 & c23)
m3 = (nc03 & nc13) | (nc03 & nc23) | (nc13 & nc23)
elif N == 3:
# Keep all but min: must beat at least 1
m0 = c01 | c02 | c03
m1 = nc01 | c12 | c13
m2 = nc02 | nc12 | c23
m3 = nc03 | nc13 | nc23
else:
tl.static_assert(False, "N must be 1, 2, or 3 for M=4")

return m0, m1, m2, m3


@triton.jit
def _apply_sparse_nm_to_qk_tile(
qk,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
SPARSITY_N: tl.constexpr,
SPARSITY_M: tl.constexpr,
):
"""Apply N:M sparse softmax to a QK score tile.

For every ``SPARSITY_M`` consecutive elements along the N (key) dimension,
keeps the top ``SPARSITY_N`` values and sets the rest to ``-inf``.
``BLOCK_N`` must be divisible by ``SPARSITY_M``.

For M=4, exactly N values are retained (ties broken by position).
For M=8, a threshold-based approach (``tl.sort``) may retain more
than N values when ties straddle the threshold boundary.
"""
tl.static_assert(SPARSITY_M == 4 or SPARSITY_M == 8, "SPARSITY_M must be 4 or 8") # noqa: PLR1714
MASK_VAL: tl.constexpr = float("-inf")

if SPARSITY_M == 4:
tl.static_assert(BLOCK_N % 4 == 0, "BLOCK_N must be divisible by 4")
reshaped = tl.reshape(qk, (BLOCK_M, BLOCK_N // 4, 4))
cols = tl.arange(0, 4)[None, None, :]
x0 = tl.sum(tl.where(cols == 0, reshaped, 0.0), axis=2)
x1 = tl.sum(tl.where(cols == 1, reshaped, 0.0), axis=2)
x2 = tl.sum(tl.where(cols == 2, reshaped, 0.0), axis=2)
x3 = tl.sum(tl.where(cols == 3, reshaped, 0.0), axis=2)

m0, m1, m2, m3 = _sparse_nm_masks_m4(x0, x1, x2, x3, SPARSITY_N)

out = tl.full((BLOCK_M, BLOCK_N // 4, 4), 0.0, dtype=qk.dtype)
out = tl.where(cols == 0, tl.expand_dims(tl.where(m0, x0, MASK_VAL), 2), out)
out = tl.where(cols == 1, tl.expand_dims(tl.where(m1, x1, MASK_VAL), 2), out)
out = tl.where(cols == 2, tl.expand_dims(tl.where(m2, x2, MASK_VAL), 2), out)
out = tl.where(cols == 3, tl.expand_dims(tl.where(m3, x3, MASK_VAL), 2), out)
return tl.reshape(out, (BLOCK_M, BLOCK_N))

else: # SPARSITY_M == 8
tl.static_assert(BLOCK_N % 8 == 0, "BLOCK_N must be divisible by 8")
reshaped = tl.reshape(qk, (BLOCK_M, BLOCK_N // 8, 8))

# Sort each group of 8 ascending; N-th largest is at index (8 - N)
sorted_vals = tl.sort(reshaped, dim=2)
KTH_IDX: tl.constexpr = SPARSITY_M - SPARSITY_N # index of N-th largest in ascending order

# Extract the threshold value at KTH_IDX via masked sum
# Use 0.0 as fill (not -inf) so sum equals just the KTH element
cols = tl.arange(0, 8)[None, None, :]
threshold = tl.sum(tl.where(cols == KTH_IDX, sorted_vals, 0.0), axis=2)

# Mask: keep elements >= threshold (may keep >N on ties — acceptable)
mask = reshaped >= tl.expand_dims(threshold, 2)
return tl.reshape(tl.where(mask, reshaped, MASK_VAL), (BLOCK_M, BLOCK_N))


# ---------------------------------------------------------------------------
# Masking helper
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -105,6 +215,10 @@ def _attn_fwd(
IS_CAUSAL: tl.constexpr, # Whether to apply causal mask
HEAD_DIM: tl.constexpr, # Actual head dimension (for d_mask)
STORE_LSE: tl.constexpr, # Whether to save LSE for backward pass
SPARSITY_N: tl.constexpr = 0, # N:M sparsity — keep top-N of every M elements (0 = disabled)
SPARSITY_M: tl.constexpr = 4, # N:M sparsity — group size (4 or 8)
NUM_SINK_BLOCKS: tl.constexpr = 0, # Leading KV blocks kept dense (attention sinks)
DENSE_WINDOW_BLOCKS: tl.constexpr = 1, # Local blocks near diagonal kept dense
):
# --- Grid: (batch, num_q_heads, num_q_tiles) ---
# Example: batch=2, num_q_heads=32, seq_len=256, BLOCK_M=128
Expand Down Expand Up @@ -162,6 +276,21 @@ def _attn_fwd(
scores = tl.dot(q, k) * qk_scale
scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL)

# --- Optional N:M sparse softmax ---
if SPARSITY_N > 0:
# Check if this KV tile should be kept dense
kv_block_idx = kv_start // BLOCK_N
is_sink = kv_block_idx < NUM_SINK_BLOCKS
# causal_offset handles chunked prefill: q starts at (seq_len_kv - seq_len_q)
causal_offset = seq_len_kv - seq_len_q
q_abs_block = (tile_q * BLOCK_M + causal_offset) // BLOCK_N
block_distance = q_abs_block - kv_block_idx
is_local = (block_distance < DENSE_WINDOW_BLOCKS) and (block_distance >= 0)
if not is_sink and not is_local:
scores = _apply_sparse_nm_to_qk_tile(
scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M
)
Comment on lines +279 to +292
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Skip N:M masking during decode.

This branch only checks SPARSITY_N > 0, so cached decode (seq_len_q == 1 with separate KV metadata) gets sparsified too. The feature is described as prefill-only; without a decode guard, generation will start pruning the KV cache as soon as sparse mode is enabled.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 279 - 292, The N:M sparsity
branch is being applied during decode because it only checks SPARSITY_N > 0;
change the condition to skip sparsification when doing cached decode (seq_len_q
== 1). Update the if guarding the block (the one that currently reads "if
SPARSITY_N > 0:") to also require not decoding (e.g., "if SPARSITY_N > 0 and
seq_len_q != 1:") so _apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N,
SPARSITY_N, SPARSITY_M) is only called during prefill/non-decoding paths; keep
the existing local/sink logic (kv_start, tile_q, q_abs_block, is_local/is_sink)
unchanged.

⚠️ Potential issue | 🔴 Critical

Don't key the sparse mask to autotuned tile sizes.

num_sink_blocks and dense_window_blocks are interpreted in units of BLOCK_N, but forward autotunes BLOCK_N over 32/64/128 while both backward kernels hardcode BLOCK_N=64. That means forward and backward can apply different sparse masks, and the forward q_abs_block check is already wrong for later rows in a tile whenever BLOCK_M > BLOCK_N. Please normalize these regions to a fixed logical token block size (or token counts) and derive locality from each row's absolute position before reusing the same rule in forward and backward.

Also applies to: 480-491, 620-631

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 279 - 292, The sparse-mask
logic currently mixes tile-sized units (BLOCK_M/BLOCK_N) with sparsity
parameters (NUM_SINK_BLOCKS, DENSE_WINDOW_BLOCKS) leading to inconsistent masks
between forward/backward; fix by computing mask membership in token-space
instead of tile-space: derive each KV block index and query row absolute token
position from actual token counts (use seq_len_kv, seq_len_q, kv_start and
per-row start = tile_q * BLOCK_M + row_offset or for whole-tile use
tile_token_start = tile_q * BLOCK_M) and then map those token positions into
logical token-blocks of a fixed reference block size (choose the constant used
by backward kernels, e.g., 64 tokens) before comparing to NUM_SINK_BLOCKS and
DENSE_WINDOW_BLOCKS; update q_abs_block, kv_block_idx, is_sink and is_local
computations (the branch that calls _apply_sparse_nm_to_qk_tile) and apply same
normalization in the other occurrences mentioned (around the other two
locations) so forward and backward use the same token-blocking semantics.


# --- Online softmax update ---
# 1. Update running max
m_new = tl.maximum(row_max, tl.max(scores, 1))
Expand Down Expand Up @@ -278,6 +407,10 @@ def _attn_bwd_dq(
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
HEAD_DIM: tl.constexpr,
SPARSITY_N: tl.constexpr = 0,
SPARSITY_M: tl.constexpr = 4,
NUM_SINK_BLOCKS: tl.constexpr = 0,
DENSE_WINDOW_BLOCKS: tl.constexpr = 1,
):
"""Phase 3 of backward: compute dQ for one Q tile, looping over KV tiles.

Expand Down Expand Up @@ -343,6 +476,20 @@ def _attn_bwd_dq(
# Recompute attention: S = Q @ K^T, P = exp2(S - LSE)
scores = tl.dot(q, kT) * qk_scale
scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL)

# Re-apply N:M sparse softmax to match forward pass
if SPARSITY_N > 0:
kv_block_idx = kv_start // BLOCK_N
is_sink = kv_block_idx < NUM_SINK_BLOCKS
causal_offset = seq_len_kv - seq_len_q
q_abs_block = (tile_q * BLOCK_M + causal_offset) // BLOCK_N
block_distance = q_abs_block - kv_block_idx
is_local = (block_distance < DENSE_WINDOW_BLOCKS) and (block_distance >= 0)
if not is_sink and not is_local:
scores = _apply_sparse_nm_to_qk_tile(
scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M
)

p = tl.math.exp2(scores - lse[:, None])

# dP = dO @ V^T, dS = P * (dP - delta), dQ += dS @ K
Expand Down Expand Up @@ -392,6 +539,10 @@ def _attn_bwd_dkdv(
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
HEAD_DIM: tl.constexpr,
SPARSITY_N: tl.constexpr = 0,
SPARSITY_M: tl.constexpr = 4,
NUM_SINK_BLOCKS: tl.constexpr = 0,
DENSE_WINDOW_BLOCKS: tl.constexpr = 1,
):
"""Phase 2 of backward: compute dK, dV for one KV tile.

Expand Down Expand Up @@ -465,6 +616,20 @@ def _attn_bwd_dkdv(
# Recompute attention: S = Q @ K^T, P = exp2(S - LSE)
scores = tl.dot(q_tile, kT) * qk_scale
scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL)

# Re-apply N:M sparse softmax to match forward pass
if SPARSITY_N > 0:
kv_block_idx = kv_start // BLOCK_N
is_sink = kv_block_idx < NUM_SINK_BLOCKS
causal_offset = seq_len_kv - seq_len_q
q_abs_block = (qi * BLOCK_M + causal_offset) // BLOCK_N
block_distance = q_abs_block - kv_block_idx
is_local = (block_distance < DENSE_WINDOW_BLOCKS) and (block_distance >= 0)
if not is_sink and not is_local:
scores = _apply_sparse_nm_to_qk_tile(
scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M
)

p = tl.math.exp2(scores - lse[:, None])

# dV += P^T @ dO
Expand Down Expand Up @@ -498,6 +663,10 @@ def forward(
b_start_loc_k,
b_seq_len_k,
max_input_len_k,
sparsity_n,
sparsity_m,
num_sink_blocks,
dense_window_blocks,
):
HEAD_DIM = q.shape[2]
num_q_heads = q.shape[1]
Expand Down Expand Up @@ -552,6 +721,10 @@ def grid(META):
IS_CAUSAL=is_causal,
HEAD_DIM=HEAD_DIM,
STORE_LSE=True,
SPARSITY_N=sparsity_n,
SPARSITY_M=sparsity_m,
NUM_SINK_BLOCKS=num_sink_blocks,
DENSE_WINDOW_BLOCKS=dense_window_blocks,
# BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune
)

Expand All @@ -566,6 +739,10 @@ def grid(META):
ctx.num_q_heads = num_q_heads
ctx.num_kv_heads = num_kv_heads
ctx.batch = batch
ctx.sparsity_n = sparsity_n
ctx.sparsity_m = sparsity_m
ctx.num_sink_blocks = num_sink_blocks
ctx.dense_window_blocks = dense_window_blocks
return o

@staticmethod
Expand Down Expand Up @@ -640,6 +817,10 @@ def backward(ctx, grad_output):
BLOCK_N=BLOCK,
IS_CAUSAL=ctx.is_causal,
HEAD_DIM=HEAD_DIM,
SPARSITY_N=ctx.sparsity_n,
SPARSITY_M=ctx.sparsity_m,
NUM_SINK_BLOCKS=ctx.num_sink_blocks,
DENSE_WINDOW_BLOCKS=ctx.dense_window_blocks,
num_warps=num_warps,
num_stages=1,
)
Expand All @@ -659,11 +840,15 @@ def backward(ctx, grad_output):
BLOCK_N=BLOCK,
IS_CAUSAL=ctx.is_causal,
HEAD_DIM=HEAD_DIM,
SPARSITY_N=ctx.sparsity_n,
SPARSITY_M=ctx.sparsity_m,
NUM_SINK_BLOCKS=ctx.num_sink_blocks,
DENSE_WINDOW_BLOCKS=ctx.dense_window_blocks,
num_warps=num_warps,
num_stages=1,
)

return dq, dk, dv, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None


def attention(
Expand All @@ -678,8 +863,13 @@ def attention(
b_start_loc_k: torch.Tensor | None = None,
b_seq_len_k: torch.Tensor | None = None,
max_input_len_k: int | None = None,
*,
sparsity_n: int = 0,
sparsity_m: int = 4,
num_sink_blocks: int = 0,
dense_window_blocks: int = 1,
) -> torch.Tensor:
"""Variable-length flash attention with GQA and autograd support.
"""Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax.

Args:
q: [total_q_tokens, num_q_heads, head_dim]
Expand All @@ -693,6 +883,14 @@ def attention(
b_start_loc_k: [batch] start offset for K/V (None = same as Q).
b_seq_len_k: [batch] length for K/V (None = same as Q).
max_input_len_k: Maximum K/V sequence length (None = same as Q).
sparsity_n: N:M sparsity — keep top-N of every M attention scores
along the key dimension. Set to 0 to disable. Examples:
``sparsity_n=2, sparsity_m=4`` for 2:4 sparsity;
``sparsity_n=4, sparsity_m=8`` for 4:8 sparsity.
sparsity_m: N:M sparsity — group size (4 or 8).
num_sink_blocks: Number of leading KV blocks to keep dense (attention sinks).
dense_window_blocks: KV blocks within this distance from the query
diagonal are kept dense (local attention window).

Returns:
Output tensor [total_q_tokens, num_q_heads, head_dim].
Expand All @@ -710,6 +908,10 @@ def attention(
b_start_loc_k,
b_seq_len_k,
max_input_len_k,
sparsity_n,
sparsity_m,
num_sink_blocks,
dense_window_blocks,
)


Expand Down
Loading
Loading