Skip to content

Commit 0c97b23

Browse files
v0.0.52
DataParallel cuda handling
1 parent 156f929 commit 0c97b23

3 files changed

Lines changed: 203 additions & 35 deletions

File tree

notebooks/11_spot_hpt_torch_fashion_mnist.ipynb

Lines changed: 196 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": null,
15+
"execution_count": 1,
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
@@ -22,9 +22,20 @@
2222
},
2323
{
2424
"cell_type": "code",
25-
"execution_count": null,
25+
"execution_count": 2,
2626
"metadata": {},
27-
"outputs": [],
27+
"outputs": [
28+
{
29+
"data": {
30+
"text/plain": [
31+
"'11-torch_p040025_1min_5init_2023-05-06_20-54-49'"
32+
]
33+
},
34+
"execution_count": 2,
35+
"metadata": {},
36+
"output_type": "execute_result"
37+
}
38+
],
2839
"source": [
2940
"import pickle\n",
3041
"import socket\n",
@@ -57,16 +68,26 @@
5768
},
5869
{
5970
"cell_type": "code",
60-
"execution_count": null,
71+
"execution_count": 3,
6172
"metadata": {},
62-
"outputs": [],
73+
"outputs": [
74+
{
75+
"name": "stdout",
76+
"output_type": "stream",
77+
"text": [
78+
"spotPython 0.0.51\n",
79+
"spotRiver 0.0.92\n",
80+
"Note: you may need to restart the kernel to use updated packages.\n"
81+
]
82+
}
83+
],
6384
"source": [
6485
"pip list | grep \"spot[RiverPython]\""
6586
]
6687
},
6788
{
6889
"cell_type": "code",
69-
"execution_count": null,
90+
"execution_count": 4,
7091
"metadata": {},
7192
"outputs": [],
7293
"source": [
@@ -77,7 +98,7 @@
7798
},
7899
{
79100
"cell_type": "code",
80-
"execution_count": null,
101+
"execution_count": 5,
81102
"metadata": {},
82103
"outputs": [],
83104
"source": [
@@ -154,9 +175,18 @@
154175
},
155176
{
156177
"cell_type": "code",
157-
"execution_count": null,
178+
"execution_count": 6,
158179
"metadata": {},
159-
"outputs": [],
180+
"outputs": [
181+
{
182+
"name": "stdout",
183+
"output_type": "stream",
184+
"text": [
185+
"2.0.0\n",
186+
"MPS device: mps\n"
187+
]
188+
}
189+
],
160190
"source": [
161191
"print(torch.__version__)\n",
162192
"# Check that MPS is available\n",
@@ -183,7 +213,7 @@
183213
},
184214
{
185215
"cell_type": "code",
186-
"execution_count": null,
216+
"execution_count": 7,
187217
"metadata": {},
188218
"outputs": [],
189219
"source": [
@@ -200,7 +230,7 @@
200230
},
201231
{
202232
"cell_type": "code",
203-
"execution_count": null,
233+
"execution_count": 8,
204234
"metadata": {},
205235
"outputs": [],
206236
"source": [
@@ -224,17 +254,28 @@
224254
},
225255
{
226256
"cell_type": "code",
227-
"execution_count": null,
257+
"execution_count": 9,
228258
"metadata": {},
229-
"outputs": [],
259+
"outputs": [
260+
{
261+
"data": {
262+
"text/plain": [
263+
"(torch.Size([60000, 28, 28]), torch.Size([10000, 28, 28]))"
264+
]
265+
},
266+
"execution_count": 9,
267+
"metadata": {},
268+
"output_type": "execute_result"
269+
}
270+
],
230271
"source": [
231272
"train, test = load_data()\n",
232273
"train.data.shape, test.data.shape"
233274
]
234275
},
235276
{
236277
"cell_type": "code",
237-
"execution_count": null,
278+
"execution_count": 10,
238279
"metadata": {},
239280
"outputs": [],
240281
"source": [
@@ -257,7 +298,7 @@
257298
},
258299
{
259300
"cell_type": "code",
260-
"execution_count": null,
301+
"execution_count": 11,
261302
"metadata": {},
262303
"outputs": [],
263304
"source": [
@@ -283,7 +324,7 @@
283324
},
284325
{
285326
"cell_type": "code",
286-
"execution_count": null,
327+
"execution_count": 12,
287328
"metadata": {},
288329
"outputs": [],
289330
"source": [
@@ -312,7 +353,7 @@
312353
},
313354
{
314355
"cell_type": "code",
315-
"execution_count": null,
356+
"execution_count": 13,
316357
"metadata": {},
317358
"outputs": [],
318359
"source": [
@@ -330,7 +371,7 @@
330371
},
331372
{
332373
"cell_type": "code",
333-
"execution_count": null,
374+
"execution_count": 14,
334375
"metadata": {},
335376
"outputs": [],
336377
"source": [
@@ -373,7 +414,7 @@
373414
},
374415
{
375416
"cell_type": "code",
376-
"execution_count": null,
417+
"execution_count": 15,
377418
"metadata": {},
378419
"outputs": [],
379420
"source": [
@@ -422,7 +463,7 @@
422463
},
423464
{
424465
"cell_type": "code",
425-
"execution_count": null,
466+
"execution_count": 16,
426467
"metadata": {},
427468
"outputs": [],
428469
"source": [
@@ -437,9 +478,24 @@
437478
},
438479
{
439480
"cell_type": "code",
440-
"execution_count": null,
481+
"execution_count": 17,
441482
"metadata": {},
442-
"outputs": [],
483+
"outputs": [
484+
{
485+
"name": "stdout",
486+
"output_type": "stream",
487+
"text": [
488+
"| name | type | default | lower | upper | transform |\n",
489+
"|------------|--------|-----------|---------|---------|-----------------------|\n",
490+
"| l1 | int | 5 | 2 | 9 | transform_power_2_int |\n",
491+
"| l2 | int | 5 | 2 | 9 | transform_power_2_int |\n",
492+
"| lr | float | 0.001 | 0.0001 | 0.1 | None |\n",
493+
"| batch_size | int | 4 | 1 | 4 | transform_power_2_int |\n",
494+
"| epochs | int | 3 | 1 | 4 | transform_power_2_int |\n",
495+
"| k_folds | int | 2 | 0 | 0 | None |\n"
496+
]
497+
}
498+
],
443499
"source": [
444500
"print(gen_design_table(fun_control))"
445501
]
@@ -457,9 +513,20 @@
457513
},
458514
{
459515
"cell_type": "code",
460-
"execution_count": null,
516+
"execution_count": 18,
461517
"metadata": {},
462-
"outputs": [],
518+
"outputs": [
519+
{
520+
"data": {
521+
"text/plain": [
522+
"array([[5.e+00, 5.e+00, 1.e-03, 4.e+00, 3.e+00, 2.e+00]])"
523+
]
524+
},
525+
"execution_count": 18,
526+
"metadata": {},
527+
"output_type": "execute_result"
528+
}
529+
],
463530
"source": [
464531
"X0 = get_default_values(fun_control)\n",
465532
"river_hyper_dict_default = fun_control[\"core_model_hyper_dict\"]\n",
@@ -471,9 +538,20 @@
471538
},
472539
{
473540
"cell_type": "code",
474-
"execution_count": null,
541+
"execution_count": 19,
475542
"metadata": {},
476-
"outputs": [],
543+
"outputs": [
544+
{
545+
"data": {
546+
"text/plain": [
547+
"array([[5.e+00, 5.e+00, 1.e-03, 4.e+00, 3.e+00, 2.e+00]])"
548+
]
549+
},
550+
"execution_count": 19,
551+
"metadata": {},
552+
"output_type": "execute_result"
553+
}
554+
],
477555
"source": [
478556
"from spotPython.hyperparameters.values import get_default_hyperparameters_as_array\n",
479557
"hyper_dict=TorchHyperDict().load()\n",
@@ -483,9 +561,98 @@
483561
},
484562
{
485563
"cell_type": "code",
486-
"execution_count": null,
487-
"metadata": {},
488-
"outputs": [],
564+
"execution_count": 20,
565+
"metadata": {},
566+
"outputs": [
567+
{
568+
"name": "stdout",
569+
"output_type": "stream",
570+
"text": [
571+
"[1, 2000] loss: 2.452\n",
572+
"[1, 4000] loss: 1.163\n",
573+
"[2, 2000] loss: 2.328\n",
574+
"[2, 4000] loss: 1.163\n",
575+
"[3, 2000] loss: 2.328\n",
576+
"[3, 4000] loss: 1.164\n",
577+
"[4, 2000] loss: 2.329\n",
578+
"[4, 4000] loss: 1.163\n",
579+
"Accuracy on hold-out set: 0.10116666666666667\n",
580+
"Loss on hold-out set: 2.320516836643219\n",
581+
"[1, 2000] loss: 2.351\n",
582+
"[1, 4000] loss: 1.166\n",
583+
"[1, 6000] loss: 0.777\n",
584+
"[1, 8000] loss: 0.583\n",
585+
"[1, 10000] loss: 0.466\n",
586+
"[1, 12000] loss: 0.388\n",
587+
"[1, 14000] loss: 0.333\n",
588+
"[1, 16000] loss: 0.291\n",
589+
"[1, 18000] loss: 0.259\n",
590+
"[2, 2000] loss: 2.330\n",
591+
"[2, 4000] loss: 1.164\n",
592+
"[2, 6000] loss: 0.776\n",
593+
"[2, 8000] loss: 0.583\n",
594+
"[2, 10000] loss: 0.467\n",
595+
"[2, 12000] loss: 0.388\n",
596+
"[2, 14000] loss: 0.333\n",
597+
"[2, 16000] loss: 0.291\n",
598+
"[2, 18000] loss: 0.259\n",
599+
"[3, 2000] loss: 2.332\n",
600+
"[3, 4000] loss: 1.165\n",
601+
"[3, 6000] loss: 0.777\n",
602+
"[3, 8000] loss: 0.582\n",
603+
"[3, 10000] loss: 0.467\n",
604+
"[3, 12000] loss: 0.389\n",
605+
"[3, 14000] loss: 0.333\n",
606+
"[3, 16000] loss: 0.292\n",
607+
"[3, 18000] loss: 0.259\n",
608+
"[4, 2000] loss: 2.334\n",
609+
"[4, 4000] loss: 1.165\n",
610+
"[4, 6000] loss: 0.778\n",
611+
"[4, 8000] loss: 0.583\n",
612+
"[4, 10000] loss: 0.466\n",
613+
"[4, 12000] loss: 0.389\n",
614+
"[4, 14000] loss: 0.333\n",
615+
"[4, 16000] loss: 0.291\n",
616+
"[4, 18000] loss: 0.259\n",
617+
"[5, 2000] loss: 2.327\n",
618+
"[5, 4000] loss: 1.169\n",
619+
"[5, 6000] loss: 0.778\n",
620+
"[5, 8000] loss: 0.583\n",
621+
"[5, 10000] loss: 0.466\n",
622+
"[5, 12000] loss: 0.389\n",
623+
"[5, 14000] loss: 0.333\n",
624+
"[5, 16000] loss: 0.291\n",
625+
"[5, 18000] loss: 0.259\n",
626+
"[6, 2000] loss: 2.335\n",
627+
"[6, 4000] loss: 1.166\n",
628+
"[6, 6000] loss: 0.778\n",
629+
"[6, 8000] loss: 0.582\n",
630+
"[6, 10000] loss: 0.466\n",
631+
"[6, 12000] loss: 0.389\n",
632+
"[6, 14000] loss: 0.333\n",
633+
"[6, 16000] loss: 0.292\n",
634+
"[6, 18000] loss: 0.259\n",
635+
"[7, 2000] loss: 2.330\n",
636+
"[7, 4000] loss: 1.166\n",
637+
"[7, 6000] loss: 0.776\n",
638+
"[7, 8000] loss: 0.582\n",
639+
"[7, 10000] loss: 0.466\n",
640+
"[7, 12000] loss: 0.389\n",
641+
"[7, 14000] loss: 0.333\n",
642+
"[7, 16000] loss: 0.291\n",
643+
"[7, 18000] loss: 0.259\n",
644+
"[8, 2000] loss: 2.327\n",
645+
"[8, 4000] loss: 1.167\n",
646+
"[8, 6000] loss: 0.777\n",
647+
"[8, 8000] loss: 0.582\n",
648+
"[8, 10000] loss: 0.466\n",
649+
"[8, 12000] loss: 0.389\n",
650+
"[8, 14000] loss: 0.334\n",
651+
"[8, 16000] loss: 0.292\n",
652+
"[8, 18000] loss: 0.259\n"
653+
]
654+
}
655+
],
489656
"source": [
490657
"spot_torch = spot.Spot(fun=fun,\n",
491658
" lower = lower,\n",

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.0.51"
10+
version = "0.0.52"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

0 commit comments

Comments
 (0)