Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
is_flash_attn_available,
is_flash_attn_version,
is_kernels_available,
is_kernels_version,
is_sageattention_available,
is_sageattention_version,
is_torch_npu_available,
Expand Down Expand Up @@ -265,6 +266,7 @@ class _HubKernelConfig:
repo_id: str
function_attr: str
revision: str | None = None
version: int | None = None
kernel_fn: Callable | None = None
wrapped_forward_attr: str | None = None
wrapped_backward_attr: str | None = None
Expand All @@ -274,27 +276,31 @@ class _HubKernelConfig:

# Registry for hub-based attention kernels
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", version=1
),
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
function_attr="flash_attn_varlen_func",
# revision="fake-ops-return-probs",
version=1,
),
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_func",
version=1,
revision=None,
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
),
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_varlen_func",
version=1,
),
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
repo_id="kernels-community/sage-attention",
function_attr="sageattn",
version=1,
),
}

Expand Down Expand Up @@ -464,6 +470,10 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
raise RuntimeError(
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
if not is_kernels_version(">=", "0.12"):
raise RuntimeError(
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
)

elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
is_inflect_available,
is_invisible_watermark_available,
is_kernels_available,
is_kernels_version,
is_kornia_available,
is_librosa_available,
is_matplotlib_available,
Expand Down
16 changes: 16 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,22 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version)


@cache
def is_kernels_version(operation: str, version: str):
"""
Compares the current Kernels version to a given reference with an operation.

Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _kernels_available:
return False
return compare_versions(parse(_kernels_version), operation, version)


@cache
def is_hf_hub_version(operation: str, version: str):
"""
Expand Down