Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions examples/txt2img/spacing_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
Sweep num_inference_steps (S in {10, 15, 20, 25, 30}) across all samplers and both
production models to find where timestep-spacing actually changes output.

Key insight: the LCM distillation grid is [19, 39, ..., 999] (congruent 19 mod 20).
normal: always subsamples this grid -> always on-grid at every S.
sgm_uniform: (trailing) coincides with the distillation grid only when S divides 50
(divisors in this range: {10, 25}) -> MSE~0 vs normal.
At non-divisors (15, 20, 30) it steps off-grid and diverges from normal.
ddim: (leading) off-grid at essentially all non-trivial S; excludes t=999.
simple: (linspace) off-grid; spans full [0, 999].

t_index is scaled proportionally to hold denoising strength constant across S:
production [32, 45] at S=50 -> fractions 0.64, 0.90
-> t_index = [round(0.64*S), round(0.90*S)]

Run from the repo's StreamDiffusion/ dir in its venv:
python examples/txt2img/spacing_compare.py
"""

import os
import sys

import numpy as np
import torch
from PIL import Image


sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from streamdiffusion import StreamDiffusionWrapper


CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
INPUT_IMAGE = os.path.join(CURRENT_DIR, "..", "..", "images", "inputs", "input.png")
OUTPUT_DIR = os.path.join(CURRENT_DIR, "..", "..", "images", "outputs", "spacing_compare")

MODELS = [
"stabilityai/sd-turbo",
"stabilityai/sdxl-turbo",
]
STEP_COUNTS = [10, 15, 20, 25, 30] # divisors of 50: {10, 25}; non-divisors: {15, 20, 30}
SAMPLERS = [
("normal", "LCM native (baseline, always on distillation grid)"),
("sgm_uniform", "trailing — on-grid only when S divides 50"),
("ddim", "leading — off-grid, excludes t=999"),
("simple", "linspace — off-grid, spans [0, 999]"),
]
PROMPT = "a peaceful mountain landscape at golden hour"
SEED = 2
WIDTH = HEIGHT = 512


def proportional_t_index(S: int) -> list[int]:
"""Scale production t_index=[32,45] at S=50 (fractions 0.64, 0.90) to arbitrary S."""
a = min(round(0.64 * S), S - 2)
b = min(round(0.90 * S), S - 1)
if b <= a:
b = a + 1
return [a, b]


def mean_brightness(img: Image.Image) -> float:
return float(np.array(img).astype(np.float32).mean() / 255.0)


def mse(a: Image.Image, b: Image.Image) -> float:
return float(((np.array(a).astype(np.float32) - np.array(b).astype(np.float32)) ** 2).mean())


def on_grid(sub_ts) -> bool:
"""All sub-timesteps on the LCM distillation grid (congruent 19 mod 20)."""
return all(int(t) % 20 == 19 for t in sub_ts)


def run_model(model_id: str) -> None:
os.makedirs(OUTPUT_DIR, exist_ok=True)
model_tag = model_id.split("/")[-1]

for S in STEP_COUNTS:
t_index = proportional_t_index(S)
is_divisor = 50 % S == 0
print(f"\n{'=' * 70}")
print(
f"Model: {model_id} S={S} t_index={t_index} "
f"({'divides 50 -> trailing==normal' if is_divisor else 'does NOT divide 50 -> trailing diverges'})"
)
print(f"{'=' * 70}")

results: dict[str, Image.Image] = {}

for sampler_name, description in SAMPLERS:
wrapper = StreamDiffusionWrapper(
model_id_or_path=model_id,
t_index_list=t_index,
frame_buffer_size=1,
width=WIDTH,
height=HEIGHT,
warmup=10,
acceleration="none",
mode="img2img",
use_denoising_batch=True,
cfg_type="none",
seed=SEED,
scheduler="lcm",
sampler=sampler_name,
)
wrapper.prepare(
prompt=PROMPT,
num_inference_steps=S,
)

image_tensor = wrapper.preprocess_image(INPUT_IMAGE)

sub_ts = wrapper.stream.sub_timesteps
grid_flag = on_grid(sub_ts)

for _ in range(wrapper.batch_size - 1):
wrapper(image=image_tensor)
img = wrapper(image=image_tensor)

out_path = os.path.join(OUTPUT_DIR, f"{model_tag}_S{S:02d}_{sampler_name}.png")
img.save(out_path)
results[sampler_name] = img

brightness = mean_brightness(img)
print(
f" {sampler_name:12s} on_grid={str(grid_flag):5s} sub_ts={[int(t) for t in sub_ts]}\n"
f" brightness={brightness:.4f} ({description})"
)

del wrapper
torch.cuda.empty_cache()

print(f"\n MSE vs 'normal' baseline (S={S}):")
baseline = results["normal"]
for sampler_name, _ in SAMPLERS:
if sampler_name != "normal":
err = mse(results[sampler_name], baseline)
print(f" {sampler_name:12s} MSE={err:.2f}")

print(f"\n Images saved to: {OUTPUT_DIR}/{model_tag}_S*_*.png")


if __name__ == "__main__":
for model_id in MODELS:
run_model(model_id)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_cuda_constraint():
"Pillow>=12.2.0", # CVE-2026-25990: out-of-bounds write in PSD loading; 12.2.0 verified
"fire==0.7.1",
"omegaconf==2.3.0",
"onnx==1.19.1", # IR 11; modelopt needs FLOAT4E2M1 (added in 1.18); onnx-gs 0.6.1 no longer needs float32_to_bfloat16
"onnx==1.19.1", # IR 11; modelopt FLOAT4E2M1 (1.18+); 1.21.0 breaks FP8 quant (external-data loading → negative QDQ scale); 6 path-traversal CVEs accepted: require untrusted model loading
"onnxruntime-gpu==1.24.4", # TRT EP, supports IR 11; never co-install CPU onnxruntime — shared files conflict
"onnxoptimizer==0.4.2",
"onnxslim==0.1.91",
Expand Down
3 changes: 3 additions & 0 deletions src/streamdiffusion/acceleration/tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@


def cosine_distance(image_embeds, text_embeds):
# Returns cosine SIMILARITY (dot product of unit-normalised vectors ∈ [−1,1]),
# not a distance. Name preserved for compatibility with HF diffusers'
# StableDiffusionSafetyChecker, which defines it identically.
normalized_image_embeds = nn.functional.normalize(image_embeds)
normalized_text_embeds = nn.functional.normalize(text_embeds)
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
Expand Down
12 changes: 11 additions & 1 deletion src/streamdiffusion/image_filter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
from typing import Optional
import random
from typing import Optional

import torch
import torch.nn.functional as F


class SimilarImageFilter:
"""Stochastic frame-skip filter (StreamDiffusion §3.3).

