From 72087f4237f838adc9f493eb66304f26b4455da9 Mon Sep 17 00:00:00 2001 From: zjli2013 Date: Thu, 14 May 2026 11:26:19 +0800 Subject: [PATCH] fix: PyTorch 2.6+ and torchmetrics v1.x compatibility Fix tensor aliasing RuntimeError in dfot_video.py by cloning xs_pred_key before in-place assignment (PyTorch 2.6 stricter check). Handle torchmetrics v1.x API changes where NoTrainLpips, NoTrainInceptionV3, _valid_img, _LPIPS_AVAILABLE, and FrechetInceptionDistance were renamed or relocated. Relax strict version pins for torchmetrics and pyiqa in requirements.txt. Tested end-to-end inference on realestate10k_mini with PyTorch 2.6. --- algorithms/common/metrics/video/base_fid.py | 5 ++++- algorithms/common/metrics/video/lpips.py | 13 +++++++++++-- algorithms/common/metrics/video/shared_registry.py | 10 ++++++++-- algorithms/dfot/dfot_video.py | 2 +- requirements.txt | 4 ++-- 5 files changed, 26 insertions(+), 8 deletions(-) diff --git a/algorithms/common/metrics/video/base_fid.py b/algorithms/common/metrics/video/base_fid.py index d297e83..b4cf33b 100644 --- a/algorithms/common/metrics/video/base_fid.py +++ b/algorithms/common/metrics/video/base_fid.py @@ -3,7 +3,10 @@ import torch from torch import Tensor from torchmetrics import Metric -from torchmetrics.image import FrechetInceptionDistance as _FrechetInceptionDistance +try: + from torchmetrics.image import FrechetInceptionDistance as _FrechetInceptionDistance +except ImportError: + from torchmetrics.image.fid import FrechetInceptionDistance as _FrechetInceptionDistance from .shared_registry import SharedVideoMetricModelRegistry diff --git a/algorithms/common/metrics/video/lpips.py b/algorithms/common/metrics/video/lpips.py index 1b4090b..d5d5352 100644 --- a/algorithms/common/metrics/video/lpips.py +++ b/algorithms/common/metrics/video/lpips.py @@ -4,9 +4,18 @@ from torchmetrics import Metric from torchmetrics.image.lpip import ( LearnedPerceptualImagePatchSimilarity as _LearnedPerceptualImagePatchSimilarity, - _valid_img, ) -from torchmetrics.utilities.imports import _LPIPS_AVAILABLE +try: + from torchmetrics.image.lpip import _valid_img +except ImportError: + def _valid_img(img, normalize): + if normalize: + return img.min() >= 0.0 and img.max() <= 1.0 and img.shape[1] == 3 + return img.min() >= -1.0 and img.max() <= 1.0 and img.shape[1] == 3 +try: + from torchmetrics.utilities.imports import _LPIPS_AVAILABLE +except ImportError: + _LPIPS_AVAILABLE = True from .shared_registry import SharedVideoMetricModelRegistry from .types import VideoMetricModelType diff --git a/algorithms/common/metrics/video/shared_registry.py b/algorithms/common/metrics/video/shared_registry.py index 79a6525..3489277 100644 --- a/algorithms/common/metrics/video/shared_registry.py +++ b/algorithms/common/metrics/video/shared_registry.py @@ -1,7 +1,13 @@ from typing import Dict, List from torch import nn, Tensor -from torchmetrics.image.lpip import NoTrainLpips -from torchmetrics.image.fid import NoTrainInceptionV3 +try: + from torchmetrics.image.lpip import NoTrainLpips +except ImportError: + from torchmetrics.image.lpip import _NoTrainLpips as NoTrainLpips +try: + from torchmetrics.image.fid import NoTrainInceptionV3 +except ImportError: + from torchmetrics.image.fid import _NoTrainInceptionV3 as NoTrainInceptionV3 from utils.torch_utils import freeze_model from utils.print_utils import suppress_warnings from .models import I3D, MotionExtractor, CLIP, DINO, LAION, MUSIQ, RAFT, AMT_S diff --git a/algorithms/dfot/dfot_video.py b/algorithms/dfot/dfot_video.py index 2d41674..d3b7b43 100644 --- a/algorithms/dfot/dfot_video.py +++ b/algorithms/dfot/dfot_video.py @@ -537,7 +537,7 @@ def _predict_videos( sliding_context_len=self.cfg.tasks.prediction.sliding_context_len or self.max_tokens // 2, ) - xs_pred[:, keyframe_indices] = xs_pred_key + xs_pred[:, keyframe_indices] = xs_pred_key.clone() # if is_rank_zero: # uncomment to visualize history guidance # history_guidance.log(logger=self.logger) diff --git a/requirements.txt b/requirements.txt index d201dc6..49e9f71 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,13 +21,13 @@ internetarchive pytorchvideo scipy rotary_embedding_torch -torchmetrics[image]==0.11.4 +torchmetrics[image]>=0.11.4 # tfrecord[torch] # bezier timm zmq openai-clip -pyiqa==0.1.10 +pyiqa>=0.1.10 easydict zmq pytubefix