Skip to content

Commit 5e86c87

Browse files
load encoder optional (#196)
* load encoder optional * Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * remove redundant code --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 90c73be commit 5e86c87

File tree

2 files changed

+25
-18
lines changed

2 files changed

+25
-18
lines changed

diffsynth_engine/configs/pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
242242
vae_tile_size: Tuple[int, int] = (34, 34)
243243
vae_tile_stride: Tuple[int, int] = (18, 16)
244244

245+
load_encoder: bool = True
246+
245247
@classmethod
246248
def basic_config(
247249
cls,

diffsynth_engine/pipelines/qwen_image.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig) ->
186186
logger.info(f"loading state dict from {config.vae_path} ...")
187187
vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
188188

189+
encoder_state_dict = None
189190
if config.encoder_path is None:
190191
config.encoder_path = fetch_model(
191192
"MusePublic/Qwen-image",
@@ -197,8 +198,9 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig) ->
197198
"text_encoder/model-00004-of-00004.safetensors",
198199
],
199200
)
200-
logger.info(f"loading state dict from {config.encoder_path} ...")
201-
encoder_state_dict = cls.load_model_checkpoint(config.encoder_path, device="cpu", dtype=config.encoder_dtype)
201+
if config.load_encoder:
202+
logger.info(f"loading state dict from {config.encoder_path} ...")
203+
encoder_state_dict = cls.load_model_checkpoint(config.encoder_path, device="cpu", dtype=config.encoder_dtype)
202204

203205
state_dicts = QwenImageStateDicts(
204206
model=model_state_dict,
@@ -225,22 +227,25 @@ def from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePipe
225227
@classmethod
226228
def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePipelineConfig) -> "QwenImagePipeline":
227229
init_device = "cpu" if config.offload_mode is not None else config.device
228-
tokenizer = Qwen2TokenizerFast.from_pretrained(QWEN_IMAGE_TOKENIZER_CONF_PATH)
229-
processor = Qwen2VLProcessor.from_pretrained(
230-
tokenizer_config_path=QWEN_IMAGE_TOKENIZER_CONF_PATH,
231-
image_processor_config_path=QWEN_IMAGE_PROCESSOR_CONFIG_FILE,
232-
)
233-
with open(QWEN_IMAGE_VISION_CONFIG_FILE, "r", encoding="utf-8") as f:
234-
vision_config = Qwen2_5_VLVisionConfig(**json.load(f))
235-
with open(QWEN_IMAGE_CONFIG_FILE, "r", encoding="utf-8") as f:
236-
text_config = Qwen2_5_VLConfig(**json.load(f))
237-
encoder = Qwen2_5_VLForConditionalGeneration.from_state_dict(
238-
state_dicts.encoder,
239-
vision_config=vision_config,
240-
config=text_config,
241-
device=("cpu" if config.use_fsdp else init_device),
242-
dtype=config.encoder_dtype,
243-
)
230+
tokenizer, processor, encoder = None, None, None
231+
if config.load_encoder:
232+
tokenizer = Qwen2TokenizerFast.from_pretrained(QWEN_IMAGE_TOKENIZER_CONF_PATH)
233+
processor = Qwen2VLProcessor.from_pretrained(
234+
tokenizer_config_path=QWEN_IMAGE_TOKENIZER_CONF_PATH,
235+
image_processor_config_path=QWEN_IMAGE_PROCESSOR_CONFIG_FILE,
236+
)
237+
with open(QWEN_IMAGE_VISION_CONFIG_FILE, "r", encoding="utf-8") as f:
238+
vision_config = Qwen2_5_VLVisionConfig(**json.load(f))
239+
with open(QWEN_IMAGE_CONFIG_FILE, "r", encoding="utf-8") as f:
240+
text_config = Qwen2_5_VLConfig(**json.load(f))
241+
encoder = Qwen2_5_VLForConditionalGeneration.from_state_dict(
242+
state_dicts.encoder,
243+
vision_config=vision_config,
244+
config=text_config,
245+
device=("cpu" if config.use_fsdp else init_device),
246+
dtype=config.encoder_dtype,
247+
)
248+
244249
with open(QWEN_IMAGE_VAE_CONFIG_FILE, "r", encoding="utf-8") as f:
245250
vae_config = json.load(f)
246251
vae = QwenImageVAE.from_state_dict(

0 commit comments

Comments
 (0)