diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 4baf5bbe6..8b25ab792 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -411,7 +411,7 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): # Configuration with RULER calibration -# Note: threshold field is omitted - calibration determines dynamic threshold λ = a / length +# Note: threshold field is omitted - calibration determines dynamic threshold lambda = a / length # The calibrated threshold adapts to sequence length for optimal sparsity SKIP_SOFTMAX_CALIB = { "sparse_cfg": { @@ -434,13 +434,154 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): } +class VSAAttributeConfig(ModeloptBaseConfig): + """Video Sparse Attention (VSA) attribute configuration. + + VSA uses a two-branch architecture optimized for video diffusion models: + 1. Compression branch: Block-averaged coarse attention + 2. Sparse branch: Top-K block selection for fine-grained attention + """ + + method: str = ModeloptField( + default="vsa", + title="Sparse attention method.", + description="Must be 'vsa' for Video Sparse Attention.", + ) + + enable: bool = ModeloptField( + default=True, + title="Enable VSA.", + description="If True, enables Video Sparse Attention. If False, bypasses sparsity.", + ) + + block_size_3d: tuple[int, int, int] | list[int] = ModeloptField( + default=(4, 4, 4), + title="3D block size.", + description=( + "Video block dimensions (T, H, W) for spatial-temporal tiling. " + "Default (4, 4, 4) creates 64-token blocks." + ), + ) + + top_k_ratio: float = ModeloptField( + default=0.5, + title="Top-K selection ratio.", + description=( + "Ratio of blocks to keep in sparse branch (0.0 to 1.0). " + "Lower values mean more sparsity. Default 0.5 keeps 50% of blocks." + ), + ) + + video_shape: tuple[int, int, int] | list[int] | None = ModeloptField( + default=None, + title="Video shape.", + description=( + "Video dimensions (T, H, W) after patchification. Required unless a " + "model-specific plugin computes it from the model's patchifier. " + "If None and no plugin provides a value, VSA will raise an error at " + "forward time." + ), + ) + + collect_stats: bool = ModeloptField( + default=False, + title="Collect statistics.", + description="Whether to collect sparsity statistics during forward pass.", + ) + + @field_validator("method") + @classmethod + def validate_vsa_method(cls, v): + """Validate method is 'vsa'.""" + if v != "vsa": + raise ValueError(f"VSAAttributeConfig method must be 'vsa', got '{v}'") + return v + + @field_validator("block_size_3d") + @classmethod + def validate_block_size_3d(cls, v): + """Validate 3D block size.""" + if isinstance(v, list): + v = tuple(v) + if len(v) != 3: + raise ValueError(f"block_size_3d must have 3 elements (T, H, W), got {len(v)}") + if any(x <= 0 for x in v): + raise ValueError(f"All block_size_3d values must be positive, got {v}") + return v + + @field_validator("top_k_ratio") + @classmethod + def validate_top_k_ratio(cls, v): + """Validate top-K ratio is in valid range.""" + if not 0.0 < v <= 1.0: + raise ValueError(f"top_k_ratio must be in range (0, 1], got {v}") + return v + + @field_validator("video_shape") + @classmethod + def validate_video_shape(cls, v): + """Validate video shape if provided.""" + if v is None: + return v + if isinstance(v, list): + v = tuple(v) + if len(v) != 3: + raise ValueError(f"video_shape must have 3 elements (T, H, W), got {len(v)}") + if any(x <= 0 for x in v): + raise ValueError(f"All video_shape values must be positive, got {v}") + return v + + +class VSAConfig(SparseAttentionConfig): + """Configuration for Video Sparse Attention optimization. + + VSA is designed for video diffusion models with learned gate_compress + parameters in attention layers. + """ + + sparse_cfg: SparseAttentionCfgType = ModeloptField( + default={ + "*attn*": { + "method": "vsa", + "block_size_3d": (4, 4, 4), + "top_k_ratio": 0.5, + "enable": True, + }, + "default": {"enable": False}, + }, + title="VSA configuration", + description=( + "Pattern-based configuration for Video Sparse Attention. " + "Keys are patterns to match module names, values are VSA configs." + ), + validate_default=True, + ) + + +# Pre-defined VSA Configuration for video diffusion models. +# Pattern "*attn*" matches attention module names by convention. +VSA_DEFAULT = { + "sparse_cfg": { + "*attn*": { + "method": "vsa", + "block_size_3d": (4, 4, 4), + "top_k_ratio": 0.5, + "enable": True, + }, + "default": {"enable": False}, + }, +} + __all__ = [ "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_DEFAULT", + "VSA_DEFAULT", "CalibrationConfig", "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", "SparseAttentionCfgType", "SparseAttentionConfig", "SparseAttributeConfig", + "VSAAttributeConfig", + "VSAConfig", ] diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py index 8a109fda7..209561cf2 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py @@ -24,4 +24,4 @@ ] # Import method implementations to trigger registration -from . import flash_skip_softmax +from . import flash_skip_softmax, vsa diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index 3f3e78db6..cf63bd388 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -94,6 +94,10 @@ def get_threshold_info(self) -> dict[str, Any]: """ return {"type": "none", "value": None} + def set_calibration_mode(self, enabled: bool): + """Set calibration mode. Override in subclasses that support calibration.""" + self._calibration_mode = enabled + @property @abstractmethod def name(self) -> str: diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py b/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py new file mode 100644 index 000000000..7cfc86f42 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py @@ -0,0 +1,369 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Video Sparse Attention (VSA) method for video diffusion models. + +VSA implements a two-branch sparse attention architecture: +1. Compression Branch: Averages tokens within 3D video blocks and computes coarse attention +2. Sparse Branch: Selects top-K blocks based on importance and computes fine-grained attention + +This method requires model modification to expose gate_compress for optimal quality. + +Uses the optimized Triton kernel from fastvideo_kernel for 2-6x speedup. + +The data flow mirrors FastVideo's VideoSparseAttentionImpl: + tile(Q,K,V,gate) -> Triton kernel -> untile(output) +""" + +import math +from typing import Any + +import torch + +from . import SparseAttentionMethod, register_sparse_method +from .vsa_utils import ( + construct_variable_block_sizes, + get_non_pad_index, + get_reverse_tile_partition_indices, + get_tile_partition_indices, +) + + +@register_sparse_method("vsa") +class VSA(SparseAttentionMethod): + """Video Sparse Attention with two-branch architecture. + + VSA combines a compression branch (coarse-grained block attention) with + a sparse branch (fine-grained attention on top-K selected blocks). + + The final output is: output = out_compression * gate_compress + out_sparse + + where gate_compress is a learned parameter from the model layer that + controls the balance between compression and sparse branches. + + Configuration Parameters: + - block_size_3d: 3D tile dimensions (T, H, W), default (4, 4, 4) + - top_k_ratio: Ratio of blocks to keep (0.0-1.0), default 0.5 + - video_shape: Video dimensions (T, H, W) after patchification + + Requirements: + - Model must expose gate_compress parameter in attention layers + - Input tensors must be 4D: [batch, heads, seq_len, dim] + """ + + def __init__(self, method_config: dict | None = None): + """Initialize VSA method. + + Args: + method_config: Configuration dict with VSA parameters. + """ + super().__init__() + config = method_config or {} + + # Block configuration + block_size = config.get("block_size_3d", (4, 4, 4)) + if isinstance(block_size, list): + block_size = tuple(block_size) + self.block_size_3d = block_size + self.block_elements = block_size[0] * block_size[1] * block_size[2] + + # Sparsity configuration + self.top_k_ratio = config.get("top_k_ratio", 0.5) + + # Video shape (can be set dynamically) + self.video_shape = config.get("video_shape", None) + + # Metadata cache: avoids recomputing tile indices on every forward pass. + # Matches FastVideo's @lru_cache on utility functions. + self._cached_metadata: dict[str, Any] | None = None + self._cached_metadata_key: tuple | None = None + + def set_video_shape(self, video_shape: tuple[int, int, int]): + """Set video shape for current forward pass. + + Args: + video_shape: Video dimensions (T, H, W) after patchification. + """ + self.video_shape = video_shape + + def _compute_metadata(self, seq_len: int, device: torch.device) -> dict[str, Any]: + """Compute block metadata from video shape. + + Results are cached and reused when called with the same (seq_len, video_shape) + to avoid recomputing tile indices on every denoising step, matching FastVideo's + ``@functools.lru_cache`` on the underlying utility functions. + + Args: + seq_len: Sequence length (should equal T * H * W). + device: Device for tensors. + + Returns: + Metadata dict with tile indices, variable sizes, etc. + """ + if self.video_shape is None: + raise ValueError( + f"video_shape must be provided for VSA but is None (seq_len={seq_len}). " + f"Set it via the VSA config ('video_shape' key), call set_video_shape(), " + f"or use a model-specific plugin that computes it from the model's " + f"patchifier." + ) + + # Return cached metadata if inputs haven't changed + cache_key = (seq_len, self.video_shape, device) + if self._cached_metadata is not None and self._cached_metadata_key == cache_key: + return self._cached_metadata + + vid_t, vid_h, vid_w = self.video_shape + ts_t, ts_h, ts_w = self.block_size_3d + + # Validate sequence length matches video shape + expected_seq_len = vid_t * vid_h * vid_w + if seq_len != expected_seq_len: + raise ValueError( + f"Sequence length {seq_len} does not match video shape {self.video_shape} " + f"(expected {expected_seq_len})" + ) + + # Calculate number of tiles + num_tiles = ( + math.ceil(vid_t / ts_t), + math.ceil(vid_h / ts_h), + math.ceil(vid_w / ts_w), + ) + total_tiles = num_tiles[0] * num_tiles[1] * num_tiles[2] + + # Get partitioning indices + tile_indices = get_tile_partition_indices(self.video_shape, self.block_size_3d, device) + reverse_indices = get_reverse_tile_partition_indices( + self.video_shape, self.block_size_3d, device + ) + variable_sizes = construct_variable_block_sizes( + self.video_shape, num_tiles, self.block_size_3d, device + ) + non_pad_index = get_non_pad_index(variable_sizes, self.block_elements) + + # Calculate padded sizes + t_padded = num_tiles[0] * ts_t + h_padded = num_tiles[1] * ts_h + w_padded = num_tiles[2] * ts_w + padded_seq_len = t_padded * h_padded * w_padded + + metadata = { + "video_shape": self.video_shape, + "tile_size": self.block_size_3d, + "num_tiles": num_tiles, + "total_tiles": total_tiles, + "tile_indices": tile_indices, + "reverse_indices": reverse_indices, + "variable_sizes": variable_sizes, + "non_pad_index": non_pad_index, + "padded_seq_len": padded_seq_len, + } + + # Cache for reuse across denoising steps + self._cached_metadata = metadata + self._cached_metadata_key = cache_key + + return metadata + + def _tile_tensor(self, tensor: torch.Tensor, metadata: dict) -> torch.Tensor: + """Rearrange tensor into tile layout with padding. + + Args: + tensor: Input tensor [batch, heads, seq_len, dim]. + metadata: Metadata from _compute_metadata. + + Returns: + Tiled tensor [batch, heads, padded_seq_len, dim]. + """ + batch, heads, seq_len, dim = tensor.shape + device = tensor.device + dtype = tensor.dtype + + tile_indices = metadata["tile_indices"] + non_pad_index = metadata["non_pad_index"] + padded_seq_len = metadata["padded_seq_len"] + + # Create padded tensor + padded = torch.zeros((batch, heads, padded_seq_len, dim), device=device, dtype=dtype) + + # Rearrange to tile order and place in padded positions + padded[:, :, non_pad_index] = tensor[:, :, tile_indices] + + return padded + + def _untile_tensor(self, tensor: torch.Tensor, metadata: dict, seq_len: int) -> torch.Tensor: + """Reverse tile layout back to original order. + + Args: + tensor: Tiled tensor [batch, heads, padded_seq_len, dim]. + metadata: Metadata from _compute_metadata. + seq_len: Original sequence length. + + Returns: + Output tensor [batch, heads, seq_len, dim]. + """ + non_pad_index = metadata["non_pad_index"] + reverse_indices = metadata["reverse_indices"] + + # Extract non-padded tokens and reverse order + return tensor[:, :, non_pad_index][:, :, reverse_indices] + + def forward_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + gate_compress: torch.Tensor | None = None, + video_shape: tuple[int, int, int] | None = None, + **kwargs, + ) -> tuple[torch.Tensor, dict]: + """Compute VSA two-branch sparse attention. + + Data flow (mirrors FastVideo's VideoSparseAttentionImpl): + 1. Compute tile metadata from video_shape + 2. Tile Q, K, V, gate_compress into padded tile order + 3. Run Triton VSA kernel on tiled tensors + 4. Untile output back to original token order + + Args: + query: Query tensor [batch, heads, seq_len, dim]. + key: Key tensor [batch, heads, seq_len, dim]. + value: Value tensor [batch, heads, seq_len, dim]. + gate_compress: Learned gating weights [batch, heads, seq_len, dim]. + If None, uses equal weighting (0.5) for both branches. + video_shape: Video dimensions (T, H, W). If None, uses self.video_shape. + **kwargs: Additional arguments (ignored). + + Returns: + Tuple of (attention_output, stats) where: + - attention_output: [batch, heads, seq_len, dim] + - stats: Dict with sparsity statistics + """ + if video_shape is not None: + self.video_shape = video_shape + + batch, heads, seq_len, dim = query.shape + device = query.device + + # Compute block metadata (cached across denoising steps) + metadata = self._compute_metadata(seq_len, device) + total_tiles = metadata["total_tiles"] + variable_sizes = metadata["variable_sizes"] + + # Calculate top-K based on ratio + top_k = max(1, int(self.top_k_ratio * total_tiles)) + + # ========== TILE: rearrange tokens into tile order ========== + # Mirrors FastVideo's VideoSparseAttentionImpl.preprocess_qkv (tile) + query_tiled = self._tile_tensor(query, metadata) + key_tiled = self._tile_tensor(key, metadata) + value_tiled = self._tile_tensor(value, metadata) + gate_tiled = ( + self._tile_tensor(gate_compress, metadata) if gate_compress is not None else None + ) + + # ========== TRITON VSA KERNEL ========== + # Kernel operates on tiled tensors in [batch, heads, padded_seq, dim] format + try: + from fastvideo_kernel import video_sparse_attn as triton_vsa_kernel + except ModuleNotFoundError: + raise ModuleNotFoundError( + "VSA requires the 'fastvideo_kernel' package for its Triton sparse attention " + "kernel. The VSA method registered successfully, but the kernel is needed at " + "runtime. Install it with:\n" + " git clone https://github.com/FastVideo/FastVideo.git\n" + " cd FastVideo/fastvideo-kernel && ./build.sh\n" + "See https://github.com/hao-ai-lab/FastVideo/tree/main/fastvideo-kernel for details." + ) from None + output_tiled = triton_vsa_kernel( + query_tiled, + key_tiled, + value_tiled, + variable_sizes, # q_variable_sizes + variable_sizes, # kv_variable_sizes + top_k, + block_size=self.block_size_3d, + compress_attn_weight=gate_tiled, + ) + + # ========== UNTILE: restore original token order ========== + # Mirrors FastVideo's VideoSparseAttentionImpl.postprocess_output (untile) + output = self._untile_tensor(output_tiled, metadata, seq_len) + + # Compute statistics + actual_sparsity = 1.0 - (top_k / total_tiles) + stats = { + "sparsity": [actual_sparsity], + "phase": "prefill", + "total_blocks": total_tiles, + "sparse_blocks": [total_tiles - top_k], + "top_k": top_k, + "video_shape": self.video_shape, + } + return output, stats + + def calculate_sparsity( + self, + attention_scores: torch.Tensor, + ) -> tuple[torch.Tensor, dict]: + """Not used by VSA. Required stub for the abstract base class. + + VSA replaces the entire attention computation via ``forward_attention()``, + which is called directly by model-specific plugins. + The softmax-patching path that calls this method is never reached in the VSA flow. + + Raises: + NotImplementedError: Always. Use ``forward_attention()`` instead. + """ + raise NotImplementedError( + "VSA does not use the softmax-patching path. " + "Use forward_attention() via a model-specific plugin instead." + ) + + def apply_sparsity( + self, + attention_scores: torch.Tensor, + sparse_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Not used by VSA. Required stub for the abstract base class. + + See ``calculate_sparsity`` for details. + + Raises: + NotImplementedError: Always. Use ``forward_attention()`` instead. + """ + raise NotImplementedError( + "VSA does not use the softmax-patching path. " + "Use forward_attention() via a model-specific plugin instead." + ) + + def get_threshold_info(self) -> dict[str, Any]: + """Get VSA configuration info. + + Returns: + Dictionary with VSA configuration. + """ + return { + "type": "vsa", + "block_size_3d": self.block_size_3d, + "top_k_ratio": self.top_k_ratio, + "video_shape": self.video_shape, + } + + @property + def name(self) -> str: + """Method identifier.""" + return "vsa" diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/vsa_utils.py b/modelopt/torch/sparsity/attention_sparsity/methods/vsa_utils.py new file mode 100644 index 000000000..affed79f7 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/vsa_utils.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for Video Sparse Attention (VSA). + +This module provides 3D block operations for video sparse attention, +including reshaping tensors into video blocks and variable block size computation. +""" + +import functools +import math + +import torch + + +@functools.lru_cache(maxsize=10) +def get_tile_partition_indices( + video_shape: tuple[int, int, int], + tile_size: tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + """Get indices to partition video tokens into tiles. + + Args: + video_shape: Video dimensions (T, H, W) after patchification. + tile_size: Tile dimensions (tile_T, tile_H, tile_W). + device: Device for the output tensor. + + Returns: + LongTensor of indices to rearrange tokens into tile order. + """ + vid_t, vid_h, vid_w = video_shape + ts, hs, ws = tile_size + indices = torch.arange(vid_t * vid_h * vid_w, device=device, dtype=torch.long).reshape( + vid_t, vid_h, vid_w + ) + + tiles = [] + for t in range(math.ceil(vid_t / ts)): + for h in range(math.ceil(vid_h / hs)): + for w in range(math.ceil(vid_w / ws)): + tile = indices[ + t * ts : min(t * ts + ts, vid_t), + h * hs : min(h * hs + hs, vid_h), + w * ws : min(w * ws + ws, vid_w), + ] + tiles.append(tile.flatten()) + + return torch.cat(tiles, dim=0) + + +@functools.lru_cache(maxsize=10) +def get_reverse_tile_partition_indices( + video_shape: tuple[int, int, int], + tile_size: tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + """Get indices to reverse tile partitioning back to original order. + + Args: + video_shape: Video dimensions (T, H, W) after patchification. + tile_size: Tile dimensions (tile_T, tile_H, tile_W). + device: Device for the output tensor. + + Returns: + LongTensor of indices to reverse the tile rearrangement. + """ + forward_indices = get_tile_partition_indices(video_shape, tile_size, device) + return torch.argsort(forward_indices) + + +@functools.lru_cache(maxsize=10) +def construct_variable_block_sizes( + video_shape: tuple[int, int, int], + num_tiles: tuple[int, int, int], + tile_size: tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + """Compute valid (non-padded) token count for each tile. + + Since video dimensions may not divide evenly by tile size, edge tiles + will have fewer valid tokens. This function computes the actual valid + token count for each tile. + + Args: + video_shape: Video dimensions (T, H, W) after patchification. + num_tiles: Number of tiles in each dimension (n_T, n_H, n_W). + tile_size: Tile dimensions (tile_T, tile_H, tile_W). + device: Device for the output tensor. + + Returns: + LongTensor of shape [num_tiles_total] with valid tokens per tile. + """ + t, h, w = video_shape + ts_t, ts_h, ts_w = tile_size + n_t, n_h, n_w = num_tiles + + def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor: + """Compute size of each tile along one dimension.""" + sizes = torch.full((n_tiles,), tile, dtype=torch.long, device=device) + remainder = dim_len - (n_tiles - 1) * tile + sizes[-1] = remainder if remainder > 0 else tile + return sizes + + t_sizes = _sizes(t, ts_t, n_t) # [n_t] + h_sizes = _sizes(h, ts_h, n_h) # [n_h] + w_sizes = _sizes(w, ts_w, n_w) # [n_w] + + # Broadcast multiply to get tokens per tile + block_sizes = ( + t_sizes[:, None, None] * h_sizes[None, :, None] * w_sizes[None, None, :] + ).reshape(-1) + + return block_sizes + + +@functools.lru_cache(maxsize=10) +def get_non_pad_index( + variable_block_sizes: torch.LongTensor, + max_block_size: int, +) -> torch.LongTensor: + """Get indices of non-padded tokens in the padded layout. + + When tiles have variable sizes, we pad to max_block_size. This function + returns indices to extract only valid (non-padded) tokens. + + Args: + variable_block_sizes: Tensor of valid token counts per tile. + max_block_size: Maximum tile size (usually tile_T * tile_H * tile_W). + + Returns: + LongTensor of indices for valid tokens. + """ + n_win = variable_block_sizes.shape[0] + device = variable_block_sizes.device + + starts_pad = torch.arange(n_win, device=device) * max_block_size + index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :] + index_mask = ( + torch.arange(max_block_size, device=device)[None, :] < variable_block_sizes[:, None] + ) + + return index_pad[index_mask] diff --git a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py index 4333d1243..94e8b8a98 100644 --- a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py +++ b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py @@ -61,9 +61,26 @@ def set_from_attribute_config( Args: attribute_cfg: Sparse attention attribute configuration. """ + from .config import VSAAttributeConfig + + # Determine which config class to use based on method + config_dict = attribute_cfg or {} + if isinstance(attribute_cfg, dict): + method = config_dict.get("method", "flash_skip_softmax") + elif attribute_cfg is not None and hasattr(attribute_cfg, "method"): + method = attribute_cfg.method + else: + method = "flash_skip_softmax" + + # Select appropriate config class based on method + if method == "vsa": + config_class = VSAAttributeConfig + else: + config_class = SparseAttentionAttributeConfig + # Ensure config is validated through Pydantic - if not isinstance(attribute_cfg, SparseAttentionAttributeConfig): - attribute_cfg = SparseAttentionAttributeConfig(**(attribute_cfg or {})) + if not isinstance(attribute_cfg, (SparseAttentionAttributeConfig, VSAAttributeConfig)): + attribute_cfg = config_class(**(config_dict)) # Store raw config for method initialization self._method_config = {} @@ -80,10 +97,10 @@ def set_from_attribute_config( # Process each attribute from validated config for attribute, val in attribute_cfg.model_dump().items(): - # Validate attribute if using config class - if hasattr(SparseAttentionAttributeConfig, "model_fields"): - assert attribute in SparseAttentionAttributeConfig.model_fields, ( - f"{attribute} is not a valid SparseAttentionModule attribute" + # Validate attribute against the appropriate config class + if hasattr(config_class, "model_fields"): + assert attribute in config_class.model_fields, ( + f"{attribute} is not a valid {config_class.__name__} attribute" ) if attribute in _module_attributes: @@ -159,14 +176,16 @@ def _setup(self): def forward(self, *args, **kwargs): """Forward with selected sparse attention method. - This method dispatches to the appropriate sparse attention implementation - based on the configured method and backend. + Methods that replace the full attention computation (e.g., VSA) override + ``forward()`` in their model-specific plugin and never reach this path. + This method handles the softmax-patching path used by methods like + ``flash_skip_softmax``. """ # Pass through if sparse attention is disabled if not self.is_enabled: return super().forward(*args, **kwargs) - # Get the appropriate context manager for this configuration + # Standard path: softmax patching context = self._get_sparse_context() # Apply sparse attention through the context diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py b/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py new file mode 100644 index 000000000..a548bfc6d --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_vsa.py @@ -0,0 +1,402 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU-only unit tests for Video Sparse Attention (VSA). + +Tests cover: +- vsa_utils.py: tile/untile index logic, variable block sizes +- vsa.py: VSA method init, metadata computation, validation, caching +- config.py: VSAAttributeConfig validation +- ModelOpt integration: sparsify() with VSA config, save/restore +""" + +import math + +import pytest + +pytest.importorskip("transformers") + +import torch +from _test_utils.torch.sparsity.sparse_attention_common import SimpleAttentionModel +from pydantic import ValidationError + +import modelopt.torch.opt as mto +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.sparsity.attention_sparsity.config import VSAAttributeConfig, VSAConfig +from modelopt.torch.sparsity.attention_sparsity.methods.vsa import VSA +from modelopt.torch.sparsity.attention_sparsity.methods.vsa_utils import ( + construct_variable_block_sizes, + get_non_pad_index, + get_reverse_tile_partition_indices, + get_tile_partition_indices, +) +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + +# --------------------------------------------------------------------------- +# vsa_utils: tile partition indices +# --------------------------------------------------------------------------- + + +class TestTilePartitionIndices: + """Tests for get_tile_partition_indices.""" + + def test_evenly_divisible(self): + """Tiles cover full volume with no remainder.""" + video_shape = (8, 8, 8) + tile_size = (4, 4, 4) + idx = get_tile_partition_indices(video_shape, tile_size, torch.device("cpu")) + assert idx.shape == (8 * 8 * 8,) + # Every original index appears exactly once + assert torch.equal(idx.sort().values, torch.arange(512)) + + def test_non_divisible(self): + """Edge tiles are smaller when dims don't divide evenly.""" + video_shape = (5, 6, 7) + tile_size = (4, 4, 4) + seq_len = 5 * 6 * 7 + idx = get_tile_partition_indices(video_shape, tile_size, torch.device("cpu")) + assert idx.shape == (seq_len,) + assert torch.equal(idx.sort().values, torch.arange(seq_len)) + + def test_round_trip(self): + """tile then reverse_tile is identity.""" + video_shape = (6, 10, 8) + tile_size = (4, 4, 4) + device = torch.device("cpu") + fwd = get_tile_partition_indices(video_shape, tile_size, device) + rev = get_reverse_tile_partition_indices(video_shape, tile_size, device) + # Applying forward then reverse should yield the original order + assert torch.equal(fwd[rev], torch.arange(6 * 10 * 8)) + + +# --------------------------------------------------------------------------- +# vsa_utils: variable block sizes +# --------------------------------------------------------------------------- + + +class TestVariableBlockSizes: + """Tests for construct_variable_block_sizes.""" + + def test_evenly_divisible(self): + """All tiles have full size when dims divide evenly.""" + video_shape = (8, 8, 8) + tile_size = (4, 4, 4) + num_tiles = (2, 2, 2) + sizes = construct_variable_block_sizes( + video_shape, num_tiles, tile_size, torch.device("cpu") + ) + assert sizes.shape == (8,) # 2*2*2 tiles + assert (sizes == 64).all() # every tile is full 4*4*4 + + def test_non_divisible_sum(self): + """Sum of variable sizes equals original sequence length.""" + video_shape = (5, 6, 7) + tile_size = (4, 4, 4) + num_tiles = ( + math.ceil(5 / 4), + math.ceil(6 / 4), + math.ceil(7 / 4), + ) + sizes = construct_variable_block_sizes( + video_shape, num_tiles, tile_size, torch.device("cpu") + ) + assert sizes.sum().item() == 5 * 6 * 7 + + def test_partial_tile_smaller(self): + """Last tile along a non-divisible dim should be smaller.""" + video_shape = (5, 4, 4) + tile_size = (4, 4, 4) + num_tiles = (2, 1, 1) + sizes = construct_variable_block_sizes( + video_shape, num_tiles, tile_size, torch.device("cpu") + ) + # First tile: 4*4*4=64, second tile: 1*4*4=16 + assert sizes[0].item() == 64 + assert sizes[1].item() == 16 + + +# --------------------------------------------------------------------------- +# vsa_utils: non-pad index +# --------------------------------------------------------------------------- + + +class TestNonPadIndex: + """Tests for get_non_pad_index.""" + + def test_full_blocks(self): + """All blocks full size → non_pad covers everything.""" + sizes = torch.tensor([64, 64, 64]) + npi = get_non_pad_index(sizes, 64) + assert npi.shape == (192,) # 3 * 64 + + def test_partial_blocks(self): + """Partial blocks → non_pad skips padding positions.""" + sizes = torch.tensor([64, 16]) + npi = get_non_pad_index(sizes, 64) + assert npi.shape == (80,) # 64 + 16 + + +# --------------------------------------------------------------------------- +# VSA: tile/untile round-trip +# --------------------------------------------------------------------------- + + +class TestTileUntileRoundTrip: + """Test _tile_tensor / _untile_tensor preserve data.""" + + @pytest.mark.parametrize( + "video_shape", + [(8, 8, 8), (5, 6, 7), (4, 4, 4)], + ids=["even", "non-divisible", "single-tile"], + ) + def test_round_trip(self, video_shape): + """tile then untile recovers the original tensor.""" + seq_len = video_shape[0] * video_shape[1] * video_shape[2] + vsa = VSA({"video_shape": video_shape}) + meta = vsa._compute_metadata(seq_len, torch.device("cpu")) + + x = torch.randn(2, 4, seq_len, 16) # [batch, heads, seq, dim] + tiled = vsa._tile_tensor(x, meta) + recovered = vsa._untile_tensor(tiled, meta, seq_len) + + assert recovered.shape == x.shape + assert torch.allclose(recovered, x) + + +# --------------------------------------------------------------------------- +# VSA method: init and config +# --------------------------------------------------------------------------- + + +class TestVSAInit: + """Tests for VSA.__init__ and basic properties.""" + + def test_defaults(self): + vsa = VSA() + assert vsa.block_size_3d == (4, 4, 4) + assert vsa.block_elements == 64 + assert vsa.top_k_ratio == 0.5 + assert vsa.video_shape is None + assert vsa.name == "vsa" + + def test_custom_config(self): + vsa = VSA({"block_size_3d": [2, 2, 2], "top_k_ratio": 0.3, "video_shape": (8, 8, 8)}) + assert vsa.block_size_3d == (2, 2, 2) + assert vsa.block_elements == 8 + assert vsa.top_k_ratio == 0.3 + assert vsa.video_shape == (8, 8, 8) + + def test_set_video_shape(self): + vsa = VSA() + vsa.set_video_shape((4, 8, 12)) + assert vsa.video_shape == (4, 8, 12) + + def test_get_threshold_info(self): + vsa = VSA({"top_k_ratio": 0.7, "video_shape": (4, 4, 4)}) + info = vsa.get_threshold_info() + assert info["type"] == "vsa" + assert info["top_k_ratio"] == 0.7 + + +# --------------------------------------------------------------------------- +# VSA method: metadata computation and validation +# --------------------------------------------------------------------------- + + +class TestVSAMetadata: + """Tests for VSA._compute_metadata validation and caching.""" + + def test_no_video_shape_raises(self): + vsa = VSA() + with pytest.raises(ValueError, match="video_shape must be provided"): + vsa._compute_metadata(100, torch.device("cpu")) + + def test_seq_len_mismatch_raises(self): + vsa = VSA({"video_shape": (4, 4, 4)}) + with pytest.raises(ValueError, match="does not match video shape"): + vsa._compute_metadata(100, torch.device("cpu")) # expected 64 + + def test_valid_metadata(self): + vsa = VSA({"video_shape": (8, 8, 8)}) + meta = vsa._compute_metadata(512, torch.device("cpu")) + assert meta["video_shape"] == (8, 8, 8) + assert meta["num_tiles"] == (2, 2, 2) + assert meta["total_tiles"] == 8 + + def test_metadata_caching(self): + vsa = VSA({"video_shape": (8, 8, 8)}) + m1 = vsa._compute_metadata(512, torch.device("cpu")) + m2 = vsa._compute_metadata(512, torch.device("cpu")) + assert m1 is m2 # same object, not recomputed + + +# --------------------------------------------------------------------------- +# VSA method: abstract stubs raise +# --------------------------------------------------------------------------- + + +class TestVSAStubs: + """calculate_sparsity and apply_sparsity should raise NotImplementedError.""" + + def test_calculate_sparsity_raises(self): + vsa = VSA() + with pytest.raises(NotImplementedError, match="softmax-patching"): + vsa.calculate_sparsity(torch.zeros(1)) + + def test_apply_sparsity_raises(self): + vsa = VSA() + with pytest.raises(NotImplementedError, match="softmax-patching"): + vsa.apply_sparsity(torch.zeros(1)) + + +# --------------------------------------------------------------------------- +# VSAAttributeConfig validation +# --------------------------------------------------------------------------- + + +class TestVSAAttributeConfig: + """Tests for VSAAttributeConfig pydantic validation.""" + + def test_valid_defaults(self): + cfg = VSAAttributeConfig() + assert cfg.method == "vsa" + assert cfg.block_size_3d == (4, 4, 4) + assert cfg.top_k_ratio == 0.5 + + def test_top_k_ratio_out_of_range(self): + with pytest.raises(ValidationError, match="top_k_ratio"): + VSAAttributeConfig(top_k_ratio=0.0) + with pytest.raises(ValidationError, match="top_k_ratio"): + VSAAttributeConfig(top_k_ratio=1.5) + + def test_video_shape_wrong_length(self): + with pytest.raises(ValidationError, match="3 elements"): + VSAAttributeConfig(video_shape=(4, 4)) + + def test_video_shape_negative(self): + with pytest.raises(ValidationError, match="positive"): + VSAAttributeConfig(video_shape=(4, -1, 4)) + + def test_video_shape_none_allowed(self): + cfg = VSAAttributeConfig(video_shape=None) + assert cfg.video_shape is None + + def test_vsa_config_defaults(self): + cfg = VSAConfig() + assert "*attn*" in cfg.sparse_cfg + assert cfg.sparse_cfg["*attn*"]["method"] == "vsa" + + +# --------------------------------------------------------------------------- +# ModelOpt integration: sparsify() with VSA config +# --------------------------------------------------------------------------- + +VSA_TEST_CFG = { + "sparse_cfg": { + "*attention*": { + "method": "vsa", + "block_size_3d": (4, 4, 4), + "top_k_ratio": 0.5, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +class TestVSASparsifyIntegration: + """Test VSA integration with modelopt sparsify() API.""" + + def test_sparsify_creates_sparse_modules(self): + """sparsify() with VSA config replaces attention modules.""" + model = SimpleAttentionModel() + sparse_model = sparse_attn.sparsify(model, VSA_TEST_CFG) + + sparse_modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + assert len(sparse_modules) > 0 + + def test_sparse_module_has_vsa_method(self): + """Replaced modules are configured with VSA method.""" + model = SimpleAttentionModel() + sparse_model = sparse_attn.sparsify(model, VSA_TEST_CFG) + + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + assert module._method == "vsa" + assert isinstance(module._sparse_method_instance, VSA) + assert module._sparse_method_instance.block_size_3d == (4, 4, 4) + assert module._sparse_method_instance.top_k_ratio == 0.5 + + def test_enable_disable(self): + """Enable/disable works on VSA sparse modules.""" + model = SimpleAttentionModel() + sparse_model = sparse_attn.sparsify(model, VSA_TEST_CFG) + + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + assert module.is_enabled + module.disable() + assert not module.is_enabled + module.enable() + assert module.is_enabled + + def test_threshold_info(self): + """VSA sparse modules report correct threshold info.""" + model = SimpleAttentionModel() + sparse_model = sparse_attn.sparsify(model, VSA_TEST_CFG) + + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + info = module.get_threshold_info() + assert info["type"] == "vsa" + assert info["top_k_ratio"] == 0.5 + + def test_save_restore(self): + """VSA modelopt_state can be saved and restored.""" + model = SimpleAttentionModel() + sparse_model = sparse_attn.sparsify(model, VSA_TEST_CFG) + + state = mto.modelopt_state(sparse_model) + + # Restore to a fresh model + model_restored = SimpleAttentionModel() + mto.restore_from_modelopt_state(model_restored, state) + + # Verify VSA method is restored + for module in model_restored.modules(): + if isinstance(module, SparseAttentionModule): + assert module._method == "vsa" + assert isinstance(module._sparse_method_instance, VSA) + + def test_pattern_matching(self): + """Pattern-based config selectively applies VSA.""" + model = SimpleAttentionModel() + + # Pattern that won't match anything + config = { + "sparse_cfg": { + "*nonexistent*": { + "method": "vsa", + "enable": True, + }, + "default": {"enable": False}, + }, + } + sparse_model = sparse_attn.sparsify(model, config) + + # No modules should have VSA enabled + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + assert not module.is_enabled