Skip to content

Skip softmax calibration via Triton kernel#1597

Open
rohansjoshi wants to merge 6 commits into
mainfrom
rohjoshi/triton-ss-calib
Open

Skip softmax calibration via Triton kernel#1597
rohansjoshi wants to merge 6 commits into
mainfrom
rohjoshi/triton-ss-calib

Conversation

@rohansjoshi
Copy link
Copy Markdown
Contributor

@rohansjoshi rohansjoshi commented Jun 2, 2026

What does this PR do?

Adds skip softmax calibration for LLMs via Triton kernel (leveraging existing kernel used for diffusion)

Type of change: New feature

Usage

python hf_sa.py --pyt_ckpt_path Qwen/Qwen3-8B --sparse_attn skip_softmax_triton_calib

The Triton calibration equals PyTorch at every threshold, for both phases:

threshold prefill triton/pytorch decode triton/pytorch
0.30 0.0% / 0.0% 12.5% / 12.5%
0.50 0.0% / 0.0% 37.5% / 37.5%
0.70 10.0% / 10.0% 62.5% / 62.5%

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A
  • Did you get Claude approval on this PR?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Added a Triton-based skip-softmax sparse-attention calibration option and a CLI flag to override the calibration data directory (defaults to adjacent RULER data).
  • Bug Fixes

    • Ensure calibration kernels run on the correct CUDA device; align measurement granularity and tile/block sizing; ignore padded query rows when counting skippable tiles.
  • Tests

    • Added GPU Triton calibration tests for end-to-end inference, multi-threshold stats, and decode-phase reporting.
  • Documentation

    • Updated changelog and example to expose the new option and flag.

@rohansjoshi rohansjoshi requested review from a team as code owners June 2, 2026 00:53
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 2, 2026

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Integrates Triton skip-softmax calibration into HF attention: adds SKIP_SOFTMAX_TRITON_CALIB and CLI wiring, moves HF calibration state onto the Triton method, binds Triton kernel launches to the tensors' CUDA device, excludes padded Q rows from calibration decisions, and adds CUDA+Triton tests.

Changes

HF Triton Skip-Softmax Calibration Support

Layer / File(s) Summary
Calibration config, export, and changelog
modelopt/torch/sparsity/attention_sparsity/config.py, CHANGELOG.rst
Adds SKIP_SOFTMAX_TRITON_CALIB config dict with RULER calibration parameters, enables triton skip-softmax pattern, exports the constant, and documents the option in the changelog.
Example script: register config and calib data dir
examples/llm_sparsity/attention_sparsity/hf_sa.py
Imports and registers skip_softmax_triton_calib in SPARSE_ATTN_CFG_CHOICES, adds --calib_data_dir, and sets a default RULER data_dir to the script-adjacent data folder when calibration is a dict.
HF attention: calibration-mode routing
modelopt/torch/kernels/common/attention/hf_triton_attention.py
Store calibration config/counters on module._sparse_method_instance; when calibration mode is enabled, call attention_calibrate, accumulate HF counters/seq_k/is_decode on the method, and return calibrated HF-layout outputs early.
TritonSkipSoftmaxMethod HF wiring and stats
modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
Add per-instance HF calibration accumulator fields, reset them per calibration forward, prefer HF-stored counters if other backends yield none, and record phase (prefill/decode) in emitted _last_stats.
Calibration kernel: padding exclusion and device binding
modelopt/torch/kernels/sparsity/attention/calibrate.py
Exclude padded Q rows from max-gap reduction (set their gap to -inf) to avoid padded rows inflating skip decisions; change calibration BLOCK_N to 128 and bind attention_calibrate Triton launch to q.device via torch.cuda.device.
triton_fa: device guards for forward/backward kernels
modelopt/torch/kernels/common/attention/triton_fa.py
Update measurement tile sizing and wrap Triton forward and backward kernel launches and the backward preprocess in with torch.cuda.device(q.device): so kernels run on the tensors' CUDA device.
Skip-softmax helper: padding-aware decision
modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py
Extend _skip_softmax_decision to accept q_pos and seq_len_q, mark padding rows as skippable, and update per-tile decision logic accordingly.
Decode calibration backend selection
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py
Capture and restore the model's original sparse-attention backend for decode-time measurement (use SDPA for prefill then restore original backend for decode measurement) instead of forcing eager.
GPU tests: HF calibration end-to-end and padding checks
tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py, tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py
Add CUDA+Triton tests: end-to-end tiny-Llama calibration/inference checks and decode-branch reporting; add a calibrate kernel test validating padding-row handling (monotonicity and bounds) and update expected measured tile counts.

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 77.42% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Skip softmax calibration via Triton kernel' accurately and concisely summarizes the main feature added: enabling skip-softmax attention calibration using a Triton kernel for sparse attention optimization.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns detected. New code avoids torch.load(weights_only=False), hardcoded trust_remote_code=True, eval/exec, pickle, nosec comments, and unsafe deserialization patterns.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch rohjoshi/triton-ss-calib

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 2, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1597/

