Skip to content

Commit 7659642

Browse files
Update netcorecv.py
1 parent c0e50e8 commit 7659642

1 file changed

Lines changed: 3 additions & 12 deletions

File tree

src/spotPython/torch/netcorecv.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ def __init__(self, lr, batch_size, epochs, k_folds):
1515
self.epochs = epochs
1616
self.k_folds = k_folds
1717
self.results = {}
18-
# TODO: move the following to the net initialization:
19-
# if torch.cuda.device_count() > 1:
20-
# print("We will use", torch.cuda.device_count(), "GPUs!")
21-
# self = nn.DataParallel(self)
2218

2319
# def evaluate_cv_old(self, dataset, shuffle=False):
2420
# try:
@@ -206,13 +202,9 @@ def evaluate_cv(self, dataset, shuffle=False):
206202

207203
def evaluate_hold_out(self, dataset, shuffle):
208204
lr = self.lr
209-
# del self.lr
210205
epochs = self.epochs
211-
# del self.epochs
212206
try:
213207
device = getDevice()
214-
# if torch.cuda.device_count() > 1:
215-
# self = nn.DataParallel(self)
216208
self.to(device)
217209
criterion = nn.CrossEntropyLoss()
218210
optimizer = optim.Adam(self.parameters(), lr=lr)
@@ -221,14 +213,13 @@ def evaluate_hold_out(self, dataset, shuffle):
221213
for epoch in range(epochs):
222214
self.train_hold_out(trainloader, criterion, optimizer, device=device, epoch=epoch)
223215
scheduler.step()
224-
df_eval = self.validate_hold_out(valloader=valloader, criterion=criterion, device=device)
216+
val_accuracy, val_loss = self.validate_hold_out(valloader=valloader, criterion=criterion, device=device)
217+
df_eval = val_loss
225218
df_preds = np.nan
226219
except Exception as err:
227220
print(f"Error in Net_Core_CV. Call to evaluate_hold_out() failed. {err=}, {type(err)=}")
228221
df_eval = np.nan
229222
df_preds = np.nan
230-
# self.lr = lr
231-
# self.epochs = epochs
232223
return df_eval, df_preds
233224

234225
def create_data_loaders(self, dataset, shuffle):
@@ -283,4 +274,4 @@ def validate_hold_out(self, valloader, criterion, device):
283274
loss = val_loss / val_steps
284275
print(f"Accuracy on hold-out set: {accuracy}")
285276
print(f"Loss on hold-out set: {loss}")
286-
return loss
277+
return accuracy, loss

0 commit comments

Comments
 (0)