Skip softmax calibration via Triton kernel#1597
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughIntegrates Triton skip-softmax calibration into HF attention: adds SKIP_SOFTMAX_TRITON_CALIB and CLI wiring, moves HF calibration state onto the Triton method, binds Triton kernel launches to the tensors' CUDA device, excludes padded Q rows from calibration decisions, and adds CUDA+Triton tests. ChangesHF Triton Skip-Softmax Calibration Support
🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/llm_sparsity/attention_sparsity/hf_sa.py (1)
176-182:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winFix
--target_sparse_ratiooverride to avoid enabling decode calibration when Triton skip-softmax is prefill-only.
calibration/calibrate.pygates decode calibration ontarget_sparse_ratio["decode"] > 0(viacalibrate_decode = target_dict.get("decode", 0.0) > 0.0), so the current override inexamples/llm_sparsity/attention_sparsity/hf_sa.pywill forcedecodeto be enabled even when the selected Triton calibration config is meant to be prefill-only (SKIP_SOFTMAX_TRITON_CALIB). Preserve the existingtarget_sparse_ratiophase set instead of unconditionally writing both phases.🔧 Proposed fix to only override phases already present
if args.target_sparse_ratio is not None: calib = sparse_cfg.setdefault("calibration", {}) assert isinstance(calib, dict) - calib["target_sparse_ratio"] = { - "prefill": args.target_sparse_ratio, - "decode": args.target_sparse_ratio, - } + existing = calib.get("target_sparse_ratio", {"prefill": 0.5, "decode": 0.5}) + calib["target_sparse_ratio"] = { + phase: args.target_sparse_ratio for phase in existing + }🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/llm_sparsity/attention_sparsity/hf_sa.py` around lines 176 - 182, The current override always writes both "prefill" and "decode" into calib["target_sparse_ratio"], which forces decode calibration on even for prefill-only configs; instead, fetch the existing target dict (e.g., target = calib.get("target_sparse_ratio", {})) and only update the phases that are already present in that dict (or at least only set "prefill" if "decode" is not present) when applying args.target_sparse_ratio; in practice modify the block that uses sparse_cfg.setdefault("calibration", {}) so it merges args.target_sparse_ratio into the existing target dict key-by-key (updating only existing phase keys like "prefill" and "decode") rather than unconditionally writing both phases.
🧹 Nitpick comments (3)
modelopt/torch/kernels/common/attention/hf_triton_attention.py (1)
176-176: 💤 Low valueAdd a brief comment explaining the deferred import.
Per CONTRIBUTING.md, function-level imports should include a brief comment naming the reason (e.g., lazy/optional/circular). This import is deferred to the calibration path—add a comment such as
# Lazy import: only needed during calibration.📝 Suggested comment
- from modelopt.torch.kernels.common.attention import attention_calibrate + # Lazy import: only needed during calibration. + from modelopt.torch.kernels.common.attention import attention_calibrate🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/kernels/common/attention/hf_triton_attention.py` at line 176, Add a short explanatory comment for the deferred import of attention_calibrate to indicate it's performed lazily for calibration only; locate the import statement "from modelopt.torch.kernels.common.attention import attention_calibrate" in hf_triton_attention.py and add a brief comment such as "# Lazy import: only needed during calibration" immediately above or inline with that import.tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py (2)
114-114: 💤 Low valuePrefer
math.isfinite(a)over wrapping a Python scalar in a tensor.
ais already a Python float, sotorch.isfinite(torch.tensor(a))allocates a tensor unnecessarily.math.isfinite(a)is clearer.♻️ Suggested change
- assert a > 0 and torch.isfinite(torch.tensor(a)) + assert a > 0 and math.isfinite(a)(add
import mathat the top of the file)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py` at line 114, Replace the tensor-based finiteness check with Python's math.isfinite: add "import math" at the top of the test file and change the assertion that currently reads "assert a > 0 and torch.isfinite(torch.tensor(a))" to "assert a > 0 and math.isfinite(a)"; refer to the variable "a" in the failing assertion to locate the line to update.
24-38: 💤 Low valueMove the
IS_AVAILABLEimport up with the other top-level imports.Line 38 is a module-level import placed below
pytestmark. It is used in theskipifdecorators, so grouping it with lines 24-31 keeps import ordering conventional and avoids surprise. Minor placement nit.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py` around lines 24 - 38, Move the module-level import "from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE" up into the top import block with the other imports (near AutoModelForCausalLM and SparseAttentionModule) so it's defined before pytestmark and available for the skipif decorators; update its position only (no code changes) so skip conditions referencing TRITON_KERNEL_AVAILABLE remain valid.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py`:
- Around line 143-149: The in-function import of
modelopt.torch.kernels.common.attention.hf_triton_attention (and its symbols
clear_hf_triton_skip_softmax_config, get_calibration_counters,
get_calibration_seq_k, set_hf_triton_skip_softmax_config,
triton_attention_forward) must be either moved to module-level with the other
imports or the deferral must be explicitly justified; if this was done to avoid
a hard dependency on Triton at test-collection time, move the import to the top
of the file and/or add a concise comment like "deferred import to avoid hard
Triton dependency during test collection" immediately above the in-function
import so the reason for the optional-dependency deferral is clear and follows
coding guidelines.
- Around line 86-94: Move the local "import copy" out of the test method and
place it with the other top-level imports at the top of the file so import
errors are raised at collection time; specifically remove the inline import
inside test_sparsify_triton_calib_sets_params and add a single top-level "import
copy" that will be used by that function (and any other tests) to deepcopy
SKIP_SOFTMAX_TRITON_CALIB.
---
Outside diff comments:
In `@examples/llm_sparsity/attention_sparsity/hf_sa.py`:
- Around line 176-182: The current override always writes both "prefill" and
"decode" into calib["target_sparse_ratio"], which forces decode calibration on
even for prefill-only configs; instead, fetch the existing target dict (e.g.,
target = calib.get("target_sparse_ratio", {})) and only update the phases that
are already present in that dict (or at least only set "prefill" if "decode" is
not present) when applying args.target_sparse_ratio; in practice modify the
block that uses sparse_cfg.setdefault("calibration", {}) so it merges
args.target_sparse_ratio into the existing target dict key-by-key (updating only
existing phase keys like "prefill" and "decode") rather than unconditionally
writing both phases.
---
Nitpick comments:
In `@modelopt/torch/kernels/common/attention/hf_triton_attention.py`:
- Line 176: Add a short explanatory comment for the deferred import of
attention_calibrate to indicate it's performed lazily for calibration only;
locate the import statement "from modelopt.torch.kernels.common.attention import
attention_calibrate" in hf_triton_attention.py and add a brief comment such as
"# Lazy import: only needed during calibration" immediately above or inline with
that import.
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py`:
- Line 114: Replace the tensor-based finiteness check with Python's
math.isfinite: add "import math" at the top of the test file and change the
assertion that currently reads "assert a > 0 and
torch.isfinite(torch.tensor(a))" to "assert a > 0 and math.isfinite(a)"; refer
to the variable "a" in the failing assertion to locate the line to update.
- Around line 24-38: Move the module-level import "from
modelopt.torch.kernels.common.attention import IS_AVAILABLE as
TRITON_KERNEL_AVAILABLE" up into the top import block with the other imports
(near AutoModelForCausalLM and SparseAttentionModule) so it's defined before
pytestmark and available for the skipif decorators; update its position only (no
code changes) so skip conditions referencing TRITON_KERNEL_AVAILABLE remain
valid.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: dc3aa8c1-a287-45c1-8196-2702a9981726
📒 Files selected for processing (5)
examples/llm_sparsity/attention_sparsity/hf_sa.pymodelopt/torch/kernels/common/attention/hf_triton_attention.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1597 +/- ##
==========================================
+ Coverage 74.63% 76.68% +2.05%
==========================================
Files 481 481
Lines 52770 52808 +38
==========================================
+ Hits 39383 40496 +1113
+ Misses 13387 12312 -1075
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
7802f9f to
8490c42
Compare
…PyTorch Signed-off-by: Kai Xu <kaix@nvidia.com>
422a5f0 to
f092f8c
Compare
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
What does this PR do?
Adds skip softmax calibration for LLMs via Triton kernel (leveraging existing kernel used for diffusion)
Type of change: New feature
Usage
The Triton calibration equals PyTorch at every threshold, for both phases:
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Documentation