From 835f14b35d4c3b1676f20dffefb57fb15c557f30 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Sat, 13 Dec 2025 21:11:16 +0000 Subject: [PATCH 01/10] Add sparse attention integration to llm_eval Signed-off-by: Kai Xu --- .vscode/settings.json | 3 + examples/llm_eval/lm_eval_hf.py | 8 + examples/llm_eval/mmlu.py | 8 + .../attention_sparsity/requirements.txt | 2 + .../attention_sparsity/calibration/dataset.py | 546 ++++++++++++++++++ .../calibration/download_ruler_data.sh | 50 ++ .../calibration/ruler_utils.py | 487 ++++++++++++++++ tests/examples/llm_eval/test_llm_eval.py | 17 + 8 files changed, 1121 insertions(+) create mode 100644 examples/llm_sparsity/attention_sparsity/requirements.txt create mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py create mode 100755 modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh create mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 0e8465ad3..1cff4a791 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -45,4 +45,7 @@ ], "git.alwaysSignOff": true, "git.enableCommitSigning": true, + "cursorpyright.analysis.extraPaths": [ + "./tests/" + ], } diff --git a/examples/llm_eval/lm_eval_hf.py b/examples/llm_eval/lm_eval_hf.py index 405e8590a..24dcb28f6 100755 --- a/examples/llm_eval/lm_eval_hf.py +++ b/examples/llm_eval/lm_eval_hf.py @@ -68,6 +68,14 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | additional_config = {} if additional_config is None else additional_config additional_config = {k: v for k, v in additional_config.items() if v is not None} + # Force eager attention if sparse attention is requested + if sparse_cfg: + additional_config["attn_implementation"] = "eager" + warnings.warn( + "Sparse attention requires attn_implementation='eager'. " + "Forcing eager attention implementation." + ) + # Enable automatic save/load of modelopt state huggingface checkpointing mto.enable_huggingface_checkpointing() diff --git a/examples/llm_eval/mmlu.py b/examples/llm_eval/mmlu.py index 316f443bb..0bf47fcd3 100755 --- a/examples/llm_eval/mmlu.py +++ b/examples/llm_eval/mmlu.py @@ -269,6 +269,14 @@ def main( max_batch_size=1, ) else: + # Force eager attention if sparse attention is requested + if sparse_cfg: + kwargs["attn_implementation"] = "eager" + warnings.warn( + "Sparse attention requires attn_implementation='eager'. " + "Forcing eager attention implementation." + ) + model = select_model( max_input_length=MAX_SEQ_LEN, max_output_length=2, dtype=dtype, **kwargs ) diff --git a/examples/llm_sparsity/attention_sparsity/requirements.txt b/examples/llm_sparsity/attention_sparsity/requirements.txt new file mode 100644 index 000000000..a3e0dfa17 --- /dev/null +++ b/examples/llm_sparsity/attention_sparsity/requirements.txt @@ -0,0 +1,2 @@ +nltk +wonderwords diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py new file mode 100644 index 000000000..7603b4e1d --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py @@ -0,0 +1,546 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RULER dataset builder for sparse attention calibration.""" + +import random +import string +from dataclasses import dataclass +from typing import Any + +from tqdm import tqdm +from transformers import AutoTokenizer + +from . import ruler_utils + + +def _generate_target_lengths( + max_seqlen: int, num_length_bins: int = 4, min_seqlen: int = 1024 +) -> list[int]: + """Generate target lengths as descending powers of 2. + + Args: + max_seqlen: Maximum sequence length + num_length_bins: Maximum number of length bins to generate + min_seqlen: Minimum sequence length threshold + + Returns: + List of target lengths in descending order + + Examples: + >>> _generate_target_lengths(32768, 4) + [32768, 16384, 8192, 4096] + >>> _generate_target_lengths(2048, 4) + [2048, 1024] + """ + target_lengths = [] + current = max_seqlen + + for _ in range(num_length_bins): + if current < min_seqlen: + break + target_lengths.append(current) + current = current // 2 + + return target_lengths + + +@dataclass +class RulerTask: + """Configuration for a RULER task.""" + + name: str + task_type: str # niah, variable_tracking, freq_words_extraction, qa + tokens_to_generate: int + template: str + answer_prefix: str + args: dict[str, Any] + + +# Task configurations based on RULER benchmark +RULER_TASKS = { + "niah_multikey_2": RulerTask( + name="niah_multikey_2", + task_type="niah", + tokens_to_generate=128, + template=( + "Some special magic {type_needle_v} are hidden within the following text. " + "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" + "{context}\n" + "What are all the special magic {type_needle_v} for {query} mentioned in the provided text?" + ), + answer_prefix=( + " The special magic {type_needle_v} for {query} mentioned in the provided text are" + ), + args={ + "type_haystack": "needle", + "type_needle_k": "words", + "type_needle_v": "numbers", + "num_needle_k": 1, + "num_needle_v": 1, + "num_needle_q": 1, + }, + ), + "niah_multikey_3": RulerTask( + name="niah_multikey_3", + task_type="niah", + tokens_to_generate=128, + template=( + "Some special magic {type_needle_v} are hidden within the following text. " + "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" + "{context}\n" + "What are all the special magic {type_needle_v} for {query} mentioned in the provided text?" + ), + answer_prefix=( + " The special magic {type_needle_v} for {query} mentioned in the provided text are" + ), + args={ + "type_haystack": "needle", + "type_needle_k": "uuids", + "type_needle_v": "uuids", + "num_needle_k": 1, + "num_needle_v": 1, + "num_needle_q": 1, + }, + ), + "vt": RulerTask( + name="vt", + task_type="variable_tracking", + tokens_to_generate=30, + template=( + "Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n" + "{context}\n" + "Question: Find all variables that are assigned the value {query} in the text above." + ), + answer_prefix=( + " Answer: According to the chain(s) of variable assignment in the text above, " + "{num_v} variables are assgined the value {query}, they are: " + ), + args={"num_chains": 1, "num_hops": 4}, + ), + "fwe": RulerTask( + name="fwe", + task_type="freq_words_extraction", + tokens_to_generate=50, + template=( + "Read the following coded text and track the frequency of each coded word. " + "Find the three most frequently appeared coded words. {context}\n" + "Question: Do not provide any explanation. Please ignore the dots '....'. " + "What are the three most frequently appeared words in the above coded text?" + ), + answer_prefix=( + " Answer: According to the coded text above, " + "the three most frequently appeared words are:" + ), + args={"alpha": 2.0}, + ), + "qa_1": RulerTask( + name="qa_1", + task_type="qa", + tokens_to_generate=32, + template=( + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "The following are given documents.\n\n{context}\n\n" + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {query}" + ), + answer_prefix=" Answer:", + args={"dataset": "squad"}, + ), + "qa_2": RulerTask( + name="qa_2", + task_type="qa", + tokens_to_generate=32, + template=( + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "The following are given documents.\n\n{context}\n\n" + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {query}" + ), + answer_prefix=" Answer:", + args={"dataset": "hotpotqa"}, + ), +} + + +class RulerDatasetBuilder: + """Builder for RULER calibration datasets.""" + + def __init__( + self, + samples: int, + max_seqlen: int, + tokenizer_name_or_path: str | object, + num_length_bins: int = 4, + max_length_filter: int = 65536, + seed: int = 42, + ): + """Initialize RULER dataset builder. + + Args: + samples: Total number of samples to generate (distributed evenly across length bins) + max_seqlen: Maximum sequence length (length bins auto-generated as powers of 2) + tokenizer_name_or_path: HuggingFace tokenizer path or tokenizer object + seed: Random seed for reproducibility + num_length_bins: Number of length bins to generate (default: 4) + max_length_filter: Maximum sequence length to keep (default: 65536) + + Note: + Length bins are auto-generated as descending powers of 2: + [max_seqlen, max_seqlen/2, max_seqlen/4, ...] + Generation stops when num_length_bins is reached or length < 1024. + Subtasks are set to all the difficult tasks defined in RULER_TASKS. + """ + # Validate inputs + if samples <= 0: + raise ValueError(f"samples must be positive, got {samples}") + if max_seqlen < 1024: + raise ValueError(f"max_seqlen must be >= 1024, got {max_seqlen}") + + # Store parameters + self.total_samples = samples + self.max_seqlen = max_seqlen + self.num_length_bins = num_length_bins + self.subtasks = list(RULER_TASKS.keys()) + self.tokenizer_name_or_path = tokenizer_name_or_path + self.seed = seed + self.max_length_filter = max_length_filter + + # Generate target lengths and validate + self.target_lengths = _generate_target_lengths(max_seqlen, num_length_bins, min_seqlen=1024) + if not self.target_lengths: + raise ValueError(f"No valid target lengths generated from max_seqlen={max_seqlen}") + + # Distribute samples evenly across lengths + self.samples_per_length = [samples // len(self.target_lengths)] * len(self.target_lengths) + + # Initialize tokenizer + if isinstance(tokenizer_name_or_path, str): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + else: + self.tokenizer = tokenizer_name_or_path + random.seed(seed) + + def build_calibration_dataset(self) -> list[dict[str, Any]]: + """Build the complete calibration dataset. + + Returns: + List of calibration samples with 'input' and 'length' fields + """ + all_samples = [] + + # Generate calibration samples + for num_samples, target_length in tqdm( + zip(self.samples_per_length, self.target_lengths), + desc="Generating RULER calibration samples", + total=len(self.target_lengths), + ): + samples_per_task = max(num_samples // len(self.subtasks), 1) + + # Generate equal samples for each task + for task_name in self.subtasks: + for sample_idx in range(samples_per_task): + sample = self._generate_sample(task_name, target_length, sample_idx) + if sample and sample["length"] <= self.max_length_filter: + all_samples.append(sample) + + random.shuffle(all_samples) + return all_samples + + def _generate_sample( + self, task_name: str, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a single RULER sample. + + Args: + task_name: Name of the RULER task + target_length: Target sequence length in tokens + sample_idx: Index of the sample (for uniqueness) + + Returns: + Dict with 'input', 'length', and metadata fields + """ + task = RULER_TASKS[task_name] + + if task.task_type == "niah": + return self._generate_niah_sample(task, target_length, sample_idx) + elif task.task_type == "variable_tracking": + return self._generate_vt_sample(task, target_length, sample_idx) + elif task.task_type == "freq_words_extraction": + return self._generate_fwe_sample(task, target_length, sample_idx) + elif task.task_type == "qa": + return self._generate_qa_sample(task, target_length, sample_idx) + else: + raise ValueError(f"Unknown task type: {task.task_type}") + + def _generate_niah_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a needle-in-haystack sample.""" + args = task.args + + # Find optimal haystack size for target length + optimal_haystack = ruler_utils.find_optimal_haystack_size( + tokenizer=self.tokenizer, + max_seq_length=target_length, + template=task.template, + answer_prefix=task.answer_prefix, + tokens_to_generate=task.tokens_to_generate, + type_haystack=args.get("type_haystack", "essay"), + type_needle_k=args.get("type_needle_k", "words"), + type_needle_v=args.get("type_needle_v", "numbers"), + num_needle_k=args.get("num_needle_k", 1), + num_needle_v=args.get("num_needle_v", 1), + num_needle_q=args.get("num_needle_q", 1), + ) + + # Generate sample using official RULER implementation + sample = ruler_utils.generate_niah_sample( + num_haystack=optimal_haystack, + tokenizer=self.tokenizer, + template=task.template, + answer_prefix=task.answer_prefix, + tokens_to_generate=task.tokens_to_generate, + type_haystack=args.get("type_haystack", "essay"), + type_needle_k=args.get("type_needle_k", "words"), + type_needle_v=args.get("type_needle_v", "numbers"), + num_needle_k=args.get("num_needle_k", 1), + num_needle_v=args.get("num_needle_v", 1), + num_needle_q=args.get("num_needle_q", 1), + random_seed=self.seed + sample_idx, + ) + + # Add task metadata + sample["task"] = task.name + sample["target_length"] = target_length + sample["sample_idx"] = sample_idx + + return sample + + def _generate_vt_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a variable tracking sample.""" + args = task.args + num_chains = args["num_chains"] + num_hops = args["num_hops"] + + # Generate variable chains + variables = [] + chains = [] + for _ in range(num_chains): + chain = [self._generate_random_variable() for _ in range(num_hops + 1)] + variables.extend(chain) + chains.append(chain) + + # Generate assignments + assignments = [ + f"VAR {chain[i]} = {chain[i + 1]}" for chain in chains for i in range(len(chain) - 1) + ] + + # Create context with padding + context = self._pad_context_with_text( + "\n".join(assignments), target_length, "variable tracking context" + ) + + # Select a query value + query_value = random.choice([chain[-1] for chain in chains]) + + # Format template + template = task.template.format(context=context, query=query_value) + + # Count variables with the query value + num_v = sum(1 for chain in chains if chain[-1] == query_value) + + # Add answer prefix + full_input = template + task.answer_prefix.format(num_v=num_v, query=query_value) + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_fwe_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a frequency word extraction sample.""" + # Generate coded words with frequencies + num_unique_words = 50 + coded_words = [self._generate_coded_word() for _ in range(num_unique_words)] + + # Assign frequencies (make top 3 clearly more frequent) + frequencies = {} + for i, word in enumerate(coded_words): + if i < 3: + frequencies[word] = random.randint(20, 30) # High frequency + else: + frequencies[word] = random.randint(1, 10) # Low frequency + + # Generate the coded text + word_list = [] + for word, freq in frequencies.items(): + word_list.extend([word] * freq) + random.shuffle(word_list) + + # Add dots for separation + coded_text = " .... ".join(word_list) + + # Pad to target length + context = self._pad_context_with_text(coded_text, target_length, "coded text padding") + + # Format template + template = task.template.format(context=context) + full_input = template + task.answer_prefix + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_qa_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a QA sample.""" + # Generate synthetic documents + num_docs = 5 + documents = [] + + # Create a simple QA pair + answer = self._generate_random_phrase() + question = f"What is the special code mentioned in document {random.randint(1, num_docs)}?" + + for i in range(num_docs): + doc_text = self._generate_document_text(200) # Base document + if i == 2: # Insert answer in one document + doc_text += f" The special code is {answer}. " + documents.append(f"Document {i + 1}:\n{doc_text}\n") + + # Combine documents + context_base = "\n".join(documents) + + # Pad to target length + context = self._pad_context_with_text( + context_base, target_length, "additional document text" + ) + + # Format template + template = task.template.format(context=context, query=question) + full_input = template + task.answer_prefix + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _pad_context_with_text( + self, base_context: str, target_length: int, padding_type: str + ) -> str: + """Pad context to approach target length.""" + tokens = self.tokenizer.encode(base_context, add_special_tokens=False) + + while len(tokens) < target_length * 0.7: # Leave room for template + if padding_type == "variable tracking context": + padding = ( + f" VAR {self._generate_random_variable()} = {self._generate_random_variable()}." + ) + elif padding_type == "coded text padding": + padding = f" .... {self._generate_coded_word()} .... " + else: + padding = " " + self._generate_essay_text(50) + + base_context += padding + tokens = self.tokenizer.encode(base_context, add_special_tokens=False) + + if len(tokens) > target_length * 0.9: + # Truncate if too long + base_context = self.tokenizer.decode(tokens[: int(target_length * 0.8)]) + + return base_context + + def _generate_random_word(self) -> str: + """Generate a random word.""" + return "".join(random.choices(string.ascii_lowercase, k=random.randint(5, 10))) + + def _generate_random_variable(self) -> str: + """Generate a random variable name.""" + return "".join(random.choices(string.ascii_uppercase, k=1)) + "".join( + random.choices(string.digits, k=3) + ) + + def _generate_coded_word(self) -> str: + """Generate a coded word.""" + return "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) + + def _generate_random_phrase(self) -> str: + """Generate a random phrase.""" + words = [self._generate_random_word() for _ in range(random.randint(2, 4))] + return " ".join(words) + + def _generate_essay_text(self, num_words: int) -> str: + """Generate essay-like text.""" + topics = [ + "technology", + "science", + "nature", + "history", + "culture", + "education", + "health", + "economics", + "politics", + "philosophy", + "art", + "literature", + ] + + sentences = [] + words_generated = 0 + + while words_generated < num_words: + topic = random.choice(topics) + word1 = self._generate_random_word() + word2 = self._generate_random_word() + word3 = self._generate_random_word() + sentence = f"The {topic} of {word1} is {word2} and {word3}. " + sentences.append(sentence) + words_generated += len(sentence.split()) + + return " ".join(sentences) + + def _generate_document_text(self, num_words: int) -> str: + """Generate document-like text.""" + return self._generate_essay_text(num_words) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh b/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh new file mode 100755 index 000000000..54797f2a5 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Download RULER calibration data for attention sparsity. + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DATA_DIR="${SCRIPT_DIR}/data" +ESSAYS_DIR="${DATA_DIR}/essays" +URLS_FILE="${DATA_DIR}/PaulGrahamEssays_URLs.txt" +URLS_URL="https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt" + +mkdir -p "${ESSAYS_DIR}" + +# Download URL list if not exists +if [ ! -f "${URLS_FILE}" ]; then + echo "Downloading URL list..." + curl -fsSL "${URLS_URL}" -o "${URLS_FILE}" +fi + +# Download essays from GitHub URLs +echo -n "Downloading essays" +count=0 +while IFS= read -r url || [ -n "$url" ]; do + if [[ "${url}" == https://github.com*.txt ]]; then + filename=$(basename "${url}") + filepath="${ESSAYS_DIR}/${filename}" + if [ ! -f "${filepath}" ]; then + raw_url="${url/github.com/raw.githubusercontent.com}" + raw_url="${raw_url/\/raw\//\/}" + curl -fsSL "${raw_url}" -o "${filepath}" 2>/dev/null && echo -n "." + count=$((count + 1)) + fi + fi +done < "${URLS_FILE}" +echo " done" + +echo "Downloaded ${count} essays to ${ESSAYS_DIR}" diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py new file mode 100644 index 000000000..70d4da81b --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py @@ -0,0 +1,487 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copied and Adapted from https://github.com/NVIDIA/RULER +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +"""Official RULER dataset generation utilities adapted for Model Optimizer. + +This module contains core logic from the RULER benchmark (https://github.com/NVIDIA/RULER) +adapted to work as a library for calibration purposes. The generation logic closely follows +the official RULER implementation to ensure dataset consistency. + +Key adaptations from official RULER: +- Converted from CLI scripts to library functions +- Works with HuggingFace tokenizers directly +- Removed file I/O, returns data structures +- Simplified for calibration use case (primarily NIAH tasks) +""" + +import logging +import random +import re +import uuid +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +# Needle/Haystack template from official RULER +NEEDLE_TEMPLATE = "One of the special magic {type_needle_v} for {key} is: {value}." + +# Depth positions for needle insertion (from official RULER) +DEPTHS = [ + 0, + 2, + 5, + 7, + 10, + 12, + 15, + 18, + 20, + 23, + 25, + 28, + 30, + 33, + 35, + 38, + 40, + 43, + 45, + 48, + 50, + 53, + 55, + 58, + 60, + 62, + 65, + 67, + 70, + 72, + 75, + 77, + 80, + 82, + 85, + 87, + 90, + 92, + 95, + 97, + 100, +] + +# Data directory for RULER calibration files (downloaded via download_ruler_data.sh) +DATA_DIR = Path(__file__).parent / "data" +RULER_URLS_FILE = DATA_DIR / "PaulGrahamEssays_URLs.txt" +ESSAYS_DIR = DATA_DIR / "essays" + + +def _get_data_dir() -> Path: + """Get data directory for RULER data. + + Returns: + Path to data directory under calibration/ (created if doesn't exist) + """ + data_dir = Path(__file__).parent / "data" + data_dir.mkdir(parents=True, exist_ok=True) + return data_dir + + +def _load_paul_graham_essays_from_files() -> str: + """Load Paul Graham essays from local files. + + Reads essay .txt files from the data/essays directory. + Files must be downloaded first using download_ruler_data.sh. + + Returns: + Combined essay text + + Raises: + RuntimeError: If essays directory doesn't exist or is empty + """ + if not ESSAYS_DIR.exists(): + raise RuntimeError( + f"Essays directory not found at {ESSAYS_DIR}.\n" + "Please run the download script first:\n" + " bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh" + ) + + essay_files = list(ESSAYS_DIR.glob("*.txt")) + if not essay_files: + raise RuntimeError( + f"No essay files found in {ESSAYS_DIR}.\n" + "Please run the download script first:\n" + " bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh" + ) + + logger.info(f"Loading {len(essay_files)} Paul Graham essays from local files...") + + all_essays = [] + for filepath in essay_files: + text = filepath.read_text() + all_essays.append(text) + + combined_text = " ".join(all_essays) + logger.info(f"Loaded {len(all_essays)} essays successfully") + + return combined_text + + +def _load_paul_graham_essays() -> str: + """Load Paul Graham essays from local files. + + Essay files must be downloaded first using download_ruler_data.sh. + + Returns: + Essay text as string + """ + essay_text = _load_paul_graham_essays_from_files() + return re.sub(r"\s+", " ", essay_text) + + +def _load_word_lists(): + """Load word lists for random word generation. + + Returns: + List of words (adj-noun combinations) + """ + import wonderwords + + # Load wonderwords lists (same as official RULER) + nouns = wonderwords.random_word._get_words_from_text_file("nounlist.txt") + adjs = wonderwords.random_word._get_words_from_text_file("adjectivelist.txt") + words = [f"{adj}-{noun}" for adj in adjs for noun in nouns] + words = sorted(set(words)) + return words + + +# Global word list (loaded once) +_WORD_LIST = None + + +def generate_random_number(num_digits=7) -> str: + """Generate random number (from official RULER).""" + lower_bound = 10 ** (num_digits - 1) + upper_bound = 10**num_digits - 1 + return str(random.randint(lower_bound, upper_bound)) + + +def generate_random_word() -> str: + """Generate random word (from official RULER).""" + global _WORD_LIST + if _WORD_LIST is None: + _WORD_LIST = _load_word_lists() + return random.choice(_WORD_LIST) + + +def generate_random_uuid() -> str: + """Generate random UUID (from official RULER).""" + return str(uuid.UUID(int=random.getrandbits(128), version=4)) + + +def generate_random(type_needle: str) -> str: + """Generate random needle value based on type (from official RULER). + + Args: + type_needle: Type of needle ('numbers', 'words', 'uuids') + + Returns: + Random value as string + """ + if type_needle == "numbers": + return generate_random_number() + elif type_needle == "words": + return generate_random_word() + elif type_needle == "uuids": + return generate_random_uuid() + else: + raise ValueError(f"Unknown needle type: {type_needle}") + + +def generate_niah_sample( + num_haystack: int, + tokenizer, + template: str, + answer_prefix: str, + tokens_to_generate: int = 128, + type_haystack: str = "essay", + type_needle_k: str = "words", + type_needle_v: str = "numbers", + num_needle_k: int = 1, + num_needle_v: int = 1, + num_needle_q: int = 1, + random_seed: int = 42, +) -> dict[str, Any]: + """Generate a single NIAH (Needle in a Haystack) sample. + + This function implements the core generation logic from official RULER's niah.py, + adapted to work as a library function. + + Args: + num_haystack: Number of haystack items/words + tokenizer: HuggingFace tokenizer (AutoTokenizer instance) + template: NIAH question template + answer_prefix: Answer prefix template + tokens_to_generate: Expected number of generation tokens + type_haystack: Type of haystack ('essay', 'noise', 'needle') + type_needle_k: Type of needle keys ('numbers', 'words', 'uuids') + type_needle_v: Type of needle values ('numbers', 'words', 'uuids') + num_needle_k: Number of needle keys + num_needle_v: Number of needle values per key + num_needle_q: Number of needles to query + random_seed: Random seed for this sample + + Returns: + Dictionary with 'input', 'outputs', 'length' keys + """ + import nltk + from nltk.tokenize import sent_tokenize + + try: + nltk.data.find("tokenizers/punkt") + except LookupError: + nltk.download("punkt", quiet=True) + nltk.download("punkt_tab", quiet=True) + + if random_seed is not None: + random.seed(random_seed) + + # Ensure num_needle_k >= num_needle_q + num_needle_k = max(num_needle_k, num_needle_q) + + # Generate needles (keys and values) + keys, values, needles = [], [], [] + for _ in range(num_needle_k): + keys.append(generate_random(type_needle_k)) + value = [] + for _ in range(num_needle_v): + value.append(generate_random(type_needle_v)) + needles.append( + NEEDLE_TEMPLATE.format( + type_needle_v=type_needle_v, + key=keys[-1], + value=value[-1], + ) + ) + values.append(value) + + random.shuffle(needles) + + # Generate context based on haystack type + if type_haystack == "essay": + # Load essay corpus + essay_text = _load_paul_graham_essays() + haystack = essay_text.split(" ") + + # Create text from haystack + if num_haystack <= len(haystack): + text = " ".join(haystack[:num_haystack]) + else: + # Repeat haystack as needed + repeats = (num_haystack + len(haystack) - 1) // len(haystack) + text = " ".join((haystack * repeats)[:num_haystack]) + + # Insert needles at various depths + document_sents = sent_tokenize(text.strip()) + insertion_positions = [ + 0, + *sorted( + int(len(document_sents) * (depth / 100)) + for depth in random.sample(DEPTHS, len(needles)) + ), + len(document_sents), + ] + + document_sents_list = [] + for i in range(1, len(insertion_positions)): + last_pos = insertion_positions[i - 1] + next_pos = insertion_positions[i] + document_sents_list.append(" ".join(document_sents[last_pos:next_pos])) + if i - 1 < len(needles): + document_sents_list.append(needles[i - 1]) + + context = " ".join(document_sents_list) + + if type_haystack == "noise": + haystack_sent = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." + sentences = [haystack_sent] * num_haystack + indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) + for index, element in zip(indexes, needles): + sentences.insert(index, element) + context = "\n".join(sentences) + + elif type_haystack == "needle": + sentences = [ + NEEDLE_TEMPLATE.format( + type_needle_v=type_needle_v, + key=generate_random(type_needle_k), + value=generate_random(type_needle_v), + ) + for _ in range(num_haystack) + ] + + indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) + for index, element in zip(indexes, needles): + sentences.insert(index, element) + context = "\n".join(sentences) + + # Generate query and answer + indices = random.sample(range(num_needle_k), num_needle_q) + queries = [keys[i] for i in indices] + answers = [a for i in indices for a in values[i]] + query = ", ".join(queries[:-1]) + ", and " + queries[-1] if len(queries) > 1 else queries[0] + + # Format template (adjust for singular vs plural) + type_needle_v_display = type_needle_v + formatted_template = template + if num_needle_q * num_needle_v == 1: + formatted_template = formatted_template.replace("Some", "A") + formatted_template = formatted_template.replace("are all", "is") + formatted_template = formatted_template.replace("are", "is") + formatted_template = formatted_template.replace("answers", "answer") + type_needle_v_display = type_needle_v[:-1] # remove "s" + + input_text = formatted_template.format( + type_needle_v=type_needle_v_display, + context=context, + query=query, + ) + + # Add answer prefix + formatted_answer_prefix = answer_prefix.format( + type_needle_v=type_needle_v_display, + query=query, + ) + input_text = input_text + formatted_answer_prefix + + # Calculate actual length + if hasattr(tokenizer, "encode"): + # HuggingFace tokenizer + tokens = tokenizer.encode(input_text, add_special_tokens=False) + length = len(tokens) + tokens_to_generate + else: + # Fallback + length = len(input_text.split()) + tokens_to_generate + + return { + "input": input_text, + "outputs": answers, + "length": length, + } + + +def find_optimal_haystack_size( + tokenizer, + max_seq_length: int, + template: str, + answer_prefix: str, + tokens_to_generate: int = 128, + type_haystack: str = "essay", + **kwargs, +) -> int: + """Find optimal haystack size using binary search (from official RULER). + + Args: + tokenizer: HuggingFace tokenizer + max_seq_length: Maximum sequence length + tokens_to_generate: Expected generation tokens + type_haystack: Type of haystack + template: NIAH question template + answer_prefix: Answer prefix template + **kwargs: Additional arguments for generate_niah_sample + + Returns: + Optimal number of haystack items + """ + # Determine incremental step based on haystack type + if type_haystack == "essay": + incremental = 500 + elif type_haystack in ["noise", "needle"]: + incremental = 25 + else: + incremental = 100 + + if max_seq_length < 4096 and type_haystack != "essay": + incremental = 5 + + # Estimate tokens per haystack item + sample = generate_niah_sample( + incremental, + tokenizer, + template, + answer_prefix, + tokens_to_generate, + type_haystack=type_haystack, + **kwargs, + ) + + if hasattr(tokenizer, "encode"): + sample_tokens = len(tokenizer.encode(sample["input"], add_special_tokens=False)) + else: + sample_tokens = len(sample["input"].split()) + + tokens_per_haystack = sample_tokens / incremental + estimated_max = int((max_seq_length / tokens_per_haystack) * 3) + + # Binary search for optimal size + lower_bound = incremental + upper_bound = max(estimated_max, incremental * 2) + optimal_num_haystack = None + + logger.info(f"Estimated {tokens_per_haystack:.1f} tokens per haystack") + logger.info(f"Binary search bounds: {lower_bound} to {upper_bound}") + + while lower_bound <= upper_bound: + mid = (lower_bound + upper_bound) // 2 + sample = generate_niah_sample( + mid, + tokenizer, + template, + answer_prefix, + tokens_to_generate, + type_haystack=type_haystack, + **kwargs, + ) + total_tokens = sample["length"] + + logger.debug(f"Testing haystack size: {mid}, tokens: {total_tokens}/{max_seq_length}") + + if total_tokens <= max_seq_length: + optimal_num_haystack = mid + lower_bound = mid + 1 + else: + upper_bound = mid - 1 + + final_size = optimal_num_haystack if optimal_num_haystack is not None else incremental + logger.info(f"Optimal haystack size: {final_size}") + + return final_size diff --git a/tests/examples/llm_eval/test_llm_eval.py b/tests/examples/llm_eval/test_llm_eval.py index 0abf78b53..88d29dedc 100644 --- a/tests/examples/llm_eval/test_llm_eval.py +++ b/tests/examples/llm_eval/test_llm_eval.py @@ -36,3 +36,20 @@ def test_llama_eval_fp8(): finally: # Force kill llm-serve if it's still running subprocess.run(["pkill", "-f", "llm-serve"], check=False) + + +def test_llama_eval_sparse_attention(tiny_llama_path): + """Test sparse attention with llm_eval integration.""" + try: + # Test with default sparse attention config (no quantization) + run_llm_ptq_command( + model=tiny_llama_path, + quant="none", # No quantization, only sparse attention + tasks="lm_eval", + lm_eval_tasks="hellaswag", + lm_eval_limit=0.05, # Small limit for fast test + sparse_cfg="SKIP_SOFTMAX_DEFAULT", + batch=4, + ) + finally: + subprocess.run(["pkill", "-f", "llm-serve"], check=False) From 32ec9caac29904d691e8f5ceec67bc7def80d952 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Mon, 15 Dec 2025 07:47:13 +0000 Subject: [PATCH 02/10] Add hf unified checkpoint export for sparse attention Signed-off-by: Kai Xu --- .vscode/settings.json | 3 --- .../torch/sparsity/attention_sparsity/calibration/dataset.py | 5 +++-- .../sparsity/attention_sparsity/calibration/ruler_utils.py | 5 ++++- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 1cff4a791..0e8465ad3 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -45,7 +45,4 @@ ], "git.alwaysSignOff": true, "git.enableCommitSigning": true, - "cursorpyright.analysis.extraPaths": [ - "./tests/" - ], } diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py index 7603b4e1d..74a4f3aa3 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py @@ -435,11 +435,12 @@ def _generate_qa_sample( # Create a simple QA pair answer = self._generate_random_phrase() - question = f"What is the special code mentioned in document {random.randint(1, num_docs)}?" + answer_doc_idx = random.randint(0, num_docs - 1) + question = f"What is the special code mentioned in document {answer_doc_idx + 1}?" for i in range(num_docs): doc_text = self._generate_document_text(200) # Base document - if i == 2: # Insert answer in one document + if i == answer_doc_idx: # Insert answer in the correct document doc_text += f" The special code is {answer}. " documents.append(f"Document {i + 1}:\n{doc_text}\n") diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py index 70d4da81b..9de75c02a 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py @@ -331,7 +331,7 @@ def generate_niah_sample( context = " ".join(document_sents_list) - if type_haystack == "noise": + elif type_haystack == "noise": haystack_sent = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." sentences = [haystack_sent] * num_haystack indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) @@ -354,6 +354,9 @@ def generate_niah_sample( sentences.insert(index, element) context = "\n".join(sentences) + else: + raise ValueError(f"Unknown haystack type: {type_haystack}") + # Generate query and answer indices = random.sample(range(num_needle_k), num_needle_q) queries = [keys[i] for i in indices] From 12c9b141aca118164ccbd9301a8b7dacd05fcabf Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 30 Dec 2025 08:29:10 +0000 Subject: [PATCH 03/10] Address feedbacks Signed-off-by: Kai Xu --- examples/llm_eval/lm_eval_hf.py | 8 -------- examples/llm_eval/mmlu.py | 8 -------- examples/llm_sparsity/attention_sparsity/requirements.txt | 2 -- .../sparsity/attention_sparsity/calibration/dataset.py | 7 ++++--- 4 files changed, 4 insertions(+), 21 deletions(-) delete mode 100644 examples/llm_sparsity/attention_sparsity/requirements.txt diff --git a/examples/llm_eval/lm_eval_hf.py b/examples/llm_eval/lm_eval_hf.py index 24dcb28f6..405e8590a 100755 --- a/examples/llm_eval/lm_eval_hf.py +++ b/examples/llm_eval/lm_eval_hf.py @@ -68,14 +68,6 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | additional_config = {} if additional_config is None else additional_config additional_config = {k: v for k, v in additional_config.items() if v is not None} - # Force eager attention if sparse attention is requested - if sparse_cfg: - additional_config["attn_implementation"] = "eager" - warnings.warn( - "Sparse attention requires attn_implementation='eager'. " - "Forcing eager attention implementation." - ) - # Enable automatic save/load of modelopt state huggingface checkpointing mto.enable_huggingface_checkpointing() diff --git a/examples/llm_eval/mmlu.py b/examples/llm_eval/mmlu.py index 0bf47fcd3..316f443bb 100755 --- a/examples/llm_eval/mmlu.py +++ b/examples/llm_eval/mmlu.py @@ -269,14 +269,6 @@ def main( max_batch_size=1, ) else: - # Force eager attention if sparse attention is requested - if sparse_cfg: - kwargs["attn_implementation"] = "eager" - warnings.warn( - "Sparse attention requires attn_implementation='eager'. " - "Forcing eager attention implementation." - ) - model = select_model( max_input_length=MAX_SEQ_LEN, max_output_length=2, dtype=dtype, **kwargs ) diff --git a/examples/llm_sparsity/attention_sparsity/requirements.txt b/examples/llm_sparsity/attention_sparsity/requirements.txt deleted file mode 100644 index a3e0dfa17..000000000 --- a/examples/llm_sparsity/attention_sparsity/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -nltk -wonderwords diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py index 74a4f3aa3..dc46413c1 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py @@ -20,9 +20,6 @@ from dataclasses import dataclass from typing import Any -from tqdm import tqdm -from transformers import AutoTokenizer - from . import ruler_utils @@ -232,6 +229,8 @@ def __init__( # Initialize tokenizer if isinstance(tokenizer_name_or_path, str): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) else: self.tokenizer = tokenizer_name_or_path @@ -243,6 +242,8 @@ def build_calibration_dataset(self) -> list[dict[str, Any]]: Returns: List of calibration samples with 'input' and 'length' fields """ + from tqdm import tqdm + all_samples = [] # Generate calibration samples From e60760bd81e248b8056d2e200cc985ffd187e0d2 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 27 Jan 2026 15:26:41 -0800 Subject: [PATCH 04/10] Move the data folder under example Signed-off-by: Kai Xu --- .../calibration/download_ruler_data.sh | 50 ------------------- .../calibration/ruler_utils.py | 17 ++++--- 2 files changed, 9 insertions(+), 58 deletions(-) delete mode 100755 modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh b/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh deleted file mode 100755 index 54797f2a5..000000000 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh +++ /dev/null @@ -1,50 +0,0 @@ -#!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Download RULER calibration data for attention sparsity. - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -DATA_DIR="${SCRIPT_DIR}/data" -ESSAYS_DIR="${DATA_DIR}/essays" -URLS_FILE="${DATA_DIR}/PaulGrahamEssays_URLs.txt" -URLS_URL="https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt" - -mkdir -p "${ESSAYS_DIR}" - -# Download URL list if not exists -if [ ! -f "${URLS_FILE}" ]; then - echo "Downloading URL list..." - curl -fsSL "${URLS_URL}" -o "${URLS_FILE}" -fi - -# Download essays from GitHub URLs -echo -n "Downloading essays" -count=0 -while IFS= read -r url || [ -n "$url" ]; do - if [[ "${url}" == https://github.com*.txt ]]; then - filename=$(basename "${url}") - filepath="${ESSAYS_DIR}/${filename}" - if [ ! -f "${filepath}" ]; then - raw_url="${url/github.com/raw.githubusercontent.com}" - raw_url="${raw_url/\/raw\//\/}" - curl -fsSL "${raw_url}" -o "${filepath}" 2>/dev/null && echo -n "." - count=$((count + 1)) - fi - fi -done < "${URLS_FILE}" -echo " done" - -echo "Downloaded ${count} essays to ${ESSAYS_DIR}" diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py index 9de75c02a..741b621f5 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py @@ -99,8 +99,10 @@ 100, ] -# Data directory for RULER calibration files (downloaded via download_ruler_data.sh) -DATA_DIR = Path(__file__).parent / "data" +# Data directory for RULER calibration files (in examples folder) +# Downloaded via examples/llm_sparsity/attention_sparsity/download_ruler_data.sh +_REPO_ROOT = Path(__file__).parent.parent.parent.parent.parent.parent +DATA_DIR = _REPO_ROOT / "examples" / "llm_sparsity" / "attention_sparsity" / "data" RULER_URLS_FILE = DATA_DIR / "PaulGrahamEssays_URLs.txt" ESSAYS_DIR = DATA_DIR / "essays" @@ -109,11 +111,10 @@ def _get_data_dir() -> Path: """Get data directory for RULER data. Returns: - Path to data directory under calibration/ (created if doesn't exist) + Path to data directory under examples/llm_sparsity/attention_sparsity/ (created if doesn't exist) """ - data_dir = Path(__file__).parent / "data" - data_dir.mkdir(parents=True, exist_ok=True) - return data_dir + DATA_DIR.mkdir(parents=True, exist_ok=True) + return DATA_DIR def _load_paul_graham_essays_from_files() -> str: @@ -132,7 +133,7 @@ def _load_paul_graham_essays_from_files() -> str: raise RuntimeError( f"Essays directory not found at {ESSAYS_DIR}.\n" "Please run the download script first:\n" - " bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh" + " bash examples/llm_sparsity/attention_sparsity/download_ruler_data.sh" ) essay_files = list(ESSAYS_DIR.glob("*.txt")) @@ -140,7 +141,7 @@ def _load_paul_graham_essays_from_files() -> str: raise RuntimeError( f"No essay files found in {ESSAYS_DIR}.\n" "Please run the download script first:\n" - " bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh" + " bash examples/llm_sparsity/attention_sparsity/download_ruler_data.sh" ) logger.info(f"Loading {len(essay_files)} Paul Graham essays from local files...") From 252cb982c3a384a25447e55e53cf5b7cd602d427 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 27 Jan 2026 16:27:05 -0800 Subject: [PATCH 05/10] Implement Inverse Power calibration for sparse attention Signed-off-by: Kai Xu --- .../attention_sparsity/calibration/dataset.py | 35 ++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py index dc46413c1..221ea2344 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py @@ -20,6 +20,8 @@ from dataclasses import dataclass from typing import Any +from tqdm import tqdm + from . import ruler_utils @@ -242,26 +244,27 @@ def build_calibration_dataset(self) -> list[dict[str, Any]]: Returns: List of calibration samples with 'input' and 'length' fields """ - from tqdm import tqdm - all_samples = [] - # Generate calibration samples - for num_samples, target_length in tqdm( - zip(self.samples_per_length, self.target_lengths), - desc="Generating RULER calibration samples", - total=len(self.target_lengths), - ): - samples_per_task = max(num_samples // len(self.subtasks), 1) - - # Generate equal samples for each task - for task_name in self.subtasks: - for sample_idx in range(samples_per_task): - sample = self._generate_sample(task_name, target_length, sample_idx) - if sample and sample["length"] <= self.max_length_filter: - all_samples.append(sample) + print( + f"Generating {self.total_samples} calibration samples " + f"across {len(self.target_lengths)} length bins: {self.target_lengths}" + ) + + # Generate calibration samples with sample-level progress + with tqdm(total=self.total_samples, desc="Generating RULER samples") as pbar: + for num_samples, target_length in zip(self.samples_per_length, self.target_lengths): + samples_per_task = max(num_samples // len(self.subtasks), 1) + + for task_name in self.subtasks: + for sample_idx in range(samples_per_task): + sample = self._generate_sample(task_name, target_length, sample_idx) + if sample and sample["length"] <= self.max_length_filter: + all_samples.append(sample) + pbar.update(1) random.shuffle(all_samples) + print(f"Generated {len(all_samples)} valid samples") return all_samples def _generate_sample( From 4d1bf5f7bc73cd57a4d1dc18d9a70a7fccb51d83 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Fri, 6 Feb 2026 20:15:33 -0800 Subject: [PATCH 06/10] Add Video Sparse Attention (VSA) Support for Video Diffusion Models Signed-off-by: Kai Xu --- examples/video_diffusion/vsa/README.md | 177 +++++ .../vsa/test_ltx2_vsa_integration.py | 612 ++++++++++++++++++ .../sparsity/attention_sparsity/config.py | 158 ++++- .../attention_sparsity/methods/__init__.py | 5 + .../attention_sparsity/methods/registry.py | 2 + .../attention_sparsity/methods/vsa.py | 373 +++++++++++ .../attention_sparsity/methods/vsa_utils.py | 155 +++++ .../attention_sparsity/plugins/huggingface.py | 6 +- .../attention_sparsity/plugins/ltx2.py | 459 +++++++++++++ .../attention_sparsity/sparse_attention.py | 37 +- 10 files changed, 1972 insertions(+), 12 deletions(-) create mode 100644 examples/video_diffusion/vsa/README.md create mode 100644 examples/video_diffusion/vsa/test_ltx2_vsa_integration.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/methods/vsa.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/methods/vsa_utils.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py diff --git a/examples/video_diffusion/vsa/README.md b/examples/video_diffusion/vsa/README.md new file mode 100644 index 000000000..dccdadd4f --- /dev/null +++ b/examples/video_diffusion/vsa/README.md @@ -0,0 +1,177 @@ +# Video Sparse Attention (VSA) Example + +This example demonstrates how to apply Video Sparse Attention (VSA) optimization to video diffusion models for faster inference. + +## Overview + +VSA is a two-branch sparse attention architecture designed specifically for video diffusion models: + +1. **Compression Branch**: Averages tokens within 3D video blocks (default 4x4x4 = 64 tokens) and computes coarse-grained attention for global context. + +2. **Sparse Branch**: Selects the top-K most important blocks based on attention scores and computes fine-grained attention only for those blocks. + +The branches are combined using learned gating: `output = compression * gate_compress + sparse` + +## Requirements + +```bash +pip install torch>=2.0 +pip install modelopt +# Optional: pip install diffusers # For real video diffusion models +``` + +## Quick Start + +### Using LTX-2 Trainer (Recommended) + +```bash +# Full video generation with VSA vs baseline comparison +python test_ltx2_vsa_integration.py \ + --checkpoint path/to/model.safetensors \ + --text-encoder-path path/to/gemma \ + --compare + +# Generate video with custom sparsity +python test_ltx2_vsa_integration.py \ + --checkpoint path/to/model.safetensors \ + --text-encoder-path path/to/gemma \ + --top-k-ratio 0.3 --output my_video.mp4 +``` + +## Configuration Options + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `--top_k_ratio` | 0.5 | Ratio of blocks to keep (0.0-1.0). Lower = more sparse | +| `--block_size` | 4 4 4 | 3D block size (T H W) for video tiling | +| `--video_shape` | 16 28 48 | Video dimensions (T H W) after patchification | +| `--batch_size` | 1 | Batch size for inference | +| `--device` | cuda | Device (cuda/cpu) | +| `--dtype` | bfloat16 | Data type (float32/float16/bfloat16) | + +## Examples + +## API Usage + +```python +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.config import VSA_DEFAULT + +# Load your video diffusion model +model = load_video_diffusion_model() + +# Apply VSA with default settings +model = mtsa.sparsify(model, config=VSA_DEFAULT) + +# Or with custom configuration +custom_config = { + "sparse_cfg": { + "*attention*": { + "method": "vsa", + "block_size_3d": (4, 4, 4), + "top_k_ratio": 0.3, # 70% sparsity + "video_shape": (16, 28, 48), + "enable": True, + }, + "default": {"enable": False}, + }, +} +model = mtsa.sparsify(model, config=custom_config) + +# Run inference +output = model(video_latents) +``` + +## Model Requirements + +For optimal VSA performance, video diffusion models should expose a `gate_compress` parameter in their attention layers. This is a learned parameter that controls the balance between the compression and sparse branches. + +Example attention layer interface: + +```python +class VideoAttention(nn.Module): + def __init__(self, hidden_dim, num_heads): + super().__init__() + self.to_q = nn.Linear(hidden_dim, hidden_dim) + self.to_k = nn.Linear(hidden_dim, hidden_dim) + self.to_v = nn.Linear(hidden_dim, hidden_dim) + # VSA-specific: learned gating + self.to_gate_compress = nn.Linear(hidden_dim, hidden_dim) +``` + +If `gate_compress` is not available, VSA will use equal weighting (sum of both branches). + +## Expected Performance + +| Top-K Ratio | Sparsity | Typical Speedup | +|-------------|----------|-----------------| +| 0.5 | 50% | 1.5-2x | +| 0.3 | 70% | 2-3x | +| 0.2 | 80% | 3-4x | + +*Actual speedup depends on model architecture, video resolution, and hardware.* + +## Troubleshooting + +### "video_shape must be set" error + +Make sure to provide `video_shape` in the configuration matching your video dimensions after patchification. + +### Low speedup + +- VSA is most effective for long sequences (high video resolution or many frames) +- For short sequences, the overhead of block operations may reduce gains +- Ensure you're using GPU with CUDA + +### Quality degradation + +- Increase `top_k_ratio` to keep more blocks +- Ensure your model has `gate_compress` for optimal branch balancing + +## LTX-2 Integration + +LTX-2 is a state-of-the-art video diffusion model that is well-suited for VSA optimization due to its high token count. + +### LTX-2 Architecture Summary + +| Component | Description | +|-----------|-------------| +| **Transformer** | 48 layers, 32 heads x 128 dim = 4096 hidden | +| **Compression** | 1:8192 pixels-to-tokens (aggressive) | +| **Attention Types** | Self-attn (attn1), Cross-attn (attn2), Audio attn, Cross-modal | + +### Example Scripts + +| Script | Purpose | +|--------|---------| +| `test_ltx2_vsa_integration.py` | Test VSA with LTX-2 trainer pipeline | + +### VSA Targets for LTX-2 + +VSA is applied only to **self-attention (attn1)** modules: + +```python +vsa_config = { + "sparse_cfg": { + "*.attn1": { # [OK] Self-attention - VSA enabled + "method": "vsa", + "top_k_ratio": 0.5, + "block_size_3d": [4, 4, 4], + }, + "*.attn2": {"enable": False}, # [NO] Text cross-attention + "*.audio_attn*": {"enable": False}, # [NO] Audio attention + "*.audio_to_video*": {"enable": False}, # [NO] Cross-modal + "*.video_to_audio*": {"enable": False}, # [NO] Cross-modal + }, +} +``` + +### Expected Token Counts for LTX-2 + +| Resolution | Frames | Tokens | VSA Tiles | Recommendation | +|------------|--------|--------|-----------|----------------| +| 512x768 | 121 | ~5,808 | 91 | Excellent for VSA | +| 384x384 | 49 | ~907 | 14 | Marginal | +| 256x256 | 25 | ~200 | 3 | Too small | + +For best VSA performance, use **121+ frames @ 512x768+** resolution. diff --git a/examples/video_diffusion/vsa/test_ltx2_vsa_integration.py b/examples/video_diffusion/vsa/test_ltx2_vsa_integration.py new file mode 100644 index 000000000..1257b31a2 --- /dev/null +++ b/examples/video_diffusion/vsa/test_ltx2_vsa_integration.py @@ -0,0 +1,612 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test VSA integration with LTX-2 video generation. + +This script tests Video Sparse Attention (VSA) on the full LTX-2 pipeline, +measuring performance improvements and validating output quality. + +Usage: + # Test with VSA enabled (default) + python test_ltx2_vsa_integration.py \ + --checkpoint path/to/model.safetensors \ + --text-encoder-path path/to/gemma \ + --prompt "A cat playing with a ball" + + # Test without VSA (baseline) + python test_ltx2_vsa_integration.py \ + --checkpoint path/to/model.safetensors \ + --text-encoder-path path/to/gemma \ + --prompt "A cat playing with a ball" \ + --no-vsa + + # Compare both (recommended) + python test_ltx2_vsa_integration.py \ + --checkpoint path/to/model.safetensors \ + --text-encoder-path path/to/gemma \ + --prompt "A cat playing with a ball" \ + --compare + + # Custom VSA parameters + python test_ltx2_vsa_integration.py \ + --checkpoint path/to/model.safetensors \ + --text-encoder-path path/to/gemma \ + --prompt "A cat playing with a ball" \ + --top-k-ratio 0.5 \ + --num-frames 121 --height 512 --width 768 + +VSA improves attention performance by using 3D tile-based sparsity: +- Automatically adapts to LTX-2's compressed token sequence +""" + +import argparse +import time +from pathlib import Path + +import torch +from ltx_trainer.model_loader import load_model +from ltx_trainer.progress import StandaloneSamplingProgress +from ltx_trainer.validation_sampler import GenerationConfig, ValidationSampler +from ltx_trainer.video_utils import save_video + +from modelopt.torch.sparsity.attention_sparsity import sparsify + + +def calculate_expected_tokens(num_frames: int, height: int, width: int) -> int: + """Calculate expected token count for LTX-2. + + LTX-2 uses 1:8192 pixels-to-tokens compression ratio. + """ + pixels = num_frames * height * width + tokens = pixels // 8192 + return tokens + + +def is_vsa_compatible(num_frames: int, height: int, width: int) -> tuple[bool, str]: + """Check if input size is compatible with VSA. + + Args: + num_frames: Number of video frames. + height: Video height in pixels. + width: Video width in pixels. + + Returns: + Tuple of (is_compatible, reason_message). + """ + tokens = calculate_expected_tokens(num_frames, height, width) + tiles = tokens // 64 # VSA tile size: 4x4x4 = 64 + + if tiles >= 90: + return True, f"Excellent: {tokens} tokens ({tiles} tiles)" + elif tiles >= 16: + return True, f"Marginal: {tokens} tokens ({tiles} tiles)" + else: + return False, f"Too small: {tokens} tokens ({tiles} tiles, need 16+ for VSA)" + + +def apply_vsa_to_transformer( + transformer: torch.nn.Module, + num_frames: int, + height: int, + width: int, + top_k_ratio: float = 0.5, +) -> torch.nn.Module: + """Apply VSA to the LTX-2 transformer. + + Args: + transformer: The transformer model. + num_frames: Number of frames (for compatibility checking). + height: Video height (for compatibility checking). + width: Video width (for compatibility checking). + top_k_ratio: Sparsity ratio (0.5 = 50% sparsity). + + Returns: + Modified transformer with VSA enabled. + """ + print("\nConfiguring VSA for LTX-2...") + + # Check compatibility + tokens = calculate_expected_tokens(num_frames, height, width) + tiles = tokens // 64 + compatible, reason = is_vsa_compatible(num_frames, height, width) + + print(f" Expected sequence: {tokens} tokens ({tiles} tiles)") + print(f" VSA compatibility: {reason}") + + if not compatible: + print(" [WARNING] Input size may be too small for VSA to provide significant benefit.") + print(" Consider using larger inputs (121+ frames @ 512x768+) for best results.") + + # Configure VSA + # NOTE: LTX-2 uses "attn1", "attn2", "audio_attn1", "audio_attn2" naming + # Pattern must be "*attn*" not "*attention*" to match these module names + sparse_config = { + "sparse_cfg": { + "*attn*": { + "method": "vsa", + "video_shape": None, # Auto-infer from LTX-2's compressed tokens + "block_size_3d": (4, 4, 4), # Standard VSA tile size + "top_k_ratio": top_k_ratio, + } + } + } + + # Apply VSA to transformer + print(" Applying VSA to attention modules...") + transformer = sparsify(transformer, sparse_config) + + print(f" [OK] VSA enabled with {int(top_k_ratio * 100)}% sparsity") + print(" Expected: 2-6x attention speedup, 1.5-2x end-to-end speedup") + + return transformer + + +def run_generation( + sampler: ValidationSampler, + config: GenerationConfig, + device: str, + num_inference_steps: int, + label: str = "", +) -> tuple[torch.Tensor, torch.Tensor | None, float]: + """Run video generation and return timing information. + + Args: + sampler: ValidationSampler instance. + config: Generation configuration. + device: Device to run on. + num_inference_steps: Number of denoising steps. + label: Label for logging (e.g., "BASELINE", "WITH VSA"). + + Returns: + Tuple of (video, audio, elapsed_time) + """ + if label: + print(f"\n{label}") + + print(f"Generating video ({num_inference_steps} steps)...") + start_time = time.time() + + with StandaloneSamplingProgress(num_steps=num_inference_steps) as progress: + sampler.sampling_context = progress + video, audio = sampler.generate(config=config, device=device) + + elapsed = time.time() - start_time + print(f"[OK] Generation completed in {elapsed:.2f}s") + + return video, audio, elapsed + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Test VSA integration with LTX-2 video generation", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Model arguments + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to model checkpoint (.safetensors)", + ) + parser.add_argument( + "--text-encoder-path", + type=str, + required=True, + help="Path to Gemma text encoder directory", + ) + + # Generation arguments + parser.add_argument( + "--prompt", + type=str, + default="A serene mountain landscape with a flowing river, golden hour lighting", + help="Text prompt for generation", + ) + parser.add_argument( + "--negative-prompt", + type=str, + default="", + help="Negative prompt", + ) + parser.add_argument( + "--height", + type=int, + default=512, + help="Video height (must be divisible by 32)", + ) + parser.add_argument( + "--width", + type=int, + default=768, + help="Video width (must be divisible by 32)", + ) + parser.add_argument( + "--num-frames", + type=int, + default=121, + help="Number of video frames (must be k*8 + 1)", + ) + parser.add_argument( + "--frame-rate", + type=float, + default=25.0, + help="Video frame rate", + ) + parser.add_argument( + "--num-inference-steps", + type=int, + default=30, + help="Number of denoising steps", + ) + parser.add_argument( + "--guidance-scale", + type=float, + default=4.0, + help="Classifier-free guidance scale (CFG)", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility", + ) + + # VSA arguments + parser.add_argument( + "--no-vsa", + action="store_true", + help="Disable VSA (for baseline comparison)", + ) + parser.add_argument( + "--compare", + action="store_true", + help="Run both with and without VSA for comparison", + ) + parser.add_argument( + "--top-k-ratio", + type=float, + default=0.5, + help="VSA sparsity ratio (0.5 = 50%% sparsity)", + ) + + # Audio arguments + parser.add_argument( + "--skip-audio", + action="store_true", + help="Skip audio generation (faster testing)", + ) + + # Output arguments + parser.add_argument( + "--output", + type=str, + default="output_vsa.mp4", + help="Output video path (.mp4)", + ) + parser.add_argument( + "--output-baseline", + type=str, + default="output_baseline.mp4", + help="Baseline output path (used with --compare)", + ) + + # Device arguments + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to run on (cuda/cpu)", + ) + + args = parser.parse_args() + + # Validate arguments + generate_audio = not args.skip_audio + + print("=" * 80) + print("LTX-2 + VSA Integration Test") + print("=" * 80) + + # Check VSA compatibility + tokens = calculate_expected_tokens(args.num_frames, args.height, args.width) + tiles = tokens // 64 + compatible, reason = is_vsa_compatible(args.num_frames, args.height, args.width) + + print("\nInput Configuration:") + print(f" Resolution: {args.width}x{args.height}") + print(f" Frames: {args.num_frames} @ {args.frame_rate} fps") + print(f" Expected tokens: {tokens} ({tiles} tiles)") + print(f" VSA compatibility: {reason}") + + if not compatible and not args.no_vsa and not args.compare: + print("\n[WARNING] WARNING: Input size may be too small for VSA benefit") + print(" Recommended: 121+ frames @ 512x768+ for optimal VSA performance") + print(" Use --no-vsa to disable VSA for small inputs") + + # Load model components + print("\nLoading LTX-2 model components...") + components = load_model( + checkpoint_path=args.checkpoint, + device="cpu", # Load to CPU first + dtype=torch.bfloat16, + with_video_vae_encoder=False, + with_video_vae_decoder=True, + with_audio_vae_decoder=generate_audio, + with_vocoder=generate_audio, + with_text_encoder=True, + text_encoder_path=args.text_encoder_path, + ) + print("[OK] Model components loaded") + + # Create generation config + gen_config = GenerationConfig( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, + frame_rate=args.frame_rate, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + seed=args.seed, + condition_image=None, + reference_video=None, + generate_audio=generate_audio, + include_reference_in_output=False, + ) + + print("\n" + "=" * 80) + print("Generation Parameters") + print("=" * 80) + print(f"Prompt: {args.prompt}") + if args.negative_prompt: + print(f"Negative prompt: {args.negative_prompt}") + print(f"Resolution: {args.width}x{args.height}") + print(f"Frames: {args.num_frames} @ {args.frame_rate} fps") + print(f"Inference steps: {args.num_inference_steps}") + print(f"CFG scale: {args.guidance_scale}") + print(f"Seed: {args.seed}") + if generate_audio: + video_duration = args.num_frames / args.frame_rate + print(f"Audio: Enabled (duration: {video_duration:.2f}s)") + else: + print("Audio: Disabled (skip-audio mode)") + print("=" * 80) + + # Test scenarios + results = {} + + if args.compare: + # ====================================================================== + # Run BASELINE (no VSA) + # ====================================================================== + print("\n" + "=" * 80) + print("TEST 1/2: BASELINE (no VSA)") + print("=" * 80) + + # Create sampler without VSA + sampler_baseline = ValidationSampler( + transformer=components.transformer, + vae_decoder=components.video_vae_decoder, + vae_encoder=components.video_vae_encoder, + text_encoder=components.text_encoder, + audio_decoder=components.audio_vae_decoder if generate_audio else None, + vocoder=components.vocoder if generate_audio else None, + ) + + try: + video_baseline, audio_baseline, time_baseline = run_generation( + sampler_baseline, + gen_config, + args.device, + args.num_inference_steps, + ) + results["baseline"] = time_baseline + + # Save baseline video + output_baseline_path = Path(args.output_baseline) + output_baseline_path.parent.mkdir(parents=True, exist_ok=True) + + audio_sample_rate = None + if audio_baseline is not None and components.vocoder is not None: + audio_sample_rate = components.vocoder.output_sample_rate + + save_video( + video_tensor=video_baseline, + output_path=output_baseline_path, + fps=args.frame_rate, + audio=audio_baseline, + audio_sample_rate=audio_sample_rate, + ) + print(f"[OK] Baseline video saved: {args.output_baseline}") + except Exception as e: + print(f"[FAIL] Baseline generation failed: {e}") + import traceback + + traceback.print_exc() + return + + # ====================================================================== + # Run WITH VSA + # ====================================================================== + print("\n" + "=" * 80) + print("TEST 2/2: WITH VSA") + print("=" * 80) + + # Reload transformer for VSA test + print("\nReloading transformer for VSA test...") + components_vsa = load_model( + checkpoint_path=args.checkpoint, + device="cpu", + dtype=torch.bfloat16, + with_video_vae_encoder=False, + with_video_vae_decoder=True, + with_audio_vae_decoder=generate_audio, + with_vocoder=generate_audio, + with_text_encoder=True, + text_encoder_path=args.text_encoder_path, + ) + + # Apply VSA + components_vsa.transformer = apply_vsa_to_transformer( + components_vsa.transformer, + args.num_frames, + args.height, + args.width, + top_k_ratio=args.top_k_ratio, + ) + + # Create sampler with VSA + sampler_vsa = ValidationSampler( + transformer=components_vsa.transformer, + vae_decoder=components_vsa.video_vae_decoder, + vae_encoder=components_vsa.video_vae_encoder, + text_encoder=components_vsa.text_encoder, + audio_decoder=components_vsa.audio_vae_decoder if generate_audio else None, + vocoder=components_vsa.vocoder if generate_audio else None, + ) + + try: + video_vsa, audio_vsa, time_vsa = run_generation( + sampler_vsa, + gen_config, + args.device, + args.num_inference_steps, + ) + results["vsa"] = time_vsa + + # Save VSA video + output_vsa_path = Path(args.output) + output_vsa_path.parent.mkdir(parents=True, exist_ok=True) + + audio_sample_rate = None + if audio_vsa is not None and components_vsa.vocoder is not None: + audio_sample_rate = components_vsa.vocoder.output_sample_rate + + save_video( + video_tensor=video_vsa, + output_path=output_vsa_path, + fps=args.frame_rate, + audio=audio_vsa, + audio_sample_rate=audio_sample_rate, + ) + print(f"[OK] VSA video saved: {args.output}") + except Exception as e: + print(f"[FAIL] VSA generation failed: {e}") + import traceback + + traceback.print_exc() + return + + else: + # ====================================================================== + # Single test (with or without VSA) + # ====================================================================== + print("\n" + "=" * 80) + print(f"TEST: {'WITH VSA' if not args.no_vsa else 'WITHOUT VSA'}") + print("=" * 80) + + transformer = components.transformer + + # Apply VSA if enabled + if not args.no_vsa: + transformer = apply_vsa_to_transformer( + transformer, + args.num_frames, + args.height, + args.width, + top_k_ratio=args.top_k_ratio, + ) + + # Create sampler + sampler = ValidationSampler( + transformer=transformer, + vae_decoder=components.video_vae_decoder, + vae_encoder=components.video_vae_encoder, + text_encoder=components.text_encoder, + audio_decoder=components.audio_vae_decoder if generate_audio else None, + vocoder=components.vocoder if generate_audio else None, + ) + + try: + video, audio, elapsed = run_generation( + sampler, + gen_config, + args.device, + args.num_inference_steps, + ) + results["single"] = elapsed + + # Save video + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + audio_sample_rate = None + if audio is not None and components.vocoder is not None: + audio_sample_rate = components.vocoder.output_sample_rate + + save_video( + video_tensor=video, + output_path=output_path, + fps=args.frame_rate, + audio=audio, + audio_sample_rate=audio_sample_rate, + ) + print(f"[OK] Video saved: {args.output}") + except Exception as e: + print(f"[FAIL] Generation failed: {e}") + import traceback + + traceback.print_exc() + return + + # ========================================================================== + # Results Summary + # ========================================================================== + print("\n" + "=" * 80) + print("TEST COMPLETE") + print("=" * 80) + + if args.compare: + speedup = results["baseline"] / results["vsa"] + print("\nPerformance Comparison:") + print(f" Baseline (no VSA): {results['baseline']:.2f}s") + print(f" With VSA: {results['vsa']:.2f}s") + print(f" Speedup: {speedup:.2f}x") + print() + print(f" Baseline video: {args.output_baseline}") + print(f" VSA video: {args.output}") + print() + if speedup >= 1.5: + print("[OK] Excellent speedup achieved!") + elif speedup >= 1.2: + print("[OK] Good speedup achieved") + else: + print("[WARNING] Speedup lower than expected (input may be too small for VSA)") + print() + print("Compare videos to verify quality is preserved with VSA.") + else: + print(f"\nGeneration time: {results['single']:.2f}s") + print(f"Output: {args.output}") + + print("\n[OK] VSA integration test successful!") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 4baf5bbe6..2c7e7252a 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -411,7 +411,7 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): # Configuration with RULER calibration -# Note: threshold field is omitted - calibration determines dynamic threshold λ = a / length +# Note: threshold field is omitted - calibration determines dynamic threshold lambda = a / length # The calibrated threshold adapts to sequence length for optimal sparsity SKIP_SOFTMAX_CALIB = { "sparse_cfg": { @@ -434,13 +434,169 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): } +class VSAAttributeConfig(ModeloptBaseConfig): + """Video Sparse Attention (VSA) attribute configuration. + + VSA uses a two-branch architecture optimized for video diffusion models: + 1. Compression branch: Block-averaged coarse attention + 2. Sparse branch: Top-K block selection for fine-grained attention + """ + + method: str = ModeloptField( + default="vsa", + title="Sparse attention method.", + description="Must be 'vsa' for Video Sparse Attention.", + ) + + enable: bool = ModeloptField( + default=True, + title="Enable VSA.", + description="If True, enables Video Sparse Attention. If False, bypasses sparsity.", + ) + + block_size_3d: tuple[int, int, int] | list[int] = ModeloptField( + default=(4, 4, 4), + title="3D block size.", + description=( + "Video block dimensions (T, H, W) for spatial-temporal tiling. " + "Default (4, 4, 4) creates 64-token blocks." + ), + ) + + top_k_ratio: float = ModeloptField( + default=0.5, + title="Top-K selection ratio.", + description=( + "Ratio of blocks to keep in sparse branch (0.0 to 1.0). " + "Lower values mean more sparsity. Default 0.5 keeps 50% of blocks." + ), + ) + + video_shape: tuple[int, int, int] | list[int] | None = ModeloptField( + default=None, + title="Video shape.", + description=( + "Video dimensions (T, H, W) after patchification. Required unless a " + "model-specific plugin (e.g., the LTX-2 plugin) computes it from the " + "model's patchifier. If None and no plugin provides a value, VSA will " + "raise an error at forward time." + ), + ) + + collect_stats: bool = ModeloptField( + default=False, + title="Collect statistics.", + description="Whether to collect sparsity statistics during forward pass.", + ) + + @field_validator("method") + @classmethod + def validate_vsa_method(cls, v): + """Validate method is 'vsa'.""" + if v != "vsa": + raise ValueError(f"VSAAttributeConfig method must be 'vsa', got '{v}'") + return v + + @field_validator("block_size_3d") + @classmethod + def validate_block_size_3d(cls, v): + """Validate 3D block size.""" + if isinstance(v, list): + v = tuple(v) + if len(v) != 3: + raise ValueError(f"block_size_3d must have 3 elements (T, H, W), got {len(v)}") + if any(x <= 0 for x in v): + raise ValueError(f"All block_size_3d values must be positive, got {v}") + return v + + @field_validator("top_k_ratio") + @classmethod + def validate_top_k_ratio(cls, v): + """Validate top-K ratio is in valid range.""" + if not 0.0 < v <= 1.0: + raise ValueError(f"top_k_ratio must be in range (0, 1], got {v}") + return v + + @field_validator("video_shape") + @classmethod + def validate_video_shape(cls, v): + """Validate video shape if provided.""" + if v is None: + return v + if isinstance(v, list): + v = tuple(v) + if len(v) != 3: + raise ValueError(f"video_shape must have 3 elements (T, H, W), got {len(v)}") + if any(x <= 0 for x in v): + raise ValueError(f"All video_shape values must be positive, got {v}") + return v + + +class VSAConfig(SparseAttentionConfig): + """Configuration for Video Sparse Attention optimization. + + VSA is designed for video diffusion models with learned gate_compress + parameters in attention layers. + """ + + sparse_cfg: SparseAttentionCfgType = ModeloptField( + default={ + "*attention*": { + "method": "vsa", + "block_size_3d": (4, 4, 4), + "top_k_ratio": 0.5, + "enable": True, + }, + "default": {"enable": False}, + }, + title="VSA configuration", + description=( + "Pattern-based configuration for Video Sparse Attention. " + "Keys are patterns to match module names, values are VSA configs." + ), + validate_default=True, + ) + + +# Pre-defined VSA Configuration for video diffusion models +VSA_DEFAULT = { + "sparse_cfg": { + "*attention*": { + "method": "vsa", + "block_size_3d": (4, 4, 4), + "top_k_ratio": 0.5, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +# High sparsity VSA configuration (70% of blocks pruned) +VSA_HIGH_SPARSITY = { + "sparse_cfg": { + "*attention*": { + "method": "vsa", + "block_size_3d": (4, 4, 4), + "top_k_ratio": 0.3, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + __all__ = [ "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_DEFAULT", + "VSA_DEFAULT", + "VSA_HIGH_SPARSITY", "CalibrationConfig", "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", "SparseAttentionCfgType", "SparseAttentionConfig", "SparseAttributeConfig", + "VSAAttributeConfig", + "VSAConfig", ] diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py index 8a109fda7..31a281f5f 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py @@ -15,6 +15,8 @@ """Sparse attention methods package.""" +from modelopt.torch.utils import import_plugin + from .registry import SparseAttentionMethod, get_sparse_method, register_sparse_method __all__ = [ @@ -25,3 +27,6 @@ # Import method implementations to trigger registration from . import flash_skip_softmax + +with import_plugin("vsa"): + from . import vsa # Video Sparse Attention (requires fastvideo_kernel) diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index 3f3e78db6..17c3e92d8 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -31,6 +31,8 @@ def __init__(self): # Flag to indicate calibration mode (set by calibrator) # Instance attribute to prevent shared state across multiple models self._calibration_mode: bool = False + # Last computed statistics (set by subclass forward methods, read by SparseAttentionModule) + self._last_stats: dict[str, Any] | None = None # Calibration parameters set by the calibrator after calibration. # Exponential model params per phase: {"prefill": {"a": ..., "b": ...}, ...} diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py b/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py new file mode 100644 index 000000000..9fbe233fa --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py @@ -0,0 +1,373 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Video Sparse Attention (VSA) method for video diffusion models. + +VSA implements a two-branch sparse attention architecture: +1. Compression Branch: Averages tokens within 3D video blocks and computes coarse attention +2. Sparse Branch: Selects top-K blocks based on importance and computes fine-grained attention + +This method requires model modification to expose gate_compress for optimal quality. + +Uses the optimized Triton kernel from fastvideo_kernel for 2-6x speedup. + +The data flow mirrors FastVideo's VideoSparseAttentionImpl: + tile(Q,K,V,gate) -> Triton kernel -> untile(output) +""" + +import logging +import math +from typing import Any + +import torch + +from . import SparseAttentionMethod, register_sparse_method +from .vsa_utils import ( + construct_variable_block_sizes, + get_non_pad_index, + get_reverse_tile_partition_indices, + get_tile_partition_indices, +) + +logger = logging.getLogger(__name__) + + +@register_sparse_method("vsa") +class VSA(SparseAttentionMethod): + """Video Sparse Attention with two-branch architecture. + + VSA combines a compression branch (coarse-grained block attention) with + a sparse branch (fine-grained attention on top-K selected blocks). + + The final output is: output = out_compression * gate_compress + out_sparse + + where gate_compress is a learned parameter from the model layer that + controls the balance between compression and sparse branches. + + Configuration Parameters: + - block_size_3d: 3D tile dimensions (T, H, W), default (4, 4, 4) + - top_k_ratio: Ratio of blocks to keep (0.0-1.0), default 0.5 + - video_shape: Video dimensions (T, H, W) after patchification + + Requirements: + - Model must expose gate_compress parameter in attention layers + - Input tensors must be 4D: [batch, heads, seq_len, dim] + """ + + def __init__(self, method_config: dict | None = None): + """Initialize VSA method. + + Args: + method_config: Configuration dict with VSA parameters. + """ + super().__init__() + config = method_config or {} + + # Block configuration + block_size = config.get("block_size_3d", (4, 4, 4)) + if isinstance(block_size, list): + block_size = tuple(block_size) + self.block_size_3d = block_size + self.block_elements = block_size[0] * block_size[1] * block_size[2] + + # Sparsity configuration + self.top_k_ratio = config.get("top_k_ratio", 0.5) + + # Video shape (can be set dynamically) + self.video_shape = config.get("video_shape", None) + + # Track last computed statistics + self._last_stats: dict = {} + + # Metadata cache: avoids recomputing tile indices on every forward pass. + # Matches FastVideo's @lru_cache on utility functions. + self._cached_metadata: dict[str, Any] | None = None + self._cached_metadata_key: tuple | None = None + + def set_video_shape(self, video_shape: tuple[int, int, int]): + """Set video shape for current forward pass. + + Args: + video_shape: Video dimensions (T, H, W) after patchification. + """ + self.video_shape = video_shape + + def _compute_metadata(self, seq_len: int, device: torch.device) -> dict[str, Any]: + """Compute block metadata from video shape. + + Results are cached and reused when called with the same (seq_len, video_shape) + to avoid recomputing tile indices on every denoising step, matching FastVideo's + ``@functools.lru_cache`` on the underlying utility functions. + + Args: + seq_len: Sequence length (should equal T * H * W). + device: Device for tensors. + + Returns: + Metadata dict with tile indices, variable sizes, etc. + """ + if self.video_shape is None: + raise ValueError( + f"video_shape must be provided for VSA but is None (seq_len={seq_len}). " + f"Set it via the VSA config ('video_shape' key), call set_video_shape(), " + f"or use a model-specific plugin (e.g., LTX-2 plugin) that computes it " + f"from the model's patchifier." + ) + + # Return cached metadata if inputs haven't changed + cache_key = (seq_len, self.video_shape) + if self._cached_metadata is not None and self._cached_metadata_key == cache_key: + return self._cached_metadata + + vid_t, vid_h, vid_w = self.video_shape + ts_t, ts_h, ts_w = self.block_size_3d + + # Validate sequence length matches video shape + expected_seq_len = vid_t * vid_h * vid_w + if seq_len != expected_seq_len: + raise ValueError( + f"Sequence length {seq_len} does not match video shape {self.video_shape} " + f"(expected {expected_seq_len})" + ) + + # Calculate number of tiles + num_tiles = ( + math.ceil(vid_t / ts_t), + math.ceil(vid_h / ts_h), + math.ceil(vid_w / ts_w), + ) + total_tiles = num_tiles[0] * num_tiles[1] * num_tiles[2] + + # Get partitioning indices + tile_indices = get_tile_partition_indices(self.video_shape, self.block_size_3d, device) + reverse_indices = get_reverse_tile_partition_indices( + self.video_shape, self.block_size_3d, device + ) + variable_sizes = construct_variable_block_sizes( + self.video_shape, num_tiles, self.block_size_3d, device + ) + non_pad_index = get_non_pad_index(variable_sizes, self.block_elements) + + # Calculate padded sizes + t_padded = num_tiles[0] * ts_t + h_padded = num_tiles[1] * ts_h + w_padded = num_tiles[2] * ts_w + padded_seq_len = t_padded * h_padded * w_padded + + metadata = { + "video_shape": self.video_shape, + "tile_size": self.block_size_3d, + "num_tiles": num_tiles, + "total_tiles": total_tiles, + "tile_indices": tile_indices, + "reverse_indices": reverse_indices, + "variable_sizes": variable_sizes, + "non_pad_index": non_pad_index, + "padded_seq_len": padded_seq_len, + } + + # Cache for reuse across denoising steps + self._cached_metadata = metadata + self._cached_metadata_key = cache_key + + return metadata + + def _tile_tensor(self, tensor: torch.Tensor, metadata: dict) -> torch.Tensor: + """Rearrange tensor into tile layout with padding. + + Args: + tensor: Input tensor [batch, heads, seq_len, dim]. + metadata: Metadata from _compute_metadata. + + Returns: + Tiled tensor [batch, heads, padded_seq_len, dim]. + """ + batch, heads, seq_len, dim = tensor.shape + device = tensor.device + dtype = tensor.dtype + + tile_indices = metadata["tile_indices"] + non_pad_index = metadata["non_pad_index"] + padded_seq_len = metadata["padded_seq_len"] + + # Create padded tensor + padded = torch.zeros((batch, heads, padded_seq_len, dim), device=device, dtype=dtype) + + # Rearrange to tile order and place in padded positions + padded[:, :, non_pad_index] = tensor[:, :, tile_indices] + + return padded + + def _untile_tensor(self, tensor: torch.Tensor, metadata: dict, seq_len: int) -> torch.Tensor: + """Reverse tile layout back to original order. + + Args: + tensor: Tiled tensor [batch, heads, padded_seq_len, dim]. + metadata: Metadata from _compute_metadata. + seq_len: Original sequence length. + + Returns: + Output tensor [batch, heads, seq_len, dim]. + """ + non_pad_index = metadata["non_pad_index"] + reverse_indices = metadata["reverse_indices"] + + # Extract non-padded tokens and reverse order + return tensor[:, :, non_pad_index][:, :, reverse_indices] + + def forward_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + gate_compress: torch.Tensor | None = None, + video_shape: tuple[int, int, int] | None = None, + **kwargs, + ) -> tuple[torch.Tensor, dict]: + """Compute VSA two-branch sparse attention. + + Data flow (mirrors FastVideo's VideoSparseAttentionImpl): + 1. Compute tile metadata from video_shape + 2. Tile Q, K, V, gate_compress into padded tile order + 3. Run Triton VSA kernel on tiled tensors + 4. Untile output back to original token order + + Args: + query: Query tensor [batch, heads, seq_len, dim]. + key: Key tensor [batch, heads, seq_len, dim]. + value: Value tensor [batch, heads, seq_len, dim]. + gate_compress: Learned gating weights [batch, heads, seq_len, dim]. + If None, uses equal weighting (0.5) for both branches. + video_shape: Video dimensions (T, H, W). If None, uses self.video_shape. + **kwargs: Additional arguments (ignored). + + Returns: + Tuple of (attention_output, stats) where: + - attention_output: [batch, heads, seq_len, dim] + - stats: Dict with sparsity statistics + """ + if video_shape is not None: + self.video_shape = video_shape + + batch, heads, seq_len, dim = query.shape + device = query.device + + # Compute block metadata (cached across denoising steps) + metadata = self._compute_metadata(seq_len, device) + total_tiles = metadata["total_tiles"] + variable_sizes = metadata["variable_sizes"] + + # Calculate top-K based on ratio + top_k = max(1, int(self.top_k_ratio * total_tiles)) + + # ========== TILE: rearrange tokens into tile order ========== + # Mirrors FastVideo's VideoSparseAttentionImpl.preprocess_qkv (tile) + query_tiled = self._tile_tensor(query, metadata) + key_tiled = self._tile_tensor(key, metadata) + value_tiled = self._tile_tensor(value, metadata) + gate_tiled = ( + self._tile_tensor(gate_compress, metadata) if gate_compress is not None else None + ) + + # ========== TRITON VSA KERNEL ========== + # Kernel operates on tiled tensors in [batch, heads, padded_seq, dim] format + try: + from fastvideo_kernel import video_sparse_attn as triton_vsa_kernel + except ModuleNotFoundError: + raise ModuleNotFoundError( + "VSA requires the 'fastvideo_kernel' package for its Triton sparse attention kernel. " + "Please install it before using the VSA method." + ) from None + output_tiled = triton_vsa_kernel( + query_tiled, + key_tiled, + value_tiled, + variable_sizes, # q_variable_sizes + variable_sizes, # kv_variable_sizes + top_k, + block_size=self.block_size_3d, + compress_attn_weight=gate_tiled, + ) + + # ========== UNTILE: restore original token order ========== + # Mirrors FastVideo's VideoSparseAttentionImpl.postprocess_output (untile) + output = self._untile_tensor(output_tiled, metadata, seq_len) + + # Compute statistics + actual_sparsity = 1.0 - (top_k / total_tiles) + stats = { + "sparsity": actual_sparsity, + "phase": "vsa_triton", + "total_blocks": total_tiles, + "sparse_blocks": total_tiles - top_k, + "top_k": top_k, + "video_shape": self.video_shape, + } + self._last_stats = stats + + return output, stats + + def calculate_sparsity( + self, + attention_scores: torch.Tensor, + ) -> tuple[torch.Tensor, dict]: + """Not used by VSA. Required stub for the abstract base class. + + VSA replaces the entire attention computation via ``forward_attention()``, + which is called directly by model-specific plugins (e.g., ``_LTX2SparseAttention``). + The softmax-patching path that calls this method is never reached in the VSA flow. + + Raises: + NotImplementedError: Always. Use ``forward_attention()`` instead. + """ + raise NotImplementedError( + "VSA does not use the softmax-patching path. " + "Use forward_attention() via a model-specific plugin instead." + ) + + def apply_sparsity( + self, + attention_scores: torch.Tensor, + sparse_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Not used by VSA. Required stub for the abstract base class. + + See ``calculate_sparsity`` for details. + + Raises: + NotImplementedError: Always. Use ``forward_attention()`` instead. + """ + raise NotImplementedError( + "VSA does not use the softmax-patching path. " + "Use forward_attention() via a model-specific plugin instead." + ) + + def get_threshold_info(self) -> dict[str, Any]: + """Get VSA configuration info. + + Returns: + Dictionary with VSA configuration. + """ + return { + "type": "vsa", + "block_size_3d": self.block_size_3d, + "top_k_ratio": self.top_k_ratio, + "video_shape": self.video_shape, + } + + @property + def name(self) -> str: + """Method identifier.""" + return "vsa" diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/vsa_utils.py b/modelopt/torch/sparsity/attention_sparsity/methods/vsa_utils.py new file mode 100644 index 000000000..affed79f7 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/vsa_utils.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for Video Sparse Attention (VSA). + +This module provides 3D block operations for video sparse attention, +including reshaping tensors into video blocks and variable block size computation. +""" + +import functools +import math + +import torch + + +@functools.lru_cache(maxsize=10) +def get_tile_partition_indices( + video_shape: tuple[int, int, int], + tile_size: tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + """Get indices to partition video tokens into tiles. + + Args: + video_shape: Video dimensions (T, H, W) after patchification. + tile_size: Tile dimensions (tile_T, tile_H, tile_W). + device: Device for the output tensor. + + Returns: + LongTensor of indices to rearrange tokens into tile order. + """ + vid_t, vid_h, vid_w = video_shape + ts, hs, ws = tile_size + indices = torch.arange(vid_t * vid_h * vid_w, device=device, dtype=torch.long).reshape( + vid_t, vid_h, vid_w + ) + + tiles = [] + for t in range(math.ceil(vid_t / ts)): + for h in range(math.ceil(vid_h / hs)): + for w in range(math.ceil(vid_w / ws)): + tile = indices[ + t * ts : min(t * ts + ts, vid_t), + h * hs : min(h * hs + hs, vid_h), + w * ws : min(w * ws + ws, vid_w), + ] + tiles.append(tile.flatten()) + + return torch.cat(tiles, dim=0) + + +@functools.lru_cache(maxsize=10) +def get_reverse_tile_partition_indices( + video_shape: tuple[int, int, int], + tile_size: tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + """Get indices to reverse tile partitioning back to original order. + + Args: + video_shape: Video dimensions (T, H, W) after patchification. + tile_size: Tile dimensions (tile_T, tile_H, tile_W). + device: Device for the output tensor. + + Returns: + LongTensor of indices to reverse the tile rearrangement. + """ + forward_indices = get_tile_partition_indices(video_shape, tile_size, device) + return torch.argsort(forward_indices) + + +@functools.lru_cache(maxsize=10) +def construct_variable_block_sizes( + video_shape: tuple[int, int, int], + num_tiles: tuple[int, int, int], + tile_size: tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + """Compute valid (non-padded) token count for each tile. + + Since video dimensions may not divide evenly by tile size, edge tiles + will have fewer valid tokens. This function computes the actual valid + token count for each tile. + + Args: + video_shape: Video dimensions (T, H, W) after patchification. + num_tiles: Number of tiles in each dimension (n_T, n_H, n_W). + tile_size: Tile dimensions (tile_T, tile_H, tile_W). + device: Device for the output tensor. + + Returns: + LongTensor of shape [num_tiles_total] with valid tokens per tile. + """ + t, h, w = video_shape + ts_t, ts_h, ts_w = tile_size + n_t, n_h, n_w = num_tiles + + def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor: + """Compute size of each tile along one dimension.""" + sizes = torch.full((n_tiles,), tile, dtype=torch.long, device=device) + remainder = dim_len - (n_tiles - 1) * tile + sizes[-1] = remainder if remainder > 0 else tile + return sizes + + t_sizes = _sizes(t, ts_t, n_t) # [n_t] + h_sizes = _sizes(h, ts_h, n_h) # [n_h] + w_sizes = _sizes(w, ts_w, n_w) # [n_w] + + # Broadcast multiply to get tokens per tile + block_sizes = ( + t_sizes[:, None, None] * h_sizes[None, :, None] * w_sizes[None, None, :] + ).reshape(-1) + + return block_sizes + + +@functools.lru_cache(maxsize=10) +def get_non_pad_index( + variable_block_sizes: torch.LongTensor, + max_block_size: int, +) -> torch.LongTensor: + """Get indices of non-padded tokens in the padded layout. + + When tiles have variable sizes, we pad to max_block_size. This function + returns indices to extract only valid (non-padded) tokens. + + Args: + variable_block_sizes: Tensor of valid token counts per tile. + max_block_size: Maximum tile size (usually tile_T * tile_H * tile_W). + + Returns: + LongTensor of indices for valid tokens. + """ + n_win = variable_block_sizes.shape[0] + device = variable_block_sizes.device + + starts_pad = torch.arange(n_win, device=device) * max_block_size + index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :] + index_mask = ( + torch.arange(max_block_size, device=device)[None, :] < variable_block_sizes[:, None] + ) + + return index_pad[index_mask] diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py index 599832943..f61e59afd 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -18,6 +18,8 @@ import torch.nn as nn import transformers +logger = logging.getLogger(__name__) + from modelopt.torch.opt.dynamic import DynamicModule from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry @@ -91,10 +93,10 @@ def register_sparse_attention_on_the_fly(model: nn.Module) -> bool: SparseAttentionRegistry.register({module_type: type_name})(_GenericSparseAttention) attention_types.add(module_type) registered_count += 1 - print(f"Registered {type_name} for sparse attention optimization") + logger.info(f"Registered {type_name} for sparse attention optimization") if registered_count > 0: - print(f"Dynamically registered {registered_count} attention module types for sparsity") + logger.info(f"Dynamically registered {registered_count} attention module types for sparsity") return registered_count > 0 diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py b/modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py new file mode 100644 index 000000000..030187b4d --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py @@ -0,0 +1,459 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Plugin for LTX-2 video diffusion models with VSA support. + +LTX-2 uses a specific Attention module structure that differs from standard +HuggingFace/Diffusers attention. This plugin provides: + +1. Detection of LTX-2's native Attention modules +2. Q/K/V projection, RMSNorm, and RoPE handling +3. Support for trainable gate_compress for VSA quality optimization +""" + +import logging +from typing import Optional + +import torch +import torch.nn as nn + +from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry +from . import CUSTOM_MODEL_PLUGINS + +logger = logging.getLogger(__name__) + +# Module-level storage for video_shape extracted by the forward pre-hook. +_current_vsa_video_shape: tuple[int, int, int] | None = None + + +def _extract_video_shape_hook(module: nn.Module, args: tuple) -> None: + """Forward pre-hook on LTXModel to extract dit_seq_shape from Modality.positions. + + Mirrors FastVideo's ``VideoSparseAttentionMetadataBuilder.build()`` which + computes ``dit_seq_shape = raw_latent_shape // patch_size``. Here we derive + the same shape by counting unique position values per dimension in the + ``Modality.positions`` tensor, which is available at the LTXModel entry + point (before ``TransformerArgsPreprocessor`` converts it to RoPE embeddings). + + The result is stored in the module-level ``_current_vsa_video_shape`` so + that ``_LTX2SparseAttention._resolve_video_shape()`` can read it``. + """ + global _current_vsa_video_shape + + # LTXModel.forward(self, video: Modality | None, audio, perturbations) + video = args[0] if len(args) > 0 else None + if video is None or not hasattr(video, "positions") or video.positions is None: + return + + positions = video.positions # (B, 3, T) or (B, 3, T, 2) + + try: + if positions.ndim == 4: + # (B, 3, T, 2) -- take start coordinates + pos_per_dim = positions[0, :, :, 0] # (3, T) + elif positions.ndim == 3: + # (B, 3, T) + pos_per_dim = positions[0] # (3, T) + else: + return + + t_dim = pos_per_dim[0].unique().numel() + h_dim = pos_per_dim[1].unique().numel() + w_dim = pos_per_dim[2].unique().numel() + seq_len = positions.shape[2] + + if t_dim * h_dim * w_dim == seq_len: + _current_vsa_video_shape = (t_dim, h_dim, w_dim) + logger.debug( + f"Extracted dit_seq_shape={_current_vsa_video_shape} from " + f"Modality.positions (seq_len={seq_len})" + ) + else: + logger.debug( + f"Position-derived shape {(t_dim, h_dim, w_dim)} product " + f"({t_dim * h_dim * w_dim}) != seq_len ({seq_len}), skipping" + ) + except Exception: + # Silently skip -- _resolve_video_shape will fall back to config + pass + + +def _is_ltx2_model(model: nn.Module) -> bool: + """Check if model is an LTX-2 model. + + Uses LTXModel / LTXSelfAttention class names to avoid false positives + from other DiTs (e.g., LongCat) that share similar attribute patterns. + + Args: + model: PyTorch model to check. + + Returns: + True if model is LTX-2 (root class LTXModel or contains LTXSelfAttention). + """ + if type(model).__name__ == "LTXModel": + return True + return any( + type(m).__name__ == "LTXSelfAttention" for m in model.modules() + ) + + +def _is_ltx2_attention_module(module: nn.Module, name: str = "") -> bool: + """Check if a module is an LTX-2 Attention module by class name or structure. + + Primary: class name is LTXSelfAttention. Fallback: has to_q/k/v, q_norm, + k_norm, and rope_type (unique to LTX-2 among DiTs). + + Args: + module: Module to check. + name: Module name in model hierarchyx. + + Returns: + True if module is an LTX-2 attention module. + """ + class_name = type(module).__name__ + if class_name == "LTXSelfAttention": + return True + # Fallback for subclasses or renamed variants: must have rope_type (LTX-2 only) + return ( + hasattr(module, "to_q") + and hasattr(module, "to_k") + and hasattr(module, "to_v") + and hasattr(module, "q_norm") + and hasattr(module, "k_norm") + ) + + +class _LTX2SparseAttention(SparseAttentionModule): + """Sparse attention wrapper for LTX-2 Attention modules. + + This plugin handles all LTX-2 specific logic: + - Argument mapping (x -> hidden_states, context -> encoder_hidden_states) + - Q/K/V projection and normalization + - RoPE application + - Trainable gate_compress for VSA quality optimization + + The plugin computes Q, K, V directly and calls VSA.forward_attention(), + keeping VSA as a pure algorithm without module-specific knowledge. + """ + + def _setup(self): + """Setup the VSA wrapper with trainable gate_compress.""" + super()._setup() + + # Check if we need to add gate_compress projection + if not hasattr(self, "to_gate_compress"): + to_q = self.to_q + in_features = to_q.in_features + out_features = to_q.out_features + + # Create gate_compress projection (zero-initialized) + self.to_gate_compress = nn.Linear(in_features, out_features, bias=True) + nn.init.zeros_(self.to_gate_compress.weight) + nn.init.zeros_(self.to_gate_compress.bias) + + # Move to same device/dtype as to_q + self.to_gate_compress = self.to_gate_compress.to( + device=to_q.weight.device, + dtype=to_q.weight.dtype, + ) + + def _compute_qkv( + self, + x: torch.Tensor, + context: Optional[torch.Tensor], + pe: Optional[torch.Tensor] = None, + k_pe: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute Q, K, V projections with LTX-2 specific processing. + + Args: + x: Input tensor [batch, seq, hidden_dim]. + context: Context for cross-attention, or None for self-attention. + pe: Positional embeddings for RoPE. + k_pe: Optional separate positional embeddings for keys. + + Returns: + Tuple of (query, key, value) tensors in [batch, seq, hidden_dim] format. + """ + # For self-attention, use x for K, V + context = context if context is not None else x + + # Project to Q, K, V + query = self.to_q(x) + key = self.to_k(context) + value = self.to_v(context) + + # Apply Q/K norms (LTX-2 specific) + if hasattr(self, "q_norm"): + query = self.q_norm(query) + if hasattr(self, "k_norm"): + key = self.k_norm(key) + + # Apply RoPE if provided (LTX-2 specific) + if pe is not None and hasattr(self, "rope_type"): + try: + from ltx_core.model.transformer.rope import apply_rotary_emb + except ModuleNotFoundError: + raise ModuleNotFoundError( + "LTX-2 VSA plugin requires the 'ltx_core' package for RoPE support. " + "Please install it before using VSA with LTX-2 models." + ) from None + + query = apply_rotary_emb(query, pe, self.rope_type) + key = apply_rotary_emb(key, pe if k_pe is None else k_pe, self.rope_type) + + return query, key, value + + def _reshape_for_vsa(self, tensor: torch.Tensor, num_heads: int) -> torch.Tensor: + """Reshape tensor from [batch, seq, hidden_dim] to [batch, heads, seq, head_dim]. + + Args: + tensor: Input tensor [batch, seq, hidden_dim]. + num_heads: Number of attention heads. + + Returns: + Reshaped tensor [batch, heads, seq, head_dim]. + """ + batch, seq_len, hidden_dim = tensor.shape + head_dim = hidden_dim // num_heads + return tensor.view(batch, seq_len, num_heads, head_dim).transpose(1, 2) + + def _reshape_from_vsa(self, tensor: torch.Tensor) -> torch.Tensor: + """Reshape tensor from [batch, heads, seq, head_dim] to [batch, seq, hidden_dim]. + + Args: + tensor: Input tensor [batch, heads, seq, head_dim]. + + Returns: + Reshaped tensor [batch, seq, hidden_dim]. + """ + batch, heads, seq_len, head_dim = tensor.shape + return tensor.transpose(1, 2).contiguous().view(batch, seq_len, heads * head_dim) + + def _resolve_video_shape(self, seq_len: int) -> tuple[int, int, int] | None: + """Resolve video_shape for the current forward pass. + + Resolution order (mirrors FastVideo's metadata flow): + 1. ``_current_vsa_video_shape`` -- set by the forward pre-hook from + ``Modality.positions`` (analogous to ``get_forward_context().attn_metadata``) + 2. ``method.video_shape`` -- explicitly set via the sparsify config + + Args: + seq_len: Current sequence length (for validation). + + Returns: + Tuple (T, H, W) or None if not determinable. + """ + # 1. Primary: video_shape extracted by forward pre-hook + if _current_vsa_video_shape is not None: + t, h, w = _current_vsa_video_shape + if t * h * w == seq_len: + return _current_vsa_video_shape + + # 2. Fallback: explicit video_shape from sparsify config + method = getattr(self, "_sparse_method_instance", None) + if method is not None and method.video_shape is not None: + t, h, w = method.video_shape + if t * h * w == seq_len: + return method.video_shape + + return None + + def forward(self, *args, **kwargs): + """Forward pass computing Q/K/V directly and calling VSA.forward_attention(). + + This method handles all LTX-2 specific logic: + 1. Extract arguments (uses LTX-2 native names: x, context, pe, k_pe) + 2. Compute Q, K, V projections with norms and RoPE + 3. Compute gate_compress + 4. Resolve video_shape from hook or config + 5. Check compatibility and call VSA or fallback + 6. Apply output projection + """ + x = kwargs.get("x") + if x is None and len(args) > 0: + x = args[0] + + if x is None: + return self._call_original_forward(*args, **kwargs) + + context = kwargs.get("context") + pe = kwargs.get("pe") + k_pe = kwargs.get("k_pe") + + # === Check cross-attention === + if context is not None: + if x.shape[1] != context.shape[1]: + # NOTE: skip VSA for Cross-attention, use original attention + return self._call_original_forward(*args, **kwargs) + + # === Check VSA method availability === + if not hasattr(self, "_sparse_method_instance") or self._sparse_method_instance is None: + return self._call_original_forward(*args, **kwargs) + + method = self._sparse_method_instance # VSA instance + + # === Compute Q, K, V === + query, key, value = self._compute_qkv(x, context, pe, k_pe) + + # === Check sequence length compatibility === + seq_len = query.shape[1] + block_size_3d = method.block_size_3d # type: ignore[attr-defined] + block_elements = block_size_3d[0] * block_size_3d[1] * block_size_3d[2] + + if seq_len < block_elements: + # Incompatible sequence length (e.g., audio attention with seq_len=32) + logger.debug(f"VSA skipped: seq_len={seq_len} < block_elements={block_elements}") + return self._call_original_forward(*args, **kwargs) + + # === Resolve video_shape === + video_shape = self._resolve_video_shape(seq_len) + if video_shape is None: + logger.debug(f"VSA skipped: no matching video_shape for seq_len={seq_len}") + return self._call_original_forward(*args, **kwargs) + + # === Compute gate_compress === + gate_compress = None + if hasattr(self, "to_gate_compress"): + gate_compress = self.to_gate_compress(x) + + # === Reshape for VSA: [batch, seq, hidden] -> [batch, heads, seq, head_dim] === + query = self._reshape_for_vsa(query, self.heads) + key = self._reshape_for_vsa(key, self.heads) + value = self._reshape_for_vsa(value, self.heads) + if gate_compress is not None: + gate_compress = self._reshape_for_vsa(gate_compress, self.heads) + + # === Call VSA forward_attention directly === + output, stats = method.forward_attention( # type: ignore[attr-defined] + query=query, + key=key, + value=value, + gate_compress=gate_compress, + video_shape=video_shape, + ) + + # Store stats for collection + method._last_stats = stats + + # === Reshape output: [batch, heads, seq, head_dim] -> [batch, seq, hidden] === + output = self._reshape_from_vsa(output) + + # === Apply output projection === + if hasattr(self, "to_out"): + output = self.to_out(output) + + return output + + def _call_original_forward(self, *args, **kwargs): + """Call the original module's forward method, bypassing VSA. + + Temporarily disables sparse attention so SparseAttentionModule.forward() + passes through to the original module. + """ + # Temporarily disable sparse attention to bypass sparse logic + # SparseAttentionModule.forward() checks is_enabled and passes through if False + was_enabled = getattr(self, "_enabled", True) + self._enabled = False + try: + # This goes through SparseAttentionModule.forward() which checks is_enabled, + # sees it's disabled, and calls DynamicModule.forward() -> original module + result = SparseAttentionModule.forward(self, *args, **kwargs) + finally: + self._enabled = was_enabled + return result + + def get_gate_compress_parameters(self): + """Get trainable gate_compress parameters. + + Returns: + Iterator of gate_compress parameters for optimization. + """ + if hasattr(self, "to_gate_compress"): + return self.to_gate_compress.parameters() + return iter([]) # Empty iterator + + +def register_ltx2_attention(model: nn.Module) -> int: + """Register LTX-2 Attention modules for VSA wrapping. + + This function detects LTX-2 Attention modules and registers them with + the SparseAttentionRegistry. It also handles unregistering any generic + wrappers that may have been registered first. + + Args: + model: LTX-2 model to process. + + Returns: + Number of module types registered. + """ + if not _is_ltx2_model(model): + return 0 + + registered_types = set() + num_modules = 0 + + for name, module in model.named_modules(): + if not _is_ltx2_attention_module(module, name): + continue + + num_modules += 1 + module_type = type(module) + + if module_type in registered_types: + continue + + # Unregister any existing generic wrapper + if module_type in SparseAttentionRegistry: + logger.debug(f"Unregistering generic wrapper for {module_type.__name__}") + SparseAttentionRegistry.unregister(module_type) + + # Register LTX-2 specific wrapper + SparseAttentionRegistry.register({module_type: module_type.__name__})(_LTX2SparseAttention) + registered_types.add(module_type) + logger.info(f"Registered LTX-2 attention: {module_type.__name__}") + + if num_modules > 0: + logger.info(f"Found {num_modules} LTX-2 Attention modules in model") + + # Register forward pre-hook to extract video_shape from Modality.positions + # before each forward pass -- analogous to FastVideo's + # set_forward_context(attn_metadata=builder.build(...)) + model.register_forward_pre_hook(_extract_video_shape_hook) + logger.debug("Registered VSA video_shape extraction hook on model") + + return len(registered_types) + + +def register_ltx2_on_the_fly(model: nn.Module) -> bool: + """Plugin entry point for LTX-2 VSA registration. + + Args: + model: PyTorch model to process. + + Returns: + True if any LTX-2 modules were registered. + """ + num_registered = register_ltx2_attention(model) + + if num_registered > 0: + logger.info(f"Registered {num_registered} LTX-2 attention types for VSA") + return True + + return False + + +# Add to plugin set (order-independent: guards against re-registration internally) +CUSTOM_MODEL_PLUGINS.add(register_ltx2_on_the_fly) diff --git a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py index 4333d1243..d6bb2cf0e 100644 --- a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py +++ b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py @@ -61,9 +61,26 @@ def set_from_attribute_config( Args: attribute_cfg: Sparse attention attribute configuration. """ + from .config import VSAAttributeConfig + + # Determine which config class to use based on method + config_dict = attribute_cfg or {} + if isinstance(attribute_cfg, dict): + method = config_dict.get("method", "flash_skip_softmax") + elif attribute_cfg is not None and hasattr(attribute_cfg, "method"): + method = attribute_cfg.method + else: + method = "flash_skip_softmax" + + # Select appropriate config class based on method + if method == "vsa": + config_class = VSAAttributeConfig + else: + config_class = SparseAttentionAttributeConfig + # Ensure config is validated through Pydantic - if not isinstance(attribute_cfg, SparseAttentionAttributeConfig): - attribute_cfg = SparseAttentionAttributeConfig(**(attribute_cfg or {})) + if not isinstance(attribute_cfg, (SparseAttentionAttributeConfig, VSAAttributeConfig)): + attribute_cfg = config_class(**(config_dict)) # Store raw config for method initialization self._method_config = {} @@ -80,10 +97,10 @@ def set_from_attribute_config( # Process each attribute from validated config for attribute, val in attribute_cfg.model_dump().items(): - # Validate attribute if using config class - if hasattr(SparseAttentionAttributeConfig, "model_fields"): - assert attribute in SparseAttentionAttributeConfig.model_fields, ( - f"{attribute} is not a valid SparseAttentionModule attribute" + # Validate attribute against the appropriate config class + if hasattr(config_class, "model_fields"): + assert attribute in config_class.model_fields, ( + f"{attribute} is not a valid {config_class.__name__} attribute" ) if attribute in _module_attributes: @@ -159,14 +176,16 @@ def _setup(self): def forward(self, *args, **kwargs): """Forward with selected sparse attention method. - This method dispatches to the appropriate sparse attention implementation - based on the configured method and backend. + Methods that replace the full attention computation (e.g., VSA) override + ``forward()`` in their model-specific plugin (e.g., ``_LTX2SparseAttention``) + and never reach this path. This method handles the softmax-patching path + used by methods like ``flash_skip_softmax``. """ # Pass through if sparse attention is disabled if not self.is_enabled: return super().forward(*args, **kwargs) - # Get the appropriate context manager for this configuration + # Standard path: softmax patching context = self._get_sparse_context() # Apply sparse attention through the context From d609c72b423056f6d7e2409618232124d9b03ffd Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Sat, 7 Feb 2026 17:03:08 -0800 Subject: [PATCH 07/10] Add unit test Signed-off-by: Kai Xu --- .../sparsity/attention_sparsity/test_vsa.py | 320 ++++++++++++++++++ 1 file changed, 320 insertions(+) create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_vsa.py diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py b/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py new file mode 100644 index 000000000..8ebb8fb99 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py @@ -0,0 +1,320 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU-only unit tests for Video Sparse Attention (VSA). + +Tests cover: +- vsa_utils.py: tile/untile index logic, variable block sizes +- vsa.py: VSA method init, metadata computation, validation, caching +- config.py: VSAAttributeConfig validation +- plugins/ltx2.py: model/module detection helpers +""" + +import math + +import pytest +import torch +from pydantic import ValidationError + +from modelopt.torch.sparsity.attention_sparsity.config import VSAAttributeConfig, VSAConfig +from modelopt.torch.sparsity.attention_sparsity.methods.vsa import VSA +from modelopt.torch.sparsity.attention_sparsity.methods.vsa_utils import ( + construct_variable_block_sizes, + get_non_pad_index, + get_reverse_tile_partition_indices, + get_tile_partition_indices, +) + +# --------------------------------------------------------------------------- +# vsa_utils: tile partition indices +# --------------------------------------------------------------------------- + + +class TestTilePartitionIndices: + """Tests for get_tile_partition_indices.""" + + def test_evenly_divisible(self): + """Tiles cover full volume with no remainder.""" + video_shape = (8, 8, 8) + tile_size = (4, 4, 4) + idx = get_tile_partition_indices(video_shape, tile_size, torch.device("cpu")) + assert idx.shape == (8 * 8 * 8,) + # Every original index appears exactly once + assert torch.equal(idx.sort().values, torch.arange(512)) + + def test_non_divisible(self): + """Edge tiles are smaller when dims don't divide evenly.""" + video_shape = (5, 6, 7) + tile_size = (4, 4, 4) + seq_len = 5 * 6 * 7 + idx = get_tile_partition_indices(video_shape, tile_size, torch.device("cpu")) + assert idx.shape == (seq_len,) + assert torch.equal(idx.sort().values, torch.arange(seq_len)) + + def test_round_trip(self): + """tile then reverse_tile is identity.""" + video_shape = (6, 10, 8) + tile_size = (4, 4, 4) + device = torch.device("cpu") + fwd = get_tile_partition_indices(video_shape, tile_size, device) + rev = get_reverse_tile_partition_indices(video_shape, tile_size, device) + # Applying forward then reverse should yield the original order + assert torch.equal(fwd[rev], torch.arange(6 * 10 * 8)) + + +# --------------------------------------------------------------------------- +# vsa_utils: variable block sizes +# --------------------------------------------------------------------------- + + +class TestVariableBlockSizes: + """Tests for construct_variable_block_sizes.""" + + def test_evenly_divisible(self): + """All tiles have full size when dims divide evenly.""" + video_shape = (8, 8, 8) + tile_size = (4, 4, 4) + num_tiles = (2, 2, 2) + sizes = construct_variable_block_sizes( + video_shape, num_tiles, tile_size, torch.device("cpu") + ) + assert sizes.shape == (8,) # 2*2*2 tiles + assert (sizes == 64).all() # every tile is full 4*4*4 + + def test_non_divisible_sum(self): + """Sum of variable sizes equals original sequence length.""" + video_shape = (5, 6, 7) + tile_size = (4, 4, 4) + num_tiles = ( + math.ceil(5 / 4), + math.ceil(6 / 4), + math.ceil(7 / 4), + ) + sizes = construct_variable_block_sizes( + video_shape, num_tiles, tile_size, torch.device("cpu") + ) + assert sizes.sum().item() == 5 * 6 * 7 + + def test_partial_tile_smaller(self): + """Last tile along a non-divisible dim should be smaller.""" + video_shape = (5, 4, 4) + tile_size = (4, 4, 4) + num_tiles = (2, 1, 1) + sizes = construct_variable_block_sizes( + video_shape, num_tiles, tile_size, torch.device("cpu") + ) + # First tile: 4*4*4=64, second tile: 1*4*4=16 + assert sizes[0].item() == 64 + assert sizes[1].item() == 16 + + +# --------------------------------------------------------------------------- +# vsa_utils: non-pad index +# --------------------------------------------------------------------------- + + +class TestNonPadIndex: + """Tests for get_non_pad_index.""" + + def test_full_blocks(self): + """All blocks full size → non_pad covers everything.""" + sizes = torch.tensor([64, 64, 64]) + npi = get_non_pad_index(sizes, 64) + assert npi.shape == (192,) # 3 * 64 + + def test_partial_blocks(self): + """Partial blocks → non_pad skips padding positions.""" + sizes = torch.tensor([64, 16]) + npi = get_non_pad_index(sizes, 64) + assert npi.shape == (80,) # 64 + 16 + + +# --------------------------------------------------------------------------- +# VSA method: init and config +# --------------------------------------------------------------------------- + + +class TestVSAInit: + """Tests for VSA.__init__ and basic properties.""" + + def test_defaults(self): + vsa = VSA() + assert vsa.block_size_3d == (4, 4, 4) + assert vsa.block_elements == 64 + assert vsa.top_k_ratio == 0.5 + assert vsa.video_shape is None + assert vsa.name == "vsa" + + def test_custom_config(self): + vsa = VSA({"block_size_3d": [2, 2, 2], "top_k_ratio": 0.3, "video_shape": (8, 8, 8)}) + assert vsa.block_size_3d == (2, 2, 2) + assert vsa.block_elements == 8 + assert vsa.top_k_ratio == 0.3 + assert vsa.video_shape == (8, 8, 8) + + def test_set_video_shape(self): + vsa = VSA() + vsa.set_video_shape((4, 8, 12)) + assert vsa.video_shape == (4, 8, 12) + + def test_get_threshold_info(self): + vsa = VSA({"top_k_ratio": 0.7, "video_shape": (4, 4, 4)}) + info = vsa.get_threshold_info() + assert info["type"] == "vsa" + assert info["top_k_ratio"] == 0.7 + + +# --------------------------------------------------------------------------- +# VSA method: metadata computation and validation +# --------------------------------------------------------------------------- + + +class TestVSAMetadata: + """Tests for VSA._compute_metadata validation and caching.""" + + def test_no_video_shape_raises(self): + vsa = VSA() + with pytest.raises(ValueError, match="video_shape must be provided"): + vsa._compute_metadata(100, torch.device("cpu")) + + def test_seq_len_mismatch_raises(self): + vsa = VSA({"video_shape": (4, 4, 4)}) + with pytest.raises(ValueError, match="does not match video shape"): + vsa._compute_metadata(100, torch.device("cpu")) # expected 64 + + def test_valid_metadata(self): + vsa = VSA({"video_shape": (8, 8, 8)}) + meta = vsa._compute_metadata(512, torch.device("cpu")) + assert meta["video_shape"] == (8, 8, 8) + assert meta["num_tiles"] == (2, 2, 2) + assert meta["total_tiles"] == 8 + + def test_metadata_caching(self): + vsa = VSA({"video_shape": (8, 8, 8)}) + m1 = vsa._compute_metadata(512, torch.device("cpu")) + m2 = vsa._compute_metadata(512, torch.device("cpu")) + assert m1 is m2 # same object, not recomputed + + +# --------------------------------------------------------------------------- +# VSA method: abstract stubs raise +# --------------------------------------------------------------------------- + + +class TestVSAStubs: + """calculate_sparsity and apply_sparsity should raise NotImplementedError.""" + + def test_calculate_sparsity_raises(self): + vsa = VSA() + with pytest.raises(NotImplementedError, match="softmax-patching"): + vsa.calculate_sparsity(torch.zeros(1)) + + def test_apply_sparsity_raises(self): + vsa = VSA() + with pytest.raises(NotImplementedError, match="softmax-patching"): + vsa.apply_sparsity(torch.zeros(1)) + + +# --------------------------------------------------------------------------- +# VSAAttributeConfig validation +# --------------------------------------------------------------------------- + + +class TestVSAAttributeConfig: + """Tests for VSAAttributeConfig pydantic validation.""" + + def test_valid_defaults(self): + cfg = VSAAttributeConfig() + assert cfg.method == "vsa" + assert cfg.block_size_3d == (4, 4, 4) + assert cfg.top_k_ratio == 0.5 + + def test_top_k_ratio_out_of_range(self): + with pytest.raises(ValidationError, match="top_k_ratio"): + VSAAttributeConfig(top_k_ratio=0.0) + with pytest.raises(ValidationError, match="top_k_ratio"): + VSAAttributeConfig(top_k_ratio=1.5) + + def test_video_shape_wrong_length(self): + with pytest.raises(ValidationError, match="3 elements"): + VSAAttributeConfig(video_shape=(4, 4)) + + def test_video_shape_negative(self): + with pytest.raises(ValidationError, match="positive"): + VSAAttributeConfig(video_shape=(4, -1, 4)) + + def test_video_shape_none_allowed(self): + cfg = VSAAttributeConfig(video_shape=None) + assert cfg.video_shape is None + + def test_vsa_config_defaults(self): + cfg = VSAConfig() + assert "*attention*" in cfg.sparse_cfg + assert cfg.sparse_cfg["*attention*"]["method"] == "vsa" + + +# --------------------------------------------------------------------------- +# LTX-2 plugin: detection helpers +# --------------------------------------------------------------------------- + + +class TestLTX2Detection: + """Tests for _is_ltx2_model and _is_ltx2_attention_module.""" + + def test_non_ltx2_model(self): + from modelopt.torch.sparsity.attention_sparsity.plugins.ltx2 import _is_ltx2_model + + model = torch.nn.Linear(10, 10) + assert _is_ltx2_model(model) is False + + def test_ltx2_model_by_class_name(self): + from modelopt.torch.sparsity.attention_sparsity.plugins.ltx2 import _is_ltx2_model + + # Fake a class named LTXModel + class LTXModel(torch.nn.Module): + pass + + assert _is_ltx2_model(LTXModel()) is True + + def test_ltx2_attention_by_class_name(self): + from modelopt.torch.sparsity.attention_sparsity.plugins.ltx2 import ( + _is_ltx2_attention_module, + ) + + class LTXSelfAttention(torch.nn.Module): + pass + + assert _is_ltx2_attention_module(LTXSelfAttention()) is True + + def test_ltx2_attention_by_structure(self): + from modelopt.torch.sparsity.attention_sparsity.plugins.ltx2 import ( + _is_ltx2_attention_module, + ) + + # Module with LTX-2 attribute signature + m = torch.nn.Module() + m.to_q = torch.nn.Linear(8, 8) + m.to_k = torch.nn.Linear(8, 8) + m.to_v = torch.nn.Linear(8, 8) + m.q_norm = torch.nn.LayerNorm(8) + m.k_norm = torch.nn.LayerNorm(8) + assert _is_ltx2_attention_module(m) is True + + def test_non_attention_module(self): + from modelopt.torch.sparsity.attention_sparsity.plugins.ltx2 import ( + _is_ltx2_attention_module, + ) + + assert _is_ltx2_attention_module(torch.nn.Linear(10, 10)) is False From 547fb5cad8dbf7db66886c72c0f477b6289de1ca Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Sat, 7 Feb 2026 23:46:47 -0800 Subject: [PATCH 08/10] Fix VSA config Signed-off-by: Kai Xu --- examples/video_diffusion/vsa/README.md | 21 +- .../vsa/test_ltx2_vsa_integration.py | 50 +- .../attention_sparsity/calibration/dataset.py | 551 ------------------ .../calibration/ruler_utils.py | 491 ---------------- .../sparsity/attention_sparsity/config.py | 26 +- .../attention_sparsity/methods/registry.py | 2 - .../attention_sparsity/methods/vsa.py | 11 +- .../attention_sparsity/plugins/huggingface.py | 4 +- .../attention_sparsity/plugins/ltx2.py | 67 ++- .../sparsity/attention_sparsity/test_vsa.py | 21 +- 10 files changed, 101 insertions(+), 1143 deletions(-) delete mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py delete mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py diff --git a/examples/video_diffusion/vsa/README.md b/examples/video_diffusion/vsa/README.md index dccdadd4f..9470ddd46 100644 --- a/examples/video_diffusion/vsa/README.md +++ b/examples/video_diffusion/vsa/README.md @@ -42,14 +42,19 @@ python test_ltx2_vsa_integration.py \ | Parameter | Default | Description | |-----------|---------|-------------| -| `--top_k_ratio` | 0.5 | Ratio of blocks to keep (0.0-1.0). Lower = more sparse | -| `--block_size` | 4 4 4 | 3D block size (T H W) for video tiling | -| `--video_shape` | 16 28 48 | Video dimensions (T H W) after patchification | -| `--batch_size` | 1 | Batch size for inference | +| `--checkpoint` | (required) | Path to model checkpoint (.safetensors) | +| `--text-encoder-path` | (required) | Path to Gemma text encoder directory | +| `--prompt` | A serene mountain... | Text prompt for generation | +| `--top-k-ratio` | 0.5 | Ratio of blocks to keep (0.0-1.0). Lower = more sparse | +| `--num-frames` | 121 | Number of video frames (must be k*8 + 1) | +| `--height` | 512 | Video height (must be divisible by 32) | +| `--width` | 768 | Video width (must be divisible by 32) | +| `--num-inference-steps` | 30 | Number of denoising steps | +| `--guidance-scale` | 4.0 | Classifier-free guidance scale | +| `--seed` | 42 | Random seed for reproducibility | | `--device` | cuda | Device (cuda/cpu) | -| `--dtype` | bfloat16 | Data type (float32/float16/bfloat16) | - -## Examples +| `--compare` | off | Run both baseline and VSA for comparison | +| `--no-vsa` | off | Disable VSA (baseline only) | ## API Usage @@ -66,7 +71,7 @@ model = mtsa.sparsify(model, config=VSA_DEFAULT) # Or with custom configuration custom_config = { "sparse_cfg": { - "*attention*": { + "*attn*": { "method": "vsa", "block_size_3d": (4, 4, 4), "top_k_ratio": 0.3, # 70% sparsity diff --git a/examples/video_diffusion/vsa/test_ltx2_vsa_integration.py b/examples/video_diffusion/vsa/test_ltx2_vsa_integration.py index 1257b31a2..b8ea94052 100644 --- a/examples/video_diffusion/vsa/test_ltx2_vsa_integration.py +++ b/examples/video_diffusion/vsa/test_ltx2_vsa_integration.py @@ -54,6 +54,7 @@ """ import argparse +import copy import time from pathlib import Path @@ -64,6 +65,7 @@ from ltx_trainer.video_utils import save_video from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.config import VSA_DEFAULT def calculate_expected_tokens(num_frames: int, height: int, width: int) -> int: @@ -131,27 +133,17 @@ def apply_vsa_to_transformer( print(" [WARNING] Input size may be too small for VSA to provide significant benefit.") print(" Consider using larger inputs (121+ frames @ 512x768+) for best results.") - # Configure VSA - # NOTE: LTX-2 uses "attn1", "attn2", "audio_attn1", "audio_attn2" naming - # Pattern must be "*attn*" not "*attention*" to match these module names - sparse_config = { - "sparse_cfg": { - "*attn*": { - "method": "vsa", - "video_shape": None, # Auto-infer from LTX-2's compressed tokens - "block_size_3d": (4, 4, 4), # Standard VSA tile size - "top_k_ratio": top_k_ratio, - } - } - } + # Configure VSA using the standard preset, overriding top_k_ratio if needed + sparse_config = copy.deepcopy(VSA_DEFAULT) + # Find the attn pattern key and override top_k_ratio + for cfg in sparse_config["sparse_cfg"].values(): + if isinstance(cfg, dict) and cfg.get("method") == "vsa": + cfg["top_k_ratio"] = top_k_ratio # Apply VSA to transformer print(" Applying VSA to attention modules...") transformer = sparsify(transformer, sparse_config) - print(f" [OK] VSA enabled with {int(top_k_ratio * 100)}% sparsity") - print(" Expected: 2-6x attention speedup, 1.5-2x end-to-end speedup") - return transformer @@ -185,7 +177,7 @@ def run_generation( video, audio = sampler.generate(config=config, device=device) elapsed = time.time() - start_time - print(f"[OK] Generation completed in {elapsed:.2f}s") + print(f"Generation completed in {elapsed:.2f}s") return video, audio, elapsed @@ -351,7 +343,7 @@ def main() -> None: with_text_encoder=True, text_encoder_path=args.text_encoder_path, ) - print("[OK] Model components loaded") + print("Model components loaded") # Create generation config gen_config = GenerationConfig( @@ -433,9 +425,9 @@ def main() -> None: audio=audio_baseline, audio_sample_rate=audio_sample_rate, ) - print(f"[OK] Baseline video saved: {args.output_baseline}") + print(f"Baseline video saved: {args.output_baseline}") except Exception as e: - print(f"[FAIL] Baseline generation failed: {e}") + print(f"Baseline generation failed: {e}") import traceback traceback.print_exc() @@ -505,9 +497,9 @@ def main() -> None: audio=audio_vsa, audio_sample_rate=audio_sample_rate, ) - print(f"[OK] VSA video saved: {args.output}") + print(f"VSA video saved: {args.output}") except Exception as e: - print(f"[FAIL] VSA generation failed: {e}") + print(f"VSA generation failed: {e}") import traceback traceback.print_exc() @@ -567,9 +559,9 @@ def main() -> None: audio=audio, audio_sample_rate=audio_sample_rate, ) - print(f"[OK] Video saved: {args.output}") + print(f"Video saved: {args.output}") except Exception as e: - print(f"[FAIL] Generation failed: {e}") + print(f"Generation failed: {e}") import traceback traceback.print_exc() @@ -592,19 +584,11 @@ def main() -> None: print(f" Baseline video: {args.output_baseline}") print(f" VSA video: {args.output}") print() - if speedup >= 1.5: - print("[OK] Excellent speedup achieved!") - elif speedup >= 1.2: - print("[OK] Good speedup achieved") - else: - print("[WARNING] Speedup lower than expected (input may be too small for VSA)") - print() - print("Compare videos to verify quality is preserved with VSA.") else: print(f"\nGeneration time: {results['single']:.2f}s") print(f"Output: {args.output}") - print("\n[OK] VSA integration test successful!") + print("\nVSA integration test successful!") print("=" * 80) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py deleted file mode 100644 index 221ea2344..000000000 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py +++ /dev/null @@ -1,551 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""RULER dataset builder for sparse attention calibration.""" - -import random -import string -from dataclasses import dataclass -from typing import Any - -from tqdm import tqdm - -from . import ruler_utils - - -def _generate_target_lengths( - max_seqlen: int, num_length_bins: int = 4, min_seqlen: int = 1024 -) -> list[int]: - """Generate target lengths as descending powers of 2. - - Args: - max_seqlen: Maximum sequence length - num_length_bins: Maximum number of length bins to generate - min_seqlen: Minimum sequence length threshold - - Returns: - List of target lengths in descending order - - Examples: - >>> _generate_target_lengths(32768, 4) - [32768, 16384, 8192, 4096] - >>> _generate_target_lengths(2048, 4) - [2048, 1024] - """ - target_lengths = [] - current = max_seqlen - - for _ in range(num_length_bins): - if current < min_seqlen: - break - target_lengths.append(current) - current = current // 2 - - return target_lengths - - -@dataclass -class RulerTask: - """Configuration for a RULER task.""" - - name: str - task_type: str # niah, variable_tracking, freq_words_extraction, qa - tokens_to_generate: int - template: str - answer_prefix: str - args: dict[str, Any] - - -# Task configurations based on RULER benchmark -RULER_TASKS = { - "niah_multikey_2": RulerTask( - name="niah_multikey_2", - task_type="niah", - tokens_to_generate=128, - template=( - "Some special magic {type_needle_v} are hidden within the following text. " - "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" - "{context}\n" - "What are all the special magic {type_needle_v} for {query} mentioned in the provided text?" - ), - answer_prefix=( - " The special magic {type_needle_v} for {query} mentioned in the provided text are" - ), - args={ - "type_haystack": "needle", - "type_needle_k": "words", - "type_needle_v": "numbers", - "num_needle_k": 1, - "num_needle_v": 1, - "num_needle_q": 1, - }, - ), - "niah_multikey_3": RulerTask( - name="niah_multikey_3", - task_type="niah", - tokens_to_generate=128, - template=( - "Some special magic {type_needle_v} are hidden within the following text. " - "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" - "{context}\n" - "What are all the special magic {type_needle_v} for {query} mentioned in the provided text?" - ), - answer_prefix=( - " The special magic {type_needle_v} for {query} mentioned in the provided text are" - ), - args={ - "type_haystack": "needle", - "type_needle_k": "uuids", - "type_needle_v": "uuids", - "num_needle_k": 1, - "num_needle_v": 1, - "num_needle_q": 1, - }, - ), - "vt": RulerTask( - name="vt", - task_type="variable_tracking", - tokens_to_generate=30, - template=( - "Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n" - "{context}\n" - "Question: Find all variables that are assigned the value {query} in the text above." - ), - answer_prefix=( - " Answer: According to the chain(s) of variable assignment in the text above, " - "{num_v} variables are assgined the value {query}, they are: " - ), - args={"num_chains": 1, "num_hops": 4}, - ), - "fwe": RulerTask( - name="fwe", - task_type="freq_words_extraction", - tokens_to_generate=50, - template=( - "Read the following coded text and track the frequency of each coded word. " - "Find the three most frequently appeared coded words. {context}\n" - "Question: Do not provide any explanation. Please ignore the dots '....'. " - "What are the three most frequently appeared words in the above coded text?" - ), - answer_prefix=( - " Answer: According to the coded text above, " - "the three most frequently appeared words are:" - ), - args={"alpha": 2.0}, - ), - "qa_1": RulerTask( - name="qa_1", - task_type="qa", - tokens_to_generate=32, - template=( - "Answer the question based on the given documents. " - "Only give me the answer and do not output any other words.\n\n" - "The following are given documents.\n\n{context}\n\n" - "Answer the question based on the given documents. " - "Only give me the answer and do not output any other words.\n\n" - "Question: {query}" - ), - answer_prefix=" Answer:", - args={"dataset": "squad"}, - ), - "qa_2": RulerTask( - name="qa_2", - task_type="qa", - tokens_to_generate=32, - template=( - "Answer the question based on the given documents. " - "Only give me the answer and do not output any other words.\n\n" - "The following are given documents.\n\n{context}\n\n" - "Answer the question based on the given documents. " - "Only give me the answer and do not output any other words.\n\n" - "Question: {query}" - ), - answer_prefix=" Answer:", - args={"dataset": "hotpotqa"}, - ), -} - - -class RulerDatasetBuilder: - """Builder for RULER calibration datasets.""" - - def __init__( - self, - samples: int, - max_seqlen: int, - tokenizer_name_or_path: str | object, - num_length_bins: int = 4, - max_length_filter: int = 65536, - seed: int = 42, - ): - """Initialize RULER dataset builder. - - Args: - samples: Total number of samples to generate (distributed evenly across length bins) - max_seqlen: Maximum sequence length (length bins auto-generated as powers of 2) - tokenizer_name_or_path: HuggingFace tokenizer path or tokenizer object - seed: Random seed for reproducibility - num_length_bins: Number of length bins to generate (default: 4) - max_length_filter: Maximum sequence length to keep (default: 65536) - - Note: - Length bins are auto-generated as descending powers of 2: - [max_seqlen, max_seqlen/2, max_seqlen/4, ...] - Generation stops when num_length_bins is reached or length < 1024. - Subtasks are set to all the difficult tasks defined in RULER_TASKS. - """ - # Validate inputs - if samples <= 0: - raise ValueError(f"samples must be positive, got {samples}") - if max_seqlen < 1024: - raise ValueError(f"max_seqlen must be >= 1024, got {max_seqlen}") - - # Store parameters - self.total_samples = samples - self.max_seqlen = max_seqlen - self.num_length_bins = num_length_bins - self.subtasks = list(RULER_TASKS.keys()) - self.tokenizer_name_or_path = tokenizer_name_or_path - self.seed = seed - self.max_length_filter = max_length_filter - - # Generate target lengths and validate - self.target_lengths = _generate_target_lengths(max_seqlen, num_length_bins, min_seqlen=1024) - if not self.target_lengths: - raise ValueError(f"No valid target lengths generated from max_seqlen={max_seqlen}") - - # Distribute samples evenly across lengths - self.samples_per_length = [samples // len(self.target_lengths)] * len(self.target_lengths) - - # Initialize tokenizer - if isinstance(tokenizer_name_or_path, str): - from transformers import AutoTokenizer - - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) - else: - self.tokenizer = tokenizer_name_or_path - random.seed(seed) - - def build_calibration_dataset(self) -> list[dict[str, Any]]: - """Build the complete calibration dataset. - - Returns: - List of calibration samples with 'input' and 'length' fields - """ - all_samples = [] - - print( - f"Generating {self.total_samples} calibration samples " - f"across {len(self.target_lengths)} length bins: {self.target_lengths}" - ) - - # Generate calibration samples with sample-level progress - with tqdm(total=self.total_samples, desc="Generating RULER samples") as pbar: - for num_samples, target_length in zip(self.samples_per_length, self.target_lengths): - samples_per_task = max(num_samples // len(self.subtasks), 1) - - for task_name in self.subtasks: - for sample_idx in range(samples_per_task): - sample = self._generate_sample(task_name, target_length, sample_idx) - if sample and sample["length"] <= self.max_length_filter: - all_samples.append(sample) - pbar.update(1) - - random.shuffle(all_samples) - print(f"Generated {len(all_samples)} valid samples") - return all_samples - - def _generate_sample( - self, task_name: str, target_length: int, sample_idx: int - ) -> dict[str, Any]: - """Generate a single RULER sample. - - Args: - task_name: Name of the RULER task - target_length: Target sequence length in tokens - sample_idx: Index of the sample (for uniqueness) - - Returns: - Dict with 'input', 'length', and metadata fields - """ - task = RULER_TASKS[task_name] - - if task.task_type == "niah": - return self._generate_niah_sample(task, target_length, sample_idx) - elif task.task_type == "variable_tracking": - return self._generate_vt_sample(task, target_length, sample_idx) - elif task.task_type == "freq_words_extraction": - return self._generate_fwe_sample(task, target_length, sample_idx) - elif task.task_type == "qa": - return self._generate_qa_sample(task, target_length, sample_idx) - else: - raise ValueError(f"Unknown task type: {task.task_type}") - - def _generate_niah_sample( - self, task: RulerTask, target_length: int, sample_idx: int - ) -> dict[str, Any]: - """Generate a needle-in-haystack sample.""" - args = task.args - - # Find optimal haystack size for target length - optimal_haystack = ruler_utils.find_optimal_haystack_size( - tokenizer=self.tokenizer, - max_seq_length=target_length, - template=task.template, - answer_prefix=task.answer_prefix, - tokens_to_generate=task.tokens_to_generate, - type_haystack=args.get("type_haystack", "essay"), - type_needle_k=args.get("type_needle_k", "words"), - type_needle_v=args.get("type_needle_v", "numbers"), - num_needle_k=args.get("num_needle_k", 1), - num_needle_v=args.get("num_needle_v", 1), - num_needle_q=args.get("num_needle_q", 1), - ) - - # Generate sample using official RULER implementation - sample = ruler_utils.generate_niah_sample( - num_haystack=optimal_haystack, - tokenizer=self.tokenizer, - template=task.template, - answer_prefix=task.answer_prefix, - tokens_to_generate=task.tokens_to_generate, - type_haystack=args.get("type_haystack", "essay"), - type_needle_k=args.get("type_needle_k", "words"), - type_needle_v=args.get("type_needle_v", "numbers"), - num_needle_k=args.get("num_needle_k", 1), - num_needle_v=args.get("num_needle_v", 1), - num_needle_q=args.get("num_needle_q", 1), - random_seed=self.seed + sample_idx, - ) - - # Add task metadata - sample["task"] = task.name - sample["target_length"] = target_length - sample["sample_idx"] = sample_idx - - return sample - - def _generate_vt_sample( - self, task: RulerTask, target_length: int, sample_idx: int - ) -> dict[str, Any]: - """Generate a variable tracking sample.""" - args = task.args - num_chains = args["num_chains"] - num_hops = args["num_hops"] - - # Generate variable chains - variables = [] - chains = [] - for _ in range(num_chains): - chain = [self._generate_random_variable() for _ in range(num_hops + 1)] - variables.extend(chain) - chains.append(chain) - - # Generate assignments - assignments = [ - f"VAR {chain[i]} = {chain[i + 1]}" for chain in chains for i in range(len(chain) - 1) - ] - - # Create context with padding - context = self._pad_context_with_text( - "\n".join(assignments), target_length, "variable tracking context" - ) - - # Select a query value - query_value = random.choice([chain[-1] for chain in chains]) - - # Format template - template = task.template.format(context=context, query=query_value) - - # Count variables with the query value - num_v = sum(1 for chain in chains if chain[-1] == query_value) - - # Add answer prefix - full_input = template + task.answer_prefix.format(num_v=num_v, query=query_value) - - # Tokenize to get actual length - tokens = self.tokenizer.encode(full_input, add_special_tokens=False) - - return { - "input": full_input, - "length": len(tokens), - "task": task.name, - "target_length": target_length, - "sample_idx": sample_idx, - } - - def _generate_fwe_sample( - self, task: RulerTask, target_length: int, sample_idx: int - ) -> dict[str, Any]: - """Generate a frequency word extraction sample.""" - # Generate coded words with frequencies - num_unique_words = 50 - coded_words = [self._generate_coded_word() for _ in range(num_unique_words)] - - # Assign frequencies (make top 3 clearly more frequent) - frequencies = {} - for i, word in enumerate(coded_words): - if i < 3: - frequencies[word] = random.randint(20, 30) # High frequency - else: - frequencies[word] = random.randint(1, 10) # Low frequency - - # Generate the coded text - word_list = [] - for word, freq in frequencies.items(): - word_list.extend([word] * freq) - random.shuffle(word_list) - - # Add dots for separation - coded_text = " .... ".join(word_list) - - # Pad to target length - context = self._pad_context_with_text(coded_text, target_length, "coded text padding") - - # Format template - template = task.template.format(context=context) - full_input = template + task.answer_prefix - - # Tokenize to get actual length - tokens = self.tokenizer.encode(full_input, add_special_tokens=False) - - return { - "input": full_input, - "length": len(tokens), - "task": task.name, - "target_length": target_length, - "sample_idx": sample_idx, - } - - def _generate_qa_sample( - self, task: RulerTask, target_length: int, sample_idx: int - ) -> dict[str, Any]: - """Generate a QA sample.""" - # Generate synthetic documents - num_docs = 5 - documents = [] - - # Create a simple QA pair - answer = self._generate_random_phrase() - answer_doc_idx = random.randint(0, num_docs - 1) - question = f"What is the special code mentioned in document {answer_doc_idx + 1}?" - - for i in range(num_docs): - doc_text = self._generate_document_text(200) # Base document - if i == answer_doc_idx: # Insert answer in the correct document - doc_text += f" The special code is {answer}. " - documents.append(f"Document {i + 1}:\n{doc_text}\n") - - # Combine documents - context_base = "\n".join(documents) - - # Pad to target length - context = self._pad_context_with_text( - context_base, target_length, "additional document text" - ) - - # Format template - template = task.template.format(context=context, query=question) - full_input = template + task.answer_prefix - - # Tokenize to get actual length - tokens = self.tokenizer.encode(full_input, add_special_tokens=False) - - return { - "input": full_input, - "length": len(tokens), - "task": task.name, - "target_length": target_length, - "sample_idx": sample_idx, - } - - def _pad_context_with_text( - self, base_context: str, target_length: int, padding_type: str - ) -> str: - """Pad context to approach target length.""" - tokens = self.tokenizer.encode(base_context, add_special_tokens=False) - - while len(tokens) < target_length * 0.7: # Leave room for template - if padding_type == "variable tracking context": - padding = ( - f" VAR {self._generate_random_variable()} = {self._generate_random_variable()}." - ) - elif padding_type == "coded text padding": - padding = f" .... {self._generate_coded_word()} .... " - else: - padding = " " + self._generate_essay_text(50) - - base_context += padding - tokens = self.tokenizer.encode(base_context, add_special_tokens=False) - - if len(tokens) > target_length * 0.9: - # Truncate if too long - base_context = self.tokenizer.decode(tokens[: int(target_length * 0.8)]) - - return base_context - - def _generate_random_word(self) -> str: - """Generate a random word.""" - return "".join(random.choices(string.ascii_lowercase, k=random.randint(5, 10))) - - def _generate_random_variable(self) -> str: - """Generate a random variable name.""" - return "".join(random.choices(string.ascii_uppercase, k=1)) + "".join( - random.choices(string.digits, k=3) - ) - - def _generate_coded_word(self) -> str: - """Generate a coded word.""" - return "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) - - def _generate_random_phrase(self) -> str: - """Generate a random phrase.""" - words = [self._generate_random_word() for _ in range(random.randint(2, 4))] - return " ".join(words) - - def _generate_essay_text(self, num_words: int) -> str: - """Generate essay-like text.""" - topics = [ - "technology", - "science", - "nature", - "history", - "culture", - "education", - "health", - "economics", - "politics", - "philosophy", - "art", - "literature", - ] - - sentences = [] - words_generated = 0 - - while words_generated < num_words: - topic = random.choice(topics) - word1 = self._generate_random_word() - word2 = self._generate_random_word() - word3 = self._generate_random_word() - sentence = f"The {topic} of {word1} is {word2} and {word3}. " - sentences.append(sentence) - words_generated += len(sentence.split()) - - return " ".join(sentences) - - def _generate_document_text(self, num_words: int) -> str: - """Generate document-like text.""" - return self._generate_essay_text(num_words) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py deleted file mode 100644 index 741b621f5..000000000 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py +++ /dev/null @@ -1,491 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copied and Adapted from https://github.com/NVIDIA/RULER -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License - -"""Official RULER dataset generation utilities adapted for Model Optimizer. - -This module contains core logic from the RULER benchmark (https://github.com/NVIDIA/RULER) -adapted to work as a library for calibration purposes. The generation logic closely follows -the official RULER implementation to ensure dataset consistency. - -Key adaptations from official RULER: -- Converted from CLI scripts to library functions -- Works with HuggingFace tokenizers directly -- Removed file I/O, returns data structures -- Simplified for calibration use case (primarily NIAH tasks) -""" - -import logging -import random -import re -import uuid -from pathlib import Path -from typing import Any - -logger = logging.getLogger(__name__) - - -# Needle/Haystack template from official RULER -NEEDLE_TEMPLATE = "One of the special magic {type_needle_v} for {key} is: {value}." - -# Depth positions for needle insertion (from official RULER) -DEPTHS = [ - 0, - 2, - 5, - 7, - 10, - 12, - 15, - 18, - 20, - 23, - 25, - 28, - 30, - 33, - 35, - 38, - 40, - 43, - 45, - 48, - 50, - 53, - 55, - 58, - 60, - 62, - 65, - 67, - 70, - 72, - 75, - 77, - 80, - 82, - 85, - 87, - 90, - 92, - 95, - 97, - 100, -] - -# Data directory for RULER calibration files (in examples folder) -# Downloaded via examples/llm_sparsity/attention_sparsity/download_ruler_data.sh -_REPO_ROOT = Path(__file__).parent.parent.parent.parent.parent.parent -DATA_DIR = _REPO_ROOT / "examples" / "llm_sparsity" / "attention_sparsity" / "data" -RULER_URLS_FILE = DATA_DIR / "PaulGrahamEssays_URLs.txt" -ESSAYS_DIR = DATA_DIR / "essays" - - -def _get_data_dir() -> Path: - """Get data directory for RULER data. - - Returns: - Path to data directory under examples/llm_sparsity/attention_sparsity/ (created if doesn't exist) - """ - DATA_DIR.mkdir(parents=True, exist_ok=True) - return DATA_DIR - - -def _load_paul_graham_essays_from_files() -> str: - """Load Paul Graham essays from local files. - - Reads essay .txt files from the data/essays directory. - Files must be downloaded first using download_ruler_data.sh. - - Returns: - Combined essay text - - Raises: - RuntimeError: If essays directory doesn't exist or is empty - """ - if not ESSAYS_DIR.exists(): - raise RuntimeError( - f"Essays directory not found at {ESSAYS_DIR}.\n" - "Please run the download script first:\n" - " bash examples/llm_sparsity/attention_sparsity/download_ruler_data.sh" - ) - - essay_files = list(ESSAYS_DIR.glob("*.txt")) - if not essay_files: - raise RuntimeError( - f"No essay files found in {ESSAYS_DIR}.\n" - "Please run the download script first:\n" - " bash examples/llm_sparsity/attention_sparsity/download_ruler_data.sh" - ) - - logger.info(f"Loading {len(essay_files)} Paul Graham essays from local files...") - - all_essays = [] - for filepath in essay_files: - text = filepath.read_text() - all_essays.append(text) - - combined_text = " ".join(all_essays) - logger.info(f"Loaded {len(all_essays)} essays successfully") - - return combined_text - - -def _load_paul_graham_essays() -> str: - """Load Paul Graham essays from local files. - - Essay files must be downloaded first using download_ruler_data.sh. - - Returns: - Essay text as string - """ - essay_text = _load_paul_graham_essays_from_files() - return re.sub(r"\s+", " ", essay_text) - - -def _load_word_lists(): - """Load word lists for random word generation. - - Returns: - List of words (adj-noun combinations) - """ - import wonderwords - - # Load wonderwords lists (same as official RULER) - nouns = wonderwords.random_word._get_words_from_text_file("nounlist.txt") - adjs = wonderwords.random_word._get_words_from_text_file("adjectivelist.txt") - words = [f"{adj}-{noun}" for adj in adjs for noun in nouns] - words = sorted(set(words)) - return words - - -# Global word list (loaded once) -_WORD_LIST = None - - -def generate_random_number(num_digits=7) -> str: - """Generate random number (from official RULER).""" - lower_bound = 10 ** (num_digits - 1) - upper_bound = 10**num_digits - 1 - return str(random.randint(lower_bound, upper_bound)) - - -def generate_random_word() -> str: - """Generate random word (from official RULER).""" - global _WORD_LIST - if _WORD_LIST is None: - _WORD_LIST = _load_word_lists() - return random.choice(_WORD_LIST) - - -def generate_random_uuid() -> str: - """Generate random UUID (from official RULER).""" - return str(uuid.UUID(int=random.getrandbits(128), version=4)) - - -def generate_random(type_needle: str) -> str: - """Generate random needle value based on type (from official RULER). - - Args: - type_needle: Type of needle ('numbers', 'words', 'uuids') - - Returns: - Random value as string - """ - if type_needle == "numbers": - return generate_random_number() - elif type_needle == "words": - return generate_random_word() - elif type_needle == "uuids": - return generate_random_uuid() - else: - raise ValueError(f"Unknown needle type: {type_needle}") - - -def generate_niah_sample( - num_haystack: int, - tokenizer, - template: str, - answer_prefix: str, - tokens_to_generate: int = 128, - type_haystack: str = "essay", - type_needle_k: str = "words", - type_needle_v: str = "numbers", - num_needle_k: int = 1, - num_needle_v: int = 1, - num_needle_q: int = 1, - random_seed: int = 42, -) -> dict[str, Any]: - """Generate a single NIAH (Needle in a Haystack) sample. - - This function implements the core generation logic from official RULER's niah.py, - adapted to work as a library function. - - Args: - num_haystack: Number of haystack items/words - tokenizer: HuggingFace tokenizer (AutoTokenizer instance) - template: NIAH question template - answer_prefix: Answer prefix template - tokens_to_generate: Expected number of generation tokens - type_haystack: Type of haystack ('essay', 'noise', 'needle') - type_needle_k: Type of needle keys ('numbers', 'words', 'uuids') - type_needle_v: Type of needle values ('numbers', 'words', 'uuids') - num_needle_k: Number of needle keys - num_needle_v: Number of needle values per key - num_needle_q: Number of needles to query - random_seed: Random seed for this sample - - Returns: - Dictionary with 'input', 'outputs', 'length' keys - """ - import nltk - from nltk.tokenize import sent_tokenize - - try: - nltk.data.find("tokenizers/punkt") - except LookupError: - nltk.download("punkt", quiet=True) - nltk.download("punkt_tab", quiet=True) - - if random_seed is not None: - random.seed(random_seed) - - # Ensure num_needle_k >= num_needle_q - num_needle_k = max(num_needle_k, num_needle_q) - - # Generate needles (keys and values) - keys, values, needles = [], [], [] - for _ in range(num_needle_k): - keys.append(generate_random(type_needle_k)) - value = [] - for _ in range(num_needle_v): - value.append(generate_random(type_needle_v)) - needles.append( - NEEDLE_TEMPLATE.format( - type_needle_v=type_needle_v, - key=keys[-1], - value=value[-1], - ) - ) - values.append(value) - - random.shuffle(needles) - - # Generate context based on haystack type - if type_haystack == "essay": - # Load essay corpus - essay_text = _load_paul_graham_essays() - haystack = essay_text.split(" ") - - # Create text from haystack - if num_haystack <= len(haystack): - text = " ".join(haystack[:num_haystack]) - else: - # Repeat haystack as needed - repeats = (num_haystack + len(haystack) - 1) // len(haystack) - text = " ".join((haystack * repeats)[:num_haystack]) - - # Insert needles at various depths - document_sents = sent_tokenize(text.strip()) - insertion_positions = [ - 0, - *sorted( - int(len(document_sents) * (depth / 100)) - for depth in random.sample(DEPTHS, len(needles)) - ), - len(document_sents), - ] - - document_sents_list = [] - for i in range(1, len(insertion_positions)): - last_pos = insertion_positions[i - 1] - next_pos = insertion_positions[i] - document_sents_list.append(" ".join(document_sents[last_pos:next_pos])) - if i - 1 < len(needles): - document_sents_list.append(needles[i - 1]) - - context = " ".join(document_sents_list) - - elif type_haystack == "noise": - haystack_sent = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." - sentences = [haystack_sent] * num_haystack - indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) - for index, element in zip(indexes, needles): - sentences.insert(index, element) - context = "\n".join(sentences) - - elif type_haystack == "needle": - sentences = [ - NEEDLE_TEMPLATE.format( - type_needle_v=type_needle_v, - key=generate_random(type_needle_k), - value=generate_random(type_needle_v), - ) - for _ in range(num_haystack) - ] - - indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) - for index, element in zip(indexes, needles): - sentences.insert(index, element) - context = "\n".join(sentences) - - else: - raise ValueError(f"Unknown haystack type: {type_haystack}") - - # Generate query and answer - indices = random.sample(range(num_needle_k), num_needle_q) - queries = [keys[i] for i in indices] - answers = [a for i in indices for a in values[i]] - query = ", ".join(queries[:-1]) + ", and " + queries[-1] if len(queries) > 1 else queries[0] - - # Format template (adjust for singular vs plural) - type_needle_v_display = type_needle_v - formatted_template = template - if num_needle_q * num_needle_v == 1: - formatted_template = formatted_template.replace("Some", "A") - formatted_template = formatted_template.replace("are all", "is") - formatted_template = formatted_template.replace("are", "is") - formatted_template = formatted_template.replace("answers", "answer") - type_needle_v_display = type_needle_v[:-1] # remove "s" - - input_text = formatted_template.format( - type_needle_v=type_needle_v_display, - context=context, - query=query, - ) - - # Add answer prefix - formatted_answer_prefix = answer_prefix.format( - type_needle_v=type_needle_v_display, - query=query, - ) - input_text = input_text + formatted_answer_prefix - - # Calculate actual length - if hasattr(tokenizer, "encode"): - # HuggingFace tokenizer - tokens = tokenizer.encode(input_text, add_special_tokens=False) - length = len(tokens) + tokens_to_generate - else: - # Fallback - length = len(input_text.split()) + tokens_to_generate - - return { - "input": input_text, - "outputs": answers, - "length": length, - } - - -def find_optimal_haystack_size( - tokenizer, - max_seq_length: int, - template: str, - answer_prefix: str, - tokens_to_generate: int = 128, - type_haystack: str = "essay", - **kwargs, -) -> int: - """Find optimal haystack size using binary search (from official RULER). - - Args: - tokenizer: HuggingFace tokenizer - max_seq_length: Maximum sequence length - tokens_to_generate: Expected generation tokens - type_haystack: Type of haystack - template: NIAH question template - answer_prefix: Answer prefix template - **kwargs: Additional arguments for generate_niah_sample - - Returns: - Optimal number of haystack items - """ - # Determine incremental step based on haystack type - if type_haystack == "essay": - incremental = 500 - elif type_haystack in ["noise", "needle"]: - incremental = 25 - else: - incremental = 100 - - if max_seq_length < 4096 and type_haystack != "essay": - incremental = 5 - - # Estimate tokens per haystack item - sample = generate_niah_sample( - incremental, - tokenizer, - template, - answer_prefix, - tokens_to_generate, - type_haystack=type_haystack, - **kwargs, - ) - - if hasattr(tokenizer, "encode"): - sample_tokens = len(tokenizer.encode(sample["input"], add_special_tokens=False)) - else: - sample_tokens = len(sample["input"].split()) - - tokens_per_haystack = sample_tokens / incremental - estimated_max = int((max_seq_length / tokens_per_haystack) * 3) - - # Binary search for optimal size - lower_bound = incremental - upper_bound = max(estimated_max, incremental * 2) - optimal_num_haystack = None - - logger.info(f"Estimated {tokens_per_haystack:.1f} tokens per haystack") - logger.info(f"Binary search bounds: {lower_bound} to {upper_bound}") - - while lower_bound <= upper_bound: - mid = (lower_bound + upper_bound) // 2 - sample = generate_niah_sample( - mid, - tokenizer, - template, - answer_prefix, - tokens_to_generate, - type_haystack=type_haystack, - **kwargs, - ) - total_tokens = sample["length"] - - logger.debug(f"Testing haystack size: {mid}, tokens: {total_tokens}/{max_seq_length}") - - if total_tokens <= max_seq_length: - optimal_num_haystack = mid - lower_bound = mid + 1 - else: - upper_bound = mid - 1 - - final_size = optimal_num_haystack if optimal_num_haystack is not None else incremental - logger.info(f"Optimal haystack size: {final_size}") - - return final_size diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 2c7e7252a..7c0ba0842 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -541,7 +541,7 @@ class VSAConfig(SparseAttentionConfig): sparse_cfg: SparseAttentionCfgType = ModeloptField( default={ - "*attention*": { + "*attn*": { "method": "vsa", "block_size_3d": (4, 4, 4), "top_k_ratio": 0.5, @@ -558,10 +558,14 @@ class VSAConfig(SparseAttentionConfig): ) -# Pre-defined VSA Configuration for video diffusion models +# Pre-defined VSA Configuration for video diffusion models. +# Pattern "*attn*" matches all LTX-2 attention module names: +# - Video self-attention: attn1, attn2 +# - Audio self-attention: audio_attn1, audio_attn2 +# - Cross-modal attention: audio_to_video_attn, video_to_audio_attn VSA_DEFAULT = { "sparse_cfg": { - "*attention*": { + "*attn*": { "method": "vsa", "block_size_3d": (4, 4, 4), "top_k_ratio": 0.5, @@ -571,26 +575,10 @@ class VSAConfig(SparseAttentionConfig): }, } - -# High sparsity VSA configuration (70% of blocks pruned) -VSA_HIGH_SPARSITY = { - "sparse_cfg": { - "*attention*": { - "method": "vsa", - "block_size_3d": (4, 4, 4), - "top_k_ratio": 0.3, - "enable": True, - }, - "default": {"enable": False}, - }, -} - - __all__ = [ "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_DEFAULT", "VSA_DEFAULT", - "VSA_HIGH_SPARSITY", "CalibrationConfig", "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index 17c3e92d8..3f3e78db6 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -31,8 +31,6 @@ def __init__(self): # Flag to indicate calibration mode (set by calibrator) # Instance attribute to prevent shared state across multiple models self._calibration_mode: bool = False - # Last computed statistics (set by subclass forward methods, read by SparseAttentionModule) - self._last_stats: dict[str, Any] | None = None # Calibration parameters set by the calibrator after calibration. # Exponential model params per phase: {"prefill": {"a": ..., "b": ...}, ...} diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py b/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py index 9fbe233fa..b42f0a111 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py @@ -27,7 +27,6 @@ tile(Q,K,V,gate) -> Triton kernel -> untile(output) """ -import logging import math from typing import Any @@ -41,8 +40,6 @@ get_tile_partition_indices, ) -logger = logging.getLogger(__name__) - @register_sparse_method("vsa") class VSA(SparseAttentionMethod): @@ -287,8 +284,12 @@ def forward_attention( from fastvideo_kernel import video_sparse_attn as triton_vsa_kernel except ModuleNotFoundError: raise ModuleNotFoundError( - "VSA requires the 'fastvideo_kernel' package for its Triton sparse attention kernel. " - "Please install it before using the VSA method." + "VSA requires the 'fastvideo_kernel' package for its Triton sparse attention " + "kernel. The VSA method registered successfully, but the kernel is needed at " + "runtime. Install it with:\n" + " git clone https://github.com/FastVideo/FastVideo.git\n" + " cd FastVideo/fastvideo-kernel && ./build.sh\n" + "See https://github.com/hao-ai-lab/FastVideo/tree/main/fastvideo-kernel for details." ) from None output_tiled = triton_vsa_kernel( query_tiled, diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py index f61e59afd..51191b304 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -18,13 +18,13 @@ import torch.nn as nn import transformers -logger = logging.getLogger(__name__) - from modelopt.torch.opt.dynamic import DynamicModule from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry from . import CUSTOM_MODEL_PLUGINS +logger = logging.getLogger(__name__) + class _GenericSparseAttention(SparseAttentionModule): """Generic sparse attention that works with any HF attention module. diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py b/modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py index 030187b4d..41179913a 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py @@ -24,7 +24,7 @@ """ import logging -from typing import Optional +import weakref import torch import torch.nn as nn @@ -34,9 +34,6 @@ logger = logging.getLogger(__name__) -# Module-level storage for video_shape extracted by the forward pre-hook. -_current_vsa_video_shape: tuple[int, int, int] | None = None - def _extract_video_shape_hook(module: nn.Module, args: tuple) -> None: """Forward pre-hook on LTXModel to extract dit_seq_shape from Modality.positions. @@ -47,11 +44,11 @@ def _extract_video_shape_hook(module: nn.Module, args: tuple) -> None: ``Modality.positions`` tensor, which is available at the LTXModel entry point (before ``TransformerArgsPreprocessor`` converts it to RoPE embeddings). - The result is stored in the module-level ``_current_vsa_video_shape`` so - that ``_LTX2SparseAttention._resolve_video_shape()`` can read it``. + The result is stored on the model instance as ``module._vsa_video_shape`` + so that ``_LTX2SparseAttention._resolve_video_shape()`` can read it via + its ``_vsa_root_model`` reference. This avoids module-level global state + and is safe for concurrent models. """ - global _current_vsa_video_shape - # LTXModel.forward(self, video: Modality | None, audio, perturbations) video = args[0] if len(args) > 0 else None if video is None or not hasattr(video, "positions") or video.positions is None: @@ -75,9 +72,9 @@ def _extract_video_shape_hook(module: nn.Module, args: tuple) -> None: seq_len = positions.shape[2] if t_dim * h_dim * w_dim == seq_len: - _current_vsa_video_shape = (t_dim, h_dim, w_dim) + module._vsa_video_shape = (t_dim, h_dim, w_dim) logger.debug( - f"Extracted dit_seq_shape={_current_vsa_video_shape} from " + f"Extracted dit_seq_shape={module._vsa_video_shape} from " f"Modality.positions (seq_len={seq_len})" ) else: @@ -86,8 +83,7 @@ def _extract_video_shape_hook(module: nn.Module, args: tuple) -> None: f"({t_dim * h_dim * w_dim}) != seq_len ({seq_len}), skipping" ) except Exception: - # Silently skip -- _resolve_video_shape will fall back to config - pass + logger.debug("Failed to extract video_shape from Modality.positions", exc_info=True) def _is_ltx2_model(model: nn.Module) -> bool: @@ -104,9 +100,7 @@ def _is_ltx2_model(model: nn.Module) -> bool: """ if type(model).__name__ == "LTXModel": return True - return any( - type(m).__name__ == "LTXSelfAttention" for m in model.modules() - ) + return any(type(m).__name__ == "LTXSelfAttention" for m in model.modules()) def _is_ltx2_attention_module(module: nn.Module, name: str = "") -> bool: @@ -117,7 +111,7 @@ def _is_ltx2_attention_module(module: nn.Module, name: str = "") -> bool: Args: module: Module to check. - name: Module name in model hierarchyx. + name: Module name in model hierarchy. Returns: True if module is an LTX-2 attention module. @@ -132,6 +126,7 @@ def _is_ltx2_attention_module(module: nn.Module, name: str = "") -> bool: and hasattr(module, "to_v") and hasattr(module, "q_norm") and hasattr(module, "k_norm") + and hasattr(module, "rope_type") ) @@ -139,9 +134,8 @@ class _LTX2SparseAttention(SparseAttentionModule): """Sparse attention wrapper for LTX-2 Attention modules. This plugin handles all LTX-2 specific logic: - - Argument mapping (x -> hidden_states, context -> encoder_hidden_states) - - Q/K/V projection and normalization - - RoPE application + - Q/K/V projection and normalization (using native LTX-2 args: x, context, pe, k_pe) + - RoPE application via ltx_core - Trainable gate_compress for VSA quality optimization The plugin computes Q, K, V directly and calls VSA.forward_attention(), @@ -172,9 +166,9 @@ def _setup(self): def _compute_qkv( self, x: torch.Tensor, - context: Optional[torch.Tensor], - pe: Optional[torch.Tensor] = None, - k_pe: Optional[torch.Tensor] = None, + context: torch.Tensor | None, + pe: torch.Tensor | None = None, + k_pe: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute Q, K, V projections with LTX-2 specific processing. @@ -208,7 +202,8 @@ def _compute_qkv( except ModuleNotFoundError: raise ModuleNotFoundError( "LTX-2 VSA plugin requires the 'ltx_core' package for RoPE support. " - "Please install it before using VSA with LTX-2 models." + "The plugin registered successfully, but 'ltx_core' is needed at runtime. " + "Install it with: pip install ltx-core" ) from None query = apply_rotary_emb(query, pe, self.rope_type) @@ -246,7 +241,7 @@ def _resolve_video_shape(self, seq_len: int) -> tuple[int, int, int] | None: """Resolve video_shape for the current forward pass. Resolution order (mirrors FastVideo's metadata flow): - 1. ``_current_vsa_video_shape`` -- set by the forward pre-hook from + 1. ``root_model._vsa_video_shape`` -- set by the forward pre-hook from ``Modality.positions`` (analogous to ``get_forward_context().attn_metadata``) 2. ``method.video_shape`` -- explicitly set via the sparsify config @@ -256,11 +251,15 @@ def _resolve_video_shape(self, seq_len: int) -> tuple[int, int, int] | None: Returns: Tuple (T, H, W) or None if not determinable. """ - # 1. Primary: video_shape extracted by forward pre-hook - if _current_vsa_video_shape is not None: - t, h, w = _current_vsa_video_shape - if t * h * w == seq_len: - return _current_vsa_video_shape + # 1. Primary: video_shape extracted by forward pre-hook on root model + root_ref = getattr(self, "_vsa_root_model_ref", None) + root = root_ref() if root_ref is not None else None + if root is not None: + shape = getattr(root, "_vsa_video_shape", None) + if shape is not None: + t, h, w = shape + if t * h * w == seq_len: + return shape # 2. Fallback: explicit video_shape from sparsify config method = getattr(self, "_sparse_method_instance", None) @@ -428,6 +427,16 @@ def register_ltx2_attention(model: nn.Module) -> int: if num_modules > 0: logger.info(f"Found {num_modules} LTX-2 Attention modules in model") + # Store a weak reference to the root model on each attention module so + # _resolve_video_shape() can read model._vsa_video_shape without globals. + # Using weakref avoids circular module registration (nn.Module.__setattr__ + # would register a plain Module reference as a submodule, causing infinite + # recursion in named_children()). + root_ref = weakref.ref(model) + for _, module in model.named_modules(): + if _is_ltx2_attention_module(module): + object.__setattr__(module, "_vsa_root_model_ref", root_ref) + # Register forward pre-hook to extract video_shape from Modality.positions # before each forward pass -- analogous to FastVideo's # set_forward_context(attn_metadata=builder.build(...)) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py b/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py index 8ebb8fb99..bf6c8f25a 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py @@ -261,8 +261,8 @@ def test_video_shape_none_allowed(self): def test_vsa_config_defaults(self): cfg = VSAConfig() - assert "*attention*" in cfg.sparse_cfg - assert cfg.sparse_cfg["*attention*"]["method"] == "vsa" + assert "*attn*" in cfg.sparse_cfg + assert cfg.sparse_cfg["*attn*"]["method"] == "vsa" # --------------------------------------------------------------------------- @@ -303,15 +303,30 @@ def test_ltx2_attention_by_structure(self): _is_ltx2_attention_module, ) - # Module with LTX-2 attribute signature + # Module with LTX-2 attribute signature (includes rope_type) m = torch.nn.Module() m.to_q = torch.nn.Linear(8, 8) m.to_k = torch.nn.Linear(8, 8) m.to_v = torch.nn.Linear(8, 8) m.q_norm = torch.nn.LayerNorm(8) m.k_norm = torch.nn.LayerNorm(8) + m.rope_type = "interleaved" assert _is_ltx2_attention_module(m) is True + def test_ltx2_attention_missing_rope_type(self): + from modelopt.torch.sparsity.attention_sparsity.plugins.ltx2 import ( + _is_ltx2_attention_module, + ) + + # Module with to_q/k/v + norms but NO rope_type — should NOT match + m = torch.nn.Module() + m.to_q = torch.nn.Linear(8, 8) + m.to_k = torch.nn.Linear(8, 8) + m.to_v = torch.nn.Linear(8, 8) + m.q_norm = torch.nn.LayerNorm(8) + m.k_norm = torch.nn.LayerNorm(8) + assert _is_ltx2_attention_module(m) is False + def test_non_attention_module(self): from modelopt.torch.sparsity.attention_sparsity.plugins.ltx2 import ( _is_ltx2_attention_module, From 220a8ca4b481f1ac88d8d70cca5762f71f26a47a Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 17 Mar 2026 11:42:24 -0700 Subject: [PATCH 09/10] Remove ltx-2 integration Signed-off-by: Kai Xu --- examples/video_diffusion/vsa/README.md | 182 ------ .../vsa/test_ltx2_vsa_integration.py | 596 ------------------ .../sparsity/attention_sparsity/config.py | 5 +- .../attention_sparsity/plugins/huggingface.py | 6 +- .../attention_sparsity/plugins/ltx2.py | 468 -------------- tests/examples/llm_eval/test_llm_eval.py | 55 -- .../sparsity/attention_sparsity/test_vsa.py | 71 --- 7 files changed, 3 insertions(+), 1380 deletions(-) delete mode 100644 examples/video_diffusion/vsa/README.md delete mode 100644 examples/video_diffusion/vsa/test_ltx2_vsa_integration.py delete mode 100644 modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py delete mode 100644 tests/examples/llm_eval/test_llm_eval.py diff --git a/examples/video_diffusion/vsa/README.md b/examples/video_diffusion/vsa/README.md deleted file mode 100644 index 9470ddd46..000000000 --- a/examples/video_diffusion/vsa/README.md +++ /dev/null @@ -1,182 +0,0 @@ -# Video Sparse Attention (VSA) Example - -This example demonstrates how to apply Video Sparse Attention (VSA) optimization to video diffusion models for faster inference. - -## Overview - -VSA is a two-branch sparse attention architecture designed specifically for video diffusion models: - -1. **Compression Branch**: Averages tokens within 3D video blocks (default 4x4x4 = 64 tokens) and computes coarse-grained attention for global context. - -2. **Sparse Branch**: Selects the top-K most important blocks based on attention scores and computes fine-grained attention only for those blocks. - -The branches are combined using learned gating: `output = compression * gate_compress + sparse` - -## Requirements - -```bash -pip install torch>=2.0 -pip install modelopt -# Optional: pip install diffusers # For real video diffusion models -``` - -## Quick Start - -### Using LTX-2 Trainer (Recommended) - -```bash -# Full video generation with VSA vs baseline comparison -python test_ltx2_vsa_integration.py \ - --checkpoint path/to/model.safetensors \ - --text-encoder-path path/to/gemma \ - --compare - -# Generate video with custom sparsity -python test_ltx2_vsa_integration.py \ - --checkpoint path/to/model.safetensors \ - --text-encoder-path path/to/gemma \ - --top-k-ratio 0.3 --output my_video.mp4 -``` - -## Configuration Options - -| Parameter | Default | Description | -|-----------|---------|-------------| -| `--checkpoint` | (required) | Path to model checkpoint (.safetensors) | -| `--text-encoder-path` | (required) | Path to Gemma text encoder directory | -| `--prompt` | A serene mountain... | Text prompt for generation | -| `--top-k-ratio` | 0.5 | Ratio of blocks to keep (0.0-1.0). Lower = more sparse | -| `--num-frames` | 121 | Number of video frames (must be k*8 + 1) | -| `--height` | 512 | Video height (must be divisible by 32) | -| `--width` | 768 | Video width (must be divisible by 32) | -| `--num-inference-steps` | 30 | Number of denoising steps | -| `--guidance-scale` | 4.0 | Classifier-free guidance scale | -| `--seed` | 42 | Random seed for reproducibility | -| `--device` | cuda | Device (cuda/cpu) | -| `--compare` | off | Run both baseline and VSA for comparison | -| `--no-vsa` | off | Disable VSA (baseline only) | - -## API Usage - -```python -import modelopt.torch.sparsity.attention_sparsity as mtsa -from modelopt.torch.sparsity.attention_sparsity.config import VSA_DEFAULT - -# Load your video diffusion model -model = load_video_diffusion_model() - -# Apply VSA with default settings -model = mtsa.sparsify(model, config=VSA_DEFAULT) - -# Or with custom configuration -custom_config = { - "sparse_cfg": { - "*attn*": { - "method": "vsa", - "block_size_3d": (4, 4, 4), - "top_k_ratio": 0.3, # 70% sparsity - "video_shape": (16, 28, 48), - "enable": True, - }, - "default": {"enable": False}, - }, -} -model = mtsa.sparsify(model, config=custom_config) - -# Run inference -output = model(video_latents) -``` - -## Model Requirements - -For optimal VSA performance, video diffusion models should expose a `gate_compress` parameter in their attention layers. This is a learned parameter that controls the balance between the compression and sparse branches. - -Example attention layer interface: - -```python -class VideoAttention(nn.Module): - def __init__(self, hidden_dim, num_heads): - super().__init__() - self.to_q = nn.Linear(hidden_dim, hidden_dim) - self.to_k = nn.Linear(hidden_dim, hidden_dim) - self.to_v = nn.Linear(hidden_dim, hidden_dim) - # VSA-specific: learned gating - self.to_gate_compress = nn.Linear(hidden_dim, hidden_dim) -``` - -If `gate_compress` is not available, VSA will use equal weighting (sum of both branches). - -## Expected Performance - -| Top-K Ratio | Sparsity | Typical Speedup | -|-------------|----------|-----------------| -| 0.5 | 50% | 1.5-2x | -| 0.3 | 70% | 2-3x | -| 0.2 | 80% | 3-4x | - -*Actual speedup depends on model architecture, video resolution, and hardware.* - -## Troubleshooting - -### "video_shape must be set" error - -Make sure to provide `video_shape` in the configuration matching your video dimensions after patchification. - -### Low speedup - -- VSA is most effective for long sequences (high video resolution or many frames) -- For short sequences, the overhead of block operations may reduce gains -- Ensure you're using GPU with CUDA - -### Quality degradation - -- Increase `top_k_ratio` to keep more blocks -- Ensure your model has `gate_compress` for optimal branch balancing - -## LTX-2 Integration - -LTX-2 is a state-of-the-art video diffusion model that is well-suited for VSA optimization due to its high token count. - -### LTX-2 Architecture Summary - -| Component | Description | -|-----------|-------------| -| **Transformer** | 48 layers, 32 heads x 128 dim = 4096 hidden | -| **Compression** | 1:8192 pixels-to-tokens (aggressive) | -| **Attention Types** | Self-attn (attn1), Cross-attn (attn2), Audio attn, Cross-modal | - -### Example Scripts - -| Script | Purpose | -|--------|---------| -| `test_ltx2_vsa_integration.py` | Test VSA with LTX-2 trainer pipeline | - -### VSA Targets for LTX-2 - -VSA is applied only to **self-attention (attn1)** modules: - -```python -vsa_config = { - "sparse_cfg": { - "*.attn1": { # [OK] Self-attention - VSA enabled - "method": "vsa", - "top_k_ratio": 0.5, - "block_size_3d": [4, 4, 4], - }, - "*.attn2": {"enable": False}, # [NO] Text cross-attention - "*.audio_attn*": {"enable": False}, # [NO] Audio attention - "*.audio_to_video*": {"enable": False}, # [NO] Cross-modal - "*.video_to_audio*": {"enable": False}, # [NO] Cross-modal - }, -} -``` - -### Expected Token Counts for LTX-2 - -| Resolution | Frames | Tokens | VSA Tiles | Recommendation | -|------------|--------|--------|-----------|----------------| -| 512x768 | 121 | ~5,808 | 91 | Excellent for VSA | -| 384x384 | 49 | ~907 | 14 | Marginal | -| 256x256 | 25 | ~200 | 3 | Too small | - -For best VSA performance, use **121+ frames @ 512x768+** resolution. diff --git a/examples/video_diffusion/vsa/test_ltx2_vsa_integration.py b/examples/video_diffusion/vsa/test_ltx2_vsa_integration.py deleted file mode 100644 index b8ea94052..000000000 --- a/examples/video_diffusion/vsa/test_ltx2_vsa_integration.py +++ /dev/null @@ -1,596 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Test VSA integration with LTX-2 video generation. - -This script tests Video Sparse Attention (VSA) on the full LTX-2 pipeline, -measuring performance improvements and validating output quality. - -Usage: - # Test with VSA enabled (default) - python test_ltx2_vsa_integration.py \ - --checkpoint path/to/model.safetensors \ - --text-encoder-path path/to/gemma \ - --prompt "A cat playing with a ball" - - # Test without VSA (baseline) - python test_ltx2_vsa_integration.py \ - --checkpoint path/to/model.safetensors \ - --text-encoder-path path/to/gemma \ - --prompt "A cat playing with a ball" \ - --no-vsa - - # Compare both (recommended) - python test_ltx2_vsa_integration.py \ - --checkpoint path/to/model.safetensors \ - --text-encoder-path path/to/gemma \ - --prompt "A cat playing with a ball" \ - --compare - - # Custom VSA parameters - python test_ltx2_vsa_integration.py \ - --checkpoint path/to/model.safetensors \ - --text-encoder-path path/to/gemma \ - --prompt "A cat playing with a ball" \ - --top-k-ratio 0.5 \ - --num-frames 121 --height 512 --width 768 - -VSA improves attention performance by using 3D tile-based sparsity: -- Automatically adapts to LTX-2's compressed token sequence -""" - -import argparse -import copy -import time -from pathlib import Path - -import torch -from ltx_trainer.model_loader import load_model -from ltx_trainer.progress import StandaloneSamplingProgress -from ltx_trainer.validation_sampler import GenerationConfig, ValidationSampler -from ltx_trainer.video_utils import save_video - -from modelopt.torch.sparsity.attention_sparsity import sparsify -from modelopt.torch.sparsity.attention_sparsity.config import VSA_DEFAULT - - -def calculate_expected_tokens(num_frames: int, height: int, width: int) -> int: - """Calculate expected token count for LTX-2. - - LTX-2 uses 1:8192 pixels-to-tokens compression ratio. - """ - pixels = num_frames * height * width - tokens = pixels // 8192 - return tokens - - -def is_vsa_compatible(num_frames: int, height: int, width: int) -> tuple[bool, str]: - """Check if input size is compatible with VSA. - - Args: - num_frames: Number of video frames. - height: Video height in pixels. - width: Video width in pixels. - - Returns: - Tuple of (is_compatible, reason_message). - """ - tokens = calculate_expected_tokens(num_frames, height, width) - tiles = tokens // 64 # VSA tile size: 4x4x4 = 64 - - if tiles >= 90: - return True, f"Excellent: {tokens} tokens ({tiles} tiles)" - elif tiles >= 16: - return True, f"Marginal: {tokens} tokens ({tiles} tiles)" - else: - return False, f"Too small: {tokens} tokens ({tiles} tiles, need 16+ for VSA)" - - -def apply_vsa_to_transformer( - transformer: torch.nn.Module, - num_frames: int, - height: int, - width: int, - top_k_ratio: float = 0.5, -) -> torch.nn.Module: - """Apply VSA to the LTX-2 transformer. - - Args: - transformer: The transformer model. - num_frames: Number of frames (for compatibility checking). - height: Video height (for compatibility checking). - width: Video width (for compatibility checking). - top_k_ratio: Sparsity ratio (0.5 = 50% sparsity). - - Returns: - Modified transformer with VSA enabled. - """ - print("\nConfiguring VSA for LTX-2...") - - # Check compatibility - tokens = calculate_expected_tokens(num_frames, height, width) - tiles = tokens // 64 - compatible, reason = is_vsa_compatible(num_frames, height, width) - - print(f" Expected sequence: {tokens} tokens ({tiles} tiles)") - print(f" VSA compatibility: {reason}") - - if not compatible: - print(" [WARNING] Input size may be too small for VSA to provide significant benefit.") - print(" Consider using larger inputs (121+ frames @ 512x768+) for best results.") - - # Configure VSA using the standard preset, overriding top_k_ratio if needed - sparse_config = copy.deepcopy(VSA_DEFAULT) - # Find the attn pattern key and override top_k_ratio - for cfg in sparse_config["sparse_cfg"].values(): - if isinstance(cfg, dict) and cfg.get("method") == "vsa": - cfg["top_k_ratio"] = top_k_ratio - - # Apply VSA to transformer - print(" Applying VSA to attention modules...") - transformer = sparsify(transformer, sparse_config) - - return transformer - - -def run_generation( - sampler: ValidationSampler, - config: GenerationConfig, - device: str, - num_inference_steps: int, - label: str = "", -) -> tuple[torch.Tensor, torch.Tensor | None, float]: - """Run video generation and return timing information. - - Args: - sampler: ValidationSampler instance. - config: Generation configuration. - device: Device to run on. - num_inference_steps: Number of denoising steps. - label: Label for logging (e.g., "BASELINE", "WITH VSA"). - - Returns: - Tuple of (video, audio, elapsed_time) - """ - if label: - print(f"\n{label}") - - print(f"Generating video ({num_inference_steps} steps)...") - start_time = time.time() - - with StandaloneSamplingProgress(num_steps=num_inference_steps) as progress: - sampler.sampling_context = progress - video, audio = sampler.generate(config=config, device=device) - - elapsed = time.time() - start_time - print(f"Generation completed in {elapsed:.2f}s") - - return video, audio, elapsed - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Test VSA integration with LTX-2 video generation", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - # Model arguments - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to model checkpoint (.safetensors)", - ) - parser.add_argument( - "--text-encoder-path", - type=str, - required=True, - help="Path to Gemma text encoder directory", - ) - - # Generation arguments - parser.add_argument( - "--prompt", - type=str, - default="A serene mountain landscape with a flowing river, golden hour lighting", - help="Text prompt for generation", - ) - parser.add_argument( - "--negative-prompt", - type=str, - default="", - help="Negative prompt", - ) - parser.add_argument( - "--height", - type=int, - default=512, - help="Video height (must be divisible by 32)", - ) - parser.add_argument( - "--width", - type=int, - default=768, - help="Video width (must be divisible by 32)", - ) - parser.add_argument( - "--num-frames", - type=int, - default=121, - help="Number of video frames (must be k*8 + 1)", - ) - parser.add_argument( - "--frame-rate", - type=float, - default=25.0, - help="Video frame rate", - ) - parser.add_argument( - "--num-inference-steps", - type=int, - default=30, - help="Number of denoising steps", - ) - parser.add_argument( - "--guidance-scale", - type=float, - default=4.0, - help="Classifier-free guidance scale (CFG)", - ) - parser.add_argument( - "--seed", - type=int, - default=42, - help="Random seed for reproducibility", - ) - - # VSA arguments - parser.add_argument( - "--no-vsa", - action="store_true", - help="Disable VSA (for baseline comparison)", - ) - parser.add_argument( - "--compare", - action="store_true", - help="Run both with and without VSA for comparison", - ) - parser.add_argument( - "--top-k-ratio", - type=float, - default=0.5, - help="VSA sparsity ratio (0.5 = 50%% sparsity)", - ) - - # Audio arguments - parser.add_argument( - "--skip-audio", - action="store_true", - help="Skip audio generation (faster testing)", - ) - - # Output arguments - parser.add_argument( - "--output", - type=str, - default="output_vsa.mp4", - help="Output video path (.mp4)", - ) - parser.add_argument( - "--output-baseline", - type=str, - default="output_baseline.mp4", - help="Baseline output path (used with --compare)", - ) - - # Device arguments - parser.add_argument( - "--device", - type=str, - default="cuda", - help="Device to run on (cuda/cpu)", - ) - - args = parser.parse_args() - - # Validate arguments - generate_audio = not args.skip_audio - - print("=" * 80) - print("LTX-2 + VSA Integration Test") - print("=" * 80) - - # Check VSA compatibility - tokens = calculate_expected_tokens(args.num_frames, args.height, args.width) - tiles = tokens // 64 - compatible, reason = is_vsa_compatible(args.num_frames, args.height, args.width) - - print("\nInput Configuration:") - print(f" Resolution: {args.width}x{args.height}") - print(f" Frames: {args.num_frames} @ {args.frame_rate} fps") - print(f" Expected tokens: {tokens} ({tiles} tiles)") - print(f" VSA compatibility: {reason}") - - if not compatible and not args.no_vsa and not args.compare: - print("\n[WARNING] WARNING: Input size may be too small for VSA benefit") - print(" Recommended: 121+ frames @ 512x768+ for optimal VSA performance") - print(" Use --no-vsa to disable VSA for small inputs") - - # Load model components - print("\nLoading LTX-2 model components...") - components = load_model( - checkpoint_path=args.checkpoint, - device="cpu", # Load to CPU first - dtype=torch.bfloat16, - with_video_vae_encoder=False, - with_video_vae_decoder=True, - with_audio_vae_decoder=generate_audio, - with_vocoder=generate_audio, - with_text_encoder=True, - text_encoder_path=args.text_encoder_path, - ) - print("Model components loaded") - - # Create generation config - gen_config = GenerationConfig( - prompt=args.prompt, - negative_prompt=args.negative_prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, - frame_rate=args.frame_rate, - num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, - seed=args.seed, - condition_image=None, - reference_video=None, - generate_audio=generate_audio, - include_reference_in_output=False, - ) - - print("\n" + "=" * 80) - print("Generation Parameters") - print("=" * 80) - print(f"Prompt: {args.prompt}") - if args.negative_prompt: - print(f"Negative prompt: {args.negative_prompt}") - print(f"Resolution: {args.width}x{args.height}") - print(f"Frames: {args.num_frames} @ {args.frame_rate} fps") - print(f"Inference steps: {args.num_inference_steps}") - print(f"CFG scale: {args.guidance_scale}") - print(f"Seed: {args.seed}") - if generate_audio: - video_duration = args.num_frames / args.frame_rate - print(f"Audio: Enabled (duration: {video_duration:.2f}s)") - else: - print("Audio: Disabled (skip-audio mode)") - print("=" * 80) - - # Test scenarios - results = {} - - if args.compare: - # ====================================================================== - # Run BASELINE (no VSA) - # ====================================================================== - print("\n" + "=" * 80) - print("TEST 1/2: BASELINE (no VSA)") - print("=" * 80) - - # Create sampler without VSA - sampler_baseline = ValidationSampler( - transformer=components.transformer, - vae_decoder=components.video_vae_decoder, - vae_encoder=components.video_vae_encoder, - text_encoder=components.text_encoder, - audio_decoder=components.audio_vae_decoder if generate_audio else None, - vocoder=components.vocoder if generate_audio else None, - ) - - try: - video_baseline, audio_baseline, time_baseline = run_generation( - sampler_baseline, - gen_config, - args.device, - args.num_inference_steps, - ) - results["baseline"] = time_baseline - - # Save baseline video - output_baseline_path = Path(args.output_baseline) - output_baseline_path.parent.mkdir(parents=True, exist_ok=True) - - audio_sample_rate = None - if audio_baseline is not None and components.vocoder is not None: - audio_sample_rate = components.vocoder.output_sample_rate - - save_video( - video_tensor=video_baseline, - output_path=output_baseline_path, - fps=args.frame_rate, - audio=audio_baseline, - audio_sample_rate=audio_sample_rate, - ) - print(f"Baseline video saved: {args.output_baseline}") - except Exception as e: - print(f"Baseline generation failed: {e}") - import traceback - - traceback.print_exc() - return - - # ====================================================================== - # Run WITH VSA - # ====================================================================== - print("\n" + "=" * 80) - print("TEST 2/2: WITH VSA") - print("=" * 80) - - # Reload transformer for VSA test - print("\nReloading transformer for VSA test...") - components_vsa = load_model( - checkpoint_path=args.checkpoint, - device="cpu", - dtype=torch.bfloat16, - with_video_vae_encoder=False, - with_video_vae_decoder=True, - with_audio_vae_decoder=generate_audio, - with_vocoder=generate_audio, - with_text_encoder=True, - text_encoder_path=args.text_encoder_path, - ) - - # Apply VSA - components_vsa.transformer = apply_vsa_to_transformer( - components_vsa.transformer, - args.num_frames, - args.height, - args.width, - top_k_ratio=args.top_k_ratio, - ) - - # Create sampler with VSA - sampler_vsa = ValidationSampler( - transformer=components_vsa.transformer, - vae_decoder=components_vsa.video_vae_decoder, - vae_encoder=components_vsa.video_vae_encoder, - text_encoder=components_vsa.text_encoder, - audio_decoder=components_vsa.audio_vae_decoder if generate_audio else None, - vocoder=components_vsa.vocoder if generate_audio else None, - ) - - try: - video_vsa, audio_vsa, time_vsa = run_generation( - sampler_vsa, - gen_config, - args.device, - args.num_inference_steps, - ) - results["vsa"] = time_vsa - - # Save VSA video - output_vsa_path = Path(args.output) - output_vsa_path.parent.mkdir(parents=True, exist_ok=True) - - audio_sample_rate = None - if audio_vsa is not None and components_vsa.vocoder is not None: - audio_sample_rate = components_vsa.vocoder.output_sample_rate - - save_video( - video_tensor=video_vsa, - output_path=output_vsa_path, - fps=args.frame_rate, - audio=audio_vsa, - audio_sample_rate=audio_sample_rate, - ) - print(f"VSA video saved: {args.output}") - except Exception as e: - print(f"VSA generation failed: {e}") - import traceback - - traceback.print_exc() - return - - else: - # ====================================================================== - # Single test (with or without VSA) - # ====================================================================== - print("\n" + "=" * 80) - print(f"TEST: {'WITH VSA' if not args.no_vsa else 'WITHOUT VSA'}") - print("=" * 80) - - transformer = components.transformer - - # Apply VSA if enabled - if not args.no_vsa: - transformer = apply_vsa_to_transformer( - transformer, - args.num_frames, - args.height, - args.width, - top_k_ratio=args.top_k_ratio, - ) - - # Create sampler - sampler = ValidationSampler( - transformer=transformer, - vae_decoder=components.video_vae_decoder, - vae_encoder=components.video_vae_encoder, - text_encoder=components.text_encoder, - audio_decoder=components.audio_vae_decoder if generate_audio else None, - vocoder=components.vocoder if generate_audio else None, - ) - - try: - video, audio, elapsed = run_generation( - sampler, - gen_config, - args.device, - args.num_inference_steps, - ) - results["single"] = elapsed - - # Save video - output_path = Path(args.output) - output_path.parent.mkdir(parents=True, exist_ok=True) - - audio_sample_rate = None - if audio is not None and components.vocoder is not None: - audio_sample_rate = components.vocoder.output_sample_rate - - save_video( - video_tensor=video, - output_path=output_path, - fps=args.frame_rate, - audio=audio, - audio_sample_rate=audio_sample_rate, - ) - print(f"Video saved: {args.output}") - except Exception as e: - print(f"Generation failed: {e}") - import traceback - - traceback.print_exc() - return - - # ========================================================================== - # Results Summary - # ========================================================================== - print("\n" + "=" * 80) - print("TEST COMPLETE") - print("=" * 80) - - if args.compare: - speedup = results["baseline"] / results["vsa"] - print("\nPerformance Comparison:") - print(f" Baseline (no VSA): {results['baseline']:.2f}s") - print(f" With VSA: {results['vsa']:.2f}s") - print(f" Speedup: {speedup:.2f}x") - print() - print(f" Baseline video: {args.output_baseline}") - print(f" VSA video: {args.output}") - print() - else: - print(f"\nGeneration time: {results['single']:.2f}s") - print(f"Output: {args.output}") - - print("\nVSA integration test successful!") - print("=" * 80) - - -if __name__ == "__main__": - main() diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 7c0ba0842..6d85812c1 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -559,10 +559,7 @@ class VSAConfig(SparseAttentionConfig): # Pre-defined VSA Configuration for video diffusion models. -# Pattern "*attn*" matches all LTX-2 attention module names: -# - Video self-attention: attn1, attn2 -# - Audio self-attention: audio_attn1, audio_attn2 -# - Cross-modal attention: audio_to_video_attn, video_to_audio_attn +# Pattern "*attn*" matches attention module names by convention. VSA_DEFAULT = { "sparse_cfg": { "*attn*": { diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py index 51191b304..599832943 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -23,8 +23,6 @@ from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry from . import CUSTOM_MODEL_PLUGINS -logger = logging.getLogger(__name__) - class _GenericSparseAttention(SparseAttentionModule): """Generic sparse attention that works with any HF attention module. @@ -93,10 +91,10 @@ def register_sparse_attention_on_the_fly(model: nn.Module) -> bool: SparseAttentionRegistry.register({module_type: type_name})(_GenericSparseAttention) attention_types.add(module_type) registered_count += 1 - logger.info(f"Registered {type_name} for sparse attention optimization") + print(f"Registered {type_name} for sparse attention optimization") if registered_count > 0: - logger.info(f"Dynamically registered {registered_count} attention module types for sparsity") + print(f"Dynamically registered {registered_count} attention module types for sparsity") return registered_count > 0 diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py b/modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py deleted file mode 100644 index 41179913a..000000000 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py +++ /dev/null @@ -1,468 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Plugin for LTX-2 video diffusion models with VSA support. - -LTX-2 uses a specific Attention module structure that differs from standard -HuggingFace/Diffusers attention. This plugin provides: - -1. Detection of LTX-2's native Attention modules -2. Q/K/V projection, RMSNorm, and RoPE handling -3. Support for trainable gate_compress for VSA quality optimization -""" - -import logging -import weakref - -import torch -import torch.nn as nn - -from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry -from . import CUSTOM_MODEL_PLUGINS - -logger = logging.getLogger(__name__) - - -def _extract_video_shape_hook(module: nn.Module, args: tuple) -> None: - """Forward pre-hook on LTXModel to extract dit_seq_shape from Modality.positions. - - Mirrors FastVideo's ``VideoSparseAttentionMetadataBuilder.build()`` which - computes ``dit_seq_shape = raw_latent_shape // patch_size``. Here we derive - the same shape by counting unique position values per dimension in the - ``Modality.positions`` tensor, which is available at the LTXModel entry - point (before ``TransformerArgsPreprocessor`` converts it to RoPE embeddings). - - The result is stored on the model instance as ``module._vsa_video_shape`` - so that ``_LTX2SparseAttention._resolve_video_shape()`` can read it via - its ``_vsa_root_model`` reference. This avoids module-level global state - and is safe for concurrent models. - """ - # LTXModel.forward(self, video: Modality | None, audio, perturbations) - video = args[0] if len(args) > 0 else None - if video is None or not hasattr(video, "positions") or video.positions is None: - return - - positions = video.positions # (B, 3, T) or (B, 3, T, 2) - - try: - if positions.ndim == 4: - # (B, 3, T, 2) -- take start coordinates - pos_per_dim = positions[0, :, :, 0] # (3, T) - elif positions.ndim == 3: - # (B, 3, T) - pos_per_dim = positions[0] # (3, T) - else: - return - - t_dim = pos_per_dim[0].unique().numel() - h_dim = pos_per_dim[1].unique().numel() - w_dim = pos_per_dim[2].unique().numel() - seq_len = positions.shape[2] - - if t_dim * h_dim * w_dim == seq_len: - module._vsa_video_shape = (t_dim, h_dim, w_dim) - logger.debug( - f"Extracted dit_seq_shape={module._vsa_video_shape} from " - f"Modality.positions (seq_len={seq_len})" - ) - else: - logger.debug( - f"Position-derived shape {(t_dim, h_dim, w_dim)} product " - f"({t_dim * h_dim * w_dim}) != seq_len ({seq_len}), skipping" - ) - except Exception: - logger.debug("Failed to extract video_shape from Modality.positions", exc_info=True) - - -def _is_ltx2_model(model: nn.Module) -> bool: - """Check if model is an LTX-2 model. - - Uses LTXModel / LTXSelfAttention class names to avoid false positives - from other DiTs (e.g., LongCat) that share similar attribute patterns. - - Args: - model: PyTorch model to check. - - Returns: - True if model is LTX-2 (root class LTXModel or contains LTXSelfAttention). - """ - if type(model).__name__ == "LTXModel": - return True - return any(type(m).__name__ == "LTXSelfAttention" for m in model.modules()) - - -def _is_ltx2_attention_module(module: nn.Module, name: str = "") -> bool: - """Check if a module is an LTX-2 Attention module by class name or structure. - - Primary: class name is LTXSelfAttention. Fallback: has to_q/k/v, q_norm, - k_norm, and rope_type (unique to LTX-2 among DiTs). - - Args: - module: Module to check. - name: Module name in model hierarchy. - - Returns: - True if module is an LTX-2 attention module. - """ - class_name = type(module).__name__ - if class_name == "LTXSelfAttention": - return True - # Fallback for subclasses or renamed variants: must have rope_type (LTX-2 only) - return ( - hasattr(module, "to_q") - and hasattr(module, "to_k") - and hasattr(module, "to_v") - and hasattr(module, "q_norm") - and hasattr(module, "k_norm") - and hasattr(module, "rope_type") - ) - - -class _LTX2SparseAttention(SparseAttentionModule): - """Sparse attention wrapper for LTX-2 Attention modules. - - This plugin handles all LTX-2 specific logic: - - Q/K/V projection and normalization (using native LTX-2 args: x, context, pe, k_pe) - - RoPE application via ltx_core - - Trainable gate_compress for VSA quality optimization - - The plugin computes Q, K, V directly and calls VSA.forward_attention(), - keeping VSA as a pure algorithm without module-specific knowledge. - """ - - def _setup(self): - """Setup the VSA wrapper with trainable gate_compress.""" - super()._setup() - - # Check if we need to add gate_compress projection - if not hasattr(self, "to_gate_compress"): - to_q = self.to_q - in_features = to_q.in_features - out_features = to_q.out_features - - # Create gate_compress projection (zero-initialized) - self.to_gate_compress = nn.Linear(in_features, out_features, bias=True) - nn.init.zeros_(self.to_gate_compress.weight) - nn.init.zeros_(self.to_gate_compress.bias) - - # Move to same device/dtype as to_q - self.to_gate_compress = self.to_gate_compress.to( - device=to_q.weight.device, - dtype=to_q.weight.dtype, - ) - - def _compute_qkv( - self, - x: torch.Tensor, - context: torch.Tensor | None, - pe: torch.Tensor | None = None, - k_pe: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute Q, K, V projections with LTX-2 specific processing. - - Args: - x: Input tensor [batch, seq, hidden_dim]. - context: Context for cross-attention, or None for self-attention. - pe: Positional embeddings for RoPE. - k_pe: Optional separate positional embeddings for keys. - - Returns: - Tuple of (query, key, value) tensors in [batch, seq, hidden_dim] format. - """ - # For self-attention, use x for K, V - context = context if context is not None else x - - # Project to Q, K, V - query = self.to_q(x) - key = self.to_k(context) - value = self.to_v(context) - - # Apply Q/K norms (LTX-2 specific) - if hasattr(self, "q_norm"): - query = self.q_norm(query) - if hasattr(self, "k_norm"): - key = self.k_norm(key) - - # Apply RoPE if provided (LTX-2 specific) - if pe is not None and hasattr(self, "rope_type"): - try: - from ltx_core.model.transformer.rope import apply_rotary_emb - except ModuleNotFoundError: - raise ModuleNotFoundError( - "LTX-2 VSA plugin requires the 'ltx_core' package for RoPE support. " - "The plugin registered successfully, but 'ltx_core' is needed at runtime. " - "Install it with: pip install ltx-core" - ) from None - - query = apply_rotary_emb(query, pe, self.rope_type) - key = apply_rotary_emb(key, pe if k_pe is None else k_pe, self.rope_type) - - return query, key, value - - def _reshape_for_vsa(self, tensor: torch.Tensor, num_heads: int) -> torch.Tensor: - """Reshape tensor from [batch, seq, hidden_dim] to [batch, heads, seq, head_dim]. - - Args: - tensor: Input tensor [batch, seq, hidden_dim]. - num_heads: Number of attention heads. - - Returns: - Reshaped tensor [batch, heads, seq, head_dim]. - """ - batch, seq_len, hidden_dim = tensor.shape - head_dim = hidden_dim // num_heads - return tensor.view(batch, seq_len, num_heads, head_dim).transpose(1, 2) - - def _reshape_from_vsa(self, tensor: torch.Tensor) -> torch.Tensor: - """Reshape tensor from [batch, heads, seq, head_dim] to [batch, seq, hidden_dim]. - - Args: - tensor: Input tensor [batch, heads, seq, head_dim]. - - Returns: - Reshaped tensor [batch, seq, hidden_dim]. - """ - batch, heads, seq_len, head_dim = tensor.shape - return tensor.transpose(1, 2).contiguous().view(batch, seq_len, heads * head_dim) - - def _resolve_video_shape(self, seq_len: int) -> tuple[int, int, int] | None: - """Resolve video_shape for the current forward pass. - - Resolution order (mirrors FastVideo's metadata flow): - 1. ``root_model._vsa_video_shape`` -- set by the forward pre-hook from - ``Modality.positions`` (analogous to ``get_forward_context().attn_metadata``) - 2. ``method.video_shape`` -- explicitly set via the sparsify config - - Args: - seq_len: Current sequence length (for validation). - - Returns: - Tuple (T, H, W) or None if not determinable. - """ - # 1. Primary: video_shape extracted by forward pre-hook on root model - root_ref = getattr(self, "_vsa_root_model_ref", None) - root = root_ref() if root_ref is not None else None - if root is not None: - shape = getattr(root, "_vsa_video_shape", None) - if shape is not None: - t, h, w = shape - if t * h * w == seq_len: - return shape - - # 2. Fallback: explicit video_shape from sparsify config - method = getattr(self, "_sparse_method_instance", None) - if method is not None and method.video_shape is not None: - t, h, w = method.video_shape - if t * h * w == seq_len: - return method.video_shape - - return None - - def forward(self, *args, **kwargs): - """Forward pass computing Q/K/V directly and calling VSA.forward_attention(). - - This method handles all LTX-2 specific logic: - 1. Extract arguments (uses LTX-2 native names: x, context, pe, k_pe) - 2. Compute Q, K, V projections with norms and RoPE - 3. Compute gate_compress - 4. Resolve video_shape from hook or config - 5. Check compatibility and call VSA or fallback - 6. Apply output projection - """ - x = kwargs.get("x") - if x is None and len(args) > 0: - x = args[0] - - if x is None: - return self._call_original_forward(*args, **kwargs) - - context = kwargs.get("context") - pe = kwargs.get("pe") - k_pe = kwargs.get("k_pe") - - # === Check cross-attention === - if context is not None: - if x.shape[1] != context.shape[1]: - # NOTE: skip VSA for Cross-attention, use original attention - return self._call_original_forward(*args, **kwargs) - - # === Check VSA method availability === - if not hasattr(self, "_sparse_method_instance") or self._sparse_method_instance is None: - return self._call_original_forward(*args, **kwargs) - - method = self._sparse_method_instance # VSA instance - - # === Compute Q, K, V === - query, key, value = self._compute_qkv(x, context, pe, k_pe) - - # === Check sequence length compatibility === - seq_len = query.shape[1] - block_size_3d = method.block_size_3d # type: ignore[attr-defined] - block_elements = block_size_3d[0] * block_size_3d[1] * block_size_3d[2] - - if seq_len < block_elements: - # Incompatible sequence length (e.g., audio attention with seq_len=32) - logger.debug(f"VSA skipped: seq_len={seq_len} < block_elements={block_elements}") - return self._call_original_forward(*args, **kwargs) - - # === Resolve video_shape === - video_shape = self._resolve_video_shape(seq_len) - if video_shape is None: - logger.debug(f"VSA skipped: no matching video_shape for seq_len={seq_len}") - return self._call_original_forward(*args, **kwargs) - - # === Compute gate_compress === - gate_compress = None - if hasattr(self, "to_gate_compress"): - gate_compress = self.to_gate_compress(x) - - # === Reshape for VSA: [batch, seq, hidden] -> [batch, heads, seq, head_dim] === - query = self._reshape_for_vsa(query, self.heads) - key = self._reshape_for_vsa(key, self.heads) - value = self._reshape_for_vsa(value, self.heads) - if gate_compress is not None: - gate_compress = self._reshape_for_vsa(gate_compress, self.heads) - - # === Call VSA forward_attention directly === - output, stats = method.forward_attention( # type: ignore[attr-defined] - query=query, - key=key, - value=value, - gate_compress=gate_compress, - video_shape=video_shape, - ) - - # Store stats for collection - method._last_stats = stats - - # === Reshape output: [batch, heads, seq, head_dim] -> [batch, seq, hidden] === - output = self._reshape_from_vsa(output) - - # === Apply output projection === - if hasattr(self, "to_out"): - output = self.to_out(output) - - return output - - def _call_original_forward(self, *args, **kwargs): - """Call the original module's forward method, bypassing VSA. - - Temporarily disables sparse attention so SparseAttentionModule.forward() - passes through to the original module. - """ - # Temporarily disable sparse attention to bypass sparse logic - # SparseAttentionModule.forward() checks is_enabled and passes through if False - was_enabled = getattr(self, "_enabled", True) - self._enabled = False - try: - # This goes through SparseAttentionModule.forward() which checks is_enabled, - # sees it's disabled, and calls DynamicModule.forward() -> original module - result = SparseAttentionModule.forward(self, *args, **kwargs) - finally: - self._enabled = was_enabled - return result - - def get_gate_compress_parameters(self): - """Get trainable gate_compress parameters. - - Returns: - Iterator of gate_compress parameters for optimization. - """ - if hasattr(self, "to_gate_compress"): - return self.to_gate_compress.parameters() - return iter([]) # Empty iterator - - -def register_ltx2_attention(model: nn.Module) -> int: - """Register LTX-2 Attention modules for VSA wrapping. - - This function detects LTX-2 Attention modules and registers them with - the SparseAttentionRegistry. It also handles unregistering any generic - wrappers that may have been registered first. - - Args: - model: LTX-2 model to process. - - Returns: - Number of module types registered. - """ - if not _is_ltx2_model(model): - return 0 - - registered_types = set() - num_modules = 0 - - for name, module in model.named_modules(): - if not _is_ltx2_attention_module(module, name): - continue - - num_modules += 1 - module_type = type(module) - - if module_type in registered_types: - continue - - # Unregister any existing generic wrapper - if module_type in SparseAttentionRegistry: - logger.debug(f"Unregistering generic wrapper for {module_type.__name__}") - SparseAttentionRegistry.unregister(module_type) - - # Register LTX-2 specific wrapper - SparseAttentionRegistry.register({module_type: module_type.__name__})(_LTX2SparseAttention) - registered_types.add(module_type) - logger.info(f"Registered LTX-2 attention: {module_type.__name__}") - - if num_modules > 0: - logger.info(f"Found {num_modules} LTX-2 Attention modules in model") - - # Store a weak reference to the root model on each attention module so - # _resolve_video_shape() can read model._vsa_video_shape without globals. - # Using weakref avoids circular module registration (nn.Module.__setattr__ - # would register a plain Module reference as a submodule, causing infinite - # recursion in named_children()). - root_ref = weakref.ref(model) - for _, module in model.named_modules(): - if _is_ltx2_attention_module(module): - object.__setattr__(module, "_vsa_root_model_ref", root_ref) - - # Register forward pre-hook to extract video_shape from Modality.positions - # before each forward pass -- analogous to FastVideo's - # set_forward_context(attn_metadata=builder.build(...)) - model.register_forward_pre_hook(_extract_video_shape_hook) - logger.debug("Registered VSA video_shape extraction hook on model") - - return len(registered_types) - - -def register_ltx2_on_the_fly(model: nn.Module) -> bool: - """Plugin entry point for LTX-2 VSA registration. - - Args: - model: PyTorch model to process. - - Returns: - True if any LTX-2 modules were registered. - """ - num_registered = register_ltx2_attention(model) - - if num_registered > 0: - logger.info(f"Registered {num_registered} LTX-2 attention types for VSA") - return True - - return False - - -# Add to plugin set (order-independent: guards against re-registration internally) -CUSTOM_MODEL_PLUGINS.add(register_ltx2_on_the_fly) diff --git a/tests/examples/llm_eval/test_llm_eval.py b/tests/examples/llm_eval/test_llm_eval.py deleted file mode 100644 index 88d29dedc..000000000 --- a/tests/examples/llm_eval/test_llm_eval.py +++ /dev/null @@ -1,55 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import subprocess - -from _test_utils.examples.models import TINY_LLAMA_PATH -from _test_utils.examples.run_command import run_llm_ptq_command -from _test_utils.torch.misc import minimum_sm - - -@minimum_sm(89) -def test_llama_eval_fp8(): - try: - run_llm_ptq_command( - model=TINY_LLAMA_PATH, - quant="fp8", - tasks="mmlu,lm_eval,simple_eval", - calib=64, - lm_eval_tasks="hellaswag,gsm8k", - simple_eval_tasks="humaneval", - lm_eval_limit=0.1, - batch=8, - ) - finally: - # Force kill llm-serve if it's still running - subprocess.run(["pkill", "-f", "llm-serve"], check=False) - - -def test_llama_eval_sparse_attention(tiny_llama_path): - """Test sparse attention with llm_eval integration.""" - try: - # Test with default sparse attention config (no quantization) - run_llm_ptq_command( - model=tiny_llama_path, - quant="none", # No quantization, only sparse attention - tasks="lm_eval", - lm_eval_tasks="hellaswag", - lm_eval_limit=0.05, # Small limit for fast test - sparse_cfg="SKIP_SOFTMAX_DEFAULT", - batch=4, - ) - finally: - subprocess.run(["pkill", "-f", "llm-serve"], check=False) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py b/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py index bf6c8f25a..84768e77a 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py @@ -19,7 +19,6 @@ - vsa_utils.py: tile/untile index logic, variable block sizes - vsa.py: VSA method init, metadata computation, validation, caching - config.py: VSAAttributeConfig validation -- plugins/ltx2.py: model/module detection helpers """ import math @@ -263,73 +262,3 @@ def test_vsa_config_defaults(self): cfg = VSAConfig() assert "*attn*" in cfg.sparse_cfg assert cfg.sparse_cfg["*attn*"]["method"] == "vsa" - - -# --------------------------------------------------------------------------- -# LTX-2 plugin: detection helpers -# --------------------------------------------------------------------------- - - -class TestLTX2Detection: - """Tests for _is_ltx2_model and _is_ltx2_attention_module.""" - - def test_non_ltx2_model(self): - from modelopt.torch.sparsity.attention_sparsity.plugins.ltx2 import _is_ltx2_model - - model = torch.nn.Linear(10, 10) - assert _is_ltx2_model(model) is False - - def test_ltx2_model_by_class_name(self): - from modelopt.torch.sparsity.attention_sparsity.plugins.ltx2 import _is_ltx2_model - - # Fake a class named LTXModel - class LTXModel(torch.nn.Module): - pass - - assert _is_ltx2_model(LTXModel()) is True - - def test_ltx2_attention_by_class_name(self): - from modelopt.torch.sparsity.attention_sparsity.plugins.ltx2 import ( - _is_ltx2_attention_module, - ) - - class LTXSelfAttention(torch.nn.Module): - pass - - assert _is_ltx2_attention_module(LTXSelfAttention()) is True - - def test_ltx2_attention_by_structure(self): - from modelopt.torch.sparsity.attention_sparsity.plugins.ltx2 import ( - _is_ltx2_attention_module, - ) - - # Module with LTX-2 attribute signature (includes rope_type) - m = torch.nn.Module() - m.to_q = torch.nn.Linear(8, 8) - m.to_k = torch.nn.Linear(8, 8) - m.to_v = torch.nn.Linear(8, 8) - m.q_norm = torch.nn.LayerNorm(8) - m.k_norm = torch.nn.LayerNorm(8) - m.rope_type = "interleaved" - assert _is_ltx2_attention_module(m) is True - - def test_ltx2_attention_missing_rope_type(self): - from modelopt.torch.sparsity.attention_sparsity.plugins.ltx2 import ( - _is_ltx2_attention_module, - ) - - # Module with to_q/k/v + norms but NO rope_type — should NOT match - m = torch.nn.Module() - m.to_q = torch.nn.Linear(8, 8) - m.to_k = torch.nn.Linear(8, 8) - m.to_v = torch.nn.Linear(8, 8) - m.q_norm = torch.nn.LayerNorm(8) - m.k_norm = torch.nn.LayerNorm(8) - assert _is_ltx2_attention_module(m) is False - - def test_non_attention_module(self): - from modelopt.torch.sparsity.attention_sparsity.plugins.ltx2 import ( - _is_ltx2_attention_module, - ) - - assert _is_ltx2_attention_module(torch.nn.Linear(10, 10)) is False From e75f42f2fc7b657005bdd781e619bdf2e0a5ad95 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 17 Mar 2026 16:41:42 -0700 Subject: [PATCH 10/10] Add unit test for vsa Signed-off-by: Kai Xu --- .../sparsity/attention_sparsity/config.py | 6 +- .../attention_sparsity/methods/__init__.py | 7 +- .../attention_sparsity/methods/registry.py | 4 + .../attention_sparsity/methods/vsa.py | 19 +-- .../attention_sparsity/sparse_attention.py | 6 +- tests/examples/llm_eval/test_llm_eval.py | 38 +++++ .../sparsity/attention_sparsity/test_vsa.py | 138 ++++++++++++++++++ 7 files changed, 194 insertions(+), 24 deletions(-) create mode 100644 tests/examples/llm_eval/test_llm_eval.py diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 6d85812c1..8b25ab792 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -477,9 +477,9 @@ class VSAAttributeConfig(ModeloptBaseConfig): title="Video shape.", description=( "Video dimensions (T, H, W) after patchification. Required unless a " - "model-specific plugin (e.g., the LTX-2 plugin) computes it from the " - "model's patchifier. If None and no plugin provides a value, VSA will " - "raise an error at forward time." + "model-specific plugin computes it from the model's patchifier. " + "If None and no plugin provides a value, VSA will raise an error at " + "forward time." ), ) diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py index 31a281f5f..209561cf2 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py @@ -15,8 +15,6 @@ """Sparse attention methods package.""" -from modelopt.torch.utils import import_plugin - from .registry import SparseAttentionMethod, get_sparse_method, register_sparse_method __all__ = [ @@ -26,7 +24,4 @@ ] # Import method implementations to trigger registration -from . import flash_skip_softmax - -with import_plugin("vsa"): - from . import vsa # Video Sparse Attention (requires fastvideo_kernel) +from . import flash_skip_softmax, vsa diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index 3f3e78db6..cf63bd388 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -94,6 +94,10 @@ def get_threshold_info(self) -> dict[str, Any]: """ return {"type": "none", "value": None} + def set_calibration_mode(self, enabled: bool): + """Set calibration mode. Override in subclasses that support calibration.""" + self._calibration_mode = enabled + @property @abstractmethod def name(self) -> str: diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py b/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py index b42f0a111..7cfc86f42 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py @@ -85,9 +85,6 @@ def __init__(self, method_config: dict | None = None): # Video shape (can be set dynamically) self.video_shape = config.get("video_shape", None) - # Track last computed statistics - self._last_stats: dict = {} - # Metadata cache: avoids recomputing tile indices on every forward pass. # Matches FastVideo's @lru_cache on utility functions. self._cached_metadata: dict[str, Any] | None = None @@ -119,12 +116,12 @@ def _compute_metadata(self, seq_len: int, device: torch.device) -> dict[str, Any raise ValueError( f"video_shape must be provided for VSA but is None (seq_len={seq_len}). " f"Set it via the VSA config ('video_shape' key), call set_video_shape(), " - f"or use a model-specific plugin (e.g., LTX-2 plugin) that computes it " - f"from the model's patchifier." + f"or use a model-specific plugin that computes it from the model's " + f"patchifier." ) # Return cached metadata if inputs haven't changed - cache_key = (seq_len, self.video_shape) + cache_key = (seq_len, self.video_shape, device) if self._cached_metadata is not None and self._cached_metadata_key == cache_key: return self._cached_metadata @@ -309,15 +306,13 @@ def forward_attention( # Compute statistics actual_sparsity = 1.0 - (top_k / total_tiles) stats = { - "sparsity": actual_sparsity, - "phase": "vsa_triton", + "sparsity": [actual_sparsity], + "phase": "prefill", "total_blocks": total_tiles, - "sparse_blocks": total_tiles - top_k, + "sparse_blocks": [total_tiles - top_k], "top_k": top_k, "video_shape": self.video_shape, } - self._last_stats = stats - return output, stats def calculate_sparsity( @@ -327,7 +322,7 @@ def calculate_sparsity( """Not used by VSA. Required stub for the abstract base class. VSA replaces the entire attention computation via ``forward_attention()``, - which is called directly by model-specific plugins (e.g., ``_LTX2SparseAttention``). + which is called directly by model-specific plugins. The softmax-patching path that calls this method is never reached in the VSA flow. Raises: diff --git a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py index d6bb2cf0e..94e8b8a98 100644 --- a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py +++ b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py @@ -177,9 +177,9 @@ def forward(self, *args, **kwargs): """Forward with selected sparse attention method. Methods that replace the full attention computation (e.g., VSA) override - ``forward()`` in their model-specific plugin (e.g., ``_LTX2SparseAttention``) - and never reach this path. This method handles the softmax-patching path - used by methods like ``flash_skip_softmax``. + ``forward()`` in their model-specific plugin and never reach this path. + This method handles the softmax-patching path used by methods like + ``flash_skip_softmax``. """ # Pass through if sparse attention is disabled if not self.is_enabled: diff --git a/tests/examples/llm_eval/test_llm_eval.py b/tests/examples/llm_eval/test_llm_eval.py new file mode 100644 index 000000000..0abf78b53 --- /dev/null +++ b/tests/examples/llm_eval/test_llm_eval.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess + +from _test_utils.examples.models import TINY_LLAMA_PATH +from _test_utils.examples.run_command import run_llm_ptq_command +from _test_utils.torch.misc import minimum_sm + + +@minimum_sm(89) +def test_llama_eval_fp8(): + try: + run_llm_ptq_command( + model=TINY_LLAMA_PATH, + quant="fp8", + tasks="mmlu,lm_eval,simple_eval", + calib=64, + lm_eval_tasks="hellaswag,gsm8k", + simple_eval_tasks="humaneval", + lm_eval_limit=0.1, + batch=8, + ) + finally: + # Force kill llm-serve if it's still running + subprocess.run(["pkill", "-f", "llm-serve"], check=False) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py b/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py index 84768e77a..a548bfc6d 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py @@ -19,14 +19,21 @@ - vsa_utils.py: tile/untile index logic, variable block sizes - vsa.py: VSA method init, metadata computation, validation, caching - config.py: VSAAttributeConfig validation +- ModelOpt integration: sparsify() with VSA config, save/restore """ import math import pytest + +pytest.importorskip("transformers") + import torch +from _test_utils.torch.sparsity.sparse_attention_common import SimpleAttentionModel from pydantic import ValidationError +import modelopt.torch.opt as mto +import modelopt.torch.sparsity.attention_sparsity as sparse_attn from modelopt.torch.sparsity.attention_sparsity.config import VSAAttributeConfig, VSAConfig from modelopt.torch.sparsity.attention_sparsity.methods.vsa import VSA from modelopt.torch.sparsity.attention_sparsity.methods.vsa_utils import ( @@ -35,6 +42,7 @@ get_reverse_tile_partition_indices, get_tile_partition_indices, ) +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule # --------------------------------------------------------------------------- # vsa_utils: tile partition indices @@ -140,6 +148,33 @@ def test_partial_blocks(self): assert npi.shape == (80,) # 64 + 16 +# --------------------------------------------------------------------------- +# VSA: tile/untile round-trip +# --------------------------------------------------------------------------- + + +class TestTileUntileRoundTrip: + """Test _tile_tensor / _untile_tensor preserve data.""" + + @pytest.mark.parametrize( + "video_shape", + [(8, 8, 8), (5, 6, 7), (4, 4, 4)], + ids=["even", "non-divisible", "single-tile"], + ) + def test_round_trip(self, video_shape): + """tile then untile recovers the original tensor.""" + seq_len = video_shape[0] * video_shape[1] * video_shape[2] + vsa = VSA({"video_shape": video_shape}) + meta = vsa._compute_metadata(seq_len, torch.device("cpu")) + + x = torch.randn(2, 4, seq_len, 16) # [batch, heads, seq, dim] + tiled = vsa._tile_tensor(x, meta) + recovered = vsa._untile_tensor(tiled, meta, seq_len) + + assert recovered.shape == x.shape + assert torch.allclose(recovered, x) + + # --------------------------------------------------------------------------- # VSA method: init and config # --------------------------------------------------------------------------- @@ -262,3 +297,106 @@ def test_vsa_config_defaults(self): cfg = VSAConfig() assert "*attn*" in cfg.sparse_cfg assert cfg.sparse_cfg["*attn*"]["method"] == "vsa" + + +# --------------------------------------------------------------------------- +# ModelOpt integration: sparsify() with VSA config +# --------------------------------------------------------------------------- + +VSA_TEST_CFG = { + "sparse_cfg": { + "*attention*": { + "method": "vsa", + "block_size_3d": (4, 4, 4), + "top_k_ratio": 0.5, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +class TestVSASparsifyIntegration: + """Test VSA integration with modelopt sparsify() API.""" + + def test_sparsify_creates_sparse_modules(self): + """sparsify() with VSA config replaces attention modules.""" + model = SimpleAttentionModel() + sparse_model = sparse_attn.sparsify(model, VSA_TEST_CFG) + + sparse_modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + assert len(sparse_modules) > 0 + + def test_sparse_module_has_vsa_method(self): + """Replaced modules are configured with VSA method.""" + model = SimpleAttentionModel() + sparse_model = sparse_attn.sparsify(model, VSA_TEST_CFG) + + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + assert module._method == "vsa" + assert isinstance(module._sparse_method_instance, VSA) + assert module._sparse_method_instance.block_size_3d == (4, 4, 4) + assert module._sparse_method_instance.top_k_ratio == 0.5 + + def test_enable_disable(self): + """Enable/disable works on VSA sparse modules.""" + model = SimpleAttentionModel() + sparse_model = sparse_attn.sparsify(model, VSA_TEST_CFG) + + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + assert module.is_enabled + module.disable() + assert not module.is_enabled + module.enable() + assert module.is_enabled + + def test_threshold_info(self): + """VSA sparse modules report correct threshold info.""" + model = SimpleAttentionModel() + sparse_model = sparse_attn.sparsify(model, VSA_TEST_CFG) + + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + info = module.get_threshold_info() + assert info["type"] == "vsa" + assert info["top_k_ratio"] == 0.5 + + def test_save_restore(self): + """VSA modelopt_state can be saved and restored.""" + model = SimpleAttentionModel() + sparse_model = sparse_attn.sparsify(model, VSA_TEST_CFG) + + state = mto.modelopt_state(sparse_model) + + # Restore to a fresh model + model_restored = SimpleAttentionModel() + mto.restore_from_modelopt_state(model_restored, state) + + # Verify VSA method is restored + for module in model_restored.modules(): + if isinstance(module, SparseAttentionModule): + assert module._method == "vsa" + assert isinstance(module._sparse_method_instance, VSA) + + def test_pattern_matching(self): + """Pattern-based config selectively applies VSA.""" + model = SimpleAttentionModel() + + # Pattern that won't match anything + config = { + "sparse_cfg": { + "*nonexistent*": { + "method": "vsa", + "enable": True, + }, + "default": {"enable": False}, + }, + } + sparse_model = sparse_attn.sparsify(model, config) + + # No modules should have VSA enabled + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + assert not module.is_enabled