Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Changelog
- Add Nemotron-3-Super-120B-A12B PTQ recipes ``modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml`` (MSE-mixed) and ``super-nvfp4-max-calib.yaml`` (max-calib mixed): NVFP4 W4A4 routed experts + FP8 per-tensor shared experts / Mamba in/out_proj + FP8 KV cache.
- Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default).
- Add post-training quantization (PTQ) example for the Megatron-Bridge framework: ``examples/megatron_bridge/quantize.py`` calibrates an HF model (via ``--quant_cfg`` alias / full config name or a ``--recipe`` YAML, with optional KV-cache quant, weight-only, compression, and MoE expert-ratio calibration) and saves a Megatron checkpoint (tensor / pipeline / expert parallelism supported), and ``examples/megatron_bridge/export.py`` converts that checkpoint to a deployable HuggingFace (unified) checkpoint for TensorRT-LLM / vLLM / SGLang. See `examples/megatron_bridge/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge>`_ for details.
- Add ``mtsa.config.SKIP_SOFTMAX_TRITON_CALIB`` for skip-softmax attention-sparsity calibration through the fused Triton ``attention_calibrate`` kernel (HF ``modelopt_triton`` backend), measuring multi-threshold tile-skip statistics the way the Triton inference kernel actually skips tiles for both prefill and decode. Exposed as ``--sparse_attn_cfg skip_softmax_triton_calib`` in ``examples/llm_sparsity/attention_sparsity/hf_sa.py`` (with a new ``--calib_data_dir`` flag for RULER calibration data).

**Bug Fixes**

Expand Down
19 changes: 19 additions & 0 deletions examples/llm_sparsity/attention_sparsity/hf_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from modelopt.torch.sparsity.attention_sparsity.config import (
SKIP_SOFTMAX_CALIB,
SKIP_SOFTMAX_CALIB_SPARSE24,
SKIP_SOFTMAX_TRITON_CALIB,
SPARSE_SOFTMAX_DEFAULT,
)
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
Expand All @@ -44,6 +45,7 @@
SPARSE_ATTN_CFG_CHOICES = {
"skip_softmax_calib": SKIP_SOFTMAX_CALIB,
"skip_softmax_calib_sparse24": SKIP_SOFTMAX_CALIB_SPARSE24,
"skip_softmax_triton_calib": SKIP_SOFTMAX_TRITON_CALIB,
"sparse_softmax": SPARSE_SOFTMAX_DEFAULT,
}

Expand Down Expand Up @@ -186,6 +188,15 @@ def main(args):
calib["max_seqlen"] = args.calib_max_seqlen
if args.calib_chunk_size is not None:
calib["chunk_size"] = args.calib_chunk_size
# Point RULER calibration at the data downloaded by download_ruler_data.sh
# (next to this script) unless the user overrides it. The NIAH essay
# haystack requires this directory.
calib.setdefault(
"data_dir",
args.calib_data_dir
if args.calib_data_dir is not None
else str(Path(__file__).parent / "data"),
)

model = mtsa.sparsify(model, config=sparse_config)
print("Sparse attention applied successfully!")
Expand Down Expand Up @@ -302,6 +313,14 @@ def main(args):
default=None,
help="Chunk size for calibration prefill. Overrides config value.",
)
parser.add_argument(
"--calib_data_dir",
type=str,
default=None,
help="Path to RULER calibration data (contains an 'essays' subdir). "
"Defaults to the 'data' directory next to this script "
"(populated by download_ruler_data.sh).",
)

args = parser.parse_args()
main(args)
41 changes: 37 additions & 4 deletions modelopt/torch/kernels/common/attention/hf_triton_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@

from modelopt.torch.kernels.common.attention.triton_fa import attention

# Skip-softmax calibration config and counters live on the module's
# ``_sparse_method_instance`` (HF passes the owning module to
# ``triton_attention_forward``), so no separate thread-local state is needed.


def _seq_lens_from_mask(
attention_mask: torch.Tensor | None,
Expand Down Expand Up @@ -105,20 +109,49 @@ def triton_attention_forward(
kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32)
kw["max_input_len_k"] = seq_k

# Sparse attention params
# Sparse-attention method instance. It carries the inference threshold and,
# during calibration, both the calibration config and the accumulated
# tile-skip counters. Available here because HF passes the owning module.
method = getattr(module, "_sparse_method_instance", None)

# Calibration mode: run the calibration kernel, which computes full attention
# while counting, per candidate threshold, how many KV tiles would be skipped.
# The sparse-attention kwargs below are intentionally not added in this branch.
if method is not None and getattr(method, "_calibration_mode", False):
trials = getattr(method, "_threshold_trials", None)
# Deferred: the package __init__ imports this module, so importing
# attention_calibrate at module top would be circular.
from modelopt.torch.kernels.common.attention import attention_calibrate

if trials and attention_calibrate is not None:
o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials)

