Skip to content

Commit ece6fec

Browse files
authored
fix key convert for lora kohya lora (#183)
* fix key convert for lora from kohya * simplify rename logic
1 parent e9521cf commit ece6fec

File tree

1 file changed

+7
-17
lines changed

1 file changed

+7
-17
lines changed

diffsynth_engine/pipelines/flux_image.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def _from_kohya(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dic
4545
flux_dim = 3072
4646
dit_rename_dict = flux_dit_config["civitai"]["rename_dict"]
4747
dit_suffix_rename_dict = flux_dit_config["civitai"]["suffix_rename_dict"]
48-
clip_rename_dict = flux_text_encoder_config["diffusers"]["rename_dict"]
4948
clip_attn_rename_dict = flux_text_encoder_config["diffusers"]["attn_rename_dict"]
5049

5150
dit_dict = {}
@@ -138,27 +137,18 @@ def _from_kohya(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dic
138137
lora_args["rank"] = lora_args["up"].shape[1]
139138
rename = rename.replace(".weight", "")
140139
dit_dict[rename] = lora_args
141-
elif "lora_te" in key:
142-
name = key.replace("lora_te1", "text_encoder")
143-
name = name.replace("text_model_encoder_layers", "text_model.encoder.layers")
144-
name = name.replace(".alpha", ".weight")
145-
rename = ""
146-
if name in clip_rename_dict:
147-
if name == "text_model.embeddings.position_embedding.weight":
148-
param = param.reshape((1, param.shape[0], param.shape[1]))
149-
rename = clip_rename_dict[name]
150-
elif name.startswith("text_model.encoder.layers."):
151-
names = name.split(".")
152-
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
153-
rename = ".".join(["encoders", layer_id, clip_attn_rename_dict[layer_type], tail])
154-
else:
155-
raise ValueError(f"Unsupported key: {key}")
140+
elif "lora_te1_text_model_encoder_layers_" in key:
141+
name = key.replace("lora_te1_text_model_encoder_layers_", "")
142+
name = name.replace(".alpha", "")
143+
layer_id, layer_type = name.split("_", 1)
144+
layer_type = layer_type.replace("self_attn_", "self_attn.").replace("mlp_", "mlp.")
145+
rename = ".".join(["encoders", layer_id, clip_attn_rename_dict[layer_type]])
146+
156147
lora_args = {}
157148
lora_args["alpha"] = param
158149
lora_args["up"] = lora_state_dict[origin_key.replace(".alpha", ".lora_up.weight")]
159150
lora_args["down"] = lora_state_dict[origin_key.replace(".alpha", ".lora_down.weight")]
160151
lora_args["rank"] = lora_args["up"].shape[1]
161-
rename = rename.replace(".weight", "")
162152
te_dict[rename] = lora_args
163153
else:
164154
raise ValueError(f"Unsupported key: {key}")

0 commit comments

Comments
 (0)