diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e68d317bc140..9b1911da7ed5 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -340,6 +340,8 @@ class _HubKernelConfig: AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn3", function_attr="flash_attn_varlen_func", + wrapped_forward_attr="flash_attn_interface._flash_attn_forward", + wrapped_backward_attr="flash_attn_interface._flash_attn_backward", version=1, ), AttentionBackendName.FLASH_HUB: _HubKernelConfig( @@ -1612,6 +1614,194 @@ def _flash_attention_3_hub_backward_op( return grad_query, grad_key, grad_value +def _flash_attention_3_varlen_hub_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float | None = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: "ParallelConfig" | None = None, + *, + window_size: tuple[int, int] = (-1, -1), + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: bool | None = None, + deterministic: bool = False, + sm_margin: int = 0, +): + if dropout_p != 0.0: + raise ValueError("`dropout_p` is not yet supported for flash-attn 3 varlen hub kernels.") + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 varlen hub kernels.") + + config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB] + wrapped_forward_fn = config.wrapped_forward_fn + wrapped_backward_fn = config.wrapped_backward_fn + if wrapped_forward_fn is None or wrapped_backward_fn is None: + raise RuntimeError( + "Flash attention 3 varlen hub kernels must expose `flash_attn_interface._flash_attn_forward` and " + "`flash_attn_interface._flash_attn_backward` for context parallel execution." + ) + + if scale is None: + scale = query.shape[-1] ** (-0.5) + + batch_size, seq_len_q, num_heads, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (_, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, query.device) + ) + indices_k = attn_mask.flatten().nonzero(as_tuple=False).flatten() + query_packed = query.flatten(0, 1) + key_packed = key.reshape(-1, *key.shape[2:])[indices_k] + value_packed = value.reshape(-1, *value.shape[2:])[indices_k] + max_seqlen_q = seq_len_q + else: + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device) + ) + query_packed = query.flatten(0, 1) + key_packed = key.flatten(0, 1) + value_packed = value.flatten(0, 1) + seqlens_k = None + + out_packed, softmax_lse, *_ = wrapped_forward_fn( + query_packed, + key_packed, + value_packed, + None, + None, + None, + None, + cu_seqlens_q, + cu_seqlens_k, + None, + None, + None, + max_seqlen_q, + max_seqlen_k, + None, + None, + None, + None, + None, + None, + None, + None, + None, + scale, + is_causal, + window_size[0], + window_size[1], + 0, + softcap, + True, + None, + num_splits, + pack_gqa, + sm_margin, + ) + + out = out_packed.view(batch_size, seq_len_q, *out_packed.shape[1:]) + + if _save_ctx: + ctx.save_for_backward( + query_packed, key_packed, value_packed, out_packed, softmax_lse, cu_seqlens_q, cu_seqlens_k + ) + ctx.seqlens_k = seqlens_k # None if unmasked + ctx.indices_k = indices_k if attn_mask is not None else None + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.batch_size = batch_size + ctx.seq_len_q = seq_len_q + ctx.seq_len_kv = seq_len_kv + ctx.num_heads = num_heads + ctx.scale = scale + ctx.is_causal = is_causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.deterministic = deterministic + ctx.sm_margin = sm_margin + + # softmax_lse in varlen mode: (num_heads, total_q) -> (batch_size, seq_len_q, num_heads) + lse_sp = softmax_lse.view(num_heads, batch_size, seq_len_q).permute(1, 2, 0).contiguous() + + return (out, lse_sp) if return_lse else out + + +def _flash_attention_3_varlen_hub_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB] + wrapped_backward_fn = config.wrapped_backward_fn + if wrapped_backward_fn is None: + raise RuntimeError( + "Flash attention 3 varlen hub kernels must expose `flash_attn_interface._flash_attn_backward` " + "for context parallel execution." + ) + + query_packed, key_packed, value_packed, out_packed, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + + grad_out_packed = grad_out.flatten(0, 1) + grad_query, grad_key, grad_value = ( + torch.empty_like(query_packed), + torch.empty_like(key_packed), + torch.empty_like(value_packed), + ) + + wrapped_backward_fn( + grad_out_packed, + query_packed, + key_packed, + value_packed, + out_packed, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + None, + None, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + grad_query, + grad_key, + grad_value, + ctx.scale, + ctx.is_causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.deterministic, + ctx.sm_margin, + ) + + grad_query = grad_query.view(ctx.batch_size, ctx.seq_len_q, *grad_query.shape[1:]) + + if ctx.seqlens_k is not None: + grad_key = _unpad_to_padded(grad_key, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv) + grad_value = _unpad_to_padded(grad_value, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv) + else: + grad_key = grad_key.view(ctx.batch_size, ctx.seq_len_kv, *grad_key.shape[1:]) + grad_value = grad_value.view(ctx.batch_size, ctx.seq_len_kv, *grad_value.shape[1:]) + + grad_query = grad_query[..., : grad_out.shape[-1]] + grad_key = grad_key[..., : grad_out.shape[-1]] + grad_value = grad_value[..., : grad_out.shape[-1]] + + return grad_query, grad_key, grad_value + + def _sage_attention_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, @@ -2986,7 +3176,7 @@ def _flash_attention_3_hub( @_AttentionBackendRegistry.register( AttentionBackendName._FLASH_3_VARLEN_HUB, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], - supports_context_parallel=False, + supports_context_parallel=True, ) def _flash_attention_3_varlen_hub( query: torch.Tensor, @@ -2998,41 +3188,73 @@ def _flash_attention_3_varlen_hub( return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: + if _parallel_config is not None and _parallel_config.context_parallel_config.ring_degree > 1: + raise NotImplementedError("`ring_degree > 1` is not yet supported for the _FLASH_3_VARLEN_HUB backend.") + + lse = None batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape - if attn_mask is not None: - attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) - ) - - key_valid, value_valid = [], [] - for b in range(batch_size): - valid_len = seqlens_k[b] - key_valid.append(key[b, :valid_len]) - value_valid.append(value[b, :valid_len]) + if _parallel_config is None: + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, query.device) + ) + indices_k = attn_mask.flatten().nonzero(as_tuple=False).flatten() + key_packed = key.reshape(-1, *key.shape[2:])[indices_k] + value_packed = value.reshape(-1, *value.shape[2:])[indices_k] + else: + (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device) + ) + key_packed = key.flatten(0, 1) + value_packed = value.flatten(0, 1) - query_packed = query.flatten(0, 1) - key_packed = torch.cat(key_valid, dim=0) - value_packed = torch.cat(value_valid, dim=0) + query_packed = query.flatten(0, 1) - func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn - out, lse, *_ = func( - q=query_packed, - k=key_packed, - v=value_packed, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - softmax_scale=scale, - causal=is_causal, - ) - out = out.unflatten(0, (batch_size, -1)) + func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn + out = func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=scale, + causal=is_causal, + return_attn_probs=return_lse, + ) + if return_lse: + out, lse, *_ = out + out = out.unflatten(0, (batch_size, -1)) + else: + forward_op = functools.partial( + _flash_attention_3_varlen_hub_forward_op, + window_size=(-1, -1), + softcap=0.0, + num_splits=1, + pack_gqa=None, + deterministic=False, + sm_margin=0, + ) + out = _templated_context_parallel_attention( + query, + key, + value, + attn_mask, + 0.0, + is_causal, + scale, + False, + return_lse, + forward_op=forward_op, + backward_op=_flash_attention_3_varlen_hub_backward_op, + _parallel_config=_parallel_config, + ) + if return_lse: + out, lse = out return (out, lse) if return_lse else out diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index d4f5e99d6763..160247412851 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -393,6 +393,10 @@ class ContextParallelAttentionBackendsTesterMixin: "_flash_3_hub", marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), ), + pytest.param( + "_flash_3_varlen_hub", + marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), + ), ], ) @pytest.mark.parametrize("ulysses_anything", [True, False]) @@ -410,7 +414,7 @@ def test_context_parallel_attn_backend_inference(self, cp_type, attention_backen if cp_type == "ring_degree": if attention_backend == AttentionBackendName.NATIVE: pytest.skip("Skipping test because ring isn't supported with native attention backend.") - elif attention_backend in ("flash_varlen_hub"): + elif attention_backend in ("flash_varlen_hub", "_flash_3_varlen_hub"): pytest.skip("`ring_degree` is not yet supported for varlen attention hub kernels.") if ulysses_anything and "ulysses" not in cp_type: diff --git a/tests/models/testing_utils/utils.py b/tests/models/testing_utils/utils.py index 8877755377e0..3beb59ed1a66 100644 --- a/tests/models/testing_utils/utils.py +++ b/tests/models/testing_utils/utils.py @@ -8,6 +8,7 @@ AttentionBackendName.FLASH_HUB, AttentionBackendName.FLASH_VARLEN_HUB, AttentionBackendName._FLASH_3_HUB, + AttentionBackendName._FLASH_3_VARLEN_HUB, }