From a5772b3608e2c2e56d1e55f5e1dd92e8578aeebb Mon Sep 17 00:00:00 2001 From: Rohan Joshi Date: Tue, 2 Jun 2026 00:19:40 +0000 Subject: [PATCH 1/6] First commit Signed-off-by: Rohan Joshi --- .../llm_sparsity/attention_sparsity/hf_sa.py | 19 ++ .../common/attention/hf_triton_attention.py | 84 ++++++++ .../sparsity/attention_sparsity/config.py | 31 +++ .../methods/triton_skip_softmax.py | 32 ++- .../test_triton_calibration_gpu.py | 186 ++++++++++++++++++ 5 files changed, 350 insertions(+), 2 deletions(-) create mode 100644 tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py index 5eae54ba6ee..1eacc6f18b3 100644 --- a/examples/llm_sparsity/attention_sparsity/hf_sa.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -31,6 +31,7 @@ from modelopt.torch.sparsity.attention_sparsity.config import ( SKIP_SOFTMAX_CALIB, SKIP_SOFTMAX_CALIB_SPARSE24, + SKIP_SOFTMAX_TRITON_CALIB, SPARSE_SOFTMAX_DEFAULT, ) from modelopt.torch.utils.memory_monitor import launch_memory_monitor @@ -44,6 +45,7 @@ SPARSE_ATTN_CFG_CHOICES = { "skip_softmax_calib": SKIP_SOFTMAX_CALIB, "skip_softmax_calib_sparse24": SKIP_SOFTMAX_CALIB_SPARSE24, + "skip_softmax_triton_calib": SKIP_SOFTMAX_TRITON_CALIB, "sparse_softmax": SPARSE_SOFTMAX_DEFAULT, } @@ -186,6 +188,15 @@ def main(args): calib["max_seqlen"] = args.calib_max_seqlen if args.calib_chunk_size is not None: calib["chunk_size"] = args.calib_chunk_size + # Point RULER calibration at the data downloaded by download_ruler_data.sh + # (next to this script) unless the user overrides it. The NIAH essay + # haystack requires this directory. + calib.setdefault( + "data_dir", + args.calib_data_dir + if args.calib_data_dir is not None + else str(Path(__file__).parent / "data"), + ) model = mtsa.sparsify(model, config=sparse_config) print("Sparse attention applied successfully!") @@ -302,6 +313,14 @@ def main(args): default=None, help="Chunk size for calibration prefill. Overrides config value.", ) + parser.add_argument( + "--calib_data_dir", + type=str, + default=None, + help="Path to RULER calibration data (contains an 'essays' subdir). " + "Defaults to the 'data' directory next to this script " + "(populated by download_ruler_data.sh).", + ) args = parser.parse_args() main(args) diff --git a/modelopt/torch/kernels/common/attention/hf_triton_attention.py b/modelopt/torch/kernels/common/attention/hf_triton_attention.py index 860c65d6621..c458a0d1080 100644 --- a/modelopt/torch/kernels/common/attention/hf_triton_attention.py +++ b/modelopt/torch/kernels/common/attention/hf_triton_attention.py @@ -22,11 +22,71 @@ from __future__ import annotations +import threading + import torch import torch.nn as nn from modelopt.torch.kernels.common.attention.triton_fa import attention +# --------------------------------------------------------------------------- +# Thread-local skip-softmax calibration config for the HF (modelopt_triton) backend +# --------------------------------------------------------------------------- +# Mirrors the diffusers/LTX backends: during calibration the Triton calibration +# kernel measures multi-threshold tile-skip statistics without skipping any tiles. +# Inference-time config (skip threshold / scale factor) is still read from the +# module/method attributes in ``triton_attention_forward`` — only calibration +# state lives here. +_thread_local = threading.local() + + +def set_hf_triton_skip_softmax_config( + threshold: float | None = None, + calibration_mode: bool = False, + threshold_trials: list[float] | None = None, + scale_factor: float | None = None, + measure_sparsity: bool = False, +) -> None: + """Set thread-local skip-softmax calibration config for the next forward. + + Accepts the same keyword arguments as the diffusers/LTX backends so the + shared :class:`TritonSkipSoftmaxMethod` can configure all backends uniformly. + Only the calibration fields are consumed by the HF backend; the inference + fields (``threshold``/``scale_factor``/``measure_sparsity``) are accepted for + signature compatibility but ignored here, since the HF inference path reads + its threshold from the module/method attributes. + + Args: + threshold: Ignored by the HF backend (inference threshold comes from the module). + calibration_mode: If True, route prefill attention through the calibration kernel. + threshold_trials: Thresholds to measure sparsity for (used when calibration_mode=True). + scale_factor: Ignored by the HF backend. + measure_sparsity: Ignored by the HF backend. + """ + _thread_local.calibration_mode = calibration_mode + _thread_local.threshold_trials = threshold_trials + # Counters accumulated across all attention calls in one forward pass. + _thread_local.calibration_counters = None + _thread_local.calibration_seq_k = None + + +def clear_hf_triton_skip_softmax_config() -> None: + """Clear thread-local skip-softmax calibration config.""" + _thread_local.calibration_mode = False + _thread_local.threshold_trials = None + _thread_local.calibration_counters = None + _thread_local.calibration_seq_k = None + + +def get_calibration_counters() -> torch.Tensor | None: + """Return accumulated calibration counters ``[num_thresholds, 2]`` or None.""" + return getattr(_thread_local, "calibration_counters", None) + + +def get_calibration_seq_k() -> int | None: + """Return KV sequence length observed during calibration, or None.""" + return getattr(_thread_local, "calibration_seq_k", None) + def _seq_lens_from_mask( attention_mask: torch.Tensor | None, @@ -105,6 +165,26 @@ def triton_attention_forward( kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32) kw["max_input_len_k"] = seq_k + # --- Calibration mode: collect multi-threshold tile-skip stats (prefill only) --- + # Run the calibration kernel, which computes full (non-skipped) attention while + # counting, per candidate threshold, how many KV tiles would be skipped. ``kw`` at + # this point holds only the base attention args that ``attention_calibrate`` accepts; + # the sparse-attention kwargs below are intentionally not added in this branch. + calib_mode = getattr(_thread_local, "calibration_mode", False) + if calib_mode and not is_decode: + trials = getattr(_thread_local, "threshold_trials", None) + from modelopt.torch.kernels.common.attention import attention_calibrate + + if trials and attention_calibrate is not None: + o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials) + + # Accumulate counters across all attention calls in this forward pass. + prev = getattr(_thread_local, "calibration_counters", None) + _thread_local.calibration_counters = counters if prev is None else prev + counters + _thread_local.calibration_seq_k = seq_k + + return (o.view(batch, seq_len, num_heads, head_dim), None) + # Sparse attention params method = getattr(module, "_sparse_method_instance", None) @@ -153,6 +233,10 @@ def register_triton_attention() -> bool: __all__ = [ + "clear_hf_triton_skip_softmax_config", + "get_calibration_counters", + "get_calibration_seq_k", "register_triton_attention", + "set_hf_triton_skip_softmax_config", "triton_attention_forward", ] diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 32a49f02e34..70d606c51b0 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -546,6 +546,36 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): } +# RULER calibration via the fused Triton calibration kernel (prefill only). +# Computes the same exponential-model calibration as SKIP_SOFTMAX_CALIB but +# measures tile-skip statistics with the Triton ``attention_calibrate`` kernel +# (the way the Triton inference kernel actually skips tiles) instead of the +# PyTorch F.softmax-patching block path. Faster on GPU since it avoids +# materializing per-block tensors. +SKIP_SOFTMAX_TRITON_CALIB = { + "sparse_cfg": { + "calibration": { + # Prefill only: omitting "decode" leaves its target at 0.0, which + # skips decode calibration (the Triton calibration kernel is + # prefill-oriented). + "target_sparse_ratio": {"prefill": 0.5}, + "samples": 64, + "max_seqlen": 16384, + # Full prefill (seq_q == seq_k, uniform batch=1) — what + # attention_calibrate was validated against. Chunked prefill would + # exercise an untested KV-cache causal-offset path in the kernel. + "chunk_size": -1, + }, + "*attn*": { + "method": "triton_skip_softmax", + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + }, +} + + class VSAAttributeConfig(ModeloptBaseConfig): """Video Sparse Attention (VSA) attribute configuration. @@ -738,6 +768,7 @@ class VSAConfig(SparseAttentionConfig): "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_CALIB_SPARSE24", "SKIP_SOFTMAX_DEFAULT", + "SKIP_SOFTMAX_TRITON_CALIB", "SKIP_SOFTMAX_TRITON_DEFAULT", "SPARSE_SOFTMAX_DEFAULT", "VSA_DEFAULT", diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py index c0a183787dd..fc1b3d25dd3 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py @@ -170,7 +170,7 @@ def _get_diffusers_backend_context(): yield def _set_triton_backends(self, **kwargs): - """Set config on both diffusers and LTX Triton backends.""" + """Set config on the diffusers, LTX, and HF (modelopt_triton) Triton backends.""" try: from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( set_triton_skip_softmax_config, @@ -187,9 +187,17 @@ def _set_triton_backends(self, **kwargs): set_ltx_triton_context(active=True, **kwargs) except ImportError: pass + try: + from modelopt.torch.kernels.common.attention.hf_triton_attention import ( + set_hf_triton_skip_softmax_config, + ) + + set_hf_triton_skip_softmax_config(**kwargs) + except ImportError: + pass def _clear_triton_backends(self): - """Clear config on both Triton backends.""" + """Clear config on the diffusers, LTX, and HF Triton backends.""" try: from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( clear_triton_skip_softmax_config, @@ -206,6 +214,14 @@ def _clear_triton_backends(self): clear_ltx_triton_context() except ImportError: pass + try: + from modelopt.torch.kernels.common.attention.hf_triton_attention import ( + clear_hf_triton_skip_softmax_config, + ) + + clear_hf_triton_skip_softmax_config() + except ImportError: + pass def _collect_calibration_stats(self, module): """Read Triton calibration counters and store as stats on the module.""" @@ -235,6 +251,18 @@ def _collect_calibration_stats(self, module): except ImportError: pass + if counters is None: + try: + from modelopt.torch.kernels.common.attention.hf_triton_attention import ( + get_calibration_counters, + get_calibration_seq_k, + ) + + counters = get_calibration_counters() + seq_k = get_calibration_seq_k() + except ImportError: + pass + if counters is None or self._threshold_trials is None: return diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py new file mode 100644 index 00000000000..26f3adf3af5 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for skip-softmax calibration via the Triton backend on HF models. + +These exercise the HuggingFace (``modelopt_triton``) wiring that routes the +calibration forward pass through the fused ``attention_calibrate`` kernel and +feeds the collected multi-threshold tile-skip statistics into the same +exponential-model fit used by the PyTorch path. +""" + +import pytest +import torch +from _test_utils.torch.transformers_models import create_tiny_llama_dir +from transformers import AutoModelForCausalLM + +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_TRITON_CALIB +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + +pytestmark = [ + pytest.mark.filterwarnings("ignore::UserWarning"), + pytest.mark.filterwarnings("ignore::RuntimeWarning"), +] + +from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE + +# Thresholds spanning a wide range so the collected sparsity covers the (10%, 90%) +# window the exponential fit relies on. +THRESHOLD_TRIALS = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 5e-2, 1e-1, 3e-1, 5e-1, 7e-1, 9e-1] + + +@pytest.fixture(scope="module") +def tiny_llama_dir(tmp_path_factory): + """Create a minimal Llama model directory.""" + return create_tiny_llama_dir( + tmp_path_factory.mktemp("tiny_llama_triton_calib"), + num_hidden_layers=2, + hidden_size=64, + intermediate_size=128, + num_attention_heads=4, + num_key_value_heads=2, + max_position_embeddings=1024, + ) + + +def _load_eager(tiny_llama_dir): + return AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, attn_implementation="eager", device_map="cuda" + ) + + +def _make_forward_loop(vocab_size, lengths=(128, 256, 384, 512)): + """Forward loop that runs several full-prefill passes of varying length. + + Each pass triggers one ``attention_calibrate`` call per layer, producing one + per-sample calibration record per length. + """ + + def forward_loop(model): + torch.manual_seed(0) + for seq_len in lengths: + input_ids = torch.randint(0, vocab_size, (1, seq_len), device="cuda") + with torch.no_grad(): + model(input_ids, use_cache=False) + + return forward_loop + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestTritonCalibrationHF: + """End-to-end calibration via the Triton backend on a tiny HF model.""" + + def test_sparsify_triton_calib_sets_params(self, tiny_llama_dir): + """Running SKIP_SOFTMAX_TRITON_CALIB fits a finite exponential model.""" + import copy + + model = _load_eager(tiny_llama_dir) + + # Use the calibrator's default (dense) threshold trials so the collected + # sparsity densely covers the (10%, 90%) window the fit filters on. + config = copy.deepcopy(SKIP_SOFTMAX_TRITON_CALIB) + + forward_loop = _make_forward_loop(model.config.vocab_size) + sparse_model = mtsa.sparsify(model, config, forward_loop=forward_loop) + + # Backend dispatched to the Triton kernel. + assert sparse_model.config._attn_implementation == "modelopt_triton" + + sparse_modules = [ + m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule) + ] + assert len(sparse_modules) == 2 + + # Calibration produced finite, in-bounds (a, b) for the prefill phase. + for module in sparse_modules: + method = module._sparse_method_instance + assert method.name == "triton_skip_softmax" + params = method.calibration_params + assert params is not None and "prefill" in params + a, b = params["prefill"]["a"], params["prefill"]["b"] + assert a > 0 and torch.isfinite(torch.tensor(a)) + assert 0.0 <= b <= 20.0 + # Prefill-only: decode must not be calibrated. + assert "decode" not in params + + def test_calibrated_model_inference(self, tiny_llama_dir): + """A model calibrated through the Triton path still runs inference cleanly.""" + import copy + + model = _load_eager(tiny_llama_dir) + config = copy.deepcopy(SKIP_SOFTMAX_TRITON_CALIB) + + forward_loop = _make_forward_loop(model.config.vocab_size) + sparse_model = mtsa.sparsify(model, config, forward_loop=forward_loop) + + sparse_model.eval() + input_ids = torch.randint(0, model.config.vocab_size, (1, 64), device="cuda") + with torch.no_grad(): + out = sparse_model(input_ids, use_cache=False) + assert out.logits is not None + assert not torch.isnan(out.logits).any() + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestHFBackendCalibrationCounters: + """Lower-level checks on the HF backend's calibration branch.""" + + def test_counters_monotonic_in_threshold(self): + """Skipped-tile counts are non-decreasing as the threshold grows.""" + from modelopt.torch.kernels.common.attention.hf_triton_attention import ( + clear_hf_triton_skip_softmax_config, + get_calibration_counters, + get_calibration_seq_k, + set_hf_triton_skip_softmax_config, + triton_attention_forward, + ) + + batch, num_heads, seq_len, head_dim = 1, 4, 256, 64 + torch.manual_seed(0) + q = torch.randn(batch, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn_like(q) + v = torch.randn_like(q) + + # A bare module stand-in; the calibration branch returns before touching + # any sparse-method attributes. + module = torch.nn.Module() + + set_hf_triton_skip_softmax_config( + calibration_mode=True, threshold_trials=THRESHOLD_TRIALS + ) + try: + out, _ = triton_attention_forward( + module, q, k, v, attention_mask=None, scaling=1.0 / (head_dim**0.5) + ) + counters = get_calibration_counters() + seq_k = get_calibration_seq_k() + finally: + clear_hf_triton_skip_softmax_config() + + assert out.shape == (batch, seq_len, num_heads, head_dim) + assert seq_k == seq_len + assert counters is not None + assert counters.shape == (len(THRESHOLD_TRIALS), 2) + + totals = counters[:, 0] + skipped = counters[:, 1] + assert torch.all(totals == totals[0]) # same tile count for every threshold + assert torch.all(skipped[1:] >= skipped[:-1]) # monotonic non-decreasing + assert torch.all(skipped <= totals) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 8490c42585b0a6f97d1eb9814bd822da6e3e8279 Mon Sep 17 00:00:00 2001 From: Rohan Joshi Date: Tue, 2 Jun 2026 21:30:53 +0000 Subject: [PATCH 2/6] Added decode calibration Signed-off-by: Rohan Joshi --- CHANGELOG.rst | 1 + .../common/attention/hf_triton_attention.py | 100 +++------- .../kernels/common/attention/triton_fa.py | 177 ++++++++++-------- .../kernels/sparsity/attention/calibrate.py | 80 ++++---- .../calibration/calibrate.py | 20 +- .../sparsity/attention_sparsity/config.py | 9 +- .../methods/triton_skip_softmax.py | 58 +++--- .../attention/test_triton_fa_calibrate.py | 37 ++++ .../test_triton_calibration_gpu.py | 132 +++++-------- 9 files changed, 298 insertions(+), 316 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b6a3a979dd5..3d96e6a9fe1 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -43,6 +43,7 @@ Changelog - Add Nemotron-3-Super-120B-A12B PTQ recipes ``modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml`` (MSE-mixed) and ``super-nvfp4-max-calib.yaml`` (max-calib mixed): NVFP4 W4A4 routed experts + FP8 per-tensor shared experts / Mamba in/out_proj + FP8 KV cache. - Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default). - Add post-training quantization (PTQ) example for the Megatron-Bridge framework: ``examples/megatron_bridge/quantize.py`` calibrates an HF model (via ``--quant_cfg`` alias / full config name or a ``--recipe`` YAML, with optional KV-cache quant, weight-only, compression, and MoE expert-ratio calibration) and saves a Megatron checkpoint (tensor / pipeline / expert parallelism supported), and ``examples/megatron_bridge/export.py`` converts that checkpoint to a deployable HuggingFace (unified) checkpoint for TensorRT-LLM / vLLM / SGLang. See `examples/megatron_bridge/README.md `_ for details. +- Add ``mtsa.config.SKIP_SOFTMAX_TRITON_CALIB`` for skip-softmax attention-sparsity calibration through the fused Triton ``attention_calibrate`` kernel (HF ``modelopt_triton`` backend), measuring multi-threshold tile-skip statistics the way the Triton inference kernel actually skips tiles for both prefill and decode. Exposed as ``--sparse_attn_cfg skip_softmax_triton_calib`` in ``examples/llm_sparsity/attention_sparsity/hf_sa.py`` (with a new ``--calib_data_dir`` flag for RULER calibration data). **Bug Fixes** diff --git a/modelopt/torch/kernels/common/attention/hf_triton_attention.py b/modelopt/torch/kernels/common/attention/hf_triton_attention.py index c458a0d1080..d30d5ad844a 100644 --- a/modelopt/torch/kernels/common/attention/hf_triton_attention.py +++ b/modelopt/torch/kernels/common/attention/hf_triton_attention.py @@ -22,70 +22,14 @@ from __future__ import annotations -import threading - import torch import torch.nn as nn from modelopt.torch.kernels.common.attention.triton_fa import attention -# --------------------------------------------------------------------------- -# Thread-local skip-softmax calibration config for the HF (modelopt_triton) backend -# --------------------------------------------------------------------------- -# Mirrors the diffusers/LTX backends: during calibration the Triton calibration -# kernel measures multi-threshold tile-skip statistics without skipping any tiles. -# Inference-time config (skip threshold / scale factor) is still read from the -# module/method attributes in ``triton_attention_forward`` — only calibration -# state lives here. -_thread_local = threading.local() - - -def set_hf_triton_skip_softmax_config( - threshold: float | None = None, - calibration_mode: bool = False, - threshold_trials: list[float] | None = None, - scale_factor: float | None = None, - measure_sparsity: bool = False, -) -> None: - """Set thread-local skip-softmax calibration config for the next forward. - - Accepts the same keyword arguments as the diffusers/LTX backends so the - shared :class:`TritonSkipSoftmaxMethod` can configure all backends uniformly. - Only the calibration fields are consumed by the HF backend; the inference - fields (``threshold``/``scale_factor``/``measure_sparsity``) are accepted for - signature compatibility but ignored here, since the HF inference path reads - its threshold from the module/method attributes. - - Args: - threshold: Ignored by the HF backend (inference threshold comes from the module). - calibration_mode: If True, route prefill attention through the calibration kernel. - threshold_trials: Thresholds to measure sparsity for (used when calibration_mode=True). - scale_factor: Ignored by the HF backend. - measure_sparsity: Ignored by the HF backend. - """ - _thread_local.calibration_mode = calibration_mode - _thread_local.threshold_trials = threshold_trials - # Counters accumulated across all attention calls in one forward pass. - _thread_local.calibration_counters = None - _thread_local.calibration_seq_k = None - - -def clear_hf_triton_skip_softmax_config() -> None: - """Clear thread-local skip-softmax calibration config.""" - _thread_local.calibration_mode = False - _thread_local.threshold_trials = None - _thread_local.calibration_counters = None - _thread_local.calibration_seq_k = None - - -def get_calibration_counters() -> torch.Tensor | None: - """Return accumulated calibration counters ``[num_thresholds, 2]`` or None.""" - return getattr(_thread_local, "calibration_counters", None) - - -def get_calibration_seq_k() -> int | None: - """Return KV sequence length observed during calibration, or None.""" - return getattr(_thread_local, "calibration_seq_k", None) +# Skip-softmax calibration config and counters live on the module's +# ``_sparse_method_instance`` (HF passes the owning module to +# ``triton_attention_forward``), so no separate thread-local state is needed. def _seq_lens_from_mask( @@ -165,29 +109,35 @@ def triton_attention_forward( kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32) kw["max_input_len_k"] = seq_k - # --- Calibration mode: collect multi-threshold tile-skip stats (prefill only) --- - # Run the calibration kernel, which computes full (non-skipped) attention while - # counting, per candidate threshold, how many KV tiles would be skipped. ``kw`` at - # this point holds only the base attention args that ``attention_calibrate`` accepts; - # the sparse-attention kwargs below are intentionally not added in this branch. - calib_mode = getattr(_thread_local, "calibration_mode", False) - if calib_mode and not is_decode: - trials = getattr(_thread_local, "threshold_trials", None) + # Sparse-attention method instance. It carries the inference threshold and, + # during calibration, both the calibration config and the accumulated + # tile-skip counters. Available here because HF passes the owning module. + method = getattr(module, "_sparse_method_instance", None) + + # Calibration mode: run the calibration kernel, which computes full attention + # while counting, per candidate threshold, how many KV tiles would be skipped. + # The sparse-attention kwargs below are intentionally not added in this branch. + if method is not None and getattr(method, "_calibration_mode", False): + trials = getattr(method, "_threshold_trials", None) + # Deferred: the package __init__ imports this module, so importing + # attention_calibrate at module top would be circular. from modelopt.torch.kernels.common.attention import attention_calibrate if trials and attention_calibrate is not None: o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials) # Accumulate counters across all attention calls in this forward pass. - prev = getattr(_thread_local, "calibration_counters", None) - _thread_local.calibration_counters = counters if prev is None else prev + counters - _thread_local.calibration_seq_k = seq_k + # The method instance is per-module so the accumulator stays on one + # device, but guard the add against a device mismatch just in case. + prev = getattr(method, "_hf_calibration_counters", None) + method._hf_calibration_counters = ( + counters if prev is None else prev + counters.to(prev.device) + ) + method._hf_calibration_seq_k = seq_k + method._hf_calibration_is_decode = is_decode return (o.view(batch, seq_len, num_heads, head_dim), None) - # Sparse attention params - method = getattr(module, "_sparse_method_instance", None) - # N:M sparse softmax: prefill only (no perf benefit for decode) if method is not None and not is_decode and getattr(module, "_apply_sparse_nm", False): kw["sparsity_n"] = method.sparsity_n @@ -233,10 +183,6 @@ def register_triton_attention() -> bool: __all__ = [ - "clear_hf_triton_skip_softmax_config", - "get_calibration_counters", - "get_calibration_seq_k", "register_triton_attention", - "set_hf_triton_skip_softmax_config", "triton_attention_forward", ] diff --git a/modelopt/torch/kernels/common/attention/triton_fa.py b/modelopt/torch/kernels/common/attention/triton_fa.py index 0b481e93558..e930b75f0d1 100644 --- a/modelopt/torch/kernels/common/attention/triton_fa.py +++ b/modelopt/torch/kernels/common/attention/triton_fa.py @@ -919,23 +919,29 @@ def forward( def grid(META): return (batch, num_q_heads, triton.cdiv(max_input_len, META["BLOCK_M"])) - if do_measure: - # Runtime counters mutate global tensors, so do not run them through - # autotune candidate trials. Use one stable config for measurement. - _attn_fwd.fn[grid]( - *fwd_args, - **fwd_kwargs, - BLOCK_M=_MEASURE_BLOCK_M, - BLOCK_N=_MEASURE_BLOCK_N, - num_warps=_MEASURE_NUM_WARPS, - num_stages=_MEASURE_NUM_STAGES, - ) - else: - _attn_fwd[grid]( - *fwd_args, - **fwd_kwargs, - # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune - ) + # Triton launches on torch.cuda.current_device(), which is not + # necessarily the device the tensors live on (e.g. under accelerate + # device_map="auto" sharding). Activate the tensor's device so the + # kernel dereferences the right pointers instead of triggering an + # illegal memory access. + with torch.cuda.device(q.device): + if do_measure: + # Runtime counters mutate global tensors, so do not run them through + # autotune candidate trials. Use one stable config for measurement. + _attn_fwd.fn[grid]( + *fwd_args, + **fwd_kwargs, + BLOCK_M=_MEASURE_BLOCK_M, + BLOCK_N=_MEASURE_BLOCK_N, + num_warps=_MEASURE_NUM_WARPS, + num_stages=_MEASURE_NUM_STAGES, + ) + else: + _attn_fwd[grid]( + *fwd_args, + **fwd_kwargs, + # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune + ) # Store sparsity counters on the output tensor for retrieval by callers if do_measure: @@ -970,23 +976,30 @@ def backward(ctx, grad_output): do = grad_output.contiguous() num_warps = 4 + # Triton launches on torch.cuda.current_device(), which is not + # necessarily the device the tensors live on (e.g. under accelerate + # device_map="auto" sharding). Activate the tensor's device for each + # launch so the kernels dereference the right pointers instead of + # triggering an illegal memory access. + # Phase 1: delta = rowsum(O * dO) delta = torch.empty_like(lse) - _attn_bwd_preprocess[(ctx.num_q_heads, triton.cdiv(q.shape[0], BLOCK))]( - o, - do, - delta, - o.stride(0), - o.stride(1), - do.stride(0), - do.stride(1), - delta.stride(0), - delta.stride(1), - q.shape[0], - HEAD_DIM=HEAD_DIM, - BLOCK_D=BLOCK_D, - BLOCK_M=BLOCK, - ) + with torch.cuda.device(q.device): + _attn_bwd_preprocess[(ctx.num_q_heads, triton.cdiv(q.shape[0], BLOCK))]( + o, + do, + delta, + o.stride(0), + o.stride(1), + do.stride(0), + do.stride(1), + delta.stride(0), + delta.stride(1), + q.shape[0], + HEAD_DIM=HEAD_DIM, + BLOCK_D=BLOCK_D, + BLOCK_M=BLOCK, + ) dq = torch.zeros_like(q) dk = torch.zeros_like(k) @@ -1016,57 +1029,59 @@ def backward(ctx, grad_output): ) # Phase 2: dK, dV - _attn_bwd_dkdv[(ctx.batch, ctx.num_kv_heads, triton.cdiv(ctx.max_input_len_k, BLOCK))]( - *bwd_args[:4], - dk, - dv, - *bwd_args[4:], - dk.stride(0), - dk.stride(1), - dv.stride(0), - dv.stride(1), - lse.stride(0), - lse.stride(1), - kv_group_num=ctx.kv_group_num, - BLOCK_M=BLOCK, - BLOCK_D=BLOCK_D, - BLOCK_N=BLOCK, - IS_CAUSAL=ctx.is_causal, - HEAD_DIM=HEAD_DIM, - SPARSITY_N=ctx.sparsity_n, - SPARSITY_M=ctx.sparsity_m, - DENSE_SINK_TOKENS=ctx.dense_sink_tokens, - DENSE_RECENT_TOKENS=ctx.dense_recent_tokens, - APPLY_SKIP_SOFTMAX=ctx.apply_skip, - SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, - num_warps=num_warps, - num_stages=1, - ) + with torch.cuda.device(q.device): + _attn_bwd_dkdv[(ctx.batch, ctx.num_kv_heads, triton.cdiv(ctx.max_input_len_k, BLOCK))]( + *bwd_args[:4], + dk, + dv, + *bwd_args[4:], + dk.stride(0), + dk.stride(1), + dv.stride(0), + dv.stride(1), + lse.stride(0), + lse.stride(1), + kv_group_num=ctx.kv_group_num, + BLOCK_M=BLOCK, + BLOCK_D=BLOCK_D, + BLOCK_N=BLOCK, + IS_CAUSAL=ctx.is_causal, + HEAD_DIM=HEAD_DIM, + SPARSITY_N=ctx.sparsity_n, + SPARSITY_M=ctx.sparsity_m, + DENSE_SINK_TOKENS=ctx.dense_sink_tokens, + DENSE_RECENT_TOKENS=ctx.dense_recent_tokens, + APPLY_SKIP_SOFTMAX=ctx.apply_skip, + SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, + num_warps=num_warps, + num_stages=1, + ) # Phase 3: dQ - _attn_bwd_dq[(ctx.batch, ctx.num_q_heads, triton.cdiv(ctx.max_input_len, BLOCK))]( - *bwd_args[:4], - dq, - *bwd_args[4:], - dq.stride(0), - dq.stride(1), - lse.stride(0), - lse.stride(1), - kv_group_num=ctx.kv_group_num, - BLOCK_M=BLOCK, - BLOCK_D=BLOCK_D, - BLOCK_N=BLOCK, - IS_CAUSAL=ctx.is_causal, - HEAD_DIM=HEAD_DIM, - SPARSITY_N=ctx.sparsity_n, - SPARSITY_M=ctx.sparsity_m, - DENSE_SINK_TOKENS=ctx.dense_sink_tokens, - DENSE_RECENT_TOKENS=ctx.dense_recent_tokens, - APPLY_SKIP_SOFTMAX=ctx.apply_skip, - SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, - num_warps=num_warps, - num_stages=1, - ) + with torch.cuda.device(q.device): + _attn_bwd_dq[(ctx.batch, ctx.num_q_heads, triton.cdiv(ctx.max_input_len, BLOCK))]( + *bwd_args[:4], + dq, + *bwd_args[4:], + dq.stride(0), + dq.stride(1), + lse.stride(0), + lse.stride(1), + kv_group_num=ctx.kv_group_num, + BLOCK_M=BLOCK, + BLOCK_D=BLOCK_D, + BLOCK_N=BLOCK, + IS_CAUSAL=ctx.is_causal, + HEAD_DIM=HEAD_DIM, + SPARSITY_N=ctx.sparsity_n, + SPARSITY_M=ctx.sparsity_m, + DENSE_SINK_TOKENS=ctx.dense_sink_tokens, + DENSE_RECENT_TOKENS=ctx.dense_recent_tokens, + APPLY_SKIP_SOFTMAX=ctx.apply_skip, + SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, + num_warps=num_warps, + num_stages=1, + ) return ( dq, diff --git a/modelopt/torch/kernels/sparsity/attention/calibrate.py b/modelopt/torch/kernels/sparsity/attention/calibrate.py index 971c423f711..276d319d83c 100644 --- a/modelopt/torch/kernels/sparsity/attention/calibrate.py +++ b/modelopt/torch/kernels/sparsity/attention/calibrate.py @@ -132,7 +132,16 @@ def _attn_fwd_calibrate( # A tile is skipped iff ALL Q rows satisfy: tile_row_max < row_max + thresh. # Equivalently: max(tile_row_max - row_max) < thresh (worst-case row # must still be below threshold for the tile to be skippable). - max_gap = tl.max(tile_row_max - row_max) # scalar + # + # Exclude padding Q rows (q_pos >= seq_len_q) from the reduction. Their Q is + # loaded as zeros, so their tile_row_max is ~0 (not -inf), which would + # otherwise dominate the max and force max_gap >= 0 — making every tile + # un-skippable. This matters most for decode (seq_len_q == 1, so 127/128 + # rows are padding) and also fixes the last partial Q tile in prefill when + # seq_len_q is not a multiple of BLOCK_M. + gap = tile_row_max - row_max + gap = tl.where(q_pos < seq_len_q, gap, -float("inf")) + max_gap = tl.max(gap) # scalar skip_mask = (max_gap < thresholds).to(tl.int32) # [PADDED_THRESHOLDS] local_skipped += skip_mask num_tiles += 1 @@ -282,38 +291,43 @@ def attention_calibrate( num_programs * num_thresholds, dtype=torch.int32, device=q.device ) - _attn_fwd_calibrate[grid]( - q, - k, - v, - qk_scale, - b_start_loc, - b_seq_len, - b_start_loc_k, - b_seq_len_k, - o, - q.stride(0), - q.stride(1), - k.stride(0), - k.stride(1), - v.stride(0), - v.stride(1), - o.stride(0), - o.stride(1), - threshold_tensor, - per_program_totals, - per_program_skipped, - kv_group_num=kv_group_num, - BLOCK_M=BLOCK_M, - BLOCK_D=BLOCK_D, - BLOCK_N=BLOCK_N, - IS_CAUSAL=is_causal, - HEAD_DIM=HEAD_DIM, - NUM_THRESHOLDS=num_thresholds, - PADDED_THRESHOLDS=triton.next_power_of_2(num_thresholds), - num_warps=4, - num_stages=1, - ) + # Triton launches on torch.cuda.current_device(), which is not necessarily + # the device the tensors live on (e.g. under accelerate device_map="auto" + # sharding). Activate the tensor's device so the kernel dereferences the + # right pointers instead of triggering an illegal memory access. + with torch.cuda.device(q.device): + _attn_fwd_calibrate[grid]( + q, + k, + v, + qk_scale, + b_start_loc, + b_seq_len, + b_start_loc_k, + b_seq_len_k, + o, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + threshold_tensor, + per_program_totals, + per_program_skipped, + kv_group_num=kv_group_num, + BLOCK_M=BLOCK_M, + BLOCK_D=BLOCK_D, + BLOCK_N=BLOCK_N, + IS_CAUSAL=is_causal, + HEAD_DIM=HEAD_DIM, + NUM_THRESHOLDS=num_thresholds, + PADDED_THRESHOLDS=triton.next_power_of_2(num_thresholds), + num_warps=4, + num_stages=1, + ) # Reduce across programs: sum per-program counts → [num_thresholds] totals = per_program_totals.view(num_programs, num_thresholds).sum(dim=0).to(torch.int64) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py index 51df5bb4d4a..840f757a8c6 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -153,9 +153,14 @@ def create_decode_calibration_forward_loop( ) -> Callable: """Create forward loop for decode phase calibration. - Uses SDPA for fast prefill, then switches to eager attention for decode - token generation with softmax hook measurement. (Previously used - ``flash_attention_2`` for prefill, but transformers>=5.0's FA2 path + Uses SDPA for fast prefill (no measurement), then switches to the model's + configured sparse-attention backend for the decode steps so measurement + happens there: ``eager`` for the pytorch backend (F.softmax hook) or + ``modelopt_triton`` for the triton backend (Triton calibration kernel). + The backend is read from ``model.config._attn_implementation``, which + ``sparsify`` already set for the chosen backend. + + (SDPA is used for prefill because transformers>=5.0's FA2 path unconditionally calls ``s_aux.to(query.dtype)`` on the attention-sinks tensor and crashes for models without sinks. SDPA is just as fast for prefill, has no softmax hook, and is version-stable.) @@ -179,7 +184,8 @@ def forward_loop(model: nn.Module) -> None: ) input_ids = inputs["input_ids"].to(device) - # Save original attention implementation + # Save original attention implementation (the sparse-attention backend + # set by sparsify: "eager" for pytorch, "modelopt_triton" for triton). original_attn_impl = getattr(model.config, "_attn_implementation", "eager") with torch.no_grad(): @@ -191,8 +197,10 @@ def forward_loop(model: nn.Module) -> None: next_token = outputs.logits[:, -1:, :].argmax(dim=-1) del outputs # Free large prefill logits [B, seqlen, vocab] before decode loop - # Step 2: Switch to eager for decode (enables softmax hook) - model.config._attn_implementation = "eager" + # Step 2: Switch to the sparse backend for decode so measurement + # happens there (eager -> F.softmax hook; modelopt_triton -> + # Triton calibration kernel). + model.config._attn_implementation = original_attn_impl # Step 3: Manual decode loop for explicit control over token generation # model.generate() method is not used here because it doesn't allow explicit control over KV cache diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 70d606c51b0..c064fd0014d 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -546,7 +546,7 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): } -# RULER calibration via the fused Triton calibration kernel (prefill only). +# RULER calibration via the fused Triton calibration kernel (prefill + decode). # Computes the same exponential-model calibration as SKIP_SOFTMAX_CALIB but # measures tile-skip statistics with the Triton ``attention_calibrate`` kernel # (the way the Triton inference kernel actually skips tiles) instead of the @@ -555,10 +555,9 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): SKIP_SOFTMAX_TRITON_CALIB = { "sparse_cfg": { "calibration": { - # Prefill only: omitting "decode" leaves its target at 0.0, which - # skips decode calibration (the Triton calibration kernel is - # prefill-oriented). - "target_sparse_ratio": {"prefill": 0.5}, + # Prefill calibration uses full-prefill forwards; decode calibration + # runs SDPA prefill followed by Triton-backend decode steps. + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.5}, "samples": 64, "max_seqlen": 16384, # Full prefill (seq_q == seq_k, uniform batch=1) — what diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py index fc1b3d25dd3..7c1899e2776 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py @@ -49,6 +49,13 @@ def __init__(self, method_config=None): self.skip_softmax_threshold = method_config.get("skip_softmax_threshold", 0.1) # Calibration state self._threshold_trials: list[float] | None = None + # HF (modelopt_triton) backend calibration outputs, accumulated across + # attention calls in one forward pass and read back in + # ``_collect_calibration_stats``. The HF backend reads/writes these + # directly on the method instance (no thread-local needed). + self._hf_calibration_counters: torch.Tensor | None = None + self._hf_calibration_seq_k: int | None = None + self._hf_calibration_is_decode: bool = False # Runtime sparsity measurement self._measure_sparsity: bool = False self._sparsity_total: int = 0 @@ -111,6 +118,11 @@ def _triton_inference_context(self, module): def _triton_calibration_context(self, module): """Calibration: collect multi-threshold sparsity stats via Triton kernel.""" module._apply_skip_softmax = True + # Reset the HF-backend calibration accumulators for this forward pass. + # (The diffusers/LTX backends reset their own state in ``_set_triton_backends``.) + self._hf_calibration_counters = None + self._hf_calibration_seq_k = None + self._hf_calibration_is_decode = False self._set_triton_backends(calibration_mode=True, threshold_trials=self._threshold_trials) with self._get_diffusers_backend_context(): try: @@ -170,7 +182,12 @@ def _get_diffusers_backend_context(): yield def _set_triton_backends(self, **kwargs): - """Set config on the diffusers, LTX, and HF (modelopt_triton) Triton backends.""" + """Set config on the diffusers and LTX Triton backends. + + The HF (modelopt_triton) backend reads its calibration config directly + from this method instance during ``triton_attention_forward``, so it + needs no separate configuration here. + """ try: from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( set_triton_skip_softmax_config, @@ -187,17 +204,9 @@ def _set_triton_backends(self, **kwargs): set_ltx_triton_context(active=True, **kwargs) except ImportError: pass - try: - from modelopt.torch.kernels.common.attention.hf_triton_attention import ( - set_hf_triton_skip_softmax_config, - ) - - set_hf_triton_skip_softmax_config(**kwargs) - except ImportError: - pass def _clear_triton_backends(self): - """Clear config on the diffusers, LTX, and HF Triton backends.""" + """Clear config on the diffusers and LTX Triton backends.""" try: from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( clear_triton_skip_softmax_config, @@ -214,19 +223,14 @@ def _clear_triton_backends(self): clear_ltx_triton_context() except ImportError: pass - try: - from modelopt.torch.kernels.common.attention.hf_triton_attention import ( - clear_hf_triton_skip_softmax_config, - ) - - clear_hf_triton_skip_softmax_config() - except ImportError: - pass def _collect_calibration_stats(self, module): """Read Triton calibration counters and store as stats on the module.""" counters = None seq_k = None + # Diffusers/LTX (video) backends are prefill-only; only the HF backend + # reports a phase, for decode-step calibration. + phase = "prefill" try: from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( @@ -252,16 +256,12 @@ def _collect_calibration_stats(self, module): pass if counters is None: - try: - from modelopt.torch.kernels.common.attention.hf_triton_attention import ( - get_calibration_counters, - get_calibration_seq_k, - ) - - counters = get_calibration_counters() - seq_k = get_calibration_seq_k() - except ImportError: - pass + # HF (modelopt_triton) backend accumulates counters on this method + # instance (``module._sparse_method_instance is self``). + counters = self._hf_calibration_counters + seq_k = self._hf_calibration_seq_k + if counters is not None and self._hf_calibration_is_decode: + phase = "decode" if counters is None or self._threshold_trials is None: return @@ -279,7 +279,7 @@ def _collect_calibration_stats(self, module): module._last_stats = { "sparsity": sparsity_list, "sample_length": sample_length, - "phase": "prefill", + "phase": phase, } def get_threshold_info(self) -> dict: diff --git a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py index fe16559a187..d13bf8d08ed 100644 --- a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py +++ b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py @@ -132,6 +132,43 @@ def test_different_seq_q_seq_k(self): assert out.shape == q.shape assert counters.shape == (2, 2) + def test_decode_skips_padding_rows(self): + """Decode (seq_q=1) skips real KV tiles once padding Q rows are excluded. + + With BLOCK_M=128, 127/128 query rows are padding. Before the padding-row + fix their ~0 gap forced zero skips; after it the largest threshold skips a + meaningful number of KV tiles. + """ + seq_q, seq_k, num_heads, head_dim = 1, 512, 4, 64 + scale = 1.0 / (head_dim**0.5) + torch.manual_seed(0) + q = torch.randn(seq_q, num_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn(seq_k, num_heads, head_dim, device="cuda", dtype=torch.float16) + v = torch.randn(seq_k, num_heads, head_dim, device="cuda", dtype=torch.float16) + b_start_loc = torch.zeros(1, device="cuda", dtype=torch.int32) + b_seq_len = torch.ones(1, device="cuda", dtype=torch.int32) + b_start_loc_k = torch.zeros(1, device="cuda", dtype=torch.int32) + b_seq_len_k = torch.full((1,), seq_k, device="cuda", dtype=torch.int32) + + _, counters = attention_calibrate( + q, + k, + v, + b_start_loc, + b_seq_len, + seq_q, + softmax_scale=scale, + is_causal=False, + b_start_loc_k=b_start_loc_k, + b_seq_len_k=b_seq_len_k, + max_input_len_k=seq_k, + threshold_trials=[1e-2, 1e-1, 5e-1, 9e-1], + ) + skipped = counters[:, 1] + assert (skipped[1:] >= skipped[:-1]).all() # monotonic non-decreasing + assert (skipped <= counters[:, 0]).all() + assert skipped[-1] > 0 # padding-row fix makes this non-zero + def test_threshold_order_doesnt_affect_counts(self): """Skipped counts at the same threshold are independent of trial ordering.""" q, k, v, locs, lens = self._make_inputs() diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py index 26f3adf3af5..3240c177934 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py @@ -21,24 +21,26 @@ exponential-model fit used by the PyTorch path. """ +import copy + import pytest import torch from _test_utils.torch.transformers_models import create_tiny_llama_dir from transformers import AutoModelForCausalLM import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.kernels.common.attention.hf_triton_attention import triton_attention_forward from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_TRITON_CALIB -from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule +from modelopt.torch.sparsity.attention_sparsity.methods.triton_skip_softmax import ( + TritonSkipSoftmaxMethod, +) pytestmark = [ pytest.mark.filterwarnings("ignore::UserWarning"), pytest.mark.filterwarnings("ignore::RuntimeWarning"), ] -from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE - -# Thresholds spanning a wide range so the collected sparsity covers the (10%, 90%) -# window the exponential fit relies on. THRESHOLD_TRIALS = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 5e-2, 1e-1, 3e-1, 5e-1, 7e-1, 9e-1] @@ -79,52 +81,37 @@ def forward_loop(model): return forward_loop -@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") -class TestTritonCalibrationHF: - """End-to-end calibration via the Triton backend on a tiny HF model.""" - - def test_sparsify_triton_calib_sets_params(self, tiny_llama_dir): - """Running SKIP_SOFTMAX_TRITON_CALIB fits a finite exponential model.""" - import copy +def _calibration_module(threshold_trials): + """Build a bare module whose ``_sparse_method_instance`` is in calibration mode. - model = _load_eager(tiny_llama_dir) - - # Use the calibrator's default (dense) threshold trials so the collected - # sparsity densely covers the (10%, 90%) window the fit filters on. - config = copy.deepcopy(SKIP_SOFTMAX_TRITON_CALIB) + The HF backend reads its calibration config from (and writes counters back + to) ``module._sparse_method_instance``, so this is the minimal stand-in for + driving ``triton_attention_forward`` through the calibration branch. + """ + method = TritonSkipSoftmaxMethod() + method.set_calibration_mode(True) + method._threshold_trials = threshold_trials - forward_loop = _make_forward_loop(model.config.vocab_size) - sparse_model = mtsa.sparsify(model, config, forward_loop=forward_loop) + module = torch.nn.Module() + module._sparse_method_instance = method + return module - # Backend dispatched to the Triton kernel. - assert sparse_model.config._attn_implementation == "modelopt_triton" - sparse_modules = [ - m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule) - ] - assert len(sparse_modules) == 2 - - # Calibration produced finite, in-bounds (a, b) for the prefill phase. - for module in sparse_modules: - method = module._sparse_method_instance - assert method.name == "triton_skip_softmax" - params = method.calibration_params - assert params is not None and "prefill" in params - a, b = params["prefill"]["a"], params["prefill"]["b"] - assert a > 0 and torch.isfinite(torch.tensor(a)) - assert 0.0 <= b <= 20.0 - # Prefill-only: decode must not be calibrated. - assert "decode" not in params +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestTritonCalibrationHF: + """End-to-end calibration via the Triton backend on a tiny HF model.""" def test_calibrated_model_inference(self, tiny_llama_dir): - """A model calibrated through the Triton path still runs inference cleanly.""" - import copy - + """SKIP_SOFTMAX_TRITON_CALIB dispatches to the Triton backend and the + calibrated model runs inference cleanly.""" model = _load_eager(tiny_llama_dir) config = copy.deepcopy(SKIP_SOFTMAX_TRITON_CALIB) + # Prefill-only (custom forward_loop can't drive RULER decode calibration). + config["sparse_cfg"]["calibration"]["target_sparse_ratio"] = {"prefill": 0.5} forward_loop = _make_forward_loop(model.config.vocab_size) sparse_model = mtsa.sparsify(model, config, forward_loop=forward_loop) + assert sparse_model.config._attn_implementation == "modelopt_triton" sparse_model.eval() input_ids = torch.randint(0, model.config.vocab_size, (1, 64), device="cuda") @@ -133,53 +120,28 @@ def test_calibrated_model_inference(self, tiny_llama_dir): assert out.logits is not None assert not torch.isnan(out.logits).any() + def test_decode_branch_reports_decode_phase(self): + """The HF calibration branch routes decode-shaped calls through the kernel + and surfaces its counters as a ``decode``-phase stats record. -@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") -class TestHFBackendCalibrationCounters: - """Lower-level checks on the HF backend's calibration branch.""" - - def test_counters_monotonic_in_threshold(self): - """Skipped-tile counts are non-decreasing as the threshold grows.""" - from modelopt.torch.kernels.common.attention.hf_triton_attention import ( - clear_hf_triton_skip_softmax_config, - get_calibration_counters, - get_calibration_seq_k, - set_hf_triton_skip_softmax_config, - triton_attention_forward, - ) - - batch, num_heads, seq_len, head_dim = 1, 4, 256, 64 + This is the HF-only counter path in ``_collect_calibration_stats``; the + kernel's skip-count behavior itself is covered in the kernel test suite. + """ + num_heads, seq_k, head_dim = 4, 512, 64 torch.manual_seed(0) - q = torch.randn(batch, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float16) - k = torch.randn_like(q) - v = torch.randn_like(q) - - # A bare module stand-in; the calibration branch returns before touching - # any sparse-method attributes. - module = torch.nn.Module() - - set_hf_triton_skip_softmax_config( - calibration_mode=True, threshold_trials=THRESHOLD_TRIALS - ) - try: - out, _ = triton_attention_forward( - module, q, k, v, attention_mask=None, scaling=1.0 / (head_dim**0.5) - ) - counters = get_calibration_counters() - seq_k = get_calibration_seq_k() - finally: - clear_hf_triton_skip_softmax_config() - - assert out.shape == (batch, seq_len, num_heads, head_dim) - assert seq_k == seq_len - assert counters is not None - assert counters.shape == (len(THRESHOLD_TRIALS), 2) - - totals = counters[:, 0] - skipped = counters[:, 1] - assert torch.all(totals == totals[0]) # same tile count for every threshold - assert torch.all(skipped[1:] >= skipped[:-1]) # monotonic non-decreasing - assert torch.all(skipped <= totals) + q = torch.randn(1, num_heads, 1, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn(1, num_heads, seq_k, head_dim, device="cuda", dtype=torch.float16) + v = torch.randn(1, num_heads, seq_k, head_dim, device="cuda", dtype=torch.float16) + + module = _calibration_module(THRESHOLD_TRIALS) + method = module._sparse_method_instance + triton_attention_forward(module, q, k, v, attention_mask=None, scaling=1.0 / head_dim**0.5) + assert method._hf_calibration_is_decode is True + assert method._hf_calibration_counters is not None + + method._collect_calibration_stats(module) + assert module._last_stats["phase"] == "decode" + assert len(module._last_stats["sparsity"]) == len(THRESHOLD_TRIALS) if __name__ == "__main__": From f092f8cdc8b35cd69157a2ec9cbd8cc49b43dbfa Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 2 Jun 2026 16:41:51 -0700 Subject: [PATCH 3/6] Fix decode calibration: full-cache kv_bound + 128x128 block to match PyTorch Signed-off-by: Kai Xu --- .../torch/kernels/common/attention/triton_fa.py | 5 ++++- .../kernels/sparsity/attention/calibrate.py | 16 ++++++++++++++-- .../attention/test_triton_fa_calibrate.py | 4 +++- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/kernels/common/attention/triton_fa.py b/modelopt/torch/kernels/common/attention/triton_fa.py index e930b75f0d1..9f127e5ceea 100644 --- a/modelopt/torch/kernels/common/attention/triton_fa.py +++ b/modelopt/torch/kernels/common/attention/triton_fa.py @@ -80,7 +80,10 @@ def _load_sparsity_helpers() -> None: _FWD_CONFIGS = [triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=1, num_warps=4)] _MEASURE_BLOCK_M = 128 -_MEASURE_BLOCK_N = 64 +# 128 so the kernel sparsity-measurement block matches the PyTorch +# flash_skip_softmax calibration block (br = bc = 128) and the Triton +# calibration kernel; otherwise the two measure at different granularities. +_MEASURE_BLOCK_N = 128 _MEASURE_NUM_STAGES = 1 _MEASURE_NUM_WARPS = 4 diff --git a/modelopt/torch/kernels/sparsity/attention/calibrate.py b/modelopt/torch/kernels/sparsity/attention/calibrate.py index 276d319d83c..85c3279b4b2 100644 --- a/modelopt/torch/kernels/sparsity/attention/calibrate.py +++ b/modelopt/torch/kernels/sparsity/attention/calibrate.py @@ -111,7 +111,17 @@ def _attn_fwd_calibrate( local_skipped = tl.zeros([PADDED_THRESHOLDS], dtype=tl.int32) num_tiles = 0 - kv_bound = seq_len_kv if not IS_CAUSAL else tl.minimum((tile_q + 1) * BLOCK_M, seq_len_kv) + # Causal bound: when Q is a suffix of KV (decode: seq_len_q == 1 against a + # long cache; or chunked prefill), the visible KV extends to + # causal_offset + (tile_q + 1) * BLOCK_M. Without the offset the loop stops + # at the first BLOCK_M KV tokens, so decode would only ever measure the + # start of the cache instead of the whole thing. + causal_offset = seq_len_kv - seq_len_q + kv_bound = ( + seq_len_kv + if not IS_CAUSAL + else tl.minimum(causal_offset + (tile_q + 1) * BLOCK_M, seq_len_kv) + ) for kv_start in range(0, kv_bound, BLOCK_N): kv_start = tl.multiple_of(kv_start, BLOCK_N) @@ -261,8 +271,10 @@ def attention_calibrate( sm_scale = 1.0 / (HEAD_DIM**0.5) if softmax_scale is None else softmax_scale qk_scale = sm_scale * LOG2E BLOCK_D = triton.next_power_of_2(HEAD_DIM) + # 128x128 to match the PyTorch flash_skip_softmax calibration block (br = bc = 128), + # so Triton-kernel and PyTorch calibration measure sparsity at the same granularity. BLOCK_M = 128 - BLOCK_N = 64 + BLOCK_N = 128 if b_seq_len_k is None: b_seq_len_k = b_seq_len diff --git a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py index d13bf8d08ed..7fca9218f64 100644 --- a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py +++ b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py @@ -319,7 +319,9 @@ def test_first_measured_call_has_real_tile_count_with_autotune(self): assert result.returncode == 0, result.stderr totals = [line for line in result.stdout.splitlines() if line.startswith("TOTAL=")] assert totals, result.stdout - assert int(totals[-1].split("=", maxsplit=1)[1]) == 8 + # seq_len=256, _MEASURE_BLOCK_M = _MEASURE_BLOCK_N = 128, non-causal: + # Q tiles = ceil(256/128) = 2, KV tiles = ceil(256/128) = 2, total = 4. + assert int(totals[-1].split("=", maxsplit=1)[1]) == 4 def test_measure_sparsity_without_skip_is_noop(self): """Without skip-softmax, measure_sparsity doesn't attach counters.""" From d847f633c3d782a400aaa7febc98c4a322c738a8 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 2 Jun 2026 17:13:30 -0700 Subject: [PATCH 4/6] Fix decode calibration: padded row in decode Signed-off-by: Kai Xu --- .../kernels/common/attention/triton_fa.py | 2 ++ .../attention/skip_softmax_helpers.py | 19 +++++++++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/kernels/common/attention/triton_fa.py b/modelopt/torch/kernels/common/attention/triton_fa.py index 9f127e5ceea..8a1a521fea6 100644 --- a/modelopt/torch/kernels/common/attention/triton_fa.py +++ b/modelopt/torch/kernels/common/attention/triton_fa.py @@ -366,6 +366,8 @@ def _attn_fwd( skip_tile = _skip_softmax_decision( scores, row_max, + q_pos, + seq_len_q, SKIP_THRESHOLD_LOG2, Sparsity_total, Sparsity_skipped, diff --git a/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py b/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py index aa65fd50a12..044e54b2e8e 100644 --- a/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py +++ b/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py @@ -142,6 +142,8 @@ def _apply_sparse_nm_to_qk_tile( def _skip_softmax_decision( scores, row_max, + q_pos, + seq_len_q, SKIP_THRESHOLD_LOG2: tl.constexpr, Sparsity_total, Sparsity_skipped, @@ -159,16 +161,25 @@ def _skip_softmax_decision( The threshold is converted to the kernel's scaled log2 score space by the Python wrapper so it can be compared directly against ``scores``. + ``q_pos`` (``[BLOCK_M]`` absolute query positions) and the scalar + ``seq_len_q`` identify padding rows. When a tile has fewer than ``BLOCK_M`` + valid queries — decode has one valid query plus ``BLOCK_M - 1`` padding + rows, and the last prefill tile is partial when ``seq_q`` is not a multiple + of ``BLOCK_M`` — the padding rows carry zero scores that are never + negligible versus their own running max and would otherwise veto every + skip. They are forced skippable so the decision reflects only valid rows. + Returns: - True when *all* Q rows in the tile satisfy the skip criterion. + True when *all valid* Q rows in the tile satisfy the skip criterion. When ``MEASURE_SPARSITY`` is set, also records total/skipped tile counts via atomic adds on ``Sparsity_total`` / ``Sparsity_skipped``. """ tile_row_max = tl.max(scores, 1) # [BLOCK_M] — ~m_i^(j) (scaled) - # Per-row: True if row's tile max is negligible vs running max - can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2) - # Per-tile: skip entire tile only if ALL rows are negligible + # Per-row: True if the row's tile max is negligible vs running max, OR the + # row is padding (q_pos >= seq_len_q) so it must not veto the tile decision. + can_skip = (tile_row_max < (row_max + SKIP_THRESHOLD_LOG2)) | (q_pos >= seq_len_q) + # Per-tile: skip entire tile only if ALL valid rows are negligible skip_tile = tl.min(can_skip.to(tl.int32)) == 1 if MEASURE_SPARSITY: From 419aca1fc87707d64dc4db4d6322dc283a1ed342 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 2 Jun 2026 17:36:36 -0700 Subject: [PATCH 5/6] Apply per-phase calibrated skip threshold at HF inference Signed-off-by: Kai Xu --- .../common/attention/hf_triton_attention.py | 9 ++++--- .../methods/triton_skip_softmax.py | 26 +++++++++++++++---- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/kernels/common/attention/hf_triton_attention.py b/modelopt/torch/kernels/common/attention/hf_triton_attention.py index d30d5ad844a..10b77f60d1b 100644 --- a/modelopt/torch/kernels/common/attention/hf_triton_attention.py +++ b/modelopt/torch/kernels/common/attention/hf_triton_attention.py @@ -145,10 +145,13 @@ def triton_attention_forward( kw["dense_sink_tokens"] = method.dense_sink_tokens kw["dense_recent_tokens"] = method.dense_recent_tokens - # Skip-softmax: applies to both prefill and decode + # Skip-softmax: applies to both prefill and decode. Prefer the method's + # per-phase calibrated dynamic threshold (scale_factor / seq_k); fall back + # to the static threshold when uncalibrated. if method is not None and getattr(module, "_apply_skip_softmax", False): - if method.skip_softmax_threshold: - kw["skip_softmax_threshold"] = method.skip_softmax_threshold + threshold = method.get_inference_threshold(seq_len, seq_k) + if threshold: + kw["skip_softmax_threshold"] = threshold o = attention(q, k, v, **kw) diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py index 7c1899e2776..a3109d56b73 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py @@ -133,20 +133,20 @@ def _triton_calibration_context(self, module): module._apply_skip_softmax = False self._clear_triton_backends() - def _get_scale_factor(self) -> float | None: - """Compute scale_factor from calibration params, or None if uncalibrated. + def _get_scale_factor(self, phase: str = "prefill") -> float | None: + """Compute the scale_factor for ``phase`` from calibration params, or None. - The scale_factor is sequence-length-independent. Backends divide by the + The scale_factor is sequence-length-independent. Callers divide by the actual ``seq_k`` at call time: ``threshold = scale_factor / seq_k``. """ if self.calibration_params and self.target_sparse_ratio: import math import warnings - params = self.calibration_params.get("prefill", {}) + params = self.calibration_params.get(phase, {}) a = params.get("a", 0) b = params.get("b", 0) - target = self.target_sparse_ratio.get("prefill", 0.5) + target = self.target_sparse_ratio.get(phase, 0.5) if a > 0 and b > 0: # Warn if target is outside the calibrated range min_s = params.get("min_observed_sparsity") @@ -167,6 +167,22 @@ def _get_scale_factor(self) -> float | None: return a * math.exp(b * target) return None + def get_inference_threshold(self, seq_q: int, seq_k: int) -> float | None: + """Return the skip threshold to apply for this call's phase. + + Picks the phase from the query length (``decode`` when ``seq_q == 1``, + else ``prefill``) and returns the calibrated dynamic threshold + ``scale_factor(phase) / seq_k`` when the phase is calibrated, otherwise + the static ``skip_softmax_threshold`` (or ``None`` to disable). This is + what the HF backend applies; it keeps prefill and decode on their own + calibrated ``(a, b)`` instead of forcing decode onto prefill's. + """ + phase = "decode" if seq_q <= 1 else "prefill" + scale_factor = self._get_scale_factor(phase) + if scale_factor is not None and seq_k > 0: + return scale_factor / seq_k + return self.skip_softmax_threshold or None + @staticmethod @contextmanager def _get_diffusers_backend_context(): From 61dc593412035608c53a518050ad222afa195f1e Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 2 Jun 2026 17:38:27 -0700 Subject: [PATCH 6/6] Add sink-pattern decode calibration test (full cache + nonzero sparsity) Signed-off-by: Kai Xu --- .../test_triton_calibration_gpu.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py index 3240c177934..949e67b2cd8 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py @@ -22,6 +22,7 @@ """ import copy +import itertools import pytest import torch @@ -143,6 +144,38 @@ def test_decode_branch_reports_decode_phase(self): assert module._last_stats["phase"] == "decode" assert len(module._last_stats["sparsity"]) == len(THRESHOLD_TRIALS) + def test_decode_calibration_measures_full_cache_with_sink(self): + """Decode calibration must scan the whole KV cache and report real sparsity. + + A dominant sink at position 0 makes the distant KV tiles negligible, so a + correct decode measurement skips almost all of them. This guards the two + decode bugs that random inputs don't expose: + * causal-offset ``kv_bound`` — without it the loop stops after the first + ``BLOCK_M`` tokens, so ``total`` would be a fraction of the cache. + * padding-row exclusion — without it the 127 padding rows veto every + tile and sparsity is 0%. + """ + num_heads, seq_k, head_dim = 4, 2048, 64 + block_n = 128 # the calibration kernel measures at 128x128 + q = torch.ones(1, num_heads, 1, head_dim, device="cuda", dtype=torch.float16) + k = torch.zeros(1, num_heads, seq_k, head_dim, device="cuda", dtype=torch.float16) + k[:, :, 0] = 20.0 # attention sink dominates every query + v = torch.randn(1, num_heads, seq_k, head_dim, device="cuda", dtype=torch.float16) + + module = _calibration_module(THRESHOLD_TRIALS) + method = module._sparse_method_instance + triton_attention_forward(module, q, k, v, attention_mask=None, scaling=1.0 / head_dim**0.5) + + counters = method._hf_calibration_counters + total = int(counters[0, 0]) + # Full cache scanned (not truncated to the first block). + assert total == num_heads * (seq_k // block_n), total + sparsity = (counters[:, 1].float() / counters[:, 0].clamp(min=1)).tolist() + # Sink => the vast majority of tiles are negligible and skippable (not 0%). + assert max(sparsity) > 0.8, sparsity + # Skipped-tile fraction is non-decreasing as the threshold grows. + assert all(later >= earlier for earlier, later in itertools.pairwise(sparsity)), sparsity + if __name__ == "__main__": pytest.main([__file__, "-v"])