Add 2:4 sparse softmax to the Triton flash attention kernel#1078
Add 2:4 sparse softmax to the Triton flash attention kernel#1078
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (5)
📝 WalkthroughWalkthroughAdded N:M sparse softmax support to Triton flash attention kernel. Changelog entry documents the feature. Implementation adds sparsity masking helpers and extends forward/backward kernels with new parameters. Public API updated with sparsity configuration. Test infrastructure refactored with shared utilities and new sparsity tests. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
📝 Coding Plan
Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1078 +/- ##
==========================================
- Coverage 70.30% 70.28% -0.03%
==========================================
Files 227 227
Lines 25857 25857
==========================================
- Hits 18179 18173 -6
- Misses 7678 7684 +6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Kai Xu <kaix@nvidia.com>
7812e52 to
a9430be
Compare
Signed-off-by: Kai Xu <kaix@nvidia.com>
8ba6efe to
7aa6960
Compare
What does this PR do?
Type of change: ?
Type of change: New feature
Add N:M structured sparsity support to the Triton flash attention kernel (
modelopt/torch/kernels/triton_fa.py). For every M consecutive key positions in the attention score tile, keeps the top-N values and sets the rest to -inf before softmax. This is applied during prefill only.Supported patterns: Any N:M where M=4 (N=1,2,3) or M=8 (N=1..4).
Performance (TFLOPS at seq_len=16384, RTX 6000):
Usage
Testing
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
Release Notes
New Features
Tests