diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index d88aef4dcf2a..664f70b95e5d 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -233,6 +233,11 @@ def rope_params(self, index, dim, theta=10000): freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs + @lru_cache_unless_export(maxsize=None) + def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + """Return pos_freqs and neg_freqs on the given device.""" + return self.pos_freqs.to(device), self.neg_freqs.to(device) + def forward( self, video_fhw: tuple[int, int, int, list[tuple[int, int, int]]], @@ -300,8 +305,9 @@ def forward( max_vid_index = max(height, width, max_vid_index) max_txt_seq_len_int = int(max_txt_seq_len) - # Create device-specific copy for text freqs without modifying self.pos_freqs - txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] + # Use cached device-transferred freqs to avoid CPU→GPU sync every forward call + pos_freqs_device, _ = self._get_device_freqs(device) + txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @@ -311,8 +317,9 @@ def _compute_video_freqs( self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None ) -> torch.Tensor: seq_lens = frame * height * width - pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs - neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + pos_freqs, neg_freqs = ( + self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs) + ) freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) @@ -367,6 +374,11 @@ def rope_params(self, index, dim, theta=10000): freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs + @lru_cache_unless_export(maxsize=None) + def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + """Return pos_freqs and neg_freqs on the given device.""" + return self.pos_freqs.to(device), self.neg_freqs.to(device) + def forward( self, video_fhw: tuple[int, int, int, list[tuple[int, int, int]]], @@ -421,8 +433,9 @@ def forward( max_vid_index = max(max_vid_index, layer_num) max_txt_seq_len_int = int(max_txt_seq_len) - # Create device-specific copy for text freqs without modifying self.pos_freqs - txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] + # Use cached device-transferred freqs to avoid CPU→GPU sync every forward call + pos_freqs_device, _ = self._get_device_freqs(device) + txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @@ -430,8 +443,9 @@ def forward( @lru_cache_unless_export(maxsize=None) def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None): seq_lens = frame * height * width - pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs - neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + pos_freqs, neg_freqs = ( + self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs) + ) freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) @@ -452,8 +466,9 @@ def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device @lru_cache_unless_export(maxsize=None) def _compute_condition_freqs(self, frame, height, width, device: torch.device = None): seq_lens = frame * height * width - pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs - neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + pos_freqs, neg_freqs = ( + self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs) + ) freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)