diff --git a/algorithms/common/metrics/video/base_fid.py b/algorithms/common/metrics/video/base_fid.py index d297e83..b4cf33b 100644 --- a/algorithms/common/metrics/video/base_fid.py +++ b/algorithms/common/metrics/video/base_fid.py @@ -3,7 +3,10 @@ import torch from torch import Tensor from torchmetrics import Metric -from torchmetrics.image import FrechetInceptionDistance as _FrechetInceptionDistance +try: + from torchmetrics.image import FrechetInceptionDistance as _FrechetInceptionDistance +except ImportError: + from torchmetrics.image.fid import FrechetInceptionDistance as _FrechetInceptionDistance from .shared_registry import SharedVideoMetricModelRegistry diff --git a/algorithms/common/metrics/video/lpips.py b/algorithms/common/metrics/video/lpips.py index 1b4090b..d5d5352 100644 --- a/algorithms/common/metrics/video/lpips.py +++ b/algorithms/common/metrics/video/lpips.py @@ -4,9 +4,18 @@ from torchmetrics import Metric from torchmetrics.image.lpip import ( LearnedPerceptualImagePatchSimilarity as _LearnedPerceptualImagePatchSimilarity, - _valid_img, ) -from torchmetrics.utilities.imports import _LPIPS_AVAILABLE +try: + from torchmetrics.image.lpip import _valid_img +except ImportError: + def _valid_img(img, normalize): + if normalize: + return img.min() >= 0.0 and img.max() <= 1.0 and img.shape[1] == 3 + return img.min() >= -1.0 and img.max() <= 1.0 and img.shape[1] == 3 +try: + from torchmetrics.utilities.imports import _LPIPS_AVAILABLE +except ImportError: + _LPIPS_AVAILABLE = True from .shared_registry import SharedVideoMetricModelRegistry from .types import VideoMetricModelType diff --git a/algorithms/common/metrics/video/shared_registry.py b/algorithms/common/metrics/video/shared_registry.py index 79a6525..3489277 100644 --- a/algorithms/common/metrics/video/shared_registry.py +++ b/algorithms/common/metrics/video/shared_registry.py @@ -1,7 +1,13 @@ from typing import Dict, List from torch import nn, Tensor -from torchmetrics.image.lpip import NoTrainLpips -from torchmetrics.image.fid import NoTrainInceptionV3 +try: + from torchmetrics.image.lpip import NoTrainLpips +except ImportError: + from torchmetrics.image.lpip import _NoTrainLpips as NoTrainLpips +try: + from torchmetrics.image.fid import NoTrainInceptionV3 +except ImportError: + from torchmetrics.image.fid import _NoTrainInceptionV3 as NoTrainInceptionV3 from utils.torch_utils import freeze_model from utils.print_utils import suppress_warnings from .models import I3D, MotionExtractor, CLIP, DINO, LAION, MUSIQ, RAFT, AMT_S diff --git a/algorithms/dfot/dfot_video.py b/algorithms/dfot/dfot_video.py index 2d41674..d3b7b43 100644 --- a/algorithms/dfot/dfot_video.py +++ b/algorithms/dfot/dfot_video.py @@ -537,7 +537,7 @@ def _predict_videos( sliding_context_len=self.cfg.tasks.prediction.sliding_context_len or self.max_tokens // 2, ) - xs_pred[:, keyframe_indices] = xs_pred_key + xs_pred[:, keyframe_indices] = xs_pred_key.clone() # if is_rank_zero: # uncomment to visualize history guidance # history_guidance.log(logger=self.logger) diff --git a/requirements.txt b/requirements.txt index d201dc6..49e9f71 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,13 +21,13 @@ internetarchive pytorchvideo scipy rotary_embedding_torch -torchmetrics[image]==0.11.4 +torchmetrics[image]>=0.11.4 # tfrecord[torch] # bezier timm zmq openai-clip -pyiqa==0.1.10 +pyiqa>=0.1.10 easydict zmq pytubefix