Skip to content

support want2v lora train#1148

Open
gushiqiao wants to merge 4 commits into
mainfrom
gsq/dev-train
Open

support want2v lora train#1148
gushiqiao wants to merge 4 commits into
mainfrom
gsq/dev-train

Conversation

@gushiqiao

Copy link
Copy Markdown
Contributor

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for WanT2V training and inference, adding new datasets (WanT2VVideoDataset, WanT2VCachedDataset, PromptDataset), a WanT2VInferencer, a WanT2VModel, and dynamic time-shifting scheduling. Feedback focuses on improving robustness and performance: handling empty video paths and directory checks in dataset resolution, wrapping metadata parsing in try-except blocks, allowing configurable batch sizes, caching VAE latent statistics to avoid redundant tensor creation, and preventing a potential division-by-zero error in the dynamic time-shift scheduler.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread lightx2v_train/lightx2v_train/data/video_dataset.py
Comment thread lightx2v_train/lightx2v_train/data/video_dataset.py Outdated
Comment on lines +393 to +397
def _maybe_add_target_hw(self, sample, height, width):
if height in (None, "") or width in (None, ""):
return
sample["target_height"] = int(height)
sample["target_width"] = int(width)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Wrap the integer parsing of height and width in a try-except block to prevent the dataset loading from crashing due to malformed metadata.

Suggested change
def _maybe_add_target_hw(self, sample, height, width):
if height in (None, "") or width in (None, ""):
return
sample["target_height"] = int(height)
sample["target_width"] = int(width)
def _maybe_add_target_hw(self, sample, height, width):
if height in (None, "") or width in (None, ""):
return
try:
sample["target_height"] = int(height)
sample["target_width"] = int(width)
except ValueError:
logger.warning("Invalid height/width values: height={}, width={}", height, width)

Comment on lines +414 to +421
return DataLoader(
dataset,
batch_size=1,
shuffle=shuffle if sampler is None else False,
sampler=sampler,
num_workers=data_config.get("num_workers", 8),
pin_memory=data_config.get("pin_memory", True),
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Allow the batch_size to be configured from data_config instead of hardcoding it to 1.

Suggested change
return DataLoader(
dataset,
batch_size=1,
shuffle=shuffle if sampler is None else False,
sampler=sampler,
num_workers=data_config.get("num_workers", 8),
pin_memory=data_config.get("pin_memory", True),
)
return DataLoader(
dataset,
batch_size=data_config.get("batch_size", 1),
shuffle=shuffle if sampler is None else False,
sampler=sampler,
num_workers=data_config.get("num_workers", 8),
pin_memory=data_config.get("pin_memory", True),
)

Comment on lines +33 to +41
if self.load_vae:
self.vae = AutoencoderKLWan.from_pretrained(
model_path,
subfolder="vae",
torch_dtype=self.vae_dtype,
).to(self.device)
self.vae.requires_grad_(False)
if model_config.get("enable_vae_tiling", False):
self.vae.enable_tiling()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Pre-compute and cache latent_mean and latent_std on the target device during initialization to avoid recreating these tensors on every training step.

Suggested change
if self.load_vae:
self.vae = AutoencoderKLWan.from_pretrained(
model_path,
subfolder="vae",
torch_dtype=self.vae_dtype,
).to(self.device)
self.vae.requires_grad_(False)
if model_config.get("enable_vae_tiling", False):
self.vae.enable_tiling()
if self.load_vae:
self.vae = AutoencoderKLWan.from_pretrained(
model_path,
subfolder="vae",
torch_dtype=self.vae_dtype,
).to(self.device)
self.vae.requires_grad_(False)
if model_config.get("enable_vae_tiling", False):
self.vae.enable_tiling()
self.latent_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1)
self.latent_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1)

Comment on lines +110 to +112
latent_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1)
latent_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1)
latent = (latent - latent_mean) * latent_std

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use the cached self.latent_mean and self.latent_std attributes instead of recreating them on every call.

Suggested change
latent_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1)
latent_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1)
latent = (latent - latent_mean) * latent_std
latent = (latent - self.latent_mean) * self.latent_std

Comment on lines +171 to +173
latents_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1)
latent = latent.to(dtype=self.vae_dtype) / latents_std + latents_mean

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use the cached self.latent_mean and self.latent_std attributes instead of recreating them on every call.

Suggested change
latents_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.vae_dtype).view(1, self.vae.config.z_dim, 1, 1, 1)
latent = latent.to(dtype=self.vae_dtype) / latents_std + latents_mean
latent = latent.to(dtype=self.vae_dtype) / self.latent_std + self.latent_mean

Comment thread lightx2v_train/lightx2v_train/schedulers/time_shift.py Outdated
gushiqiao and others added 3 commits June 12, 2026 16:41
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant