Skip to content

Conversation

@DefTruth
Copy link
Contributor

@DefTruth DefTruth commented Jan 4, 2026

follows up #12660, and also fixed vipshop/cache-dit#622

  1. We need to disable the splitting of encoder_hidden_states because the image_encoder consistently generates 257 tokens for image_embed. This causes the shape of encoder_hidden_states—whose token count is always 769 (512 + 257) after concatenation—to be indivisible by the number of devices in the CP.

  2. Since the key/value in cross-attention depends solely on encoder_hidden_states (text or img), the (q_chunk * k) * v computation can be parallelized independently. Thus, there is no need to pass the parallel_config for cross-attention. This change reduces redundant all-to-all communications—specifically (3+1)×2=8 for the two cross-attention operations (text and img)—thereby improving Wan’s performance under context parallelism.

@sayakpaul @yiyixuxu @DN6

reproduce

import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
import torch.distributed as dist
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers import ContextParallelConfig

if dist.is_available():
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    world_size = dist.get_world_size()
    torch.cuda.set_device(device)
else:
    rank = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    world_size = 1

model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
    model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(
    model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanImageToVideoPipeline.from_pretrained(
    model_id, vae=vae, 
    image_encoder=image_encoder, 
    torch_dtype=torch.bfloat16,
    quantization_config=(
        PipelineQuantizationConfig(
            quant_backend="bitsandbytes_4bit",
            quant_kwargs={
                "load_in_4bit": True,
                "bnb_4bit_quant_type": "nf4",
                "bnb_4bit_compute_dtype": torch.bfloat16,
            },
            components_to_quantize=["text_encoder", "transformer"],
        )
    ),
).to(device)
pipe.vae.enable_tiling()

image = load_image(
  "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG"
)
max_area = 480 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt=(
    "Summer beach vacation style, a white cat wearing sunglasses sits on a "
    "surfboard. The fluffy-furred feline gazes directly at the camera with "
    "a relaxed expression. Blurred beach scenery forms the background featuring "
    "crystal-clear waters, distant green hills, and a blue sky dotted with white "
    "clouds. The cat assumes a naturally relaxed posture, as if savoring the sea "
    "breeze and warm sunlight. A close-up shot highlights the feline's intricate "
    "details and the refreshing atmosphere of the seaside."
)
negative_prompt=(
    "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,"
    "低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,"
    "毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
)

pipe.transformer.set_attention_backend("native")
if world_size > 1:
    pipe.transformer.enable_parallelism(
        config=ContextParallelConfig(ulysses_degree=world_size)
    )

pipe.set_progress_bar_config(disable=rank != 0)

output = pipe(
    image=image, prompt=prompt, negative_prompt=negative_prompt, 
    height=height, width=width, num_frames=21, guidance_scale=5.0,
    num_inference_steps=16, generator=torch.Generator("cpu").manual_seed(42),
).frames[0]

if rank == 0:
    export_to_video(output, "output.mp4", fps=9)

if dist.is_initialized():
    dist.destroy_process_group()

test cmds:

python3 test_wan_2.1_i2v.py # baseline
torchrun --nproc_per_node=2 --local-ranks-filter=0 test_wan_i2v.py # CP2 w/ ulysses

w/o this pr:

rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/dev/vipshop/cache-dit/examples/tmp/test_wan_i2v.py", line 72, in <module>
[rank0]:     output = pipe(
[rank0]:              ^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 756, in __call__
[rank0]:     noise_pred = current_model(
[rank0]:                  ^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]:     output = function_reference.forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/models/transformers/transformer_wan.py", line 684, in forward
[rank0]:     hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 188, in new_forward
[rank0]:     args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 157, in pre_forward
[rank0]:     input_val = self._prepare_cp_input(input_val, cpm)
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 211, in _prepare_cp_input
[rank0]:     return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 261, in shard
[rank0]:     assert tensor.size()[dim] % mesh.size() == 0, (
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError: Tensor size along dimension to be sharded must be divisible by mesh size
[rank0]:[W104 07:18:46.479685596 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

w/ this pr:

[W105 02:06:22.318276131 socket.cpp:767] [c10d] The client socket cannot be initialized to connect to [10.189.108.254]:29500 (errno: 97 - Address family not supported by protocol).
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
[W105 02:06:32.697211097 socket.cpp:767] [c10d] The client socket cannot be initialized to connect to [10.189.108.254]:29500 (errno: 97 - Address family not supported by protocol).
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:34<00:00,  2.46s/it]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:10<00:00,  2.12s/it]
Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:46<00:00,  6.61s/it]
Attention backends are an experimental feature and the API may be subject to change.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
  0%|                                                                                                                                                                                                                              | 0/16 [00:00<?, ?it/s]Expected input tensor to have 2 dimensions, but got 1 dimensions, split will not be applied.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [01:05<00:00,  4.12s/it]
Baseline, L20x1, 118s Ulysses 2, L20x2, 65s
base ulysses2

@DefTruth
Copy link
Contributor Author

DefTruth commented Jan 5, 2026

@sayakpaul @DN6 PTAL~

@sayakpaul
Copy link
Member

Since the key/value in cross-attention depends solely on encoder_hidden_states (text or img), the (q_chunk * k) * v computation can be parallelized independently. Thus, there is no need to pass the parallel_config for cross-attention. This change reduces redundant all-to-all communications—specifically (3+1)×2=8 for the two cross-attention operations (text and img)—thereby improving Wan’s performance under context parallelism.

Would it have any memory impact?

"blocks.*": {
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
# Reference: https://github.com/huggingface/diffusers/pull/12909
Copy link
Member

@sayakpaul sayakpaul Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this specific to I2V only? If so, then this change is probably a little to intrusive no?

Copy link
Contributor Author

@DefTruth DefTruth Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul This is theoretically applicable to all wan series models and offers better performance. I conducted tests on the wan 2.1/2.2 t2v/i2v models, and the results were all correct. You can perform quick verification through examples in cache-dit. This patch has already been used in cache-dit to fix some precision issues. vipshop/cache-dit#639

pip3 install torch==2.9.1 transformers accelerate torchao bitsandbytes torchvision 
pip3 install opencv-python-headless einops imageio-ffmpeg ftfy 
pip3 install git+https://github.com/huggingface/diffusers.git # latest 
pip3 install git+https://github.com/vipshop/cache-dit.git # latest

git clone https://github.com/vipshop/cache-dit.git && cd cache-dit/examples

# use  `--cpu-offload` and `--parallel-text-encoder` for low VRAM device, e.g, < 48GiB
torchrun --nproc_per_node=4 generate.py wan2.1_t2v --parallel ulysses --parallel-text-encoder
torchrun --nproc_per_node=4 generate.py wan2.2_t2v --parallel ulysses --parallel-text-encoder --cpu-offload
torchrun --nproc_per_node=2 generate.py wan2.1_i2v --parallel ulysses --parallel-text-encoder --steps 16 --frames 21 --vae-tiling
torchrun --nproc_per_node=2 generate.py wan2.2_i2v --parallel ulysses --parallel-text-encoder --cpu-offload --steps 16 --frames 21 --vae-tiling

@DefTruth
Copy link
Contributor Author

DefTruth commented Jan 5, 2026

Since the key/value in cross-attention depends solely on encoder_hidden_states (text or img), the (q_chunk * k) * v computation can be parallelized independently. Thus, there is no need to pass the parallel_config for cross-attention. This change reduces redundant all-to-all communications—specifically (3+1)×2=8 for the two cross-attention operations (text and img)—thereby improving Wan’s performance under context parallelism.

Would it have any memory impact?

  1. no memory impact for hidden_states (since hidden_states are still split)
  2. negligible memory impact for encoder_hidden_states (since the seq_len of encoder_hidden_states is generally relatively short, e.g, 512)

@sayakpaul sayakpaul requested a review from DN6 January 5, 2026 05:09
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 6, 2026

ohhreally cool!
I think this could be applicable to other models beyond wan?

@DefTruth
Copy link
Contributor Author

DefTruth commented Jan 6, 2026

ohhreally cool! I think this could be applicable to other models beyond wan?

I'm not sure. It depends on whether the Attention implementation of other models is similar to Wan. I haven't tested it on other models yet.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 6, 2026

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Jan 6, 2026

Style fix runs successfully without any file modified.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 6, 2026

@DefTruth can you run make fix-copies?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DefTruth
Copy link
Contributor Author

DefTruth commented Jan 6, 2026

@DefTruth can you run make fix-copies?

done

@sayakpaul sayakpaul merged commit 3138e37 into huggingface:main Jan 6, 2026
9 of 11 checks passed
@sayakpaul
Copy link
Member

Thanks for your contributions!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Low quality output (snow-like noise) when using Cache-DiT with Wan2.1-I2V

4 participants