Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
Expand Down Expand Up @@ -717,11 +717,7 @@ def forward(
img_ids = img_ids[0]

ids = torch.cat((txt_ids, img_ids), dim=0)
if is_torch_npu_available():
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
else:
image_rotary_emb = self.pos_embed(ids)
image_rotary_emb = self.pos_embed(ids)

if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
Expand Down
12 changes: 3 additions & 9 deletions src/diffusers/models/transformers/transformer_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
Expand Down Expand Up @@ -835,14 +835,8 @@ def forward(
if txt_ids.ndim == 3:
txt_ids = txt_ids[0]

if is_torch_npu_available():
freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
else:
image_rotary_emb = self.pos_embed(img_ids)
text_rotary_emb = self.pos_embed(txt_ids)
image_rotary_emb = self.pos_embed(img_ids)
text_rotary_emb = self.pos_embed(txt_ids)
concat_rotary_emb = (
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import is_torch_npu_available, logging
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
Expand Down Expand Up @@ -499,11 +499,7 @@ def forward(
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

ids = torch.cat((txt_ids, img_ids), dim=0)
if is_torch_npu_available():
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
else:
image_rotary_emb = self.pos_embed(ids)
image_rotary_emb = self.pos_embed(ids)

for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]:
Expand Down
8 changes: 2 additions & 6 deletions src/diffusers/models/transformers/transformer_ovis_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import is_torch_npu_available, logging
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
Expand Down Expand Up @@ -530,11 +530,7 @@ def forward(
img_ids = img_ids[0]

ids = torch.cat((txt_ids, img_ids), dim=0)
if is_torch_npu_available():
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
else:
image_rotary_emb = self.pos_embed(ids)
image_rotary_emb = self.pos_embed(ids)

for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
Expand Down
Loading