NOTE: the StreamDiffusion paper describes cosine similarity in latent space to
compute the skip probability. This implementation uses pixel-space MSE for
simplicity and speed. `threshold` is remapped via ``mse_threshold = 1 − threshold``,
so ``threshold=0.98`` means frames with MSE < 0.02 have a non-zero skip probability.
The stochastic skip logic uses a 1-frame delay (skip probability is computed
asynchronously and applied on the next frame) to avoid GPU stalls.
"""

def __init__(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None:
self.threshold = threshold
self._mse_threshold: float = max(1e-7, 1.0 - threshold)
Expand Down
70 changes: 55 additions & 15 deletions src/streamdiffusion/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
normalize_prompt_weights: bool = True,
normalize_seed_weights: bool = True,
scheduler: Literal["lcm", "tcd"] = "lcm",
sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal",
sampler: Literal["simple", "sgm_uniform", "normal", "ddim", "beta", "karras"] = "normal",
kvo_cache: List[torch.Tensor] = [],
cache_interval: int = 1,
cache_maxframes: int = 1,
Expand Down Expand Up @@ -180,11 +180,11 @@ def _initialize_scheduler(self, scheduler_type: str, sampler_type: str, config):
# Map sampler types to configuration parameters
sampler_config = {
"simple": {"timestep_spacing": "linspace"},
"sgm uniform": {"timestep_spacing": "trailing"},
"sgm_uniform": {"timestep_spacing": "trailing"},
"normal": {}, # Default configuration
"ddim": {"timestep_spacing": "leading"},
"beta": {"beta_schedule": "scaled_linear"},
"karras": {}, # Karras sigmas handled per scheduler
"beta": {"beta_schedule": "scaled_linear"}, # no-op: equals SD/SDXL default
"karras": {}, # no-op: LCM/TCD have no karras-sigma logic
}

# Get sampler-specific configuration
Expand All @@ -203,6 +203,27 @@ def _initialize_scheduler(self, scheduler_type: str, sampler_type: str, config):
logger.warning(f"Unknown scheduler type '{scheduler_type}', falling back to LCM")
return LCMScheduler.from_config(config, **sampler_params)

def _get_spaced_timesteps(self, spacing: str, num_inference_steps: int) -> torch.Tensor:
"""Return a descending timestep schedule per the paper's Table 2 spacing strategies.

