From 4533474043fbe10d6641ed85aea75fe1c1cc104a Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Thu, 2 Apr 2026 16:39:42 +0800 Subject: [PATCH 1/9] Add ERNIE-Image --- .../api/models/ernie_image_transformer2d.md | 19 + docs/source/en/api/pipelines/ernie_image.md | 57 +++ src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_ernie_image.py | 311 ++++++++++++ src/diffusers/pipelines/__init__.py | 2 + .../pipelines/ernie_image/__init__.py | 47 ++ .../ernie_image/pipeline_ernie_image.py | 457 ++++++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 15 + .../dummy_torch_and_transformers_objects.py | 15 + .../test_models_transformer_ernie_image.py | 199 ++++++++ 12 files changed, 1129 insertions(+) create mode 100644 docs/source/en/api/models/ernie_image_transformer2d.md create mode 100644 docs/source/en/api/pipelines/ernie_image.md create mode 100644 src/diffusers/models/transformers/transformer_ernie_image.py create mode 100644 src/diffusers/pipelines/ernie_image/__init__.py create mode 100644 src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py create mode 100644 tests/models/transformers/test_models_transformer_ernie_image.py diff --git a/docs/source/en/api/models/ernie_image_transformer2d.md b/docs/source/en/api/models/ernie_image_transformer2d.md new file mode 100644 index 000000000000..058616c47814 --- /dev/null +++ b/docs/source/en/api/models/ernie_image_transformer2d.md @@ -0,0 +1,19 @@ + + +# ErnieImageTransformer2DModel + +A Transformer model for image-like data from [Ernie-Image](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo). + +## ErnieImageTransformer2DModel + +[[autodoc]] ErnieImageTransformer2DModel \ No newline at end of file diff --git a/docs/source/en/api/pipelines/ernie_image.md b/docs/source/en/api/pipelines/ernie_image.md new file mode 100644 index 000000000000..34c1dc4cd489 --- /dev/null +++ b/docs/source/en/api/pipelines/ernie_image.md @@ -0,0 +1,57 @@ + + +# Ernie-Image + +
+ LoRA +
+ +[Ernie-Image](https://huggingface.co/papers/2511.22699) is a powerful and highly efficient image generation model with 6B parameters. Currently there's only one model with two more to be released: + +|Model|Hugging Face| +|---|---| +|Ernie-Image|https://huggingface.co/Tongyi-MAI/Ernie-Image-Turbo| + +## Ernie-Image + +Ernie-Image-Turbo is a distilled version of Ernie-Image that matches or exceeds leading competitors with only 8 NFEs (Number of Function Evaluations). It offers sub-second inference latency on enterprise-grade H800 GPUs and fits comfortably within 16G VRAM consumer devices. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence. + +## ZImagePipeline + +Use [`ZImageImg2ImgPipeline`] to transform an existing image based on a text prompt. + +```python +import torch +from diffusers import ErnieImagePipeline +from diffusers.utils import load_image + +pipe = ErnieImagePipeline.from_pretrained("Tongyi-MAI/Ernie-Image-Turbo", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +prompt = "一只黑白相间的中华田园犬" +images = pipe( + prompt=prompt, + height=1024, + width=1024, + num_inference_steps=50, + guidance_scale=5.0, + generator=generator, +).images +images[0].save("ernie-image-output.png") +``` + +## ZImagePipeline + +[[autodoc]] ErnieImagePipeline + - all + - __call__ diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 7d966452d1a2..8fea3482d1c3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -292,6 +292,7 @@ "ZImageControlNetModel", "ZImageTransformer2DModel", "attention_backend", + "ErnieImageTransformer2DModel" ] ) _import_structure["modular_pipelines"].extend( @@ -732,6 +733,7 @@ "ZImageInpaintPipeline", "ZImageOmniPipeline", "ZImagePipeline", + "ErnieImagePipeline", ] ) @@ -1079,6 +1081,7 @@ ZImageControlNetModel, ZImageTransformer2DModel, attention_backend, + ErnieImageTransformer2DModel, ) from .modular_pipelines import ( AutoPipelineBlocks, @@ -1493,6 +1496,7 @@ ZImageInpaintPipeline, ZImageOmniPipeline, ZImagePipeline, + ErnieImagePipeline, ) try: diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7ded56049833..6c62db841b14 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -101,6 +101,7 @@ _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"] _import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"] _import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"] + _import_structure["transformers.transformer_ernie_image"] = ["ErnieImageTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"] _import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"] @@ -218,6 +219,7 @@ DiTTransformer2DModel, DualTransformer2DModel, EasyAnimateTransformer3DModel, + ErnieImageTransformer2DModel, Flux2Transformer2DModel, FluxTransformer2DModel, GlmImageTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 45157ee91808..bf9bd49881e4 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -52,3 +52,4 @@ from .transformer_wan_animate import WanAnimateTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel from .transformer_z_image import ZImageTransformer2DModel + from .transformer_ernie_image import ErnieImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py new file mode 100644 index 000000000000..63fe3c47b811 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -0,0 +1,311 @@ +# Copyright (c) 2025, Baidu Inc. All rights reserved. +# Author: fengzhida (fengzhida@baidu.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Ernie-Image Transformer2DModel for HuggingFace Diffusers. +""" + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...configuration_utils import ConfigMixin, register_to_config +from ..embeddings import Timesteps +from ..modeling_utils import ModelMixin +from ...utils import BaseOutput + + +@dataclass +class ErnieImageTransformer2DModelOutput(BaseOutput): + sample: torch.Tensor + + +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta ** scale) + out = torch.einsum("...n,d->...nd", pos, omega) + return out.float() + + +class EmbedND3(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = list(axes_dim) + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1) + emb = emb.unsqueeze(1).permute(2, 0, 1, 3) + return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) + + +class PatchEmbedDynamic(nn.Module): + def __init__(self, in_channels: int, embed_dim: int, patch_size: int): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + B, D, Hp, Wp = x.shape + return x.reshape(B, D, Hp * Wp).transpose(1, 2).contiguous() + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int): + super().__init__() + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + self.act = nn.SiLU() + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + sample = self.linear_1(sample.to(self.linear_1.weight.dtype)) + return self.linear_2(self.act(sample).to(self.linear_2.weight.dtype)) + + +class RMSNorm(nn.Module): + """RMSNorm implementation matching Megatron's TENorm.""" + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # 内部计算转换为FP32,对齐transform engine的TENorm计算精度 + x_norm = self._norm(x.float()) + output = x_norm * self.weight.float() + return output.to(x.dtype) + + +class Attention(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, eps: float = 1e-6, qk_layernorm: bool = True): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + # Separate Q, K, V projections (matches converted weights) + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.linear_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.qk_layernorm = qk_layernorm + if qk_layernorm: + # self.q_layernorm = RMSNorm(self.head_dim, eps=eps) + # self.k_layernorm = RMSNorm(self.head_dim, eps=eps) + self.q_layernorm = torch.nn.RMSNorm(self.head_dim, eps=eps) + self.k_layernorm = torch.nn.RMSNorm(self.head_dim, eps=eps) + + def forward(self, x: torch.Tensor, rotary_pos_emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + S, B, H = x.shape + # Separate Q, K, V projections + q = self.q_proj(x).view(S, B, self.num_heads, self.head_dim).contiguous() + k = self.k_proj(x).view(S, B, self.num_heads, self.head_dim).contiguous() + v = self.v_proj(x).view(S, B, self.num_heads, self.head_dim).contiguous() + if self.qk_layernorm: + q, k = self.q_layernorm(q), self.k_layernorm(k) + q, k = self._apply_rotary(q, rotary_pos_emb), self._apply_rotary(k, rotary_pos_emb) + q, k, v = q.permute(1, 2, 0, 3), k.permute(1, 2, 0, 3), v.permute(1, 2, 0, 3) + attn_mask = ~attention_mask if attention_mask is not None else None + out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) + return self.linear_proj(out.permute(2, 0, 1, 3).reshape(S, B, H)) + + def _apply_rotary(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + """Apply rotary position embedding. + + Matches Megatron's _apply_rotary_pos_emb_bshd with rotary_interleaved=False. + freqs: [S, B, 1, dim] containing angles [θ0, θ0, θ1, θ1, ...] + """ + rot_dim = freqs.shape[-1] + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + + cos_ = torch.cos(freqs).to(x.dtype) + sin_ = torch.sin(freqs).to(x.dtype) + + # Non-interleaved rotate_half: [-x2, x1] + x1, x2 = x.chunk(2, dim=-1) + x_rotated = torch.cat((-x2, x1), dim=-1) + + x = x * cos_ + x_rotated * sin_ + return torch.cat((x, x_pass), dim=-1) + + +class FeedForward(nn.Module): + def __init__(self, hidden_size: int, ffn_hidden_size: int): + super().__init__() + # Separate gate and up projections (matches converted weights) + self.gate_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False) + self.linear_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x))) + + +class SharedAdaLNBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, qk_layernorm: bool = True): + super().__init__() + # self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps) + self.adaLN_sa_ln = torch.nn.RMSNorm(hidden_size, eps=eps) + self.self_attention = Attention(hidden_size, num_heads, eps=eps, qk_layernorm=qk_layernorm) + # self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps) + self.adaLN_mlp_ln = torch.nn.RMSNorm(hidden_size, eps=eps) + self.mlp = FeedForward(hidden_size, ffn_hidden_size) + + def forward(self, x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask=None): + residual = x + x = self.adaLN_sa_ln(x) + x = self._modulate(x, shift_msa, scale_msa) + attn_out = self.self_attention(x, rotary_pos_emb, attention_mask) + x = residual + self._apply_gate(gate_msa, attn_out) + residual = x + x = self._modulate(self.adaLN_mlp_ln(x), shift_mlp, scale_mlp) + return residual + self._apply_gate(gate_mlp, self.mlp(x)) + + def _modulate(self, x, shift, scale): + """AdaLN modulation: x * (1 + scale) + shift,在FP32下计算确保数值稳定""" + x_fp32 = x.float() + shift_fp32 = shift.float() + scale_fp32 = scale.float() + out = x_fp32 * (1 + scale_fp32) + shift_fp32 + return out.to(x.dtype) + + def _apply_gate(self, gate, x): + """Gate乘法在FP32下计算,对齐TE精度""" + return (gate.float() * x.float()).to(x.dtype) + +class AdaLNContinuous(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=eps) + self.linear = nn.Linear(hidden_size, hidden_size * 2) + # 对齐 Megatron 实现:zero init + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: + scale, shift = self.linear(conditioning).chunk(2, dim=-1) + x = self.norm(x) + # Broadcast conditioning to sequence dimension + x = x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0) + return x + + +class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + hidden_size: int = 3072, + num_attention_heads: int = 24, + num_layers: int = 24, + ffn_hidden_size: int = 8192, + in_channels: int = 128, + out_channels: int = 128, + patch_size: int = 1, + text_in_dim: int = 2560, + rope_theta: int = 256, + rope_axes_dim: Tuple[int, int, int] = (32, 48, 48), + eps: float = 1e-6, + qk_layernorm: bool = True, + ): + super().__init__() + self.gradient_checkpointing = False + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.num_layers = num_layers + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.text_in_dim = text_in_dim + + self.x_embedder = PatchEmbedDynamic(in_channels, hidden_size, patch_size) + self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None + self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0) + self.time_embedding = TimestepEmbedding(hidden_size, hidden_size) + self.pos_embed = EmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size)) + nn.init.zeros_(self.adaLN_modulation[-1].weight) + nn.init.zeros_(self.adaLN_modulation[-1].bias) + self.layers = nn.ModuleList([SharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm) for _ in range(num_layers)]) + self.final_norm = AdaLNContinuous(hidden_size, eps) + self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels) + nn.init.zeros_(self.final_linear.weight) + nn.init.zeros_(self.final_linear.bias) + + def forward(self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: List[torch.Tensor], return_dict: bool = True): + device, dtype = hidden_states.device, hidden_states.dtype + B, C, H, W = hidden_states.shape + p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size + N_img = Hp * Wp + + img_sbh = self.x_embedder(hidden_states).transpose(0, 1).contiguous() + text_bth, text_lens = self._pad_text(encoder_hidden_states, device, dtype) + if self.text_proj is not None and text_bth.numel() > 0: + text_bth = self.text_proj(text_bth) + Tmax = text_bth.shape[1] + text_sbh = text_bth.transpose(0, 1).contiguous() + + x = torch.cat([img_sbh, text_sbh], dim=0) + S = x.shape[0] + + # Position IDs + text_ids = torch.cat([torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1), torch.zeros((B, Tmax, 2), device=device)], dim=-1) if Tmax > 0 else torch.zeros((B, 0, 3), device=device) + grid_yx = torch.stack(torch.meshgrid(torch.arange(Hp, device=device, dtype=torch.float32), torch.arange(Wp, device=device, dtype=torch.float32), indexing="ij"), dim=-1).reshape(-1, 2) + image_ids = torch.cat([text_lens.float().view(B, 1, 1).expand(-1, N_img, -1), grid_yx.view(1, N_img, 2).expand(B, -1, -1)], dim=-1) + rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)) + + # Attention mask + valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) if Tmax > 0 else torch.zeros((B, 0), device=device, dtype=torch.bool) + attention_mask = (~torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1))[:, None, None, :] + + # AdaLN + c = self.time_embedding(self.time_proj(timestep.to(dtype))) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)] + for layer in self.layers: + if self.gradient_checkpointing and self.training: + x = self._gradient_checkpointing_func( + layer.__call__, + x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask + ) + else: + x = layer(x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask) + x = self.final_norm(x, c).type_as(x) + patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous() + output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W) + + return ErnieImageTransformer2DModelOutput(sample=output) if return_dict else (output,) + + def _pad_text(self, text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype): + B = len(text_hiddens) + if B == 0: + return torch.zeros((0, 0, self.text_in_dim), device=device, dtype=dtype), torch.zeros((0,), device=device, dtype=torch.long) + normalized = [th.squeeze(1).to(device).to(dtype) if th.dim() == 3 else th.to(device).to(dtype) for th in text_hiddens] + lens = torch.tensor([t.shape[0] for t in normalized], device=device, dtype=torch.long) + Tmax = int(lens.max().item()) + text_bth = torch.zeros((B, Tmax, self.text_in_dim), device=device, dtype=dtype) + for i, t in enumerate(normalized): + text_bth[i, :t.shape[0], :] = t + return text_bth, lens diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3dafb56fdd65..1985a12846e9 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -313,6 +313,7 @@ _import_structure["mochi"] = ["MochiPipeline"] _import_structure["musicldm"] = ["MusicLDMPipeline"] _import_structure["omnigen"] = ["OmniGenPipeline"] + _import_structure["ernie_image"] = ["ErnieImagePipeline"] _import_structure["ovis_image"] = ["OvisImagePipeline"] _import_structure["visualcloze"] = ["VisualClozePipeline", "VisualClozeGenerationPipeline"] _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] @@ -750,6 +751,7 @@ from .mochi import MochiPipeline from .musicldm import MusicLDMPipeline from .omnigen import OmniGenPipeline + from .ernie_image import ErnieImagePipeline from .ovis_image import OvisImagePipeline from .pag import ( AnimateDiffPAGPipeline, diff --git a/src/diffusers/pipelines/ernie_image/__init__.py b/src/diffusers/pipelines/ernie_image/__init__.py new file mode 100644 index 000000000000..97355fb609f3 --- /dev/null +++ b/src/diffusers/pipelines/ernie_image/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa: F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_ernie_image"] = ["ErnieImagePipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_ernie_image import ErnieImagePipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py new file mode 100644 index 000000000000..29c1b9fd9bab --- /dev/null +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -0,0 +1,457 @@ +# Copyright (c) 2025, Baidu Inc. All rights reserved. +# Author: fengzhida (fengzhida@baidu.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Ernie-Image Pipeline for HuggingFace Diffusers. +""" + +import json +import os +import torch +from PIL import Image +from dataclasses import dataclass +from typing import Callable, List, Optional, Tuple, Union +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import BaseOutput +from ...models import AutoencoderKLFlux2 +from ...models.transformers import ErnieImageTransformer2DModel + + +@dataclass +class ErnieImagePipelineOutput(BaseOutput): + images: List[Image.Image] + + +class ErnieImagePipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using ErnieImageTransformer2DModel. + + This pipeline uses: + - A custom DiT transformer model + - A Flux2-style VAE for encoding/decoding latents + - A text encoder (e.g., Qwen) for text conditioning + - Flow Matching Euler Discrete Scheduler + """ + + model_cpu_offload_seq = "pe->text_encoder->transformer->vae" + + def __init__( + self, + transformer, + vae, + text_encoder, + tokenizer, + scheduler: FlowMatchEulerDiscreteScheduler, + pe=None, + pe_tokenizer=None, + ): + super().__init__() + self.register_modules( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + pe=pe, + pe_tokenizer=pe_tokenizer, + ) + self.vae_scale_factor = 16 # VAE downsample factor + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): + """ + Load pipeline from a pretrained model directory. + + Args: + pretrained_model_name_or_path: Path to the saved pipeline directory + **kwargs: Additional arguments passed to component loaders + - torch_dtype: Data type for model weights (default: torch.bfloat16) + - device_map: Device map for model loading + - trust_remote_code: Whether to trust remote code for text encoder + + Returns: + ErnieImagePipeline instance + """ + + torch_dtype = kwargs.pop("torch_dtype", torch.bfloat16) + trust_remote_code = kwargs.pop("trust_remote_code", True) + + # Determine whether this is a local directory or a Hub repo ID. + # For local paths we join sub-directories; for Hub IDs we use `subfolder`. + is_local = os.path.isdir(pretrained_model_name_or_path) + + def _path_or_subfolder(subfolder: str): + if is_local: + return {"pretrained_model_name_or_path": os.path.join(pretrained_model_name_or_path, subfolder)} + return {"pretrained_model_name_or_path": pretrained_model_name_or_path, "subfolder": subfolder} + + # Load transformer + transformer = ErnieImageTransformer2DModel.from_pretrained( + **_path_or_subfolder("transformer"), + torch_dtype=torch_dtype, + ) + + # Load VAE + vae = AutoencoderKLFlux2.from_pretrained( + **_path_or_subfolder("vae"), + torch_dtype=torch_dtype, + ) + + # Load text encoder + text_encoder = AutoModel.from_pretrained( + **_path_or_subfolder("text_encoder"), + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + ) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained( + **_path_or_subfolder("tokenizer"), + trust_remote_code=trust_remote_code, + ) + + # Load PE + pe = AutoModelForCausalLM.from_pretrained( + **_path_or_subfolder("pe"), + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + ) + + # Load PE tokenizer (auto-picks up chat_template.jinja in the same dir) + pe_tokenizer = AutoTokenizer.from_pretrained( + **_path_or_subfolder("pe"), + trust_remote_code=trust_remote_code, + ) + + # Load scheduler + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + **_path_or_subfolder("scheduler"), + ) + + return cls( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + pe=pe, + pe_tokenizer=pe_tokenizer, + scheduler=scheduler, + ) + + @torch.no_grad() + def _enhance_prompt_with_pe( + self, + prompt: str, + device: torch.device, + width: int = 1024, + height: int = 1024, + system_prompt: Optional[str] = None, + max_length: int = 1536, + temperature: float = 0.6, + top_p: float = 0.95, + ) -> str: + """Use PE model to rewrite/enhance a short prompt via chat_template.""" + # Build user message as JSON carrying prompt text and target resolution + user_content = json.dumps( + {"prompt": prompt, "width": width, "height": height}, + ensure_ascii=False, + ) + messages = [] + if system_prompt is not None: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": user_content}) + + # apply_chat_template picks up the chat_template.jinja loaded with pe_tokenizer + input_text = self.pe_tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, # "Output:" is already in the user block + ) + # pe_device = next(self.pe.parameters()).device + inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(device) + + output_ids = self.pe.generate( + **inputs, + max_new_tokens=max_length, + do_sample=temperature != 1.0 or top_p != 1.0, + temperature=temperature, + top_p=top_p, + pad_token_id=self.pe_tokenizer.pad_token_id, + eos_token_id=self.pe_tokenizer.eos_token_id, + ) + # Decode only newly generated tokens + generated_ids = output_ids[0][inputs["input_ids"].shape[1]:] + return self.pe_tokenizer.decode(generated_ids, skip_special_tokens=True).strip() + + @torch.no_grad() + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: torch.device, + num_images_per_prompt: int = 1, + max_length: int = 64, + ) -> List[torch.Tensor]: + """Encode text prompts to embeddings.""" + if isinstance(prompt, str): + prompt = [prompt] + + text_hiddens = [] + + for p in prompt: + ids = self.tokenizer( + p, + add_special_tokens=True, + truncation=True, + max_length=max_length, + padding=False, + )["input_ids"] + + if len(ids) == 0: + if self.tokenizer.bos_token_id is not None: + ids = [self.tokenizer.bos_token_id] + else: + ids = [0] + + input_ids = torch.tensor([ids], device=device) + with torch.no_grad(): + outputs = self.text_encoder( + input_ids=input_ids, + output_hidden_states=True, + ) + # Use second to last hidden state (matches training) + hidden = outputs.hidden_states[-2][0] # [T, H] + + # Repeat for num_images_per_prompt + for _ in range(num_images_per_prompt): + text_hiddens.append(hidden) + + return text_hiddens + + @torch.no_grad() + def _encode_negative_prompt( + self, + negative_prompt: List[str], + device: torch.device, + num_images_per_prompt: int = 1, + max_length: int = 64, + ) -> List[torch.Tensor]: + """Encode negative prompts for CFG.""" + text_hiddens = [] + + for np in negative_prompt: + ids = self.tokenizer( + np, + add_special_tokens=True, + truncation=True, + max_length=max_length, + padding=False, + )["input_ids"] + + if len(ids) == 0: + if self.tokenizer.bos_token_id is not None: + ids = [self.tokenizer.bos_token_id] + else: + ids = [0] + + input_ids = torch.tensor([ids], device=device) + with torch.no_grad(): + outputs = self.text_encoder( + input_ids=input_ids, + output_hidden_states=True, + ) + hidden = outputs.hidden_states[-2][0] + + for _ in range(num_images_per_prompt): + text_hiddens.append(hidden) + + return text_hiddens + + @staticmethod + def _patchify_latents(latents: torch.Tensor) -> torch.Tensor: + """2x2 patchify: [B, 32, H, W] -> [B, 128, H/2, W/2]""" + b, c, h, w = latents.shape + latents = latents.view(b, c, h // 2, 2, w // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + return latents.reshape(b, c * 4, h // 2, w // 2) + + @staticmethod + def _unpatchify_latents(latents: torch.Tensor) -> torch.Tensor: + """Reverse patchify: [B, 128, H/2, W/2] -> [B, 32, H, W]""" + b, c, h, w = latents.shape + latents = latents.reshape(b, c // 4, 2, 2, h, w) + latents = latents.permute(0, 1, 4, 2, 5, 3) + return latents.reshape(b, c // 4, h * 2, w * 2) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = "", + height: int = 256, + width: int = 256, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_images_per_prompt: int = 1, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + max_length: int = 1536, + use_pe: bool = True, # 默认使用PE进行改写 + ): + """ + Generate images from text prompts. + + Args: + prompt: Text prompt(s) + negative_prompt: Negative prompt(s) for CFG. Default is "". + height: Image height (must be divisible by 16) + width: Image width (must be divisible by 16) + num_inference_steps: Number of denoising steps + guidance_scale: CFG scale (1.0 = no guidance) + num_images_per_prompt: Number of images per prompt + generator: Random generator for reproducibility + latents: Pre-generated latents (optional) + output_type: "pil" or "latent" + return_dict: Whether to return a dataclass + callback: Optional callback function + callback_steps: Steps between callbacks + max_length: Max token length for text encoding + + Returns: + Generated images + """ + device = self._execution_device + dtype = self.transformer.dtype + + self.pe.to(device) + # Validate dimensions + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError(f"Height and width must be divisible by {self.vae_scale_factor}") + + # Handle prompts + if isinstance(prompt, str): + prompt = [prompt] + + # Enhance prompts with PE if enabled + if use_pe and self.pe is not None and self.pe_tokenizer is not None: + prompt = [ + self._enhance_prompt_with_pe(p, device, width=width, height=height, max_length=max_length) + for p in prompt + ] + + batch_size = len(prompt) + total_batch_size = batch_size * num_images_per_prompt + + # Handle negative prompt + if negative_prompt is None: + negative_prompt = "" + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if len(negative_prompt) != batch_size: + raise ValueError(f"negative_prompt must have same length as prompt ({batch_size})") + + # Encode prompts + text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt, max_length) + + # CFG with negative prompt + do_cfg = guidance_scale > 1.0 + if do_cfg: + uncond_text_hiddens = self._encode_negative_prompt( + negative_prompt, device, num_images_per_prompt, max_length + ) + + # Latent dimensions + latent_h = height // self.vae_scale_factor + latent_w = width // self.vae_scale_factor + latent_channels = 128 # After patchify + + # Initialize latents + if latents is None: + latents = torch.randn( + (total_batch_size, latent_channels, latent_h, latent_w), + device=device, + dtype=dtype, + generator=generator, + ) + + # Setup scheduler + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # Denoising loop + if do_cfg: + cfg_text_hiddens = list(uncond_text_hiddens) + list(text_hiddens) + else: + cfg_text_hiddens = text_hiddens + + for i, t in enumerate(self.scheduler.timesteps): + if do_cfg: + latent_model_input = torch.cat([latents, latents], dim=0) + t_batch = torch.full((total_batch_size * 2,), t.item(), device=device, dtype=dtype) + else: + latent_model_input = latents + t_batch = torch.full((total_batch_size,), t.item(), device=device, dtype=dtype) + + # Model prediction + pred = self.transformer( + hidden_states=latent_model_input, + timestep=t_batch, + encoder_hidden_states=cfg_text_hiddens, + return_dict=False, + )[0] + + # Apply CFG + if do_cfg: + pred_uncond, pred_cond = pred.chunk(2, dim=0) + pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + + # Scheduler step + latents = self.scheduler.step(pred, t, latents).prev_sample + + # Callback + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + return latents + + # Decode latents to images + # Unnormalize latents using VAE's BN stats + bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device) + bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to(device) + latents = latents * bn_std + bn_mean + + # Unpatchify + latents = self._unpatchify_latents(latents) + + # Decode + images = self.vae.decode(latents, return_dict=False)[0] + + # Post-process + images = (images.clamp(-1, 1) + 1) / 2 + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + images = [Image.fromarray((img * 255).astype("uint8")) for img in images] + + if not return_dict: + return (images,) + + return ErnieImagePipelineOutput(images=images) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index fa37388fe75a..2a5b2bd6b8a2 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1016,6 +1016,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ErnieImageTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class Flux2Transformer2DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 1e4d14566160..aa1c9fbb3c10 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2582,6 +2582,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class ErnieImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class OvisImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_ernie_image.py b/tests/models/transformers/test_models_transformer_ernie_image.py new file mode 100644 index 000000000000..7ef855609ed8 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_ernie_image.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +import unittest + +import torch + +from diffusers import ErnieImageTransformer2DModel + +from ...testing_utils import IS_GITHUB_ACTIONS, torch_device +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin + + +# Ernie-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations +# Cannot use enable_full_determinism() which sets it to True +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" +torch.use_deterministic_algorithms(False) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +if hasattr(torch.backends, "cuda"): + torch.backends.cuda.matmul.allow_tf32 = False + + +@unittest.skipIf( + IS_GITHUB_ACTIONS, + reason="Skipping test-suite inside the CI because the model has `torch.empty()` inside of it during init and we don't have a clear way to override it in the modeling tests.", +) +class ErnieImageTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = ErnieImageTransformer2DModel + main_input_name = "hidden_states" + # We override the items here because the transformer under consideration is small. + model_split_percents = [0.9, 0.9, 0.9] + + def prepare_dummy_input(self, height=16, width=16): + batch_size = 1 + num_channels = 16 + embedding_dim = 16 + sequence_length = 16 + + hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) + encoder_hidden_states = [ + torch.randn((sequence_length, embedding_dim)).to(torch_device) for _ in range(batch_size) + ] + timestep = torch.tensor([1.0]).to(torch_device) + + return {"hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "timestep": timestep} + + @property + def dummy_input(self): + return self.prepare_dummy_input() + + @property + def input_shape(self): + return (16, 16, 16) + + @property + def output_shape(self): + return (16, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "hidden_size": 16, + "num_attention_heads": 1, + "num_layers": 1, + "ffn_hidden_size": 16, + "in_channels": 16, + "out_channels": 16, + "patch_size": 1, + "text_in_dim": 16, + "rope_theta": 256, + "rope_axes_dim": (8, 4, 4), + "eps": 1e-6, + "qk_layernorm": True, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def setUp(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + def tearDown(self): + super().tearDown() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"ErnieImageTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Test is not supported for handling main inputs that are lists.") + def test_training(self): + super().test_training() + + @unittest.skip("Test is not supported for handling main inputs that are lists.") + def test_ema_training(self): + super().test_ema_training() + + @unittest.skip("Test is not supported for handling main inputs that are lists.") + def test_effective_gradient_checkpointing(self): + super().test_effective_gradient_checkpointing() + + @unittest.skip( + "Test needs to be revisited. But we need to ensure `x_pad_token` and `cap_pad_token` are cast to the same dtype as the destination tensor before they are assigned to the padding indices." + ) + def test_layerwise_casting_training(self): + super().test_layerwise_casting_training() + + @unittest.skip( + "TimestepEmbedding uses explicit dtype casting that conflicts with float8 layerwise casting hooks." + ) + def test_layerwise_casting_inference(self): + super().test_layerwise_casting_inference() + + @unittest.skip( + "TimestepEmbedding uses explicit dtype casting that conflicts with float8 layerwise casting hooks." + ) + def test_layerwise_casting_memory(self): + super().test_layerwise_casting_memory() + + @unittest.skip( + "TimestepEmbedding uses explicit dtype casting that conflicts with float8 layerwise casting hooks." + ) + def test_group_offloading_with_layerwise_casting(self): + super().test_group_offloading_with_layerwise_casting() + + @unittest.skip( + "TimestepEmbedding uses explicit dtype casting that conflicts with float8 layerwise casting hooks." + ) + def test_group_offloading_with_layerwise_casting_0(self): + pass + + @unittest.skip( + "TimestepEmbedding uses explicit dtype casting that conflicts with float8 layerwise casting hooks." + ) + def test_group_offloading_with_layerwise_casting_1(self): + pass + + @unittest.skip("Test is not supported for handling main inputs that are lists.") + def test_outputs_equivalence(self): + super().test_outputs_equivalence() + + @unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.") + def test_group_offloading(self): + super().test_group_offloading() + + @unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.") + def test_group_offloading_with_disk(self): + super().test_group_offloading_with_disk() + + +class ErnieImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = ErnieImageTransformer2DModel + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + + def prepare_init_args_and_inputs_for_common(self): + return ErnieImageTransformerTests().prepare_init_args_and_inputs_for_common() + + def prepare_dummy_input(self, height, width): + return ErnieImageTransformerTests().prepare_dummy_input(height=height, width=width) + + @unittest.skip( + "The repeated block in this model is ZImageTransformerBlock, which is used for noise_refiner, context_refiner, and layers. As a consequence of this, the inputs recorded for the block would vary during compilation and full compilation with fullgraph=True would trigger recompilation at least thrice." + ) + def test_torch_compile_recompilation_and_graph_break(self): + super().test_torch_compile_recompilation_and_graph_break() + + @unittest.skip("Fullgraph AoT is broken") + def test_compile_works_with_aot(self): + super().test_compile_works_with_aot() + + @unittest.skip("Fullgraph is broken") + def test_compile_on_different_shapes(self): + super().test_compile_on_different_shapes() From 4049a2072901ea38e9ad664c36944df6fec79f02 Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Thu, 2 Apr 2026 17:51:24 +0800 Subject: [PATCH 2/9] Update doc --- .../api/models/ernie_image_transformer2d.md | 2 +- docs/source/en/api/pipelines/ernie_image.md | 24 +++++++------------ 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/docs/source/en/api/models/ernie_image_transformer2d.md b/docs/source/en/api/models/ernie_image_transformer2d.md index 058616c47814..8be37d56bf42 100644 --- a/docs/source/en/api/models/ernie_image_transformer2d.md +++ b/docs/source/en/api/models/ernie_image_transformer2d.md @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # ErnieImageTransformer2DModel -A Transformer model for image-like data from [Ernie-Image](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo). +A Transformer model for image-like data from [ERNIE-Image-Turbo](https://huggingface.co/baidu/ERNIE-Image-Turbo). ## ErnieImageTransformer2DModel diff --git a/docs/source/en/api/pipelines/ernie_image.md b/docs/source/en/api/pipelines/ernie_image.md index 34c1dc4cd489..6162713f2c75 100644 --- a/docs/source/en/api/pipelines/ernie_image.md +++ b/docs/source/en/api/pipelines/ernie_image.md @@ -16,26 +16,26 @@ specific language governing permissions and limitations under the License. LoRA -[Ernie-Image](https://huggingface.co/papers/2511.22699) is a powerful and highly efficient image generation model with 6B parameters. Currently there's only one model with two more to be released: +[ERNIE-Image] is a powerful and highly efficient image generation model with 8B parameters. Currently there's only one model with two more to be released: |Model|Hugging Face| |---|---| -|Ernie-Image|https://huggingface.co/Tongyi-MAI/Ernie-Image-Turbo| +|ERNIE-Image-Turbo|https://huggingface.co/baidu/ERNIE-Image-Turbo| ## Ernie-Image -Ernie-Image-Turbo is a distilled version of Ernie-Image that matches or exceeds leading competitors with only 8 NFEs (Number of Function Evaluations). It offers sub-second inference latency on enterprise-grade H800 GPUs and fits comfortably within 16G VRAM consumer devices. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence. +ERNIE-Image-Turbo is a distilled version of ERNIE-Image that matches or exceeds leading competitors with only 8 NFEs (Number of Function Evaluations). It offers sub-second inference latency on enterprise-grade H800 GPUs and fits comfortably within 16G VRAM consumer devices. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence. -## ZImagePipeline +## ErnieImagePipeline -Use [`ZImageImg2ImgPipeline`] to transform an existing image based on a text prompt. +Use [`ErnieImagePipeline`] to generate an image based on a text prompt. ```python import torch from diffusers import ErnieImagePipeline from diffusers.utils import load_image -pipe = ErnieImagePipeline.from_pretrained("Tongyi-MAI/Ernie-Image-Turbo", torch_dtype=torch.bfloat16) +pipe = ErnieImagePipeline.from_pretrained("baidu/ERNIE-Image-Turbo", torch_dtype=torch.bfloat16) pipe.to("cuda") prompt = "一只黑白相间的中华田园犬" @@ -43,15 +43,9 @@ images = pipe( prompt=prompt, height=1024, width=1024, - num_inference_steps=50, + num_inference_steps=8, guidance_scale=5.0, generator=generator, ).images -images[0].save("ernie-image-output.png") -``` - -## ZImagePipeline - -[[autodoc]] ErnieImagePipeline - - all - - __call__ +images[0].save("ernie-image-turbo-output.png") +``` \ No newline at end of file From 579e6c7f6642508796a09f8d69bf3cb6ef9195cb Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Thu, 2 Apr 2026 18:53:56 +0800 Subject: [PATCH 3/9] Update doc --- docs/source/en/api/pipelines/ernie_image.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/ernie_image.md b/docs/source/en/api/pipelines/ernie_image.md index 6162713f2c75..5bb6f550096f 100644 --- a/docs/source/en/api/pipelines/ernie_image.md +++ b/docs/source/en/api/pipelines/ernie_image.md @@ -28,7 +28,7 @@ ERNIE-Image-Turbo is a distilled version of ERNIE-Image that matches or exceeds ## ErnieImagePipeline -Use [`ErnieImagePipeline`] to generate an image based on a text prompt. +Use [`ErnieImagePipeline`] to generate an image based on a text prompt. If you do not want to use PE, please set use_pe=False. ```python import torch From d16d16e9b2b755659e2bbf08509a5b60620d10d9 Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Fri, 3 Apr 2026 22:57:21 +0800 Subject: [PATCH 4/9] Change from Custom-Attention to Diffusers Style Attention --- .../transformers/transformer_ernie_image.py | 187 +++++++++--------- .../ernie_image/pipeline_ernie_image.py | 1 + 2 files changed, 100 insertions(+), 88 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index 63fe3c47b811..7a63e6d6818d 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -28,6 +28,9 @@ from ..embeddings import Timesteps from ..modeling_utils import ModelMixin from ...utils import BaseOutput +from ..normalization import RMSNorm +from ..attention_processor import Attention +from ..attention_dispatch import dispatch_attention_fn @dataclass @@ -52,8 +55,8 @@ def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]): def forward(self, ids: torch.Tensor) -> torch.Tensor: emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1) - emb = emb.unsqueeze(1).permute(2, 0, 1, 3) - return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) + emb = emb.unsqueeze(2) # [B, S, 1, head_dim//2] + return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim] class PatchEmbedDynamic(nn.Module): @@ -76,78 +79,84 @@ def __init__(self, in_channels: int, time_embed_dim: int): self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) def forward(self, sample: torch.Tensor) -> torch.Tensor: - sample = self.linear_1(sample.to(self.linear_1.weight.dtype)) - return self.linear_2(self.act(sample).to(self.linear_2.weight.dtype)) + sample = sample.to(self.linear_1.weight.dtype) + return self.linear_2(self.act(self.linear_1(sample))) -class RMSNorm(nn.Module): - """RMSNorm implementation matching Megatron's TENorm.""" - - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x: torch.Tensor) -> torch.Tensor: - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # 内部计算转换为FP32,对齐transform engine的TENorm计算精度 - x_norm = self._norm(x.float()) - output = x_norm * self.weight.float() - return output.to(x.dtype) +class ErnieImageSingleStreamAttnProcessor: + _attention_backend = None + _parallel_config = None + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) -class Attention(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, eps: float = 1e-6, qk_layernorm: bool = True): - super().__init__() - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = hidden_size // num_heads - # Separate Q, K, V projections (matches converted weights) - self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) - self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) - self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) - self.linear_proj = nn.Linear(hidden_size, hidden_size, bias=False) - self.qk_layernorm = qk_layernorm - if qk_layernorm: - # self.q_layernorm = RMSNorm(self.head_dim, eps=eps) - # self.k_layernorm = RMSNorm(self.head_dim, eps=eps) - self.q_layernorm = torch.nn.RMSNorm(self.head_dim, eps=eps) - self.k_layernorm = torch.nn.RMSNorm(self.head_dim, eps=eps) - - def forward(self, x: torch.Tensor, rotary_pos_emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - S, B, H = x.shape - # Separate Q, K, V projections - q = self.q_proj(x).view(S, B, self.num_heads, self.head_dim).contiguous() - k = self.k_proj(x).view(S, B, self.num_heads, self.head_dim).contiguous() - v = self.v_proj(x).view(S, B, self.num_heads, self.head_dim).contiguous() - if self.qk_layernorm: - q, k = self.q_layernorm(q), self.k_layernorm(k) - q, k = self._apply_rotary(q, rotary_pos_emb), self._apply_rotary(k, rotary_pos_emb) - q, k, v = q.permute(1, 2, 0, 3), k.permute(1, 2, 0, 3), v.permute(1, 2, 0, 3) - attn_mask = ~attention_mask if attention_mask is not None else None - out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) - return self.linear_proj(out.permute(2, 0, 1, 3).reshape(S, B, H)) - - def _apply_rotary(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - """Apply rotary position embedding. - - Matches Megatron's _apply_rotary_pos_emb_bshd with rotary_interleaved=False. - freqs: [S, B, 1, dim] containing angles [θ0, θ0, θ1, θ1, ...] - """ - rot_dim = freqs.shape[-1] - x, x_pass = x[..., :rot_dim], x[..., rot_dim:] - - cos_ = torch.cos(freqs).to(x.dtype) - sin_ = torch.sin(freqs).to(x.dtype) - - # Non-interleaved rotate_half: [-x2, x1] - x1, x2 = x.chunk(2, dim=-1) - x_rotated = torch.cat((-x2, x1), dim=-1) - - x = x * cos_ + x_rotated * sin_ - return torch.cat((x, x_pass), dim=-1) + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + freqs_cis: torch.Tensor | None = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + # Apply Norms + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE: same rotate_half logic as Megatron _apply_rotary_pos_emb_bshd (rotary_interleaved=False) + # x_in: [B, S, heads, head_dim], freqs_cis: [B, S, 1, head_dim] with angles [θ0,θ0,θ1,θ1,...] + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + rot_dim = freqs_cis.shape[-1] + x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:] + cos_ = torch.cos(freqs_cis).to(x.dtype) + sin_ = torch.sin(freqs_cis).to(x.dtype) + # Non-interleaved rotate_half: [-x2, x1] + x1, x2 = x.chunk(2, dim=-1) + x_rotated = torch.cat((-x2, x1), dim=-1) + return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1) + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + # Compute joint attention + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + output = attn.to_out[0](hidden_states) + + return output class FeedForward(nn.Module): @@ -161,22 +170,31 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x))) - class SharedAdaLNBlock(nn.Module): def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, qk_layernorm: bool = True): super().__init__() - # self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps) - self.adaLN_sa_ln = torch.nn.RMSNorm(hidden_size, eps=eps) - self.self_attention = Attention(hidden_size, num_heads, eps=eps, qk_layernorm=qk_layernorm) - # self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps) - self.adaLN_mlp_ln = torch.nn.RMSNorm(hidden_size, eps=eps) + self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps) + self.self_attention = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + dim_head=hidden_size // num_heads, + heads=num_heads, + qk_norm="rms_norm" if qk_layernorm else None, + eps=eps, + bias=False, + out_bias=False, + processor=ErnieImageSingleStreamAttnProcessor(), + ) + self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps) self.mlp = FeedForward(hidden_size, ffn_hidden_size) def forward(self, x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask=None): residual = x x = self.adaLN_sa_ln(x) x = self._modulate(x, shift_msa, scale_msa) - attn_out = self.self_attention(x, rotary_pos_emb, attention_mask) + x_bsh = x.permute(1, 0, 2) # [S, B, H] → [B, S, H] for diffusers Attention (batch-first) + attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, freqs_cis=rotary_pos_emb) + attn_out = attn_out.permute(1, 0, 2) # [B, S, H] → [S, B, H] x = residual + self._apply_gate(gate_msa, attn_out) residual = x x = self._modulate(self.adaLN_mlp_ln(x), shift_mlp, scale_mlp) @@ -231,7 +249,6 @@ def __init__( qk_layernorm: bool = True, ): super().__init__() - self.gradient_checkpointing = False self.hidden_size = hidden_size self.num_heads = num_attention_heads self.head_dim = hidden_size // num_attention_heads @@ -277,26 +294,20 @@ def forward(self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_h image_ids = torch.cat([text_lens.float().view(B, 1, 1).expand(-1, N_img, -1), grid_yx.view(1, N_img, 2).expand(B, -1, -1)], dim=-1) rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)) - # Attention mask + # Attention mask: True = valid (attend), False = padding (mask out), matches sdpa bool convention valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) if Tmax > 0 else torch.zeros((B, 0), device=device, dtype=torch.bool) - attention_mask = (~torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1))[:, None, None, :] + attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[:, None, None, :] # AdaLN c = self.time_embedding(self.time_proj(timestep.to(dtype))) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)] for layer in self.layers: - if self.gradient_checkpointing and self.training: - x = self._gradient_checkpointing_func( - layer.__call__, - x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask - ) - else: - x = layer(x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask) + x = layer(x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask) x = self.final_norm(x, c).type_as(x) patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous() output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W) - return ErnieImageTransformer2DModelOutput(sample=output) if return_dict else (output,) + return Text2ImgDiTTransformer2DModelOutput(sample=output) if return_dict else (output,) def _pad_text(self, text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype): B = len(text_hiddens) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 29c1b9fd9bab..242017d1caf7 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -19,6 +19,7 @@ import json import os +import numpy as np import torch from PIL import Image from dataclasses import dataclass From 9cbbf5d92b69b43034210979355b402aedbbc84a Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Fri, 3 Apr 2026 23:03:08 +0800 Subject: [PATCH 5/9] Change from Custom-Attention to Diffusers Style Attention --- src/diffusers/models/transformers/transformer_ernie_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index 7a63e6d6818d..a0a0fc08ab31 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -90,7 +90,7 @@ class ErnieImageSingleStreamAttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( - "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + "ErnieImageSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." ) def __call__( @@ -307,7 +307,7 @@ def forward(self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_h patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous() output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W) - return Text2ImgDiTTransformer2DModelOutput(sample=output) if return_dict else (output,) + return ErnieImageTransformer2DModelOutput(sample=output) if return_dict else (output,) def _pad_text(self, text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype): B = len(text_hiddens) From 9fca91205f75f4a7914d18e20cd538b0aace36ba Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Tue, 7 Apr 2026 17:13:22 +0800 Subject: [PATCH 6/9] =?UTF-8?q?=E5=85=BC=E5=AE=B9SGLang?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../pipelines/ernie_image/pipeline_ernie_image.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 242017d1caf7..56bb1af87444 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -50,6 +50,8 @@ class ErnieImagePipeline(DiffusionPipeline): """ model_cpu_offload_seq = "pe->text_encoder->transformer->vae" + # For SGLang fallback ... + _optional_components = ["pe", "pe_tokenizer"] def __init__( self, @@ -350,8 +352,8 @@ def __call__( # Handle prompts if isinstance(prompt, str): prompt = [prompt] - - # Enhance prompts with PE if enabled + + # [Phase 1] PE: enhance prompts, then offload to CPU if use_pe and self.pe is not None and self.pe_tokenizer is not None: prompt = [ self._enhance_prompt_with_pe(p, device, width=width, height=height, max_length=max_length) @@ -369,7 +371,7 @@ def __call__( if len(negative_prompt) != batch_size: raise ValueError(f"negative_prompt must have same length as prompt ({batch_size})") - # Encode prompts + # [Phase 2] Text encoding, then offload text_encoder to CPU text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt, max_length) # CFG with negative prompt From 465f00979b3170b71ab4a3844c4047723f578ea3 Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Tue, 7 Apr 2026 19:24:58 +0800 Subject: [PATCH 7/9] =?UTF-8?q?=E4=BC=98=E5=8C=96PE=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=E7=9A=84=E5=8A=A0=E8=BD=BD=E4=B8=8Eoffload=E7=AD=96=E7=95=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/en/api/pipelines/ernie_image.md | 2 + .../ernie_image/pipeline_ernie_image.py | 95 +++++++++++-------- 2 files changed, 60 insertions(+), 37 deletions(-) diff --git a/docs/source/en/api/pipelines/ernie_image.md b/docs/source/en/api/pipelines/ernie_image.md index 5bb6f550096f..4f2ad62bdd5c 100644 --- a/docs/source/en/api/pipelines/ernie_image.md +++ b/docs/source/en/api/pipelines/ernie_image.md @@ -37,6 +37,8 @@ from diffusers.utils import load_image pipe = ErnieImagePipeline.from_pretrained("baidu/ERNIE-Image-Turbo", torch_dtype=torch.bfloat16) pipe.to("cuda") +# 如果显存不足,可以开启offload +pipe.enable_model_cpu_offload() prompt = "一只黑白相间的中华田园犬" images = pipe( diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 56bb1af87444..7ee8293d896e 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -52,6 +52,7 @@ class ErnieImagePipeline(DiffusionPipeline): model_cpu_offload_seq = "pe->text_encoder->transformer->vae" # For SGLang fallback ... _optional_components = ["pe", "pe_tokenizer"] + _callback_tensor_inputs = ["latents"] def __init__( self, @@ -93,6 +94,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): torch_dtype = kwargs.pop("torch_dtype", torch.bfloat16) trust_remote_code = kwargs.pop("trust_remote_code", True) + device_map = kwargs.pop("device_map", None) # Determine whether this is a local directory or a Hub repo ID. # For local paths we join sub-directories; for Hub IDs we use `subfolder`. @@ -133,6 +135,8 @@ def _path_or_subfolder(subfolder: str): **_path_or_subfolder("pe"), torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + **({"device_map": device_map} if device_map else {}), ) # Load PE tokenizer (auto-picks up chat_template.jinja in the same dir) @@ -185,8 +189,13 @@ def _enhance_prompt_with_pe( tokenize=False, add_generation_prompt=False, # "Output:" is already in the user block ) - # pe_device = next(self.pe.parameters()).device - inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(device) + # When accelerate offload hooks are installed, use the hook's execution_device + # to ensure inputs land on the same device as the model weights during forward() + if hasattr(self.pe, "_hf_hook") and hasattr(self.pe._hf_hook, "execution_device"): + pe_device = self.pe._hf_hook.execution_device + else: + pe_device = device + inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(pe_device) output_ids = self.pe.generate( **inputs, @@ -314,8 +323,8 @@ def __call__( latents: Optional[torch.Tensor] = None, output_type: str = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, - callback_steps: int = 1, + callback_on_step_end: Optional[Callable[[int, int, dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_length: int = 1536, use_pe: bool = True, # 默认使用PE进行改写 ): @@ -334,8 +343,12 @@ def __call__( latents: Pre-generated latents (optional) output_type: "pil" or "latent" return_dict: Whether to return a dataclass - callback: Optional callback function - callback_steps: Steps between callbacks + callback_on_step_end: Optional callback invoked at the end of each denoising step. + Called as `callback_on_step_end(pipeline, step, timestep, callback_kwargs)` where + `callback_kwargs` contains the tensors listed in `callback_on_step_end_tensor_inputs`. + The callback may return a dict to override those tensors for subsequent steps. + callback_on_step_end_tensor_inputs: List of tensor names passed into the callback kwargs. + Must be a subset of `_callback_tensor_inputs` (default: `["latents"]`). max_length: Max token length for text encoding Returns: @@ -344,7 +357,6 @@ def __call__( device = self._execution_device dtype = self.transformer.dtype - self.pe.to(device) # Validate dimensions if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: raise ValueError(f"Height and width must be divisible by {self.vae_scale_factor}") @@ -353,7 +365,7 @@ def __call__( if isinstance(prompt, str): prompt = [prompt] - # [Phase 1] PE: enhance prompts, then offload to CPU + # [Phase 1] PE: enhance prompts if use_pe and self.pe is not None and self.pe_tokenizer is not None: prompt = [ self._enhance_prompt_with_pe(p, device, width=width, height=height, max_length=max_length) @@ -371,7 +383,7 @@ def __call__( if len(negative_prompt) != batch_size: raise ValueError(f"negative_prompt must have same length as prompt ({batch_size})") - # [Phase 2] Text encoding, then offload text_encoder to CPU + # [Phase 2] Text encoding text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt, max_length) # CFG with negative prompt @@ -396,7 +408,8 @@ def __call__( ) # Setup scheduler - self.scheduler.set_timesteps(num_inference_steps, device=device) + sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1) + self.scheduler.set_timesteps(sigmas=sigmas[:-1], device=device) # Denoising loop if do_cfg: @@ -404,33 +417,38 @@ def __call__( else: cfg_text_hiddens = text_hiddens - for i, t in enumerate(self.scheduler.timesteps): - if do_cfg: - latent_model_input = torch.cat([latents, latents], dim=0) - t_batch = torch.full((total_batch_size * 2,), t.item(), device=device, dtype=dtype) - else: - latent_model_input = latents - t_batch = torch.full((total_batch_size,), t.item(), device=device, dtype=dtype) - - # Model prediction - pred = self.transformer( - hidden_states=latent_model_input, - timestep=t_batch, - encoder_hidden_states=cfg_text_hiddens, - return_dict=False, - )[0] - - # Apply CFG - if do_cfg: - pred_uncond, pred_cond = pred.chunk(2, dim=0) - pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) - - # Scheduler step - latents = self.scheduler.step(pred, t, latents).prev_sample - - # Callback - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(self.scheduler.timesteps): + if do_cfg: + latent_model_input = torch.cat([latents, latents], dim=0) + t_batch = torch.full((total_batch_size * 2,), t.item(), device=device, dtype=dtype) + else: + latent_model_input = latents + t_batch = torch.full((total_batch_size,), t.item(), device=device, dtype=dtype) + + # Model prediction + pred = self.transformer( + hidden_states=latent_model_input, + timestep=t_batch, + encoder_hidden_states=cfg_text_hiddens, + return_dict=False, + )[0] + + # Apply CFG + if do_cfg: + pred_uncond, pred_cond = pred.chunk(2, dim=0) + pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + + # Scheduler step + latents = self.scheduler.step(pred, t, latents).prev_sample + + # Callback + if callback_on_step_end is not None: + callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs} + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + + progress_bar.update() if output_type == "latent": return latents @@ -454,6 +472,9 @@ def __call__( if output_type == "pil": images = [Image.fromarray((img * 255).astype("uint8")) for img in images] + # Offload all models + self.maybe_free_model_hooks() + if not return_dict: return (images,) From 6afd5342bb2fec21af31f75a12f59fb4cc0a20a4 Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Wed, 8 Apr 2026 11:56:09 +0800 Subject: [PATCH 8/9] =?UTF-8?q?=E6=9B=B4=E6=96=B0Doc=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E4=B8=8Econfig=E9=85=8D=E7=BD=AE=E7=9B=B8=E5=85=B3=E5=86=85?= =?UTF-8?q?=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api/models/ernie_image_transformer2d.md | 2 + docs/source/en/api/pipelines/ernie_image.md | 41 ++++++++++++++-- .../transformers/transformer_ernie_image.py | 3 +- .../ernie_image/pipeline_ernie_image.py | 48 +++++++------------ .../pipelines/ernie_image/pipeline_output.py | 36 ++++++++++++++ 5 files changed, 94 insertions(+), 36 deletions(-) create mode 100644 src/diffusers/pipelines/ernie_image/pipeline_output.py diff --git a/docs/source/en/api/models/ernie_image_transformer2d.md b/docs/source/en/api/models/ernie_image_transformer2d.md index 8be37d56bf42..9fe03090577f 100644 --- a/docs/source/en/api/models/ernie_image_transformer2d.md +++ b/docs/source/en/api/models/ernie_image_transformer2d.md @@ -12,6 +12,8 @@ specific language governing permissions and limitations under the License. # ErnieImageTransformer2DModel +A Transformer model for image-like data from [ERNIE-Image](https://huggingface.co/baidu/ERNIE-Image). + A Transformer model for image-like data from [ERNIE-Image-Turbo](https://huggingface.co/baidu/ERNIE-Image-Turbo). ## ErnieImageTransformer2DModel diff --git a/docs/source/en/api/pipelines/ernie_image.md b/docs/source/en/api/pipelines/ernie_image.md index 4f2ad62bdd5c..69c0234d4cbf 100644 --- a/docs/source/en/api/pipelines/ernie_image.md +++ b/docs/source/en/api/pipelines/ernie_image.md @@ -16,19 +16,51 @@ specific language governing permissions and limitations under the License. LoRA -[ERNIE-Image] is a powerful and highly efficient image generation model with 8B parameters. Currently there's only one model with two more to be released: +[ERNIE-Image] is a powerful and highly efficient image generation model with 8B parameters. Currently there's only two models to be released: |Model|Hugging Face| |---|---| +|ERNIE-Image|https://huggingface.co/baidu/ERNIE-Image| |ERNIE-Image-Turbo|https://huggingface.co/baidu/ERNIE-Image-Turbo| -## Ernie-Image +## ERNIE-Image -ERNIE-Image-Turbo is a distilled version of ERNIE-Image that matches or exceeds leading competitors with only 8 NFEs (Number of Function Evaluations). It offers sub-second inference latency on enterprise-grade H800 GPUs and fits comfortably within 16G VRAM consumer devices. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence. +ERNIE-Image is designed with a relatively compact architecture and solid instruction-following capability, emphasizing parameter efficiency. Based on an 8B DiT backbone, it provides performance that is comparable in some scenarios to larger (20B+) models, while maintaining reasonable parameter efficiency. It offers a relatively stable level of performance in instruction understanding and execution, text generation (e.g., English / Chinese / Japanese), and overall stability. + +## ERNIE-Image-Turbo + +ERNIE-Image-Turbo is a distilled variant of ERNIE-Image, requiring only 8 NFEs (Number of Function Evaluations) and offering a more efficient alternative with relatively comparable performance to the full model in certain cases. ## ErnieImagePipeline -Use [`ErnieImagePipeline`] to generate an image based on a text prompt. If you do not want to use PE, please set use_pe=False. +Use [ErnieImagePipeline] to generate images from text prompts. The pipeline supports Prompt Enhancer (PE) by default, which enhances the user’s raw prompt to improve output quality, though it may reduce instruction-following accuracy. + +We provide a pretrained 3B-parameter PE model; however, using larger language models (e.g., Gemini or ChatGPT) for prompt enhancement may yield better results. The system prompt template is available at: https://huggingface.co/baidu/ERNIE-Image/blob/main/pe/chat_template.jinja. + +If you prefer not to use PE, set use_pe=False. + +```python +import torch +from diffusers import ErnieImagePipeline +from diffusers.utils import load_image + +pipe = ErnieImagePipeline.from_pretrained("baidu/ERNIE-Image", torch_dtype=torch.bfloat16) +pipe.to("cuda") +# 如果显存不足,可以开启offload +pipe.enable_model_cpu_offload() + +prompt = "一只黑白相间的中华田园犬" +images = pipe( + prompt=prompt, + height=1024, + width=1024, + num_inference_steps=50, + guidance_scale=5.0, + generator=generator, + use_pe=True, +).images +images[0].save("ernie-image-output.png") +``` ```python import torch @@ -48,6 +80,7 @@ images = pipe( num_inference_steps=8, guidance_scale=5.0, generator=generator, + use_pe=True, ).images images[0].save("ernie-image-turbo-output.png") ``` \ No newline at end of file diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index a0a0fc08ab31..f92d65eb40b5 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -1,5 +1,4 @@ -# Copyright (c) 2025, Baidu Inc. All rights reserved. -# Author: fengzhida (fengzhida@baidu.com) +# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 7ee8293d896e..4f79a4501059 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -1,5 +1,4 @@ -# Copyright (c) 2025, Baidu Inc. All rights reserved. -# Author: fengzhida (fengzhida@baidu.com) +# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,20 +21,14 @@ import numpy as np import torch from PIL import Image -from dataclasses import dataclass -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Union from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import BaseOutput from ...models import AutoencoderKLFlux2 from ...models.transformers import ErnieImageTransformer2DModel - - -@dataclass -class ErnieImagePipelineOutput(BaseOutput): - images: List[Image.Image] +from .pipeline_output import ErnieImagePipelineOutput class ErnieImagePipeline(DiffusionPipeline): @@ -168,7 +161,6 @@ def _enhance_prompt_with_pe( width: int = 1024, height: int = 1024, system_prompt: Optional[str] = None, - max_length: int = 1536, temperature: float = 0.6, top_p: float = 0.95, ) -> str: @@ -196,10 +188,9 @@ def _enhance_prompt_with_pe( else: pe_device = device inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(pe_device) - output_ids = self.pe.generate( **inputs, - max_new_tokens=max_length, + max_new_tokens=self.pe_tokenizer.model_max_length, do_sample=temperature != 1.0 or top_p != 1.0, temperature=temperature, top_p=top_p, @@ -216,7 +207,6 @@ def encode_prompt( prompt: Union[str, List[str]], device: torch.device, num_images_per_prompt: int = 1, - max_length: int = 64, ) -> List[torch.Tensor]: """Encode text prompts to embeddings.""" if isinstance(prompt, str): @@ -229,7 +219,6 @@ def encode_prompt( p, add_special_tokens=True, truncation=True, - max_length=max_length, padding=False, )["input_ids"] @@ -260,7 +249,6 @@ def _encode_negative_prompt( negative_prompt: List[str], device: torch.device, num_images_per_prompt: int = 1, - max_length: int = 64, ) -> List[torch.Tensor]: """Encode negative prompts for CFG.""" text_hiddens = [] @@ -270,7 +258,6 @@ def _encode_negative_prompt( np, add_special_tokens=True, truncation=True, - max_length=max_length, padding=False, )["input_ids"] @@ -314,10 +301,10 @@ def __call__( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = "", - height: int = 256, - width: int = 256, + height: int = 1024, + width: int = 1024, num_inference_steps: int = 50, - guidance_scale: float = 5.0, + guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, @@ -325,7 +312,6 @@ def __call__( return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_length: int = 1536, use_pe: bool = True, # 默认使用PE进行改写 ): """ @@ -334,10 +320,10 @@ def __call__( Args: prompt: Text prompt(s) negative_prompt: Negative prompt(s) for CFG. Default is "". - height: Image height (must be divisible by 16) - width: Image width (must be divisible by 16) + height: Image height in pixels (must be divisible by 16). Default: 1024. + width: Image width in pixels (must be divisible by 16). Default: 1024. num_inference_steps: Number of denoising steps - guidance_scale: CFG scale (1.0 = no guidance) + guidance_scale: CFG scale (1.0 = no guidance). Default: 4.0. num_images_per_prompt: Number of images per prompt generator: Random generator for reproducibility latents: Pre-generated latents (optional) @@ -349,10 +335,10 @@ def __call__( The callback may return a dict to override those tensors for subsequent steps. callback_on_step_end_tensor_inputs: List of tensor names passed into the callback kwargs. Must be a subset of `_callback_tensor_inputs` (default: `["latents"]`). - max_length: Max token length for text encoding + use_pe: Whether to use the PE model to enhance prompts before generation. Returns: - Generated images + :class:`ErnieImagePipelineOutput` with `images` and `revised_prompts`. """ device = self._execution_device dtype = self.transformer.dtype @@ -366,11 +352,13 @@ def __call__( prompt = [prompt] # [Phase 1] PE: enhance prompts + revised_prompts: Optional[List[str]] = None if use_pe and self.pe is not None and self.pe_tokenizer is not None: prompt = [ - self._enhance_prompt_with_pe(p, device, width=width, height=height, max_length=max_length) + self._enhance_prompt_with_pe(p, device, width=width, height=height) for p in prompt ] + revised_prompts = list(prompt) batch_size = len(prompt) total_batch_size = batch_size * num_images_per_prompt @@ -384,13 +372,13 @@ def __call__( raise ValueError(f"negative_prompt must have same length as prompt ({batch_size})") # [Phase 2] Text encoding - text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt, max_length) + text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt) # CFG with negative prompt do_cfg = guidance_scale > 1.0 if do_cfg: uncond_text_hiddens = self._encode_negative_prompt( - negative_prompt, device, num_images_per_prompt, max_length + negative_prompt, device, num_images_per_prompt ) # Latent dimensions @@ -478,4 +466,4 @@ def __call__( if not return_dict: return (images,) - return ErnieImagePipelineOutput(images=images) + return ErnieImagePipelineOutput(images=images, revised_prompts=revised_prompts) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_output.py b/src/diffusers/pipelines/ernie_image/pipeline_output.py new file mode 100644 index 000000000000..8919db0c0aca --- /dev/null +++ b/src/diffusers/pipelines/ernie_image/pipeline_output.py @@ -0,0 +1,36 @@ +# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional + +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class ErnieImagePipelineOutput(BaseOutput): + """ + Output class for ERNIE-Image pipelines. + + Args: + images (`List[PIL.Image.Image]`): + List of generated images. + revised_prompts (`List[str]`, *optional*): + List of PE-revised prompts. `None` when PE is disabled or unavailable. + """ + + images: List[PIL.Image.Image] + revised_prompts: Optional[List[str]] From b360596fa8933d59abf4edc91f036807ee6bbe61 Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Wed, 8 Apr 2026 21:18:17 +0800 Subject: [PATCH 9/9] =?UTF-8?q?Fix=E5=AE=98=E6=96=B9=E5=8F=8D=E9=A6=88?= =?UTF-8?q?=E7=9A=84=E5=86=85=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../transformers/transformer_ernie_image.py | 134 +++++++++++++---- .../ernie_image/pipeline_ernie_image.py | 136 +++++------------- 2 files changed, 138 insertions(+), 132 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index f92d65eb40b5..e87995d57bfb 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -17,6 +17,7 @@ """ import math +import inspect from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union @@ -25,11 +26,13 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ..embeddings import Timesteps +from ..embeddings import TimestepEmbedding from ..modeling_utils import ModelMixin from ...utils import BaseOutput from ..normalization import RMSNorm from ..attention_processor import Attention from ..attention_dispatch import dispatch_attention_fn +from ..attention import AttentionMixin, AttentionModuleMixin @dataclass @@ -45,7 +48,7 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: return out.float() -class EmbedND3(nn.Module): +class ErnieImageEmbedND3(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]): super().__init__() self.dim = dim @@ -70,18 +73,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.reshape(B, D, Hp * Wp).transpose(1, 2).contiguous() -class TimestepEmbedding(nn.Module): - def __init__(self, in_channels: int, time_embed_dim: int): - super().__init__() - self.linear_1 = nn.Linear(in_channels, time_embed_dim) - self.act = nn.SiLU() - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) - - def forward(self, sample: torch.Tensor) -> torch.Tensor: - sample = sample.to(self.linear_1.weight.dtype) - return self.linear_2(self.act(self.linear_1(sample))) - - class ErnieImageSingleStreamAttnProcessor: _attention_backend = None _parallel_config = None @@ -157,6 +148,89 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso return output +class ErnieImageAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = ErnieImageSingleStreamAttnProcessor + _available_processors = [ErnieImageSingleStreamAttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + qk_norm: str = "rms_norm", + added_kv_proj_dim: int | None = None, + added_proj_bias: bool | None = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.use_bias = bias + self.dropout = dropout + + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + # QK Norm + if qk_norm == "layer_norm": + self.norm_q = torch.nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "rms_norm": + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'." + ) + + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + class FeedForward(nn.Module): def __init__(self, hidden_size: int, ffn_hidden_size: int): @@ -173,9 +247,8 @@ class SharedAdaLNBlock(nn.Module): def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, qk_layernorm: bool = True): super().__init__() self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps) - self.self_attention = Attention( + self.self_attention = ErnieImageAttention( query_dim=hidden_size, - cross_attention_dim=None, dim_head=hidden_size // num_heads, heads=num_heads, qk_norm="rms_norm" if qk_layernorm else None, @@ -192,7 +265,7 @@ def forward(self, x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, x = self.adaLN_sa_ln(x) x = self._modulate(x, shift_msa, scale_msa) x_bsh = x.permute(1, 0, 2) # [S, B, H] → [B, S, H] for diffusers Attention (batch-first) - attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, freqs_cis=rotary_pos_emb) + attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb) attn_out = attn_out.permute(1, 0, 2) # [B, S, H] → [S, B, H] x = residual + self._apply_gate(gate_msa, attn_out) residual = x @@ -261,7 +334,7 @@ def __init__( self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0) self.time_embedding = TimestepEmbedding(hidden_size, hidden_size) - self.pos_embed = EmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim) + self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size)) nn.init.zeros_(self.adaLN_modulation[-1].weight) nn.init.zeros_(self.adaLN_modulation[-1].bias) @@ -271,14 +344,22 @@ def __init__( nn.init.zeros_(self.final_linear.weight) nn.init.zeros_(self.final_linear.bias) - def forward(self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: List[torch.Tensor], return_dict: bool = True): + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + # encoder_hidden_states: List[torch.Tensor], + text_bth: torch.Tensor, + text_lens: torch.Tensor, + return_dict: bool = True + ): device, dtype = hidden_states.device, hidden_states.dtype B, C, H, W = hidden_states.shape p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size N_img = Hp * Wp img_sbh = self.x_embedder(hidden_states).transpose(0, 1).contiguous() - text_bth, text_lens = self._pad_text(encoder_hidden_states, device, dtype) + # text_bth, text_lens = self._pad_text(encoder_hidden_states, device, dtype) if self.text_proj is not None and text_bth.numel() > 0: text_bth = self.text_proj(text_bth) Tmax = text_bth.shape[1] @@ -298,7 +379,9 @@ def forward(self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_h attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[:, None, None, :] # AdaLN - c = self.time_embedding(self.time_proj(timestep.to(dtype))) + sample = self.time_proj(timestep.to(dtype)) + sample = sample.to(self.time_embedding.linear_1.weight.dtype) + c = self.time_embedding(sample) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)] for layer in self.layers: x = layer(x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask) @@ -308,14 +391,3 @@ def forward(self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_h return ErnieImageTransformer2DModelOutput(sample=output) if return_dict else (output,) - def _pad_text(self, text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype): - B = len(text_hiddens) - if B == 0: - return torch.zeros((0, 0, self.text_in_dim), device=device, dtype=dtype), torch.zeros((0,), device=device, dtype=torch.long) - normalized = [th.squeeze(1).to(device).to(dtype) if th.dim() == 3 else th.to(device).to(dtype) for th in text_hiddens] - lens = torch.tensor([t.shape[0] for t in normalized], device=device, dtype=torch.long) - Tmax = int(lens.max().item()) - text_bth = torch.zeros((B, Tmax, self.text_in_dim), device=device, dtype=dtype) - for i, t in enumerate(normalized): - text_bth[i, :t.shape[0], :] = t - return text_bth, lens diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 4f79a4501059..e2526c0f5700 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -49,13 +49,13 @@ class ErnieImagePipeline(DiffusionPipeline): def __init__( self, - transformer, - vae, - text_encoder, - tokenizer, + transformer: ErnieImageTransformer2DModel, + vae: AutoencoderKLFlux2, + text_encoder: AutoModel, + tokenizer: AutoTokenizer, scheduler: FlowMatchEulerDiscreteScheduler, - pe=None, - pe_tokenizer=None, + pe: Optional[AutoModelForCausalLM] = None, + pe_tokenizer: Optional[AutoTokenizer] = None, ): super().__init__() self.register_modules( @@ -69,89 +69,13 @@ def __init__( ) self.vae_scale_factor = 16 # VAE downsample factor - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): - """ - Load pipeline from a pretrained model directory. - - Args: - pretrained_model_name_or_path: Path to the saved pipeline directory - **kwargs: Additional arguments passed to component loaders - - torch_dtype: Data type for model weights (default: torch.bfloat16) - - device_map: Device map for model loading - - trust_remote_code: Whether to trust remote code for text encoder + @property + def guidance_scale(self): + return self._guidance_scale - Returns: - ErnieImagePipeline instance - """ - - torch_dtype = kwargs.pop("torch_dtype", torch.bfloat16) - trust_remote_code = kwargs.pop("trust_remote_code", True) - device_map = kwargs.pop("device_map", None) - - # Determine whether this is a local directory or a Hub repo ID. - # For local paths we join sub-directories; for Hub IDs we use `subfolder`. - is_local = os.path.isdir(pretrained_model_name_or_path) - - def _path_or_subfolder(subfolder: str): - if is_local: - return {"pretrained_model_name_or_path": os.path.join(pretrained_model_name_or_path, subfolder)} - return {"pretrained_model_name_or_path": pretrained_model_name_or_path, "subfolder": subfolder} - - # Load transformer - transformer = ErnieImageTransformer2DModel.from_pretrained( - **_path_or_subfolder("transformer"), - torch_dtype=torch_dtype, - ) - - # Load VAE - vae = AutoencoderKLFlux2.from_pretrained( - **_path_or_subfolder("vae"), - torch_dtype=torch_dtype, - ) - - # Load text encoder - text_encoder = AutoModel.from_pretrained( - **_path_or_subfolder("text_encoder"), - torch_dtype=torch_dtype, - trust_remote_code=trust_remote_code, - ) - - # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained( - **_path_or_subfolder("tokenizer"), - trust_remote_code=trust_remote_code, - ) - - # Load PE - pe = AutoModelForCausalLM.from_pretrained( - **_path_or_subfolder("pe"), - torch_dtype=torch_dtype, - trust_remote_code=trust_remote_code, - low_cpu_mem_usage=True, - **({"device_map": device_map} if device_map else {}), - ) - - # Load PE tokenizer (auto-picks up chat_template.jinja in the same dir) - pe_tokenizer = AutoTokenizer.from_pretrained( - **_path_or_subfolder("pe"), - trust_remote_code=trust_remote_code, - ) - - # Load scheduler - scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - **_path_or_subfolder("scheduler"), - ) - - return cls( - transformer=transformer, - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - pe=pe, - pe_tokenizer=pe_tokenizer, - scheduler=scheduler, - ) + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 @torch.no_grad() def _enhance_prompt_with_pe( @@ -181,13 +105,7 @@ def _enhance_prompt_with_pe( tokenize=False, add_generation_prompt=False, # "Output:" is already in the user block ) - # When accelerate offload hooks are installed, use the hook's execution_device - # to ensure inputs land on the same device as the model weights during forward() - if hasattr(self.pe, "_hf_hook") and hasattr(self.pe._hf_hook, "execution_device"): - pe_device = self.pe._hf_hook.execution_device - else: - pe_device = device - inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(pe_device) + inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(device) output_ids = self.pe.generate( **inputs, max_new_tokens=self.pe_tokenizer.model_max_length, @@ -296,6 +214,19 @@ def _unpatchify_latents(latents: torch.Tensor) -> torch.Tensor: latents = latents.permute(0, 1, 4, 2, 5, 3) return latents.reshape(b, c // 4, h * 2, w * 2) + def _pad_text(self, text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype): + text_in_dim = self.transformer.config.text_in_dim + B = len(text_hiddens) + if B == 0: + return torch.zeros((0, 0, text_in_dim), device=device, dtype=dtype), torch.zeros((0,), device=device, dtype=torch.long) + normalized = [th.squeeze(1).to(device).to(dtype) if th.dim() == 3 else th.to(device).to(dtype) for th in text_hiddens] + lens = torch.tensor([t.shape[0] for t in normalized], device=device, dtype=torch.long) + Tmax = int(lens.max().item()) + text_bth = torch.zeros((B, Tmax, text_in_dim), device=device, dtype=dtype) + for i, t in enumerate(normalized): + text_bth[i, :t.shape[0], :] = t + return text_bth, lens + @torch.no_grad() def __call__( self, @@ -343,6 +274,7 @@ def __call__( device = self._execution_device dtype = self.transformer.dtype + self._guidance_scale = guidance_scale # Validate dimensions if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: raise ValueError(f"Height and width must be divisible by {self.vae_scale_factor}") @@ -375,8 +307,7 @@ def __call__( text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt) # CFG with negative prompt - do_cfg = guidance_scale > 1.0 - if do_cfg: + if self.do_classifier_free_guidance: uncond_text_hiddens = self._encode_negative_prompt( negative_prompt, device, num_images_per_prompt ) @@ -400,14 +331,14 @@ def __call__( self.scheduler.set_timesteps(sigmas=sigmas[:-1], device=device) # Denoising loop - if do_cfg: + if self.do_classifier_free_guidance: cfg_text_hiddens = list(uncond_text_hiddens) + list(text_hiddens) else: cfg_text_hiddens = text_hiddens with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(self.scheduler.timesteps): - if do_cfg: + if self.do_classifier_free_guidance: latent_model_input = torch.cat([latents, latents], dim=0) t_batch = torch.full((total_batch_size * 2,), t.item(), device=device, dtype=dtype) else: @@ -415,15 +346,18 @@ def __call__( t_batch = torch.full((total_batch_size,), t.item(), device=device, dtype=dtype) # Model prediction + text_bth, text_lens = self._pad_text(cfg_text_hiddens, device, dtype) pred = self.transformer( hidden_states=latent_model_input, timestep=t_batch, - encoder_hidden_states=cfg_text_hiddens, + # encoder_hidden_states=cfg_text_hiddens, + text_bth=text_bth, + text_lens=text_lens, return_dict=False, )[0] # Apply CFG - if do_cfg: + if self.do_classifier_free_guidance: pred_uncond, pred_cond = pred.chunk(2, dim=0) pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)