-
Notifications
You must be signed in to change notification settings - Fork 124
Expand file tree
/
Copy pathpsnrmeter.py
More file actions
28 lines (22 loc) · 723 Bytes
/
psnrmeter.py
File metadata and controls
28 lines (22 loc) · 723 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from math import log10
import torch
from torchnet.meter import meter
class PSNRMeter(meter.Meter):
def __init__(self):
super(PSNRMeter, self).__init__()
self.reset()
def reset(self):
self.n = 0
self.sesum = 0.0
def add(self, output, target):
if not torch.is_tensor(output) and not torch.is_tensor(target):
output = torch.from_numpy(output)
target = torch.from_numpy(target)
output = output.cpu()
target = target.cpu()
self.n += output.numel()
self.sesum += torch.sum((output - target) ** 2)
def value(self):
mse = self.sesum / max(1, self.n)
psnr = 10 * log10(1 / mse)
return psnr