-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathloss_functions.py
More file actions
67 lines (50 loc) · 1.9 KB
/
loss_functions.py
File metadata and controls
67 lines (50 loc) · 1.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from abc import ABC, abstractmethod
import numpy as np
from tensorflow.keras.losses import BinaryCrossentropy # type: ignore
def safe_Y_hat(Y_hat, epsilon=1e-15):
return np.clip(Y_hat, epsilon, 1 - epsilon)
class Loss(ABC):
def __init__(self, C=0.1, regularization=False) -> None:
self.C = C
self.regularization = regularization
def l2_loss(self, Y_hat, Y, params):
n = Y_hat.shape[1]
loss = self._loss(Y_hat, Y)
param_losses = [np.sum(param**2) for param in params]
return loss + self.C * (sum(param_losses) / (2 * n))
def loss(self, Y_hat, Y, params):
if self.regularization:
return self.l2_loss(Y_hat, Y, params)
else:
return self._loss(Y_hat, Y)
def dloss_dy_hat(self, Y_hat, Y):
grad_loss = self._dloss_dy_hat(Y_hat, Y)
return grad_loss
def dreg_loss_dparam(self, Y_hat, param):
n = Y_hat.shape[1]
return self.C * param / n
@abstractmethod
def _loss(self, Y_hat, Y) -> float:
raise NotImplementedError()
@abstractmethod
def _dloss_dy_hat(self, Y_hat, Y) -> np.array:
raise NotImplementedError()
class BinaryCrossEntropyLoss(Loss):
def _loss(self, Y_hat, Y) -> float:
return BinaryCrossentropy(from_logits=False)(Y, Y_hat).numpy()
# # Y_hat are the predictions
# Y_hat = safe_Y_hat(Y_hat)
# loss = -1 * (Y * np.log(Y_hat) + (1 - Y) * np.log(1 - Y_hat))
# return np.mean(loss)
def _dloss_dy_hat(self, Y_hat, Y) -> np.array:
n = Y_hat.shape[1]
numerator = Y_hat - Y
return (numerator) / n
class MSE(Loss):
def _loss(self, Y_hat, Y) -> float:
# Y_hat are the predictions
squared_error = (Y_hat - Y) ** 2
return np.mean(squared_error)
def _dloss_dy_hat(self, Y_hat, Y) -> np.array:
n = Y_hat.shape[1]
return 2 * (Y_hat - Y) / n