From 1d07b6643a9acf631043c0a5b045e38cd59a7ec2 Mon Sep 17 00:00:00 2001
From: yjy415 <2471352175@qq.com>
Date: Fri, 6 Feb 2026 12:38:44 +0800
Subject: [PATCH] add:inference-time-scaling
---
diffsynth/pipelines/flux2_image.py | 58 ++++++
diffsynth/pipelines/flux_image.py | 50 +++++
diffsynth/pipelines/qwen_image.py | 53 +++++-
diffsynth/pipelines/z_image.py | 46 +++++
.../utils/inference_time_scaling/__init__.py | 0
diffsynth/utils/inference_time_scaling/ses.py | 174 ++++++++++++++++++
.../inference_time_scaling.md | 99 ++++++++++
.../inference_time_scaling.md | 100 ++++++++++
.../flux/model_inference/FLUX.1-dev-SES.py | 29 +++
.../flux2/model_inference/FLUX.2-dev-SES.py | 37 ++++
.../model_inference/Qwen-Image-SES.py | 27 +++
.../z_image/model_inference/Z-Image-SES.py | 30 +++
.../model_inference/Z-Image-Turbo-SES.py | 26 +++
13 files changed, 728 insertions(+), 1 deletion(-)
create mode 100644 diffsynth/utils/inference_time_scaling/__init__.py
create mode 100644 diffsynth/utils/inference_time_scaling/ses.py
create mode 100644 docs/en/Research_Tutorial/inference_time_scaling.md
create mode 100644 docs/zh/Research_Tutorial/inference_time_scaling.md
create mode 100644 examples/flux/model_inference/FLUX.1-dev-SES.py
create mode 100644 examples/flux2/model_inference/FLUX.2-dev-SES.py
create mode 100644 examples/qwen_image/model_inference/Qwen-Image-SES.py
create mode 100644 examples/z_image/model_inference/Z-Image-SES.py
create mode 100644 examples/z_image/model_inference/Z-Image-Turbo-SES.py
diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py
index d5dc35bd4..5daec1644 100644
--- a/diffsynth/pipelines/flux2_image.py
+++ b/diffsynth/pipelines/flux2_image.py
@@ -17,6 +17,7 @@
from ..models.flux2_vae import Flux2VAE
from ..models.z_image_text_encoder import ZImageTextEncoder
+from ..utils.inference_time_scaling.ses import run_ses_cem, SESRewardScorer
class Flux2ImagePipeline(BasePipeline):
@@ -94,6 +95,11 @@ def __call__(
num_inference_steps: int = 30,
# Progress bar
progress_bar_cmd = tqdm,
+ # SES
+ enable_ses: bool = False,
+ ses_reward_model: str = "pick",
+ ses_eval_budget: int = 50,
+ ses_inference_steps: int = 10,
):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
@@ -115,6 +121,58 @@ def __call__(
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
+ # Inference-Time Scaling (SES)
+ if enable_ses:
+ print(f"[SES] Starting optimization with budget={ses_eval_budget}, steps={ses_inference_steps}")
+ scorer = SESRewardScorer(ses_reward_model, device=self.device, dtype=self.torch_dtype)
+ self.load_models_to_device(list(self.in_iteration_models) + ['vae'])
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
+
+ h_latent = height // 16
+ w_latent = width // 16
+
+ def ses_generate_callback(trial_latents_spatial):
+ trial_inputs = inputs_shared.copy()
+
+ self.scheduler.set_timesteps(ses_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=h_latent*w_latent)
+ eval_timesteps = self.scheduler.timesteps
+
+ curr_latents_seq = rearrange(trial_latents_spatial, "b c h w -> b (h w) c")
+
+ for progress_id, timestep in enumerate(eval_timesteps):
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+
+ trial_inputs["latents"] = curr_latents_seq
+
+ noise_pred = self.cfg_guided_model_fn(
+ self.model_fn, cfg_scale,
+ trial_inputs, inputs_posi, inputs_nega,
+ **models, timestep=timestep, progress_id=progress_id
+ )
+ curr_latents_seq = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **trial_inputs)
+
+ curr_latents_spatial = rearrange(curr_latents_seq, "b (h w) c -> b c h w", h=h_latent, w=w_latent)
+
+ decoded_img = self.vae.decode(curr_latents_spatial)
+ return self.vae_output_to_image(decoded_img)
+ initial_noise_seq = inputs_shared["latents"]
+ initial_noise_spatial = rearrange(initial_noise_seq, "b (h w) c -> b c h w", h=h_latent, w=w_latent)
+
+ optimized_latents_spatial = run_ses_cem(
+ base_latents=initial_noise_spatial,
+ pipeline_callback=ses_generate_callback,
+ prompt=prompt,
+ scorer=scorer,
+ total_eval_budget=ses_eval_budget,
+ popsize=10,
+ k_elites=5
+ )
+ optimized_latents_seq = rearrange(optimized_latents_spatial, "b c h w -> b (h w) c")
+ inputs_shared["latents"] = optimized_latents_seq
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=h_latent*w_latent)
+ del scorer
+ torch.cuda.empty_cache()
+
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py
index bfc53e505..c6751494f 100644
--- a/diffsynth/pipelines/flux_image.py
+++ b/diffsynth/pipelines/flux_image.py
@@ -20,6 +20,8 @@
from ..models.step1x_text_encoder import Step1xEditEmbedder
from ..core.vram.layers import AutoWrappedLinear
+from ..utils.inference_time_scaling.ses import run_ses_cem, SESRewardScorer
+
class MultiControlNet(torch.nn.Module):
def __init__(self, models: list[torch.nn.Module]):
super().__init__()
@@ -240,6 +242,11 @@ def __call__(
tile_stride: int = 64,
# Progress bar
progress_bar_cmd = tqdm,
+ # SES
+ enable_ses: bool = False,
+ ses_reward_model: str = "pick",
+ ses_eval_budget: int = 50,
+ ses_inference_steps: int = 10,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
@@ -274,6 +281,49 @@ def __call__(
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
+ # Inference-Time Scaling (SES)
+ if enable_ses:
+ print(f"[SES] Starting optimization with budget={ses_eval_budget}, steps={ses_inference_steps}")
+ scorer = SESRewardScorer(ses_reward_model, device=self.device, dtype=self.torch_dtype)
+ self.load_models_to_device(list(self.in_iteration_models) + ['vae_decoder'])
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
+
+ def ses_generate_callback(trial_latents):
+ trial_inputs = inputs_shared.copy()
+
+ self.scheduler.set_timesteps(ses_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
+ eval_timesteps = self.scheduler.timesteps
+ curr_latents = trial_latents
+
+ for progress_id, timestep in enumerate(eval_timesteps):
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+
+ trial_inputs["latents"] = curr_latents
+ noise_pred = self.cfg_guided_model_fn(
+ self.model_fn, cfg_scale,
+ trial_inputs, inputs_posi, inputs_nega,
+ **models, timestep=timestep, progress_id=progress_id
+ )
+ curr_latents = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **trial_inputs)
+
+ decoded_img = self.vae_decoder(curr_latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return self.vae_output_to_image(decoded_img)
+
+ initial_noise = inputs_shared["latents"]
+ optimized_latents = run_ses_cem(
+ base_latents=initial_noise,
+ pipeline_callback=ses_generate_callback,
+ prompt=prompt,
+ scorer=scorer,
+ total_eval_budget=ses_eval_budget,
+ popsize=10,
+ k_elites=5
+ )
+ inputs_shared["latents"] = optimized_latents
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
+ del scorer
+ torch.cuda.empty_cache()
+
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py
index 75cfbee77..8a75db74f 100644
--- a/diffsynth/pipelines/qwen_image.py
+++ b/diffsynth/pipelines/qwen_image.py
@@ -20,6 +20,7 @@
from ..models.dinov3_image_encoder import DINOv3ImageEncoder
from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel
+from ..utils.inference_time_scaling.ses import run_ses_cem, SESRewardScorer
class QwenImagePipeline(BasePipeline):
@@ -141,6 +142,11 @@ def __call__(
tile_stride: int = 64,
# Progress bar
progress_bar_cmd = tqdm,
+ # SES
+ enable_ses: bool = False,
+ ses_reward_model: str = "pick",
+ ses_eval_budget: int = 50,
+ ses_inference_steps: int = 10,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu)
@@ -171,6 +177,51 @@ def __call__(
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
+ # Inference-Time Scaling (SES)
+ if enable_ses:
+ print(f"[SES] Starting optimization with budget={ses_eval_budget}, steps={ses_inference_steps}")
+ scorer = SESRewardScorer(ses_reward_model, device=self.device, dtype=self.torch_dtype)
+
+ self.load_models_to_device(list(self.in_iteration_models) + ['vae'])
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
+
+ def ses_generate_callback(trial_latents):
+ trial_inputs = inputs_shared.copy()
+ self.scheduler.set_timesteps(ses_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu)
+ eval_timesteps = self.scheduler.timesteps
+ curr_latents = trial_latents
+
+ for progress_id, timestep in enumerate(eval_timesteps):
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+
+ trial_inputs["latents"] = curr_latents
+
+ noise_pred = self.cfg_guided_model_fn(
+ self.model_fn, cfg_scale,
+ trial_inputs, inputs_posi, inputs_nega,
+ **models, timestep=timestep, progress_id=progress_id
+ )
+ curr_latents = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **trial_inputs)
+
+ decoded_img = self.vae.decode(curr_latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return self.vae_output_to_image(decoded_img)
+
+ initial_noise = inputs_shared["latents"]
+
+ optimized_latents = run_ses_cem(
+ base_latents=initial_noise,
+ pipeline_callback=ses_generate_callback,
+ prompt=prompt,
+ scorer=scorer,
+ total_eval_budget=ses_eval_budget,
+ popsize=10,
+ k_elites=5
+ )
+ inputs_shared["latents"] = optimized_latents
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu)
+ del scorer
+ torch.cuda.empty_cache()
+
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
@@ -182,7 +233,7 @@ def __call__(
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
-
+
# Decode
self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py
index 2c5b68730..decfd91d9 100644
--- a/diffsynth/pipelines/z_image.py
+++ b/diffsynth/pipelines/z_image.py
@@ -23,6 +23,7 @@
from ..models.dinov3_image_encoder import DINOv3ImageEncoder
from ..models.z_image_image2lora import ZImageImage2LoRAModel
+from ..utils.inference_time_scaling.ses import run_ses_cem, SESRewardScorer
class ZImagePipeline(BasePipeline):
@@ -116,6 +117,11 @@ def __call__(
positive_only_lora: Dict[str, torch.Tensor] = None,
# Progress bar
progress_bar_cmd = tqdm,
+ # SES
+ enable_ses: bool = False,
+ ses_reward_model: str = "pick",
+ ses_eval_budget: int = 50,
+ ses_inference_steps: int = 10,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
@@ -140,6 +146,46 @@ def __call__(
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
+ if enable_ses:
+ print(f"[SES] Starting optimization with budget={ses_eval_budget}, steps={ses_inference_steps}")
+ scorer = SESRewardScorer(ses_reward_model, device=self.device, dtype=self.torch_dtype)
+ self.load_models_to_device(list(self.in_iteration_models) + ['vae_decoder'])
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
+ def ses_generate_callback(trial_latents):
+ trial_inputs = inputs_shared.copy()
+ self.scheduler.set_timesteps(ses_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
+ eval_timesteps = self.scheduler.timesteps
+ curr_latents = trial_latents
+
+ for progress_id, timestep in enumerate(eval_timesteps):
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+ trial_inputs["latents"] = curr_latents
+ noise_pred = self.cfg_guided_model_fn(
+ self.model_fn, cfg_scale,
+ trial_inputs, inputs_posi, inputs_nega,
+ **models, timestep=timestep, progress_id=progress_id
+ )
+ curr_latents = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **trial_inputs)
+ decoded_img = self.vae_decoder(curr_latents)
+ return self.vae_output_to_image(decoded_img)
+
+ initial_noise = inputs_shared["latents"]
+
+ optimized_latents = run_ses_cem(
+ base_latents=initial_noise,
+ pipeline_callback=ses_generate_callback,
+ prompt=prompt,
+ scorer=scorer,
+ total_eval_budget=ses_eval_budget,
+ popsize=10,
+ k_elites=5
+ )
+ inputs_shared["latents"] = optimized_latents
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
+
+ del scorer
+ torch.cuda.empty_cache()
+
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
diff --git a/diffsynth/utils/inference_time_scaling/__init__.py b/diffsynth/utils/inference_time_scaling/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/diffsynth/utils/inference_time_scaling/ses.py b/diffsynth/utils/inference_time_scaling/ses.py
new file mode 100644
index 000000000..f00b7fd9a
--- /dev/null
+++ b/diffsynth/utils/inference_time_scaling/ses.py
@@ -0,0 +1,174 @@
+import torch
+import pywt
+import numpy as np
+import torch.nn as nn
+from transformers import AutoProcessor, AutoModel
+from tqdm import tqdm
+import os
+
+def split_dwt(z_tensor_cpu, wavelet_name, dwt_level):
+ all_clow_np = []
+ all_chigh_list = []
+ z_tensor_cpu = z_tensor_cpu.cpu().float()
+
+ for i in range(z_tensor_cpu.shape[0]):
+ z_numpy_ch = z_tensor_cpu[i].numpy()
+
+ coeffs_ch = pywt.wavedec2(z_numpy_ch, wavelet_name, level=dwt_level, mode='symmetric', axes=(-2, -1))
+
+ clow_np = coeffs_ch[0]
+ chigh_list = coeffs_ch[1:]
+
+ all_clow_np.append(clow_np)
+ all_chigh_list.append(chigh_list)
+
+ all_clow_tensor = torch.from_numpy(np.stack(all_clow_np, axis=0))
+ return all_clow_tensor, all_chigh_list
+
+def reconstruct_dwt(c_low_tensor_cpu, c_high_coeffs, wavelet_name, original_shape):
+ H_high, W_high = original_shape
+ c_low_tensor_cpu = c_low_tensor_cpu.cpu().float()
+
+ clow_np = c_low_tensor_cpu.numpy()
+
+ if clow_np.ndim == 4 and clow_np.shape[0] == 1:
+ clow_np = clow_np[0]
+
+ coeffs_combined = [clow_np] + c_high_coeffs
+ z_recon_np = pywt.waverec2(coeffs_combined, wavelet_name, mode='symmetric', axes=(-2, -1))
+ if z_recon_np.shape[-2] != H_high or z_recon_np.shape[-1] != W_high:
+ z_recon_np = z_recon_np[..., :H_high, :W_high]
+ z_recon_tensor = torch.from_numpy(z_recon_np)
+ if z_recon_tensor.ndim == 3:
+ z_recon_tensor = z_recon_tensor.unsqueeze(0)
+ return z_recon_tensor
+
+class SESRewardScorer:
+ def __init__(self, reward_name, device, dtype):
+ self.reward_name = reward_name
+ self.device = device
+ self.dtype = dtype
+ self.model = None
+ self.processor = None
+ self._load_model()
+
+ def _load_model(self):
+ print(f"[SES] Loading Reward Model: {self.reward_name}...")
+ if self.reward_name == "pick":
+ self.processor = AutoProcessor.from_pretrained("yuvalkirstain/PickScore_v1")
+ self.model = AutoModel.from_pretrained("yuvalkirstain/PickScore_v1", torch_dtype=self.dtype).to(self.device).eval()
+ elif self.reward_name == "clip":
+ self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
+ self.model = AutoModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=self.dtype).to(self.device).eval()
+ elif self.reward_name == "hps":
+ self.processor = AutoProcessor.from_pretrained("adams-story/HPSv2-hf")
+ self.model = AutoModel.from_pretrained("adams-story/HPSv2-hf", torch_dtype=self.dtype, trust_remote_code=True).to(self.device).eval()
+ else:
+ print(f"[SES] Warning: Reward model {self.reward_name} not implemented in wrapper, skipping.")
+
+ def get_score(self, image_pil, text_prompt):
+ try:
+ with torch.no_grad():
+ if self.reward_name == "pick":
+ inputs = self.processor(text=[text_prompt], images=[image_pil], return_tensors="pt", padding="max_length", truncation=True, max_length=77).to(self.device)
+ inputs['pixel_values'] = inputs['pixel_values'].to(self.dtype)
+ outputs = self.model(**inputs)
+ return outputs.logits_per_image[0, 0].item()
+
+ elif self.reward_name == "clip":
+ inputs = self.processor(text=[text_prompt], images=[image_pil], return_tensors="pt", padding=True, truncation=True).to(self.device)
+ if 'pixel_values' in inputs: inputs['pixel_values'] = inputs['pixel_values'].to(self.dtype)
+ outputs = self.model(**inputs)
+ return outputs.logits_per_image.item()
+
+ elif self.reward_name == "hps":
+ inputs = self.processor(text=[text_prompt], images=[image_pil], return_tensors="pt", padding="max_length", truncation=True, max_length=77).to(self.device)
+ inputs['pixel_values'] = inputs['pixel_values'].to(self.dtype)
+ outputs = self.model(**inputs)
+ return outputs.logits_per_image[0, 0].item()
+ except Exception as e:
+ print(f"Error computing score: {e}")
+ return 0.0
+ return 0.0
+def run_ses_cem(
+ base_latents,
+ pipeline_callback,
+ prompt,
+ scorer,
+ total_eval_budget=50,
+ popsize=10,
+ k_elites=5,
+ wavelet_name="db1",
+ dwt_level=5,
+ lambda_prior=1e-3
+):
+ latent_h, latent_w = base_latents.shape[-2], base_latents.shape[-1]
+ c_low_init, c_high_fixed_batch = split_dwt(base_latents, wavelet_name, dwt_level)
+ c_high_fixed = c_high_fixed_batch[0]
+ c_low_shape = c_low_init.shape[1:]
+ mu = c_low_init.view(-1).cpu()
+ sigma_sq = torch.ones_like(mu) * 1.0
+
+ best_overall = {"fitness": -float('inf'), "score": -float('inf'), "c_low": c_low_init[0]}
+ eval_count = 0
+
+ elite_db = []
+ n_generations = (total_eval_budget // popsize) + 5
+ pbar = tqdm(total=total_eval_budget, desc="[SES] Searching", unit="img")
+
+ for gen in range(n_generations):
+ if eval_count >= total_eval_budget: break
+
+ std = torch.sqrt(torch.clamp(sigma_sq, min=1e-9))
+ z_noise = torch.randn(popsize, mu.shape[0])
+ samples_flat = mu + z_noise * std
+ samples_reshaped = samples_flat.view(popsize, *c_low_shape)
+
+ batch_results = []
+
+ for i in range(popsize):
+ if eval_count >= total_eval_budget: break
+
+ c_low_sample = samples_reshaped[i].unsqueeze(0)
+ z_recon = reconstruct_dwt(c_low_sample, c_high_fixed, wavelet_name, (latent_h, latent_w))
+ z_recon = z_recon.to(base_latents.device, dtype=base_latents.dtype)
+ img = pipeline_callback(z_recon)
+
+ score = scorer.get_score(img, prompt)
+ penalty = lambda_prior * (torch.norm(c_low_sample.float())**2).item()
+ fitness = score - penalty
+
+ res = {
+ "fitness": fitness,
+ "score": score,
+ "c_low": c_low_sample.cpu()
+ }
+ batch_results.append(res)
+ if fitness > best_overall['fitness']:
+ best_overall = res
+
+ eval_count += 1
+ pbar.update(1)
+ pbar.set_postfix({
+ "Gen": gen,
+ "Best": f"{best_overall['score']:.4f}"
+ })
+
+ if not batch_results: break
+ elite_db.extend(batch_results)
+ elite_db.sort(key=lambda x: x['fitness'], reverse=True)
+ elite_db = elite_db[:k_elites]
+ elites_flat = torch.stack([x['c_low'].view(-1) for x in elite_db])
+ mu_new = torch.mean(elites_flat, dim=0)
+
+ if len(elite_db) > 1:
+ sigma_sq_new = torch.var(elites_flat, dim=0, unbiased=True) + 1e-7
+ else:
+ sigma_sq_new = sigma_sq
+ mu = mu_new
+ sigma_sq = sigma_sq_new
+ pbar.close()
+ best_c_low = best_overall['c_low']
+ final_latents = reconstruct_dwt(best_c_low, c_high_fixed, wavelet_name, (latent_h, latent_w))
+
+ return final_latents.to(base_latents.device, dtype=base_latents.dtype)
\ No newline at end of file
diff --git a/docs/en/Research_Tutorial/inference_time_scaling.md b/docs/en/Research_Tutorial/inference_time_scaling.md
new file mode 100644
index 000000000..3de452d14
--- /dev/null
+++ b/docs/en/Research_Tutorial/inference_time_scaling.md
@@ -0,0 +1,99 @@
+# Inference-Time Scaling
+
+DiffSynth-Studio supports Inference-time Scaling technology, specifically implementing the **Spectral Evolution Search (SES)** algorithm. This technology allows users to improve generated image quality during the inference stage by increasing computational cost, without retraining the model.
+
+## 1. Basic Principle
+
+The traditional text-to-image inference process starts from random Gaussian noise and generates an image through a fixed number of denoising steps. The quality generated in this way highly depends on the randomness of the initial noise.
+
+**SES (Spectral Evolution Search)** transforms the inference process into a **search optimization problem** targeting the initial noise:
+
+1. **Search Space**: Search for the lowest frequency part of the initial noise in the frequency domain of wavelet transform.
+2. **Evolutionary Strategy**: Use the Cross-Entropy Method to iteratively sample noise populations.
+3. **Reward Feedback**: Use reward models such as PickScore to score generated low-step preview images.
+4. **Result Output**: Find the noise with the highest score and perform complete high-quality denoising.
+
+This method essentially **trades inference computation time for generation quality**.
+
+For more technical details on this method, please refer to the paper: **[Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation](https://arxiv.org/abs/2602.03208)**.
+
+## 2. Quick Start
+
+In DiffSynth-Studio, SES has been integrated into the pipelines of mainstream text-to-image models. You only need to set `enable_ses=True` when calling `pipe()` to enable it.
+
+The following is [quick start code](https://www.google.com/search?q=../../../examples/z_image/model_inference/Z-Image-Turbo-SES.py) using **Z-Image-Turbo** as an example:
+
+```python
+from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
+import torch
+
+pipe = ZImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
+)
+
+prompt = "Chinese woman in red Hanfu holding a fan, with a bright yellow neon lightning bolt floating above her palm. Soft-lit outdoor night background, silhouetted tiered pagoda, blurred colorful distant lights."
+
+image = pipe(
+ prompt=prompt,
+ seed=42,
+ rand_device="cuda",
+ enable_ses=True,
+ ses_reward_model="pick",
+ ses_eval_budget=50,
+ ses_inference_steps=8
+)
+image.save("image_Z-Image-Turbo_ses.jpg")
+```
+
+## 3. Supported Models and Parameters
+
+### 3.1 Core Parameters Detailed
+
+In the `pipe()` call, the following parameters control the behavior of SES:
+
+| Parameter Name | Type | Default Value | Description |
+| --- | --- | --- | --- |
+| `enable_ses` | `bool` | `False` | Whether to enable SES optimization. |
+| `ses_reward_model` | `str` | `"pick"` | Reward model selection. Supports `"pick"` (PickScore), `"hps"` (HPSv2), `"clip"`. |
+| `ses_eval_budget` | `int` | `50` | Total search budget (total number of evaluated samples). Higher values mean a higher quality ceiling but longer time consumption. |
+| `ses_inference_steps` | `int` | `10` | The number of steps used to generate preview images during the search phase. Higher values provide more accurate quality assessment for candidate noise but take longer; 8-15 is recommended. |
+
+### 3.2 Supported Model List
+
+Currently, the following text-to-image models support SES:
+
+* **[Qwen-Image](../../../examples/qwen_image/model_inference/Qwen-Image-SES.py)**
+* **[FLUX.1-dev](../../../examples/flux/model_inference/FLUX.1-dev-SES.py)**
+* **[FLUX.2-dev](../../../examples/flux2/model_inference/FLUX.2-dev-SES.py)**
+* **[Z-Image](../../../examples/z_image/model_inference/Z-Image-SES.py) / [Z-Image-Turbo](../../../examples/z_image/model_inference/Z-Image-Turbo-SES.py)**
+
+## 4. Effect Demonstration
+
+As the search budget (`ses_eval_budget`) increases, SES can stably improve image quality. The following shows the quality changes brought by different computational budgets under the same random seed.
+
+**Scenario 1: Qwen-Image**
+
+* **Prompt**: *"Springtime in the style of Paul Delvaux"*
+* **Reward Model**: PickScore
+
+| **Budget = 0** | **Budget = 10** | **Budget = 30** | **Budget = 50** |
+| --- | --- | --- | --- |
+| | | | |
+|
|
|
|
|
+
+**Scenario 2: FLUX.1-dev**
+
+* **Prompt**: *"A masterful painting of a young woman in the style of Diego Velázquez."*
+* **Reward Model**: HPSv2
+
+| **Budget = 0** | **Budget = 10** | **Budget = 30** | **Budget = 50** |
+| --- | --- | --- | --- |
+| | | | |
+|
|
|
|
|
\ No newline at end of file
diff --git a/docs/zh/Research_Tutorial/inference_time_scaling.md b/docs/zh/Research_Tutorial/inference_time_scaling.md
new file mode 100644
index 000000000..5ba28dfec
--- /dev/null
+++ b/docs/zh/Research_Tutorial/inference_time_scaling.md
@@ -0,0 +1,100 @@
+# 推理时扩展
+
+DiffSynth-Studio 支持推理时扩展(Inference-time Scaling)技术,具体实现了 **Spectral Evolution Search (SES)** 算法。该技术允许用户在推理阶段通过增加计算量来提升生成图像质量,而无需重新训练模型。
+
+## 1. 基本原理
+
+传统的文生图推理过程是从一个随机的高斯噪声开始,经过固定的去噪步数生成图像。这种方式生成的质量高度依赖于初始噪声的随机性。
+
+**SES (Spectral Evolution Search)** 将推理过程转化为一个针对初始噪声的**搜索优化问题**:
+
+1. **搜索空间**:在小波变换的频域空间中搜索初始噪声的最低频部分。
+2. **进化策略**:使用交叉熵方法迭代采样噪声群体。
+3. **奖励反馈**:利用 PickScore 等奖励模型对生成的低步数预览图进行评分。
+4. **结果输出**:找到得分最高的噪声,进行完整的高质量去噪。
+
+这种方法本质上是用**推理计算时间换取生成质量**。
+
+关于该方法的更多技术细节,请参考论文:**[Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation](https://arxiv.org/abs/2602.03208)**。
+
+## 2. 快速开始
+
+在 DiffSynth-Studio 中,SES 已集成到主流文生图模型的 Pipeline 中。你只需在调用 `pipe()` 时设置 `enable_ses=True` 即可开启。
+
+以下是以 **Z-Image-Turbo** 为例的[快速上手代码](../../../examples/z_image/model_inference/Z-Image-Turbo-SES.py):
+
+```python
+from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
+import torch
+
+pipe = ZImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
+)
+
+prompt = "Chinese woman in red Hanfu holding a fan, with a bright yellow neon lightning bolt floating above her palm. Soft-lit outdoor night background, silhouetted tiered pagoda, blurred colorful distant lights."
+
+image = pipe(
+ prompt=prompt,
+ seed=42,
+ rand_device="cuda",
+ enable_ses=True,
+ ses_reward_model="pick",
+ ses_eval_budget=50,
+ ses_inference_steps=8
+)
+image.save("image_Z-Image-Turbo_ses.jpg")
+```
+
+## 3. 支持的模型与参数
+
+### 3.1 核心参数详解
+
+在 `pipe()` 调用中,以下参数控制 SES 的行为:
+
+| 参数名 | 类型 | 默认值 | 说明 |
+| --- | --- | --- | --- |
+| `enable_ses` | `bool` | `False` | 是否开启 SES 优化。 |
+| `ses_reward_model` | `str` | `"pick"` | 奖励模型选择。支持 `"pick"` (PickScore), `"hps"` (HPSv2), `"clip"`。 |
+| `ses_eval_budget` | `int` | `50` | 搜索的总预算(评估样本总数)。数值越高,质量上限越高,但耗时越长。 |
+| `ses_inference_steps` | `int` | `10` | 搜索阶段生成预览图使用的步数。数值越高,对于候选噪声的质量评估越准确,但耗时越长,建议设为 8~15 。 |
+
+### 3.2 支持模型列表
+
+目前以下文生图模型均已支持 SES:
+
+* **[Qwen-Image](../../../examples/qwen_image/model_inference/Qwen-Image-SES.py)**
+* **[FLUX.1-dev](../../../examples/flux/model_inference/FLUX.1-dev-SES.py)**
+* **[FLUX.2-dev](../../../examples/flux2/model_inference/FLUX.2-dev-SES.py)**
+* **[Z-Image](../../../examples/z_image/model_inference/Z-Image-SES.py) / [Z-Image-Turbo](../../../examples/z_image/model_inference/Z-Image-Turbo-SES.py)**
+
+
+## 4. 效果展示
+
+随着搜索预算(`ses_eval_budget`)的增加,SES 能够稳定地提升图像质量。以下展示了在相同随机种子下,不同计算预算带来的质量变化。
+
+**场景 1:Qwen-Image**
+
+* **Prompt**: *"Springtime in the style of Paul Delvaux"*
+* **Reward Model**: PickScore
+
+| **Budget = 0** | **Budget = 10** | **Budget = 30** | **Budget = 50** |
+| --- | --- | --- | --- |
+| | | | |
+|
|
|
|
|
+
+**场景 2:FLUX.1-dev**
+
+* **Prompt**: *"A masterful painting of a young woman in the style of Diego Velázquez."*
+* **Reward Model**: HPSv2
+
+| **Budget = 0** | **Budget = 10** | **Budget = 30** | **Budget = 50** |
+| --- | --- | --- | --- |
+| | | | |
+|
|
|
|
|
\ No newline at end of file
diff --git a/examples/flux/model_inference/FLUX.1-dev-SES.py b/examples/flux/model_inference/FLUX.1-dev-SES.py
new file mode 100644
index 000000000..5e4280c7d
--- /dev/null
+++ b/examples/flux/model_inference/FLUX.1-dev-SES.py
@@ -0,0 +1,29 @@
+import torch
+from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig
+
+pipe = FluxImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"),
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
+ ],
+)
+
+prompt = "A solo girl with silver wavy hair and blue eyes, wearing a blue dress, underwater, air bubbles, floating hair."
+negative_prompt = "nsfw, low quality"
+
+image = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ seed=0,
+ cfg_scale=2,
+ num_inference_steps=50,
+ enable_ses=True,
+ ses_reward_model="pick",
+ ses_eval_budget=20,
+ ses_inference_steps=20
+)
+image.save("flux_ses_optimized.jpg")
\ No newline at end of file
diff --git a/examples/flux2/model_inference/FLUX.2-dev-SES.py b/examples/flux2/model_inference/FLUX.2-dev-SES.py
new file mode 100644
index 000000000..af3c608b6
--- /dev/null
+++ b/examples/flux2/model_inference/FLUX.2-dev-SES.py
@@ -0,0 +1,37 @@
+from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
+import torch
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cuda",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+pipe = Flux2ImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
+ ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config),
+ ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
+)
+prompt = "A hermit crab using a soda can as its shell on the beach. The can has the text 'BFL Diffusers' on it."
+
+image = pipe(
+ prompt,
+ seed=42,
+ rand_device="cuda",
+ num_inference_steps=50,
+ enable_ses=True,
+ ses_reward_model="pick",
+ ses_eval_budget=20,
+ ses_inference_steps=10
+)
+
+image.save("image_FLUX.2-dev_ses.jpg")
diff --git a/examples/qwen_image/model_inference/Qwen-Image-SES.py b/examples/qwen_image/model_inference/Qwen-Image-SES.py
new file mode 100644
index 000000000..62b047323
--- /dev/null
+++ b/examples/qwen_image/model_inference/Qwen-Image-SES.py
@@ -0,0 +1,27 @@
+from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
+import torch
+
+pipe = QwenImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
+)
+
+prompt = "水下少女,身穿蓝裙,周围有气泡。"
+
+image = pipe(
+ prompt,
+ seed=0,
+ num_inference_steps=40,
+ enable_ses=True,
+ ses_reward_model="pick",
+ ses_eval_budget=20,
+ ses_inference_steps=10
+)
+
+image.save("image_ses.jpg")
\ No newline at end of file
diff --git a/examples/z_image/model_inference/Z-Image-SES.py b/examples/z_image/model_inference/Z-Image-SES.py
new file mode 100644
index 000000000..2f4f8c111
--- /dev/null
+++ b/examples/z_image/model_inference/Z-Image-SES.py
@@ -0,0 +1,30 @@
+from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
+import torch
+
+
+pipe = ZImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors"),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
+)
+prompt = "Chinese woman in red Hanfu holding a fan, with a bright yellow neon lightning bolt floating above her palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
+
+image = pipe(
+ prompt=prompt,
+ seed=42,
+ num_inference_steps=50,
+ cfg_scale=4,
+ rand_device="cuda",
+ enable_ses=True,
+ ses_reward_model="pick",
+ ses_eval_budget=20,
+ ses_inference_steps=10
+)
+image.save("image_Z-Image_ses.jpg")
+
+
diff --git a/examples/z_image/model_inference/Z-Image-Turbo-SES.py b/examples/z_image/model_inference/Z-Image-Turbo-SES.py
new file mode 100644
index 000000000..4248c975f
--- /dev/null
+++ b/examples/z_image/model_inference/Z-Image-Turbo-SES.py
@@ -0,0 +1,26 @@
+from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
+import torch
+
+pipe = ZImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
+)
+
+prompt = "Chinese woman in red Hanfu holding a fan, with a bright yellow neon lightning bolt floating above her palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
+
+image = pipe(
+ prompt=prompt,
+ seed=42,
+ rand_device="cuda",
+ enable_ses=True,
+ ses_reward_model="pick",
+ ses_eval_budget=50,
+ ses_inference_steps=8
+)
+image.save("image_Z-Image-Turbo_ses.jpg")
\ No newline at end of file