@@ -45,7 +45,6 @@ def _from_kohya(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dic
4545 flux_dim = 3072
4646 dit_rename_dict = flux_dit_config ["civitai" ]["rename_dict" ]
4747 dit_suffix_rename_dict = flux_dit_config ["civitai" ]["suffix_rename_dict" ]
48- clip_rename_dict = flux_text_encoder_config ["diffusers" ]["rename_dict" ]
4948 clip_attn_rename_dict = flux_text_encoder_config ["diffusers" ]["attn_rename_dict" ]
5049
5150 dit_dict = {}
@@ -138,27 +137,18 @@ def _from_kohya(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dic
138137 lora_args ["rank" ] = lora_args ["up" ].shape [1 ]
139138 rename = rename .replace (".weight" , "" )
140139 dit_dict [rename ] = lora_args
141- elif "lora_te" in key :
142- name = key .replace ("lora_te1" , "text_encoder" )
143- name = name .replace ("text_model_encoder_layers" , "text_model.encoder.layers" )
144- name = name .replace (".alpha" , ".weight" )
145- rename = ""
146- if name in clip_rename_dict :
147- if name == "text_model.embeddings.position_embedding.weight" :
148- param = param .reshape ((1 , param .shape [0 ], param .shape [1 ]))
149- rename = clip_rename_dict [name ]
150- elif name .startswith ("text_model.encoder.layers." ):
151- names = name .split ("." )
152- layer_id , layer_type , tail = names [3 ], "." .join (names [4 :- 1 ]), names [- 1 ]
153- rename = "." .join (["encoders" , layer_id , clip_attn_rename_dict [layer_type ], tail ])
154- else :
155- raise ValueError (f"Unsupported key: { key } " )
140+ elif "lora_te1_text_model_encoder_layers_" in key :
141+ name = key .replace ("lora_te1_text_model_encoder_layers_" , "" )
142+ name = name .replace (".alpha" , "" )
143+ layer_id , layer_type = name .split ("_" , 1 )
144+ layer_type = layer_type .replace ("self_attn_" , "self_attn." ).replace ("mlp_" , "mlp." )
145+ rename = "." .join (["encoders" , layer_id , clip_attn_rename_dict [layer_type ]])
146+
156147 lora_args = {}
157148 lora_args ["alpha" ] = param
158149 lora_args ["up" ] = lora_state_dict [origin_key .replace (".alpha" , ".lora_up.weight" )]
159150 lora_args ["down" ] = lora_state_dict [origin_key .replace (".alpha" , ".lora_down.weight" )]
160151 lora_args ["rank" ] = lora_args ["up" ].shape [1 ]
161- rename = rename .replace (".weight" , "" )
162152 te_dict [rename ] = lora_args
163153 else :
164154 raise ValueError (f"Unsupported key: { key } " )
0 commit comments