From d769495347132a2db6cc1ecb664cfcaaa0f5b979 Mon Sep 17 00:00:00 2001 From: daqiege <44255948+daqiege@users.noreply.github.com> Date: Wed, 27 May 2026 11:03:57 +0800 Subject: [PATCH 1/2] fix(wan_s2v_dit): accept and propagate attn_kwargs WanS2VDiT.forward() does not accept the attn_kwargs argument that WanSpeech2VideoPipeline.predict_noise passes (added in 4ae8f2c0), causing TypeError on every speech-to-video inference call. This mirrors what was done in WanDiT/DiTBlock (wan_dit.py): add an optional attn_kwargs parameter and forward it down through DiTBlockS2V to self_attn / cross_attn, which already accept attn_kwargs. Fixes #221 --- diffsynth_engine/models/wan/wan_s2v_dit.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/diffsynth_engine/models/wan/wan_s2v_dit.py b/diffsynth_engine/models/wan/wan_s2v_dit.py index d0d21c5a..540b91af 100644 --- a/diffsynth_engine/models/wan/wan_s2v_dit.py +++ b/diffsynth_engine/models/wan/wan_s2v_dit.py @@ -1,5 +1,5 @@ import json -from typing import List, Optional +from typing import Any, Dict, List, Optional import numpy as np import torch @@ -277,7 +277,7 @@ def __init__(self, dit_block: DiTBlock): self.ffn = dit_block.ffn self.modulation = dit_block.modulation - def forward(self, x, x_seq_len, context, t_mod, t_mod_0, freqs): + def forward(self, x, x_seq_len, context, t_mod, t_mod_0, freqs, attn_kwargs=None): # msa: multi-head self-attention mlp: multi-layer perceptron shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [ t for t in (self.modulation + t_mod).chunk(6, dim=1) @@ -293,9 +293,9 @@ def forward(self, x, x_seq_len, context, t_mod, t_mod_0, freqs): ], dim=1, ) - self_attn_x = self.self_attn(input_x, freqs) + self_attn_x = self.self_attn(input_x, freqs, attn_kwargs) x += torch.cat([self_attn_x[:, :x_seq_len] * gate_msa, self_attn_x[:, x_seq_len:] * gate_msa_0], dim=1) - x += self.cross_attn(self.norm3(x), context) + x += self.cross_attn(self.norm3(x), context, attn_kwargs) norm2_x = self.norm2(x) input_x = torch.cat( [ @@ -411,6 +411,7 @@ def forward( drop_motion_frames: bool = False, # !(ref_as_first_frame || clip_idx) audio_mask: Optional[torch.Tensor] = None, # b c tx h w void_audio_input: Optional[torch.Tensor] = None, + attn_kwargs: Optional[Dict[str, Any]] = None, ): fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False) use_cfg = x.shape[0] > 1 @@ -479,7 +480,13 @@ def forward( freqs = torch.concat([freqs_img, freqs_ref_motion], dim=1) for idx, block in enumerate(self.blocks): x = block( - x=x, x_seq_len=x_seq_len_local, context=context, t_mod=t_mod, t_mod_0=t_mod_0, freqs=freqs + x=x, + x_seq_len=x_seq_len_local, + context=context, + t_mod=t_mod, + t_mod_0=t_mod_0, + freqs=freqs, + attn_kwargs=attn_kwargs, ) if idx in self.audio_injector.injected_block_id.keys(): x = self.inject_audio( From e99c715cb1925e6f8d206d3b2bd5a92ceecda517 Mon Sep 17 00:00:00 2001 From: daqiege <44255948+daqiege@users.noreply.github.com> Date: Wed, 27 May 2026 11:33:16 +0800 Subject: [PATCH 2/2] fix: pass attn_kwargs as keyword to self_attn/cross_attn (review) --- diffsynth_engine/models/wan/wan_s2v_dit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diffsynth_engine/models/wan/wan_s2v_dit.py b/diffsynth_engine/models/wan/wan_s2v_dit.py index 540b91af..7ac82406 100644 --- a/diffsynth_engine/models/wan/wan_s2v_dit.py +++ b/diffsynth_engine/models/wan/wan_s2v_dit.py @@ -293,9 +293,9 @@ def forward(self, x, x_seq_len, context, t_mod, t_mod_0, freqs, attn_kwargs=None ], dim=1, ) - self_attn_x = self.self_attn(input_x, freqs, attn_kwargs) + self_attn_x = self.self_attn(input_x, freqs, attn_kwargs=attn_kwargs) x += torch.cat([self_attn_x[:, :x_seq_len] * gate_msa, self_attn_x[:, x_seq_len:] * gate_msa_0], dim=1) - x += self.cross_attn(self.norm3(x), context, attn_kwargs) + x += self.cross_attn(self.norm3(x), context, attn_kwargs=attn_kwargs) norm2_x = self.norm2(x) input_x = torch.cat( [