@@ -15,10 +15,6 @@ def __init__(self, lr, batch_size, epochs, k_folds):
1515 self .epochs = epochs
1616 self .k_folds = k_folds
1717 self .results = {}
18- # TODO: move the following to the net initialization:
19- # if torch.cuda.device_count() > 1:
20- # print("We will use", torch.cuda.device_count(), "GPUs!")
21- # self = nn.DataParallel(self)
2218
2319 # def evaluate_cv_old(self, dataset, shuffle=False):
2420 # try:
@@ -206,13 +202,9 @@ def evaluate_cv(self, dataset, shuffle=False):
206202
207203 def evaluate_hold_out (self , dataset , shuffle ):
208204 lr = self .lr
209- # del self.lr
210205 epochs = self .epochs
211- # del self.epochs
212206 try :
213207 device = getDevice ()
214- # if torch.cuda.device_count() > 1:
215- # self = nn.DataParallel(self)
216208 self .to (device )
217209 criterion = nn .CrossEntropyLoss ()
218210 optimizer = optim .Adam (self .parameters (), lr = lr )
@@ -221,14 +213,13 @@ def evaluate_hold_out(self, dataset, shuffle):
221213 for epoch in range (epochs ):
222214 self .train_hold_out (trainloader , criterion , optimizer , device = device , epoch = epoch )
223215 scheduler .step ()
224- df_eval = self .validate_hold_out (valloader = valloader , criterion = criterion , device = device )
216+ val_accuracy , val_loss = self .validate_hold_out (valloader = valloader , criterion = criterion , device = device )
217+ df_eval = val_loss
225218 df_preds = np .nan
226219 except Exception as err :
227220 print (f"Error in Net_Core_CV. Call to evaluate_hold_out() failed. { err = } , { type (err )= } " )
228221 df_eval = np .nan
229222 df_preds = np .nan
230- # self.lr = lr
231- # self.epochs = epochs
232223 return df_eval , df_preds
233224
234225 def create_data_loaders (self , dataset , shuffle ):
@@ -283,4 +274,4 @@ def validate_hold_out(self, valloader, criterion, device):
283274 loss = val_loss / val_steps
284275 print (f"Accuracy on hold-out set: { accuracy } " )
285276 print (f"Loss on hold-out set: { loss } " )
286- return loss
277+ return accuracy , loss
0 commit comments