diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3a50634d82d8..e415307f4cfb 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -418,6 +418,8 @@ "QwenImageEditModularPipeline", "QwenImageEditPlusAutoBlocks", "QwenImageEditPlusModularPipeline", + "QwenImageLayeredAutoBlocks", + "QwenImageLayeredModularPipeline", "QwenImageModularPipeline", "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", @@ -1138,6 +1140,8 @@ QwenImageEditModularPipeline, QwenImageEditPlusAutoBlocks, QwenImageEditPlusModularPipeline, + QwenImageLayeredAutoBlocks, + QwenImageLayeredModularPipeline, QwenImageModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 5fcc1a176d1b..e64db23f3831 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -63,6 +63,8 @@ "QwenImageEditAutoBlocks", "QwenImageEditPlusModularPipeline", "QwenImageEditPlusAutoBlocks", + "QwenImageLayeredModularPipeline", + "QwenImageLayeredAutoBlocks", ] _import_structure["z_image"] = [ "ZImageAutoBlocks", @@ -96,6 +98,8 @@ QwenImageEditModularPipeline, QwenImageEditPlusAutoBlocks, QwenImageEditPlusModularPipeline, + QwenImageLayeredAutoBlocks, + QwenImageLayeredModularPipeline, QwenImageModularPipeline, ) from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline diff --git a/src/diffusers/modular_pipelines/flux/inputs.py b/src/diffusers/modular_pipelines/flux/inputs.py index 8309eebfeb37..45b1c6bc136f 100644 --- a/src/diffusers/modular_pipelines/flux/inputs.py +++ b/src/diffusers/modular_pipelines/flux/inputs.py @@ -121,7 +121,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip return components, state -# Adapted from `QwenImageInputsDynamicStep` +# Adapted from `QwenImageAdditionalInputsStep` class FluxInputsDynamicStep(ModularPipelineBlocks): model_name = "flux" diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index c5fa4cf9921f..d857fd040955 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -62,6 +62,7 @@ ("qwenimage", "QwenImageModularPipeline"), ("qwenimage-edit", "QwenImageEditModularPipeline"), ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"), + ("qwenimage-layered", "QwenImageLayeredModularPipeline"), ("z-image", "ZImageModularPipeline"), ] ) @@ -231,7 +232,7 @@ def format_value(v): class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): """ - Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks, + Base class for all Pipeline Blocks: ConditionalPipelineBlocks, AutoPipelineBlocks, SequentialPipelineBlocks, LoopSequentialPipelineBlocks [`ModularPipelineBlocks`] provides method to load and save the definition of pipeline blocks. @@ -527,9 +528,10 @@ def doc(self): ) -class AutoPipelineBlocks(ModularPipelineBlocks): +class ConditionalPipelineBlocks(ModularPipelineBlocks): """ - A Pipeline Blocks that automatically selects a block to run based on the inputs. + A Pipeline Blocks that conditionally selects a block to run based on the inputs. Subclasses must implement the + `select_block` method to define the logic for selecting the block. This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the library implements for all the pipeline blocks (such as loading or saving etc.) @@ -539,12 +541,13 @@ class AutoPipelineBlocks(ModularPipelineBlocks): Attributes: block_classes: List of block classes to be used block_names: List of prefixes for each block - block_trigger_inputs: List of input names that trigger specific blocks, with None for default + block_trigger_inputs: List of input names that select_block() uses to determine which block to run """ block_classes = [] block_names = [] block_trigger_inputs = [] + default_block_name = None # name of the default block if no trigger inputs are provided, if None, this block can be skipped if no trigger inputs are provided def __init__(self): sub_blocks = InsertableDict() @@ -554,26 +557,15 @@ def __init__(self): else: sub_blocks[block_name] = block self.sub_blocks = sub_blocks - if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): + if not (len(self.block_classes) == len(self.block_names)): raise ValueError( - f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same." + f"In {self.__class__.__name__}, the number of block_classes and block_names must be the same." ) - default_blocks = [t for t in self.block_trigger_inputs if t is None] - # can only have 1 or 0 default block, and has to put in the last - # the order of blocks matters here because the first block with matching trigger will be dispatched - # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] - # as long as mask is provided, it is inpaint; if only image is provided, it is img2img - if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None): + if self.default_block_name is not None and self.default_block_name not in self.block_names: raise ValueError( - f"In {self.__class__.__name__}, exactly one None must be specified as the last element " - "in block_trigger_inputs." + f"In {self.__class__.__name__}, default_block_name '{self.default_block_name}' must be one of block_names: {self.block_names}" ) - # Map trigger inputs to block objects - self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.values())) - self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.keys())) - self.block_to_trigger_map = dict(zip(self.sub_blocks.keys(), self.block_trigger_inputs)) - @property def model_name(self): return next(iter(self.sub_blocks.values())).model_name @@ -602,8 +594,10 @@ def expected_configs(self): @property def required_inputs(self) -> List[str]: - if None not in self.block_trigger_inputs: + # no default block means this conditional block can be skipped entirely + if self.default_block_name is None: return [] + first_block = next(iter(self.sub_blocks.values())) required_by_all = set(getattr(first_block, "required_inputs", set())) @@ -614,7 +608,6 @@ def required_inputs(self) -> List[str]: return list(required_by_all) - # YiYi TODO: add test for this @property def inputs(self) -> List[Tuple[str, Any]]: named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()] @@ -639,36 +632,9 @@ def outputs(self) -> List[str]: combined_outputs = self.combine_outputs(*named_outputs) return combined_outputs - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - # Find default block first (if any) - - block = self.trigger_to_block_map.get(None) - for input_name in self.block_trigger_inputs: - if input_name is not None and state.get(input_name) is not None: - block = self.trigger_to_block_map[input_name] - break - - if block is None: - logger.info(f"skipping auto block: {self.__class__.__name__}") - return pipeline, state - - try: - logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}") - return block(pipeline, state) - except Exception as e: - error_msg = ( - f"\nError in block: {block.__class__.__name__}\n" - f"Error details: {str(e)}\n" - f"Traceback:\n{traceback.format_exc()}" - ) - logger.error(error_msg) - raise - - def _get_trigger_inputs(self): + def _get_trigger_inputs(self) -> set: """ - Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique - block_trigger_inputs values + Returns a set of all unique trigger input values found in this block and nested blocks. """ def fn_recursive_get_trigger(blocks): @@ -676,9 +642,8 @@ def fn_recursive_get_trigger(blocks): if blocks is not None: for name, block in blocks.items(): - # Check if current block has trigger inputs(i.e. auto block) + # Check if current block has block_trigger_inputs if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None: - # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) # If block has sub_blocks, recursively check them @@ -688,15 +653,57 @@ def fn_recursive_get_trigger(blocks): return trigger_values - trigger_inputs = set(self.block_trigger_inputs) - trigger_inputs.update(fn_recursive_get_trigger(self.sub_blocks)) + # Start with this block's block_trigger_inputs + all_triggers = {t for t in self.block_trigger_inputs if t is not None} + # Add nested triggers + all_triggers.update(fn_recursive_get_trigger(self.sub_blocks)) - return trigger_inputs + return all_triggers @property def trigger_inputs(self): + """All trigger inputs including from nested blocks.""" return self._get_trigger_inputs() + def select_block(self, **kwargs) -> Optional[str]: + """ + Select the block to run based on the trigger inputs. Subclasses must implement this method to define the logic + for selecting the block. + + Args: + **kwargs: Trigger input names and their values from the state. + + Returns: + Optional[str]: The name of the block to run, or None to use default/skip. + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement the `select_block` method.") + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + trigger_kwargs = {name: state.get(name) for name in self.block_trigger_inputs if name is not None} + block_name = self.select_block(**trigger_kwargs) + + if block_name is None: + block_name = self.default_block_name + + if block_name is None: + logger.info(f"skipping conditional block: {self.__class__.__name__}") + return pipeline, state + + block = self.sub_blocks[block_name] + + try: + logger.info(f"Running block: {block.__class__.__name__}") + return block(pipeline, state) + except Exception as e: + error_msg = ( + f"\nError in block: {block.__class__.__name__}\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ @@ -708,7 +715,7 @@ def __repr__(self): header += "\n" header += " " + "=" * 100 + "\n" header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" - header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n" + header += f" Trigger Inputs: {sorted(self.trigger_inputs)}\n" header += " " + "=" * 100 + "\n\n" # Format description with proper indentation @@ -729,31 +736,20 @@ def __repr__(self): expected_configs = getattr(self, "expected_configs", []) configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - # Blocks section - moved to the end with simplified format + # Blocks section blocks_str = " Sub-Blocks:\n" for i, (name, block) in enumerate(self.sub_blocks.items()): - # Get trigger input for this block - trigger = None - if hasattr(self, "block_to_trigger_map"): - trigger = self.block_to_trigger_map.get(name) - # Format the trigger info - if trigger is None: - trigger_str = "[default]" - elif isinstance(trigger, (list, tuple)): - trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" - else: - trigger_str = f"[trigger: {trigger}]" - # For AutoPipelineBlocks, add bullet points - blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" + if name == self.default_block_name: + addtional_str = " [default]" else: - # For SequentialPipelineBlocks, show execution order - blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + addtional_str = "" + blocks_str += f" • {name}{addtional_str} ({block.__class__.__name__})\n" # Add block description - desc_lines = block.description.split("\n") - indented_desc = desc_lines[0] - if len(desc_lines) > 1: - indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:]) + block_desc_lines = block.description.split("\n") + indented_desc = block_desc_lines[0] + if len(block_desc_lines) > 1: + indented_desc += "\n" + "\n".join(" " + line for line in block_desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" # Build the representation with conditional sections @@ -784,6 +780,35 @@ def doc(self): ) +class AutoPipelineBlocks(ConditionalPipelineBlocks): + """ + A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs. + """ + + def __init__(self): + super().__init__() + + if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): + raise ValueError( + f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same." + ) + + @property + def default_block_name(self) -> Optional[str]: + """Derive default_block_name from block_trigger_inputs (None entry).""" + if None in self.block_trigger_inputs: + idx = self.block_trigger_inputs.index(None) + return self.block_names[idx] + return None + + def select_block(self, **kwargs) -> Optional[str]: + """Select block based on which trigger input is present (not None).""" + for trigger_input, block_name in zip(self.block_trigger_inputs, self.block_names): + if trigger_input is not None and kwargs.get(trigger_input) is not None: + return block_name + return None + + class SequentialPipelineBlocks(ModularPipelineBlocks): """ A Pipeline Blocks that combines multiple pipeline block classes into one. When called, it will call each block in @@ -885,7 +910,8 @@ def _get_inputs(self): # Only add outputs if the block cannot be skipped should_add_outputs = True - if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + if isinstance(block, ConditionalPipelineBlocks) and block.default_block_name is None: + # ConditionalPipelineBlocks without default can be skipped should_add_outputs = False if should_add_outputs: @@ -948,8 +974,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: def _get_trigger_inputs(self): """ - Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique - block_trigger_inputs values + Returns a set of all unique trigger input values found in the blocks. """ def fn_recursive_get_trigger(blocks): @@ -957,9 +982,8 @@ def fn_recursive_get_trigger(blocks): if blocks is not None: for name, block in blocks.items(): - # Check if current block has trigger inputs(i.e. auto block) + # Check if current block has block_trigger_inputs (ConditionalPipelineBlocks) if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None: - # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) # If block has sub_blocks, recursively check them @@ -975,82 +999,84 @@ def fn_recursive_get_trigger(blocks): def trigger_inputs(self): return self._get_trigger_inputs() - def _traverse_trigger_blocks(self, trigger_inputs): - # Convert trigger_inputs to a set for easier manipulation - active_triggers = set(trigger_inputs) + def _traverse_trigger_blocks(self, active_inputs): + """ + Traverse blocks and select which ones would run given the active inputs. + + Args: + active_inputs: Dict of input names to values that are "present" + + Returns: + OrderedDict of block_name -> block that would execute + """ - def fn_recursive_traverse(block, block_name, active_triggers): + def fn_recursive_traverse(block, block_name, active_inputs): result_blocks = OrderedDict() - # sequential(include loopsequential) or PipelineBlock - if not hasattr(block, "block_trigger_inputs"): - if block.sub_blocks: - # sequential or LoopSequentialPipelineBlocks (keep traversing) - for sub_block_name, sub_block in block.sub_blocks.items(): - blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) - blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) - blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()} - result_blocks.update(blocks_to_update) + # ConditionalPipelineBlocks (includes AutoPipelineBlocks) + if isinstance(block, ConditionalPipelineBlocks): + trigger_kwargs = {name: active_inputs.get(name) for name in block.block_trigger_inputs} + selected_block_name = block.select_block(**trigger_kwargs) + + if selected_block_name is None: + selected_block_name = block.default_block_name + + if selected_block_name is None: + return result_blocks + + selected_block = block.sub_blocks[selected_block_name] + + if selected_block.sub_blocks: + result_blocks.update(fn_recursive_traverse(selected_block, block_name, active_inputs)) else: - # PipelineBlock - result_blocks[block_name] = block - # Add this block's output names to active triggers if defined - if hasattr(block, "outputs"): - active_triggers.update(out.name for out in block.outputs) + result_blocks[block_name] = selected_block + if hasattr(selected_block, "outputs"): + for out in selected_block.outputs: + active_inputs[out.name] = True + return result_blocks - # auto + # SequentialPipelineBlocks or LoopSequentialPipelineBlocks + if block.sub_blocks: + for sub_block_name, sub_block in block.sub_blocks.items(): + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_inputs) + blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()} + result_blocks.update(blocks_to_update) else: - # Find first block_trigger_input that matches any value in our active_triggers - this_block = None - for trigger_input in block.block_trigger_inputs: - if trigger_input is not None and trigger_input in active_triggers: - this_block = block.trigger_to_block_map[trigger_input] - break - - # If no matches found, try to get the default (None) block - if this_block is None and None in block.block_trigger_inputs: - this_block = block.trigger_to_block_map[None] - - if this_block is not None: - # sequential/auto (keep traversing) - if this_block.sub_blocks: - result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) - else: - # PipelineBlock - result_blocks[block_name] = this_block - # Add this block's output names to active triggers if defined - # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute? - if hasattr(this_block, "outputs"): - active_triggers.update(out.name for out in this_block.outputs) + result_blocks[block_name] = block + if hasattr(block, "outputs"): + for out in block.outputs: + active_inputs[out.name] = True return result_blocks all_blocks = OrderedDict() for block_name, block in self.sub_blocks.items(): - blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) + blocks_to_update = fn_recursive_traverse(block, block_name, active_inputs) all_blocks.update(blocks_to_update) return all_blocks - def get_execution_blocks(self, *trigger_inputs): - trigger_inputs_all = self.trigger_inputs + def get_execution_blocks(self, **kwargs): + """ + Get the blocks that would execute given the specified inputs. - if trigger_inputs is not None: - if not isinstance(trigger_inputs, (list, tuple, set)): - trigger_inputs = [trigger_inputs] - invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] - if invalid_inputs: - logger.warning( - f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" - ) - trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] + Args: + **kwargs: Input names and values. Only trigger inputs affect block selection. + Pass any inputs that would be non-None at runtime. - if trigger_inputs is None: - if None in trigger_inputs_all: - trigger_inputs = [None] - else: - trigger_inputs = [trigger_inputs_all[0]] - blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) + Returns: + SequentialPipelineBlocks containing only the blocks that would execute + + Example: + # Get blocks for inpainting workflow blocks = pipeline.get_execution_blocks(prompt="a cat", mask=mask, + image=image) + + # Get blocks for text2image workflow blocks = pipeline.get_execution_blocks(prompt="a cat") + """ + # Filter out None values + active_inputs = {k: v for k, v in kwargs.items() if v is not None} + + blocks_triggered = self._traverse_trigger_blocks(active_inputs) return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) def __repr__(self): @@ -1067,7 +1093,7 @@ def __repr__(self): header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n" # Get first trigger input as example example_input = next(t for t in self.trigger_inputs if t is not None) - header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + header += f" Use `get_execution_blocks()` to see selected blocks (e.g. `get_execution_blocks({example_input}=...)`).\n" header += " " + "=" * 100 + "\n\n" # Format description with proper indentation @@ -1091,22 +1117,8 @@ def __repr__(self): # Blocks section - moved to the end with simplified format blocks_str = " Sub-Blocks:\n" for i, (name, block) in enumerate(self.sub_blocks.items()): - # Get trigger input for this block - trigger = None - if hasattr(self, "block_to_trigger_map"): - trigger = self.block_to_trigger_map.get(name) - # Format the trigger info - if trigger is None: - trigger_str = "[default]" - elif isinstance(trigger, (list, tuple)): - trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" - else: - trigger_str = f"[trigger: {trigger}]" - # For AutoPipelineBlocks, add bullet points - blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" - else: - # For SequentialPipelineBlocks, show execution order - blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + # show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" # Add block description desc_lines = block.description.split("\n") @@ -1230,15 +1242,9 @@ def _get_inputs(self): if inp.name not in outputs and inp not in inputs: inputs.append(inp) - # Only add outputs if the block cannot be skipped - should_add_outputs = True - if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: - should_add_outputs = False - - if should_add_outputs: - # Add this block's outputs - block_intermediate_outputs = [out.name for out in block.intermediate_outputs] - outputs.update(block_intermediate_outputs) + # Add this block's outputs + block_intermediate_outputs = [out.name for out in block.intermediate_outputs] + outputs.update(block_intermediate_outputs) for input_param in inputs: if input_param.name in self.required_inputs: @@ -1295,6 +1301,14 @@ def __init__(self): sub_blocks[block_name] = block self.sub_blocks = sub_blocks + # Validate that sub_blocks are only leaf blocks + for block_name, block in self.sub_blocks.items(): + if block.sub_blocks: + raise ValueError( + f"In {self.__class__.__name__}, sub_blocks must be leaf blocks (no sub_blocks). " + f"Block '{block_name}' ({block.__class__.__name__}) has sub_blocks." + ) + @classmethod def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks": """ diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py deleted file mode 100644 index f7ee1dd3097b..000000000000 --- a/src/diffusers/modular_pipelines/node_utils.py +++ /dev/null @@ -1,661 +0,0 @@ -import json -import logging -import os -from pathlib import Path -from typing import List, Optional, Tuple, Union - -import numpy as np -import PIL -import torch - -from ..configuration_utils import ConfigMixin -from ..image_processor import PipelineImageInput -from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks -from .modular_pipeline_utils import InputParam - - -logger = logging.getLogger(__name__) - -# YiYi Notes: this is actually for SDXL, put it here for now -SDXL_INPUTS_SCHEMA = { - "prompt": InputParam( - "prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation" - ), - "prompt_2": InputParam( - "prompt_2", - type_hint=Union[str, List[str]], - description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2", - ), - "negative_prompt": InputParam( - "negative_prompt", - type_hint=Union[str, List[str]], - description="The prompt or prompts not to guide the image generation", - ), - "negative_prompt_2": InputParam( - "negative_prompt_2", - type_hint=Union[str, List[str]], - description="The negative prompt or prompts for text_encoder_2", - ), - "cross_attention_kwargs": InputParam( - "cross_attention_kwargs", - type_hint=Optional[dict], - description="Kwargs dictionary passed to the AttentionProcessor", - ), - "clip_skip": InputParam( - "clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder" - ), - "image": InputParam( - "image", - type_hint=PipelineImageInput, - required=True, - description="The image(s) to modify for img2img or inpainting", - ), - "mask_image": InputParam( - "mask_image", - type_hint=PipelineImageInput, - required=True, - description="Mask image for inpainting, white pixels will be repainted", - ), - "generator": InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="Generator(s) for deterministic generation", - ), - "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), - "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), - "num_images_per_prompt": InputParam( - "num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt" - ), - "num_inference_steps": InputParam( - "num_inference_steps", type_hint=int, default=50, description="Number of denoising steps" - ), - "timesteps": InputParam( - "timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process" - ), - "sigmas": InputParam( - "sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process" - ), - "denoising_end": InputParam( - "denoising_end", - type_hint=Optional[float], - description="Fraction of denoising process to complete before termination", - ), - # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 - "strength": InputParam( - "strength", type_hint=float, default=0.3, description="How much to transform the reference image" - ), - "denoising_start": InputParam( - "denoising_start", type_hint=Optional[float], description="Starting point of the denoising process" - ), - "latents": InputParam( - "latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation" - ), - "padding_mask_crop": InputParam( - "padding_mask_crop", - type_hint=Optional[Tuple[int, int]], - description="Size of margin in crop for image and mask", - ), - "original_size": InputParam( - "original_size", - type_hint=Optional[Tuple[int, int]], - description="Original size of the image for SDXL's micro-conditioning", - ), - "target_size": InputParam( - "target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning" - ), - "negative_original_size": InputParam( - "negative_original_size", - type_hint=Optional[Tuple[int, int]], - description="Negative conditioning based on image resolution", - ), - "negative_target_size": InputParam( - "negative_target_size", - type_hint=Optional[Tuple[int, int]], - description="Negative conditioning based on target resolution", - ), - "crops_coords_top_left": InputParam( - "crops_coords_top_left", - type_hint=Tuple[int, int], - default=(0, 0), - description="Top-left coordinates for SDXL's micro-conditioning", - ), - "negative_crops_coords_top_left": InputParam( - "negative_crops_coords_top_left", - type_hint=Tuple[int, int], - default=(0, 0), - description="Negative conditioning crop coordinates", - ), - "aesthetic_score": InputParam( - "aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image" - ), - "negative_aesthetic_score": InputParam( - "negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score" - ), - "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), - "output_type": InputParam( - "output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)" - ), - "ip_adapter_image": InputParam( - "ip_adapter_image", - type_hint=PipelineImageInput, - required=True, - description="Image(s) to be used as IP adapter", - ), - "control_image": InputParam( - "control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition" - ), - "control_guidance_start": InputParam( - "control_guidance_start", - type_hint=Union[float, List[float]], - default=0.0, - description="When ControlNet starts applying", - ), - "control_guidance_end": InputParam( - "control_guidance_end", - type_hint=Union[float, List[float]], - default=1.0, - description="When ControlNet stops applying", - ), - "controlnet_conditioning_scale": InputParam( - "controlnet_conditioning_scale", - type_hint=Union[float, List[float]], - default=1.0, - description="Scale factor for ControlNet outputs", - ), - "guess_mode": InputParam( - "guess_mode", - type_hint=bool, - default=False, - description="Enables ControlNet encoder to recognize input without prompts", - ), - "control_mode": InputParam( - "control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet" - ), -} - -SDXL_INTERMEDIATE_INPUTS_SCHEMA = { - "prompt_embeds": InputParam( - "prompt_embeds", - type_hint=torch.Tensor, - required=True, - description="Text embeddings used to guide image generation", - ), - "negative_prompt_embeds": InputParam( - "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings" - ), - "pooled_prompt_embeds": InputParam( - "pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings" - ), - "negative_pooled_prompt_embeds": InputParam( - "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings" - ), - "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), - "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - "preprocess_kwargs": InputParam( - "preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor" - ), - "latents": InputParam( - "latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process" - ), - "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), - "num_inference_steps": InputParam( - "num_inference_steps", type_hint=int, required=True, description="Number of denoising steps" - ), - "latent_timestep": InputParam( - "latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep" - ), - "image_latents": InputParam( - "image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image" - ), - "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), - "masked_image_latents": InputParam( - "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting" - ), - "add_time_ids": InputParam( - "add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning" - ), - "negative_add_time_ids": InputParam( - "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids" - ), - "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), - "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), - "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), - "ip_adapter_embeds": InputParam( - "ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter" - ), - "negative_ip_adapter_embeds": InputParam( - "negative_ip_adapter_embeds", - type_hint=List[torch.Tensor], - description="Negative image embeddings for IP-Adapter", - ), - "images": InputParam( - "images", - type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], - required=True, - description="Generated images", - ), -} - -SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA} - - -DEFAULT_PARAM_MAPS = { - "prompt": { - "label": "Prompt", - "type": "string", - "default": "a bear sitting in a chair drinking a milkshake", - "display": "textarea", - }, - "negative_prompt": { - "label": "Negative Prompt", - "type": "string", - "default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality", - "display": "textarea", - }, - "num_inference_steps": { - "label": "Steps", - "type": "int", - "default": 25, - "min": 1, - "max": 1000, - }, - "seed": { - "label": "Seed", - "type": "int", - "default": 0, - "min": 0, - "display": "random", - }, - "width": { - "label": "Width", - "type": "int", - "display": "text", - "default": 1024, - "min": 8, - "max": 8192, - "step": 8, - "group": "dimensions", - }, - "height": { - "label": "Height", - "type": "int", - "display": "text", - "default": 1024, - "min": 8, - "max": 8192, - "step": 8, - "group": "dimensions", - }, - "images": { - "label": "Images", - "type": "image", - "display": "output", - }, - "image": { - "label": "Image", - "type": "image", - "display": "input", - }, -} - -DEFAULT_TYPE_MAPS = { - "int": { - "type": "int", - "default": 0, - "min": 0, - }, - "float": { - "type": "float", - "default": 0.0, - "min": 0.0, - }, - "str": { - "type": "string", - "default": "", - }, - "bool": { - "type": "boolean", - "default": False, - }, - "image": { - "type": "image", - }, -} - -DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"] -DEFAULT_CATEGORY = "Modular Diffusers" -DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"] -DEFAULT_PARAMS_GROUPS_KEYS = { - "text_encoders": ["text_encoder", "tokenizer"], - "ip_adapter_embeds": ["ip_adapter_embeds"], - "prompt_embeddings": ["prompt_embeds"], -} - - -def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS): - """ - Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" -> - "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None - """ - if name is None: - return None - for group_name, group_keys in group_params_keys.items(): - for group_key in group_keys: - if group_key in name: - return group_name - return None - - -class ModularNode(ConfigMixin): - """ - A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper - around a ModularPipelineBlocks object. - - > [!WARNING] > This is an experimental feature and is likely to change in the future. - """ - - config_name = "node_config.json" - - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: str, - trust_remote_code: Optional[bool] = None, - **kwargs, - ): - blocks = ModularPipelineBlocks.from_pretrained( - pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs - ) - return cls(blocks, **kwargs) - - def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): - self.blocks = blocks - - if label is None: - label = self.blocks.__class__.__name__ - # blocks param name -> mellon param name - self.name_mapping = {} - - input_params = {} - # pass or create a default param dict for each input - # e.g. for prompt, - # prompt = { - # "name": "text_input", # the name of the input in node definition, could be different from the input name in diffusers - # "label": "Prompt", - # "type": "string", - # "default": "a bear sitting in a chair drinking a milkshake", - # "display": "textarea"} - # if type is not specified, it'll be a "custom" param of its own type - # e.g. you can pass ModularNode(scheduler = {name :"scheduler"}) - # it will get this spec in node definition {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}} - # name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}} - inputs = self.blocks.inputs + self.blocks.intermediate_inputs - for inp in inputs: - param = kwargs.pop(inp.name, None) - if param: - # user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...}) - input_params[inp.name] = param - mellon_name = param.pop("name", inp.name) - if mellon_name != inp.name: - self.name_mapping[inp.name] = mellon_name - continue - - if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name): - continue - - if inp.name in DEFAULT_PARAM_MAPS: - # first check if it's in the default param map, if so, directly use that - param = DEFAULT_PARAM_MAPS[inp.name].copy() - elif get_group_name(inp.name): - param = get_group_name(inp.name) - if inp.name not in self.name_mapping: - self.name_mapping[inp.name] = param - else: - # if not, check if it's in the SDXL input schema, if so, - # 1. use the type hint to determine the type - # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}} - if inp.type_hint is not None: - type_str = str(inp.type_hint).lower() - else: - inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None) - type_str = str(inp_spec.type_hint).lower() if inp_spec else "" - for type_key, type_param in DEFAULT_TYPE_MAPS.items(): - if type_key in type_str: - param = type_param.copy() - param["label"] = inp.name - param["display"] = "input" - break - else: - param = inp.name - # add the param dict to the inp_params dict - input_params[inp.name] = param - - component_params = {} - for comp in self.blocks.expected_components: - param = kwargs.pop(comp.name, None) - if param: - component_params[comp.name] = param - mellon_name = param.pop("name", comp.name) - if mellon_name != comp.name: - self.name_mapping[comp.name] = mellon_name - continue - - to_exclude = False - for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS: - if exclude_key in comp.name: - to_exclude = True - break - if to_exclude: - continue - - if get_group_name(comp.name): - param = get_group_name(comp.name) - if comp.name not in self.name_mapping: - self.name_mapping[comp.name] = param - elif comp.name in DEFAULT_MODEL_KEYS: - param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"} - else: - param = comp.name - # add the param dict to the model_params dict - component_params[comp.name] = param - - output_params = {} - if isinstance(self.blocks, SequentialPipelineBlocks): - last_block_name = list(self.blocks.sub_blocks.keys())[-1] - outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs - else: - outputs = self.blocks.intermediate_outputs - - for out in outputs: - param = kwargs.pop(out.name, None) - if param: - output_params[out.name] = param - mellon_name = param.pop("name", out.name) - if mellon_name != out.name: - self.name_mapping[out.name] = mellon_name - continue - - if out.name in DEFAULT_PARAM_MAPS: - param = DEFAULT_PARAM_MAPS[out.name].copy() - param["display"] = "output" - else: - group_name = get_group_name(out.name) - if group_name: - param = group_name - if out.name not in self.name_mapping: - self.name_mapping[out.name] = param - else: - param = out.name - # add the param dict to the outputs dict - output_params[out.name] = param - - if len(kwargs) > 0: - logger.warning(f"Unused kwargs: {kwargs}") - - register_dict = { - "category": category, - "label": label, - "input_params": input_params, - "component_params": component_params, - "output_params": output_params, - "name_mapping": self.name_mapping, - } - self.register_to_config(**register_dict) - - def setup(self, components_manager, collection=None): - self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection) - self._components_manager = components_manager - - @property - def mellon_config(self): - return self._convert_to_mellon_config() - - def _convert_to_mellon_config(self): - node = {} - node["label"] = self.config.label - node["category"] = self.config.category - - node_param = {} - for inp_name, inp_param in self.config.input_params.items(): - if inp_name in self.name_mapping: - mellon_name = self.name_mapping[inp_name] - else: - mellon_name = inp_name - if isinstance(inp_param, str): - param = { - "label": inp_param, - "type": inp_param, - "display": "input", - } - else: - param = inp_param - - if mellon_name not in node_param: - node_param[mellon_name] = param - else: - logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}") - - for comp_name, comp_param in self.config.component_params.items(): - if comp_name in self.name_mapping: - mellon_name = self.name_mapping[comp_name] - else: - mellon_name = comp_name - if isinstance(comp_param, str): - param = { - "label": comp_param, - "type": comp_param, - "display": "input", - } - else: - param = comp_param - - if mellon_name not in node_param: - node_param[mellon_name] = param - else: - logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}") - - for out_name, out_param in self.config.output_params.items(): - if out_name in self.name_mapping: - mellon_name = self.name_mapping[out_name] - else: - mellon_name = out_name - if isinstance(out_param, str): - param = { - "label": out_param, - "type": out_param, - "display": "output", - } - else: - param = out_param - - if mellon_name not in node_param: - node_param[mellon_name] = param - else: - logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}") - node["params"] = node_param - return node - - def save_mellon_config(self, file_path): - """ - Save the Mellon configuration to a JSON file. - - Args: - file_path (str or Path): Path where the JSON file will be saved - - Returns: - Path: Path to the saved config file - """ - file_path = Path(file_path) - - # Create directory if it doesn't exist - os.makedirs(file_path.parent, exist_ok=True) - - # Create a combined dictionary with module definition and name mapping - config = {"module": self.mellon_config, "name_mapping": self.name_mapping} - - # Save the config to file - with open(file_path, "w", encoding="utf-8") as f: - json.dump(config, f, indent=2) - - logger.info(f"Mellon config and name mapping saved to {file_path}") - - return file_path - - @classmethod - def load_mellon_config(cls, file_path): - """ - Load a Mellon configuration from a JSON file. - - Args: - file_path (str or Path): Path to the JSON file containing Mellon config - - Returns: - dict: The loaded combined configuration containing 'module' and 'name_mapping' - """ - file_path = Path(file_path) - - if not file_path.exists(): - raise FileNotFoundError(f"Config file not found: {file_path}") - - with open(file_path, "r", encoding="utf-8") as f: - config = json.load(f) - - logger.info(f"Mellon config loaded from {file_path}") - - return config - - def process_inputs(self, **kwargs): - params_components = {} - for comp_name, comp_param in self.config.component_params.items(): - logger.debug(f"component: {comp_name}") - mellon_comp_name = self.name_mapping.get(comp_name, comp_name) - if mellon_comp_name in kwargs: - if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]: - comp = kwargs[mellon_comp_name].pop(comp_name) - else: - comp = kwargs.pop(mellon_comp_name) - if comp: - params_components[comp_name] = self._components_manager.get_one(comp["model_id"]) - - params_run = {} - for inp_name, inp_param in self.config.input_params.items(): - logger.debug(f"input: {inp_name}") - mellon_inp_name = self.name_mapping.get(inp_name, inp_name) - if mellon_inp_name in kwargs: - if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]: - inp = kwargs[mellon_inp_name].pop(inp_name) - else: - inp = kwargs.pop(mellon_inp_name) - if inp is not None: - params_run[inp_name] = inp - - return_output_names = list(self.config.output_params.keys()) - - return params_components, params_run, return_output_names - - def execute(self, **kwargs): - params_components, params_run, return_output_names = self.process_inputs(**kwargs) - - self.pipeline.update_components(**params_components) - output = self.pipeline(**params_run, output=return_output_names) - return output diff --git a/src/diffusers/modular_pipelines/qwenimage/__init__.py b/src/diffusers/modular_pipelines/qwenimage/__init__.py index ae4ec4799fbc..2b01a5b5a4b5 100644 --- a/src/diffusers/modular_pipelines/qwenimage/__init__.py +++ b/src/diffusers/modular_pipelines/qwenimage/__init__.py @@ -21,27 +21,27 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["encoders"] = ["QwenImageTextEncoderStep"] - _import_structure["modular_blocks"] = [ - "ALL_BLOCKS", + _import_structure["modular_blocks_qwenimage"] = [ "AUTO_BLOCKS", - "CONTROLNET_BLOCKS", - "EDIT_AUTO_BLOCKS", - "EDIT_BLOCKS", - "EDIT_INPAINT_BLOCKS", - "EDIT_PLUS_AUTO_BLOCKS", - "EDIT_PLUS_BLOCKS", - "IMAGE2IMAGE_BLOCKS", - "INPAINT_BLOCKS", - "TEXT2IMAGE_BLOCKS", "QwenImageAutoBlocks", + ] + _import_structure["modular_blocks_qwenimage_edit"] = [ + "EDIT_AUTO_BLOCKS", "QwenImageEditAutoBlocks", + ] + _import_structure["modular_blocks_qwenimage_edit_plus"] = [ + "EDIT_PLUS_AUTO_BLOCKS", "QwenImageEditPlusAutoBlocks", ] + _import_structure["modular_blocks_qwenimage_layered"] = [ + "LAYERED_AUTO_BLOCKS", + "QwenImageLayeredAutoBlocks", + ] _import_structure["modular_pipeline"] = [ "QwenImageEditModularPipeline", "QwenImageEditPlusModularPipeline", "QwenImageModularPipeline", + "QwenImageLayeredModularPipeline", ] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -51,28 +51,26 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .encoders import ( - QwenImageTextEncoderStep, - ) - from .modular_blocks import ( - ALL_BLOCKS, + from .modular_blocks_qwenimage import ( AUTO_BLOCKS, - CONTROLNET_BLOCKS, - EDIT_AUTO_BLOCKS, - EDIT_BLOCKS, - EDIT_INPAINT_BLOCKS, - EDIT_PLUS_AUTO_BLOCKS, - EDIT_PLUS_BLOCKS, - IMAGE2IMAGE_BLOCKS, - INPAINT_BLOCKS, - TEXT2IMAGE_BLOCKS, QwenImageAutoBlocks, + ) + from .modular_blocks_qwenimage_edit import ( + EDIT_AUTO_BLOCKS, QwenImageEditAutoBlocks, + ) + from .modular_blocks_qwenimage_edit_plus import ( + EDIT_PLUS_AUTO_BLOCKS, QwenImageEditPlusAutoBlocks, ) + from .modular_blocks_qwenimage_layered import ( + LAYERED_AUTO_BLOCKS, + QwenImageLayeredAutoBlocks, + ) from .modular_pipeline import ( QwenImageEditModularPipeline, QwenImageEditPlusModularPipeline, + QwenImageLayeredModularPipeline, QwenImageModularPipeline, ) else: diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index bd92d403539e..0c66d6ea3303 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -23,7 +23,7 @@ from ...utils.torch_utils import randn_tensor, unwrap_module from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier +from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift @@ -113,7 +113,9 @@ def get_timesteps(scheduler, num_inference_steps, strength): return timesteps, num_inference_steps - t_start -# Prepare Latents steps +# ==================== +# 1. PREPARE LATENTS +# ==================== class QwenImagePrepareLatentsStep(ModularPipelineBlocks): @@ -207,6 +209,98 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks): + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Prepare initial random noise (B, layers+1, C, H, W) for the generation process" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("latents"), + InputParam(name="height"), + InputParam(name="width"), + InputParam(name="layers", default=4), + InputParam(name="num_images_per_prompt", default=1), + InputParam(name="generator"), + InputParam( + name="batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam( + name="dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs, can be generated in input step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process", + ), + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs( + height=block_state.height, + width=block_state.width, + vae_scale_factor=components.vae_scale_factor, + ) + + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + # we can update the height and width here since it's used to generate the initial + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) + latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) + + shape = (batch_size, block_state.layers + 1, components.num_channels_latents, latent_height, latent_width) + if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if block_state.latents is None: + block_state.latents = randn_tensor( + shape, generator=block_state.generator, device=device, dtype=block_state.dtype + ) + block_state.latents = components.pachifier.pack_latents(block_state.latents) + + self.set_block_state(state, block_state) + return components, state + + class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks): model_name = "qwenimage" @@ -351,7 +445,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state -# Set Timesteps steps +# ==================== +# 2. SET TIMESTEPS +# ==================== class QwenImageSetTimestepsStep(ModularPipelineBlocks): @@ -420,6 +516,64 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks): + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50, type_hint=int), + InputParam("sigmas", type_hint=List[float]), + InputParam("image_latents", required=True, type_hint=torch.Tensor), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam(name="timesteps", type_hint=torch.Tensor), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + # Layered-specific mu calculation + base_seqlen = 256 * 256 / 16 / 16 # = 256 + mu = (block_state.image_latents.shape[1] / base_seqlen) ** 0.5 + + # Default sigmas if not provided + sigmas = ( + np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps) + if block_state.sigmas is None + else block_state.sigmas + ) + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + return components, state + + class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): model_name = "qwenimage" @@ -493,7 +647,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state -# other inputs for denoiser +# ==================== +# 3. OTHER INPUTS FOR DENOISER +# ==================== ## RoPE inputs for denoiser @@ -522,6 +678,7 @@ def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( name="img_shapes", + kwargs_type="denoiser_input_fields", type_hint=List[List[Tuple[int, int, int]]], description="The shapes of the images latents, used for RoPE calculation", ), @@ -589,6 +746,7 @@ def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( name="img_shapes", + kwargs_type="denoiser_input_fields", type_hint=List[List[Tuple[int, int, int]]], description="The shapes of the images latents, used for RoPE calculation", ), @@ -639,19 +797,64 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state -class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep): +class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks): model_name = "qwenimage-edit-plus" + @property + def description(self) -> str: + return ( + "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus.\n" + "Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images.\n" + "Should be placed after prepare_latents step." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="batch_size", required=True), + InputParam(name="image_height", required=True, type_hint=List[int]), + InputParam(name="image_width", required=True, type_hint=List[int]), + InputParam(name="height", required=True), + InputParam(name="width", required=True), + InputParam(name="prompt_embeds_mask"), + InputParam(name="negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="img_shapes", + kwargs_type="denoiser_input_fields", + type_hint=List[List[Tuple[int, int, int]]], + description="The shapes of the image latents, used for RoPE calculation", + ), + OutputParam( + name="txt_seq_lens", + kwargs_type="denoiser_input_fields", + type_hint=List[int], + description="The sequence lengths of the prompt embeds, used for RoPE calculation", + ), + OutputParam( + name="negative_txt_seq_lens", + kwargs_type="denoiser_input_fields", + type_hint=List[int], + description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", + ), + ] + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) vae_scale_factor = components.vae_scale_factor + + # Edit Plus: image_height and image_width are lists block_state.img_shapes = [ [ (1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2), *[ - (1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2) - for vae_height, vae_width in zip(block_state.image_height, block_state.image_width) + (1, img_height // vae_scale_factor // 2, img_width // vae_scale_factor // 2) + for img_height, img_width in zip(block_state.image_height, block_state.image_width) ], ] ] * block_state.batch_size @@ -670,6 +873,87 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks): + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return ( + "Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="batch_size", required=True), + InputParam(name="layers", required=True), + InputParam(name="height", required=True), + InputParam(name="width", required=True), + InputParam(name="prompt_embeds_mask"), + InputParam(name="negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="img_shapes", + type_hint=List[List[Tuple[int, int, int]]], + kwargs_type="denoiser_input_fields", + description="The shapes of the image latents, used for RoPE calculation", + ), + OutputParam( + name="txt_seq_lens", + type_hint=List[int], + kwargs_type="denoiser_input_fields", + description="The sequence lengths of the prompt embeds, used for RoPE calculation", + ), + OutputParam( + name="negative_txt_seq_lens", + type_hint=List[int], + kwargs_type="denoiser_input_fields", + description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", + ), + OutputParam( + name="additional_t_cond", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="The additional t cond, used for RoPE calculation", + ), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + # All shapes are the same for Layered + shape = ( + 1, + block_state.height // components.vae_scale_factor // 2, + block_state.width // components.vae_scale_factor // 2, + ) + + # layers+1 output shapes + 1 condition shape (all same) + block_state.img_shapes = [[shape] * (block_state.layers + 2)] * block_state.batch_size + + # txt_seq_lens + block_state.txt_seq_lens = ( + block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None + ) + block_state.negative_txt_seq_lens = ( + block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() + if block_state.negative_prompt_embeds_mask is not None + else None + ) + + block_state.additional_t_cond = torch.tensor([0] * block_state.batch_size).to(device=device, dtype=torch.long) + + self.set_block_state(state, block_state) + return components, state + + ## ControlNet inputs for denoiser class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks): model_name = "qwenimage" diff --git a/src/diffusers/modular_pipelines/qwenimage/decoders.py b/src/diffusers/modular_pipelines/qwenimage/decoders.py index 6e145f18550a..24a88ebfca3c 100644 --- a/src/diffusers/modular_pipelines/qwenimage/decoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/decoders.py @@ -24,12 +24,13 @@ from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier +from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier logger = logging.get_logger(__name__) +# after denoising loop (unpack latents) class QwenImageAfterDenoiseStep(ModularPipelineBlocks): model_name = "qwenimage" @@ -71,6 +72,46 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks): + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("latents", required=True, type_hint=torch.Tensor), + InputParam("height", required=True, type_hint=int), + InputParam("width", required=True, type_hint=int), + InputParam("layers", required=True, type_hint=int), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Unpack: (B, seq, C*4) -> (B, C, layers+1, H, W) + block_state.latents = components.pachifier.unpack_latents( + block_state.latents, + block_state.height, + block_state.width, + block_state.layers, + components.vae_scale_factor, + ) + + self.set_block_state(state, block_state) + return components, state + + +# decode step class QwenImageDecoderStep(ModularPipelineBlocks): model_name = "qwenimage" @@ -135,6 +176,81 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +class QwenImageLayeredDecoderStep(ModularPipelineBlocks): + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Decode unpacked latents (B, C, layers+1, H, W) into layer images." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLQwenImage), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("latents", required=True, type_hint=torch.Tensor), + InputParam("output_type", default="pil", type_hint=str), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + latents = block_state.latents + + # 1. VAE normalization + latents = latents.to(components.vae.dtype) + latents_mean = ( + torch.tensor(components.vae.config.latents_mean) + .view(1, components.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + + # 2. Reshape for batch decoding: (B, C, layers+1, H, W) -> (B*layers, C, 1, H, W) + b, c, f, h, w = latents.shape + # 3. Remove first frame (composite), keep layers frames + latents = latents[:, :, 1:] + latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w) + + # 4. Decode: (B*layers, C, 1, H, W) -> (B*layers, C, H, W) + image = components.vae.decode(latents, return_dict=False)[0] + image = image.squeeze(2) + + # 5. Postprocess - returns flat list of B*layers images + image = components.image_processor.postprocess(image, output_type=block_state.output_type) + + # 6. Chunk into list per batch item + images = [] + for bidx in range(b): + images.append(image[bidx * f : (bidx + 1) * f]) + + block_state.images = images + + self.set_block_state(state, block_state) + return components, state + + +# postprocess the decoded images class QwenImageProcessImagesOutputStep(ModularPipelineBlocks): model_name = "qwenimage" diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py index 49acd2dc0295..eb1e5a341c68 100644 --- a/src/diffusers/modular_pipelines/qwenimage/denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from typing import List, Tuple import torch @@ -28,7 +29,12 @@ logger = logging.get_logger(__name__) +# ==================== +# 1. LOOP STEPS (run at each denoising step) +# ==================== + +# loop step:before denoiser class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks): model_name = "qwenimage" @@ -60,7 +66,7 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks): - model_name = "qwenimage" + model_name = "qwenimage-edit" @property def description(self) -> str: @@ -185,6 +191,7 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState return components, block_state +# loop step:denoiser class QwenImageLoopDenoiser(ModularPipelineBlocks): model_name = "qwenimage" @@ -253,6 +260,13 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState ), } + transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) + additional_cond_kwargs = {} + for field_name, field_value in block_state.denoiser_input_fields.items(): + if field_name in transformer_args and field_name not in guider_inputs: + additional_cond_kwargs[field_name] = field_value + block_state.additional_cond_kwargs.update(additional_cond_kwargs) + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) guider_state = components.guider.prepare_inputs(guider_inputs) @@ -264,7 +278,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState guider_state_batch.noise_pred = components.transformer( hidden_states=block_state.latent_model_input, timestep=block_state.timestep / 1000, - img_shapes=block_state.img_shapes, attention_kwargs=block_state.attention_kwargs, return_dict=False, **cond_kwargs, @@ -284,7 +297,7 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState class QwenImageEditLoopDenoiser(ModularPipelineBlocks): - model_name = "qwenimage" + model_name = "qwenimage-edit" @property def description(self) -> str: @@ -351,6 +364,13 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState ), } + transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) + additional_cond_kwargs = {} + for field_name, field_value in block_state.denoiser_input_fields.items(): + if field_name in transformer_args and field_name not in guider_inputs: + additional_cond_kwargs[field_name] = field_value + block_state.additional_cond_kwargs.update(additional_cond_kwargs) + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) guider_state = components.guider.prepare_inputs(guider_inputs) @@ -362,7 +382,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState guider_state_batch.noise_pred = components.transformer( hidden_states=block_state.latent_model_input, timestep=block_state.timestep / 1000, - img_shapes=block_state.img_shapes, attention_kwargs=block_state.attention_kwargs, return_dict=False, **cond_kwargs, @@ -384,6 +403,7 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState return components, block_state +# loop step:after denoiser class QwenImageLoopAfterDenoiser(ModularPipelineBlocks): model_name = "qwenimage" @@ -481,6 +501,9 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState return components, block_state +# ==================== +# 2. DENOISE LOOP WRAPPER: define the denoising loop logic +# ==================== class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks): model_name = "qwenimage" @@ -537,8 +560,15 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state -# composing the denoising loops +# ==================== +# 3. DENOISE STEPS: compose the denoising loop with loop wrapper + loop steps +# ==================== + + +# Qwen Image (text2image, image2image) class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper): + model_name = "qwenimage" + block_classes = [ QwenImageLoopBeforeDenoiser, QwenImageLoopDenoiser, @@ -559,8 +589,9 @@ def description(self) -> str: ) -# composing the inpainting denoising loops +# Qwen Image (inpainting) class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): + model_name = "qwenimage" block_classes = [ QwenImageLoopBeforeDenoiser, QwenImageLoopDenoiser, @@ -583,8 +614,9 @@ def description(self) -> str: ) -# composing the controlnet denoising loops +# Qwen Image (text2image, image2image) with controlnet class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): + model_name = "qwenimage" block_classes = [ QwenImageLoopBeforeDenoiser, QwenImageLoopBeforeDenoiserControlNet, @@ -607,8 +639,9 @@ def description(self) -> str: ) -# composing the controlnet denoising loops +# Qwen Image (inpainting) with controlnet class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): + model_name = "qwenimage" block_classes = [ QwenImageLoopBeforeDenoiser, QwenImageLoopBeforeDenoiserControlNet, @@ -639,8 +672,9 @@ def description(self) -> str: ) -# composing the denoising loops +# Qwen Image Edit (image2image) class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper): + model_name = "qwenimage-edit" block_classes = [ QwenImageEditLoopBeforeDenoiser, QwenImageEditLoopDenoiser, @@ -661,7 +695,9 @@ def description(self) -> str: ) +# Qwen Image Edit (inpainting) class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): + model_name = "qwenimage-edit" block_classes = [ QwenImageEditLoopBeforeDenoiser, QwenImageEditLoopDenoiser, @@ -682,3 +718,26 @@ def description(self) -> str: " - `QwenImageLoopAfterDenoiserInpaint`\n" "This block supports inpainting tasks for QwenImage Edit." ) + + +# Qwen Image Layered (image2image) +class QwenImageLayeredDenoiseStep(QwenImageDenoiseLoopWrapper): + model_name = "qwenimage-layered" + block_classes = [ + QwenImageEditLoopBeforeDenoiser, + QwenImageEditLoopDenoiser, + QwenImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageEditLoopBeforeDenoiser`\n" + " - `QwenImageEditLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + "This block supports QwenImage Layered." + ) diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index b126a368bfdf..4b66dd32e521 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Text and VAE encoder blocks for QwenImage pipelines. +""" + from typing import Dict, List, Optional, Union import PIL @@ -28,6 +32,17 @@ from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from .modular_pipeline import QwenImageModularPipeline +from .prompt_templates import ( + QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE, + QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE, + QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX, + QWENIMAGE_EDIT_PROMPT_TEMPLATE, + QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX, + QWENIMAGE_LAYERED_CAPTION_PROMPT_CN, + QWENIMAGE_LAYERED_CAPTION_PROMPT_EN, + QWENIMAGE_PROMPT_TEMPLATE, + QWENIMAGE_PROMPT_TEMPLATE_START_IDX, +) logger = logging.get_logger(__name__) @@ -45,8 +60,8 @@ def get_qwen_prompt_embeds( text_encoder, tokenizer, prompt: Union[str, List[str]] = None, - prompt_template_encode: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", - prompt_template_encode_start_idx: int = 34, + prompt_template_encode: str = QWENIMAGE_PROMPT_TEMPLATE, + prompt_template_encode_start_idx: int = QWENIMAGE_PROMPT_TEMPLATE_START_IDX, tokenizer_max_length: int = 1024, device: Optional[torch.device] = None, ): @@ -86,8 +101,8 @@ def get_qwen_prompt_embeds_edit( processor, prompt: Union[str, List[str]] = None, image: Optional[torch.Tensor] = None, - prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n", - prompt_template_encode_start_idx: int = 64, + prompt_template_encode: str = QWENIMAGE_EDIT_PROMPT_TEMPLATE, + prompt_template_encode_start_idx: int = QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX, device: Optional[torch.device] = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt @@ -133,9 +148,9 @@ def get_qwen_prompt_embeds_edit_plus( processor, prompt: Union[str, List[str]] = None, image: Optional[Union[torch.Tensor, List[PIL.Image.Image], PIL.Image.Image]] = None, - prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", - img_template_encode: str = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>", - prompt_template_encode_start_idx: int = 64, + prompt_template_encode: str = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE, + img_template_encode: str = QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE, + prompt_template_encode_start_idx: int = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX, device: Optional[torch.device] = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt @@ -241,15 +256,18 @@ def encode_vae_image( return image_latents -class QwenImageEditResizeDynamicStep(ModularPipelineBlocks): - model_name = "qwenimage" - - def __init__(self, input_name: str = "image", output_name: str = "resized_image"): - """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio. +# ==================== +# 1. RESIZE +# ==================== +class QwenImageEditResizeStep(ModularPipelineBlocks): + model_name = "qwenimage-edit" - This block resizes an input image tensor and exposes the resized result under configurable input and output - names. Use this when you need to wire the resize step to different image fields (e.g., "image", - "control_image") + def __init__( + self, + input_name: str = "image", + output_name: str = "resized_image", + ): + """Create a configurable step for resizing images to the target area while maintaining the aspect ratio. Args: input_name (str, optional): Name of the image field to read from the @@ -267,7 +285,7 @@ def __init__(self, input_name: str = "image", output_name: str = "resized_image" @property def description(self) -> str: - return f"Image Resize step that resize the {self._image_input_name} to the target area (1024 * 1024) while maintaining the aspect ratio." + return f"Image Resize step that resize the {self._image_input_name} to target area while maintaining the aspect ratio." @property def expected_components(self) -> List[ComponentSpec]: @@ -321,89 +339,289 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state -class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep): - model_name = "qwenimage" +class QwenImageLayeredResizeStep(ModularPipelineBlocks): + model_name = "qwenimage-layered" def __init__( self, input_name: str = "image", output_name: str = "resized_image", - vae_image_output_name: str = "vae_image", ): - """Create a configurable step for resizing images to the target area (384 * 384) while maintaining the aspect ratio. - - This block resizes an input image or a list input images and exposes the resized result under configurable - input and output names. Use this when you need to wire the resize step to different image fields (e.g., - "image", "control_image") + """Create a configurable step for resizing images to the target area while maintaining the aspect ratio. Args: input_name (str, optional): Name of the image field to read from the pipeline state. Defaults to "image". output_name (str, optional): Name of the resized image field to write back to the pipeline state. Defaults to "resized_image". - vae_image_output_name (str, optional): Name of the image field - to write back to the pipeline state. This is used by the VAE encoder step later on. QwenImage Edit Plus - processes the input image(s) differently for the VL and the VAE. """ if not isinstance(input_name, str) or not isinstance(output_name, str): raise ValueError( f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}" ) - self.condition_image_size = 384 * 384 self._image_input_name = input_name self._resized_image_output_name = output_name - self._vae_image_output_name = vae_image_output_name super().__init__() + @property + def description(self) -> str: + return f"Image Resize step that resize the {self._image_input_name} to target area while maintaining the aspect ratio." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_resize_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize" + ), + InputParam( + name="resolution", + default=640, + type_hint=int, + description="The target area to resize the image to, can be 1024 or 640", + ), + ] + @property def intermediate_outputs(self) -> List[OutputParam]: - return super().intermediate_outputs + [ + return [ OutputParam( - name=self._vae_image_output_name, - type_hint=List[PIL.Image.Image], - description="The images to be processed which will be further used by the VAE encoder.", + name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images" ), ] + @staticmethod + def check_inputs(resolution: int): + if resolution not in [1024, 640]: + raise ValueError(f"Resolution must be 1024 or 640 but is {resolution}") + @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) + self.check_inputs(resolution=block_state.resolution) + images = getattr(block_state, self._image_input_name) if not is_valid_image_imagelist(images): raise ValueError(f"Images must be image or list of images but are {type(images)}") - if ( - not isinstance(images, torch.Tensor) - and isinstance(images, PIL.Image.Image) - and not isinstance(images, list) - ): + if is_valid_image(images): images = [images] - # TODO (sayakpaul): revisit this when the inputs are `torch.Tensor`s - condition_images = [] - vae_images = [] - for img in images: - image_width, image_height = img.size - condition_width, condition_height, _ = calculate_dimensions( - self.condition_image_size, image_width / image_height + image_width, image_height = images[0].size + target_area = block_state.resolution * block_state.resolution + calculated_width, calculated_height, _ = calculate_dimensions(target_area, image_width / image_height) + + resized_images = [ + components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width) + for image in images + ] + + setattr(block_state, self._resized_image_output_name, resized_images) + self.set_block_state(state, block_state) + return components, state + + +class QwenImageEditPlusResizeStep(ModularPipelineBlocks): + """Resize each image independently based on its own aspect ratio. For QwenImage Edit Plus.""" + + model_name = "qwenimage-edit-plus" + + def __init__( + self, + input_name: str = "image", + output_name: str = "resized_image", + target_area: int = 1024 * 1024, + ): + """Create a step for resizing images to a target area. + + Each image is resized independently based on its own aspect ratio. This is suitable for Edit Plus where + multiple reference images can have different dimensions. + + Args: + input_name (str, optional): Name of the image field to read. Defaults to "image". + output_name (str, optional): Name of the resized image field to write. Defaults to "resized_image". + target_area (int, optional): Target area in pixels. Defaults to 1024*1024. + """ + if not isinstance(input_name, str) or not isinstance(output_name, str): + raise ValueError( + f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}" ) - condition_images.append(components.image_resize_processor.resize(img, condition_height, condition_width)) - vae_images.append(img) + self._image_input_name = input_name + self._resized_image_output_name = output_name + self._target_area = target_area + super().__init__() + + @property + def description(self) -> str: + return ( + f"Image Resize step that resizes {self._image_input_name} to target area {self._target_area}.\n" + "Each image is resized independently based on its own aspect ratio." + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_resize_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + name=self._image_input_name, + required=True, + type_hint=torch.Tensor, + description="The image(s) to resize", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images" + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + images = getattr(block_state, self._image_input_name) + + if not is_valid_image_imagelist(images): + raise ValueError(f"Images must be image or list of images but are {type(images)}") + + if is_valid_image(images): + images = [images] + + # Resize each image independently based on its own aspect ratio + resized_images = [] + for image in images: + image_width, image_height = image.size + calculated_width, calculated_height, _ = calculate_dimensions( + self._target_area, image_width / image_height + ) + resized_images.append( + components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width) + ) + + setattr(block_state, self._resized_image_output_name, resized_images) + self.set_block_state(state, block_state) + return components, state + + +# ==================== +# 2. GET IMAGE PROMPT +# ==================== +class QwenImageLayeredGetImagePromptStep(ModularPipelineBlocks): + """ + Auto-caption step that generates a text prompt from the input image if none is provided. Uses the VL model to + generate a description of the image. + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return ( + "Auto-caption step that generates a text prompt from the input image if none is provided.\n" + "Uses the VL model (text_encoder) to generate a description of the image.\n" + "If prompt is already provided, this step passes through unchanged." + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration), + ComponentSpec("processor", Qwen2VLProcessor), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec(name="image_caption_prompt_en", default=QWENIMAGE_LAYERED_CAPTION_PROMPT_EN), + ConfigSpec(name="image_caption_prompt_cn", default=QWENIMAGE_LAYERED_CAPTION_PROMPT_CN), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="prompt", type_hint=str, description="The prompt to encode"), + InputParam( + name="resized_image", + required=True, + type_hint=PIL.Image.Image, + description="The image to generate caption from, should be resized use the resize step", + ), + InputParam( + name="use_en_prompt", + default=False, + type_hint=bool, + description="Whether to use English prompt template", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + # If prompt is empty or None, generate caption from image + if block_state.prompt is None or block_state.prompt == "" or block_state.prompt == " ": + if block_state.use_en_prompt: + caption_prompt = components.config.image_caption_prompt_en + else: + caption_prompt = components.config.image_caption_prompt_cn + + model_inputs = components.processor( + text=caption_prompt, + images=block_state.resized_image, + padding=True, + return_tensors="pt", + ).to(device) + + generated_ids = components.text_encoder.generate(**model_inputs, max_new_tokens=512) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids) + ] + output_text = components.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + + block_state.prompt = output_text.strip() - setattr(block_state, self._resized_image_output_name, condition_images) - setattr(block_state, self._vae_image_output_name, vae_images) self.set_block_state(state, block_state) return components, state +# ==================== +# 3. TEXT ENCODER +# ==================== class QwenImageTextEncoderStep(ModularPipelineBlocks): model_name = "qwenimage" @property def description(self) -> str: - return "Text Encoder step that generate text_embeddings to guide the image generation" + return "Text Encoder step that generates text embeddings to guide the image generation." @property def expected_components(self) -> List[ComponentSpec]: @@ -421,11 +639,8 @@ def expected_components(self) -> List[ComponentSpec]: @property def expected_configs(self) -> List[ConfigSpec]: return [ - ConfigSpec( - name="prompt_template_encode", - default="<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", - ), - ConfigSpec(name="prompt_template_encode_start_idx", default=34), + ConfigSpec(name="prompt_template_encode", default=QWENIMAGE_PROMPT_TEMPLATE), + ConfigSpec(name="prompt_template_encode_start_idx", default=QWENIMAGE_PROMPT_TEMPLATE_START_IDX), ConfigSpec(name="tokenizer_max_length", default=1024), ] @@ -532,7 +747,7 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks): @property def description(self) -> str: - return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation" + return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation." @property def expected_components(self) -> List[ComponentSpec]: @@ -550,11 +765,8 @@ def expected_components(self) -> List[ComponentSpec]: @property def expected_configs(self) -> List[ConfigSpec]: return [ - ConfigSpec( - name="prompt_template_encode", - default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n", - ), - ConfigSpec(name="prompt_template_encode_start_idx", default=64), + ConfigSpec(name="prompt_template_encode", default=QWENIMAGE_EDIT_PROMPT_TEMPLATE), + ConfigSpec(name="prompt_template_encode_start_idx", default=QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX), ] @property @@ -565,7 +777,7 @@ def inputs(self) -> List[InputParam]: InputParam( name="resized_image", required=True, - type_hint=torch.Tensor, + type_hint=PIL.Image.Image, description="The image prompt to encode, should be resized using resize step", ), ] @@ -647,23 +859,93 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state -class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep): - model_name = "qwenimage" +class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks): + """Text encoder for QwenImage Edit Plus (VL encoding with multiple images).""" + + model_name = "qwenimage-edit-plus" + + @property + def description(self) -> str: + return ( + "Text Encoder step for QwenImage Edit Plus that processes prompt and multiple images together " + "to generate text embeddings for guiding image generation." + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration), + ComponentSpec("processor", Qwen2VLProcessor), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] @property def expected_configs(self) -> List[ConfigSpec]: return [ - ConfigSpec( - name="prompt_template_encode", - default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + ConfigSpec(name="prompt_template_encode", default=QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE), + ConfigSpec(name="img_template_encode", default=QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE), + ConfigSpec(name="prompt_template_encode_start_idx", default=QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"), + InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"), + InputParam( + name="resized_cond_image", + required=True, + type_hint=torch.Tensor, + description="The image(s) to encode, can be a single image or list of images, should be resized to 384x384 using resize step", ), - ConfigSpec( - name="img_template_encode", - default="Picture {}: <|vision_start|><|image_pad|><|vision_end|>", + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="The prompt embeddings", + ), + OutputParam( + name="prompt_embeds_mask", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="The encoder attention mask", + ), + OutputParam( + name="negative_prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="The negative prompt embeddings", + ), + OutputParam( + name="negative_prompt_embeds_mask", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="The negative prompt embeddings mask", ), - ConfigSpec(name="prompt_template_encode_start_idx", default=64), ] + @staticmethod + def check_inputs(prompt, negative_prompt): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ( + negative_prompt is not None + and not isinstance(negative_prompt, str) + and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) @@ -676,7 +958,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): components.text_encoder, components.processor, prompt=block_state.prompt, - image=block_state.resized_image, + image=block_state.resized_cond_image, prompt_template_encode=components.config.prompt_template_encode, img_template_encode=components.config.img_template_encode, prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, @@ -692,7 +974,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): components.text_encoder, components.processor, prompt=negative_prompt, - image=block_state.resized_image, + image=block_state.resized_cond_image, prompt_template_encode=components.config.prompt_template_encode, img_template_encode=components.config.img_template_encode, prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, @@ -704,12 +986,15 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +# ==================== +# 4. IMAGE PREPROCESS +# ==================== class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks): model_name = "qwenimage" @property def description(self) -> str: - return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images can be resized first using QwenImageEditResizeDynamicStep." + return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images will be resized to the given height and width." @property def expected_components(self) -> List[ComponentSpec]: @@ -726,8 +1011,7 @@ def expected_components(self) -> List[ComponentSpec]: def inputs(self) -> List[InputParam]: return [ InputParam("mask_image", required=True), - InputParam("resized_image"), - InputParam("image"), + InputParam("image", required=True), InputParam("height"), InputParam("width"), InputParam("padding_mask_crop"), @@ -757,23 +1041,73 @@ def check_inputs(height, width, vae_scale_factor): def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) - if block_state.resized_image is None and block_state.image is None: - raise ValueError("resized_image and image cannot be None at the same time") + self.check_inputs( + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width - if block_state.resized_image is None: - image = block_state.image - self.check_inputs( - height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = ( + components.image_mask_processor.preprocess( + image=block_state.image, + mask=block_state.mask_image, + height=height, + width=width, + padding_mask_crop=block_state.padding_mask_crop, ) - height = block_state.height or components.default_height - width = block_state.width or components.default_width - else: - width, height = block_state.resized_image[0].size - image = block_state.resized_image + ) + + self.set_block_state(state, block_state) + return components, state + + +class QwenImageEditInpaintProcessImagesInputStep(ModularPipelineBlocks): + model_name = "qwenimage-edit" + + @property + def description(self) -> str: + return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images should be resized first." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_mask_processor", + InpaintProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("mask_image", required=True), + InputParam("resized_image", required=True), + InputParam("padding_mask_crop"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam(name="processed_image"), + OutputParam(name="processed_mask_image"), + OutputParam( + name="mask_overlay_kwargs", + type_hint=Dict, + description="The kwargs for the postprocess step to apply the mask overlay", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + width, height = block_state.resized_image[0].size block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = ( components.image_mask_processor.preprocess( - image=image, + image=block_state.resized_image, mask=block_state.mask_image, height=height, width=width, @@ -790,7 +1124,7 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks): @property def description(self) -> str: - return "Image Preprocess step. Images can be resized first using QwenImageEditResizeDynamicStep." + return "Image Preprocess step. will resize the image to the given height and width." @property def expected_components(self) -> List[ComponentSpec]: @@ -805,7 +1139,11 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: - return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")] + return [ + InputParam("image", required=True), + InputParam("height"), + InputParam("width"), + ] @property def intermediate_outputs(self) -> List[OutputParam]: @@ -823,22 +1161,58 @@ def check_inputs(height, width, vae_scale_factor): def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) - if block_state.resized_image is None and block_state.image is None: - raise ValueError("resized_image and image cannot be None at the same time") + self.check_inputs( + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width - if block_state.resized_image is None: - image = block_state.image - self.check_inputs( - height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor - ) - height = block_state.height or components.default_height - width = block_state.width or components.default_width - else: - width, height = block_state.resized_image[0].size - image = block_state.resized_image + block_state.processed_image = components.image_processor.preprocess( + image=block_state.image, + height=height, + width=width, + ) + + self.set_block_state(state, block_state) + return components, state + + +class QwenImageEditProcessImagesInputStep(ModularPipelineBlocks): + model_name = "qwenimage-edit" + + @property + def description(self) -> str: + return "Image Preprocess step. Images needs to be resized first." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("resized_image", required=True), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam(name="processed_image")] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + width, height = block_state.resized_image[0].size block_state.processed_image = components.image_processor.preprocess( - image=image, + image=block_state.resized_image, height=height, width=width, ) @@ -847,59 +1221,64 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state -class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep): +class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks): model_name = "qwenimage-edit-plus" - def __init__(self): - self.vae_image_size = 1024 * 1024 - super().__init__() - @property def description(self) -> str: - return "Image Preprocess step for QwenImage Edit Plus. Unlike QwenImage Edit, QwenImage Edit Plus doesn't use the same resized image for further preprocessing." + return "Image Preprocess step. Images can be resized first using QwenImageEditResizeStep." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] @property def inputs(self) -> List[InputParam]: - return [InputParam("vae_image"), InputParam("image"), InputParam("height"), InputParam("width")] + return [InputParam("resized_image")] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam(name="processed_image")] @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) - if block_state.vae_image is None and block_state.image is None: - raise ValueError("`vae_image` and `image` cannot be None at the same time") + image = block_state.resized_image - vae_image_sizes = None - if block_state.vae_image is None: - image = block_state.image - self.check_inputs( - height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor - ) - height = block_state.height or components.default_height - width = block_state.width or components.default_width - block_state.processed_image = components.image_processor.preprocess( - image=image, height=height, width=width + is_image_list = isinstance(image, list) + if not is_image_list: + image = [image] + + processed_images = [] + for img in image: + img_width, img_height = img.size + processed_images.append( + components.image_processor.preprocess(image=img, height=img_height, width=img_width) ) - else: - # QwenImage Edit Plus can allow multiple input images with varied resolutions - processed_images = [] - vae_image_sizes = [] - for img in block_state.vae_image: - width, height = img.size - vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, width / height) - vae_image_sizes.append((vae_width, vae_height)) - processed_images.append( - components.image_processor.preprocess(image=img, height=vae_height, width=vae_width) - ) + block_state.processed_image = processed_images + if is_image_list: block_state.processed_image = processed_images - - block_state.vae_image_sizes = vae_image_sizes + else: + block_state.processed_image = processed_images[0] self.set_block_state(state, block_state) return components, state -class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks): +# ==================== +# 5. VAE ENCODER +# ==================== +class QwenImageVaeEncoderStep(ModularPipelineBlocks): + """VAE encoder that handles both single images and lists of images with varied resolutions.""" + model_name = "qwenimage" def __init__( @@ -909,21 +1288,12 @@ def __init__( ): """Initialize a VAE encoder step for converting images to latent representations. - Both the input and output names are configurable so this block can be configured to process to different image - inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents"). + Handles both single images and lists of images. When input is a list, outputs a list of latents. When input is + a single tensor, outputs a single latent tensor. Args: - input_name (str, optional): Name of the input image tensor. Defaults to "processed_image". - Examples: "processed_image" or "processed_control_image" - output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents". - Examples: "image_latents" or "control_image_latents" - - Examples: - # Basic usage with default settings (includes image processor) QwenImageVaeEncoderDynamicStep() - - # Custom input/output names for control image QwenImageVaeEncoderDynamicStep( - input_name="processed_control_image", output_name="control_image_latents" - ) + input_name (str, optional): Name of the input image tensor or list. Defaults to "processed_image". + output_name (str, optional): Name of the output latent tensor or list. Defaults to "image_latents". """ self._image_input_name = input_name self._image_latents_output_name = output_name @@ -931,17 +1301,18 @@ def __init__( @property def description(self) -> str: - return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n" + return ( + f"VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n" + "Handles both single images and lists of images with varied resolutions." + ) @property def expected_components(self) -> List[ComponentSpec]: - components = [ComponentSpec("vae", AutoencoderKLQwenImage)] - return components + return [ComponentSpec("vae", AutoencoderKLQwenImage)] @property def inputs(self) -> List[InputParam]: - inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")] - return inputs + return [InputParam(self._image_input_name, required=True), InputParam("generator")] @property def intermediate_outputs(self) -> List[OutputParam]: @@ -949,46 +1320,7 @@ def intermediate_outputs(self) -> List[OutputParam]: OutputParam( self._image_latents_output_name, type_hint=torch.Tensor, - description="The latents representing the reference image", - ) - ] - - @torch.no_grad() - def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - device = components._execution_device - dtype = components.vae.dtype - - image = getattr(block_state, self._image_input_name) - - # Encode image into latents - image_latents = encode_vae_image( - image=image, - vae=components.vae, - generator=block_state.generator, - device=device, - dtype=dtype, - latent_channels=components.num_channels_latents, - ) - setattr(block_state, self._image_latents_output_name, image_latents) - - self.set_block_state(state, block_state) - - return components, state - - -class QwenImageEditPlusVaeEncoderDynamicStep(QwenImageVaeEncoderDynamicStep): - model_name = "qwenimage-edit-plus" - - @property - def intermediate_outputs(self) -> List[OutputParam]: - # Each reference image latent can have varied resolutions hence we return this as a list. - return [ - OutputParam( - self._image_latents_output_name, - type_hint=List[torch.Tensor], - description="The latents representing the reference image(s).", + description="The latents representing the reference image(s). Single tensor or list depending on input.", ) ] @@ -1000,8 +1332,11 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - dtype = components.vae.dtype image = getattr(block_state, self._image_input_name) + is_image_list = isinstance(image, list) + if not is_image_list: + image = [image] - # Encode image into latents + # Handle both single image and list of images image_latents = [] for img in image: image_latents.append( @@ -1014,6 +1349,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - latent_channels=components.num_channels_latents, ) ) + if not is_image_list: + image_latents = image_latents[0] setattr(block_state, self._image_latents_output_name, image_latents) @@ -1131,3 +1468,37 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - self.set_block_state(state, block_state) return components, state + + +# ==================== +# 6. PERMUTE LATENTS +# ==================== +class QwenImageLayeredPermuteLatentsStep(ModularPipelineBlocks): + """Permute image latents from VAE format to Layered format.""" + + model_name = "qwenimage-layered" + + def __init__(self, input_name: str = "image_latents"): + self._input_name = input_name + super().__init__() + + @property + def description(self) -> str: + return f"Permute {self._input_name} from (B, C, 1, H, W) to (B, 1, C, H, W) for Layered packing." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(self._input_name, required=True), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Permute: (B, C, 1, H, W) -> (B, 1, C, H, W) + latents = getattr(block_state, self._input_name) + setattr(block_state, self._input_name, latents.permute(0, 2, 1, 3, 4)) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py index 6e656e484847..4a1cf3700c57 100644 --- a/src/diffusers/modular_pipelines/qwenimage/inputs.py +++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py @@ -19,7 +19,7 @@ from ...models import QwenImageMultiControlNetModel from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier +from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier def repeat_tensor_to_batch_size( @@ -221,37 +221,16 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state -class QwenImageInputsDynamicStep(ModularPipelineBlocks): - model_name = "qwenimage" - - def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = []): - """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" - - This step handles multiple common tasks to prepare inputs for the denoising step: - 1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size - 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size - - This is a dynamic block that allows you to configure which inputs to process. +class QwenImageAdditionalInputsStep(ModularPipelineBlocks): + """Input step for QwenImage: update height/width, expand batch, patchify.""" - Args: - image_latent_inputs (List[str], optional): Names of image latent tensors to process. - These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or - list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"] - additional_batch_inputs (List[str], optional): - Names of additional conditional input tensors to expand batch size. These tensors will only have their - batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. - Defaults to []. Examples: ["processed_mask_image"] - - Examples: - # Configure to process image_latents (default behavior) QwenImageInputsDynamicStep() - - # Configure to process multiple image latent inputs - QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"]) + model_name = "qwenimage" - # Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep( - image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"] - ) - """ + def __init__( + self, + image_latent_inputs: List[str] = ["image_latents"], + additional_batch_inputs: List[str] = [], + ): if not isinstance(image_latent_inputs, list): image_latent_inputs = [image_latent_inputs] if not isinstance(additional_batch_inputs, list): @@ -263,14 +242,12 @@ def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additiona @property def description(self) -> str: - # Functionality section summary_section = ( "Input processing step that:\n" - " 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n" + " 1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size\n" " 2. For additional batch inputs: Expands batch dimensions to match final batch size" ) - # Inputs info inputs_info = "" if self._image_latent_inputs or self._additional_batch_inputs: inputs_info = "\n\nConfigured inputs:" @@ -279,11 +256,16 @@ def description(self) -> str: if self._additional_batch_inputs: inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" - # Placement guidance placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." return summary_section + inputs_info + placement_section + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + @property def inputs(self) -> List[InputParam]: inputs = [ @@ -293,11 +275,9 @@ def inputs(self) -> List[InputParam]: InputParam(name="width"), ] - # Add image latent inputs for image_latent_input_name in self._image_latent_inputs: inputs.append(InputParam(name=image_latent_input_name)) - # Add additional batch inputs for input_name in self._additional_batch_inputs: inputs.append(InputParam(name=input_name)) @@ -306,26 +286,28 @@ def inputs(self) -> List[InputParam]: @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam(name="image_height", type_hint=int, description="The height of the image latents"), - OutputParam(name="image_width", type_hint=int, description="The width of the image latents"), - ] - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + OutputParam( + name="image_height", + type_hint=int, + description="The image height calculated from the image latents dimension", + ), + OutputParam( + name="image_width", + type_hint=int, + description="The image width calculated from the image latents dimension", + ), ] def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - # Process image latent inputs (height/width calculation, patchify, and batch expansion) + # Process image latent inputs for image_latent_input_name in self._image_latent_inputs: image_latent_tensor = getattr(block_state, image_latent_input_name) if image_latent_tensor is None: continue - # 1. Calculate height/width from latents + # 1. Calculate height/width from latents and update if not provided height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor) block_state.height = block_state.height or height block_state.width = block_state.width or width @@ -335,7 +317,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - if not hasattr(block_state, "image_width"): block_state.image_width = width - # 2. Patchify the image latent tensor + # 2. Patchify image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor) # 3. Expand batch size @@ -354,7 +336,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - if input_tensor is None: continue - # Only expand batch size input_tensor = repeat_tensor_to_batch_size( input_name=input_name, input_tensor=input_tensor, @@ -368,63 +349,270 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state -class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep): +class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks): + """Input step for QwenImage Edit Plus: handles list of latents with different sizes.""" + model_name = "qwenimage-edit-plus" + def __init__( + self, + image_latent_inputs: List[str] = ["image_latents"], + additional_batch_inputs: List[str] = [], + ): + if not isinstance(image_latent_inputs, list): + image_latent_inputs = [image_latent_inputs] + if not isinstance(additional_batch_inputs, list): + additional_batch_inputs = [additional_batch_inputs] + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + summary_section = ( + "Input processing step for Edit Plus that:\n" + " 1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size\n" + " Height/width defaults to last image in the list." + ) + + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + inputs = [ + InputParam(name="num_images_per_prompt", default=1), + InputParam(name="batch_size", required=True), + InputParam(name="height"), + InputParam(name="width"), + ] + + for image_latent_input_name in self._image_latent_inputs: + inputs.append(InputParam(name=image_latent_input_name)) + + for input_name in self._additional_batch_inputs: + inputs.append(InputParam(name=input_name)) + + return inputs + @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam(name="image_height", type_hint=List[int], description="The height of the image latents"), - OutputParam(name="image_width", type_hint=List[int], description="The width of the image latents"), + OutputParam( + name="image_height", + type_hint=List[int], + description="The image heights calculated from the image latents dimension", + ), + OutputParam( + name="image_width", + type_hint=List[int], + description="The image widths calculated from the image latents dimension", + ), ] def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - # Process image latent inputs (height/width calculation, patchify, and batch expansion) + # Process image latent inputs for image_latent_input_name in self._image_latent_inputs: image_latent_tensor = getattr(block_state, image_latent_input_name) if image_latent_tensor is None: continue - # Each image latent can have different size in QwenImage Edit Plus. + is_list = isinstance(image_latent_tensor, list) + if not is_list: + image_latent_tensor = [image_latent_tensor] + image_heights = [] image_widths = [] packed_image_latent_tensors = [] - for img_latent_tensor in image_latent_tensor: + for i, img_latent_tensor in enumerate(image_latent_tensor): # 1. Calculate height/width from latents height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor) image_heights.append(height) image_widths.append(width) - # 2. Patchify the image latent tensor + # 2. Patchify img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor) # 3. Expand batch size img_latent_tensor = repeat_tensor_to_batch_size( - input_name=image_latent_input_name, + input_name=f"{image_latent_input_name}[{i}]", input_tensor=img_latent_tensor, num_images_per_prompt=block_state.num_images_per_prompt, batch_size=block_state.batch_size, ) packed_image_latent_tensors.append(img_latent_tensor) + # Concatenate all packed latents along dim=1 packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1) + + # Output lists of heights/widths block_state.image_height = image_heights block_state.image_width = image_widths - setattr(block_state, image_latent_input_name, packed_image_latent_tensors) + # Default height/width from last image block_state.height = block_state.height or image_heights[-1] block_state.width = block_state.width or image_widths[-1] + setattr(block_state, image_latent_input_name, packed_image_latent_tensors) + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +# YiYi TODO: support define config default component from the ModularPipeline level. +# it is same as QwenImageAdditionalInputsStep, but with layered pachifier. +class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks): + """Input step for QwenImage Layered: update height/width, expand batch, patchify with layered pachifier.""" + + model_name = "qwenimage-layered" + + def __init__( + self, + image_latent_inputs: List[str] = ["image_latents"], + additional_batch_inputs: List[str] = [], + ): + if not isinstance(image_latent_inputs, list): + image_latent_inputs = [image_latent_inputs] + if not isinstance(additional_batch_inputs, list): + additional_batch_inputs = [additional_batch_inputs] + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + summary_section = ( + "Input processing step for Layered that:\n" + " 1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + inputs = [ + InputParam(name="num_images_per_prompt", default=1), + InputParam(name="batch_size", required=True), + ] + + for image_latent_input_name in self._image_latent_inputs: + inputs.append(InputParam(name=image_latent_input_name)) + + for input_name in self._additional_batch_inputs: + inputs.append(InputParam(name=input_name)) + + return inputs + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="image_height", + type_hint=int, + description="The image height calculated from the image latents dimension", + ), + OutputParam( + name="image_width", + type_hint=int, + description="The image width calculated from the image latents dimension", + ), + OutputParam(name="height", type_hint=int, description="The height of the image output"), + OutputParam(name="width", type_hint=int, description="The width of the image output"), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate height/width from latents and update if not provided + # Layered latents are (B, layers, C, H, W) + height = image_latent_tensor.shape[3] * components.vae_scale_factor + width = image_latent_tensor.shape[4] * components.vae_scale_factor + block_state.height = height + block_state.width = width + + if not hasattr(block_state, "image_height"): + block_state.image_height = height + if not hasattr(block_state, "image_width"): + block_state.image_width = width + + # 2. Patchify with layered pachifier + image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor) + + # 3. Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=image_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, image_latent_input_name, image_latent_tensor) + # Process additional batch inputs (only batch expansion) for input_name in self._additional_batch_inputs: input_tensor = getattr(block_state, input_name) if input_tensor is None: continue - # Only expand batch size input_tensor = repeat_tensor_to_batch_size( input_name=input_name, input_tensor=input_tensor, diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py deleted file mode 100644 index dcce0cab5dd1..000000000000 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py +++ /dev/null @@ -1,1113 +0,0 @@ -# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ...utils import logging -from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks -from ..modular_pipeline_utils import InsertableDict -from .before_denoise import ( - QwenImageControlNetBeforeDenoiserStep, - QwenImageCreateMaskLatentsStep, - QwenImageEditPlusRoPEInputsStep, - QwenImageEditRoPEInputsStep, - QwenImagePrepareLatentsStep, - QwenImagePrepareLatentsWithStrengthStep, - QwenImageRoPEInputsStep, - QwenImageSetTimestepsStep, - QwenImageSetTimestepsWithStrengthStep, -) -from .decoders import ( - QwenImageAfterDenoiseStep, - QwenImageDecoderStep, - QwenImageInpaintProcessImagesOutputStep, - QwenImageProcessImagesOutputStep, -) -from .denoise import ( - QwenImageControlNetDenoiseStep, - QwenImageDenoiseStep, - QwenImageEditDenoiseStep, - QwenImageEditInpaintDenoiseStep, - QwenImageInpaintControlNetDenoiseStep, - QwenImageInpaintDenoiseStep, - QwenImageLoopBeforeDenoiserControlNet, -) -from .encoders import ( - QwenImageControlNetVaeEncoderStep, - QwenImageEditPlusProcessImagesInputStep, - QwenImageEditPlusResizeDynamicStep, - QwenImageEditPlusTextEncoderStep, - QwenImageEditPlusVaeEncoderDynamicStep, - QwenImageEditResizeDynamicStep, - QwenImageEditTextEncoderStep, - QwenImageInpaintProcessImagesInputStep, - QwenImageProcessImagesInputStep, - QwenImageTextEncoderStep, - QwenImageVaeEncoderDynamicStep, -) -from .inputs import ( - QwenImageControlNetInputsStep, - QwenImageEditPlusInputsDynamicStep, - QwenImageInputsDynamicStep, - QwenImageTextInputsStep, -) - - -logger = logging.get_logger(__name__) - -# 1. QwenImage - -## 1.1 QwenImage/text2image - -#### QwenImage/decode -#### (standard decode step works for most tasks except for inpaint) -QwenImageDecodeBlocks = InsertableDict( - [ - ("decode", QwenImageDecoderStep()), - ("postprocess", QwenImageProcessImagesOutputStep()), - ] -) - - -class QwenImageDecodeStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageDecodeBlocks.values() - block_names = QwenImageDecodeBlocks.keys() - - @property - def description(self): - return "Decode step that decodes the latents to images and postprocess the generated image." - - -#### QwenImage/text2image presets -TEXT2IMAGE_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageTextEncoderStep()), - ("input", QwenImageTextInputsStep()), - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsStep()), - ("prepare_rope_inputs", QwenImageRoPEInputsStep()), - ("denoise", QwenImageDenoiseStep()), - ("after_denoise", QwenImageAfterDenoiseStep()), - ("decode", QwenImageDecodeStep()), - ] -) - - -## 1.2 QwenImage/inpaint - -#### QwenImage/inpaint vae encoder -QwenImageInpaintVaeEncoderBlocks = InsertableDict( - [ - ( - "preprocess", - QwenImageInpaintProcessImagesInputStep, - ), # image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs - ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents - ] -) - - -class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageInpaintVaeEncoderBlocks.values() - block_names = QwenImageInpaintVaeEncoderBlocks.keys() - - @property - def description(self) -> str: - return ( - "This step is used for processing image and mask inputs for inpainting tasks. It:\n" - " - Resizes the image to the target size, based on `height` and `width`.\n" - " - Processes and updates `image` and `mask_image`.\n" - " - Creates `image_latents`." - ) - - -#### QwenImage/inpaint inputs -QwenImageInpaintInputBlocks = InsertableDict( - [ - ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings - ( - "additional_inputs", - QwenImageInputsDynamicStep( - image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"] - ), - ), - ] -) - - -class QwenImageInpaintInputStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageInpaintInputBlocks.values() - block_names = QwenImageInpaintInputBlocks.keys() - - @property - def description(self): - return "Input step that prepares the inputs for the inpainting denoising step. It:\n" - " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n" - " - update height/width based `image_latents`, patchify `image_latents`." - - -# QwenImage/inpaint prepare latents -QwenImageInpaintPrepareLatentsBlocks = InsertableDict( - [ - ("add_noise_to_latents", QwenImagePrepareLatentsWithStrengthStep()), - ("create_mask_latents", QwenImageCreateMaskLatentsStep()), - ] -) - - -class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageInpaintPrepareLatentsBlocks.values() - block_names = QwenImageInpaintPrepareLatentsBlocks.keys() - - @property - def description(self) -> str: - return ( - "This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n" - " - Add noise to the image latents to create the latents input for the denoiser.\n" - " - Create the pachified latents `mask` based on the processedmask image.\n" - ) - - -#### QwenImage/inpaint decode -QwenImageInpaintDecodeBlocks = InsertableDict( - [ - ("decode", QwenImageDecoderStep()), - ("postprocess", QwenImageInpaintProcessImagesOutputStep()), - ] -) - - -class QwenImageInpaintDecodeStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageInpaintDecodeBlocks.values() - block_names = QwenImageInpaintDecodeBlocks.keys() - - @property - def description(self): - return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image." - - -#### QwenImage/inpaint presets -INPAINT_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageTextEncoderStep()), - ("vae_encoder", QwenImageInpaintVaeEncoderStep()), - ("input", QwenImageInpaintInputStep()), - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), - ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()), - ("prepare_rope_inputs", QwenImageRoPEInputsStep()), - ("denoise", QwenImageInpaintDenoiseStep()), - ("after_denoise", QwenImageAfterDenoiseStep()), - ("decode", QwenImageInpaintDecodeStep()), - ] -) - - -## 1.3 QwenImage/img2img - -#### QwenImage/img2img vae encoder -QwenImageImg2ImgVaeEncoderBlocks = InsertableDict( - [ - ("preprocess", QwenImageProcessImagesInputStep()), - ("encode", QwenImageVaeEncoderDynamicStep()), - ] -) - - -class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage" - - block_classes = QwenImageImg2ImgVaeEncoderBlocks.values() - block_names = QwenImageImg2ImgVaeEncoderBlocks.keys() - - @property - def description(self) -> str: - return "Vae encoder step that preprocess andencode the image inputs into their latent representations." - - -#### QwenImage/img2img inputs -QwenImageImg2ImgInputBlocks = InsertableDict( - [ - ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings - ("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])), - ] -) - - -class QwenImageImg2ImgInputStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageImg2ImgInputBlocks.values() - block_names = QwenImageImg2ImgInputBlocks.keys() - - @property - def description(self): - return "Input step that prepares the inputs for the img2img denoising step. It:\n" - " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n" - " - update height/width based `image_latents`, patchify `image_latents`." - - -#### QwenImage/img2img presets -IMAGE2IMAGE_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageTextEncoderStep()), - ("vae_encoder", QwenImageImg2ImgVaeEncoderStep()), - ("input", QwenImageImg2ImgInputStep()), - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), - ("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()), - ("prepare_rope_inputs", QwenImageRoPEInputsStep()), - ("denoise", QwenImageDenoiseStep()), - ("after_denoise", QwenImageAfterDenoiseStep()), - ("decode", QwenImageDecodeStep()), - ] -) - - -## 1.4 QwenImage/controlnet - -#### QwenImage/controlnet presets -CONTROLNET_BLOCKS = InsertableDict( - [ - ("controlnet_vae_encoder", QwenImageControlNetVaeEncoderStep()), # vae encoder step for control_image - ("controlnet_inputs", QwenImageControlNetInputsStep()), # additional input step for controlnet - ( - "controlnet_before_denoise", - QwenImageControlNetBeforeDenoiserStep(), - ), # before denoise step (after set_timesteps step) - ( - "controlnet_denoise_loop_before", - QwenImageLoopBeforeDenoiserControlNet(), - ), # controlnet loop step (insert before the denoiseloop_denoiser) - ] -) - - -## 1.5 QwenImage/auto encoders - - -#### for inpaint and img2img tasks -class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep] - block_names = ["inpaint", "img2img"] - block_trigger_inputs = ["mask_image", "image"] - - @property - def description(self): - return ( - "Vae encoder step that encode the image inputs into their latent representations.\n" - + "This is an auto pipeline block.\n" - + " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n" - + " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n" - + " - if `mask_image` or `image` is not provided, step will be skipped." - ) - - -# for controlnet tasks -class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks): - block_classes = [QwenImageControlNetVaeEncoderStep] - block_names = ["controlnet"] - block_trigger_inputs = ["control_image"] - - @property - def description(self): - return ( - "Vae encoder step that encode the image inputs into their latent representations.\n" - + "This is an auto pipeline block.\n" - + " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n" - + " - if `control_image` is not provided, step will be skipped." - ) - - -## 1.6 QwenImage/auto inputs - - -# text2image/inpaint/img2img -class QwenImageAutoInputStep(AutoPipelineBlocks): - block_classes = [QwenImageInpaintInputStep, QwenImageImg2ImgInputStep, QwenImageTextInputsStep] - block_names = ["inpaint", "img2img", "text2image"] - block_trigger_inputs = ["processed_mask_image", "image_latents", None] - - @property - def description(self): - return ( - "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n" - " This is an auto pipeline block that works for text2image/inpaint/img2img tasks.\n" - + " - `QwenImageInpaintInputStep` (inpaint) is used when `processed_mask_image` is provided.\n" - + " - `QwenImageImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n" - + " - `QwenImageTextInputsStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n" - ) - - -# controlnet -class QwenImageOptionalControlNetInputStep(AutoPipelineBlocks): - block_classes = [QwenImageControlNetInputsStep] - block_names = ["controlnet"] - block_trigger_inputs = ["control_image_latents"] - - @property - def description(self): - return ( - "Controlnet input step that prepare the control_image_latents input.\n" - + "This is an auto pipeline block.\n" - + " - `QwenImageControlNetInputsStep` (controlnet) is used when `control_image_latents` is provided.\n" - + " - if `control_image_latents` is not provided, step will be skipped." - ) - - -## 1.7 QwenImage/auto before denoise step -# compose the steps into a BeforeDenoiseStep for text2image/img2img/inpaint tasks before combine into an auto step - -# QwenImage/text2image before denoise -QwenImageText2ImageBeforeDenoiseBlocks = InsertableDict( - [ - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsStep()), - ("prepare_rope_inputs", QwenImageRoPEInputsStep()), - ] -) - - -class QwenImageText2ImageBeforeDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageText2ImageBeforeDenoiseBlocks.values() - block_names = QwenImageText2ImageBeforeDenoiseBlocks.keys() - - @property - def description(self): - return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for text2image task." - - -# QwenImage/inpaint before denoise -QwenImageInpaintBeforeDenoiseBlocks = InsertableDict( - [ - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), - ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()), - ("prepare_rope_inputs", QwenImageRoPEInputsStep()), - ] -) - - -class QwenImageInpaintBeforeDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageInpaintBeforeDenoiseBlocks.values() - block_names = QwenImageInpaintBeforeDenoiseBlocks.keys() - - @property - def description(self): - return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." - - -# QwenImage/img2img before denoise -QwenImageImg2ImgBeforeDenoiseBlocks = InsertableDict( - [ - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), - ("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()), - ("prepare_rope_inputs", QwenImageRoPEInputsStep()), - ] -) - - -class QwenImageImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageImg2ImgBeforeDenoiseBlocks.values() - block_names = QwenImageImg2ImgBeforeDenoiseBlocks.keys() - - @property - def description(self): - return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." - - -# auto before_denoise step for text2image, inpaint, img2img tasks -class QwenImageAutoBeforeDenoiseStep(AutoPipelineBlocks): - block_classes = [ - QwenImageInpaintBeforeDenoiseStep, - QwenImageImg2ImgBeforeDenoiseStep, - QwenImageText2ImageBeforeDenoiseStep, - ] - block_names = ["inpaint", "img2img", "text2image"] - block_trigger_inputs = ["processed_mask_image", "image_latents", None] - - @property - def description(self): - return ( - "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n" - + "This is an auto pipeline block that works for text2img, inpainting, img2img tasks.\n" - + " - `QwenImageInpaintBeforeDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n" - + " - `QwenImageImg2ImgBeforeDenoiseStep` (img2img) is used when `image_latents` is provided.\n" - + " - `QwenImageText2ImageBeforeDenoiseStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n" - ) - - -# auto before_denoise step for controlnet tasks -class QwenImageOptionalControlNetBeforeDenoiseStep(AutoPipelineBlocks): - block_classes = [QwenImageControlNetBeforeDenoiserStep] - block_names = ["controlnet"] - block_trigger_inputs = ["control_image_latents"] - - @property - def description(self): - return ( - "Controlnet before denoise step that prepare the controlnet input.\n" - + "This is an auto pipeline block.\n" - + " - `QwenImageControlNetBeforeDenoiserStep` (controlnet) is used when `control_image_latents` is provided.\n" - + " - if `control_image_latents` is not provided, step will be skipped." - ) - - -## 1.8 QwenImage/auto denoise - - -# auto denoise step for controlnet tasks: works for all tasks with controlnet -class QwenImageControlNetAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [QwenImageInpaintControlNetDenoiseStep, QwenImageControlNetDenoiseStep] - block_names = ["inpaint_denoise", "denoise"] - block_trigger_inputs = ["mask", None] - - @property - def description(self): - return ( - "Controlnet step during the denoising process. \n" - " This is an auto pipeline block that works for inpaint and text2image/img2img tasks with controlnet.\n" - + " - `QwenImageInpaintControlNetDenoiseStep` (inpaint) is used when `mask` is provided.\n" - + " - `QwenImageControlNetDenoiseStep` (text2image/img2img) is used when `mask` is not provided.\n" - ) - - -# auto denoise step for everything: works for all tasks with or without controlnet -class QwenImageAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [ - QwenImageControlNetAutoDenoiseStep, - QwenImageInpaintDenoiseStep, - QwenImageDenoiseStep, - ] - block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"] - block_trigger_inputs = ["control_image_latents", "mask", None] - - @property - def description(self): - return ( - "Denoise step that iteratively denoise the latents. \n" - " This is an auto pipeline block that works for inpaint/text2image/img2img tasks. It also works with controlnet\n" - + " - `QwenImageControlNetAutoDenoiseStep` (controlnet) is used when `control_image_latents` is provided.\n" - + " - `QwenImageInpaintDenoiseStep` (inpaint) is used when `mask` is provided and `control_image_latents` is not provided.\n" - + " - `QwenImageDenoiseStep` (text2image/img2img) is used when `mask` is not provided and `control_image_latents` is not provided.\n" - ) - - -## 1.9 QwenImage/auto decode -# auto decode step for inpaint and text2image tasks - - -class QwenImageAutoDecodeStep(AutoPipelineBlocks): - block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep] - block_names = ["inpaint_decode", "decode"] - block_trigger_inputs = ["mask", None] - - @property - def description(self): - return ( - "Decode step that decode the latents into images. \n" - " This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n" - + " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n" - + " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n" - ) - - -class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = [ - QwenImageAutoInputStep, - QwenImageOptionalControlNetInputStep, - QwenImageAutoBeforeDenoiseStep, - QwenImageOptionalControlNetBeforeDenoiseStep, - QwenImageAutoDenoiseStep, - QwenImageAfterDenoiseStep, - ] - block_names = [ - "input", - "controlnet_input", - "before_denoise", - "controlnet_before_denoise", - "denoise", - "after_denoise", - ] - - @property - def description(self): - return ( - "Core step that performs the denoising process. \n" - + " - `QwenImageAutoInputStep` (input) standardizes the inputs for the denoising step.\n" - + " - `QwenImageOptionalControlNetInputStep` (controlnet_input) prepares the controlnet input.\n" - + " - `QwenImageAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n" - + " - `QwenImageOptionalControlNetBeforeDenoiseStep` (controlnet_before_denoise) prepares the controlnet input for the denoising step.\n" - + " - `QwenImageAutoDenoiseStep` (denoise) iteratively denoises the latents.\n" - + "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n" - + " - for image-to-image generation, you need to provide `image_latents`\n" - + " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n" - + " - to run the controlnet workflow, you need to provide `control_image_latents`\n" - + " - for text-to-image generation, all you need to provide is prompt embeddings" - ) - - -## 1.10 QwenImage/auto block & presets -AUTO_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageTextEncoderStep()), - ("vae_encoder", QwenImageAutoVaeEncoderStep()), - ("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()), - ("denoise", QwenImageCoreDenoiseStep()), - ("decode", QwenImageAutoDecodeStep()), - ] -) - - -class QwenImageAutoBlocks(SequentialPipelineBlocks): - model_name = "qwenimage" - - block_classes = AUTO_BLOCKS.values() - block_names = AUTO_BLOCKS.keys() - - @property - def description(self): - return ( - "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n" - + "- for image-to-image generation, you need to provide `image`\n" - + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" - + "- to run the controlnet workflow, you need to provide `control_image`\n" - + "- for text-to-image generation, all you need to provide is `prompt`" - ) - - -# 2. QwenImage-Edit - -## 2.1 QwenImage-Edit/edit - -#### QwenImage-Edit/edit vl encoder: take both image and text prompts -QwenImageEditVLEncoderBlocks = InsertableDict( - [ - ("resize", QwenImageEditResizeDynamicStep()), - ("encode", QwenImageEditTextEncoderStep()), - ] -) - - -class QwenImageEditVLEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageEditVLEncoderBlocks.values() - block_names = QwenImageEditVLEncoderBlocks.keys() - - @property - def description(self) -> str: - return "QwenImage-Edit VL encoder step that encode the image an text prompts together." - - -#### QwenImage-Edit/edit vae encoder -QwenImageEditVaeEncoderBlocks = InsertableDict( - [ - ("resize", QwenImageEditResizeDynamicStep()), # edit has a different resize step - ("preprocess", QwenImageProcessImagesInputStep()), # resized_image -> processed_image - ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents - ] -) - - -class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageEditVaeEncoderBlocks.values() - block_names = QwenImageEditVaeEncoderBlocks.keys() - - @property - def description(self) -> str: - return "Vae encoder step that encode the image inputs into their latent representations." - - -#### QwenImage-Edit/edit input -QwenImageEditInputBlocks = InsertableDict( - [ - ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings - ("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])), - ] -) - - -class QwenImageEditInputStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageEditInputBlocks.values() - block_names = QwenImageEditInputBlocks.keys() - - @property - def description(self): - return "Input step that prepares the inputs for the edit denoising step. It:\n" - " - make sure the text embeddings have consistent batch size as well as the additional inputs: \n" - " - `image_latents`.\n" - " - update height/width based `image_latents`, patchify `image_latents`." - - -#### QwenImage/edit presets -EDIT_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageEditVLEncoderStep()), - ("vae_encoder", QwenImageEditVaeEncoderStep()), - ("input", QwenImageEditInputStep()), - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsStep()), - ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), - ("denoise", QwenImageEditDenoiseStep()), - ("after_denoise", QwenImageAfterDenoiseStep()), - ("decode", QwenImageDecodeStep()), - ] -) - - -## 2.2 QwenImage-Edit/edit inpaint - -#### QwenImage-Edit/edit inpaint vae encoder: the difference from regular inpaint is the resize step -QwenImageEditInpaintVaeEncoderBlocks = InsertableDict( - [ - ("resize", QwenImageEditResizeDynamicStep()), # image -> resized_image - ( - "preprocess", - QwenImageInpaintProcessImagesInputStep, - ), # resized_image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs - ( - "encode", - QwenImageVaeEncoderDynamicStep(input_name="processed_image", output_name="image_latents"), - ), # processed_image -> image_latents - ] -) - - -class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageEditInpaintVaeEncoderBlocks.values() - block_names = QwenImageEditInpaintVaeEncoderBlocks.keys() - - @property - def description(self) -> str: - return ( - "This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n" - " - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n" - " - process the resized image and mask image.\n" - " - create image latents." - ) - - -#### QwenImage-Edit/edit inpaint presets -EDIT_INPAINT_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageEditVLEncoderStep()), - ("vae_encoder", QwenImageEditInpaintVaeEncoderStep()), - ("input", QwenImageInpaintInputStep()), - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), - ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()), - ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), - ("denoise", QwenImageEditInpaintDenoiseStep()), - ("after_denoise", QwenImageAfterDenoiseStep()), - ("decode", QwenImageInpaintDecodeStep()), - ] -) - - -## 2.3 QwenImage-Edit/auto encoders - - -class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [ - QwenImageEditInpaintVaeEncoderStep, - QwenImageEditVaeEncoderStep, - ] - block_names = ["edit_inpaint", "edit"] - block_trigger_inputs = ["mask_image", "image"] - - @property - def description(self): - return ( - "Vae encoder step that encode the image inputs into their latent representations. \n" - " This is an auto pipeline block that works for edit and edit_inpaint tasks.\n" - + " - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n" - + " - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n" - + " - if `mask_image` or `image` is not provided, step will be skipped." - ) - - -## 2.4 QwenImage-Edit/auto inputs -class QwenImageEditAutoInputStep(AutoPipelineBlocks): - block_classes = [QwenImageInpaintInputStep, QwenImageEditInputStep] - block_names = ["edit_inpaint", "edit"] - block_trigger_inputs = ["processed_mask_image", "image_latents"] - - @property - def description(self): - return ( - "Input step that prepares the inputs for the edit denoising step.\n" - + " It is an auto pipeline block that works for edit and edit_inpaint tasks.\n" - + " - `QwenImageInpaintInputStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n" - + " - `QwenImageEditInputStep` (edit) is used when `image_latents` is provided.\n" - + " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped." - ) - - -## 2.5 QwenImage-Edit/auto before denoise -# compose the steps into a BeforeDenoiseStep for edit and edit_inpaint tasks before combine into an auto step - -#### QwenImage-Edit/edit before denoise -QwenImageEditBeforeDenoiseBlocks = InsertableDict( - [ - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsStep()), - ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), - ] -) - - -class QwenImageEditBeforeDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageEditBeforeDenoiseBlocks.values() - block_names = QwenImageEditBeforeDenoiseBlocks.keys() - - @property - def description(self): - return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task." - - -#### QwenImage-Edit/edit inpaint before denoise -QwenImageEditInpaintBeforeDenoiseBlocks = InsertableDict( - [ - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), - ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()), - ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), - ] -) - - -class QwenImageEditInpaintBeforeDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageEditInpaintBeforeDenoiseBlocks.values() - block_names = QwenImageEditInpaintBeforeDenoiseBlocks.keys() - - @property - def description(self): - return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit inpaint task." - - -# auto before_denoise step for edit and edit_inpaint tasks -class QwenImageEditAutoBeforeDenoiseStep(AutoPipelineBlocks): - model_name = "qwenimage-edit" - block_classes = [ - QwenImageEditInpaintBeforeDenoiseStep, - QwenImageEditBeforeDenoiseStep, - ] - block_names = ["edit_inpaint", "edit"] - block_trigger_inputs = ["processed_mask_image", "image_latents"] - - @property - def description(self): - return ( - "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n" - + "This is an auto pipeline block that works for edit (img2img) and edit inpaint tasks.\n" - + " - `QwenImageEditInpaintBeforeDenoiseStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n" - + " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n" - + " - if `image_latents` or `processed_mask_image` is not provided, step will be skipped." - ) - - -## 2.6 QwenImage-Edit/auto denoise - - -class QwenImageEditAutoDenoiseStep(AutoPipelineBlocks): - model_name = "qwenimage-edit" - - block_classes = [QwenImageEditInpaintDenoiseStep, QwenImageEditDenoiseStep] - block_names = ["inpaint_denoise", "denoise"] - block_trigger_inputs = ["processed_mask_image", "image_latents"] - - @property - def description(self): - return ( - "Denoise step that iteratively denoise the latents. \n" - + "This block supports edit (img2img) and edit inpaint tasks for QwenImage Edit. \n" - + " - `QwenImageEditInpaintDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n" - + " - `QwenImageEditDenoiseStep` (img2img) is used when `image_latents` is provided.\n" - + " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped." - ) - - -## 2.7 QwenImage-Edit/auto blocks & presets - - -class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage-edit" - block_classes = [ - QwenImageEditAutoInputStep, - QwenImageEditAutoBeforeDenoiseStep, - QwenImageEditAutoDenoiseStep, - QwenImageAfterDenoiseStep, - ] - block_names = ["input", "before_denoise", "denoise", "after_denoise"] - - @property - def description(self): - return ( - "Core step that performs the denoising process. \n" - + " - `QwenImageEditAutoInputStep` (input) standardizes the inputs for the denoising step.\n" - + " - `QwenImageEditAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n" - + " - `QwenImageEditAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n" - + "This step support edit (img2img) and edit inpainting workflow for QwenImage Edit:\n" - + " - When `processed_mask_image` is provided, it will be used for edit inpainting task.\n" - + " - When `image_latents` is provided, it will be used for edit (img2img) task.\n" - ) - - -EDIT_AUTO_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageEditVLEncoderStep()), - ("vae_encoder", QwenImageEditAutoVaeEncoderStep()), - ("denoise", QwenImageEditCoreDenoiseStep()), - ("decode", QwenImageAutoDecodeStep()), - ] -) - - -class QwenImageEditAutoBlocks(SequentialPipelineBlocks): - model_name = "qwenimage-edit" - block_classes = EDIT_AUTO_BLOCKS.values() - block_names = EDIT_AUTO_BLOCKS.keys() - - @property - def description(self): - return ( - "Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n" - + "- for edit (img2img) generation, you need to provide `image`\n" - + "- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" - ) - - -#################### QwenImage Edit Plus ##################### - -# 3. QwenImage-Edit Plus - -## 3.1 QwenImage-Edit Plus / edit - -#### QwenImage-Edit Plus vl encoder: take both image and text prompts -QwenImageEditPlusVLEncoderBlocks = InsertableDict( - [ - ("resize", QwenImageEditPlusResizeDynamicStep()), - ("encode", QwenImageEditPlusTextEncoderStep()), - ] -) - - -class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageEditPlusVLEncoderBlocks.values() - block_names = QwenImageEditPlusVLEncoderBlocks.keys() - - @property - def description(self) -> str: - return "QwenImage-Edit Plus VL encoder step that encode the image an text prompts together." - - -#### QwenImage-Edit Plus vae encoder -QwenImageEditPlusVaeEncoderBlocks = InsertableDict( - [ - ("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step - ("preprocess", QwenImageEditPlusProcessImagesInputStep()), # vae_image -> processed_image - ("encode", QwenImageEditPlusVaeEncoderDynamicStep()), # processed_image -> image_latents - ] -) - - -class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage-edit-plus" - block_classes = QwenImageEditPlusVaeEncoderBlocks.values() - block_names = QwenImageEditPlusVaeEncoderBlocks.keys() - - @property - def description(self) -> str: - return "Vae encoder step that encode the image inputs into their latent representations." - - -#### QwenImage Edit Plus input blocks -QwenImageEditPlusInputBlocks = InsertableDict( - [ - ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings - ( - "additional_inputs", - QwenImageEditPlusInputsDynamicStep(image_latent_inputs=["image_latents"]), - ), - ] -) - - -class QwenImageEditPlusInputStep(SequentialPipelineBlocks): - model_name = "qwenimage-edit-plus" - block_classes = QwenImageEditPlusInputBlocks.values() - block_names = QwenImageEditPlusInputBlocks.keys() - - -#### QwenImage Edit Plus presets -EDIT_PLUS_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageEditPlusVLEncoderStep()), - ("vae_encoder", QwenImageEditPlusVaeEncoderStep()), - ("input", QwenImageEditPlusInputStep()), - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsStep()), - ("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()), - ("denoise", QwenImageEditDenoiseStep()), - ("after_denoise", QwenImageAfterDenoiseStep()), - ("decode", QwenImageDecodeStep()), - ] -) - - -QwenImageEditPlusBeforeDenoiseBlocks = InsertableDict( - [ - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsStep()), - ("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()), - ] -) - - -class QwenImageEditPlusBeforeDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage-edit-plus" - block_classes = QwenImageEditPlusBeforeDenoiseBlocks.values() - block_names = QwenImageEditPlusBeforeDenoiseBlocks.keys() - - @property - def description(self): - return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task." - - -# auto before_denoise step for edit tasks -class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks): - model_name = "qwenimage-edit-plus" - block_classes = [QwenImageEditPlusBeforeDenoiseStep] - block_names = ["edit"] - block_trigger_inputs = ["image_latents"] - - @property - def description(self): - return ( - "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n" - + "This is an auto pipeline block that works for edit (img2img) task.\n" - + " - `QwenImageEditPlusBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n" - + " - if `image_latents` is not provided, step will be skipped." - ) - - -## 3.2 QwenImage-Edit Plus/auto encoders - - -class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [QwenImageEditPlusVaeEncoderStep] - block_names = ["edit"] - block_trigger_inputs = ["image"] - - @property - def description(self): - return ( - "Vae encoder step that encode the image inputs into their latent representations. \n" - " This is an auto pipeline block that works for edit task.\n" - + " - `QwenImageEditPlusVaeEncoderStep` (edit) is used when `image` is provided.\n" - + " - if `image` is not provided, step will be skipped." - ) - - -## 3.3 QwenImage-Edit/auto blocks & presets - - -class QwenImageEditPlusAutoInputStep(AutoPipelineBlocks): - block_classes = [QwenImageEditPlusInputStep] - block_names = ["edit"] - block_trigger_inputs = ["image_latents"] - - @property - def description(self): - return ( - "Input step that prepares the inputs for the edit denoising step.\n" - + " It is an auto pipeline block that works for edit task.\n" - + " - `QwenImageEditPlusInputStep` (edit) is used when `image_latents` is provided.\n" - + " - if `image_latents` is not provided, step will be skipped." - ) - - -class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage-edit-plus" - block_classes = [ - QwenImageEditPlusAutoInputStep, - QwenImageEditPlusAutoBeforeDenoiseStep, - QwenImageEditAutoDenoiseStep, - QwenImageAfterDenoiseStep, - ] - block_names = ["input", "before_denoise", "denoise", "after_denoise"] - - @property - def description(self): - return ( - "Core step that performs the denoising process. \n" - + " - `QwenImageEditAutoInputStep` (input) standardizes the inputs for the denoising step.\n" - + " - `QwenImageEditPlusAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n" - + " - `QwenImageEditAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n" - + "This step support edit (img2img) workflow for QwenImage Edit Plus:\n" - + " - When `image_latents` is provided, it will be used for edit (img2img) task.\n" - ) - - -EDIT_PLUS_AUTO_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageEditPlusVLEncoderStep()), - ("vae_encoder", QwenImageEditPlusAutoVaeEncoderStep()), - ("denoise", QwenImageEditPlusCoreDenoiseStep()), - ("decode", QwenImageAutoDecodeStep()), - ] -) - - -class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks): - model_name = "qwenimage-edit-plus" - block_classes = EDIT_PLUS_AUTO_BLOCKS.values() - block_names = EDIT_PLUS_AUTO_BLOCKS.keys() - - @property - def description(self): - return ( - "Auto Modular pipeline for edit (img2img) and edit tasks using QwenImage-Edit Plus.\n" - + "- for edit (img2img) generation, you need to provide `image`\n" - ) - - -# 3. all block presets supported in QwenImage, QwenImage-Edit, QwenImage-Edit Plus - - -ALL_BLOCKS = { - "text2image": TEXT2IMAGE_BLOCKS, - "img2img": IMAGE2IMAGE_BLOCKS, - "edit": EDIT_BLOCKS, - "edit_inpaint": EDIT_INPAINT_BLOCKS, - "edit_plus": EDIT_PLUS_BLOCKS, - "inpaint": INPAINT_BLOCKS, - "controlnet": CONTROLNET_BLOCKS, - "auto": AUTO_BLOCKS, - "edit_auto": EDIT_AUTO_BLOCKS, - "edit_plus_auto": EDIT_PLUS_AUTO_BLOCKS, -} diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py new file mode 100644 index 000000000000..63e9f5a28372 --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py @@ -0,0 +1,469 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict +from .before_denoise import ( + QwenImageControlNetBeforeDenoiserStep, + QwenImageCreateMaskLatentsStep, + QwenImagePrepareLatentsStep, + QwenImagePrepareLatentsWithStrengthStep, + QwenImageRoPEInputsStep, + QwenImageSetTimestepsStep, + QwenImageSetTimestepsWithStrengthStep, +) +from .decoders import ( + QwenImageAfterDenoiseStep, + QwenImageDecoderStep, + QwenImageInpaintProcessImagesOutputStep, + QwenImageProcessImagesOutputStep, +) +from .denoise import ( + QwenImageControlNetDenoiseStep, + QwenImageDenoiseStep, + QwenImageInpaintControlNetDenoiseStep, + QwenImageInpaintDenoiseStep, +) +from .encoders import ( + QwenImageControlNetVaeEncoderStep, + QwenImageInpaintProcessImagesInputStep, + QwenImageProcessImagesInputStep, + QwenImageTextEncoderStep, + QwenImageVaeEncoderStep, +) +from .inputs import ( + QwenImageAdditionalInputsStep, + QwenImageControlNetInputsStep, + QwenImageTextInputsStep, +) + + +logger = logging.get_logger(__name__) + + +# ==================== +# 1. VAE ENCODER +# ==================== + + +class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = [QwenImageInpaintProcessImagesInputStep(), QwenImageVaeEncoderStep()] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return ( + "This step is used for processing image and mask inputs for inpainting tasks. It:\n" + " - Resizes the image to the target size, based on `height` and `width`.\n" + " - Processes and updates `image` and `mask_image`.\n" + " - Creates `image_latents`." + ) + + +class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks): + model_name = "qwenimage" + + block_classes = [QwenImageProcessImagesInputStep(), QwenImageVaeEncoderStep()] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return "Vae encoder step that preprocess andencode the image inputs into their latent representations." + + +# Auto VAE encoder +class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep] + block_names = ["inpaint", "img2img"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block.\n" + + " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n" + + " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n" + + " - if `mask_image` or `image` is not provided, step will be skipped." + ) + + +# optional controlnet vae encoder +class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks): + block_classes = [QwenImageControlNetVaeEncoderStep] + block_names = ["controlnet"] + block_trigger_inputs = ["control_image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block.\n" + + " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n" + + " - if `control_image` is not provided, step will be skipped." + ) + + +# ==================== +# 2. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# ==================== + + +# assemble input steps +class QwenImageImg2ImgInputStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"])] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return "Input step that prepares the inputs for the img2img denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n" + " - update height/width based `image_latents`, patchify `image_latents`." + + +class QwenImageInpaintInputStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageAdditionalInputsStep( + image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"] + ), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return "Input step that prepares the inputs for the inpainting denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n" + " - update height/width based `image_latents`, patchify `image_latents`." + + +# assemble prepare latents steps +class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()] + block_names = ["add_noise_to_latents", "create_mask_latents"] + + @property + def description(self) -> str: + return ( + "This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n" + " - Add noise to the image latents to create the latents input for the denoiser.\n" + " - Create the pachified latents `mask` based on the processedmask image.\n" + ) + + +# assemble denoising steps + + +# Qwen Image (text2image) +class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = [ + QwenImageTextInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageRoPEInputsStep(), + QwenImageDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)." + + +# Qwen Image (inpainting) +class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = [ + QwenImageInpaintInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImageInpaintPrepareLatentsStep(), + QwenImageRoPEInputsStep(), + QwenImageInpaintDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_inpaint_latents", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." + + +# Qwen Image (image2image) +class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = [ + QwenImageImg2ImgInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImagePrepareLatentsWithStrengthStep(), + QwenImageRoPEInputsStep(), + QwenImageDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_img2img_latents", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." + + +# Qwen Image (text2image) with controlnet +class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageControlNetInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageRoPEInputsStep(), + QwenImageControlNetBeforeDenoiserStep(), + QwenImageControlNetDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "controlnet_input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "controlnet_before_denoise", + "controlnet_denoise", + "after_denoise", + ] + + @property + def description(self): + return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)." + + +# Qwen Image (inpainting) with controlnet +class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = [ + QwenImageInpaintInputStep(), + QwenImageControlNetInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImageInpaintPrepareLatentsStep(), + QwenImageRoPEInputsStep(), + QwenImageControlNetBeforeDenoiserStep(), + QwenImageInpaintControlNetDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "controlnet_input", + "prepare_latents", + "set_timesteps", + "prepare_inpaint_latents", + "prepare_rope_inputs", + "controlnet_before_denoise", + "controlnet_denoise", + "after_denoise", + ] + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." + + +# Qwen Image (image2image) with controlnet +class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = [ + QwenImageImg2ImgInputStep(), + QwenImageControlNetInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImagePrepareLatentsWithStrengthStep(), + QwenImageRoPEInputsStep(), + QwenImageControlNetBeforeDenoiserStep(), + QwenImageControlNetDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "controlnet_input", + "prepare_latents", + "set_timesteps", + "prepare_img2img_latents", + "prepare_rope_inputs", + "controlnet_before_denoise", + "controlnet_denoise", + "after_denoise", + ] + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." + + +# Auto denoise step for QwenImage +class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks): + block_classes = [ + QwenImageCoreDenoiseStep, + QwenImageInpaintCoreDenoiseStep, + QwenImageImg2ImgCoreDenoiseStep, + QwenImageControlNetCoreDenoiseStep, + QwenImageControlNetInpaintCoreDenoiseStep, + QwenImageControlNetImg2ImgCoreDenoiseStep, + ] + block_names = [ + "text2image", + "inpaint", + "img2img", + "controlnet_text2image", + "controlnet_inpaint", + "controlnet_img2img", + ] + block_trigger_inputs = ["control_image_latents", "processed_mask_image", "image_latents"] + default_block_name = "text2image" + + def select_block(self, control_image_latents=None, processed_mask_image=None, image_latents=None): + if control_image_latents is not None: + if processed_mask_image is not None: + return "controlnet_inpaint" + elif image_latents is not None: + return "controlnet_img2img" + else: + return "controlnet_text2image" + else: + if processed_mask_image is not None: + return "inpaint" + elif image_latents is not None: + return "img2img" + else: + return "text2image" + + @property + def description(self): + return ( + "Core step that performs the denoising process. \n" + + " - `QwenImageCoreDenoiseStep` (text2image) for text2image tasks.\n" + + " - `QwenImageInpaintCoreDenoiseStep` (inpaint) for inpaint tasks.\n" + + " - `QwenImageImg2ImgCoreDenoiseStep` (img2img) for img2img tasks.\n" + + " - `QwenImageControlNetCoreDenoiseStep` (controlnet_text2image) for text2image tasks with controlnet.\n" + + " - `QwenImageControlNetInpaintCoreDenoiseStep` (controlnet_inpaint) for inpaint tasks with controlnet.\n" + + " - `QwenImageControlNetImg2ImgCoreDenoiseStep` (controlnet_img2img) for img2img tasks with controlnet.\n" + + "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n" + + " - for image-to-image generation, you need to provide `image_latents`\n" + + " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n" + + " - to run the controlnet workflow, you need to provide `control_image_latents`\n" + + " - for text-to-image generation, all you need to provide is prompt embeddings" + ) + + +# ==================== +# 3. DECODE +# ==================== + + +# standard decode step works for most tasks except for inpaint +class QwenImageDecodeStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image." + + +# Inpaint decode step +class QwenImageInpaintDecodeStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image." + + +# Auto decode step for QwenImage +class QwenImageAutoDecodeStep(AutoPipelineBlocks): + block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep] + block_names = ["inpaint_decode", "decode"] + block_trigger_inputs = ["mask", None] + + @property + def description(self): + return ( + "Decode step that decode the latents into images. \n" + " This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n" + + " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n" + + " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n" + ) + + +# ==================== +# 4. AUTO BLOCKS & PRESETS +# ==================== +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageTextEncoderStep()), + ("vae_encoder", QwenImageAutoVaeEncoderStep()), + ("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()), + ("denoise", QwenImageAutoCoreDenoiseStep()), + ("decode", QwenImageAutoDecodeStep()), + ] +) + + +class QwenImageAutoBlocks(SequentialPipelineBlocks): + model_name = "qwenimage" + + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + + @property + def description(self): + return ( + "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n" + + "- for image-to-image generation, you need to provide `image`\n" + + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + + "- to run the controlnet workflow, you need to provide `control_image`\n" + + "- for text-to-image generation, all you need to provide is `prompt`" + ) diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py new file mode 100644 index 000000000000..99a349994c19 --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py @@ -0,0 +1,336 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict +from .before_denoise import ( + QwenImageCreateMaskLatentsStep, + QwenImageEditRoPEInputsStep, + QwenImagePrepareLatentsStep, + QwenImagePrepareLatentsWithStrengthStep, + QwenImageSetTimestepsStep, + QwenImageSetTimestepsWithStrengthStep, +) +from .decoders import ( + QwenImageAfterDenoiseStep, + QwenImageDecoderStep, + QwenImageInpaintProcessImagesOutputStep, + QwenImageProcessImagesOutputStep, +) +from .denoise import ( + QwenImageEditDenoiseStep, + QwenImageEditInpaintDenoiseStep, +) +from .encoders import ( + QwenImageEditInpaintProcessImagesInputStep, + QwenImageEditProcessImagesInputStep, + QwenImageEditResizeStep, + QwenImageEditTextEncoderStep, + QwenImageVaeEncoderStep, +) +from .inputs import ( + QwenImageAdditionalInputsStep, + QwenImageTextInputsStep, +) + + +logger = logging.get_logger(__name__) + + +# ==================== +# 1. TEXT ENCODER +# ==================== + + +class QwenImageEditVLEncoderStep(SequentialPipelineBlocks): + """VL encoder that takes both image and text prompts.""" + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditResizeStep(), + QwenImageEditTextEncoderStep(), + ] + block_names = ["resize", "encode"] + + @property + def description(self) -> str: + return "QwenImage-Edit VL encoder step that encode the image and text prompts together." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# Edit VAE encoder +class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditResizeStep(), + QwenImageEditProcessImagesInputStep(), + QwenImageVaeEncoderStep(), + ] + block_names = ["resize", "preprocess", "encode"] + + @property + def description(self) -> str: + return "Vae encoder step that encode the image inputs into their latent representations." + + +# Edit Inpaint VAE encoder +class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditResizeStep(), + QwenImageEditInpaintProcessImagesInputStep(), + QwenImageVaeEncoderStep(input_name="processed_image", output_name="image_latents"), + ] + block_names = ["resize", "preprocess", "encode"] + + @property + def description(self) -> str: + return ( + "This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n" + " - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n" + " - process the resized image and mask image.\n" + " - create image latents." + ) + + +# Auto VAE encoder +class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [QwenImageEditInpaintVaeEncoderStep, QwenImageEditVaeEncoderStep] + block_names = ["edit_inpaint", "edit"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + "This is an auto pipeline block.\n" + " - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n" + " - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n" + " - if `mask_image` or `image` is not provided, step will be skipped." + ) + + +# ==================== +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# ==================== + + +# assemble input steps +class QwenImageEditInputStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"]), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the edit denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs.\n" + " - update height/width based `image_latents`, patchify `image_latents`." + ) + + +class QwenImageEditInpaintInputStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageAdditionalInputsStep( + image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"] + ), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the edit inpaint denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs.\n" + " - update height/width based `image_latents`, patchify `image_latents`." + ) + + +# assemble prepare latents steps +class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit" + block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()] + block_names = ["add_noise_to_latents", "create_mask_latents"] + + @property + def description(self) -> str: + return ( + "This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It:\n" + " - Add noise to the image latents to create the latents input for the denoiser.\n" + " - Create the patchified latents `mask` based on the processed mask image.\n" + ) + + +# Qwen Image Edit (image2image) core denoise step +class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageEditRoPEInputsStep(), + QwenImageEditDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Core denoising workflow for QwenImage-Edit edit (img2img) task." + + +# Qwen Image Edit (inpainting) core denoise step +class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditInpaintInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImageEditInpaintPrepareLatentsStep(), + QwenImageEditRoPEInputsStep(), + QwenImageEditInpaintDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_inpaint_latents", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Core denoising workflow for QwenImage-Edit edit inpaint task." + + +# Auto core denoise step for QwenImage Edit +class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks): + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditInpaintCoreDenoiseStep, + QwenImageEditCoreDenoiseStep, + ] + block_names = ["edit_inpaint", "edit"] + block_trigger_inputs = ["processed_mask_image", "image_latents"] + default_block_name = "edit" + + def select_block(self, processed_mask_image=None, image_latents=None) -> Optional[str]: + if processed_mask_image is not None: + return "edit_inpaint" + elif image_latents is not None: + return "edit" + return None + + @property + def description(self): + return ( + "Auto core denoising step that selects the appropriate workflow based on inputs.\n" + " - `QwenImageEditInpaintCoreDenoiseStep` when `processed_mask_image` is provided\n" + " - `QwenImageEditCoreDenoiseStep` when `image_latents` is provided\n" + "Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit." + ) + + +# ==================== +# 4. DECODE +# ==================== + + +# Decode step (standard) +class QwenImageEditDecodeStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit" + block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image." + + +# Inpaint decode step +class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit" + block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask overlay to the original image." + + +# Auto decode step +class QwenImageEditAutoDecodeStep(AutoPipelineBlocks): + block_classes = [QwenImageEditInpaintDecodeStep, QwenImageEditDecodeStep] + block_names = ["inpaint_decode", "decode"] + block_trigger_inputs = ["mask", None] + + @property + def description(self): + return ( + "Decode step that decode the latents into images.\n" + "This is an auto pipeline block.\n" + " - `QwenImageEditInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n" + " - `QwenImageEditDecodeStep` (edit) is used when `mask` is not provided.\n" + ) + + +# ==================== +# 5. AUTO BLOCKS & PRESETS +# ==================== + +EDIT_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageEditVLEncoderStep()), + ("vae_encoder", QwenImageEditAutoVaeEncoderStep()), + ("denoise", QwenImageEditAutoCoreDenoiseStep()), + ("decode", QwenImageEditAutoDecodeStep()), + ] +) + + +class QwenImageEditAutoBlocks(SequentialPipelineBlocks): + model_name = "qwenimage-edit" + block_classes = EDIT_AUTO_BLOCKS.values() + block_names = EDIT_AUTO_BLOCKS.keys() + + @property + def description(self): + return ( + "Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n" + "- for edit (img2img) generation, you need to provide `image`\n" + "- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`\n" + ) diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py new file mode 100644 index 000000000000..275e4288eb0a --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py @@ -0,0 +1,181 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict +from .before_denoise import ( + QwenImageEditPlusRoPEInputsStep, + QwenImagePrepareLatentsStep, + QwenImageSetTimestepsStep, +) +from .decoders import ( + QwenImageAfterDenoiseStep, + QwenImageDecoderStep, + QwenImageProcessImagesOutputStep, +) +from .denoise import ( + QwenImageEditDenoiseStep, +) +from .encoders import ( + QwenImageEditPlusProcessImagesInputStep, + QwenImageEditPlusResizeStep, + QwenImageEditPlusTextEncoderStep, + QwenImageVaeEncoderStep, +) +from .inputs import ( + QwenImageEditPlusAdditionalInputsStep, + QwenImageTextInputsStep, +) + + +logger = logging.get_logger(__name__) + + +# ==================== +# 1. TEXT ENCODER +# ==================== + + +class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks): + """VL encoder that takes both image and text prompts. Uses 384x384 target area.""" + + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageEditPlusResizeStep(target_area=384 * 384, output_name="resized_cond_image"), + QwenImageEditPlusTextEncoderStep(), + ] + block_names = ["resize", "encode"] + + @property + def description(self) -> str: + return "QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): + """VAE encoder that handles multiple images with different sizes. Uses 1024x1024 target area.""" + + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageEditPlusResizeStep(target_area=1024 * 1024, output_name="resized_image"), + QwenImageEditPlusProcessImagesInputStep(), + QwenImageVaeEncoderStep(), + ] + block_names = ["resize", "preprocess", "encode"] + + @property + def description(self) -> str: + return ( + "VAE encoder step that encodes image inputs into latent representations.\n" + "Each image is resized independently based on its own aspect ratio to 1024x1024 target area." + ) + + +# ==================== +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# ==================== + + +# assemble input steps +class QwenImageEditPlusInputStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageEditPlusAdditionalInputsStep(image_latent_inputs=["image_latents"]), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the Edit Plus denoising step. It:\n" + " - Standardizes text embeddings batch size.\n" + " - Processes list of image latents: patchifies, concatenates along dim=1, expands batch.\n" + " - Outputs lists of image_height/image_width for RoPE calculation.\n" + " - Defaults height/width from last image in the list." + ) + + +# Qwen Image Edit Plus (image2image) core denoise step +class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageEditPlusInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageEditPlusRoPEInputsStep(), + QwenImageEditDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Core denoising workflow for QwenImage-Edit Plus edit (img2img) task." + + +# ==================== +# 4. DECODE +# ==================== + + +class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit-plus" + block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocesses the generated image." + + +# ==================== +# 5. AUTO BLOCKS & PRESETS +# ==================== + +EDIT_PLUS_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageEditPlusVLEncoderStep()), + ("vae_encoder", QwenImageEditPlusVaeEncoderStep()), + ("denoise", QwenImageEditPlusCoreDenoiseStep()), + ("decode", QwenImageEditPlusDecodeStep()), + ] +) + + +class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks): + model_name = "qwenimage-edit-plus" + block_classes = EDIT_PLUS_AUTO_BLOCKS.values() + block_names = EDIT_PLUS_AUTO_BLOCKS.keys() + + @property + def description(self): + return ( + "Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus.\n" + "- `image` is required input (can be single image or list of images).\n" + "- Each image is resized independently based on its own aspect ratio.\n" + "- VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area." + ) diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py new file mode 100644 index 000000000000..fe6f756789af --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py @@ -0,0 +1,159 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict +from .before_denoise import ( + QwenImageLayeredPrepareLatentsStep, + QwenImageLayeredRoPEInputsStep, + QwenImageLayeredSetTimestepsStep, +) +from .decoders import ( + QwenImageLayeredAfterDenoiseStep, + QwenImageLayeredDecoderStep, +) +from .denoise import ( + QwenImageLayeredDenoiseStep, +) +from .encoders import ( + QwenImageEditProcessImagesInputStep, + QwenImageLayeredGetImagePromptStep, + QwenImageLayeredPermuteLatentsStep, + QwenImageLayeredResizeStep, + QwenImageTextEncoderStep, + QwenImageVaeEncoderStep, +) +from .inputs import ( + QwenImageLayeredAdditionalInputsStep, + QwenImageTextInputsStep, +) + + +logger = logging.get_logger(__name__) + + +# ==================== +# 1. TEXT ENCODER +# ==================== + + +class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks): + """Text encoder that takes text prompt, will generate a prompt based on image if not provided.""" + + model_name = "qwenimage-layered" + block_classes = [ + QwenImageLayeredResizeStep(), + QwenImageLayeredGetImagePromptStep(), + QwenImageTextEncoderStep(), + ] + block_names = ["resize", "get_image_prompt", "encode"] + + @property + def description(self) -> str: + return "QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not provided." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# Edit VAE encoder +class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks): + model_name = "qwenimage-layered" + block_classes = [ + QwenImageLayeredResizeStep(), + QwenImageEditProcessImagesInputStep(), + QwenImageVaeEncoderStep(), + QwenImageLayeredPermuteLatentsStep(), + ] + block_names = ["resize", "preprocess", "encode", "permute"] + + @property + def description(self) -> str: + return "Vae encoder step that encode the image inputs into their latent representations." + + +# ==================== +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# ==================== + + +# assemble input steps +class QwenImageLayeredInputStep(SequentialPipelineBlocks): + model_name = "qwenimage-layered" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageLayeredAdditionalInputsStep(image_latent_inputs=["image_latents"]), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the layered denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs.\n" + " - update height/width based `image_latents`, patchify `image_latents`." + ) + + +# Qwen Image Layered (image2image) core denoise step +class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage-layered" + block_classes = [ + QwenImageLayeredInputStep(), + QwenImageLayeredPrepareLatentsStep(), + QwenImageLayeredSetTimestepsStep(), + QwenImageLayeredRoPEInputsStep(), + QwenImageLayeredDenoiseStep(), + QwenImageLayeredAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Core denoising workflow for QwenImage-Layered img2img task." + + +# ==================== +# 4. AUTO BLOCKS & PRESETS +# ==================== + +LAYERED_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageLayeredTextEncoderStep()), + ("vae_encoder", QwenImageLayeredVaeEncoderStep()), + ("denoise", QwenImageLayeredCoreDenoiseStep()), + ("decode", QwenImageLayeredDecoderStep()), + ] +) + + +class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks): + model_name = "qwenimage-layered" + block_classes = LAYERED_AUTO_BLOCKS.values() + block_names = LAYERED_AUTO_BLOCKS.keys() + + @property + def description(self): + return "Auto Modular pipeline for layered denoising tasks using QwenImage-Layered." diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py index 59e1a13a5db2..892435989d00 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py @@ -90,6 +90,88 @@ def unpack_latents(self, latents, height, width, vae_scale_factor=8): return latents +class QwenImageLayeredPachifier(ConfigMixin): + """ + A class to pack and unpack latents for QwenImage Layered. + + Unlike QwenImagePachifier, this handles 5D latents with shape (B, layers+1, C, H, W). + """ + + config_name = "config.json" + + @register_to_config + def __init__(self, patch_size: int = 2): + super().__init__() + + def pack_latents(self, latents): + """ + Pack latents from (B, layers, C, H, W) to (B, layers * H/2 * W/2, C*4). + """ + + if latents.ndim != 5: + raise ValueError(f"Latents must have 5 dimensions (B, layers, C, H, W), but got {latents.ndim}") + + batch_size, layers, num_channels_latents, latent_height, latent_width = latents.shape + patch_size = self.config.patch_size + + if latent_height % patch_size != 0 or latent_width % patch_size != 0: + raise ValueError( + f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}" + ) + + latents = latents.view( + batch_size, + layers, + num_channels_latents, + latent_height // patch_size, + patch_size, + latent_width // patch_size, + patch_size, + ) + latents = latents.permute(0, 1, 3, 5, 2, 4, 6) + latents = latents.reshape( + batch_size, + layers * (latent_height // patch_size) * (latent_width // patch_size), + num_channels_latents * patch_size * patch_size, + ) + return latents + + def unpack_latents(self, latents, height, width, layers, vae_scale_factor=8): + """ + Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W). + """ + + if latents.ndim != 3: + raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}") + + batch_size, _, channels = latents.shape + patch_size = self.config.patch_size + + height = patch_size * (int(height) // (vae_scale_factor * patch_size)) + width = patch_size * (int(width) // (vae_scale_factor * patch_size)) + + latents = latents.view( + batch_size, + layers + 1, + height // patch_size, + width // patch_size, + channels // (patch_size * patch_size), + patch_size, + patch_size, + ) + latents = latents.permute(0, 1, 4, 2, 5, 3, 6) + latents = latents.reshape( + batch_size, + layers + 1, + channels // (patch_size * patch_size), + height, + width, + ) + latents = latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w) + + return latents + + class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin): """ A ModularPipeline for QwenImage. @@ -203,3 +285,13 @@ class QwenImageEditPlusModularPipeline(QwenImageEditModularPipeline): """ default_blocks_name = "QwenImageEditPlusAutoBlocks" + + +class QwenImageLayeredModularPipeline(QwenImageModularPipeline): + """ + A ModularPipeline for QwenImage-Layered. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "QwenImageLayeredAutoBlocks" diff --git a/src/diffusers/modular_pipelines/qwenimage/prompt_templates.py b/src/diffusers/modular_pipelines/qwenimage/prompt_templates.py new file mode 100644 index 000000000000..8e7beb555760 --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/prompt_templates.py @@ -0,0 +1,121 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Prompt templates for QwenImage pipelines. + +This module centralizes all prompt templates used across different QwenImage pipeline variants: +- QwenImage (base): Text-only encoding for text-to-image generation +- QwenImage Edit: VL encoding with single image for image editing +- QwenImage Edit Plus: VL encoding with multiple images for multi-reference editing +- QwenImage Layered: Auto-captioning for image decomposition +""" + +# ============================================ +# QwenImage Base (text-only encoding) +# ============================================ +# Used for text-to-image generation where only text prompt is encoded + +QWENIMAGE_PROMPT_TEMPLATE = ( + "<|im_start|>system\n" + "Describe the image by detailing the color, shape, size, texture, quantity, text, " + "spatial relationships of the objects and background:<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n" + "<|im_start|>assistant\n" +) +QWENIMAGE_PROMPT_TEMPLATE_START_IDX = 34 + + +# ============================================ +# QwenImage Edit (VL encoding with single image) +# ============================================ +# Used for single-image editing where both image and text are encoded together + +QWENIMAGE_EDIT_PROMPT_TEMPLATE = ( + "<|im_start|>system\n" + "Describe the key features of the input image (color, shape, size, texture, objects, background), " + "then explain how the user's text instruction should alter or modify the image. " + "Generate a new image that meets the user's requirements while maintaining consistency " + "with the original input where appropriate.<|im_end|>\n" + "<|im_start|>user\n" + "<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n" + "<|im_start|>assistant\n" +) +QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX = 64 + + +# ============================================ +# QwenImage Edit Plus (VL encoding with multiple images) +# ============================================ +# Used for multi-reference editing where multiple images and text are encoded together +# The img_template is used to format each image in the prompt + +QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE = ( + "<|im_start|>system\n" + "Describe the key features of the input image (color, shape, size, texture, objects, background), " + "then explain how the user's text instruction should alter or modify the image. " + "Generate a new image that meets the user's requirements while maintaining consistency " + "with the original input where appropriate.<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n" + "<|im_start|>assistant\n" +) +QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" +QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX = 64 + + +# ============================================ +# QwenImage Layered (auto-captioning) +# ============================================ +# Used for image decomposition where the VL model generates a caption from the input image +# if no prompt is provided. These prompts instruct the model to describe the image in detail. + +QWENIMAGE_LAYERED_CAPTION_PROMPT_EN = ( + "<|im_start|>system\n" + "You are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + "# Image Annotator\n" + "You are a professional image annotator. Please write an image caption based on the input image:\n" + "1. Write the caption using natural, descriptive language without structured formats or rich text.\n" + "2. Enrich caption details by including:\n" + " - Object attributes, such as quantity, color, shape, size, material, state, position, actions, and so on\n" + " - Vision Relations between objects, such as spatial relations, functional relations, possessive relations, " + "attachment relations, action relations, comparative relations, causal relations, and so on\n" + " - Environmental details, such as weather, lighting, colors, textures, atmosphere, and so on\n" + " - Identify the text clearly visible in the image, without translation or explanation, " + "and highlight it in the caption with quotation marks\n" + "3. Maintain authenticity and accuracy:\n" + " - Avoid generalizations\n" + " - Describe all visible information in the image, while do not add information not explicitly shown in the image\n" + "<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n" + "<|im_start|>assistant\n" +) + +QWENIMAGE_LAYERED_CAPTION_PROMPT_CN = ( + "<|im_start|>system\n" + "You are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + "# 图像标注器\n" + "你是一个专业的图像标注器。请基于输入图像,撰写图注:\n" + "1. 使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。\n" + "2. 通过加入以下内容,丰富图注细节:\n" + " - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等\n" + " - 对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等\n" + " - 环境细节:例如天气、光照、颜色、纹理、气氛等\n" + " - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调\n" + "3. 保持真实性与准确性:\n" + " - 不要使用笼统的描述\n" + " - 描述图像中所有可见的信息,但不要加入没有在图像中出现的内容\n" + "<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n" + "<|im_start|>assistant\n" +) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index c14910250b54..b5ebe1b81495 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -99,6 +99,7 @@ QwenImageEditPlusPipeline, QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, + QwenImageLayeredPipeline, QwenImagePipeline, ) from .sana import SanaPipeline @@ -202,6 +203,7 @@ ("qwenimage", QwenImageImg2ImgPipeline), ("qwenimage-edit", QwenImageEditPipeline), ("qwenimage-edit-plus", QwenImageEditPlusPipeline), + ("qwenimage-layered", QwenImageLayeredPipeline), ("z-image", ZImageImg2ImgPipeline), ] ) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 63cec365799b..47d27741fe88 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -167,6 +167,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class QwenImageLayeredAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class QwenImageLayeredModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class QwenImageModularPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"]