@@ -186,6 +186,7 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig) ->
186186 logger .info (f"loading state dict from { config .vae_path } ..." )
187187 vae_state_dict = cls .load_model_checkpoint (config .vae_path , device = "cpu" , dtype = config .vae_dtype )
188188
189+ encoder_state_dict = None
189190 if config .encoder_path is None :
190191 config .encoder_path = fetch_model (
191192 "MusePublic/Qwen-image" ,
@@ -197,8 +198,9 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig) ->
197198 "text_encoder/model-00004-of-00004.safetensors" ,
198199 ],
199200 )
200- logger .info (f"loading state dict from { config .encoder_path } ..." )
201- encoder_state_dict = cls .load_model_checkpoint (config .encoder_path , device = "cpu" , dtype = config .encoder_dtype )
201+ if config .load_encoder :
202+ logger .info (f"loading state dict from { config .encoder_path } ..." )
203+ encoder_state_dict = cls .load_model_checkpoint (config .encoder_path , device = "cpu" , dtype = config .encoder_dtype )
202204
203205 state_dicts = QwenImageStateDicts (
204206 model = model_state_dict ,
@@ -225,22 +227,25 @@ def from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePipe
225227 @classmethod
226228 def _from_state_dict (cls , state_dicts : QwenImageStateDicts , config : QwenImagePipelineConfig ) -> "QwenImagePipeline" :
227229 init_device = "cpu" if config .offload_mode is not None else config .device
228- tokenizer = Qwen2TokenizerFast .from_pretrained (QWEN_IMAGE_TOKENIZER_CONF_PATH )
229- processor = Qwen2VLProcessor .from_pretrained (
230- tokenizer_config_path = QWEN_IMAGE_TOKENIZER_CONF_PATH ,
231- image_processor_config_path = QWEN_IMAGE_PROCESSOR_CONFIG_FILE ,
232- )
233- with open (QWEN_IMAGE_VISION_CONFIG_FILE , "r" , encoding = "utf-8" ) as f :
234- vision_config = Qwen2_5_VLVisionConfig (** json .load (f ))
235- with open (QWEN_IMAGE_CONFIG_FILE , "r" , encoding = "utf-8" ) as f :
236- text_config = Qwen2_5_VLConfig (** json .load (f ))
237- encoder = Qwen2_5_VLForConditionalGeneration .from_state_dict (
238- state_dicts .encoder ,
239- vision_config = vision_config ,
240- config = text_config ,
241- device = ("cpu" if config .use_fsdp else init_device ),
242- dtype = config .encoder_dtype ,
243- )
230+ tokenizer , processor , encoder = None , None , None
231+ if config .load_encoder :
232+ tokenizer = Qwen2TokenizerFast .from_pretrained (QWEN_IMAGE_TOKENIZER_CONF_PATH )
233+ processor = Qwen2VLProcessor .from_pretrained (
234+ tokenizer_config_path = QWEN_IMAGE_TOKENIZER_CONF_PATH ,
235+ image_processor_config_path = QWEN_IMAGE_PROCESSOR_CONFIG_FILE ,
236+ )
237+ with open (QWEN_IMAGE_VISION_CONFIG_FILE , "r" , encoding = "utf-8" ) as f :
238+ vision_config = Qwen2_5_VLVisionConfig (** json .load (f ))
239+ with open (QWEN_IMAGE_CONFIG_FILE , "r" , encoding = "utf-8" ) as f :
240+ text_config = Qwen2_5_VLConfig (** json .load (f ))
241+ encoder = Qwen2_5_VLForConditionalGeneration .from_state_dict (
242+ state_dicts .encoder ,
243+ vision_config = vision_config ,
244+ config = text_config ,
245+ device = ("cpu" if config .use_fsdp else init_device ),
246+ dtype = config .encoder_dtype ,
247+ )
248+
244249 with open (QWEN_IMAGE_VAE_CONFIG_FILE , "r" , encoding = "utf-8" ) as f :
245250 vae_config = json .load (f )
246251 vae = QwenImageVAE .from_state_dict (
0 commit comments