diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 90ffcac80dc5..07cfc54c2284 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -38,6 +38,7 @@ is_flash_attn_available, is_flash_attn_version, is_kernels_available, + is_kernels_version, is_sageattention_available, is_sageattention_version, is_torch_npu_available, @@ -265,6 +266,7 @@ class _HubKernelConfig: repo_id: str function_attr: str revision: str | None = None + version: int | None = None kernel_fn: Callable | None = None wrapped_forward_attr: str | None = None wrapped_backward_attr: str | None = None @@ -274,27 +276,31 @@ class _HubKernelConfig: # Registry for hub-based attention kernels _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = { - # TODO: temporary revision for now. Remove when merged upstream into `main`. AttentionBackendName._FLASH_3_HUB: _HubKernelConfig( - repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs" + repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", version=1 ), AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn3", function_attr="flash_attn_varlen_func", - # revision="fake-ops-return-probs", + version=1, ), AttentionBackendName.FLASH_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", + version=1, revision=None, wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward", wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward", ), AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig( - repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None + repo_id="kernels-community/flash-attn2", + function_attr="flash_attn_varlen_func", + version=1, ), AttentionBackendName.SAGE_HUB: _HubKernelConfig( - repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None + repo_id="kernels-community/sage-attention", + function_attr="sageattn", + version=1, ), } @@ -464,6 +470,10 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None raise RuntimeError( f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." ) + if not is_kernels_version(">=", "0.12"): + raise RuntimeError( + f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`." + ) elif backend == AttentionBackendName.AITER: if not _CAN_USE_AITER_ATTN: diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index dd50405c74b2..23d7ac7c6c2d 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -86,6 +86,7 @@ is_inflect_available, is_invisible_watermark_available, is_kernels_available, + is_kernels_version, is_kornia_available, is_librosa_available, is_matplotlib_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 8fb481946ebf..551fa358a28d 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -724,6 +724,22 @@ def is_transformers_version(operation: str, version: str): return compare_versions(parse(_transformers_version), operation, version) +@cache +def is_kernels_version(operation: str, version: str): + """ + Compares the current Kernels version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _kernels_available: + return False + return compare_versions(parse(_kernels_version), operation, version) + + @cache def is_hf_hub_version(operation: str, version: str): """