-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[model][NPU]:Add NPU fusion operator patch to Zimage model to improve performance #1256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| import torch | ||
| from ..device.npu_compatible_device import get_device_type | ||
| try: | ||
| import torch_npu | ||
| except: | ||
| pass | ||
|
|
||
|
|
||
| def rms_norm_forward_npu(self, hidden_states): | ||
| "npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py" | ||
| if hidden_states.dtype != self.weight.dtype: | ||
| hidden_states = hidden_states.to(self.weight.dtype) | ||
| return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0] | ||
|
|
||
|
|
||
| def rms_norm_forward_transformers_npu(self, hidden_states): | ||
| "npu rms fused operator for transformers" | ||
| if hidden_states.dtype != self.weight.dtype: | ||
| hidden_states = hidden_states.to(self.weight.dtype) | ||
| return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] | ||
|
|
||
|
|
||
| def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor): | ||
| "npu rope fused operator for Zimage" | ||
| with torch.amp.autocast(get_device_type(), enabled=False): | ||
| freqs_cis = freqs_cis.unsqueeze(2) | ||
| cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1) | ||
| cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2) | ||
| sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2) | ||
| return torch_npu.npu_rotary_mul(x_in, cos, sin).to(x_in) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Union, List, Optional, Tuple, Iterable, Dict | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..core.device.npu_compatible_device import get_device_type | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..core.device.npu_compatible_device import get_device_type, IS_NPU_AVAILABLE | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..diffusion import FlowMatchScheduler | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..core import ModelConfig, gradient_checkpoint_forward | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..core.data.operators import ImageCropAndResize | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -63,6 +63,7 @@ def from_pretrained( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_configs: list[ModelConfig] = [], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vram_limit: float = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| enable_npu_patch: bool = True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Initialize pipeline | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -84,6 +85,8 @@ def from_pretrained( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # VRAM Management | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pipe.vram_management_enabled = pipe.check_vram_management_state() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # NPU patch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| apply_npu_patch(enable_npu_patch) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return pipe | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -667,3 +670,19 @@ def model_fn_z_image_turbo( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x = rearrange(x, "C B H W -> B C H W") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x = -x | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return x | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def apply_npu_patch(enable_npu_patch: bool=True): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if IS_NPU_AVAILABLE and enable_npu_patch: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..models.general_modules import RMSNorm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..models.z_image_dit import Attention | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..core.npu_patch.npu_fused_operator import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rms_norm_forward_npu, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rms_norm_forward_transformers_npu, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_emb_Zimage_npu | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| RMSNorm.forward = rms_norm_forward_npu | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Qwen3RMSNorm.forward = rms_norm_forward_transformers_npu | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Attention.apply_rotary_emb = rotary_emb_Zimage_npu | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+677
to
+688
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The import from
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -20,12 +20,13 @@ def __init__( | |||||
| offload_models=None, | ||||||
| device="cpu", | ||||||
| task="sft", | ||||||
| enable_npu_patch=True, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default value for
Suggested change
|
||||||
| ): | ||||||
| super().__init__() | ||||||
| # Load models | ||||||
| model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) | ||||||
| tokenizer_config = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) | ||||||
| self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) | ||||||
| self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, enable_npu_patch=enable_npu_patch) | ||||||
| self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) | ||||||
|
|
||||||
| # Training mode | ||||||
|
|
@@ -94,6 +95,7 @@ def z_image_parser(): | |||||
| parser = add_general_config(parser) | ||||||
| parser = add_image_size_config(parser) | ||||||
| parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") | ||||||
| parser.add_argument("--enable_npu_patch", default=False, action="store_true", help="Whether to use npu fused operator patch to improve performance in NPU.") | ||||||
| return parser | ||||||
|
|
||||||
|
|
||||||
|
|
@@ -136,6 +138,7 @@ def z_image_parser(): | |||||
| offload_models=args.offload_models, | ||||||
| task=args.task, | ||||||
| device=accelerator.device, | ||||||
| enable_npu_patch=args.enable_npu_patch | ||||||
| ) | ||||||
| model_logger = ModelLogger( | ||||||
| args.output_path, | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bare
except:is too broad and can hide other unexpected errors. It's better to catch the specificImportErrorthat can occur iftorch_npuis not installed.