Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions diffsynth/core/npu_patch/npu_fused_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
from ..device.npu_compatible_device import get_device_type
try:
import torch_npu
except:
pass
Comment on lines +5 to +6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The bare except: is too broad and can hide other unexpected errors. It's better to catch the specific ImportError that can occur if torch_npu is not installed.

Suggested change
except:
pass
except ImportError:
pass



def rms_norm_forward_npu(self, hidden_states):
"npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py"
if hidden_states.dtype != self.weight.dtype:
hidden_states = hidden_states.to(self.weight.dtype)
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0]


def rms_norm_forward_transformers_npu(self, hidden_states):
"npu rms fused operator for transformers"
if hidden_states.dtype != self.weight.dtype:
hidden_states = hidden_states.to(self.weight.dtype)
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]


def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor):
"npu rope fused operator for Zimage"
with torch.amp.autocast(get_device_type(), enabled=False):
freqs_cis = freqs_cis.unsqueeze(2)
cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1)
cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2)
sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2)
return torch_npu.npu_rotary_mul(x_in, cos, sin).to(x_in)
20 changes: 10 additions & 10 deletions diffsynth/models/z_image_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_k
self.norm_q = RMSNorm(head_dim, eps=1e-5)
self.norm_k = RMSNorm(head_dim, eps=1e-5)

# Apply RoPE
def apply_rotary_emb(self, x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo

def forward(self, hidden_states, freqs_cis, attention_mask):
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
Expand All @@ -103,17 +111,9 @@ def forward(self, hidden_states, freqs_cis, attention_mask):
if self.norm_k is not None:
key = self.norm_k(key)

# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo

if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
query = self.apply_rotary_emb(query, freqs_cis)
key = self.apply_rotary_emb(key, freqs_cis)

# Cast to correct dtype
dtype = query.dtype
Expand Down
21 changes: 20 additions & 1 deletion diffsynth/pipelines/z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from typing import Union, List, Optional, Tuple, Iterable, Dict

from ..core.device.npu_compatible_device import get_device_type
from ..core.device.npu_compatible_device import get_device_type, IS_NPU_AVAILABLE
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..core.data.operators import ImageCropAndResize
Expand Down Expand Up @@ -63,6 +63,7 @@ def from_pretrained(
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit: float = None,
enable_npu_patch: bool = True,
):
# Initialize pipeline
pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype)
Expand All @@ -84,6 +85,8 @@ def from_pretrained(

# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
# NPU patch
apply_npu_patch(enable_npu_patch)
return pipe


Expand Down Expand Up @@ -667,3 +670,19 @@ def model_fn_z_image_turbo(
x = rearrange(x, "C B H W -> B C H W")
x = -x
return x


def apply_npu_patch(enable_npu_patch: bool=True):
if IS_NPU_AVAILABLE and enable_npu_patch:
from ..models.general_modules import RMSNorm
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm
from ..models.z_image_dit import Attention
from ..core.npu_patch.npu_fused_operator import (
rms_norm_forward_npu,
rms_norm_forward_transformers_npu,
rotary_emb_Zimage_npu
)

RMSNorm.forward = rms_norm_forward_npu
Qwen3RMSNorm.forward = rms_norm_forward_transformers_npu
Attention.apply_rotary_emb = rotary_emb_Zimage_npu
Comment on lines +677 to +688
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import from transformers on line 678 can fail if the library isn't installed or the module path changes. To make this patching more robust, it's best to wrap the transformers-related import and monkey-patch in a try...except ImportError block. This prevents a missing optional dependency from crashing the application, while still allowing other patches to be applied.

Suggested change
from ..models.general_modules import RMSNorm
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm
from ..models.z_image_dit import Attention
from ..core.npu_patch.npu_fused_operator import (
rms_norm_forward_npu,
rms_norm_forward_transformers_npu,
rotary_emb_Zimage_npu
)
RMSNorm.forward = rms_norm_forward_npu
Qwen3RMSNorm.forward = rms_norm_forward_transformers_npu
Attention.apply_rotary_emb = rotary_emb_Zimage_npu
from ..models.general_modules import RMSNorm
from ..models.z_image_dit import Attention
from ..core.npu_patch.npu_fused_operator import (
rms_norm_forward_npu,
rms_norm_forward_transformers_npu,
rotary_emb_Zimage_npu
)
RMSNorm.forward = rms_norm_forward_npu
Attention.apply_rotary_emb = rotary_emb_Zimage_npu
try:
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm
Qwen3RMSNorm.forward = rms_norm_forward_transformers_npu
except ImportError:
pass # Silently ignore if transformers is not installed

Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ accelerate launch --config_file examples/z_image/model_training/full/accelerate_
--output_path "./models/train/Z-Image-Turbo_full" \
--trainable_models "dit" \
--use_gradient_checkpointing \
--dataset_num_workers 8
--dataset_num_workers 8 \
--enable_npu_patch
5 changes: 4 additions & 1 deletion examples/z_image/model_training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ def __init__(
offload_models=None,
device="cpu",
task="sft",
enable_npu_patch=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value for enable_npu_patch is True here, but the corresponding command-line argument in z_image_parser defaults to False. To maintain consistency between programmatic use and CLI use, it's better to have the same default value. I suggest changing this to False.

Suggested change
enable_npu_patch=True,
enable_npu_patch=False,

):
super().__init__()
# Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, enable_npu_patch=enable_npu_patch)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)

# Training mode
Expand Down Expand Up @@ -94,6 +95,7 @@ def z_image_parser():
parser = add_general_config(parser)
parser = add_image_size_config(parser)
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
parser.add_argument("--enable_npu_patch", default=False, action="store_true", help="Whether to use npu fused operator patch to improve performance in NPU.")
return parser


Expand Down Expand Up @@ -136,6 +138,7 @@ def z_image_parser():
offload_models=args.offload_models,
task=args.task,
device=accelerator.device,
enable_npu_patch=args.enable_npu_patch
)
model_logger = ModelLogger(
args.output_path,
Expand Down