Skip to content

Add 2:4 sparse softmax to the Triton flash attention kernel#1078

Open
kaix-nv wants to merge 2 commits intomainfrom
kaix/triton_fa_sparse24
Open

Add 2:4 sparse softmax to the Triton flash attention kernel#1078
kaix-nv wants to merge 2 commits intomainfrom
kaix/triton_fa_sparse24

Conversation

@kaix-nv
Copy link
Contributor

@kaix-nv kaix-nv commented Mar 19, 2026

What does this PR do?

Type of change: ?

Type of change: New feature

Add N:M structured sparsity support to the Triton flash attention kernel (modelopt/torch/kernels/triton_fa.py). For every M consecutive key positions in the attention score tile, keeps the top-N values and sets the rest to -inf before softmax. This is applied during prefill only.

Supported patterns: Any N:M where M=4 (N=1,2,3) or M=8 (N=1..4).

  • Sink tokens and dense window blocks for preserving local attention and attention sinks

Performance (TFLOPS at seq_len=16384, RTX 6000):

Pattern TFLOPS % of Dense
Dense 89.3 100%
2:4 (M=4) 69.5 78%
4:8 (M=8) 57.3 64%

Usage

# Add a code snippet demonstrating how to use this
from modelopt.torch.kernels import attention

# 2:4 sparsity (keep top 2 of every 4 K positions)
out = attention(q, k, v, b_start_loc, b_seq_len, max_len,
                sparsity_n=2, sparsity_m=4)

# 4:8 sparsity with sink tokens and dense window
out = attention(q, k, v, b_start_loc, b_seq_len, max_len,
                sparsity_n=4, sparsity_m=8,
                num_sink_tokens=4, dense_window_blocks=2)

# Dense (default, zero overhead)
out = attention(q, k, v, b_start_loc, b_seq_len, max_len)

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features

    • Added N:M structured sparse softmax support to the flash attention kernel, enabling selective computation where only top-N attention scores are retained within each group of M tokens, with configurable controls for attention sinks and dense window regions.
  • Tests

    • Significantly expanded test coverage for sparse attention operations and kernel correctness validation.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 19, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 19, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: fb5da2f5-8cf7-4acd-93d1-2af90199751c

📥 Commits

Reviewing files that changed from the base of the PR and between 839fa3d and 7aa6960.

📒 Files selected for processing (5)
  • CHANGELOG.rst
  • modelopt/torch/kernels/triton_fa.py
  • tests/gpu/torch/sparsity/attention_sparsity/conftest.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py

📝 Walkthrough

Walkthrough

Added N:M sparse softmax support to Triton flash attention kernel. Changelog entry documents the feature. Implementation adds sparsity masking helpers and extends forward/backward kernels with new parameters. Public API updated with sparsity configuration. Test infrastructure refactored with shared utilities and new sparsity tests.

Changes

Cohort / File(s) Summary
Documentation
CHANGELOG.rst
Added unreleased changelog entry for version 0.44 documenting N:M sparse softmax support feature for the Triton flash attention kernel.
Core Implementation
modelopt/torch/kernels/triton_fa.py
Added Triton JIT helpers (_sparse_nm_masks_m4, _apply_sparse_nm_to_qk_tile) for N:M structured sparsity masking. Extended forward attention kernel with sparsity parameters (SPARSITY_N, SPARSITY_M, NUM_SINK_TOKENS, DENSE_WINDOW_BLOCKS) and conditional sparsity application. Mirrored sparsity logic in backward kernels. Updated autograd wrapper and public attention(...) API to accept and propagate sparsity configuration.
Test Infrastructure
tests/gpu/torch/sparsity/attention_sparsity/conftest.py
Added shared test utilities: make_qkv for synthetic tensor creation, make_varlen_meta for attention metadata, sdpa_reference for PyTorch reference implementation, and tiny_llama_dir fixture for model setup.
Test Refactoring
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
Refactored tests to use shared utilities from conftest. Renamed test classes (TestTritonFaVsSdpaTestForward, backward tests prefixed with test_dense_). Added test_sparse_disabled_matches_dense assertion. Reorganized HF integration tests into separate TestHFIntegration class.
New Sparsity Tests
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
Added new GPU test module for N:M sparsity validation: TestSparseNM for end-to-end attention outputs across sparsity configurations, TestSparseTileStructure for tile-level sparsity pattern verification, and TestSparseBackward for gradient sanity checks under sparsity.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add 2:4 sparse softmax to the Triton flash attention kernel' accurately describes the main feature being introduced—N:M structured sparsity support with a specific 2:4 pattern example. It is clear, specific, and directly reflects the primary changes across the codebase.
Docstring Coverage ✅ Passed Docstring coverage is 88.24% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed Pull request contains no security anti-patterns as defined in SECURITY.md across all six criteria.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kaix/triton_fa_sparse24
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Mar 19, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 70.28%. Comparing base (839fa3d) to head (7aa6960).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1078      +/-   ##
==========================================
- Coverage   70.30%   70.28%   -0.03%     
==========================================
  Files         227      227              
  Lines       25857    25857              
==========================================
- Hits        18179    18173       -6     
- Misses       7678     7684       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Kai Xu <kaix@nvidia.com>
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_sparse24 branch from 7812e52 to a9430be Compare March 20, 2026 00:20
Signed-off-by: Kai Xu <kaix@nvidia.com>
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_sparse24 branch from 8ba6efe to 7aa6960 Compare March 20, 2026 04:47
@kaix-nv kaix-nv marked this pull request as ready for review March 20, 2026 05:16
@kaix-nv kaix-nv requested a review from a team as a code owner March 20, 2026 05:16
@kaix-nv kaix-nv requested review from ChenhanYu, Edwardf0t1, cjluo-nv, kevalmorabia97 and rohansjoshi and removed request for ChenhanYu and Edwardf0t1 March 20, 2026 05:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant