Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions diffsynth/pipelines/flux2_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..models.flux2_vae import Flux2VAE
from ..models.z_image_text_encoder import ZImageTextEncoder

from ..utils.inference_time_scaling.ses import run_ses_cem, SESRewardScorer

class Flux2ImagePipeline(BasePipeline):

Expand Down Expand Up @@ -94,6 +95,11 @@ def __call__(
num_inference_steps: int = 30,
# Progress bar
progress_bar_cmd = tqdm,
# SES
enable_ses: bool = False,
ses_reward_model: str = "pick",
ses_eval_budget: int = 50,
ses_inference_steps: int = 10,
):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)

Expand All @@ -115,6 +121,58 @@ def __call__(
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)

# Inference-Time Scaling (SES)
if enable_ses:
print(f"[SES] Starting optimization with budget={ses_eval_budget}, steps={ses_inference_steps}")
scorer = SESRewardScorer(ses_reward_model, device=self.device, dtype=self.torch_dtype)
self.load_models_to_device(list(self.in_iteration_models) + ['vae'])
models = {name: getattr(self, name) for name in self.in_iteration_models}

h_latent = height // 16
w_latent = width // 16

def ses_generate_callback(trial_latents_spatial):
trial_inputs = inputs_shared.copy()

self.scheduler.set_timesteps(ses_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=h_latent*w_latent)
eval_timesteps = self.scheduler.timesteps

curr_latents_seq = rearrange(trial_latents_spatial, "b c h w -> b (h w) c")

for progress_id, timestep in enumerate(eval_timesteps):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)

trial_inputs["latents"] = curr_latents_seq

noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
trial_inputs, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
curr_latents_seq = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **trial_inputs)

curr_latents_spatial = rearrange(curr_latents_seq, "b (h w) c -> b c h w", h=h_latent, w=w_latent)

decoded_img = self.vae.decode(curr_latents_spatial)
return self.vae_output_to_image(decoded_img)
initial_noise_seq = inputs_shared["latents"]
initial_noise_spatial = rearrange(initial_noise_seq, "b (h w) c -> b c h w", h=h_latent, w=w_latent)

optimized_latents_spatial = run_ses_cem(
base_latents=initial_noise_spatial,
pipeline_callback=ses_generate_callback,
prompt=prompt,
scorer=scorer,
total_eval_budget=ses_eval_budget,
popsize=10,
k_elites=5
)
optimized_latents_seq = rearrange(optimized_latents_spatial, "b c h w -> b (h w) c")
inputs_shared["latents"] = optimized_latents_seq
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=h_latent*w_latent)
del scorer
torch.cuda.empty_cache()
Comment on lines +125 to +174
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code for Inference-Time Scaling (SES) is largely duplicated across multiple pipeline files (flux_image.py, flux2_image.py, qwen_image.py, z_image.py). This duplication makes the code harder to maintain and update. Consider refactoring this logic into a shared helper function or a method in the BasePipeline class. This would centralize the SES implementation, making it easier to manage and reducing the risk of inconsistencies between pipelines. A base method could accept pipeline-specific parameters (like scheduler settings, VAE model name, and latent shape handling) to accommodate the variations between models.


# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
Expand Down
50 changes: 50 additions & 0 deletions diffsynth/pipelines/flux_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from ..models.step1x_text_encoder import Step1xEditEmbedder
from ..core.vram.layers import AutoWrappedLinear

from ..utils.inference_time_scaling.ses import run_ses_cem, SESRewardScorer

class MultiControlNet(torch.nn.Module):
def __init__(self, models: list[torch.nn.Module]):
super().__init__()
Expand Down Expand Up @@ -240,6 +242,11 @@ def __call__(
tile_stride: int = 64,
# Progress bar
progress_bar_cmd = tqdm,
# SES
enable_ses: bool = False,
ses_reward_model: str = "pick",
ses_eval_budget: int = 50,
ses_inference_steps: int = 10,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
Expand Down Expand Up @@ -274,6 +281,49 @@ def __call__(
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)

# Inference-Time Scaling (SES)
if enable_ses:
print(f"[SES] Starting optimization with budget={ses_eval_budget}, steps={ses_inference_steps}")
scorer = SESRewardScorer(ses_reward_model, device=self.device, dtype=self.torch_dtype)
self.load_models_to_device(list(self.in_iteration_models) + ['vae_decoder'])
models = {name: getattr(self, name) for name in self.in_iteration_models}

def ses_generate_callback(trial_latents):
trial_inputs = inputs_shared.copy()

self.scheduler.set_timesteps(ses_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
eval_timesteps = self.scheduler.timesteps
curr_latents = trial_latents

for progress_id, timestep in enumerate(eval_timesteps):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)

trial_inputs["latents"] = curr_latents
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
trial_inputs, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
curr_latents = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **trial_inputs)

