diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index d5dc35bd4..5daec1644 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -17,6 +17,7 @@ from ..models.flux2_vae import Flux2VAE from ..models.z_image_text_encoder import ZImageTextEncoder +from ..utils.inference_time_scaling.ses import run_ses_cem, SESRewardScorer class Flux2ImagePipeline(BasePipeline): @@ -94,6 +95,11 @@ def __call__( num_inference_steps: int = 30, # Progress bar progress_bar_cmd = tqdm, + # SES + enable_ses: bool = False, + ses_reward_model: str = "pick", + ses_eval_budget: int = 50, + ses_inference_steps: int = 10, ): self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16) @@ -115,6 +121,58 @@ def __call__( for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + # Inference-Time Scaling (SES) + if enable_ses: + print(f"[SES] Starting optimization with budget={ses_eval_budget}, steps={ses_inference_steps}") + scorer = SESRewardScorer(ses_reward_model, device=self.device, dtype=self.torch_dtype) + self.load_models_to_device(list(self.in_iteration_models) + ['vae']) + models = {name: getattr(self, name) for name in self.in_iteration_models} + + h_latent = height // 16 + w_latent = width // 16 + + def ses_generate_callback(trial_latents_spatial): + trial_inputs = inputs_shared.copy() + + self.scheduler.set_timesteps(ses_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=h_latent*w_latent) + eval_timesteps = self.scheduler.timesteps + + curr_latents_seq = rearrange(trial_latents_spatial, "b c h w -> b (h w) c") + + for progress_id, timestep in enumerate(eval_timesteps): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + + trial_inputs["latents"] = curr_latents_seq + + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + trial_inputs, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + curr_latents_seq = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **trial_inputs) + + curr_latents_spatial = rearrange(curr_latents_seq, "b (h w) c -> b c h w", h=h_latent, w=w_latent) + + decoded_img = self.vae.decode(curr_latents_spatial) + return self.vae_output_to_image(decoded_img) + initial_noise_seq = inputs_shared["latents"] + initial_noise_spatial = rearrange(initial_noise_seq, "b (h w) c -> b c h w", h=h_latent, w=w_latent) + + optimized_latents_spatial = run_ses_cem( + base_latents=initial_noise_spatial, + pipeline_callback=ses_generate_callback, + prompt=prompt, + scorer=scorer, + total_eval_budget=ses_eval_budget, + popsize=10, + k_elites=5 + ) + optimized_latents_seq = rearrange(optimized_latents_spatial, "b c h w -> b (h w) c") + inputs_shared["latents"] = optimized_latents_seq + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=h_latent*w_latent) + del scorer + torch.cuda.empty_cache() + # Denoise self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index bfc53e505..c6751494f 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -20,6 +20,8 @@ from ..models.step1x_text_encoder import Step1xEditEmbedder from ..core.vram.layers import AutoWrappedLinear +from ..utils.inference_time_scaling.ses import run_ses_cem, SESRewardScorer + class MultiControlNet(torch.nn.Module): def __init__(self, models: list[torch.nn.Module]): super().__init__() @@ -240,6 +242,11 @@ def __call__( tile_stride: int = 64, # Progress bar progress_bar_cmd = tqdm, + # SES + enable_ses: bool = False, + ses_reward_model: str = "pick", + ses_eval_budget: int = 50, + ses_inference_steps: int = 10, ): # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) @@ -274,6 +281,49 @@ def __call__( for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + # Inference-Time Scaling (SES) + if enable_ses: + print(f"[SES] Starting optimization with budget={ses_eval_budget}, steps={ses_inference_steps}") + scorer = SESRewardScorer(ses_reward_model, device=self.device, dtype=self.torch_dtype) + self.load_models_to_device(list(self.in_iteration_models) + ['vae_decoder']) + models = {name: getattr(self, name) for name in self.in_iteration_models} + + def ses_generate_callback(trial_latents): + trial_inputs = inputs_shared.copy() + + self.scheduler.set_timesteps(ses_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + eval_timesteps = self.scheduler.timesteps + curr_latents = trial_latents + + for progress_id, timestep in enumerate(eval_timesteps): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + + trial_inputs["latents"] = curr_latents + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + trial_inputs, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + curr_latents = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **trial_inputs) + + decoded_img = self.vae_decoder(curr_latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return self.vae_output_to_image(decoded_img) + + initial_noise = inputs_shared["latents"] + optimized_latents = run_ses_cem( + base_latents=initial_noise, + pipeline_callback=ses_generate_callback, + prompt=prompt, + scorer=scorer, + total_eval_budget=ses_eval_budget, + popsize=10, + k_elites=5 + ) + inputs_shared["latents"] = optimized_latents + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + del scorer + torch.cuda.empty_cache() + # Denoise self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 75cfbee77..8a75db74f 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -20,6 +20,7 @@ from ..models.dinov3_image_encoder import DINOv3ImageEncoder from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel +from ..utils.inference_time_scaling.ses import run_ses_cem, SESRewardScorer class QwenImagePipeline(BasePipeline): @@ -141,6 +142,11 @@ def __call__( tile_stride: int = 64, # Progress bar progress_bar_cmd = tqdm, + # SES + enable_ses: bool = False, + ses_reward_model: str = "pick", + ses_eval_budget: int = 50, + ses_inference_steps: int = 10, ): # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu) @@ -171,6 +177,51 @@ def __call__( for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + # Inference-Time Scaling (SES) + if enable_ses: + print(f"[SES] Starting optimization with budget={ses_eval_budget}, steps={ses_inference_steps}") + scorer = SESRewardScorer(ses_reward_model, device=self.device, dtype=self.torch_dtype) + + self.load_models_to_device(list(self.in_iteration_models) + ['vae']) + models = {name: getattr(self, name) for name in self.in_iteration_models} + + def ses_generate_callback(trial_latents): + trial_inputs = inputs_shared.copy() + self.scheduler.set_timesteps(ses_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu) + eval_timesteps = self.scheduler.timesteps + curr_latents = trial_latents + + for progress_id, timestep in enumerate(eval_timesteps): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + + trial_inputs["latents"] = curr_latents + + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + trial_inputs, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + curr_latents = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **trial_inputs) + + decoded_img = self.vae.decode(curr_latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return self.vae_output_to_image(decoded_img) + + initial_noise = inputs_shared["latents"] + + optimized_latents = run_ses_cem( + base_latents=initial_noise, + pipeline_callback=ses_generate_callback, + prompt=prompt, + scorer=scorer, + total_eval_budget=ses_eval_budget, + popsize=10, + k_elites=5 + ) + inputs_shared["latents"] = optimized_latents + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu) + del scorer + torch.cuda.empty_cache() + # Denoise self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} @@ -182,7 +233,7 @@ def __call__( **models, timestep=timestep, progress_id=progress_id ) inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) - + # Decode self.load_models_to_device(['vae']) image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index 2c5b68730..decfd91d9 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -23,6 +23,7 @@ from ..models.dinov3_image_encoder import DINOv3ImageEncoder from ..models.z_image_image2lora import ZImageImage2LoRAModel +from ..utils.inference_time_scaling.ses import run_ses_cem, SESRewardScorer class ZImagePipeline(BasePipeline): @@ -116,6 +117,11 @@ def __call__( positive_only_lora: Dict[str, torch.Tensor] = None, # Progress bar progress_bar_cmd = tqdm, + # SES + enable_ses: bool = False, + ses_reward_model: str = "pick", + ses_eval_budget: int = 50, + ses_inference_steps: int = 10, ): # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) @@ -140,6 +146,46 @@ def __call__( for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + if enable_ses: + print(f"[SES] Starting optimization with budget={ses_eval_budget}, steps={ses_inference_steps}") + scorer = SESRewardScorer(ses_reward_model, device=self.device, dtype=self.torch_dtype) + self.load_models_to_device(list(self.in_iteration_models) + ['vae_decoder']) + models = {name: getattr(self, name) for name in self.in_iteration_models} + def ses_generate_callback(trial_latents): + trial_inputs = inputs_shared.copy() + self.scheduler.set_timesteps(ses_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + eval_timesteps = self.scheduler.timesteps + curr_latents = trial_latents + + for progress_id, timestep in enumerate(eval_timesteps): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + trial_inputs["latents"] = curr_latents + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + trial_inputs, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + curr_latents = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **trial_inputs) + decoded_img = self.vae_decoder(curr_latents) + return self.vae_output_to_image(decoded_img) + + initial_noise = inputs_shared["latents"] + + optimized_latents = run_ses_cem( + base_latents=initial_noise, + pipeline_callback=ses_generate_callback, + prompt=prompt, + scorer=scorer, + total_eval_budget=ses_eval_budget, + popsize=10, + k_elites=5 + ) + inputs_shared["latents"] = optimized_latents + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + del scorer + torch.cuda.empty_cache() + # Denoise self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} diff --git a/diffsynth/utils/inference_time_scaling/__init__.py b/diffsynth/utils/inference_time_scaling/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/diffsynth/utils/inference_time_scaling/ses.py b/diffsynth/utils/inference_time_scaling/ses.py new file mode 100644 index 000000000..f00b7fd9a --- /dev/null +++ b/diffsynth/utils/inference_time_scaling/ses.py @@ -0,0 +1,174 @@ +import torch +import pywt +import numpy as np +import torch.nn as nn +from transformers import AutoProcessor, AutoModel +from tqdm import tqdm +import os + +def split_dwt(z_tensor_cpu, wavelet_name, dwt_level): + all_clow_np = [] + all_chigh_list = [] + z_tensor_cpu = z_tensor_cpu.cpu().float() + + for i in range(z_tensor_cpu.shape[0]): + z_numpy_ch = z_tensor_cpu[i].numpy() + + coeffs_ch = pywt.wavedec2(z_numpy_ch, wavelet_name, level=dwt_level, mode='symmetric', axes=(-2, -1)) + + clow_np = coeffs_ch[0] + chigh_list = coeffs_ch[1:] + + all_clow_np.append(clow_np) + all_chigh_list.append(chigh_list) + + all_clow_tensor = torch.from_numpy(np.stack(all_clow_np, axis=0)) + return all_clow_tensor, all_chigh_list + +def reconstruct_dwt(c_low_tensor_cpu, c_high_coeffs, wavelet_name, original_shape): + H_high, W_high = original_shape + c_low_tensor_cpu = c_low_tensor_cpu.cpu().float() + + clow_np = c_low_tensor_cpu.numpy() + + if clow_np.ndim == 4 and clow_np.shape[0] == 1: + clow_np = clow_np[0] + + coeffs_combined = [clow_np] + c_high_coeffs + z_recon_np = pywt.waverec2(coeffs_combined, wavelet_name, mode='symmetric', axes=(-2, -1)) + if z_recon_np.shape[-2] != H_high or z_recon_np.shape[-1] != W_high: + z_recon_np = z_recon_np[..., :H_high, :W_high] + z_recon_tensor = torch.from_numpy(z_recon_np) + if z_recon_tensor.ndim == 3: + z_recon_tensor = z_recon_tensor.unsqueeze(0) + return z_recon_tensor + +class SESRewardScorer: + def __init__(self, reward_name, device, dtype): + self.reward_name = reward_name + self.device = device + self.dtype = dtype + self.model = None + self.processor = None + self._load_model() + + def _load_model(self): + print(f"[SES] Loading Reward Model: {self.reward_name}...") + if self.reward_name == "pick": + self.processor = AutoProcessor.from_pretrained("yuvalkirstain/PickScore_v1") + self.model = AutoModel.from_pretrained("yuvalkirstain/PickScore_v1", torch_dtype=self.dtype).to(self.device).eval() + elif self.reward_name == "clip": + self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + self.model = AutoModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=self.dtype).to(self.device).eval() + elif self.reward_name == "hps": + self.processor = AutoProcessor.from_pretrained("adams-story/HPSv2-hf") + self.model = AutoModel.from_pretrained("adams-story/HPSv2-hf", torch_dtype=self.dtype, trust_remote_code=True).to(self.device).eval() + else: + print(f"[SES] Warning: Reward model {self.reward_name} not implemented in wrapper, skipping.") + + def get_score(self, image_pil, text_prompt): + try: + with torch.no_grad(): + if self.reward_name == "pick": + inputs = self.processor(text=[text_prompt], images=[image_pil], return_tensors="pt", padding="max_length", truncation=True, max_length=77).to(self.device) + inputs['pixel_values'] = inputs['pixel_values'].to(self.dtype) + outputs = self.model(**inputs) + return outputs.logits_per_image[0, 0].item() + + elif self.reward_name == "clip": + inputs = self.processor(text=[text_prompt], images=[image_pil], return_tensors="pt", padding=True, truncation=True).to(self.device) + if 'pixel_values' in inputs: inputs['pixel_values'] = inputs['pixel_values'].to(self.dtype) + outputs = self.model(**inputs) + return outputs.logits_per_image.item() + + elif self.reward_name == "hps": + inputs = self.processor(text=[text_prompt], images=[image_pil], return_tensors="pt", padding="max_length", truncation=True, max_length=77).to(self.device) + inputs['pixel_values'] = inputs['pixel_values'].to(self.dtype) + outputs = self.model(**inputs) + return outputs.logits_per_image[0, 0].item() + except Exception as e: + print(f"Error computing score: {e}") + return 0.0 + return 0.0 +def run_ses_cem( + base_latents, + pipeline_callback, + prompt, + scorer, + total_eval_budget=50, + popsize=10, + k_elites=5, + wavelet_name="db1", + dwt_level=5, + lambda_prior=1e-3 +): + latent_h, latent_w = base_latents.shape[-2], base_latents.shape[-1] + c_low_init, c_high_fixed_batch = split_dwt(base_latents, wavelet_name, dwt_level) + c_high_fixed = c_high_fixed_batch[0] + c_low_shape = c_low_init.shape[1:] + mu = c_low_init.view(-1).cpu() + sigma_sq = torch.ones_like(mu) * 1.0 + + best_overall = {"fitness": -float('inf'), "score": -float('inf'), "c_low": c_low_init[0]} + eval_count = 0 + + elite_db = [] + n_generations = (total_eval_budget // popsize) + 5 + pbar = tqdm(total=total_eval_budget, desc="[SES] Searching", unit="img") + + for gen in range(n_generations): + if eval_count >= total_eval_budget: break + + std = torch.sqrt(torch.clamp(sigma_sq, min=1e-9)) + z_noise = torch.randn(popsize, mu.shape[0]) + samples_flat = mu + z_noise * std + samples_reshaped = samples_flat.view(popsize, *c_low_shape) + + batch_results = [] + + for i in range(popsize): + if eval_count >= total_eval_budget: break + + c_low_sample = samples_reshaped[i].unsqueeze(0) + z_recon = reconstruct_dwt(c_low_sample, c_high_fixed, wavelet_name, (latent_h, latent_w)) + z_recon = z_recon.to(base_latents.device, dtype=base_latents.dtype) + img = pipeline_callback(z_recon) + + score = scorer.get_score(img, prompt) + penalty = lambda_prior * (torch.norm(c_low_sample.float())**2).item() + fitness = score - penalty + + res = { + "fitness": fitness, + "score": score, + "c_low": c_low_sample.cpu() + } + batch_results.append(res) + if fitness > best_overall['fitness']: + best_overall = res + + eval_count += 1 + pbar.update(1) + pbar.set_postfix({ + "Gen": gen, + "Best": f"{best_overall['score']:.4f}" + }) + + if not batch_results: break + elite_db.extend(batch_results) + elite_db.sort(key=lambda x: x['fitness'], reverse=True) + elite_db = elite_db[:k_elites] + elites_flat = torch.stack([x['c_low'].view(-1) for x in elite_db]) + mu_new = torch.mean(elites_flat, dim=0) + + if len(elite_db) > 1: + sigma_sq_new = torch.var(elites_flat, dim=0, unbiased=True) + 1e-7 + else: + sigma_sq_new = sigma_sq + mu = mu_new + sigma_sq = sigma_sq_new + pbar.close() + best_c_low = best_overall['c_low'] + final_latents = reconstruct_dwt(best_c_low, c_high_fixed, wavelet_name, (latent_h, latent_w)) + + return final_latents.to(base_latents.device, dtype=base_latents.dtype) \ No newline at end of file diff --git a/docs/en/Research_Tutorial/inference_time_scaling.md b/docs/en/Research_Tutorial/inference_time_scaling.md new file mode 100644 index 000000000..3de452d14 --- /dev/null +++ b/docs/en/Research_Tutorial/inference_time_scaling.md @@ -0,0 +1,99 @@ +# Inference-Time Scaling + +DiffSynth-Studio supports Inference-time Scaling technology, specifically implementing the **Spectral Evolution Search (SES)** algorithm. This technology allows users to improve generated image quality during the inference stage by increasing computational cost, without retraining the model. + +## 1. Basic Principle + +The traditional text-to-image inference process starts from random Gaussian noise and generates an image through a fixed number of denoising steps. The quality generated in this way highly depends on the randomness of the initial noise. + +**SES (Spectral Evolution Search)** transforms the inference process into a **search optimization problem** targeting the initial noise: + +1. **Search Space**: Search for the lowest frequency part of the initial noise in the frequency domain of wavelet transform. +2. **Evolutionary Strategy**: Use the Cross-Entropy Method to iteratively sample noise populations. +3. **Reward Feedback**: Use reward models such as PickScore to score generated low-step preview images. +4. **Result Output**: Find the noise with the highest score and perform complete high-quality denoising. + +This method essentially **trades inference computation time for generation quality**. + +For more technical details on this method, please refer to the paper: **[Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation](https://arxiv.org/abs/2602.03208)**. + +## 2. Quick Start + +In DiffSynth-Studio, SES has been integrated into the pipelines of mainstream text-to-image models. You only need to set `enable_ses=True` when calling `pipe()` to enable it. + +The following is [quick start code](https://www.google.com/search?q=../../../examples/z_image/model_inference/Z-Image-Turbo-SES.py) using **Z-Image-Turbo** as an example: + +```python +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +import torch + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +prompt = "Chinese woman in red Hanfu holding a fan, with a bright yellow neon lightning bolt floating above her palm. Soft-lit outdoor night background, silhouetted tiered pagoda, blurred colorful distant lights." + +image = pipe( + prompt=prompt, + seed=42, + rand_device="cuda", + enable_ses=True, + ses_reward_model="pick", + ses_eval_budget=50, + ses_inference_steps=8 +) +image.save("image_Z-Image-Turbo_ses.jpg") +``` + +## 3. Supported Models and Parameters + +### 3.1 Core Parameters Detailed + +In the `pipe()` call, the following parameters control the behavior of SES: + +| Parameter Name | Type | Default Value | Description | +| --- | --- | --- | --- | +| `enable_ses` | `bool` | `False` | Whether to enable SES optimization. | +| `ses_reward_model` | `str` | `"pick"` | Reward model selection. Supports `"pick"` (PickScore), `"hps"` (HPSv2), `"clip"`. | +| `ses_eval_budget` | `int` | `50` | Total search budget (total number of evaluated samples). Higher values mean a higher quality ceiling but longer time consumption. | +| `ses_inference_steps` | `int` | `10` | The number of steps used to generate preview images during the search phase. Higher values provide more accurate quality assessment for candidate noise but take longer; 8-15 is recommended. | + +### 3.2 Supported Model List + +Currently, the following text-to-image models support SES: + +* **[Qwen-Image](../../../examples/qwen_image/model_inference/Qwen-Image-SES.py)** +* **[FLUX.1-dev](../../../examples/flux/model_inference/FLUX.1-dev-SES.py)** +* **[FLUX.2-dev](../../../examples/flux2/model_inference/FLUX.2-dev-SES.py)** +* **[Z-Image](../../../examples/z_image/model_inference/Z-Image-SES.py) / [Z-Image-Turbo](../../../examples/z_image/model_inference/Z-Image-Turbo-SES.py)** + +## 4. Effect Demonstration + +As the search budget (`ses_eval_budget`) increases, SES can stably improve image quality. The following shows the quality changes brought by different computational budgets under the same random seed. + +**Scenario 1: Qwen-Image** + +* **Prompt**: *"Springtime in the style of Paul Delvaux"* +* **Reward Model**: PickScore + +| **Budget = 0** | **Budget = 10** | **Budget = 30** | **Budget = 50** | +| --- | --- | --- | --- | +| | | | | +| Image | Image | Image | Image | + +**Scenario 2: FLUX.1-dev** + +* **Prompt**: *"A masterful painting of a young woman in the style of Diego Velázquez."* +* **Reward Model**: HPSv2 + +| **Budget = 0** | **Budget = 10** | **Budget = 30** | **Budget = 50** | +| --- | --- | --- | --- | +| | | | | +| Image | Image | Image | Image | \ No newline at end of file diff --git a/docs/zh/Research_Tutorial/inference_time_scaling.md b/docs/zh/Research_Tutorial/inference_time_scaling.md new file mode 100644 index 000000000..5ba28dfec --- /dev/null +++ b/docs/zh/Research_Tutorial/inference_time_scaling.md @@ -0,0 +1,100 @@ +# 推理时扩展 + +DiffSynth-Studio 支持推理时扩展(Inference-time Scaling)技术,具体实现了 **Spectral Evolution Search (SES)** 算法。该技术允许用户在推理阶段通过增加计算量来提升生成图像质量,而无需重新训练模型。 + +## 1. 基本原理 + +传统的文生图推理过程是从一个随机的高斯噪声开始,经过固定的去噪步数生成图像。这种方式生成的质量高度依赖于初始噪声的随机性。 + +**SES (Spectral Evolution Search)** 将推理过程转化为一个针对初始噪声的**搜索优化问题**: + +1. **搜索空间**:在小波变换的频域空间中搜索初始噪声的最低频部分。 +2. **进化策略**:使用交叉熵方法迭代采样噪声群体。 +3. **奖励反馈**:利用 PickScore 等奖励模型对生成的低步数预览图进行评分。 +4. **结果输出**:找到得分最高的噪声,进行完整的高质量去噪。 + +这种方法本质上是用**推理计算时间换取生成质量**。 + +关于该方法的更多技术细节,请参考论文:**[Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation](https://arxiv.org/abs/2602.03208)**。 + +## 2. 快速开始 + +在 DiffSynth-Studio 中,SES 已集成到主流文生图模型的 Pipeline 中。你只需在调用 `pipe()` 时设置 `enable_ses=True` 即可开启。 + +以下是以 **Z-Image-Turbo** 为例的[快速上手代码](../../../examples/z_image/model_inference/Z-Image-Turbo-SES.py): + +```python +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +import torch + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +prompt = "Chinese woman in red Hanfu holding a fan, with a bright yellow neon lightning bolt floating above her palm. Soft-lit outdoor night background, silhouetted tiered pagoda, blurred colorful distant lights." + +image = pipe( + prompt=prompt, + seed=42, + rand_device="cuda", + enable_ses=True, + ses_reward_model="pick", + ses_eval_budget=50, + ses_inference_steps=8 +) +image.save("image_Z-Image-Turbo_ses.jpg") +``` + +## 3. 支持的模型与参数 + +### 3.1 核心参数详解 + +在 `pipe()` 调用中,以下参数控制 SES 的行为: + +| 参数名 | 类型 | 默认值 | 说明 | +| --- | --- | --- | --- | +| `enable_ses` | `bool` | `False` | 是否开启 SES 优化。 | +| `ses_reward_model` | `str` | `"pick"` | 奖励模型选择。支持 `"pick"` (PickScore), `"hps"` (HPSv2), `"clip"`。 | +| `ses_eval_budget` | `int` | `50` | 搜索的总预算(评估样本总数)。数值越高,质量上限越高,但耗时越长。 | +| `ses_inference_steps` | `int` | `10` | 搜索阶段生成预览图使用的步数。数值越高,对于候选噪声的质量评估越准确,但耗时越长,建议设为 8~15 。 | + +### 3.2 支持模型列表 + +目前以下文生图模型均已支持 SES: + +* **[Qwen-Image](../../../examples/qwen_image/model_inference/Qwen-Image-SES.py)** +* **[FLUX.1-dev](../../../examples/flux/model_inference/FLUX.1-dev-SES.py)** +* **[FLUX.2-dev](../../../examples/flux2/model_inference/FLUX.2-dev-SES.py)** +* **[Z-Image](../../../examples/z_image/model_inference/Z-Image-SES.py) / [Z-Image-Turbo](../../../examples/z_image/model_inference/Z-Image-Turbo-SES.py)** + + +## 4. 效果展示 + +随着搜索预算(`ses_eval_budget`)的增加,SES 能够稳定地提升图像质量。以下展示了在相同随机种子下,不同计算预算带来的质量变化。 + +**场景 1:Qwen-Image** + +* **Prompt**: *"Springtime in the style of Paul Delvaux"* +* **Reward Model**: PickScore + +| **Budget = 0** | **Budget = 10** | **Budget = 30** | **Budget = 50** | +| --- | --- | --- | --- | +| | | | | +| Image | Image | Image | Image | + +**场景 2:FLUX.1-dev** + +* **Prompt**: *"A masterful painting of a young woman in the style of Diego Velázquez."* +* **Reward Model**: HPSv2 + +| **Budget = 0** | **Budget = 10** | **Budget = 30** | **Budget = 50** | +| --- | --- | --- | --- | +| | | | | +| Image | Image | Image | Image | \ No newline at end of file diff --git a/examples/flux/model_inference/FLUX.1-dev-SES.py b/examples/flux/model_inference/FLUX.1-dev-SES.py new file mode 100644 index 000000000..5e4280c7d --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-SES.py @@ -0,0 +1,29 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +prompt = "A solo girl with silver wavy hair and blue eyes, wearing a blue dress, underwater, air bubbles, floating hair." +negative_prompt = "nsfw, low quality" + +image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=0, + cfg_scale=2, + num_inference_steps=50, + enable_ses=True, + ses_reward_model="pick", + ses_eval_budget=20, + ses_inference_steps=20 +) +image.save("flux_ses_optimized.jpg") \ No newline at end of file diff --git a/examples/flux2/model_inference/FLUX.2-dev-SES.py b/examples/flux2/model_inference/FLUX.2-dev-SES.py new file mode 100644 index 000000000..af3c608b6 --- /dev/null +++ b/examples/flux2/model_inference/FLUX.2-dev-SES.py @@ -0,0 +1,37 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), +) +prompt = "A hermit crab using a soda can as its shell on the beach. The can has the text 'BFL Diffusers' on it." + +image = pipe( + prompt, + seed=42, + rand_device="cuda", + num_inference_steps=50, + enable_ses=True, + ses_reward_model="pick", + ses_eval_budget=20, + ses_inference_steps=10 +) + +image.save("image_FLUX.2-dev_ses.jpg") diff --git a/examples/qwen_image/model_inference/Qwen-Image-SES.py b/examples/qwen_image/model_inference/Qwen-Image-SES.py new file mode 100644 index 000000000..62b047323 --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-SES.py @@ -0,0 +1,27 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) + +prompt = "水下少女,身穿蓝裙,周围有气泡。" + +image = pipe( + prompt, + seed=0, + num_inference_steps=40, + enable_ses=True, + ses_reward_model="pick", + ses_eval_budget=20, + ses_inference_steps=10 +) + +image.save("image_ses.jpg") \ No newline at end of file diff --git a/examples/z_image/model_inference/Z-Image-SES.py b/examples/z_image/model_inference/Z-Image-SES.py new file mode 100644 index 000000000..2f4f8c111 --- /dev/null +++ b/examples/z_image/model_inference/Z-Image-SES.py @@ -0,0 +1,30 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +prompt = "Chinese woman in red Hanfu holding a fan, with a bright yellow neon lightning bolt floating above her palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." + +image = pipe( + prompt=prompt, + seed=42, + num_inference_steps=50, + cfg_scale=4, + rand_device="cuda", + enable_ses=True, + ses_reward_model="pick", + ses_eval_budget=20, + ses_inference_steps=10 +) +image.save("image_Z-Image_ses.jpg") + + diff --git a/examples/z_image/model_inference/Z-Image-Turbo-SES.py b/examples/z_image/model_inference/Z-Image-Turbo-SES.py new file mode 100644 index 000000000..4248c975f --- /dev/null +++ b/examples/z_image/model_inference/Z-Image-Turbo-SES.py @@ -0,0 +1,26 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +import torch + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +prompt = "Chinese woman in red Hanfu holding a fan, with a bright yellow neon lightning bolt floating above her palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." + +image = pipe( + prompt=prompt, + seed=42, + rand_device="cuda", + enable_ses=True, + ses_reward_model="pick", + ses_eval_budget=50, + ses_inference_steps=8 +) +image.save("image_Z-Image-Turbo_ses.jpg") \ No newline at end of file