support want2v lora train#1148
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for WanT2V training and inference, adding new datasets (WanT2VVideoDataset, WanT2VCachedDataset, PromptDataset), a WanT2VInferencer, a WanT2VModel, and dynamic time-shifting scheduling. Feedback focuses on improving robustness and performance: handling empty video paths and directory checks in dataset resolution, wrapping metadata parsing in try-except blocks, allowing configurable batch sizes, caching VAE latent statistics to avoid redundant tensor creation, and preventing a potential division-by-zero error in the dynamic time-shift scheduler.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _maybe_add_target_hw(self, sample, height, width): | ||
| if height in (None, "") or width in (None, ""): | ||
| return | ||
| sample["target_height"] = int(height) | ||
| sample["target_width"] = int(width) |
There was a problem hiding this comment.
Wrap the integer parsing of height and width in a try-except block to prevent the dataset loading from crashing due to malformed metadata.
| def _maybe_add_target_hw(self, sample, height, width): | |
| if height in (None, "") or width in (None, ""): | |
| return | |
| sample["target_height"] = int(height) | |
| sample["target_width"] = int(width) | |
| def _maybe_add_target_hw(self, sample, height, width): | |
| if height in (None, "") or width in (None, ""): | |
| return | |
| try: | |
| sample["target_height"] = int(height) | |
| sample["target_width"] = int(width) | |
| except ValueError: | |
| logger.warning("Invalid height/width values: height={}, width={}", height, width) |
| return DataLoader( | ||
| dataset, | ||
| batch_size=1, | ||
| shuffle=shuffle if sampler is None else False, | ||
| sampler=sampler, | ||
| num_workers=data_config.get("num_workers", 8), | ||
| pin_memory=data_config.get("pin_memory", True), | ||
| ) |
There was a problem hiding this comment.
Allow the batch_size to be configured from data_config instead of hardcoding it to 1.
| return DataLoader( | |
| dataset, | |
| batch_size=1, | |
| shuffle=shuffle if sampler is None else False, | |
| sampler=sampler, | |
| num_workers=data_config.get("num_workers", 8), | |
| pin_memory=data_config.get("pin_memory", True), | |
| ) | |
| return DataLoader( | |
| dataset, | |
| batch_size=data_config.get("batch_size", 1), | |
| shuffle=shuffle if sampler is None else False, | |
| sampler=sampler, | |
| num_workers=data_config.get("num_workers", 8), | |
| pin_memory=data_config.get("pin_memory", True), | |
| ) |
| if self.load_vae: | ||
| self.vae = AutoencoderKLWan.from_pretrained( | ||
| model_path, | ||
| subfolder="vae", | ||
| torch_dtype=self.vae_dtype, | ||
| ).to(self.device) | ||
| self.vae.requires_grad_(False) | ||
| if model_config.get("enable_vae_tiling", False): | ||
| self.vae.enable_tiling() |
There was a problem hiding this comment.
Pre-compute and cache latent_mean and latent_std on the target device during initialization to avoid recreating these tensors on every training step.
| if self.load_vae: | |
| self.vae = AutoencoderKLWan.from_pretrained( | |
| model_path, | |
| subfolder="vae", | |
| torch_dtype=self.vae_dtype, | |
| ).to(self.device) | |
| self.vae.requires_grad_(False) | |
| if model_config.get("enable_vae_tiling", False): | |
| self.vae.enable_tiling() | |
| if self.load_vae: | |
| self.vae = AutoencoderKLWan.from_pretrained( | |
| model_path, | |
| subfolder="vae", | |
| torch_dtype=self.vae_dtype, | |
| ).to(self.device) | |
| self.vae.requires_grad_(False) | |
| if model_config.get("enable_vae_tiling", False): | |
| self.vae.enable_tiling() | |
| self.latent_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1) | |
| self.latent_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1) |
| latent_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1) | ||
| latent_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1) | ||
| latent = (latent - latent_mean) * latent_std |
There was a problem hiding this comment.
Use the cached self.latent_mean and self.latent_std attributes instead of recreating them on every call.
| latent_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1) | |
| latent_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1) | |
| latent = (latent - latent_mean) * latent_std | |
| latent = (latent - self.latent_mean) * self.latent_std |
| latents_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1) | ||
| latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1) | ||
| latent = latent.to(dtype=self.vae_dtype) / latents_std + latents_mean |
There was a problem hiding this comment.
Use the cached self.latent_mean and self.latent_std attributes instead of recreating them on every call.
| latents_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1) | |
| latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1) | |
| latent = latent.to(dtype=self.vae_dtype) / latents_std + latents_mean | |
| latent = latent.to(dtype=self.vae_dtype) / self.latent_std + self.latent_mean |
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
No description provided.