decoded_img = self.vae_decoder(curr_latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return self.vae_output_to_image(decoded_img)

initial_noise = inputs_shared["latents"]
optimized_latents = run_ses_cem(
base_latents=initial_noise,
pipeline_callback=ses_generate_callback,
prompt=prompt,
scorer=scorer,
total_eval_budget=ses_eval_budget,
popsize=10,
k_elites=5
)
inputs_shared["latents"] = optimized_latents
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
del scorer
torch.cuda.empty_cache()
Comment on lines +285 to +325
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code for Inference-Time Scaling (SES) is largely duplicated across multiple pipeline files. This duplication makes the code harder to maintain and update. Consider refactoring this logic into a shared helper function or a method in the BasePipeline class. This would centralize the SES implementation, making it easier to manage and reducing the risk of inconsistencies between pipelines.


# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
Expand Down
53 changes: 52 additions & 1 deletion diffsynth/pipelines/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ..models.dinov3_image_encoder import DINOv3ImageEncoder
from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel

from ..utils.inference_time_scaling.ses import run_ses_cem, SESRewardScorer

class QwenImagePipeline(BasePipeline):

Expand Down Expand Up @@ -141,6 +142,11 @@ def __call__(
tile_stride: int = 64,
# Progress bar
progress_bar_cmd = tqdm,
# SES
enable_ses: bool = False,
ses_reward_model: str = "pick",
ses_eval_budget: int = 50,
ses_inference_steps: int = 10,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu)
Expand Down Expand Up @@ -171,6 +177,51 @@ def __call__(
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)

# Inference-Time Scaling (SES)
if enable_ses:
print(f"[SES] Starting optimization with budget={ses_eval_budget}, steps={ses_inference_steps}")
scorer = SESRewardScorer(ses_reward_model, device=self.device, dtype=self.torch_dtype)

self.load_models_to_device(list(self.in_iteration_models) + ['vae'])
models = {name: getattr(self, name) for name in self.in_iteration_models}

def ses_generate_callback(trial_latents):
trial_inputs = inputs_shared.copy()
self.scheduler.set_timesteps(ses_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu)
eval_timesteps = self.scheduler.timesteps
curr_latents = trial_latents

for progress_id, timestep in enumerate(eval_timesteps):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)

trial_inputs["latents"] = curr_latents

noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
trial_inputs, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
curr_latents = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **trial_inputs)

decoded_img = self.vae.decode(curr_latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return self.vae_output_to_image(decoded_img)

initial_noise = inputs_shared["latents"]

optimized_latents = run_ses_cem(
base_latents=initial_noise,
pipeline_callback=ses_generate_callback,
prompt=prompt,
scorer=scorer,
total_eval_budget=ses_eval_budget,
popsize=10,
k_elites=5
)
inputs_shared["latents"] = optimized_latents
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu)
del scorer
torch.cuda.empty_cache()
Comment on lines +181 to +223
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code for Inference-Time Scaling (SES) is largely duplicated across multiple pipeline files. This duplication makes the code harder to maintain and update. Consider refactoring this logic into a shared helper function or a method in the BasePipeline class. This would centralize the SES implementation, making it easier to manage and reducing the risk of inconsistencies between pipelines.


# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
Expand All @@ -182,7 +233,7 @@ def __call__(
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)

# Decode
self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
Expand Down
46 changes: 46 additions & 0 deletions diffsynth/pipelines/z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..models.dinov3_image_encoder import DINOv3ImageEncoder
from ..models.z_image_image2lora import ZImageImage2LoRAModel

from ..utils.inference_time_scaling.ses import run_ses_cem, SESRewardScorer

class ZImagePipeline(BasePipeline):

Expand Down Expand Up @@ -116,6 +117,11 @@ def __call__(
positive_only_lora: Dict[str, torch.Tensor] = None,
# Progress bar
progress_bar_cmd = tqdm,
# SES
enable_ses: bool = False,
ses_reward_model: str = "pick",
ses_eval_budget: int = 50,
ses_inference_steps: int = 10,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
Expand All @@ -140,6 +146,46 @@ def __call__(
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)

if enable_ses:
print(f"[SES] Starting optimization with budget={ses_eval_budget}, steps={ses_inference_steps}")
scorer = SESRewardScorer(ses_reward_model, device=self.device, dtype=self.torch_dtype)
self.load_models_to_device(list(self.in_iteration_models) + ['vae_decoder'])
models = {name: getattr(self, name) for name in self.in_iteration_models}
def ses_generate_callback(trial_latents):
trial_inputs = inputs_shared.copy()
self.scheduler.set_timesteps(ses_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
eval_timesteps = self.scheduler.timesteps
curr_latents = trial_latents

for progress_id, timestep in enumerate(eval_timesteps):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
trial_inputs["latents"] = curr_latents
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
trial_inputs, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
curr_latents = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **trial_inputs)
decoded_img = self.vae_decoder(curr_latents)
return self.vae_output_to_image(decoded_img)

initial_noise = inputs_shared["latents"]

optimized_latents = run_ses_cem(
base_latents=initial_noise,
pipeline_callback=ses_generate_callback,
prompt=prompt,
scorer=scorer,
total_eval_budget=ses_eval_budget,
popsize=10,
k_elites=5
)
inputs_shared["latents"] = optimized_latents
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)

del scorer
torch.cuda.empty_cache()
Comment on lines +149 to +187
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code for Inference-Time Scaling (SES) is largely duplicated across multiple pipeline files. This duplication makes the code harder to maintain and update. Consider refactoring this logic into a shared helper function or a method in the BasePipeline class. This would centralize the SES implementation, making it easier to manage and reducing the risk of inconsistencies between pipelines.


# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
Expand Down
Empty file.
Loading