Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion algorithms/common/metrics/video/base_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.image import FrechetInceptionDistance as _FrechetInceptionDistance
try:
from torchmetrics.image import FrechetInceptionDistance as _FrechetInceptionDistance
except ImportError:
from torchmetrics.image.fid import FrechetInceptionDistance as _FrechetInceptionDistance
from .shared_registry import SharedVideoMetricModelRegistry


Expand Down
13 changes: 11 additions & 2 deletions algorithms/common/metrics/video/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,18 @@
from torchmetrics import Metric
from torchmetrics.image.lpip import (
LearnedPerceptualImagePatchSimilarity as _LearnedPerceptualImagePatchSimilarity,
_valid_img,
)
from torchmetrics.utilities.imports import _LPIPS_AVAILABLE
try:
from torchmetrics.image.lpip import _valid_img
except ImportError:
def _valid_img(img, normalize):
if normalize:
return img.min() >= 0.0 and img.max() <= 1.0 and img.shape[1] == 3
return img.min() >= -1.0 and img.max() <= 1.0 and img.shape[1] == 3
try:
from torchmetrics.utilities.imports import _LPIPS_AVAILABLE
except ImportError:
_LPIPS_AVAILABLE = True
from .shared_registry import SharedVideoMetricModelRegistry
from .types import VideoMetricModelType

Expand Down
10 changes: 8 additions & 2 deletions algorithms/common/metrics/video/shared_registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import Dict, List
from torch import nn, Tensor
from torchmetrics.image.lpip import NoTrainLpips
from torchmetrics.image.fid import NoTrainInceptionV3
try:
from torchmetrics.image.lpip import NoTrainLpips
except ImportError:
from torchmetrics.image.lpip import _NoTrainLpips as NoTrainLpips
try:
from torchmetrics.image.fid import NoTrainInceptionV3
except ImportError:
from torchmetrics.image.fid import _NoTrainInceptionV3 as NoTrainInceptionV3
from utils.torch_utils import freeze_model
from utils.print_utils import suppress_warnings
from .models import I3D, MotionExtractor, CLIP, DINO, LAION, MUSIQ, RAFT, AMT_S
Expand Down
2 changes: 1 addition & 1 deletion algorithms/dfot/dfot_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def _predict_videos(
sliding_context_len=self.cfg.tasks.prediction.sliding_context_len
or self.max_tokens // 2,
)
xs_pred[:, keyframe_indices] = xs_pred_key
xs_pred[:, keyframe_indices] = xs_pred_key.clone()
# if is_rank_zero: # uncomment to visualize history guidance
# history_guidance.log(logger=self.logger)

Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ internetarchive
pytorchvideo
scipy
rotary_embedding_torch
torchmetrics[image]==0.11.4
torchmetrics[image]>=0.11.4
# tfrecord[torch]
# bezier
timm
zmq
openai-clip
pyiqa==0.1.10
pyiqa>=0.1.10
easydict
zmq
pytubefix
Expand Down