diff --git a/examples/txt2img/spacing_compare.py b/examples/txt2img/spacing_compare.py new file mode 100644 index 000000000..0e226ede0 --- /dev/null +++ b/examples/txt2img/spacing_compare.py @@ -0,0 +1,146 @@ +""" +Sweep num_inference_steps (S in {10, 15, 20, 25, 30}) across all samplers and both +production models to find where timestep-spacing actually changes output. + +Key insight: the LCM distillation grid is [19, 39, ..., 999] (congruent 19 mod 20). + normal: always subsamples this grid -> always on-grid at every S. + sgm_uniform: (trailing) coincides with the distillation grid only when S divides 50 + (divisors in this range: {10, 25}) -> MSE~0 vs normal. + At non-divisors (15, 20, 30) it steps off-grid and diverges from normal. + ddim: (leading) off-grid at essentially all non-trivial S; excludes t=999. + simple: (linspace) off-grid; spans full [0, 999]. + +t_index is scaled proportionally to hold denoising strength constant across S: + production [32, 45] at S=50 -> fractions 0.64, 0.90 + -> t_index = [round(0.64*S), round(0.90*S)] + +Run from the repo's StreamDiffusion/ dir in its venv: + python examples/txt2img/spacing_compare.py +""" + +import os +import sys + +import numpy as np +import torch +from PIL import Image + + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) +from streamdiffusion import StreamDiffusionWrapper + + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +INPUT_IMAGE = os.path.join(CURRENT_DIR, "..", "..", "images", "inputs", "input.png") +OUTPUT_DIR = os.path.join(CURRENT_DIR, "..", "..", "images", "outputs", "spacing_compare") + +MODELS = [ + "stabilityai/sd-turbo", + "stabilityai/sdxl-turbo", +] +STEP_COUNTS = [10, 15, 20, 25, 30] # divisors of 50: {10, 25}; non-divisors: {15, 20, 30} +SAMPLERS = [ + ("normal", "LCM native (baseline, always on distillation grid)"), + ("sgm_uniform", "trailing — on-grid only when S divides 50"), + ("ddim", "leading — off-grid, excludes t=999"), + ("simple", "linspace — off-grid, spans [0, 999]"), +] +PROMPT = "a peaceful mountain landscape at golden hour" +SEED = 2 +WIDTH = HEIGHT = 512 + + +def proportional_t_index(S: int) -> list[int]: + """Scale production t_index=[32,45] at S=50 (fractions 0.64, 0.90) to arbitrary S.""" + a = min(round(0.64 * S), S - 2) + b = min(round(0.90 * S), S - 1) + if b <= a: + b = a + 1 + return [a, b] + + +def mean_brightness(img: Image.Image) -> float: + return float(np.array(img).astype(np.float32).mean() / 255.0) + + +def mse(a: Image.Image, b: Image.Image) -> float: + return float(((np.array(a).astype(np.float32) - np.array(b).astype(np.float32)) ** 2).mean()) + + +def on_grid(sub_ts) -> bool: + """All sub-timesteps on the LCM distillation grid (congruent 19 mod 20).""" + return all(int(t) % 20 == 19 for t in sub_ts) + + +def run_model(model_id: str) -> None: + os.makedirs(OUTPUT_DIR, exist_ok=True) + model_tag = model_id.split("/")[-1] + + for S in STEP_COUNTS: + t_index = proportional_t_index(S) + is_divisor = 50 % S == 0 + print(f"\n{'=' * 70}") + print( + f"Model: {model_id} S={S} t_index={t_index} " + f"({'divides 50 -> trailing==normal' if is_divisor else 'does NOT divide 50 -> trailing diverges'})" + ) + print(f"{'=' * 70}") + + results: dict[str, Image.Image] = {} + + for sampler_name, description in SAMPLERS: + wrapper = StreamDiffusionWrapper( + model_id_or_path=model_id, + t_index_list=t_index, + frame_buffer_size=1, + width=WIDTH, + height=HEIGHT, + warmup=10, + acceleration="none", + mode="img2img", + use_denoising_batch=True, + cfg_type="none", + seed=SEED, + scheduler="lcm", + sampler=sampler_name, + ) + wrapper.prepare( + prompt=PROMPT, + num_inference_steps=S, + ) + + image_tensor = wrapper.preprocess_image(INPUT_IMAGE) + + sub_ts = wrapper.stream.sub_timesteps + grid_flag = on_grid(sub_ts) + + for _ in range(wrapper.batch_size - 1): + wrapper(image=image_tensor) + img = wrapper(image=image_tensor) + + out_path = os.path.join(OUTPUT_DIR, f"{model_tag}_S{S:02d}_{sampler_name}.png") + img.save(out_path) + results[sampler_name] = img + + brightness = mean_brightness(img) + print( + f" {sampler_name:12s} on_grid={str(grid_flag):5s} sub_ts={[int(t) for t in sub_ts]}\n" + f" brightness={brightness:.4f} ({description})" + ) + + del wrapper + torch.cuda.empty_cache() + + print(f"\n MSE vs 'normal' baseline (S={S}):") + baseline = results["normal"] + for sampler_name, _ in SAMPLERS: + if sampler_name != "normal": + err = mse(results[sampler_name], baseline) + print(f" {sampler_name:12s} MSE={err:.2f}") + + print(f"\n Images saved to: {OUTPUT_DIR}/{model_tag}_S*_*.png") + + +if __name__ == "__main__": + for model_id in MODELS: + run_model(model_id) diff --git a/setup.py b/setup.py index da3031552..da7be49ad 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ def get_cuda_constraint(): "Pillow>=12.2.0", # CVE-2026-25990: out-of-bounds write in PSD loading; 12.2.0 verified "fire==0.7.1", "omegaconf==2.3.0", - "onnx==1.19.1", # IR 11; modelopt needs FLOAT4E2M1 (added in 1.18); onnx-gs 0.6.1 no longer needs float32_to_bfloat16 + "onnx==1.19.1", # IR 11; modelopt FLOAT4E2M1 (1.18+); 1.21.0 breaks FP8 quant (external-data loading → negative QDQ scale); 6 path-traversal CVEs accepted: require untrusted model loading "onnxruntime-gpu==1.24.4", # TRT EP, supports IR 11; never co-install CPU onnxruntime — shared files conflict "onnxoptimizer==0.4.2", "onnxslim==0.1.91", diff --git a/src/streamdiffusion/acceleration/tensorrt/__init__.py b/src/streamdiffusion/acceleration/tensorrt/__init__.py index 20b8f5da0..9120f7755 100644 --- a/src/streamdiffusion/acceleration/tensorrt/__init__.py +++ b/src/streamdiffusion/acceleration/tensorrt/__init__.py @@ -17,6 +17,9 @@ def cosine_distance(image_embeds, text_embeds): + # Returns cosine SIMILARITY (dot product of unit-normalised vectors ∈ [−1,1]), + # not a distance. Name preserved for compatibility with HF diffusers' + # StableDiffusionSafetyChecker, which defines it identically. normalized_image_embeds = nn.functional.normalize(image_embeds) normalized_text_embeds = nn.functional.normalize(text_embeds) return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) diff --git a/src/streamdiffusion/image_filter.py b/src/streamdiffusion/image_filter.py index 5523c8869..61a9c88b0 100644 --- a/src/streamdiffusion/image_filter.py +++ b/src/streamdiffusion/image_filter.py @@ -1,11 +1,21 @@ -from typing import Optional import random +from typing import Optional import torch import torch.nn.functional as F class SimilarImageFilter: + """Stochastic frame-skip filter (StreamDiffusion §3.3). + + NOTE: the StreamDiffusion paper describes cosine similarity in latent space to + compute the skip probability. This implementation uses pixel-space MSE for + simplicity and speed. `threshold` is remapped via ``mse_threshold = 1 − threshold``, + so ``threshold=0.98`` means frames with MSE < 0.02 have a non-zero skip probability. + The stochastic skip logic uses a 1-frame delay (skip probability is computed + asynchronously and applied on the next frame) to avoid GPU stalls. + """ + def __init__(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None: self.threshold = threshold self._mse_threshold: float = max(1e-7, 1.0 - threshold) diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index 55922e695..3840208b2 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -49,7 +49,7 @@ def __init__( normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True, scheduler: Literal["lcm", "tcd"] = "lcm", - sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal", + sampler: Literal["simple", "sgm_uniform", "normal", "ddim", "beta", "karras"] = "normal", kvo_cache: List[torch.Tensor] = [], cache_interval: int = 1, cache_maxframes: int = 1, @@ -180,11 +180,11 @@ def _initialize_scheduler(self, scheduler_type: str, sampler_type: str, config): # Map sampler types to configuration parameters sampler_config = { "simple": {"timestep_spacing": "linspace"}, - "sgm uniform": {"timestep_spacing": "trailing"}, + "sgm_uniform": {"timestep_spacing": "trailing"}, "normal": {}, # Default configuration "ddim": {"timestep_spacing": "leading"}, - "beta": {"beta_schedule": "scaled_linear"}, - "karras": {}, # Karras sigmas handled per scheduler + "beta": {"beta_schedule": "scaled_linear"}, # no-op: equals SD/SDXL default + "karras": {}, # no-op: LCM/TCD have no karras-sigma logic } # Get sampler-specific configuration @@ -203,6 +203,27 @@ def _initialize_scheduler(self, scheduler_type: str, sampler_type: str, config): logger.warning(f"Unknown scheduler type '{scheduler_type}', falling back to LCM") return LCMScheduler.from_config(config, **sampler_params) + def _get_spaced_timesteps(self, spacing: str, num_inference_steps: int) -> torch.Tensor: + """Return a descending timestep schedule per the paper's Table 2 spacing strategies. + + LCMScheduler and TCDScheduler store timestep_spacing in config but never consume it + in set_timesteps (they always build a linearly-spaced grid). This produces the + correct leading/linspace/trailing schedules so the sampler option actually takes + effect. 'normal' bypasses this and uses the LCM/TCD native grid. + """ + T = self.scheduler.config.num_train_timesteps + S = num_inference_steps + if spacing == "trailing": + # round(arange(T, 0, -T/S)) - 1 — descending, includes T-1 + ts = torch.arange(T, 0, -T / S).round().long() - 1 + elif spacing == "linspace": + # linspace(0, T-1, S) reversed — includes both 0 and T-1 + ts = torch.linspace(0, T - 1, S).round().long().flip(0) + else: # "leading" + # arange(S) * (T//S) reversed — includes 0, excludes T-1 + ts = (torch.arange(S, dtype=torch.float32) * (T // S)).round().long().flip(0) + return ts.clamp(0, T - 1) + def _check_unet_tensorrt(self) -> bool: """Cache TensorRT detection to avoid repeated hasattr calls""" if self._is_unet_tensorrt is None: @@ -482,6 +503,13 @@ def prepare( self.prompt_embeds = embeds_ctx.prompt_embeds self.scheduler.set_timesteps(num_inference_steps, self.device) + # LCM/TCD ignore timestep_spacing natively. Apply the correct grid for samplers + # that explicitly request a spacing; leave "normal" on the LCM native schedule. + _SPACING_SAMPLERS = {"simple", "sgm_uniform", "ddim"} + if self.sampler_type in _SPACING_SAMPLERS: + spacing = getattr(self.scheduler.config, "timestep_spacing", "leading") + if spacing in ("trailing", "linspace", "leading"): + self.scheduler.timesteps = self._get_spaced_timesteps(spacing, num_inference_steps).to(self.device) self.timesteps = self.scheduler.timesteps.to(self.device) # make sub timesteps list based on the indices in the t_list list and the values in the timesteps list @@ -516,6 +544,9 @@ def prepare( alpha_prod_t_sqrt_list = [] beta_prod_t_sqrt_list = [] + # alpha_prod_t_sqrt = √ᾱₜ (cumulative signal-retention std) + # beta_prod_t_sqrt = √(1−ᾱₜ) (cumulative MARGINAL noise std, NOT per-step βₜ) + # Notation follows diffusers convention; ᾱₜ = Πᵢ(1−βᵢ) is computed by the scheduler. for timestep in self.sub_timesteps: alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt() beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt() @@ -621,7 +652,7 @@ def get_normalize_seed_weights(self) -> bool: def set_scheduler( self, scheduler: Literal["lcm", "tcd"] = None, - sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = None, + sampler: Literal["simple", "sgm_uniform", "normal", "ddim", "beta", "karras"] = None, ) -> None: """ Change the scheduler and/or sampler at runtime. @@ -842,15 +873,24 @@ def unet_step( if hook_mid_res is not None: ip_scale_kw["mid_block_additional_residual"] = hook_mid_res - model_pred, kvo_cache_out = self.unet( - x_t_latent_plus_uc, - t_list, - encoder_hidden_states=self.prompt_embeds, - kvo_cache=self.kvo_cache, - return_dict=False, - **ip_scale_kw, - ) - self.update_kvo_cache(kvo_cache_out) + if self._check_unet_tensorrt(): + model_pred, kvo_cache_out = self.unet( + x_t_latent_plus_uc, + t_list, + encoder_hidden_states=self.prompt_embeds, + kvo_cache=self.kvo_cache, + return_dict=False, + **ip_scale_kw, + ) + self.update_kvo_cache(kvo_cache_out) + else: + model_pred = self.unet( + x_t_latent_plus_uc, + t_list, + encoder_hidden_states=self.prompt_embeds, + return_dict=False, + **ip_scale_kw, + )[0] if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): noise_pred_text = model_pred[1:] @@ -1226,7 +1266,7 @@ def txt2img_sd_turbo(self, batch_size: int = 1) -> torch.Tensor: self.sub_timesteps_tensor, encoder_hidden_states=self.prompt_embeds, return_dict=False, - ) + )[0] x_0_pred_out = ((x_t_latent - self.beta_prod_t_sqrt * model_pred).float() / self.alpha_prod_t_sqrt.float()).to( x_t_latent.dtype diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index 8f81b6f89..d8755be74 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -1,15 +1,18 @@ -from typing import List, Optional, Dict, Tuple, Literal, Any, Callable +import logging import threading +from typing import Any, Dict, List, Literal, Optional, Tuple + import torch import torch.nn.functional as F -import gc -import logging + logger = logging.getLogger(__name__) from .preprocessing.orchestrator_user import OrchestratorUser + class CacheStats: """Helper class to track cache statistics""" + def __init__(self): self.hits = 0 self.misses = 0 @@ -22,7 +25,13 @@ def record_miss(self): class StreamParameterUpdater(OrchestratorUser): - def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True): + def __init__( + self, + stream_diffusion, + wrapper=None, + normalize_prompt_weights: bool = True, + normalize_seed_weights: bool = True, + ): self.stream = stream_diffusion self.wrapper = wrapper # Reference to wrapper for accessing pipeline structure self.normalize_prompt_weights = normalize_prompt_weights @@ -39,8 +48,7 @@ def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: boo self._seed_cache: Dict[int, Dict] = {} self._current_seed_list: List[Tuple[int, float]] = [] self._seed_cache_stats = CacheStats() - - + # Attach shared orchestrator once (lazy-creates on stream if absent) self.attach_orchestrator(self.stream) @@ -50,6 +58,7 @@ def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: boo self._current_style_images: Dict[str, Any] = {} # Use the shared orchestrator attached via OrchestratorUser self._embedding_orchestrator = self._preprocessing_orchestrator + def get_cache_info(self) -> Dict: """Get cache statistics for monitoring performance.""" total_requests = self._prompt_cache_stats.hits + self._prompt_cache_stats.misses @@ -68,7 +77,7 @@ def get_cache_info(self) -> Dict: "seed_cache_hits": self._seed_cache_stats.hits, "seed_cache_misses": self._seed_cache_stats.misses, "seed_hit_rate": f"{seed_hit_rate:.2%}", - "current_seeds": len(self._current_seed_list) + "current_seeds": len(self._current_seed_list), } def clear_caches(self) -> None: @@ -81,7 +90,7 @@ def clear_caches(self) -> None: self._seed_cache.clear() self._current_seed_list.clear() self._seed_cache_stats = CacheStats() - + # Clear embedding caches self._embedding_cache.clear() self._current_style_images.clear() @@ -93,13 +102,13 @@ def get_normalize_prompt_weights(self) -> bool: def get_normalize_seed_weights(self) -> bool: """Get the current seed weight normalization setting.""" return self.normalize_seed_weights - + # Deprecated enhancer registration removed; embedding composition is handled via stream.embedding_hooks def register_embedding_preprocessor(self, preprocessor: Any, style_image_key: str) -> None: """ Register an embedding preprocessor for parallel processing. - + Args: preprocessor: IPAdapterEmbeddingPreprocessor instance style_image_key: Unique key for the style image this preprocessor handles @@ -108,28 +117,27 @@ def register_embedding_preprocessor(self, preprocessor: Any, style_image_key: st # Ensure orchestrator is present self.attach_orchestrator(self.stream) self._embedding_orchestrator = self._preprocessing_orchestrator - + self._embedding_preprocessors.append((preprocessor, style_image_key)) - + def unregister_embedding_preprocessor(self, style_image_key: str) -> None: """Unregister an embedding preprocessor by style image key.""" original_count = len(self._embedding_preprocessors) self._embedding_preprocessors = [ - (preprocessor, key) for preprocessor, key in self._embedding_preprocessors - if key != style_image_key + (preprocessor, key) for preprocessor, key in self._embedding_preprocessors if key != style_image_key ] removed_count = original_count - len(self._embedding_preprocessors) - + # Clear cached embeddings for this key if style_image_key in self._embedding_cache: del self._embedding_cache[style_image_key] if style_image_key in self._current_style_images: del self._current_style_images[style_image_key] - + def update_style_image(self, style_image_key: str, style_image: Any, is_stream: bool = False) -> None: """ Update a style image and trigger embedding preprocessing. - + Args: style_image_key: Unique key for the style image style_image: The style image (PIL Image, path, etc.) @@ -138,14 +146,16 @@ def update_style_image(self, style_image_key: str, style_image: Any, is_stream: """ # Store the style image self._current_style_images[style_image_key] = style_image - + # Trigger preprocessing for this style image self._preprocess_style_image_parallel(style_image_key, style_image, is_stream) - - def _preprocess_style_image_parallel(self, style_image_key: str, style_image: Any, is_stream: bool = False) -> None: + + def _preprocess_style_image_parallel( + self, style_image_key: str, style_image: Any, is_stream: bool = False + ) -> None: """ Preprocessing for a specific style image with mode selection - + Args: style_image_key: Unique key for the style image style_image: The style image to process @@ -153,57 +163,47 @@ def _preprocess_style_image_parallel(self, style_image_key: str, style_image: An """ if not self._embedding_preprocessors or self._embedding_orchestrator is None: return - + # Find preprocessors for this key relevant_preprocessors = [ - preprocessor for preprocessor, key in self._embedding_preprocessors - if key == style_image_key + preprocessor for preprocessor, key in self._embedding_preprocessors if key == style_image_key ] - + if not relevant_preprocessors: return - + # Choose processing mode based on is_stream parameter try: if is_stream: # Pipelined processing - optimized for throughput with 1-frame lag embedding_results = self._embedding_orchestrator.process_pipelined( - style_image, - relevant_preprocessors, - None, - self.stream.width, - self.stream.height, - "ipadapter" + style_image, relevant_preprocessors, None, self.stream.width, self.stream.height, "ipadapter" ) else: # Synchronous processing - immediate results for discrete updates embedding_results = self._embedding_orchestrator.process_sync( - style_image, - relevant_preprocessors, - None, - self.stream.width, - self.stream.height, - None, - "ipadapter" + style_image, relevant_preprocessors, None, self.stream.width, self.stream.height, None, "ipadapter" ) - + # Cache results for this style image key if embedding_results and embedding_results[0] is not None: self._embedding_cache[style_image_key] = embedding_results[0] else: # This is an error condition - we should always have results - raise RuntimeError(f"_preprocess_style_image_parallel: Failed to generate embeddings for style image '{style_image_key}'") - - except Exception as e: + raise RuntimeError( + f"_preprocess_style_image_parallel: Failed to generate embeddings for style image '{style_image_key}'" + ) + + except Exception: import traceback + traceback.print_exc() - + def get_cached_embeddings(self, style_image_key: str) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: """Get cached embeddings for a style image key""" cached_result = self._embedding_cache.get(style_image_key, None) return cached_result - def _normalize_weights(self, weights: List[float], normalize: bool) -> torch.Tensor: """Generic weight normalization helper""" weights_tensor = torch.tensor(weights, device=self.stream.device, dtype=self.stream.dtype) @@ -218,7 +218,7 @@ def _validate_index(self, index: int, item_list: List, operation_name: str) -> b return False if index < 0 or index >= len(item_list): - logger.warning(f"{operation_name}: Warning: Index {index} out of range (0-{len(item_list)-1})") + logger.warning(f"{operation_name}: Warning: Index {index} out of range (0-{len(item_list) - 1})") return False return True @@ -281,28 +281,27 @@ def update_stream_params( f"provided t_index_list (max index: {max_t_index}). Adjusting to {max_t_index + 1}." ) num_inference_steps = max_t_index + 1 - + old_num_steps = len(self.stream.timesteps) self.stream.scheduler.set_timesteps(num_inference_steps, self.stream.device) self.stream.timesteps = self.stream.scheduler.timesteps.to(self.stream.device) - + # If t_index_list wasn't explicitly provided, rescale existing t_list proportionally if t_index_list is None and old_num_steps > 0: # Rescale each index proportionally to the new number of steps # e.g., if t_list = [0, 16, 32, 45] with 50 steps -> [0, 3, 6, 8] with 9 steps scale_factor = (num_inference_steps - 1) / (old_num_steps - 1) if old_num_steps > 1 else 1.0 - t_index_list = [ - min(round(t * scale_factor), num_inference_steps - 1) - for t in self.stream.t_list - ] - + t_index_list = [min(round(t * scale_factor), num_inference_steps - 1) for t in self.stream.t_list] + # Now update timestep-dependent parameters with the correct t_index_list if t_index_list is not None: self._recalculate_timestep_dependent_params(t_index_list) if guidance_scale is not None: if self.stream.cfg_type == "none" and guidance_scale > 1.0: - logger.warning("update_stream_params: Warning: guidance_scale > 1.0 with cfg_type='none' will have no effect") + logger.warning( + "update_stream_params: Warning: guidance_scale > 1.0 with cfg_type='none' will have no effect" + ) self.stream.guidance_scale = guidance_scale if delta is not None: @@ -310,7 +309,7 @@ def update_stream_params( if seed is not None: self._update_seed(seed) - + if normalize_prompt_weights is not None: self.normalize_prompt_weights = normalize_prompt_weights logger.info(f"update_stream_params: Prompt weight normalization set to {normalize_prompt_weights}") @@ -324,44 +323,42 @@ def update_stream_params( self._update_blended_prompts( prompt_list=prompt_list, negative_prompt=negative_prompt or self._current_negative_prompt, - prompt_interpolation_method=prompt_interpolation_method + prompt_interpolation_method=prompt_interpolation_method, ) # Handle seed blending if seed_list is provided if seed_list is not None: - self._update_blended_seeds( - seed_list=seed_list, - interpolation_method=seed_interpolation_method - ) - + self._update_blended_seeds(seed_list=seed_list, interpolation_method=seed_interpolation_method) # Handle ControlNet configuration updates if controlnet_config is not None: - #TODO: happy path for control images + # TODO: happy path for control images self._update_controlnet_config(controlnet_config) - + # Handle IPAdapter configuration updates if ipadapter_config is not None: - logger.info(f"update_stream_params: Updating IPAdapter configuration") + logger.info("update_stream_params: Updating IPAdapter configuration") self._update_ipadapter_config(ipadapter_config) - + # Handle Hook configuration updates if image_preprocessing_config is not None: - logger.info(f"update_stream_params: Updating image preprocessing configuration with {len(image_preprocessing_config)} processors") + logger.info( + f"update_stream_params: Updating image preprocessing configuration with {len(image_preprocessing_config)} processors" + ) logger.info(f"update_stream_params: image_preprocessing_config = {image_preprocessing_config}") - self._update_hook_config('image_preprocessing', image_preprocessing_config) - + self._update_hook_config("image_preprocessing", image_preprocessing_config) + if image_postprocessing_config is not None: - logger.info(f"update_stream_params: Updating image postprocessing configuration") - self._update_hook_config('image_postprocessing', image_postprocessing_config) - + logger.info("update_stream_params: Updating image postprocessing configuration") + self._update_hook_config("image_postprocessing", image_postprocessing_config) + if latent_preprocessing_config is not None: - logger.info(f"update_stream_params: Updating latent preprocessing configuration") - self._update_hook_config('latent_preprocessing', latent_preprocessing_config) - + logger.info("update_stream_params: Updating latent preprocessing configuration") + self._update_hook_config("latent_preprocessing", latent_preprocessing_config) + if latent_postprocessing_config is not None: - logger.info(f"update_stream_params: Updating latent postprocessing configuration") - self._update_hook_config('latent_postprocessing', latent_postprocessing_config) + logger.info("update_stream_params: Updating latent postprocessing configuration") + self._update_hook_config("latent_postprocessing", latent_postprocessing_config) if self.stream.kvo_cache: if cache_interval is not None: @@ -375,9 +372,7 @@ def update_stream_params( # runtime — resizing one-at-a-time races with TRT inference (causes "Dimensions # with name C must be equal" errors). cache_maxframes is a logical write window. actual_cache_size = ( - self.stream.kvo_cache[0].shape[1] - if self.stream.kvo_cache - else cache_maxframes + self.stream.kvo_cache[0].shape[1] if self.stream.kvo_cache else cache_maxframes ) if cache_maxframes > actual_cache_size: logger.warning( @@ -395,9 +390,7 @@ def update_stream_params( @torch.inference_mode() def update_prompt_weights( - self, - prompt_weights: List[float], - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + self, prompt_weights: List[float], prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" ) -> None: """Update weights for current prompt list without re-encoding prompts.""" if not self._current_prompt_list: @@ -405,7 +398,9 @@ def update_prompt_weights( return if len(prompt_weights) != len(self._current_prompt_list): - logger.warning(f"update_prompt_weights: Warning: Weight count {len(prompt_weights)} doesn't match prompt count {len(self._current_prompt_list)}") + logger.warning( + f"update_prompt_weights: Warning: Weight count {len(prompt_weights)} doesn't match prompt count {len(self._current_prompt_list)}" + ) return # Update the current prompt list with new weights @@ -420,9 +415,7 @@ def update_prompt_weights( @torch.inference_mode() def update_seed_weights( - self, - seed_weights: List[float], - interpolation_method: Literal["linear", "slerp"] = "linear" + self, seed_weights: List[float], interpolation_method: Literal["linear", "slerp"] = "linear" ) -> None: """Update weights for current seed list without regenerating noise.""" if not self._current_seed_list: @@ -430,7 +423,9 @@ def update_seed_weights( return if len(seed_weights) != len(self._current_seed_list): - logger.warning(f"update_seed_weights: Warning: Weight count {len(seed_weights)} doesn't match seed count {len(self._current_seed_list)}") + logger.warning( + f"update_seed_weights: Warning: Weight count {len(seed_weights)} doesn't match seed count {len(self._current_seed_list)}" + ) return # Update the current seed list with new weights @@ -448,7 +443,7 @@ def _update_blended_prompts( self, prompt_list: List[Tuple[str, float]], negative_prompt: str = "", - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + prompt_interpolation_method: Literal["linear", "slerp"] = "slerp", ) -> None: """Update prompt embeddings using multiple weighted prompts.""" # Store current state @@ -461,14 +456,10 @@ def _update_blended_prompts( # Apply blending self._apply_prompt_blending(prompt_interpolation_method) - def _cache_prompt_embeddings( - self, - prompt_list: List[Tuple[str, float]], - negative_prompt: str - ) -> None: + def _cache_prompt_embeddings(self, prompt_list: List[Tuple[str, float]], negative_prompt: str) -> None: """Cache prompt embeddings for efficient reuse.""" for idx, (prompt_text, weight) in enumerate(prompt_list): - if idx not in self._prompt_cache or self._prompt_cache[idx]['text'] != prompt_text: + if idx not in self._prompt_cache or self._prompt_cache[idx]["text"] != prompt_text: # Cache miss - encode the prompt self._prompt_cache_stats.record_miss() encoder_output = self.stream.pipe.encode_prompt( @@ -482,10 +473,7 @@ def _cache_prompt_embeddings( if len(self._prompt_cache) >= 32: oldest_key = next(iter(self._prompt_cache)) del self._prompt_cache[oldest_key] - self._prompt_cache[idx] = { - 'embed': encoder_output[0], - 'text': prompt_text - } + self._prompt_cache[idx] = {"embed": encoder_output[0], "text": prompt_text} else: # Cache hit self._prompt_cache_stats.record_hit() @@ -500,7 +488,7 @@ def _apply_prompt_blending(self, prompt_interpolation_method: Literal["linear", for idx, (prompt_text, weight) in enumerate(self._current_prompt_list): if idx in self._prompt_cache: - embeddings.append(self._prompt_cache[idx]['embed']) + embeddings.append(self._prompt_cache[idx]["embed"]) weights.append(weight) if not embeddings: @@ -545,13 +533,14 @@ def _apply_prompt_blending(self, prompt_interpolation_method: Literal["linear", # No CFG, just use the blended embeddings final_prompt_embeds = combined_embeds.repeat(self.stream.batch_size, 1, 1) final_negative_embeds = None # Will be set by enhancers if needed - + # Enhancer mechanism removed in favor of embedding_hooks # Run embedding hooks to compose final embeddings (e.g., append IP-Adapter tokens) try: - if hasattr(self.stream, 'embedding_hooks') and self.stream.embedding_hooks: + if hasattr(self.stream, "embedding_hooks") and self.stream.embedding_hooks: from .hooks import EmbedsCtx # local import to avoid cycles + embeds_ctx = EmbedsCtx( prompt_embeds=final_prompt_embeds, negative_prompt_embeds=final_negative_embeds, @@ -562,15 +551,21 @@ def _apply_prompt_blending(self, prompt_interpolation_method: Literal["linear", final_negative_embeds = embeds_ctx.negative_prompt_embeds except Exception as e: import logging + logging.getLogger(__name__).error(f"_apply_prompt_blending: embedding hook failed: {e}") - + # Set final embeddings on stream self.stream.prompt_embeds = final_prompt_embeds if final_negative_embeds is not None: self.stream.negative_prompt_embeds = final_negative_embeds def _slerp(self, embed1: torch.Tensor, embed2: torch.Tensor, t: float) -> torch.Tensor: - """Spherical linear interpolation between two embeddings.""" + """Spherical linear interpolation between two embeddings. + + Traces the geodesic on the unit sphere (MML §3.3), then rescales to the + linearly-interpolated norm so embeddings of unequal magnitude are handled + correctly. + """ # Handle case where t is 0 or 1 if t <= 0: return embed1 @@ -582,31 +577,32 @@ def _slerp(self, embed1: torch.Tensor, embed2: torch.Tensor, t: float) -> torch. flat1 = embed1.view(-1) flat2 = embed2.view(-1) - # Normalize + # Preserve norms for magnitude interpolation, then normalize for angle calc. + norm1 = flat1.norm() + norm2 = flat2.norm() flat1_norm = F.normalize(flat1, dim=0) flat2_norm = F.normalize(flat2, dim=0) - # Calculate angle + # Calculate angle between unit vectors dot_product = torch.clamp(torch.dot(flat1_norm, flat2_norm), -1.0, 1.0) theta = torch.acos(dot_product) - # Handle parallel vectors + # Handle parallel vectors (degenerate SLERP → LERP) if theta.abs() < 1e-6: result = (1 - t) * flat1 + t * flat2 else: - # SLERP formula + # SLERP on unit sphere, rescaled to linearly-interpolated magnitude. sin_theta = torch.sin(theta) w1 = torch.sin((1 - t) * theta) / sin_theta w2 = torch.sin(t * theta) / sin_theta - result = w1 * flat1 + w2 * flat2 + unit_result = w1 * flat1_norm + w2 * flat2_norm + result = unit_result * ((1 - t) * norm1 + t * norm2) return result.view(original_shape) @torch.inference_mode() def _update_blended_seeds( - self, - seed_list: List[Tuple[int, float]], - interpolation_method: Literal["linear", "slerp"] = "linear" + self, seed_list: List[Tuple[int, float]], interpolation_method: Literal["linear", "slerp"] = "linear" ) -> None: """Update seed tensors using multiple weighted seeds.""" # Store current state @@ -621,7 +617,7 @@ def _update_blended_seeds( def _cache_seed_noise(self, seed_list: List[Tuple[int, float]]) -> None: """Cache seed noise tensors for efficient reuse.""" for idx, (seed_value, weight) in enumerate(seed_list): - if idx not in self._seed_cache or self._seed_cache[idx]['seed'] != seed_value: + if idx not in self._seed_cache or self._seed_cache[idx]["seed"] != seed_value: # Cache miss - generate noise for the seed self._seed_cache_stats.record_miss() generator = torch.Generator(device=self.stream.device) @@ -631,13 +627,10 @@ def _cache_seed_noise(self, seed_list: List[Tuple[int, float]]) -> None: (self.stream.batch_size, 4, self.stream.latent_height, self.stream.latent_width), generator=generator, device=self.stream.device, - dtype=self.stream.dtype + dtype=self.stream.dtype, ) - self._seed_cache[idx] = { - 'noise': noise, - 'seed': seed_value - } + self._seed_cache[idx] = {"noise": noise, "seed": seed_value} else: # Cache hit self._seed_cache_stats.record_hit() @@ -652,7 +645,7 @@ def _apply_seed_blending(self, interpolation_method: Literal["linear", "slerp"]) for idx, (seed_value, weight) in enumerate(self._current_seed_list): if idx in self._seed_cache: - noise_tensors.append(self._seed_cache[idx]['noise']) + noise_tensors.append(self._seed_cache[idx]["noise"]) weights.append(weight) if not noise_tensors: @@ -663,6 +656,7 @@ def _apply_seed_blending(self, interpolation_method: Literal["linear", "slerp"]) weights = self._normalize_weights(weights, self.normalize_seed_weights) # Apply interpolation + # SLERP only activates for exactly 2 seeds; 3+ seeds always use linear blending. if interpolation_method == "slerp" and len(noise_tensors) == 2: # Spherical linear interpolation for 2 seeds noise1, noise2 = noise_tensors[0], noise_tensors[1] @@ -673,20 +667,26 @@ def _apply_seed_blending(self, interpolation_method: Literal["linear", "slerp"]) combined_noise = torch.zeros_like(noise_tensors[0]) for noise, weight in zip(noise_tensors, weights): combined_noise += weight * noise - - # Preserve noise magnitude when weights are normalized + + # For normalized weights (Σwᵢ=1), Var(Σwᵢεᵢ)=Σwᵢ² ≤ 1 — the blend is + # under-dispersed. Restore 𝒩(0,I) variance with the exact closed-form + # factor 1/√(Σwᵢ²) (MML §6.4 / Bishop §2.3 variance of a sum). if self.normalize_seed_weights and len(noise_tensors) > 1: - original_magnitude = torch.mean(torch.stack([torch.norm(noise) for noise in noise_tensors])) - current_magnitude = torch.norm(combined_noise) - if current_magnitude > 1e-8: # Avoid division by zero - combined_noise = combined_noise * (original_magnitude / current_magnitude) + sum_sq = (weights * weights).sum() + combined_noise = combined_noise / torch.sqrt(sum_sq) # Update stream noise self.stream.init_noise = combined_noise self.stream.stock_noise = torch.zeros_like(self.stream.init_noise) def _slerp_noise(self, noise1: torch.Tensor, noise2: torch.Tensor, t: float) -> torch.Tensor: - """Spherical linear interpolation between two noise tensors.""" + """Spherical linear interpolation between two noise tensors. + + NOTE: weights are applied to the raw (un-normalised) flats. For independent + Gaussian tensors this is benign — both norms ≈ √N, so the equal-magnitude + assumption nearly holds; at θ≈90° the blend also preserves 𝒩(0,I) variance. + For embeddings of unequal norm use ``_slerp`` (exact magnitude-rescaling). + """ # Handle case where t is 0 or 1 if t <= 0: return noise1 @@ -748,6 +748,7 @@ def _update_seed(self, seed: int) -> None: def _get_scheduler_scalings(self, timestep): """Get LCM/TCD-specific scaling factors for boundary conditions.""" from diffusers import LCMScheduler + if isinstance(self.stream.scheduler, LCMScheduler): c_skip, c_out = self.stream.scheduler.get_scalings_for_boundary_condition_discrete(timestep) return c_skip, c_out @@ -765,9 +766,7 @@ def _update_timestep_calculations(self) -> None: for t in self.stream.t_list: self.stream.sub_timesteps.append(self.stream.timesteps[t]) - sub_timesteps_tensor = torch.tensor( - self.stream.sub_timesteps, dtype=torch.long, device=self.stream.device - ) + sub_timesteps_tensor = torch.tensor(self.stream.sub_timesteps, dtype=torch.long, device=self.stream.device) self.stream.sub_timesteps_tensor = torch.repeat_interleave( sub_timesteps_tensor, repeats=self.stream.frame_bff_size if self.stream.use_denoising_batch else 1, @@ -793,12 +792,8 @@ def _update_timestep_calculations(self) -> None: ) if self.stream.use_denoising_batch: - self.stream.c_skip = torch.repeat_interleave( - self.stream.c_skip, repeats=self.stream.frame_bff_size, dim=0 - ) - self.stream.c_out = torch.repeat_interleave( - self.stream.c_out, repeats=self.stream.frame_bff_size, dim=0 - ) + self.stream.c_skip = torch.repeat_interleave(self.stream.c_skip, repeats=self.stream.frame_bff_size, dim=0) + self.stream.c_out = torch.repeat_interleave(self.stream.c_out, repeats=self.stream.frame_bff_size, dim=0) # Update alpha_prod_t_sqrt and beta_prod_t_sqrt alpha_prod_t_sqrt_list = [] @@ -838,29 +833,25 @@ def _update_timestep_values_only(self, t_index_list: List[int]) -> None: def _recalculate_timestep_dependent_params(self, t_index_list: List[int]) -> None: """Recalculate all parameters that depend on t_index_list.""" - + # Check if this is a structural change (length) or just value change if len(t_index_list) == len(self.stream.t_list): # Same length - only values changed, use lightweight update (working branch behavior) self._update_timestep_values_only(t_index_list) return - + # Length changed - do full recalculation including batch-dependent parameters (broken branch logic - but it works for this case!) self.stream.t_list = t_index_list self.stream.denoising_steps_num = len(self.stream.t_list) old_batch_size = self.stream.batch_size - + if self.stream.use_denoising_batch: self.stream.batch_size = self.stream.denoising_steps_num * self.stream.frame_bff_size if self.stream.cfg_type == "initialize": - self.stream.trt_unet_batch_size = ( - self.stream.denoising_steps_num + 1 - ) * self.stream.frame_bff_size + self.stream.trt_unet_batch_size = (self.stream.denoising_steps_num + 1) * self.stream.frame_bff_size elif self.stream.cfg_type == "full": - self.stream.trt_unet_batch_size = ( - 2 * self.stream.denoising_steps_num * self.stream.frame_bff_size - ) + self.stream.trt_unet_batch_size = 2 * self.stream.denoising_steps_num * self.stream.frame_bff_size else: self.stream.trt_unet_batch_size = self.stream.denoising_steps_num * self.stream.frame_bff_size else: @@ -891,27 +882,33 @@ def _recalculate_timestep_dependent_params(self, t_index_list: List[int]) -> Non # Resize kvo_cache tensors if batch size changed if self.stream.kvo_cache and old_batch_size != self.stream.batch_size: - logger.info(f"_recalculate_timestep_dependent_params: Resizing kvo_cache tensors from batch_size {old_batch_size} to {self.stream.batch_size}") + logger.info( + f"_recalculate_timestep_dependent_params: Resizing kvo_cache tensors from batch_size {old_batch_size} to {self.stream.batch_size}" + ) for i, cache_tensor in enumerate(self.stream.kvo_cache): # KVO cache shape: (2, cache_maxframes, batch_size, seq_length, hidden_dim) current_shape = cache_tensor.shape - new_shape = (current_shape[0], current_shape[1], self.stream.batch_size, current_shape[3], current_shape[4]) - new_cache_tensor = torch.zeros( - new_shape, - dtype=cache_tensor.dtype, - device=cache_tensor.device + new_shape = ( + current_shape[0], + current_shape[1], + self.stream.batch_size, + current_shape[3], + current_shape[4], ) - + new_cache_tensor = torch.zeros(new_shape, dtype=cache_tensor.dtype, device=cache_tensor.device) + # Copy over as much data as possible from old cache min_batch = min(old_batch_size, self.stream.batch_size) new_cache_tensor[:, :, :min_batch, :, :] = cache_tensor[:, :, :min_batch, :, :] - + self.stream.kvo_cache[i] = new_cache_tensor # Drop bucketed storage refs so update_kvo_cache falls back to # per-layer writes against the new tensors. self.stream._kvo_buckets = None self.stream._kvo_outputs_by_bucket = None - logger.info(f"_recalculate_timestep_dependent_params: KVO cache tensors resized to new batch_size {self.stream.batch_size}") + logger.info( + f"_recalculate_timestep_dependent_params: KVO cache tensors resized to new batch_size {self.stream.batch_size}" + ) # Update timestep-dependent calculations (shared with value-only path) self._update_timestep_calculations() @@ -930,10 +927,7 @@ def _recalculate_controlnet_inputs(self, width: int, height: int) -> None: @torch.inference_mode() def update_prompt_at_index( - self, - index: int, - new_prompt: str, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + self, index: int, new_prompt: str, prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" ) -> None: """Update a single prompt at the specified index without re-encoding others.""" if not self._validate_index(index, self._current_prompt_list, "update_prompt_at_index"): @@ -947,11 +941,11 @@ def update_prompt_at_index( self._cache_prompt_embeddings([(new_prompt, weight)], self._current_negative_prompt) # Update cache index to point to the new prompt - if index in self._prompt_cache and self._prompt_cache[index]['text'] != new_prompt: + if index in self._prompt_cache and self._prompt_cache[index]["text"] != new_prompt: # Find if this prompt is already cached elsewhere existing_cache_key = None for cache_idx, cache_data in self._prompt_cache.items(): - if cache_data['text'] == new_prompt: + if cache_data["text"] == new_prompt: existing_cache_key = cache_idx break @@ -969,10 +963,7 @@ def update_prompt_at_index( do_classifier_free_guidance=False, negative_prompt=self._current_negative_prompt, ) - self._prompt_cache[index] = { - 'embed': encoder_output[0], - 'text': new_prompt - } + self._prompt_cache[index] = {"embed": encoder_output[0], "text": new_prompt} # Recompute blended embeddings with updated prompt self._apply_prompt_blending(prompt_interpolation_method) @@ -984,16 +975,12 @@ def get_current_prompts(self) -> List[Tuple[str, float]]: @torch.inference_mode() def add_prompt( - self, - prompt: str, - weight: float = 1.0, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + self, prompt: str, weight: float = 1.0, prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" ) -> None: """Add a new prompt to the current list.""" new_index = len(self._current_prompt_list) self._current_prompt_list.append((prompt, weight)) - # Cache the new prompt encoder_output = self.stream.pipe.encode_prompt( prompt=prompt, @@ -1002,10 +989,7 @@ def add_prompt( do_classifier_free_guidance=False, negative_prompt=self._current_negative_prompt, ) - self._prompt_cache[new_index] = { - 'embed': encoder_output[0], - 'text': prompt - } + self._prompt_cache[new_index] = {"embed": encoder_output[0], "text": prompt} self._prompt_cache_stats.record_miss() # Recompute blended embeddings @@ -1013,9 +997,7 @@ def add_prompt( @torch.inference_mode() def remove_prompt_at_index( - self, - index: int, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + self, index: int, prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" ) -> None: """Remove a prompt at the specified index.""" if not self._validate_index(index, self._current_prompt_list, "remove_prompt_at_index"): @@ -1040,10 +1022,7 @@ def remove_prompt_at_index( @torch.inference_mode() def update_seed_at_index( - self, - index: int, - new_seed: int, - interpolation_method: Literal["linear", "slerp"] = "linear" + self, index: int, new_seed: int, interpolation_method: Literal["linear", "slerp"] = "linear" ) -> None: """Update a single seed at the specified index without regenerating others.""" if not self._validate_index(index, self._current_seed_list, "update_seed_at_index"): @@ -1053,16 +1032,15 @@ def update_seed_at_index( old_seed, weight = self._current_seed_list[index] self._current_seed_list[index] = (new_seed, weight) - # Cache the new seed noise self._cache_seed_noise([(new_seed, weight)]) # Update cache index to point to the new seed - if index in self._seed_cache and self._seed_cache[index]['seed'] != new_seed: + if index in self._seed_cache and self._seed_cache[index]["seed"] != new_seed: # Find if this seed is already cached elsewhere existing_cache_key = None for cache_idx, cache_data in self._seed_cache.items(): - if cache_data['seed'] == new_seed: + if cache_data["seed"] == new_seed: existing_cache_key = cache_idx break @@ -1080,13 +1058,10 @@ def update_seed_at_index( (self.stream.batch_size, 4, self.stream.latent_height, self.stream.latent_width), generator=generator, device=self.stream.device, - dtype=self.stream.dtype + dtype=self.stream.dtype, ) - self._seed_cache[index] = { - 'noise': noise, - 'seed': new_seed - } + self._seed_cache[index] = {"noise": noise, "seed": new_seed} # Recompute blended noise with updated seed self._apply_seed_blending(interpolation_method) @@ -1098,10 +1073,7 @@ def get_current_seeds(self) -> List[Tuple[int, float]]: @torch.inference_mode() def add_seed( - self, - seed: int, - weight: float = 1.0, - interpolation_method: Literal["linear", "slerp"] = "linear" + self, seed: int, weight: float = 1.0, interpolation_method: Literal["linear", "slerp"] = "linear" ) -> None: """Add a new seed to the current list.""" new_index = len(self._current_seed_list) @@ -1117,24 +1089,17 @@ def add_seed( (self.stream.batch_size, 4, self.stream.latent_height, self.stream.latent_width), generator=generator, device=self.stream.device, - dtype=self.stream.dtype + dtype=self.stream.dtype, ) - self._seed_cache[new_index] = { - 'noise': noise, - 'seed': seed - } + self._seed_cache[new_index] = {"noise": noise, "seed": seed} self._seed_cache_stats.record_miss() # Recompute blended noise self._apply_seed_blending(interpolation_method) @torch.inference_mode() - def remove_seed_at_index( - self, - index: int, - interpolation_method: Literal["linear", "slerp"] = "linear" - ) -> None: + def remove_seed_at_index(self, index: int, interpolation_method: Literal["linear", "slerp"] = "linear") -> None: """Remove a seed at the specified index.""" if not self._validate_index(index, self._current_seed_list, "remove_seed_at_index"): return @@ -1159,7 +1124,7 @@ def remove_seed_at_index( def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> None: """ Update ControlNet configuration by diffing current vs desired state. - + Args: desired_config: Complete ControlNet configuration list defining the desired state. Each dict contains: model_id, preprocessor, conditioning_scale, enabled, etc. @@ -1167,41 +1132,47 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non # Find the ControlNet pipeline/module (module-aware) controlnet_pipeline = self._get_controlnet_pipeline() if not controlnet_pipeline: - logger.debug("_update_controlnet_config: No ControlNet pipeline found (expected when ControlNet not loaded)") + logger.debug( + "_update_controlnet_config: No ControlNet pipeline found (expected when ControlNet not loaded)" + ) return - + current_config = self._get_current_controlnet_config() - + # Simple approach: detect what changed and apply minimal updates - current_models = {i: getattr(cn, 'model_id', f'controlnet_{i}') for i, cn in enumerate(controlnet_pipeline.controlnets)} - desired_models = {cfg['model_id']: cfg for cfg in desired_config} - + current_models = { + i: getattr(cn, "model_id", f"controlnet_{i}") for i, cn in enumerate(controlnet_pipeline.controlnets) + } + desired_models = {cfg["model_id"]: cfg for cfg in desired_config} + # Reorder to match desired order (module supports stable reordering) try: - desired_order = [cfg['model_id'] for cfg in desired_config if 'model_id' in cfg] - if hasattr(controlnet_pipeline, 'reorder_controlnets_by_model_ids'): + desired_order = [cfg["model_id"] for cfg in desired_config if "model_id" in cfg] + if hasattr(controlnet_pipeline, "reorder_controlnets_by_model_ids"): controlnet_pipeline.reorder_controlnets_by_model_ids(desired_order) except Exception: pass # Recompute current models after potential reorder - current_models = {i: getattr(cn, 'model_id', f'controlnet_{i}') for i, cn in enumerate(controlnet_pipeline.controlnets)} + current_models = { + i: getattr(cn, "model_id", f"controlnet_{i}") for i, cn in enumerate(controlnet_pipeline.controlnets) + } # Remove controlnets not in desired config for i in reversed(range(len(controlnet_pipeline.controlnets))): - model_id = current_models.get(i, f'controlnet_{i}') + model_id = current_models.get(i, f"controlnet_{i}") if model_id not in desired_models: logger.info(f"_update_controlnet_config: Removing ControlNet {model_id}") try: controlnet_pipeline.remove_controlnet(i) except Exception: raise - + # Add new controlnets and update existing ones for desired_cfg in desired_config: - model_id = desired_cfg['model_id'] + model_id = desired_cfg["model_id"] existing_index = next((i for i, mid in current_models.items() if mid == model_id), None) - + if existing_index is None: # Add new controlnet logger.info(f"_update_controlnet_config: Adding ControlNet {model_id}") @@ -1209,15 +1180,16 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non # Prefer module path: construct ControlNetConfig try: from .modules.controlnet_module import ControlNetConfig # type: ignore + cn_cfg = ControlNetConfig( - model_id=desired_cfg.get('model_id'), - preprocessor=desired_cfg.get('preprocessor'), - conditioning_scale=desired_cfg.get('conditioning_scale', 1.0), - enabled=desired_cfg.get('enabled', True), - conditioning_channels=desired_cfg.get('conditioning_channels'), - preprocessor_params=desired_cfg.get('preprocessor_params'), + model_id=desired_cfg.get("model_id"), + preprocessor=desired_cfg.get("preprocessor"), + conditioning_scale=desired_cfg.get("conditioning_scale", 1.0), + enabled=desired_cfg.get("enabled", True), + conditioning_channels=desired_cfg.get("conditioning_channels"), + preprocessor_params=desired_cfg.get("preprocessor_params"), ) - controlnet_pipeline.add_controlnet(cn_cfg, desired_cfg.get('control_image')) + controlnet_pipeline.add_controlnet(cn_cfg, desired_cfg.get("control_image")) except Exception: # No fallback raise @@ -1225,114 +1197,136 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non logger.error(f"_update_controlnet_config: add_controlnet failed for {model_id}: {e}") else: # Update existing controlnet - if 'conditioning_scale' in desired_cfg: - current_scale = current_config[existing_index].get('conditioning_scale', 1.0) - desired_scale = desired_cfg['conditioning_scale'] - + if "conditioning_scale" in desired_cfg: + current_scale = current_config[existing_index].get("conditioning_scale", 1.0) + desired_scale = desired_cfg["conditioning_scale"] + if current_scale != desired_scale: - logger.info(f"_update_controlnet_config: Updating {model_id} scale: {current_scale} → {desired_scale}") - if hasattr(controlnet_pipeline, 'controlnet_scales') and 0 <= existing_index < len(controlnet_pipeline.controlnet_scales): + logger.info( + f"_update_controlnet_config: Updating {model_id} scale: {current_scale} → {desired_scale}" + ) + if hasattr(controlnet_pipeline, "controlnet_scales") and 0 <= existing_index < len( + controlnet_pipeline.controlnet_scales + ): controlnet_pipeline.controlnet_scales[existing_index] = float(desired_scale) - + # Enable/disable toggle - if 'enabled' in desired_cfg and hasattr(controlnet_pipeline, 'enabled_list'): + if "enabled" in desired_cfg and hasattr(controlnet_pipeline, "enabled_list"): if 0 <= existing_index < len(controlnet_pipeline.enabled_list): - controlnet_pipeline.enabled_list[existing_index] = bool(desired_cfg['enabled']) + controlnet_pipeline.enabled_list[existing_index] = bool(desired_cfg["enabled"]) - if 'preprocessor_params' in desired_cfg and hasattr(controlnet_pipeline, 'preprocessors') and controlnet_pipeline.preprocessors[existing_index]: + if ( + "preprocessor_params" in desired_cfg + and hasattr(controlnet_pipeline, "preprocessors") + and controlnet_pipeline.preprocessors[existing_index] + ): preprocessor = controlnet_pipeline.preprocessors[existing_index] - preprocessor.params.update(desired_cfg['preprocessor_params']) - for param_name, param_value in desired_cfg['preprocessor_params'].items(): + preprocessor.params.update(desired_cfg["preprocessor_params"]) + for param_name, param_value in desired_cfg["preprocessor_params"].items(): if hasattr(preprocessor, param_name): setattr(preprocessor, param_name, param_value) - + # Pipeline references are now automatically managed during preprocessor creation # No need to manually re-establish pipeline references for pipeline-aware processors - def _get_controlnet_pipeline(self): """ Get the ControlNet module or legacy pipeline from the structure (module-aware). """ # Module-installed path - if hasattr(self.stream, '_controlnet_module'): + if hasattr(self.stream, "_controlnet_module"): return self.stream._controlnet_module # Legacy paths - if hasattr(self.stream, 'controlnets'): + if hasattr(self.stream, "controlnets"): return self.stream - if hasattr(self.stream, 'stream') and hasattr(self.stream.stream, 'controlnets'): + if hasattr(self.stream, "stream") and hasattr(self.stream.stream, "controlnets"): return self.stream.stream - if self.wrapper and hasattr(self.wrapper, 'stream'): - if hasattr(self.wrapper.stream, '_controlnet_module'): + if self.wrapper and hasattr(self.wrapper, "stream"): + if hasattr(self.wrapper.stream, "_controlnet_module"): return self.wrapper.stream._controlnet_module - if hasattr(self.wrapper.stream, 'controlnets'): + if hasattr(self.wrapper.stream, "controlnets"): return self.wrapper.stream - if hasattr(self.wrapper.stream, 'stream') and hasattr(self.wrapper.stream.stream, 'controlnets'): + if hasattr(self.wrapper.stream, "stream") and hasattr(self.wrapper.stream.stream, "controlnets"): return self.wrapper.stream.stream return None def _get_current_controlnet_config(self) -> List[Dict[str, Any]]: """ Get current ControlNet configuration state. - + Returns: List of current ControlNet configurations """ controlnet_pipeline = self._get_controlnet_pipeline() - if not controlnet_pipeline or not hasattr(controlnet_pipeline, 'controlnets') or not controlnet_pipeline.controlnets: + if ( + not controlnet_pipeline + or not hasattr(controlnet_pipeline, "controlnets") + or not controlnet_pipeline.controlnets + ): return [] - + current_config = [] for i, controlnet in enumerate(controlnet_pipeline.controlnets): - model_id = getattr(controlnet, 'model_id', f'controlnet_{i}') - scale = controlnet_pipeline.controlnet_scales[i] if hasattr(controlnet_pipeline, 'controlnet_scales') and i < len(controlnet_pipeline.controlnet_scales) else 1.0 + model_id = getattr(controlnet, "model_id", f"controlnet_{i}") + scale = ( + controlnet_pipeline.controlnet_scales[i] + if hasattr(controlnet_pipeline, "controlnet_scales") and i < len(controlnet_pipeline.controlnet_scales) + else 1.0 + ) enabled_val = True try: - if hasattr(controlnet_pipeline, 'enabled_list') and i < len(controlnet_pipeline.enabled_list): + if hasattr(controlnet_pipeline, "enabled_list") and i < len(controlnet_pipeline.enabled_list): enabled_val = bool(controlnet_pipeline.enabled_list[i]) except Exception: enabled_val = True config = { - 'model_id': model_id, - 'conditioning_scale': scale, - 'preprocessor_params': getattr(controlnet_pipeline.preprocessors[i], 'params', {}) if hasattr(controlnet_pipeline, 'preprocessors') and controlnet_pipeline.preprocessors[i] else {}, - 'enabled': enabled_val, + "model_id": model_id, + "conditioning_scale": scale, + "preprocessor_params": getattr(controlnet_pipeline.preprocessors[i], "params", {}) + if hasattr(controlnet_pipeline, "preprocessors") and controlnet_pipeline.preprocessors[i] + else {}, + "enabled": enabled_val, } current_config.append(config) - + return current_config def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None: """ Update IPAdapter configuration. - + Args: - desired_config: IPAdapter configuration dict containing: + desired_config: IPAdapter configuration dict containing: ipadapter_model_path, image_encoder_path, style_image, scale, enabled, etc. """ # Find the IPAdapter pipeline ipadapter_pipeline = self._get_ipadapter_pipeline() - + if not ipadapter_pipeline: - logger.warning(f"_update_ipadapter_config: No IPAdapter pipeline found") + logger.warning("_update_ipadapter_config: No IPAdapter pipeline found") return - - if 'scale' in desired_config and desired_config['scale'] is not None: - desired_scale = float(desired_config['scale']) + + if "scale" in desired_config and desired_config["scale"] is not None: + desired_scale = float(desired_config["scale"]) # Get current scale from IPAdapter instance - current_scale = getattr(self.stream.ipadapter, 'scale', 1.0) if hasattr(self.stream, 'ipadapter') else 1.0 - + current_scale = getattr(self.stream.ipadapter, "scale", 1.0) if hasattr(self.stream, "ipadapter") else 1.0 + if current_scale != desired_scale: logger.info(f"_update_ipadapter_config: Updating scale: {current_scale} → {desired_scale}") - + # Get weight_type from IPAdapter instance - weight_type = getattr(self.stream.ipadapter, 'weight_type', None) if hasattr(self.stream, 'ipadapter') else None - + weight_type = ( + getattr(self.stream.ipadapter, "weight_type", None) if hasattr(self.stream, "ipadapter") else None + ) + # Apply scale with weight type consideration - if weight_type is not None and hasattr(self.stream, 'ipadapter'): + if weight_type is not None and hasattr(self.stream, "ipadapter"): try: from diffusers_ipadapter.ip_adapter.attention_processor import build_layer_weights - ip_procs = [p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index")] + + ip_procs = [ + p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index") + ] num_layers = len(ip_procs) weights = build_layer_weights(num_layers, desired_scale, weight_type) if weights is not None: @@ -1340,47 +1334,51 @@ def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None: else: self.stream.ipadapter.set_scale(desired_scale) # Update our tracking attribute - setattr(self.stream.ipadapter, 'scale', desired_scale) + setattr(self.stream.ipadapter, "scale", desired_scale) except Exception: # Do not add fallback mechanisms raise else: # Simple uniform scale - if hasattr(self.stream, 'ipadapter'): + if hasattr(self.stream, "ipadapter"): # Tell diffusers_ipadapter to set the scale self.stream.ipadapter.set_scale(desired_scale) # Update our tracking attribute - setattr(self.stream.ipadapter, 'scale', desired_scale) - + setattr(self.stream.ipadapter, "scale", desired_scale) # Update enabled state if provided - if 'enabled' in desired_config and desired_config['enabled'] is not None: - enabled_state = bool(desired_config['enabled']) + if "enabled" in desired_config and desired_config["enabled"] is not None: + enabled_state = bool(desired_config["enabled"]) # Update IPAdapter instance - if hasattr(self.stream, 'ipadapter'): - current_enabled = getattr(self.stream.ipadapter, 'enabled', True) + if hasattr(self.stream, "ipadapter"): + current_enabled = getattr(self.stream.ipadapter, "enabled", True) if current_enabled != enabled_state: - logger.info(f"_update_ipadapter_config: Updating enabled state: {current_enabled} → {enabled_state}") - setattr(self.stream.ipadapter, 'enabled', enabled_state) + logger.info( + f"_update_ipadapter_config: Updating enabled state: {current_enabled} → {enabled_state}" + ) + setattr(self.stream.ipadapter, "enabled", enabled_state) # Update weight type if provided (affects per-layer distribution and/or per-step factor) - if 'weight_type' in desired_config and desired_config['weight_type'] is not None: - weight_type = desired_config['weight_type'] + if "weight_type" in desired_config and desired_config["weight_type"] is not None: + weight_type = desired_config["weight_type"] # Update IPAdapter instance - if hasattr(self.stream, 'ipadapter'): - setattr(self.stream.ipadapter, 'weight_type', weight_type) - + if hasattr(self.stream, "ipadapter"): + setattr(self.stream.ipadapter, "weight_type", weight_type) + # For PyTorch UNet, immediately apply a per-layer scale vector so layers reflect selection types try: - is_tensorrt_engine = hasattr(self.stream.unet, 'engine') and hasattr(self.stream.unet, 'stream') + is_tensorrt_engine = hasattr(self.stream.unet, "engine") and hasattr(self.stream.unet, "stream") if not is_tensorrt_engine: # Compute per-layer vector using Diffusers_IPAdapter helper from diffusers_ipadapter.ip_adapter.attention_processor import build_layer_weights + # Count installed IP layers by scanning processors with _ip_layer_index - ip_procs = [p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index")] + ip_procs = [ + p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index") + ] num_layers = len(ip_procs) # Get base weight from IPAdapter instance - base_weight = float(getattr(self.stream.ipadapter, 'scale', 1.0)) + base_weight = float(getattr(self.stream.ipadapter, "scale", 1.0)) weights = build_layer_weights(num_layers, base_weight, weight_type) # If None, keep uniform base scale; else set per-layer vector if weights is not None: @@ -1388,7 +1386,7 @@ def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None: else: self.stream.ipadapter.set_scale(base_weight) # Keep our tracking attribute in sync - setattr(self.stream.ipadapter, 'scale', base_weight) + setattr(self.stream.ipadapter, "scale", base_weight) except Exception: # Do not add fallback mechanisms raise @@ -1396,191 +1394,207 @@ def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None: def _get_ipadapter_pipeline(self): """ Get the IPAdapter pipeline from the pipeline structure (following ControlNet pattern). - + Returns: IPAdapter pipeline object or None if not found """ # Check if stream is IPAdapter pipeline directly - if hasattr(self.stream, 'ipadapter'): + if hasattr(self.stream, "ipadapter"): return self.stream - + # Check if stream has nested stream (ControlNet wrapper) - if hasattr(self.stream, 'stream') and hasattr(self.stream.stream, 'ipadapter'): + if hasattr(self.stream, "stream") and hasattr(self.stream.stream, "ipadapter"): return self.stream.stream - + # Check if we have a wrapper reference and can access through it - if self.wrapper and hasattr(self.wrapper, 'stream'): - if hasattr(self.wrapper.stream, 'ipadapter'): + if self.wrapper and hasattr(self.wrapper, "stream"): + if hasattr(self.wrapper.stream, "ipadapter"): return self.wrapper.stream - elif hasattr(self.wrapper.stream, 'stream') and hasattr(self.wrapper.stream.stream, 'ipadapter'): + elif hasattr(self.wrapper.stream, "stream") and hasattr(self.wrapper.stream.stream, "ipadapter"): return self.wrapper.stream.stream - + return None def _get_current_ipadapter_config(self) -> Optional[Dict[str, Any]]: """ Get current IPAdapter configuration by introspecting the IPAdapter instance. - + Returns: Current IPAdapter configuration dict or None if no IPAdapter """ # Get config from IPAdapter instance - if hasattr(self.stream, 'ipadapter') and self.stream.ipadapter is not None: + if hasattr(self.stream, "ipadapter") and self.stream.ipadapter is not None: ipadapter = self.stream.ipadapter - + config = { - 'scale': getattr(ipadapter, 'scale', 1.0), - 'weight_type': getattr(ipadapter, 'weight_type', None), - 'enabled': getattr(ipadapter, 'enabled', True), # Check actual enabled state + "scale": getattr(ipadapter, "scale", 1.0), + "weight_type": getattr(ipadapter, "weight_type", None), + "enabled": getattr(ipadapter, "enabled", True), # Check actual enabled state } - + # Add static initialization fields - if hasattr(self.stream, '_ipadapter_module'): + if hasattr(self.stream, "_ipadapter_module"): module_config = self.stream._ipadapter_module.config - config.update({ - 'style_image_key': module_config.style_image_key, - 'num_image_tokens': module_config.num_image_tokens, - 'type': module_config.type.value, - }) - + config.update( + { + "style_image_key": module_config.style_image_key, + "num_image_tokens": module_config.num_image_tokens, + "type": module_config.type.value, + } + ) + # Check if style image is set ipadapter_pipeline = self._get_ipadapter_pipeline() - if ipadapter_pipeline and hasattr(ipadapter_pipeline, 'style_image') and ipadapter_pipeline.style_image: - config['has_style_image'] = True + if ipadapter_pipeline and hasattr(ipadapter_pipeline, "style_image") and ipadapter_pipeline.style_image: + config["has_style_image"] = True else: - config['has_style_image'] = False - + config["has_style_image"] = False + return config - + # No IPAdapter instance found return None def _get_current_hook_config(self, hook_type: str) -> List[Dict[str, Any]]: """ Get current hook configuration by introspecting the hook module state. - + Args: hook_type: Type of hook (image_preprocessing, image_postprocessing, etc.) - + Returns: List of processor configurations or empty list if no module """ # Get the hook module module_attr_name = f"_{hook_type}_module" hook_module = getattr(self.stream, module_attr_name, None) - + if not hook_module: return [] - + # Get processors from the module - processors = getattr(hook_module, 'processors', []) - + processors = getattr(hook_module, "processors", []) + config = [] for i, processor in enumerate(processors): proc_config = { - 'type': getattr(processor, '__class__').__name__, - 'order': getattr(processor, 'order', i), - 'enabled': getattr(processor, 'enabled', True), + "type": getattr(processor, "__class__").__name__, + "order": getattr(processor, "order", i), + "enabled": getattr(processor, "enabled", True), } - + # Try to get processor parameters - if hasattr(processor, 'params'): - proc_config['params'] = dict(processor.params) - + if hasattr(processor, "params"): + proc_config["params"] = dict(processor.params) + config.append(proc_config) - + return config def _update_hook_config(self, hook_type: str, desired_config: List[Dict[str, Any]]) -> None: """ Update hook configuration by modifying existing processors in-place instead of recreating them. - + Args: hook_type: Type of hook (image_preprocessing, image_postprocessing, etc.) desired_config: List of processor configurations """ logger.info(f"_update_hook_config: Updating {hook_type} with {len(desired_config)} processors") - + # Get or create the hook module module_attr_name = f"_{hook_type}_module" hook_module = getattr(self.stream, module_attr_name, None) - + if not hook_module: logger.info(f"_update_hook_config: No existing {hook_type} module, creating new one") # Create the appropriate hook module try: if hook_type in ["image_preprocessing", "image_postprocessing"]: - from streamdiffusion.modules.image_processing_module import ImagePreprocessingModule, ImagePostprocessingModule + from streamdiffusion.modules.image_processing_module import ( + ImagePostprocessingModule, + ImagePreprocessingModule, + ) + if hook_type == "image_preprocessing": hook_module = ImagePreprocessingModule() else: hook_module = ImagePostprocessingModule() elif hook_type in ["latent_preprocessing", "latent_postprocessing"]: - from streamdiffusion.modules.latent_processing_module import LatentPreprocessingModule, LatentPostprocessingModule + from streamdiffusion.modules.latent_processing_module import ( + LatentPostprocessingModule, + LatentPreprocessingModule, + ) + if hook_type == "latent_preprocessing": hook_module = LatentPreprocessingModule() else: hook_module = LatentPostprocessingModule() else: raise ValueError(f"Unknown hook type: {hook_type}") - + # Install the module hook_module.install(self.stream) setattr(self.stream, module_attr_name, hook_module) logger.info(f"_update_hook_config: Created and installed {hook_type} module") - + except Exception as e: logger.error(f"_update_hook_config: Failed to create {hook_type} module: {e}") return - - logger.info(f"_update_hook_config: Found existing {hook_type} module with {len(hook_module.processors)} processors") - + + logger.info( + f"_update_hook_config: Found existing {hook_type} module with {len(hook_module.processors)} processors" + ) + # Modify existing processors in-place instead of clearing and recreating for i, proc_config in enumerate(desired_config): - processor_type = proc_config.get('type', 'unknown') - enabled = proc_config.get('enabled', True) - params = proc_config.get('params', {}) - + processor_type = proc_config.get("type", "unknown") + enabled = proc_config.get("enabled", True) + params = proc_config.get("params", {}) + logger.info(f"_update_hook_config: Processing config {i}: type={processor_type}, enabled={enabled}") - + if i < len(hook_module.processors): # Modify existing processor existing_processor = hook_module.processors[i] - + # Get the current processor type from registry name if available, otherwise use class name - current_type = existing_processor.params.get('_registry_name') if hasattr(existing_processor, 'params') else None + current_type = ( + existing_processor.params.get("_registry_name") if hasattr(existing_processor, "params") else None + ) if not current_type: current_type = existing_processor.__class__.__name__ - - logger.info(f"_update_hook_config: Modifying existing processor {i}: {current_type} -> {processor_type}") - + + logger.info( + f"_update_hook_config: Modifying existing processor {i}: {current_type} -> {processor_type}" + ) + # If processor type changed, replace it if current_type.lower() != processor_type.lower(): logger.info(f"_update_hook_config: Type changed, replacing processor {i}") try: from streamdiffusion.preprocessing.processors import get_preprocessor - + # Determine normalization context from hook type - if 'latent' in hook_type: - normalization_context = 'latent' + if "latent" in hook_type: + normalization_context = "latent" else: # Image preprocessing/postprocessing uses 'pipeline' context - normalization_context = 'pipeline' - + normalization_context = "pipeline" + new_processor = get_preprocessor( - processor_type, - pipeline_ref=getattr(self, 'stream', None), - normalization_context=normalization_context + processor_type, + pipeline_ref=getattr(self, "stream", None), + normalization_context=normalization_context, ) - + # Copy attributes from old processor - setattr(new_processor, 'order', getattr(existing_processor, 'order', i)) - setattr(new_processor, 'enabled', enabled) - + setattr(new_processor, "order", getattr(existing_processor, "order", i)) + setattr(new_processor, "enabled", enabled) + # Set parameters - if hasattr(new_processor, 'params'): + if hasattr(new_processor, "params"): new_processor.params.update(params) - + hook_module.processors[i] = new_processor logger.info(f"_update_hook_config: Successfully replaced processor {i} with {processor_type}") except Exception as e: @@ -1588,15 +1602,15 @@ def _update_hook_config(self, hook_type: str, desired_config: List[Dict[str, Any else: # Same type, just update attributes logger.info(f"_update_hook_config: Same type, updating attributes for processor {i}") - setattr(existing_processor, 'enabled', enabled) - + setattr(existing_processor, "enabled", enabled) + # Update parameters - if hasattr(existing_processor, 'params'): + if hasattr(existing_processor, "params"): existing_processor.params.update(params) for param_name, param_value in params.items(): if hasattr(existing_processor, param_name): setattr(existing_processor, param_name, param_value) - + logger.info(f"_update_hook_config: Updated processor {i} enabled={enabled}, params={params}") else: # Add new processor @@ -1606,12 +1620,15 @@ def _update_hook_config(self, hook_type: str, desired_config: List[Dict[str, Any logger.info(f"_update_hook_config: Successfully added processor {i}: {processor_type}") except Exception as e: logger.error(f"_update_hook_config: Failed to add processor {i}: {e}") - + # Remove extra processors if config is shorter while len(hook_module.processors) > len(desired_config): removed_idx = len(hook_module.processors) - 1 removed_processor = hook_module.processors.pop() - logger.info(f"_update_hook_config: Removed extra processor {removed_idx}: {removed_processor.__class__.__name__}") - - logger.info(f"_update_hook_config: Finished updating {hook_type}, now has {len(hook_module.processors)} processors") + logger.info( + f"_update_hook_config: Removed extra processor {removed_idx}: {removed_processor.__class__.__name__}" + ) + logger.info( + f"_update_hook_config: Finished updating {hook_type}, now has {len(hook_module.processors)} processors" + ) diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 36abe02b0..651ff92e0 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -108,7 +108,7 @@ def __init__( normalize_seed_weights: bool = True, # Scheduler and sampler options scheduler: Literal["lcm", "tcd"] = "lcm", - sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal", + sampler: Literal["simple", "sgm_uniform", "normal", "ddim", "beta", "karras"] = "normal", # ControlNet options use_controlnet: bool = False, controlnet_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, @@ -216,7 +216,7 @@ def __init__( by default True. When False, weights > 1 will amplify noise. scheduler : Literal["lcm", "tcd"], optional The scheduler type to use for denoising, by default "lcm". - sampler : Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"], optional + sampler : Literal["simple", "sgm_uniform", "normal", "ddim", "beta", "karras"], optional The sampler type to use for noise scheduling, by default "normal". use_controlnet : bool, optional Whether to enable ControlNet support, by default False. @@ -1070,7 +1070,7 @@ def _load_model( normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True, scheduler: Literal["lcm", "tcd"] = "lcm", - sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal", + sampler: Literal["simple", "sgm_uniform", "normal", "ddim", "beta", "karras"] = "normal", use_controlnet: bool = False, controlnet_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, use_ipadapter: bool = False, @@ -1147,7 +1147,7 @@ def _load_model( When False, weights > 1 will amplify noise. scheduler : Literal["lcm", "tcd"], optional The scheduler type to use for denoising, by default "lcm". - sampler : Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"], optional + sampler : Literal["simple", "sgm_uniform", "normal", "ddim", "beta", "karras"], optional The sampler type to use for noise scheduling, by default "normal". use_controlnet : bool, optional Whether to enable ControlNet support, by default False.