Skip to content

Commit ff89581

Browse files
val or best_val
1 parent 2fb0791 commit ff89581

2 files changed

Lines changed: 166 additions & 40 deletions

File tree

notebooks/14_spot_ray_hpt_torch_cifar10.ipynb

Lines changed: 160 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": null,
15+
"execution_count": 1,
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
@@ -22,9 +22,20 @@
2222
},
2323
{
2424
"cell_type": "code",
25-
"execution_count": null,
26-
"metadata": {},
27-
"outputs": [],
25+
"execution_count": 2,
26+
"metadata": {},
27+
"outputs": [
28+
{
29+
"data": {
30+
"text/plain": [
31+
"'12-torch_p040025_1min_5init_2023-05-07_18-06-33'"
32+
]
33+
},
34+
"execution_count": 2,
35+
"metadata": {},
36+
"output_type": "execute_result"
37+
}
38+
],
2839
"source": [
2940
"import pickle\n",
3041
"import socket\n",
@@ -57,16 +68,26 @@
5768
},
5869
{
5970
"cell_type": "code",
60-
"execution_count": null,
61-
"metadata": {},
62-
"outputs": [],
71+
"execution_count": 3,
72+
"metadata": {},
73+
"outputs": [
74+
{
75+
"name": "stdout",
76+
"output_type": "stream",
77+
"text": [
78+
"spotPython 0.0.60\n",
79+
"spotRiver 0.0.92\n",
80+
"Note: you may need to restart the kernel to use updated packages.\n"
81+
]
82+
}
83+
],
6384
"source": [
6485
"pip list | grep \"spot[RiverPython]\""
6586
]
6687
},
6788
{
6889
"cell_type": "code",
69-
"execution_count": null,
90+
"execution_count": 4,
7091
"metadata": {},
7192
"outputs": [],
7293
"source": [
@@ -77,7 +98,7 @@
7798
},
7899
{
79100
"cell_type": "code",
80-
"execution_count": null,
101+
"execution_count": 5,
81102
"metadata": {},
82103
"outputs": [],
83104
"source": [
@@ -154,9 +175,18 @@
154175
},
155176
{
156177
"cell_type": "code",
157-
"execution_count": null,
158-
"metadata": {},
159-
"outputs": [],
178+
"execution_count": 6,
179+
"metadata": {},
180+
"outputs": [
181+
{
182+
"name": "stdout",
183+
"output_type": "stream",
184+
"text": [
185+
"2.0.0\n",
186+
"MPS device: mps\n"
187+
]
188+
}
189+
],
160190
"source": [
161191
"print(torch.__version__)\n",
162192
"# Check that MPS is available\n",
@@ -183,7 +213,7 @@
183213
},
184214
{
185215
"cell_type": "code",
186-
"execution_count": null,
216+
"execution_count": 7,
187217
"metadata": {},
188218
"outputs": [],
189219
"source": [
@@ -208,7 +238,7 @@
208238
},
209239
{
210240
"cell_type": "code",
211-
"execution_count": null,
241+
"execution_count": 8,
212242
"metadata": {},
213243
"outputs": [],
214244
"source": [
@@ -229,17 +259,36 @@
229259
},
230260
{
231261
"cell_type": "code",
232-
"execution_count": null,
233-
"metadata": {},
234-
"outputs": [],
262+
"execution_count": 9,
263+
"metadata": {},
264+
"outputs": [
265+
{
266+
"name": "stdout",
267+
"output_type": "stream",
268+
"text": [
269+
"Files already downloaded and verified\n",
270+
"Files already downloaded and verified\n"
271+
]
272+
},
273+
{
274+
"data": {
275+
"text/plain": [
276+
"((50000, 32, 32, 3), (10000, 32, 32, 3))"
277+
]
278+
},
279+
"execution_count": 9,
280+
"metadata": {},
281+
"output_type": "execute_result"
282+
}
283+
],
235284
"source": [
236285
"train, test = load_data()\n",
237286
"train.data.shape, test.data.shape"
238287
]
239288
},
240289
{
241290
"cell_type": "code",
242-
"execution_count": null,
291+
"execution_count": 10,
243292
"metadata": {},
244293
"outputs": [],
245294
"source": [
@@ -264,7 +313,7 @@
264313
},
265314
{
266315
"cell_type": "code",
267-
"execution_count": null,
316+
"execution_count": 11,
268317
"metadata": {},
269318
"outputs": [],
270319
"source": [
@@ -307,7 +356,7 @@
307356
},
308357
{
309358
"cell_type": "code",
310-
"execution_count": null,
359+
"execution_count": 12,
311360
"metadata": {},
312361
"outputs": [],
313362
"source": [
@@ -436,9 +485,49 @@
436485
},
437486
{
438487
"cell_type": "code",
439-
"execution_count": null,
440-
"metadata": {},
441-
"outputs": [],
488+
"execution_count": 13,
489+
"metadata": {},
490+
"outputs": [
491+
{
492+
"data": {
493+
"text/plain": [
494+
"{'l1': {'type': 'int',\n",
495+
" 'default': 5,\n",
496+
" 'transform': 'transform_power_2_int',\n",
497+
" 'lower': 2,\n",
498+
" 'upper': 9},\n",
499+
" 'l2': {'type': 'int',\n",
500+
" 'default': 5,\n",
501+
" 'transform': 'transform_power_2_int',\n",
502+
" 'lower': 2,\n",
503+
" 'upper': 9},\n",
504+
" 'lr': {'type': 'float',\n",
505+
" 'default': 0.001,\n",
506+
" 'transform': 'None',\n",
507+
" 'lower': 1e-05,\n",
508+
" 'upper': 0.01},\n",
509+
" 'batch_size': {'type': 'int',\n",
510+
" 'default': 4,\n",
511+
" 'transform': 'transform_power_2_int',\n",
512+
" 'lower': 1,\n",
513+
" 'upper': 4},\n",
514+
" 'epochs': {'type': 'int',\n",
515+
" 'default': 3,\n",
516+
" 'transform': 'transform_power_2_int',\n",
517+
" 'lower': 1,\n",
518+
" 'upper': 4},\n",
519+
" 'k_folds': {'type': 'int',\n",
520+
" 'default': 2,\n",
521+
" 'transform': 'None',\n",
522+
" 'lower': 0,\n",
523+
" 'upper': 0}}"
524+
]
525+
},
526+
"execution_count": 13,
527+
"metadata": {},
528+
"output_type": "execute_result"
529+
}
530+
],
442531
"source": [
443532
"fun_control = modify_hyper_parameter_bounds(fun_control, \"k_folds\", bounds=[0, 0])\n",
444533
"fun_control[\"core_model_hyper_dict\"]"
@@ -462,7 +551,7 @@
462551
},
463552
{
464553
"cell_type": "code",
465-
"execution_count": null,
554+
"execution_count": 14,
466555
"metadata": {},
467556
"outputs": [],
468557
"source": [
@@ -501,7 +590,7 @@
501590
},
502591
{
503592
"cell_type": "code",
504-
"execution_count": null,
593+
"execution_count": 15,
505594
"metadata": {},
506595
"outputs": [],
507596
"source": [
@@ -552,7 +641,7 @@
552641
},
553642
{
554643
"cell_type": "code",
555-
"execution_count": null,
644+
"execution_count": 16,
556645
"metadata": {},
557646
"outputs": [],
558647
"source": [
@@ -567,9 +656,24 @@
567656
},
568657
{
569658
"cell_type": "code",
570-
"execution_count": null,
571-
"metadata": {},
572-
"outputs": [],
659+
"execution_count": 17,
660+
"metadata": {},
661+
"outputs": [
662+
{
663+
"name": "stdout",
664+
"output_type": "stream",
665+
"text": [
666+
"| name | type | default | lower | upper | transform |\n",
667+
"|------------|--------|-----------|---------|---------|-----------------------|\n",
668+
"| l1 | int | 5 | 2 | 9 | transform_power_2_int |\n",
669+
"| l2 | int | 5 | 2 | 9 | transform_power_2_int |\n",
670+
"| lr | float | 0.001 | 1e-05 | 0.01 | None |\n",
671+
"| batch_size | int | 4 | 1 | 4 | transform_power_2_int |\n",
672+
"| epochs | int | 3 | 1 | 4 | transform_power_2_int |\n",
673+
"| k_folds | int | 2 | 0 | 0 | None |\n"
674+
]
675+
}
676+
],
573677
"source": [
574678
"print(gen_design_table(fun_control))"
575679
]
@@ -587,9 +691,20 @@
587691
},
588692
{
589693
"cell_type": "code",
590-
"execution_count": null,
591-
"metadata": {},
592-
"outputs": [],
694+
"execution_count": 18,
695+
"metadata": {},
696+
"outputs": [
697+
{
698+
"data": {
699+
"text/plain": [
700+
"array([[5.e+00, 5.e+00, 1.e-03, 4.e+00, 3.e+00, 2.e+00]])"
701+
]
702+
},
703+
"execution_count": 18,
704+
"metadata": {},
705+
"output_type": "execute_result"
706+
}
707+
],
593708
"source": [
594709
"from spotPython.hyperparameters.values import get_default_hyperparameters_as_array\n",
595710
"hyper_dict=TorchHyperDict().load()\n",
@@ -599,9 +714,19 @@
599714
},
600715
{
601716
"cell_type": "code",
602-
"execution_count": null,
603-
"metadata": {},
604-
"outputs": [],
717+
"execution_count": 19,
718+
"metadata": {},
719+
"outputs": [
720+
{
721+
"name": "stdout",
722+
"output_type": "stream",
723+
"text": [
724+
"Epoch: 1, Batch: 1000. Batch Size: 8. Training Loss: 2.061\n",
725+
"Epoch: 1, Batch: 2000. Batch Size: 8. Training Loss: 0.885\n",
726+
"Epoch: 1, Batch: 3000. Batch Size: 8. Training Loss: 0.549\n"
727+
]
728+
}
729+
],
605730
"source": [
606731
"spot_torch = spot.Spot(fun=fun,\n",
607732
" lower = lower,\n",

src/spotPython/torch/netcore.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,20 @@ def evaluate_hold_out(self, dataset, shuffle, test_dataset=None):
8888
device = getDevice()
8989
self.to(device)
9090
criterion = nn.CrossEntropyLoss()
91-
optimizer = optim.Adam(self.parameters(), lr=lr)
91+
# TODO: optimizer = optim.Adam(self.parameters(), lr=lr)
92+
optimizer = optim.SGD(self.parameters(), lr=lr, momentum=0.9)
9293
if test_dataset is None:
9394
trainloader, valloader = self.create_train_val_data_loaders(dataset, shuffle)
9495
else:
9596
trainloader, valloader = self.create_train_test_data_loaders(dataset, shuffle, test_dataset)
96-
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
97+
# TODO: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
9798
# Early stopping parameters
9899
patience = 5
99100
best_val_loss = float("inf")
100101
counter = 0
101102
for epoch in range(epochs):
102103
self.train_hold_out(trainloader, criterion, optimizer, device=device, epoch=epoch)
103-
scheduler.step()
104+
# TODO: scheduler.step()
104105
# Early stopping check
105106
val_accuracy, val_loss = self.validate_hold_out(valloader=valloader, criterion=criterion, device=device)
106107
if val_loss < best_val_loss:
@@ -111,13 +112,13 @@ def evaluate_hold_out(self, dataset, shuffle, test_dataset=None):
111112
if counter >= patience:
112113
print(f"Early stopping at epoch {epoch}")
113114
break
114-
df_eval = best_val_loss
115+
df_eval = val_loss
115116
df_preds = np.nan
116117
except Exception as err:
117118
print(f"Error in Net_Core. Call to evaluate_hold_out() failed. {err=}, {type(err)=}")
118119
df_eval = np.nan
119120
df_preds = np.nan
120-
print(f"Returned to Spot: Best validation loss: {df_eval}")
121+
print(f"Returned to Spot: Validation loss: {df_eval}")
121122
return df_eval, df_preds
122123

123124
def create_train_val_data_loaders(self, dataset, shuffle):

0 commit comments

Comments
 (0)