# Accumulate counters across all attention calls in this forward pass.
# The method instance is per-module so the accumulator stays on one
# device, but guard the add against a device mismatch just in case.
prev = getattr(method, "_hf_calibration_counters", None)
method._hf_calibration_counters = (
counters if prev is None else prev + counters.to(prev.device)
)
method._hf_calibration_seq_k = seq_k
method._hf_calibration_is_decode = is_decode

return (o.view(batch, seq_len, num_heads, head_dim), None)

# N:M sparse softmax: prefill only (no perf benefit for decode)
if method is not None and not is_decode and getattr(module, "_apply_sparse_nm", False):
kw["sparsity_n"] = method.sparsity_n
kw["sparsity_m"] = method.sparsity_m
kw["dense_sink_tokens"] = method.dense_sink_tokens
kw["dense_recent_tokens"] = method.dense_recent_tokens

# Skip-softmax: applies to both prefill and decode
# Skip-softmax: applies to both prefill and decode. Prefer the method's
# per-phase calibrated dynamic threshold (scale_factor / seq_k); fall back
# to the static threshold when uncalibrated.
if method is not None and getattr(module, "_apply_skip_softmax", False):
if method.skip_softmax_threshold:
kw["skip_softmax_threshold"] = method.skip_softmax_threshold
threshold = method.get_inference_threshold(seq_len, seq_k)
if threshold:
kw["skip_softmax_threshold"] = threshold

o = attention(q, k, v, **kw)

Expand Down
184 changes: 102 additions & 82 deletions modelopt/torch/kernels/common/attention/triton_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ def _load_sparsity_helpers() -> None:
_FWD_CONFIGS = [triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=1, num_warps=4)]

_MEASURE_BLOCK_M = 128
_MEASURE_BLOCK_N = 64
# 128 so the kernel sparsity-measurement block matches the PyTorch
# flash_skip_softmax calibration block (br = bc = 128) and the Triton
# calibration kernel; otherwise the two measure at different granularities.
_MEASURE_BLOCK_N = 128
_MEASURE_NUM_STAGES = 1
_MEASURE_NUM_WARPS = 4