LCMScheduler and TCDScheduler store timestep_spacing in config but never consume it
in set_timesteps (they always build a linearly-spaced grid). This produces the
correct leading/linspace/trailing schedules so the sampler option actually takes
effect. 'normal' bypasses this and uses the LCM/TCD native grid.
"""
T = self.scheduler.config.num_train_timesteps
S = num_inference_steps
if spacing == "trailing":
# round(arange(T, 0, -T/S)) - 1 — descending, includes T-1
ts = torch.arange(T, 0, -T / S).round().long() - 1
elif spacing == "linspace":
# linspace(0, T-1, S) reversed — includes both 0 and T-1
ts = torch.linspace(0, T - 1, S).round().long().flip(0)
else: # "leading"
# arange(S) * (T//S) reversed — includes 0, excludes T-1
ts = (torch.arange(S, dtype=torch.float32) * (T // S)).round().long().flip(0)
return ts.clamp(0, T - 1)

def _check_unet_tensorrt(self) -> bool:
"""Cache TensorRT detection to avoid repeated hasattr calls"""
if self._is_unet_tensorrt is None:
Expand Down Expand Up @@ -482,6 +503,13 @@ def prepare(
self.prompt_embeds = embeds_ctx.prompt_embeds

self.scheduler.set_timesteps(num_inference_steps, self.device)
# LCM/TCD ignore timestep_spacing natively. Apply the correct grid for samplers
# that explicitly request a spacing; leave "normal" on the LCM native schedule.
_SPACING_SAMPLERS = {"simple", "sgm_uniform", "ddim"}
if self.sampler_type in _SPACING_SAMPLERS:
spacing = getattr(self.scheduler.config, "timestep_spacing", "leading")
if spacing in ("trailing", "linspace", "leading"):
self.scheduler.timesteps = self._get_spaced_timesteps(spacing, num_inference_steps).to(self.device)
self.timesteps = self.scheduler.timesteps.to(self.device)

# make sub timesteps list based on the indices in the t_list list and the values in the timesteps list
Expand Down Expand Up @@ -516,6 +544,9 @@ def prepare(

alpha_prod_t_sqrt_list = []
beta_prod_t_sqrt_list = []
# alpha_prod_t_sqrt = √ᾱₜ (cumulative signal-retention std)
# beta_prod_t_sqrt = √(1−ᾱₜ) (cumulative MARGINAL noise std, NOT per-step βₜ)
# Notation follows diffusers convention; ᾱₜ = Πᵢ(1−βᵢ) is computed by the scheduler.
for timestep in self.sub_timesteps:
alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt()
beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt()
Expand Down Expand Up @@ -621,7 +652,7 @@ def get_normalize_seed_weights(self) -> bool:
def set_scheduler(
self,
scheduler: Literal["lcm", "tcd"] = None,
sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = None,
sampler: Literal["simple", "sgm_uniform", "normal", "ddim", "beta", "karras"] = None,
) -> None:
"""
Change the scheduler and/or sampler at runtime.
Expand Down Expand Up @@ -842,15 +873,24 @@ def unet_step(
if hook_mid_res is not None:
ip_scale_kw["mid_block_additional_residual"] = hook_mid_res

model_pred, kvo_cache_out = self.unet(
x_t_latent_plus_uc,
t_list,
encoder_hidden_states=self.prompt_embeds,
kvo_cache=self.kvo_cache,
return_dict=False,
**ip_scale_kw,
)
self.update_kvo_cache(kvo_cache_out)
if self._check_unet_tensorrt():
model_pred, kvo_cache_out = self.unet(
x_t_latent_plus_uc,
t_list,
encoder_hidden_states=self.prompt_embeds,
kvo_cache=self.kvo_cache,
return_dict=False,
**ip_scale_kw,
)
self.update_kvo_cache(kvo_cache_out)
else:
model_pred = self.unet(
x_t_latent_plus_uc,
t_list,
encoder_hidden_states=self.prompt_embeds,
return_dict=False,
**ip_scale_kw,
)[0]

if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"):
noise_pred_text = model_pred[1:]
Expand Down Expand Up @@ -1226,7 +1266,7 @@ def txt2img_sd_turbo(self, batch_size: int = 1) -> torch.Tensor:
self.sub_timesteps_tensor,
encoder_hidden_states=self.prompt_embeds,
return_dict=False,
)
)[0]

x_0_pred_out = ((x_t_latent - self.beta_prod_t_sqrt * model_pred).float() / self.alpha_prod_t_sqrt.float()).to(
x_t_latent.dtype
Expand Down
Loading