diff --git a/diffsynth/core/npu_patch/npu_fused_operator.py b/diffsynth/core/npu_patch/npu_fused_operator.py new file mode 100644 index 000000000..5b28eea96 --- /dev/null +++ b/diffsynth/core/npu_patch/npu_fused_operator.py @@ -0,0 +1,30 @@ +import torch +from ..device.npu_compatible_device import get_device_type +try: + import torch_npu +except: + pass + + +def rms_norm_forward_npu(self, hidden_states): + "npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py" + if hidden_states.dtype != self.weight.dtype: + hidden_states = hidden_states.to(self.weight.dtype) + return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0] + + +def rms_norm_forward_transformers_npu(self, hidden_states): + "npu rms fused operator for transformers" + if hidden_states.dtype != self.weight.dtype: + hidden_states = hidden_states.to(self.weight.dtype) + return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] + + +def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor): + "npu rope fused operator for Zimage" + with torch.amp.autocast(get_device_type(), enabled=False): + freqs_cis = freqs_cis.unsqueeze(2) + cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1) + cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2) + sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2) + return torch_npu.npu_rotary_mul(x_in, cos, sin).to(x_in) \ No newline at end of file diff --git a/diffsynth/models/z_image_dit.py b/diffsynth/models/z_image_dit.py index 70574f01a..810def26f 100644 --- a/diffsynth/models/z_image_dit.py +++ b/diffsynth/models/z_image_dit.py @@ -88,6 +88,14 @@ def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_k self.norm_q = RMSNorm(head_dim, eps=1e-5) self.norm_k = RMSNorm(head_dim, eps=1e-5) + # Apply RoPE + def apply_rotary_emb(self, x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast(get_device_type(), enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) # todo + def forward(self, hidden_states, freqs_cis, attention_mask): query = self.to_q(hidden_states) key = self.to_k(hidden_states) @@ -103,17 +111,9 @@ def forward(self, hidden_states, freqs_cis, attention_mask): if self.norm_k is not None: key = self.norm_k(key) - # Apply RoPE - def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - with torch.amp.autocast(get_device_type(), enabled=False): - x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.unsqueeze(2) - x_out = torch.view_as_real(x * freqs_cis).flatten(3) - return x_out.type_as(x_in) # todo - if freqs_cis is not None: - query = apply_rotary_emb(query, freqs_cis) - key = apply_rotary_emb(key, freqs_cis) + query = self.apply_rotary_emb(query, freqs_cis) + key = self.apply_rotary_emb(key, freqs_cis) # Cast to correct dtype dtype = query.dtype diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index 2c5b68730..2bf02bf8e 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -6,7 +6,7 @@ import numpy as np from typing import Union, List, Optional, Tuple, Iterable, Dict -from ..core.device.npu_compatible_device import get_device_type +from ..core.device.npu_compatible_device import get_device_type, IS_NPU_AVAILABLE from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..core.data.operators import ImageCropAndResize @@ -63,6 +63,7 @@ def from_pretrained( model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), vram_limit: float = None, + enable_npu_patch: bool = True, ): # Initialize pipeline pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype) @@ -84,6 +85,8 @@ def from_pretrained( # VRAM Management pipe.vram_management_enabled = pipe.check_vram_management_state() + # NPU patch + apply_npu_patch(enable_npu_patch) return pipe @@ -667,3 +670,19 @@ def model_fn_z_image_turbo( x = rearrange(x, "C B H W -> B C H W") x = -x return x + + +def apply_npu_patch(enable_npu_patch: bool=True): + if IS_NPU_AVAILABLE and enable_npu_patch: + from ..models.general_modules import RMSNorm + from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm + from ..models.z_image_dit import Attention + from ..core.npu_patch.npu_fused_operator import ( + rms_norm_forward_npu, + rms_norm_forward_transformers_npu, + rotary_emb_Zimage_npu + ) + + RMSNorm.forward = rms_norm_forward_npu + Qwen3RMSNorm.forward = rms_norm_forward_transformers_npu + Attention.apply_rotary_emb = rotary_emb_Zimage_npu diff --git a/examples/z_image/model_training/special/npu_training/Z-Image-Turbo-NPU.sh b/examples/z_image/model_training/special/npu_training/Z-Image-Turbo-NPU.sh index 93cc645d9..75938bfe9 100644 --- a/examples/z_image/model_training/special/npu_training/Z-Image-Turbo-NPU.sh +++ b/examples/z_image/model_training/special/npu_training/Z-Image-Turbo-NPU.sh @@ -13,4 +13,5 @@ accelerate launch --config_file examples/z_image/model_training/full/accelerate_ --output_path "./models/train/Z-Image-Turbo_full" \ --trainable_models "dit" \ --use_gradient_checkpointing \ - --dataset_num_workers 8 + --dataset_num_workers 8 \ + --enable_npu_patch diff --git a/examples/z_image/model_training/train.py b/examples/z_image/model_training/train.py index b4c76e58f..10eb725c7 100644 --- a/examples/z_image/model_training/train.py +++ b/examples/z_image/model_training/train.py @@ -20,12 +20,13 @@ def __init__( offload_models=None, device="cpu", task="sft", + enable_npu_patch=True, ): super().__init__() # Load models model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) tokenizer_config = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) - self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) + self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, enable_npu_patch=enable_npu_patch) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) # Training mode @@ -94,6 +95,7 @@ def z_image_parser(): parser = add_general_config(parser) parser = add_image_size_config(parser) parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") + parser.add_argument("--enable_npu_patch", default=False, action="store_true", help="Whether to use npu fused operator patch to improve performance in NPU.") return parser @@ -136,6 +138,7 @@ def z_image_parser(): offload_models=args.offload_models, task=args.task, device=accelerator.device, + enable_npu_patch=args.enable_npu_patch ) model_logger = ModelLogger( args.output_path,