Expand Down Expand Up @@ -363,6 +366,8 @@ def _attn_fwd(
skip_tile = _skip_softmax_decision(
scores,
row_max,
q_pos,
seq_len_q,
SKIP_THRESHOLD_LOG2,
Sparsity_total,
Sparsity_skipped,
Expand Down Expand Up @@ -919,23 +924,29 @@ def forward(
def grid(META):
return (batch, num_q_heads, triton.cdiv(max_input_len, META["BLOCK_M"]))

if do_measure:
# Runtime counters mutate global tensors, so do not run them through
# autotune candidate trials. Use one stable config for measurement.
_attn_fwd.fn[grid](
*fwd_args,
**fwd_kwargs,
BLOCK_M=_MEASURE_BLOCK_M,
BLOCK_N=_MEASURE_BLOCK_N,
num_warps=_MEASURE_NUM_WARPS,
num_stages=_MEASURE_NUM_STAGES,
)
else:
_attn_fwd[grid](
*fwd_args,
**fwd_kwargs,
# BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune
)
# Triton launches on torch.cuda.current_device(), which is not
# necessarily the device the tensors live on (e.g. under accelerate
# device_map="auto" sharding). Activate the tensor's device so the
# kernel dereferences the right pointers instead of triggering an
# illegal memory access.
with torch.cuda.device(q.device):
if do_measure:
# Runtime counters mutate global tensors, so do not run them through
# autotune candidate trials. Use one stable config for measurement.
_attn_fwd.fn[grid](
*fwd_args,
**fwd_kwargs,
BLOCK_M=_MEASURE_BLOCK_M,
BLOCK_N=_MEASURE_BLOCK_N,
num_warps=_MEASURE_NUM_WARPS,
num_stages=_MEASURE_NUM_STAGES,
)
else:
_attn_fwd[grid](
*fwd_args,
**fwd_kwargs,
# BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune
)

# Store sparsity counters on the output tensor for retrieval by callers
if do_measure:
Expand Down Expand Up @@ -970,23 +981,30 @@ def backward(ctx, grad_output):
do = grad_output.contiguous()
num_warps = 4

# Triton launches on torch.cuda.current_device(), which is not
# necessarily the device the tensors live on (e.g. under accelerate
# device_map="auto" sharding). Activate the tensor's device for each
# launch so the kernels dereference the right pointers instead of
# triggering an illegal memory access.

# Phase 1: delta = rowsum(O * dO)
delta = torch.empty_like(lse)
_attn_bwd_preprocess[(ctx.num_q_heads, triton.cdiv(q.shape[0], BLOCK))](
o,
do,
delta,
o.stride(0),
o.stride(1),
do.stride(0),
do.stride(1),
delta.stride(0),
delta.stride(1),
q.shape[0],
HEAD_DIM=HEAD_DIM,
BLOCK_D=BLOCK_D,
BLOCK_M=BLOCK,
)
with torch.cuda.device(q.device):
_attn_bwd_preprocess[(ctx.num_q_heads, triton.cdiv(q.shape[0], BLOCK))](
o,
do,
delta,
o.stride(0),
o.stride(1),
do.stride(0),
do.stride(1),
delta.stride(0),
delta.stride(1),
q.shape[0],
HEAD_DIM=HEAD_DIM,
BLOCK_D=BLOCK_D,
BLOCK_M=BLOCK,
)

dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
Expand Down Expand Up @@ -1016,57 +1034,59 @@ def backward(ctx, grad_output):
)

# Phase 2: dK, dV
_attn_bwd_dkdv[(ctx.batch, ctx.num_kv_heads, triton.cdiv(ctx.max_input_len_k, BLOCK))](
*bwd_args[:4],
dk,
dv,
*bwd_args[4:],
dk.stride(0),
dk.stride(1),
dv.stride(0),
dv.stride(1),
lse.stride(0),
lse.stride(1),
kv_group_num=ctx.kv_group_num,
BLOCK_M=BLOCK,
BLOCK_D=BLOCK_D,
BLOCK_N=BLOCK,
IS_CAUSAL=ctx.is_causal,
HEAD_DIM=HEAD_DIM,
SPARSITY_N=ctx.sparsity_n,
SPARSITY_M=ctx.sparsity_m,
DENSE_SINK_TOKENS=ctx.dense_sink_tokens,
DENSE_RECENT_TOKENS=ctx.dense_recent_tokens,
APPLY_SKIP_SOFTMAX=ctx.apply_skip,
SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2,
num_warps=num_warps,
num_stages=1,
)
with torch.cuda.device(q.device):
_attn_bwd_dkdv[(ctx.batch, ctx.num_kv_heads, triton.cdiv(ctx.max_input_len_k, BLOCK))](
*bwd_args[:4],
dk,
dv,
*bwd_args[4:],
dk.stride(0),
dk.stride(1),
dv.stride(0),
dv.stride(1),
lse.stride(0),
lse.stride(1),
kv_group_num=ctx.kv_group_num,
BLOCK_M=BLOCK,
BLOCK_D=BLOCK_D,
BLOCK_N=BLOCK,
IS_CAUSAL=ctx.is_causal,
HEAD_DIM=HEAD_DIM,
SPARSITY_N=ctx.sparsity_n,
SPARSITY_M=ctx.sparsity_m,
DENSE_SINK_TOKENS=ctx.dense_sink_tokens,
DENSE_RECENT_TOKENS=ctx.dense_recent_tokens,
APPLY_SKIP_SOFTMAX=ctx.apply_skip,
SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2,
num_warps=num_warps,
num_stages=1,
)

# Phase 3: dQ
_attn_bwd_dq[(ctx.batch, ctx.num_q_heads, triton.cdiv(ctx.max_input_len, BLOCK))](
*bwd_args[:4],
dq,
*bwd_args[4:],
dq.stride(0),
dq.stride(1),
lse.stride(0),
lse.stride(1),
kv_group_num=ctx.kv_group_num,
BLOCK_M=BLOCK,
BLOCK_D=BLOCK_D,
BLOCK_N=BLOCK,
IS_CAUSAL=ctx.is_causal,
HEAD_DIM=HEAD_DIM,
SPARSITY_N=ctx.sparsity_n,
SPARSITY_M=ctx.sparsity_m,
DENSE_SINK_TOKENS=ctx.dense_sink_tokens,
DENSE_RECENT_TOKENS=ctx.dense_recent_tokens,
APPLY_SKIP_SOFTMAX=ctx.apply_skip,
SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2,
num_warps=num_warps,
num_stages=1,
)
with torch.cuda.device(q.device):
_attn_bwd_dq[(ctx.batch, ctx.num_q_heads, triton.cdiv(ctx.max_input_len, BLOCK))](
*bwd_args[:4],
dq,
*bwd_args[4:],
dq.stride(0),
dq.stride(1),
lse.stride(0),
lse.stride(1),
kv_group_num=ctx.kv_group_num,
BLOCK_M=BLOCK,
BLOCK_D=BLOCK_D,
BLOCK_N=BLOCK,
IS_CAUSAL=ctx.is_causal,
HEAD_DIM=HEAD_DIM,
SPARSITY_N=ctx.sparsity_n,
SPARSITY_M=ctx.sparsity_m,
DENSE_SINK_TOKENS=ctx.dense_sink_tokens,
DENSE_RECENT_TOKENS=ctx.dense_recent_tokens,
APPLY_SKIP_SOFTMAX=ctx.apply_skip,
SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2,
num_warps=num_warps,
num_stages=1,
)

return (
dq,
Expand Down
Loading
Loading