Built to branch gh-pages at 2026-06-03 01:20 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/llm_sparsity/attention_sparsity/hf_sa.py (1)

176-182: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Fix --target_sparse_ratio override to avoid enabling decode calibration when Triton skip-softmax is prefill-only.
calibration/calibrate.py gates decode calibration on target_sparse_ratio["decode"] > 0 (via calibrate_decode = target_dict.get("decode", 0.0) > 0.0), so the current override in examples/llm_sparsity/attention_sparsity/hf_sa.py will force decode to be enabled even when the selected Triton calibration config is meant to be prefill-only (SKIP_SOFTMAX_TRITON_CALIB). Preserve the existing target_sparse_ratio phase set instead of unconditionally writing both phases.

🔧 Proposed fix to only override phases already present
     if args.target_sparse_ratio is not None:
         calib = sparse_cfg.setdefault("calibration", {})
         assert isinstance(calib, dict)
-        calib["target_sparse_ratio"] = {
-            "prefill": args.target_sparse_ratio,
-            "decode": args.target_sparse_ratio,
-        }
+        existing = calib.get("target_sparse_ratio", {"prefill": 0.5, "decode": 0.5})
+        calib["target_sparse_ratio"] = {
+            phase: args.target_sparse_ratio for phase in existing
+        }
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/llm_sparsity/attention_sparsity/hf_sa.py` around lines 176 - 182,
The current override always writes both "prefill" and "decode" into
calib["target_sparse_ratio"], which forces decode calibration on even for
prefill-only configs; instead, fetch the existing target dict (e.g., target =
calib.get("target_sparse_ratio", {})) and only update the phases that are
already present in that dict (or at least only set "prefill" if "decode" is not
present) when applying args.target_sparse_ratio; in practice modify the block
that uses sparse_cfg.setdefault("calibration", {}) so it merges
args.target_sparse_ratio into the existing target dict key-by-key (updating only
existing phase keys like "prefill" and "decode") rather than unconditionally
writing both phases.
🧹 Nitpick comments (3)
modelopt/torch/kernels/common/attention/hf_triton_attention.py (1)

176-176: 💤 Low value

Add a brief comment explaining the deferred import.

Per CONTRIBUTING.md, function-level imports should include a brief comment naming the reason (e.g., lazy/optional/circular). This import is deferred to the calibration path—add a comment such as # Lazy import: only needed during calibration.

📝 Suggested comment
-        from modelopt.torch.kernels.common.attention import attention_calibrate
+        # Lazy import: only needed during calibration.
+        from modelopt.torch.kernels.common.attention import attention_calibrate
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/kernels/common/attention/hf_triton_attention.py` at line 176,
Add a short explanatory comment for the deferred import of attention_calibrate
to indicate it's performed lazily for calibration only; locate the import
statement "from modelopt.torch.kernels.common.attention import
attention_calibrate" in hf_triton_attention.py and add a brief comment such as
"# Lazy import: only needed during calibration" immediately above or inline with
that import.
tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py (2)

114-114: 💤 Low value

Prefer math.isfinite(a) over wrapping a Python scalar in a tensor.

a is already a Python float, so torch.isfinite(torch.tensor(a)) allocates a tensor unnecessarily. math.isfinite(a) is clearer.

♻️ Suggested change
-            assert a > 0 and torch.isfinite(torch.tensor(a))
+            assert a > 0 and math.isfinite(a)

