Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 253 additions & 31 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ class _HubKernelConfig:
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
function_attr="flash_attn_varlen_func",
wrapped_forward_attr="flash_attn_interface._flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._flash_attn_backward",
version=1,
),
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
Expand Down Expand Up @@ -1612,6 +1614,194 @@ def _flash_attention_3_hub_backward_op(
return grad_query, grad_key, grad_value


def _flash_attention_3_varlen_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: "ParallelConfig" | None = None,
*,
window_size: tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: bool | None = None,
deterministic: bool = False,
sm_margin: int = 0,
):
if dropout_p != 0.0:
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 varlen hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 varlen hub kernels.")

config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention 3 varlen hub kernels must expose `flash_attn_interface._flash_attn_forward` and "
"`flash_attn_interface._flash_attn_backward` for context parallel execution."
)

if scale is None:
scale = query.shape[-1] ** (-0.5)

batch_size, seq_len_q, num_heads, _ = query.shape
_, seq_len_kv, _, _ = key.shape

if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (_, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, query.device)
)
indices_k = attn_mask.flatten().nonzero(as_tuple=False).flatten()
query_packed = query.flatten(0, 1)
key_packed = key.reshape(-1, *key.shape[2:])[indices_k]
value_packed = value.reshape(-1, *value.shape[2:])[indices_k]
max_seqlen_q = seq_len_q
else:
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device)
)
query_packed = query.flatten(0, 1)
key_packed = key.flatten(0, 1)
value_packed = value.flatten(0, 1)
seqlens_k = None

out_packed, softmax_lse, *_ = wrapped_forward_fn(
query_packed,
key_packed,
value_packed,
None,
None,
None,
None,
cu_seqlens_q,
cu_seqlens_k,
None,
None,
None,
max_seqlen_q,
max_seqlen_k,
None,
None,
None,
None,
None,
None,
None,
None,
None,
scale,
is_causal,
window_size[0],
window_size[1],
0,
softcap,
True,
None,
num_splits,
pack_gqa,
sm_margin,
)

out = out_packed.view(batch_size, seq_len_q, *out_packed.shape[1:])

if _save_ctx:
ctx.save_for_backward(
query_packed, key_packed, value_packed, out_packed, softmax_lse, cu_seqlens_q, cu_seqlens_k
)
ctx.seqlens_k = seqlens_k # None if unmasked
ctx.indices_k = indices_k if attn_mask is not None else None
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.batch_size = batch_size
ctx.seq_len_q = seq_len_q
ctx.seq_len_kv = seq_len_kv
ctx.num_heads = num_heads
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.deterministic = deterministic
ctx.sm_margin = sm_margin

# softmax_lse in varlen mode: (num_heads, total_q) -> (batch_size, seq_len_q, num_heads)
lse_sp = softmax_lse.view(num_heads, batch_size, seq_len_q).permute(1, 2, 0).contiguous()

return (out, lse_sp) if return_lse else out


def _flash_attention_3_varlen_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention 3 varlen hub kernels must expose `flash_attn_interface._flash_attn_backward` "
"for context parallel execution."
)

query_packed, key_packed, value_packed, out_packed, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors

grad_out_packed = grad_out.flatten(0, 1)
grad_query, grad_key, grad_value = (
torch.empty_like(query_packed),
torch.empty_like(key_packed),
torch.empty_like(value_packed),
)

wrapped_backward_fn(
grad_out_packed,
query_packed,
key_packed,
value_packed,
out_packed,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
None,
None,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
grad_query,
grad_key,
grad_value,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.deterministic,
ctx.sm_margin,
)

grad_query = grad_query.view(ctx.batch_size, ctx.seq_len_q, *grad_query.shape[1:])

if ctx.seqlens_k is not None:
grad_key = _unpad_to_padded(grad_key, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv)
grad_value = _unpad_to_padded(grad_value, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv)
else:
grad_key = grad_key.view(ctx.batch_size, ctx.seq_len_kv, *grad_key.shape[1:])
grad_value = grad_value.view(ctx.batch_size, ctx.seq_len_kv, *grad_value.shape[1:])

grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]

return grad_query, grad_key, grad_value


def _sage_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
Expand Down Expand Up @@ -2986,7 +3176,7 @@ def _flash_attention_3_hub(
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_VARLEN_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_attention_3_varlen_hub(
query: torch.Tensor,
Expand All @@ -2998,41 +3188,73 @@ def _flash_attention_3_varlen_hub(
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if _parallel_config is not None and _parallel_config.context_parallel_config.ring_degree > 1:
raise NotImplementedError("`ring_degree > 1` is not yet supported for the _FLASH_3_VARLEN_HUB backend.")

lse = None
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape

if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)

(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
)

key_valid, value_valid = [], []
for b in range(batch_size):
valid_len = seqlens_k[b]
key_valid.append(key[b, :valid_len])
value_valid.append(value[b, :valid_len])
if _parallel_config is None:
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
(_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, query.device)
)
indices_k = attn_mask.flatten().nonzero(as_tuple=False).flatten()
key_packed = key.reshape(-1, *key.shape[2:])[indices_k]
value_packed = value.reshape(-1, *value.shape[2:])[indices_k]
else:
(_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device)
)
key_packed = key.flatten(0, 1)
value_packed = value.flatten(0, 1)

query_packed = query.flatten(0, 1)
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
query_packed = query.flatten(0, 1)

func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn
out, lse, *_ = func(
q=query_packed,
k=key_packed,
v=value_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=scale,
causal=is_causal,
)
out = out.unflatten(0, (batch_size, -1))
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn
out = func(
q=query_packed,
k=key_packed,
v=value_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
out = out.unflatten(0, (batch_size, -1))
else:
forward_op = functools.partial(
_flash_attention_3_varlen_hub_forward_op,
window_size=(-1, -1),
softcap=0.0,
num_splits=1,
pack_gqa=None,
deterministic=False,
sm_margin=0,
)
out = _templated_context_parallel_attention(
query,
key,
value,
attn_mask,
0.0,
is_causal,
scale,
False,
return_lse,
forward_op=forward_op,
backward_op=_flash_attention_3_varlen_hub_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out

return (out, lse) if return_lse else out

Expand Down
6 changes: 5 additions & 1 deletion tests/models/testing_utils/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,10 @@ class ContextParallelAttentionBackendsTesterMixin:
"_flash_3_hub",
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
),
pytest.param(
"_flash_3_varlen_hub",
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
),
],
)
@pytest.mark.parametrize("ulysses_anything", [True, False])
Expand All @@ -410,7 +414,7 @@ def test_context_parallel_attn_backend_inference(self, cp_type, attention_backen
if cp_type == "ring_degree":
if attention_backend == AttentionBackendName.NATIVE:
pytest.skip("Skipping test because ring isn't supported with native attention backend.")
elif attention_backend in ("flash_varlen_hub"):
elif attention_backend in ("flash_varlen_hub", "_flash_3_varlen_hub"):
pytest.skip("`ring_degree` is not yet supported for varlen attention hub kernels.")

if ulysses_anything and "ulysses" not in cp_type:
Expand Down
1 change: 1 addition & 0 deletions tests/models/testing_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AttentionBackendName.FLASH_HUB,
AttentionBackendName.FLASH_VARLEN_HUB,
AttentionBackendName._FLASH_3_HUB,
AttentionBackendName._FLASH_3_VARLEN_HUB,
}


Expand Down
Loading