diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b6a3a979dd5..3d96e6a9fe1 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -43,6 +43,7 @@ Changelog - Add Nemotron-3-Super-120B-A12B PTQ recipes ``modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml`` (MSE-mixed) and ``super-nvfp4-max-calib.yaml`` (max-calib mixed): NVFP4 W4A4 routed experts + FP8 per-tensor shared experts / Mamba in/out_proj + FP8 KV cache. - Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default). - Add post-training quantization (PTQ) example for the Megatron-Bridge framework: ``examples/megatron_bridge/quantize.py`` calibrates an HF model (via ``--quant_cfg`` alias / full config name or a ``--recipe`` YAML, with optional KV-cache quant, weight-only, compression, and MoE expert-ratio calibration) and saves a Megatron checkpoint (tensor / pipeline / expert parallelism supported), and ``examples/megatron_bridge/export.py`` converts that checkpoint to a deployable HuggingFace (unified) checkpoint for TensorRT-LLM / vLLM / SGLang. See `examples/megatron_bridge/README.md `_ for details. +- Add ``mtsa.config.SKIP_SOFTMAX_TRITON_CALIB`` for skip-softmax attention-sparsity calibration through the fused Triton ``attention_calibrate`` kernel (HF ``modelopt_triton`` backend), measuring multi-threshold tile-skip statistics the way the Triton inference kernel actually skips tiles for both prefill and decode. Exposed as ``--sparse_attn_cfg skip_softmax_triton_calib`` in ``examples/llm_sparsity/attention_sparsity/hf_sa.py`` (with a new ``--calib_data_dir`` flag for RULER calibration data). **Bug Fixes** diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py index 5eae54ba6ee..1eacc6f18b3 100644 --- a/examples/llm_sparsity/attention_sparsity/hf_sa.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -31,6 +31,7 @@ from modelopt.torch.sparsity.attention_sparsity.config import ( SKIP_SOFTMAX_CALIB, SKIP_SOFTMAX_CALIB_SPARSE24, + SKIP_SOFTMAX_TRITON_CALIB, SPARSE_SOFTMAX_DEFAULT, ) from modelopt.torch.utils.memory_monitor import launch_memory_monitor @@ -44,6 +45,7 @@ SPARSE_ATTN_CFG_CHOICES = { "skip_softmax_calib": SKIP_SOFTMAX_CALIB, "skip_softmax_calib_sparse24": SKIP_SOFTMAX_CALIB_SPARSE24, + "skip_softmax_triton_calib": SKIP_SOFTMAX_TRITON_CALIB, "sparse_softmax": SPARSE_SOFTMAX_DEFAULT, } @@ -186,6 +188,15 @@ def main(args): calib["max_seqlen"] = args.calib_max_seqlen if args.calib_chunk_size is not None: calib["chunk_size"] = args.calib_chunk_size + # Point RULER calibration at the data downloaded by download_ruler_data.sh + # (next to this script) unless the user overrides it. The NIAH essay + # haystack requires this directory. + calib.setdefault( + "data_dir", + args.calib_data_dir + if args.calib_data_dir is not None + else str(Path(__file__).parent / "data"), + ) model = mtsa.sparsify(model, config=sparse_config) print("Sparse attention applied successfully!") @@ -302,6 +313,14 @@ def main(args): default=None, help="Chunk size for calibration prefill. Overrides config value.", ) + parser.add_argument( + "--calib_data_dir", + type=str, + default=None, + help="Path to RULER calibration data (contains an 'essays' subdir). " + "Defaults to the 'data' directory next to this script " + "(populated by download_ruler_data.sh).", + ) args = parser.parse_args() main(args) diff --git a/modelopt/torch/kernels/common/attention/hf_triton_attention.py b/modelopt/torch/kernels/common/attention/hf_triton_attention.py index 860c65d6621..10b77f60d1b 100644 --- a/modelopt/torch/kernels/common/attention/hf_triton_attention.py +++ b/modelopt/torch/kernels/common/attention/hf_triton_attention.py @@ -27,6 +27,10 @@ from modelopt.torch.kernels.common.attention.triton_fa import attention +# Skip-softmax calibration config and counters live on the module's +# ``_sparse_method_instance`` (HF passes the owning module to +# ``triton_attention_forward``), so no separate thread-local state is needed. + def _seq_lens_from_mask( attention_mask: torch.Tensor | None, @@ -105,9 +109,35 @@ def triton_attention_forward( kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32) kw["max_input_len_k"] = seq_k - # Sparse attention params + # Sparse-attention method instance. It carries the inference threshold and, + # during calibration, both the calibration config and the accumulated + # tile-skip counters. Available here because HF passes the owning module. method = getattr(module, "_sparse_method_instance", None) + # Calibration mode: run the calibration kernel, which computes full attention + # while counting, per candidate threshold, how many KV tiles would be skipped. + # The sparse-attention kwargs below are intentionally not added in this branch. + if method is not None and getattr(method, "_calibration_mode", False): + trials = getattr(method, "_threshold_trials", None) + # Deferred: the package __init__ imports this module, so importing + # attention_calibrate at module top would be circular. + from modelopt.torch.kernels.common.attention import attention_calibrate + + if trials and attention_calibrate is not None: + o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials) + + # Accumulate counters across all attention calls in this forward pass. + # The method instance is per-module so the accumulator stays on one + # device, but guard the add against a device mismatch just in case. + prev = getattr(method, "_hf_calibration_counters", None) + method._hf_calibration_counters = ( + counters if prev is None else prev + counters.to(prev.device) + ) + method._hf_calibration_seq_k = seq_k + method._hf_calibration_is_decode = is_decode + + return (o.view(batch, seq_len, num_heads, head_dim), None) + # N:M sparse softmax: prefill only (no perf benefit for decode) if method is not None and not is_decode and getattr(module, "_apply_sparse_nm", False): kw["sparsity_n"] = method.sparsity_n @@ -115,10 +145,13 @@ def triton_attention_forward( kw["dense_sink_tokens"] = method.dense_sink_tokens kw["dense_recent_tokens"] = method.dense_recent_tokens - # Skip-softmax: applies to both prefill and decode + # Skip-softmax: applies to both prefill and decode. Prefer the method's + # per-phase calibrated dynamic threshold (scale_factor / seq_k); fall back + # to the static threshold when uncalibrated. if method is not None and getattr(module, "_apply_skip_softmax", False): - if method.skip_softmax_threshold: - kw["skip_softmax_threshold"] = method.skip_softmax_threshold + threshold = method.get_inference_threshold(seq_len, seq_k) + if threshold: + kw["skip_softmax_threshold"] = threshold o = attention(q, k, v, **kw) diff --git a/modelopt/torch/kernels/common/attention/triton_fa.py b/modelopt/torch/kernels/common/attention/triton_fa.py index 0b481e93558..8a1a521fea6 100644 --- a/modelopt/torch/kernels/common/attention/triton_fa.py +++ b/modelopt/torch/kernels/common/attention/triton_fa.py @@ -80,7 +80,10 @@ def _load_sparsity_helpers() -> None: _FWD_CONFIGS = [triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=1, num_warps=4)] _MEASURE_BLOCK_M = 128 -_MEASURE_BLOCK_N = 64 +# 128 so the kernel sparsity-measurement block matches the PyTorch +# flash_skip_softmax calibration block (br = bc = 128) and the Triton +# calibration kernel; otherwise the two measure at different granularities. +_MEASURE_BLOCK_N = 128 _MEASURE_NUM_STAGES = 1 _MEASURE_NUM_WARPS = 4 @@ -363,6 +366,8 @@ def _attn_fwd( skip_tile = _skip_softmax_decision( scores, row_max, + q_pos, + seq_len_q, SKIP_THRESHOLD_LOG2, Sparsity_total, Sparsity_skipped, @@ -919,23 +924,29 @@ def forward( def grid(META): return (batch, num_q_heads, triton.cdiv(max_input_len, META["BLOCK_M"])) - if do_measure: - # Runtime counters mutate global tensors, so do not run them through - # autotune candidate trials. Use one stable config for measurement. - _attn_fwd.fn[grid]( - *fwd_args, - **fwd_kwargs, - BLOCK_M=_MEASURE_BLOCK_M, - BLOCK_N=_MEASURE_BLOCK_N, - num_warps=_MEASURE_NUM_WARPS, - num_stages=_MEASURE_NUM_STAGES, - ) - else: - _attn_fwd[grid]( - *fwd_args, - **fwd_kwargs, - # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune - ) + # Triton launches on torch.cuda.current_device(), which is not + # necessarily the device the tensors live on (e.g. under accelerate + # device_map="auto" sharding). Activate the tensor's device so the + # kernel dereferences the right pointers instead of triggering an + # illegal memory access. + with torch.cuda.device(q.device): + if do_measure: + # Runtime counters mutate global tensors, so do not run them through + # autotune candidate trials. Use one stable config for measurement. + _attn_fwd.fn[grid]( + *fwd_args, + **fwd_kwargs, + BLOCK_M=_MEASURE_BLOCK_M, + BLOCK_N=_MEASURE_BLOCK_N, + num_warps=_MEASURE_NUM_WARPS, + num_stages=_MEASURE_NUM_STAGES, + ) + else: + _attn_fwd[grid]( + *fwd_args, + **fwd_kwargs, + # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune + ) # Store sparsity counters on the output tensor for retrieval by callers if do_measure: @@ -970,23 +981,30 @@ def backward(ctx, grad_output): do = grad_output.contiguous() num_warps = 4 + # Triton launches on torch.cuda.current_device(), which is not + # necessarily the device the tensors live on (e.g. under accelerate + # device_map="auto" sharding). Activate the tensor's device for each + # launch so the kernels dereference the right pointers instead of + # triggering an illegal memory access. + # Phase 1: delta = rowsum(O * dO) delta = torch.empty_like(lse) - _attn_bwd_preprocess[(ctx.num_q_heads, triton.cdiv(q.shape[0], BLOCK))]( - o, - do, - delta, - o.stride(0), - o.stride(1), - do.stride(0), - do.stride(1), - delta.stride(0), - delta.stride(1), - q.shape[0], - HEAD_DIM=HEAD_DIM, - BLOCK_D=BLOCK_D, - BLOCK_M=BLOCK, - ) + with torch.cuda.device(q.device): + _attn_bwd_preprocess[(ctx.num_q_heads, triton.cdiv(q.shape[0], BLOCK))]( + o, + do, + delta, + o.stride(0), + o.stride(1), + do.stride(0), + do.stride(1), + delta.stride(0), + delta.stride(1), + q.shape[0], + HEAD_DIM=HEAD_DIM, + BLOCK_D=BLOCK_D, + BLOCK_M=BLOCK, + ) dq = torch.zeros_like(q) dk = torch.zeros_like(k) @@ -1016,57 +1034,59 @@ def backward(ctx, grad_output): ) # Phase 2: dK, dV - _attn_bwd_dkdv[(ctx.batch, ctx.num_kv_heads, triton.cdiv(ctx.max_input_len_k, BLOCK))]( - *bwd_args[:4], - dk, - dv, - *bwd_args[4:], - dk.stride(0), - dk.stride(1), - dv.stride(0), - dv.stride(1), - lse.stride(0), - lse.stride(1), - kv_group_num=ctx.kv_group_num, - BLOCK_M=BLOCK, - BLOCK_D=BLOCK_D, - BLOCK_N=BLOCK, - IS_CAUSAL=ctx.is_causal, - HEAD_DIM=HEAD_DIM, - SPARSITY_N=ctx.sparsity_n, - SPARSITY_M=ctx.sparsity_m, - DENSE_SINK_TOKENS=ctx.dense_sink_tokens, - DENSE_RECENT_TOKENS=ctx.dense_recent_tokens, - APPLY_SKIP_SOFTMAX=ctx.apply_skip, - SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, - num_warps=num_warps, - num_stages=1, - ) + with torch.cuda.device(q.device): + _attn_bwd_dkdv[(ctx.batch, ctx.num_kv_heads, triton.cdiv(ctx.max_input_len_k, BLOCK))]( + *bwd_args[:4], + dk, + dv, + *bwd_args[4:], + dk.stride(0), + dk.stride(1), + dv.stride(0), + dv.stride(1), + lse.stride(0), + lse.stride(1), + kv_group_num=ctx.kv_group_num, + BLOCK_M=BLOCK, + BLOCK_D=BLOCK_D, + BLOCK_N=BLOCK, + IS_CAUSAL=ctx.is_causal, + HEAD_DIM=HEAD_DIM, + SPARSITY_N=ctx.sparsity_n, + SPARSITY_M=ctx.sparsity_m, + DENSE_SINK_TOKENS=ctx.dense_sink_tokens, + DENSE_RECENT_TOKENS=ctx.dense_recent_tokens, + APPLY_SKIP_SOFTMAX=ctx.apply_skip, + SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, + num_warps=num_warps, + num_stages=1, + ) # Phase 3: dQ - _attn_bwd_dq[(ctx.batch, ctx.num_q_heads, triton.cdiv(ctx.max_input_len, BLOCK))]( - *bwd_args[:4], - dq, - *bwd_args[4:], - dq.stride(0), - dq.stride(1), - lse.stride(0), - lse.stride(1), - kv_group_num=ctx.kv_group_num, - BLOCK_M=BLOCK, - BLOCK_D=BLOCK_D, - BLOCK_N=BLOCK, - IS_CAUSAL=ctx.is_causal, - HEAD_DIM=HEAD_DIM, - SPARSITY_N=ctx.sparsity_n, - SPARSITY_M=ctx.sparsity_m, - DENSE_SINK_TOKENS=ctx.dense_sink_tokens, - DENSE_RECENT_TOKENS=ctx.dense_recent_tokens, - APPLY_SKIP_SOFTMAX=ctx.apply_skip, - SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, - num_warps=num_warps, - num_stages=1, - ) + with torch.cuda.device(q.device): + _attn_bwd_dq[(ctx.batch, ctx.num_q_heads, triton.cdiv(ctx.max_input_len, BLOCK))]( + *bwd_args[:4], + dq, + *bwd_args[4:], + dq.stride(0), + dq.stride(1), + lse.stride(0), + lse.stride(1), + kv_group_num=ctx.kv_group_num, + BLOCK_M=BLOCK, + BLOCK_D=BLOCK_D, + BLOCK_N=BLOCK, + IS_CAUSAL=ctx.is_causal, + HEAD_DIM=HEAD_DIM, + SPARSITY_N=ctx.sparsity_n, + SPARSITY_M=ctx.sparsity_m, + DENSE_SINK_TOKENS=ctx.dense_sink_tokens, + DENSE_RECENT_TOKENS=ctx.dense_recent_tokens, + APPLY_SKIP_SOFTMAX=ctx.apply_skip, + SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, + num_warps=num_warps, + num_stages=1, + ) return ( dq, diff --git a/modelopt/torch/kernels/sparsity/attention/calibrate.py b/modelopt/torch/kernels/sparsity/attention/calibrate.py index 971c423f711..85c3279b4b2 100644 --- a/modelopt/torch/kernels/sparsity/attention/calibrate.py +++ b/modelopt/torch/kernels/sparsity/attention/calibrate.py @@ -111,7 +111,17 @@ def _attn_fwd_calibrate( local_skipped = tl.zeros([PADDED_THRESHOLDS], dtype=tl.int32) num_tiles = 0 - kv_bound = seq_len_kv if not IS_CAUSAL else tl.minimum((tile_q + 1) * BLOCK_M, seq_len_kv) + # Causal bound: when Q is a suffix of KV (decode: seq_len_q == 1 against a + # long cache; or chunked prefill), the visible KV extends to + # causal_offset + (tile_q + 1) * BLOCK_M. Without the offset the loop stops + # at the first BLOCK_M KV tokens, so decode would only ever measure the + # start of the cache instead of the whole thing. + causal_offset = seq_len_kv - seq_len_q + kv_bound = ( + seq_len_kv + if not IS_CAUSAL + else tl.minimum(causal_offset + (tile_q + 1) * BLOCK_M, seq_len_kv) + ) for kv_start in range(0, kv_bound, BLOCK_N): kv_start = tl.multiple_of(kv_start, BLOCK_N) @@ -132,7 +142,16 @@ def _attn_fwd_calibrate( # A tile is skipped iff ALL Q rows satisfy: tile_row_max < row_max + thresh. # Equivalently: max(tile_row_max - row_max) < thresh (worst-case row # must still be below threshold for the tile to be skippable). - max_gap = tl.max(tile_row_max - row_max) # scalar + # + # Exclude padding Q rows (q_pos >= seq_len_q) from the reduction. Their Q is + # loaded as zeros, so their tile_row_max is ~0 (not -inf), which would + # otherwise dominate the max and force max_gap >= 0 — making every tile + # un-skippable. This matters most for decode (seq_len_q == 1, so 127/128 + # rows are padding) and also fixes the last partial Q tile in prefill when + # seq_len_q is not a multiple of BLOCK_M. + gap = tile_row_max - row_max + gap = tl.where(q_pos < seq_len_q, gap, -float("inf")) + max_gap = tl.max(gap) # scalar skip_mask = (max_gap < thresholds).to(tl.int32) # [PADDED_THRESHOLDS] local_skipped += skip_mask num_tiles += 1 @@ -252,8 +271,10 @@ def attention_calibrate( sm_scale = 1.0 / (HEAD_DIM**0.5) if softmax_scale is None else softmax_scale qk_scale = sm_scale * LOG2E BLOCK_D = triton.next_power_of_2(HEAD_DIM) + # 128x128 to match the PyTorch flash_skip_softmax calibration block (br = bc = 128), + # so Triton-kernel and PyTorch calibration measure sparsity at the same granularity. BLOCK_M = 128 - BLOCK_N = 64 + BLOCK_N = 128 if b_seq_len_k is None: b_seq_len_k = b_seq_len @@ -282,38 +303,43 @@ def attention_calibrate( num_programs * num_thresholds, dtype=torch.int32, device=q.device ) - _attn_fwd_calibrate[grid]( - q, - k, - v, - qk_scale, - b_start_loc, - b_seq_len, - b_start_loc_k, - b_seq_len_k, - o, - q.stride(0), - q.stride(1), - k.stride(0), - k.stride(1), - v.stride(0), - v.stride(1), - o.stride(0), - o.stride(1), - threshold_tensor, - per_program_totals, - per_program_skipped, - kv_group_num=kv_group_num, - BLOCK_M=BLOCK_M, - BLOCK_D=BLOCK_D, - BLOCK_N=BLOCK_N, - IS_CAUSAL=is_causal, - HEAD_DIM=HEAD_DIM, - NUM_THRESHOLDS=num_thresholds, - PADDED_THRESHOLDS=triton.next_power_of_2(num_thresholds), - num_warps=4, - num_stages=1, - ) + # Triton launches on torch.cuda.current_device(), which is not necessarily + # the device the tensors live on (e.g. under accelerate device_map="auto" + # sharding). Activate the tensor's device so the kernel dereferences the + # right pointers instead of triggering an illegal memory access. + with torch.cuda.device(q.device): + _attn_fwd_calibrate[grid]( + q, + k, + v, + qk_scale, + b_start_loc, + b_seq_len, + b_start_loc_k, + b_seq_len_k, + o, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + threshold_tensor, + per_program_totals, + per_program_skipped, + kv_group_num=kv_group_num, + BLOCK_M=BLOCK_M, + BLOCK_D=BLOCK_D, + BLOCK_N=BLOCK_N, + IS_CAUSAL=is_causal, + HEAD_DIM=HEAD_DIM, + NUM_THRESHOLDS=num_thresholds, + PADDED_THRESHOLDS=triton.next_power_of_2(num_thresholds), + num_warps=4, + num_stages=1, + ) # Reduce across programs: sum per-program counts → [num_thresholds] totals = per_program_totals.view(num_programs, num_thresholds).sum(dim=0).to(torch.int64) diff --git a/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py b/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py index aa65fd50a12..044e54b2e8e 100644 --- a/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py +++ b/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py @@ -142,6 +142,8 @@ def _apply_sparse_nm_to_qk_tile( def _skip_softmax_decision( scores, row_max, + q_pos, + seq_len_q, SKIP_THRESHOLD_LOG2: tl.constexpr, Sparsity_total, Sparsity_skipped, @@ -159,16 +161,25 @@ def _skip_softmax_decision( The threshold is converted to the kernel's scaled log2 score space by the Python wrapper so it can be compared directly against ``scores``. + ``q_pos`` (``[BLOCK_M]`` absolute query positions) and the scalar + ``seq_len_q`` identify padding rows. When a tile has fewer than ``BLOCK_M`` + valid queries — decode has one valid query plus ``BLOCK_M - 1`` padding + rows, and the last prefill tile is partial when ``seq_q`` is not a multiple + of ``BLOCK_M`` — the padding rows carry zero scores that are never + negligible versus their own running max and would otherwise veto every + skip. They are forced skippable so the decision reflects only valid rows. + Returns: - True when *all* Q rows in the tile satisfy the skip criterion. + True when *all valid* Q rows in the tile satisfy the skip criterion. When ``MEASURE_SPARSITY`` is set, also records total/skipped tile counts via atomic adds on ``Sparsity_total`` / ``Sparsity_skipped``. """ tile_row_max = tl.max(scores, 1) # [BLOCK_M] — ~m_i^(j) (scaled) - # Per-row: True if row's tile max is negligible vs running max - can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2) - # Per-tile: skip entire tile only if ALL rows are negligible + # Per-row: True if the row's tile max is negligible vs running max, OR the + # row is padding (q_pos >= seq_len_q) so it must not veto the tile decision. + can_skip = (tile_row_max < (row_max + SKIP_THRESHOLD_LOG2)) | (q_pos >= seq_len_q) + # Per-tile: skip entire tile only if ALL valid rows are negligible skip_tile = tl.min(can_skip.to(tl.int32)) == 1 if MEASURE_SPARSITY: diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py index 51df5bb4d4a..840f757a8c6 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -153,9 +153,14 @@ def create_decode_calibration_forward_loop( ) -> Callable: """Create forward loop for decode phase calibration. - Uses SDPA for fast prefill, then switches to eager attention for decode - token generation with softmax hook measurement. (Previously used - ``flash_attention_2`` for prefill, but transformers>=5.0's FA2 path + Uses SDPA for fast prefill (no measurement), then switches to the model's + configured sparse-attention backend for the decode steps so measurement + happens there: ``eager`` for the pytorch backend (F.softmax hook) or + ``modelopt_triton`` for the triton backend (Triton calibration kernel). + The backend is read from ``model.config._attn_implementation``, which + ``sparsify`` already set for the chosen backend. + + (SDPA is used for prefill because transformers>=5.0's FA2 path unconditionally calls ``s_aux.to(query.dtype)`` on the attention-sinks tensor and crashes for models without sinks. SDPA is just as fast for prefill, has no softmax hook, and is version-stable.) @@ -179,7 +184,8 @@ def forward_loop(model: nn.Module) -> None: ) input_ids = inputs["input_ids"].to(device) - # Save original attention implementation + # Save original attention implementation (the sparse-attention backend + # set by sparsify: "eager" for pytorch, "modelopt_triton" for triton). original_attn_impl = getattr(model.config, "_attn_implementation", "eager") with torch.no_grad(): @@ -191,8 +197,10 @@ def forward_loop(model: nn.Module) -> None: next_token = outputs.logits[:, -1:, :].argmax(dim=-1) del outputs # Free large prefill logits [B, seqlen, vocab] before decode loop - # Step 2: Switch to eager for decode (enables softmax hook) - model.config._attn_implementation = "eager" + # Step 2: Switch to the sparse backend for decode so measurement + # happens there (eager -> F.softmax hook; modelopt_triton -> + # Triton calibration kernel). + model.config._attn_implementation = original_attn_impl # Step 3: Manual decode loop for explicit control over token generation # model.generate() method is not used here because it doesn't allow explicit control over KV cache diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 32a49f02e34..c064fd0014d 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -546,6 +546,35 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): } +# RULER calibration via the fused Triton calibration kernel (prefill + decode). +# Computes the same exponential-model calibration as SKIP_SOFTMAX_CALIB but +# measures tile-skip statistics with the Triton ``attention_calibrate`` kernel +# (the way the Triton inference kernel actually skips tiles) instead of the +# PyTorch F.softmax-patching block path. Faster on GPU since it avoids +# materializing per-block tensors. +SKIP_SOFTMAX_TRITON_CALIB = { + "sparse_cfg": { + "calibration": { + # Prefill calibration uses full-prefill forwards; decode calibration + # runs SDPA prefill followed by Triton-backend decode steps. + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.5}, + "samples": 64, + "max_seqlen": 16384, + # Full prefill (seq_q == seq_k, uniform batch=1) — what + # attention_calibrate was validated against. Chunked prefill would + # exercise an untested KV-cache causal-offset path in the kernel. + "chunk_size": -1, + }, + "*attn*": { + "method": "triton_skip_softmax", + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + }, +} + + class VSAAttributeConfig(ModeloptBaseConfig): """Video Sparse Attention (VSA) attribute configuration. @@ -738,6 +767,7 @@ class VSAConfig(SparseAttentionConfig): "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_CALIB_SPARSE24", "SKIP_SOFTMAX_DEFAULT", + "SKIP_SOFTMAX_TRITON_CALIB", "SKIP_SOFTMAX_TRITON_DEFAULT", "SPARSE_SOFTMAX_DEFAULT", "VSA_DEFAULT", diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py index c0a183787dd..a3109d56b73 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py @@ -49,6 +49,13 @@ def __init__(self, method_config=None): self.skip_softmax_threshold = method_config.get("skip_softmax_threshold", 0.1) # Calibration state self._threshold_trials: list[float] | None = None + # HF (modelopt_triton) backend calibration outputs, accumulated across + # attention calls in one forward pass and read back in + # ``_collect_calibration_stats``. The HF backend reads/writes these + # directly on the method instance (no thread-local needed). + self._hf_calibration_counters: torch.Tensor | None = None + self._hf_calibration_seq_k: int | None = None + self._hf_calibration_is_decode: bool = False # Runtime sparsity measurement self._measure_sparsity: bool = False self._sparsity_total: int = 0 @@ -111,6 +118,11 @@ def _triton_inference_context(self, module): def _triton_calibration_context(self, module): """Calibration: collect multi-threshold sparsity stats via Triton kernel.""" module._apply_skip_softmax = True + # Reset the HF-backend calibration accumulators for this forward pass. + # (The diffusers/LTX backends reset their own state in ``_set_triton_backends``.) + self._hf_calibration_counters = None + self._hf_calibration_seq_k = None + self._hf_calibration_is_decode = False self._set_triton_backends(calibration_mode=True, threshold_trials=self._threshold_trials) with self._get_diffusers_backend_context(): try: @@ -121,20 +133,20 @@ def _triton_calibration_context(self, module): module._apply_skip_softmax = False self._clear_triton_backends() - def _get_scale_factor(self) -> float | None: - """Compute scale_factor from calibration params, or None if uncalibrated. + def _get_scale_factor(self, phase: str = "prefill") -> float | None: + """Compute the scale_factor for ``phase`` from calibration params, or None. - The scale_factor is sequence-length-independent. Backends divide by the + The scale_factor is sequence-length-independent. Callers divide by the actual ``seq_k`` at call time: ``threshold = scale_factor / seq_k``. """ if self.calibration_params and self.target_sparse_ratio: import math import warnings - params = self.calibration_params.get("prefill", {}) + params = self.calibration_params.get(phase, {}) a = params.get("a", 0) b = params.get("b", 0) - target = self.target_sparse_ratio.get("prefill", 0.5) + target = self.target_sparse_ratio.get(phase, 0.5) if a > 0 and b > 0: # Warn if target is outside the calibrated range min_s = params.get("min_observed_sparsity") @@ -155,6 +167,22 @@ def _get_scale_factor(self) -> float | None: return a * math.exp(b * target) return None + def get_inference_threshold(self, seq_q: int, seq_k: int) -> float | None: + """Return the skip threshold to apply for this call's phase. + + Picks the phase from the query length (``decode`` when ``seq_q == 1``, + else ``prefill``) and returns the calibrated dynamic threshold + ``scale_factor(phase) / seq_k`` when the phase is calibrated, otherwise + the static ``skip_softmax_threshold`` (or ``None`` to disable). This is + what the HF backend applies; it keeps prefill and decode on their own + calibrated ``(a, b)`` instead of forcing decode onto prefill's. + """ + phase = "decode" if seq_q <= 1 else "prefill" + scale_factor = self._get_scale_factor(phase) + if scale_factor is not None and seq_k > 0: + return scale_factor / seq_k + return self.skip_softmax_threshold or None + @staticmethod @contextmanager def _get_diffusers_backend_context(): @@ -170,7 +198,12 @@ def _get_diffusers_backend_context(): yield def _set_triton_backends(self, **kwargs): - """Set config on both diffusers and LTX Triton backends.""" + """Set config on the diffusers and LTX Triton backends. + + The HF (modelopt_triton) backend reads its calibration config directly + from this method instance during ``triton_attention_forward``, so it + needs no separate configuration here. + """ try: from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( set_triton_skip_softmax_config, @@ -189,7 +222,7 @@ def _set_triton_backends(self, **kwargs): pass def _clear_triton_backends(self): - """Clear config on both Triton backends.""" + """Clear config on the diffusers and LTX Triton backends.""" try: from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( clear_triton_skip_softmax_config, @@ -211,6 +244,9 @@ def _collect_calibration_stats(self, module): """Read Triton calibration counters and store as stats on the module.""" counters = None seq_k = None + # Diffusers/LTX (video) backends are prefill-only; only the HF backend + # reports a phase, for decode-step calibration. + phase = "prefill" try: from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( @@ -235,6 +271,14 @@ def _collect_calibration_stats(self, module): except ImportError: pass + if counters is None: + # HF (modelopt_triton) backend accumulates counters on this method + # instance (``module._sparse_method_instance is self``). + counters = self._hf_calibration_counters + seq_k = self._hf_calibration_seq_k + if counters is not None and self._hf_calibration_is_decode: + phase = "decode" + if counters is None or self._threshold_trials is None: return @@ -251,7 +295,7 @@ def _collect_calibration_stats(self, module): module._last_stats = { "sparsity": sparsity_list, "sample_length": sample_length, - "phase": "prefill", + "phase": phase, } def get_threshold_info(self) -> dict: diff --git a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py index fe16559a187..7fca9218f64 100644 --- a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py +++ b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py @@ -132,6 +132,43 @@ def test_different_seq_q_seq_k(self): assert out.shape == q.shape assert counters.shape == (2, 2) + def test_decode_skips_padding_rows(self): + """Decode (seq_q=1) skips real KV tiles once padding Q rows are excluded. + + With BLOCK_M=128, 127/128 query rows are padding. Before the padding-row + fix their ~0 gap forced zero skips; after it the largest threshold skips a + meaningful number of KV tiles. + """ + seq_q, seq_k, num_heads, head_dim = 1, 512, 4, 64 + scale = 1.0 / (head_dim**0.5) + torch.manual_seed(0) + q = torch.randn(seq_q, num_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn(seq_k, num_heads, head_dim, device="cuda", dtype=torch.float16) + v = torch.randn(seq_k, num_heads, head_dim, device="cuda", dtype=torch.float16) + b_start_loc = torch.zeros(1, device="cuda", dtype=torch.int32) + b_seq_len = torch.ones(1, device="cuda", dtype=torch.int32) + b_start_loc_k = torch.zeros(1, device="cuda", dtype=torch.int32) + b_seq_len_k = torch.full((1,), seq_k, device="cuda", dtype=torch.int32) + + _, counters = attention_calibrate( + q, + k, + v, + b_start_loc, + b_seq_len, + seq_q, + softmax_scale=scale, + is_causal=False, + b_start_loc_k=b_start_loc_k, + b_seq_len_k=b_seq_len_k, + max_input_len_k=seq_k, + threshold_trials=[1e-2, 1e-1, 5e-1, 9e-1], + ) + skipped = counters[:, 1] + assert (skipped[1:] >= skipped[:-1]).all() # monotonic non-decreasing + assert (skipped <= counters[:, 0]).all() + assert skipped[-1] > 0 # padding-row fix makes this non-zero + def test_threshold_order_doesnt_affect_counts(self): """Skipped counts at the same threshold are independent of trial ordering.""" q, k, v, locs, lens = self._make_inputs() @@ -282,7 +319,9 @@ def test_first_measured_call_has_real_tile_count_with_autotune(self): assert result.returncode == 0, result.stderr totals = [line for line in result.stdout.splitlines() if line.startswith("TOTAL=")] assert totals, result.stdout - assert int(totals[-1].split("=", maxsplit=1)[1]) == 8 + # seq_len=256, _MEASURE_BLOCK_M = _MEASURE_BLOCK_N = 128, non-causal: + # Q tiles = ceil(256/128) = 2, KV tiles = ceil(256/128) = 2, total = 4. + assert int(totals[-1].split("=", maxsplit=1)[1]) == 4 def test_measure_sparsity_without_skip_is_noop(self): """Without skip-softmax, measure_sparsity doesn't attach counters.""" diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py new file mode 100644 index 00000000000..949e67b2cd8 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py @@ -0,0 +1,181 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for skip-softmax calibration via the Triton backend on HF models. + +These exercise the HuggingFace (``modelopt_triton``) wiring that routes the +calibration forward pass through the fused ``attention_calibrate`` kernel and +feeds the collected multi-threshold tile-skip statistics into the same +exponential-model fit used by the PyTorch path. +""" + +import copy +import itertools + +import pytest +import torch +from _test_utils.torch.transformers_models import create_tiny_llama_dir +from transformers import AutoModelForCausalLM + +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.kernels.common.attention.hf_triton_attention import triton_attention_forward +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_TRITON_CALIB +from modelopt.torch.sparsity.attention_sparsity.methods.triton_skip_softmax import ( + TritonSkipSoftmaxMethod, +) + +pytestmark = [ + pytest.mark.filterwarnings("ignore::UserWarning"), + pytest.mark.filterwarnings("ignore::RuntimeWarning"), +] + +THRESHOLD_TRIALS = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 5e-2, 1e-1, 3e-1, 5e-1, 7e-1, 9e-1] + + +@pytest.fixture(scope="module") +def tiny_llama_dir(tmp_path_factory): + """Create a minimal Llama model directory.""" + return create_tiny_llama_dir( + tmp_path_factory.mktemp("tiny_llama_triton_calib"), + num_hidden_layers=2, + hidden_size=64, + intermediate_size=128, + num_attention_heads=4, + num_key_value_heads=2, + max_position_embeddings=1024, + ) + + +def _load_eager(tiny_llama_dir): + return AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, attn_implementation="eager", device_map="cuda" + ) + + +def _make_forward_loop(vocab_size, lengths=(128, 256, 384, 512)): + """Forward loop that runs several full-prefill passes of varying length. + + Each pass triggers one ``attention_calibrate`` call per layer, producing one + per-sample calibration record per length. + """ + + def forward_loop(model): + torch.manual_seed(0) + for seq_len in lengths: + input_ids = torch.randint(0, vocab_size, (1, seq_len), device="cuda") + with torch.no_grad(): + model(input_ids, use_cache=False) + + return forward_loop + + +def _calibration_module(threshold_trials): + """Build a bare module whose ``_sparse_method_instance`` is in calibration mode. + + The HF backend reads its calibration config from (and writes counters back + to) ``module._sparse_method_instance``, so this is the minimal stand-in for + driving ``triton_attention_forward`` through the calibration branch. + """ + method = TritonSkipSoftmaxMethod() + method.set_calibration_mode(True) + method._threshold_trials = threshold_trials + + module = torch.nn.Module() + module._sparse_method_instance = method + return module + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestTritonCalibrationHF: + """End-to-end calibration via the Triton backend on a tiny HF model.""" + + def test_calibrated_model_inference(self, tiny_llama_dir): + """SKIP_SOFTMAX_TRITON_CALIB dispatches to the Triton backend and the + calibrated model runs inference cleanly.""" + model = _load_eager(tiny_llama_dir) + config = copy.deepcopy(SKIP_SOFTMAX_TRITON_CALIB) + # Prefill-only (custom forward_loop can't drive RULER decode calibration). + config["sparse_cfg"]["calibration"]["target_sparse_ratio"] = {"prefill": 0.5} + + forward_loop = _make_forward_loop(model.config.vocab_size) + sparse_model = mtsa.sparsify(model, config, forward_loop=forward_loop) + assert sparse_model.config._attn_implementation == "modelopt_triton" + + sparse_model.eval() + input_ids = torch.randint(0, model.config.vocab_size, (1, 64), device="cuda") + with torch.no_grad(): + out = sparse_model(input_ids, use_cache=False) + assert out.logits is not None + assert not torch.isnan(out.logits).any() + + def test_decode_branch_reports_decode_phase(self): + """The HF calibration branch routes decode-shaped calls through the kernel + and surfaces its counters as a ``decode``-phase stats record. + + This is the HF-only counter path in ``_collect_calibration_stats``; the + kernel's skip-count behavior itself is covered in the kernel test suite. + """ + num_heads, seq_k, head_dim = 4, 512, 64 + torch.manual_seed(0) + q = torch.randn(1, num_heads, 1, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn(1, num_heads, seq_k, head_dim, device="cuda", dtype=torch.float16) + v = torch.randn(1, num_heads, seq_k, head_dim, device="cuda", dtype=torch.float16) + + module = _calibration_module(THRESHOLD_TRIALS) + method = module._sparse_method_instance + triton_attention_forward(module, q, k, v, attention_mask=None, scaling=1.0 / head_dim**0.5) + assert method._hf_calibration_is_decode is True + assert method._hf_calibration_counters is not None + + method._collect_calibration_stats(module) + assert module._last_stats["phase"] == "decode" + assert len(module._last_stats["sparsity"]) == len(THRESHOLD_TRIALS) + + def test_decode_calibration_measures_full_cache_with_sink(self): + """Decode calibration must scan the whole KV cache and report real sparsity. + + A dominant sink at position 0 makes the distant KV tiles negligible, so a + correct decode measurement skips almost all of them. This guards the two + decode bugs that random inputs don't expose: + * causal-offset ``kv_bound`` — without it the loop stops after the first + ``BLOCK_M`` tokens, so ``total`` would be a fraction of the cache. + * padding-row exclusion — without it the 127 padding rows veto every + tile and sparsity is 0%. + """ + num_heads, seq_k, head_dim = 4, 2048, 64 + block_n = 128 # the calibration kernel measures at 128x128 + q = torch.ones(1, num_heads, 1, head_dim, device="cuda", dtype=torch.float16) + k = torch.zeros(1, num_heads, seq_k, head_dim, device="cuda", dtype=torch.float16) + k[:, :, 0] = 20.0 # attention sink dominates every query + v = torch.randn(1, num_heads, seq_k, head_dim, device="cuda", dtype=torch.float16) + + module = _calibration_module(THRESHOLD_TRIALS) + method = module._sparse_method_instance + triton_attention_forward(module, q, k, v, attention_mask=None, scaling=1.0 / head_dim**0.5) + + counters = method._hf_calibration_counters + total = int(counters[0, 0]) + # Full cache scanned (not truncated to the first block). + assert total == num_heads * (seq_k // block_n), total + sparsity = (counters[:, 1].float() / counters[:, 0].clamp(min=1)).tolist() + # Sink => the vast majority of tiles are negligible and skippable (not 0%). + assert max(sparsity) > 0.8, sparsity + # Skipped-tile fraction is non-decreasing as the threshold grows. + assert all(later >= earlier for earlier, later in itertools.pairwise(sparsity)), sparsity + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])