(add import math at the top of the file)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py`
at line 114, Replace the tensor-based finiteness check with Python's
math.isfinite: add "import math" at the top of the test file and change the
assertion that currently reads "assert a > 0 and
torch.isfinite(torch.tensor(a))" to "assert a > 0 and math.isfinite(a)"; refer
to the variable "a" in the failing assertion to locate the line to update.

24-38: 💤 Low value

Move the IS_AVAILABLE import up with the other top-level imports.

Line 38 is a module-level import placed below pytestmark. It is used in the skipif decorators, so grouping it with lines 24-31 keeps import ordering conventional and avoids surprise. Minor placement nit.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py`
around lines 24 - 38, Move the module-level import "from
modelopt.torch.kernels.common.attention import IS_AVAILABLE as
TRITON_KERNEL_AVAILABLE" up into the top import block with the other imports
(near AutoModelForCausalLM and SparseAttentionModule) so it's defined before
pytestmark and available for the skipif decorators; update its position only (no
code changes) so skip conditions referencing TRITON_KERNEL_AVAILABLE remain
valid.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py`:
- Around line 143-149: The in-function import of
modelopt.torch.kernels.common.attention.hf_triton_attention (and its symbols
clear_hf_triton_skip_softmax_config, get_calibration_counters,
get_calibration_seq_k, set_hf_triton_skip_softmax_config,
triton_attention_forward) must be either moved to module-level with the other
imports or the deferral must be explicitly justified; if this was done to avoid
a hard dependency on Triton at test-collection time, move the import to the top
of the file and/or add a concise comment like "deferred import to avoid hard
Triton dependency during test collection" immediately above the in-function
import so the reason for the optional-dependency deferral is clear and follows
coding guidelines.
- Around line 86-94: Move the local "import copy" out of the test method and
place it with the other top-level imports at the top of the file so import
errors are raised at collection time; specifically remove the inline import
inside test_sparsify_triton_calib_sets_params and add a single top-level "import
copy" that will be used by that function (and any other tests) to deepcopy
SKIP_SOFTMAX_TRITON_CALIB.

---

Outside diff comments:
In `@examples/llm_sparsity/attention_sparsity/hf_sa.py`:
- Around line 176-182: The current override always writes both "prefill" and
"decode" into calib["target_sparse_ratio"], which forces decode calibration on
even for prefill-only configs; instead, fetch the existing target dict (e.g.,
target = calib.get("target_sparse_ratio", {})) and only update the phases that
are already present in that dict (or at least only set "prefill" if "decode" is
not present) when applying args.target_sparse_ratio; in practice modify the
block that uses sparse_cfg.setdefault("calibration", {}) so it merges
args.target_sparse_ratio into the existing target dict key-by-key (updating only
existing phase keys like "prefill" and "decode") rather than unconditionally
writing both phases.

---

Nitpick comments:
In `@modelopt/torch/kernels/common/attention/hf_triton_attention.py`:
- Line 176: Add a short explanatory comment for the deferred import of
attention_calibrate to indicate it's performed lazily for calibration only;
locate the import statement "from modelopt.torch.kernels.common.attention import
attention_calibrate" in hf_triton_attention.py and add a brief comment such as
"# Lazy import: only needed during calibration" immediately above or inline with
that import.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py`:
- Line 114: Replace the tensor-based finiteness check with Python's
math.isfinite: add "import math" at the top of the test file and change the
assertion that currently reads "assert a > 0 and
torch.isfinite(torch.tensor(a))" to "assert a > 0 and math.isfinite(a)"; refer
to the variable "a" in the failing assertion to locate the line to update.
- Around line 24-38: Move the module-level import "from
modelopt.torch.kernels.common.attention import IS_AVAILABLE as
TRITON_KERNEL_AVAILABLE" up into the top import block with the other imports
(near AutoModelForCausalLM and SparseAttentionModule) so it's defined before
pytestmark and available for the skipif decorators; update its position only (no
code changes) so skip conditions referencing TRITON_KERNEL_AVAILABLE remain
valid.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: dc3aa8c1-a287-45c1-8196-2702a9981726

📥 Commits

Reviewing files that changed from the base of the PR and between 905259f and 7802f9f.

📒 Files selected for processing (5)
  • examples/llm_sparsity/attention_sparsity/hf_sa.py
  • modelopt/torch/kernels/common/attention/hf_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py

Comment thread tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py Outdated
Comment thread tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 2, 2026

Codecov Report

❌ Patch coverage is 75.00000% with 14 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.68%. Comparing base (902d369) to head (61dc593).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...delopt/torch/kernels/common/attention/triton_fa.py 45.45% 6 Missing ⚠️
...lopt/torch/kernels/sparsity/attention/calibrate.py 37.50% 5 Missing ⚠️
...kernels/sparsity/attention/skip_softmax_helpers.py 0.00% 1 Missing ⚠️
...arsity/attention_sparsity/calibration/calibrate.py 0.00% 1 Missing ⚠️
.../attention_sparsity/methods/triton_skip_softmax.py 95.23% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1597      +/-   ##
==========================================
+ Coverage   74.63%   76.68%   +2.05%     
==========================================
  Files         481      481              
  Lines       52770    52808      +38     
==========================================
+ Hits        39383    40496    +1113     
+ Misses      13387    12312    -1075     
Flag Coverage Δ
examples 40.23% <19.64%> (-2.27%) ⬇️
gpu 59.82% <71.42%> (+8.27%) ⬆️
regression 15.18% <5.35%> (+0.06%) ⬆️
unit 53.89% <25.00%> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment thread modelopt/torch/kernels/common/attention/hf_triton_attention.py Outdated
Comment thread modelopt/torch/sparsity/attention_sparsity/config.py
Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
@rohansjoshi rohansjoshi force-pushed the rohjoshi/triton-ss-calib branch from 7802f9f to 8490c42 Compare June 2, 2026 22:38
…PyTorch

Signed-off-by: Kai Xu <kaix@nvidia.com>
@kaix-nv kaix-nv force-pushed the rohjoshi/triton-ss-calib branch from 422a5f0 to f092f8c Compare June 2, 2026 23:53
Signed-off-by: Kai Xu <kaix@nvidia.com>
@kaix-nv kaix-nv self-requested a review June 3, 2026 00:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants