Skip to content
2 changes: 2 additions & 0 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def text_encoder_attn_modules(text_encoder):
"QwenImageLoraLoaderMixin",
"ZImageLoraLoaderMixin",
"Flux2LoraLoaderMixin",
"LongCatLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
Expand Down Expand Up @@ -120,6 +121,7 @@ def text_encoder_attn_modules(text_encoder):
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
KandinskyLoraLoaderMixin,
LongCatLoraLoaderMixin,
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
Expand Down
207 changes: 207 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5387,6 +5387,213 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
super().unfuse_lora(components=components, **kwargs)


class LongCatLoraLoaderMixin(LoraBaseMixin):
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME

@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)

allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True

user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}

state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)

if any("dora_scale" in k for k in state_dict):
logger.warning(
"It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

# Flux-family formats
is_kohya = any(".lora_down.weight" in k for k in state_dict)
if is_kohya:
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
else:
is_xlabs = any("processor" in k for k in state_dict)
if is_xlabs:
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
else:
is_bfl_control = any("query_norm.scale" in k for k in state_dict)
if is_bfl_control:
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
else:
is_fal_kontext = any("base_model" in k for k in state_dict)
if is_fal_kontext:
state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict)

# Generic non-diffusers formats
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
has_default = any("default." in k for k in state_dict)

if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default:
converted, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
state_dict = converted
if network_alphas is not None:
metadata = {} if metadata is None else dict(metadata)
metadata["network_alphas"] = network_alphas

# Keep only transformer keys
def _is_non_transformer_key(k: str) -> bool:
bad_prefixes = ("unet.", "text_encoder.", "text_encoder_2.", "vae.", "controlnet.")
if k.startswith(bad_prefixes):
return True
bad_substrings = (".unet.", ".text_encoder.", ".text_encoder_2.", ".vae.", ".controlnet.")
return any(s in k for s in bad_substrings)

state_dict = {k: v for k, v in state_dict.items() if not _is_non_transformer_key(k)}

return (state_dict, metadata) if return_lora_metadata else state_dict

def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")

low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. "
"Please update it with `pip install -U peft`."
)

if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()

kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")

self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)

@classmethod
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. "
"Please update it with `pip install -U peft`."
)

logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)

@classmethod
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
lora_layers = {}
lora_metadata = {}

if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata

if not lora_layers:
raise ValueError("You must pass at least `transformer_lora_layers`.")

cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)

def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
super().unfuse_lora(components=components, **kwargs)

class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
"LongCatImageTransformer2DModel": lambda model_cls, weights: weights,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
import numpy as np
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor

from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, LongCatLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import LongCatImageTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
Expand Down Expand Up @@ -202,7 +201,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class LongCatImagePipeline(DiffusionPipeline, FromSingleFileMixin):
class LongCatImagePipeline(DiffusionPipeline, LongCatLoraLoaderMixin, FromSingleFileMixin):
r"""
The pipeline for text-to-image generation.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
import PIL
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor

from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, LongCatLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import LongCatImageTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
Expand Down Expand Up @@ -228,7 +227,7 @@ def calculate_dimensions(target_area, ratio):
return width, height


class LongCatImageEditPipeline(DiffusionPipeline, FromSingleFileMixin):
class LongCatImageEditPipeline(DiffusionPipeline, LongCatLoraLoaderMixin, FromSingleFileMixin):
r"""
The LongCat-Image-Edit pipeline for image editing.
"""
Expand Down
Loading