Skip to content

Commit 23a1762

Browse files
v0.056 train test val
1 parent 43197f4 commit 23a1762

7 files changed

Lines changed: 172 additions & 49 deletions

File tree

notebooks/11_spot_hpt_torch_fashion_mnist.ipynb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@
421421
"fun = HyperTorch(seed=123, log_level=50).fun_torch\n",
422422
"weights = 1.0\n",
423423
"shuffle = True\n",
424+
"eval = \"train_hold_out\"\n",
424425
"\n",
425426
"fun_control.update({\n",
426427
" \"data_dir\": None,\n",
@@ -434,6 +435,7 @@
434435
" \"metric\": None,\n",
435436
" \"metric_sklearn\": None,\n",
436437
" \"shuffle\": shuffle,\n",
438+
" \"eval\": eval\n",
437439
" })"
438440
]
439441
},

notebooks/12_spot_hpt_torch_cifar10.ipynb

Lines changed: 141 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": null,
15+
"execution_count": 1,
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
@@ -26,9 +26,20 @@
2626
},
2727
{
2828
"cell_type": "code",
29-
"execution_count": null,
29+
"execution_count": 2,
3030
"metadata": {},
31-
"outputs": [],
31+
"outputs": [
32+
{
33+
"data": {
34+
"text/plain": [
35+
"'12-torch_p040025_1min_5init_2023-05-07_10-00-44'"
36+
]
37+
},
38+
"execution_count": 2,
39+
"metadata": {},
40+
"output_type": "execute_result"
41+
}
42+
],
3243
"source": [
3344
"import pickle\n",
3445
"import socket\n",
@@ -61,16 +72,26 @@
6172
},
6273
{
6374
"cell_type": "code",
64-
"execution_count": null,
75+
"execution_count": 3,
6576
"metadata": {},
66-
"outputs": [],
77+
"outputs": [
78+
{
79+
"name": "stdout",
80+
"output_type": "stream",
81+
"text": [
82+
"spotPython 0.0.56\n",
83+
"spotRiver 0.0.92\n",
84+
"Note: you may need to restart the kernel to use updated packages.\n"
85+
]
86+
}
87+
],
6788
"source": [
6889
"pip list | grep \"spot[RiverPython]\""
6990
]
7091
},
7192
{
7293
"cell_type": "code",
73-
"execution_count": null,
94+
"execution_count": 4,
7495
"metadata": {},
7596
"outputs": [],
7697
"source": [
@@ -81,7 +102,7 @@
81102
},
82103
{
83104
"cell_type": "code",
84-
"execution_count": null,
105+
"execution_count": 5,
85106
"metadata": {},
86107
"outputs": [],
87108
"source": [
@@ -158,9 +179,18 @@
158179
},
159180
{
160181
"cell_type": "code",
161-
"execution_count": null,
182+
"execution_count": 6,
162183
"metadata": {},
163-
"outputs": [],
184+
"outputs": [
185+
{
186+
"name": "stdout",
187+
"output_type": "stream",
188+
"text": [
189+
"2.0.0\n",
190+
"MPS device: mps\n"
191+
]
192+
}
193+
],
164194
"source": [
165195
"print(torch.__version__)\n",
166196
"# Check that MPS is available\n",
@@ -187,7 +217,7 @@
187217
},
188218
{
189219
"cell_type": "code",
190-
"execution_count": null,
220+
"execution_count": 7,
191221
"metadata": {},
192222
"outputs": [],
193223
"source": [
@@ -212,7 +242,7 @@
212242
},
213243
{
214244
"cell_type": "code",
215-
"execution_count": null,
245+
"execution_count": 8,
216246
"metadata": {},
217247
"outputs": [],
218248
"source": [
@@ -241,7 +271,7 @@
241271
},
242272
{
243273
"cell_type": "code",
244-
"execution_count": null,
274+
"execution_count": 9,
245275
"metadata": {},
246276
"outputs": [],
247277
"source": [
@@ -262,7 +292,7 @@
262292
},
263293
{
264294
"cell_type": "code",
265-
"execution_count": null,
295+
"execution_count": 10,
266296
"metadata": {},
267297
"outputs": [],
268298
"source": [
@@ -298,7 +328,7 @@
298328
},
299329
{
300330
"cell_type": "code",
301-
"execution_count": null,
331+
"execution_count": 11,
302332
"metadata": {},
303333
"outputs": [],
304334
"source": [
@@ -319,17 +349,36 @@
319349
},
320350
{
321351
"cell_type": "code",
322-
"execution_count": null,
352+
"execution_count": 12,
323353
"metadata": {},
324-
"outputs": [],
354+
"outputs": [
355+
{
356+
"name": "stdout",
357+
"output_type": "stream",
358+
"text": [
359+
"Files already downloaded and verified\n",
360+
"Files already downloaded and verified\n"
361+
]
362+
},
363+
{
364+
"data": {
365+
"text/plain": [
366+
"((50000, 32, 32, 3), (10000, 32, 32, 3))"
367+
]
368+
},
369+
"execution_count": 12,
370+
"metadata": {},
371+
"output_type": "execute_result"
372+
}
373+
],
325374
"source": [
326375
"train, test = load_data()\n",
327376
"train.data.shape, test.data.shape"
328377
]
329378
},
330379
{
331380
"cell_type": "code",
332-
"execution_count": null,
381+
"execution_count": 13,
333382
"metadata": {},
334383
"outputs": [],
335384
"source": [
@@ -352,7 +401,7 @@
352401
},
353402
{
354403
"cell_type": "code",
355-
"execution_count": null,
404+
"execution_count": 14,
356405
"metadata": {},
357406
"outputs": [],
358407
"source": [
@@ -378,7 +427,7 @@
378427
},
379428
{
380429
"cell_type": "code",
381-
"execution_count": null,
430+
"execution_count": 15,
382431
"metadata": {},
383432
"outputs": [],
384433
"source": [
@@ -407,7 +456,7 @@
407456
},
408457
{
409458
"cell_type": "code",
410-
"execution_count": null,
459+
"execution_count": 16,
411460
"metadata": {},
412461
"outputs": [],
413462
"source": [
@@ -425,16 +474,7 @@
425474
},
426475
{
427476
"cell_type": "code",
428-
"execution_count": null,
429-
"metadata": {},
430-
"outputs": [],
431-
"source": [
432-
"fun_control"
433-
]
434-
},
435-
{
436-
"cell_type": "code",
437-
"execution_count": null,
477+
"execution_count": 17,
438478
"metadata": {},
439479
"outputs": [],
440480
"source": [
@@ -477,13 +517,14 @@
477517
},
478518
{
479519
"cell_type": "code",
480-
"execution_count": null,
520+
"execution_count": 18,
481521
"metadata": {},
482522
"outputs": [],
483523
"source": [
484524
"fun = HyperTorch(seed=123, log_level=50).fun_torch\n",
485525
"weights = 1.0\n",
486526
"shuffle = True\n",
527+
"eval = \"train_hold_out\"\n",
487528
"\n",
488529
"fun_control.update({\n",
489530
" \"data_dir\": None,\n",
@@ -497,6 +538,7 @@
497538
" \"metric\": None,\n",
498539
" \"metric_sklearn\": None,\n",
499540
" \"shuffle\": shuffle,\n",
541+
" \"eval\": eval\n",
500542
" })"
501543
]
502544
},
@@ -526,7 +568,7 @@
526568
},
527569
{
528570
"cell_type": "code",
529-
"execution_count": null,
571+
"execution_count": 19,
530572
"metadata": {},
531573
"outputs": [],
532574
"source": [
@@ -541,9 +583,24 @@
541583
},
542584
{
543585
"cell_type": "code",
544-
"execution_count": null,
586+
"execution_count": 20,
545587
"metadata": {},
546-
"outputs": [],
588+
"outputs": [
589+
{
590+
"name": "stdout",
591+
"output_type": "stream",
592+
"text": [
593+
"| name | type | default | lower | upper | transform |\n",
594+
"|------------|--------|-----------|---------|---------|-----------------------|\n",
595+
"| l1 | int | 5 | 2 | 9 | transform_power_2_int |\n",
596+
"| l2 | int | 5 | 2 | 9 | transform_power_2_int |\n",
597+
"| lr | float | 0.001 | 1e-05 | 0.01 | None |\n",
598+
"| batch_size | int | 4 | 1 | 4 | transform_power_2_int |\n",
599+
"| epochs | int | 3 | 1 | 4 | transform_power_2_int |\n",
600+
"| k_folds | int | 2 | 0 | 0 | None |\n"
601+
]
602+
}
603+
],
547604
"source": [
548605
"print(gen_design_table(fun_control))"
549606
]
@@ -561,9 +618,20 @@
561618
},
562619
{
563620
"cell_type": "code",
564-
"execution_count": null,
621+
"execution_count": 21,
565622
"metadata": {},
566-
"outputs": [],
623+
"outputs": [
624+
{
625+
"data": {
626+
"text/plain": [
627+
"array([[5.e+00, 5.e+00, 1.e-03, 4.e+00, 3.e+00, 2.e+00]])"
628+
]
629+
},
630+
"execution_count": 21,
631+
"metadata": {},
632+
"output_type": "execute_result"
633+
}
634+
],
567635
"source": [
568636
"from spotPython.hyperparameters.values import get_default_hyperparameters_as_array\n",
569637
"hyper_dict=TorchHyperDict().load()\n",
@@ -573,9 +641,17 @@
573641
},
574642
{
575643
"cell_type": "code",
576-
"execution_count": null,
644+
"execution_count": 22,
577645
"metadata": {},
578-
"outputs": [],
646+
"outputs": [
647+
{
648+
"name": "stdout",
649+
"output_type": "stream",
650+
"text": [
651+
"[1, 2000] loss: 2.306\n"
652+
]
653+
}
654+
],
579655
"source": [
580656
"spot_torch = spot.Spot(fun=fun,\n",
581657
" lower = lower,\n",
@@ -799,6 +875,33 @@
799875
"min(spot_torch.y), max(spot_torch.y)"
800876
]
801877
},
878+
{
879+
"cell_type": "code",
880+
"execution_count": null,
881+
"metadata": {},
882+
"outputs": [],
883+
"source": [
884+
"trainset = fun_control[\"train\"]"
885+
]
886+
},
887+
{
888+
"cell_type": "code",
889+
"execution_count": null,
890+
"metadata": {},
891+
"outputs": [],
892+
"source": [
893+
"model_default.evaluate_hold_out(dataset=trainset, shuffle=False, test_dataset=testset)"
894+
]
895+
},
896+
{
897+
"cell_type": "code",
898+
"execution_count": null,
899+
"metadata": {},
900+
"outputs": [],
901+
"source": [
902+
"model_spot.evaluate_hold_out(dataset=trainset, shuffle=False, test_dataset=testset)"
903+
]
904+
},
802905
{
803906
"attachments": {},
804907
"cell_type": "markdown",

notebooks/13_spot_hpt_torch_cv_fashion_mnist.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@
381381
"fun = HyperTorch(seed=123, log_level=50).fun_torch\n",
382382
"weights = -1.0\n",
383383
"shuffle = True\n",
384-
"eval=\"cv\"\n",
384+
"eval=\"train_cv\"\n",
385385
"\n",
386386
"fun_control.update({\n",
387387
" \"data_dir\": None,\n",

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.0.55"
10+
version = "0.0.56"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

0 commit comments

Comments
 (0)