diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index aac4835fe849..9f56a27a174e 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -120,7 +120,10 @@ "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias", "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight", "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"], - "z-image-turbo": "cap_embedder.0.weight", + "z-image-turbo": [ + "model.diffusion_model.layers.0.adaLN_modulation.0.weight", + "layers.0.adaLN_modulation.0.weight", + ], "z-image-turbo-controlnet": "control_all_x_embedder.2-1.weight", "z-image-turbo-controlnet-2.x": "control_layers.14.adaLN_modulation.0.weight", "sana": [ @@ -727,10 +730,7 @@ def infer_diffusers_model_type(checkpoint): ): model_type = "instruct-pix2pix" - elif ( - CHECKPOINT_KEY_NAMES["z-image-turbo"] in checkpoint - and checkpoint[CHECKPOINT_KEY_NAMES["z-image-turbo"]].shape[0] == 2560 - ): + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["z-image-turbo"]): model_type = "z-image-turbo" elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]): @@ -3852,6 +3852,7 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): ".attention.k_norm.weight": ".attention.norm_k.weight", ".attention.q_norm.weight": ".attention.norm_q.weight", ".attention.out.weight": ".attention.to_out.0.weight", + "model.diffusion_model.": "", } def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None: @@ -3886,6 +3887,9 @@ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) update_state_dict(converted_state_dict, key, new_key) + if "norm_final.weight" in converted_state_dict.keys(): + _ = converted_state_dict.pop("norm_final.weight") + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in # special_keys_remap for key in list(converted_state_dict.keys()):