Skip to content

Commit 92d5767

Browse files
lr, epochs
1 parent 0c97b23 commit 92d5767

3 files changed

Lines changed: 40 additions & 201 deletions

File tree

notebooks/11_spot_hpt_torch_fashion_mnist.ipynb

Lines changed: 29 additions & 196 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": 1,
15+
"execution_count": null,
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
@@ -22,20 +22,9 @@
2222
},
2323
{
2424
"cell_type": "code",
25-
"execution_count": 2,
25+
"execution_count": null,
2626
"metadata": {},
27-
"outputs": [
28-
{
29-
"data": {
30-
"text/plain": [
31-
"'11-torch_p040025_1min_5init_2023-05-06_20-54-49'"
32-
]
33-
},
34-
"execution_count": 2,
35-
"metadata": {},
36-
"output_type": "execute_result"
37-
}
38-
],
27+
"outputs": [],
3928
"source": [
4029
"import pickle\n",
4130
"import socket\n",
@@ -68,26 +57,16 @@
6857
},
6958
{
7059
"cell_type": "code",
71-
"execution_count": 3,
60+
"execution_count": null,
7261
"metadata": {},
73-
"outputs": [
74-
{
75-
"name": "stdout",
76-
"output_type": "stream",
77-
"text": [
78-
"spotPython 0.0.51\n",
79-
"spotRiver 0.0.92\n",
80-
"Note: you may need to restart the kernel to use updated packages.\n"
81-
]
82-
}
83-
],
62+
"outputs": [],
8463
"source": [
8564
"pip list | grep \"spot[RiverPython]\""
8665
]
8766
},
8867
{
8968
"cell_type": "code",
90-
"execution_count": 4,
69+
"execution_count": null,
9170
"metadata": {},
9271
"outputs": [],
9372
"source": [
@@ -98,7 +77,7 @@
9877
},
9978
{
10079
"cell_type": "code",
101-
"execution_count": 5,
80+
"execution_count": null,
10281
"metadata": {},
10382
"outputs": [],
10483
"source": [
@@ -175,18 +154,9 @@
175154
},
176155
{
177156
"cell_type": "code",
178-
"execution_count": 6,
157+
"execution_count": null,
179158
"metadata": {},
180-
"outputs": [
181-
{
182-
"name": "stdout",
183-
"output_type": "stream",
184-
"text": [
185-
"2.0.0\n",
186-
"MPS device: mps\n"
187-
]
188-
}
189-
],
159+
"outputs": [],
190160
"source": [
191161
"print(torch.__version__)\n",
192162
"# Check that MPS is available\n",
@@ -213,7 +183,7 @@
213183
},
214184
{
215185
"cell_type": "code",
216-
"execution_count": 7,
186+
"execution_count": null,
217187
"metadata": {},
218188
"outputs": [],
219189
"source": [
@@ -230,7 +200,7 @@
230200
},
231201
{
232202
"cell_type": "code",
233-
"execution_count": 8,
203+
"execution_count": null,
234204
"metadata": {},
235205
"outputs": [],
236206
"source": [
@@ -254,28 +224,17 @@
254224
},
255225
{
256226
"cell_type": "code",
257-
"execution_count": 9,
227+
"execution_count": null,
258228
"metadata": {},
259-
"outputs": [
260-
{
261-
"data": {
262-
"text/plain": [
263-
"(torch.Size([60000, 28, 28]), torch.Size([10000, 28, 28]))"
264-
]
265-
},
266-
"execution_count": 9,
267-
"metadata": {},
268-
"output_type": "execute_result"
269-
}
270-
],
229+
"outputs": [],
271230
"source": [
272231
"train, test = load_data()\n",
273232
"train.data.shape, test.data.shape"
274233
]
275234
},
276235
{
277236
"cell_type": "code",
278-
"execution_count": 10,
237+
"execution_count": null,
279238
"metadata": {},
280239
"outputs": [],
281240
"source": [
@@ -298,7 +257,7 @@
298257
},
299258
{
300259
"cell_type": "code",
301-
"execution_count": 11,
260+
"execution_count": null,
302261
"metadata": {},
303262
"outputs": [],
304263
"source": [
@@ -324,7 +283,7 @@
324283
},
325284
{
326285
"cell_type": "code",
327-
"execution_count": 12,
286+
"execution_count": null,
328287
"metadata": {},
329288
"outputs": [],
330289
"source": [
@@ -353,7 +312,7 @@
353312
},
354313
{
355314
"cell_type": "code",
356-
"execution_count": 13,
315+
"execution_count": null,
357316
"metadata": {},
358317
"outputs": [],
359318
"source": [
@@ -371,7 +330,7 @@
371330
},
372331
{
373332
"cell_type": "code",
374-
"execution_count": 14,
333+
"execution_count": null,
375334
"metadata": {},
376335
"outputs": [],
377336
"source": [
@@ -414,7 +373,7 @@
414373
},
415374
{
416375
"cell_type": "code",
417-
"execution_count": 15,
376+
"execution_count": null,
418377
"metadata": {},
419378
"outputs": [],
420379
"source": [
@@ -463,7 +422,7 @@
463422
},
464423
{
465424
"cell_type": "code",
466-
"execution_count": 16,
425+
"execution_count": null,
467426
"metadata": {},
468427
"outputs": [],
469428
"source": [
@@ -478,24 +437,9 @@
478437
},
479438
{
480439
"cell_type": "code",
481-
"execution_count": 17,
440+
"execution_count": null,
482441
"metadata": {},
483-
"outputs": [
484-
{
485-
"name": "stdout",
486-
"output_type": "stream",
487-
"text": [
488-
"| name | type | default | lower | upper | transform |\n",
489-
"|------------|--------|-----------|---------|---------|-----------------------|\n",
490-
"| l1 | int | 5 | 2 | 9 | transform_power_2_int |\n",
491-
"| l2 | int | 5 | 2 | 9 | transform_power_2_int |\n",
492-
"| lr | float | 0.001 | 0.0001 | 0.1 | None |\n",
493-
"| batch_size | int | 4 | 1 | 4 | transform_power_2_int |\n",
494-
"| epochs | int | 3 | 1 | 4 | transform_power_2_int |\n",
495-
"| k_folds | int | 2 | 0 | 0 | None |\n"
496-
]
497-
}
498-
],
442+
"outputs": [],
499443
"source": [
500444
"print(gen_design_table(fun_control))"
501445
]
@@ -513,20 +457,9 @@
513457
},
514458
{
515459
"cell_type": "code",
516-
"execution_count": 18,
460+
"execution_count": null,
517461
"metadata": {},
518-
"outputs": [
519-
{
520-
"data": {
521-
"text/plain": [
522-
"array([[5.e+00, 5.e+00, 1.e-03, 4.e+00, 3.e+00, 2.e+00]])"
523-
]
524-
},
525-
"execution_count": 18,
526-
"metadata": {},
527-
"output_type": "execute_result"
528-
}
529-
],
462+
"outputs": [],
530463
"source": [
531464
"X0 = get_default_values(fun_control)\n",
532465
"river_hyper_dict_default = fun_control[\"core_model_hyper_dict\"]\n",
@@ -538,20 +471,9 @@
538471
},
539472
{
540473
"cell_type": "code",
541-
"execution_count": 19,
474+
"execution_count": null,
542475
"metadata": {},
543-
"outputs": [
544-
{
545-
"data": {
546-
"text/plain": [
547-
"array([[5.e+00, 5.e+00, 1.e-03, 4.e+00, 3.e+00, 2.e+00]])"
548-
]
549-
},
550-
"execution_count": 19,
551-
"metadata": {},
552-
"output_type": "execute_result"
553-
}
554-
],
476+
"outputs": [],
555477
"source": [
556478
"from spotPython.hyperparameters.values import get_default_hyperparameters_as_array\n",
557479
"hyper_dict=TorchHyperDict().load()\n",
@@ -561,98 +483,9 @@
561483
},
562484
{
563485
"cell_type": "code",
564-
"execution_count": 20,
565-
"metadata": {},
566-
"outputs": [
567-
{
568-
"name": "stdout",
569-
"output_type": "stream",
570-
"text": [
571-
"[1, 2000] loss: 2.452\n",
572-
"[1, 4000] loss: 1.163\n",
573-
"[2, 2000] loss: 2.328\n",
574-
"[2, 4000] loss: 1.163\n",
575-
"[3, 2000] loss: 2.328\n",
576-
"[3, 4000] loss: 1.164\n",
577-
"[4, 2000] loss: 2.329\n",
578-
"[4, 4000] loss: 1.163\n",
579-
"Accuracy on hold-out set: 0.10116666666666667\n",
580-
"Loss on hold-out set: 2.320516836643219\n",
581-
"[1, 2000] loss: 2.351\n",
582-
"[1, 4000] loss: 1.166\n",
583-
"[1, 6000] loss: 0.777\n",
584-
"[1, 8000] loss: 0.583\n",
585-
"[1, 10000] loss: 0.466\n",
586-
"[1, 12000] loss: 0.388\n",
587-
"[1, 14000] loss: 0.333\n",
588-
"[1, 16000] loss: 0.291\n",
589-
"[1, 18000] loss: 0.259\n",
590-
"[2, 2000] loss: 2.330\n",
591-
"[2, 4000] loss: 1.164\n",
592-
"[2, 6000] loss: 0.776\n",
593-
"[2, 8000] loss: 0.583\n",
594-
"[2, 10000] loss: 0.467\n",
595-
"[2, 12000] loss: 0.388\n",
596-
"[2, 14000] loss: 0.333\n",
597-
"[2, 16000] loss: 0.291\n",
598-
"[2, 18000] loss: 0.259\n",
599-
"[3, 2000] loss: 2.332\n",
600-
"[3, 4000] loss: 1.165\n",
601-
"[3, 6000] loss: 0.777\n",
602-
"[3, 8000] loss: 0.582\n",
603-
"[3, 10000] loss: 0.467\n",
604-
"[3, 12000] loss: 0.389\n",
605-
"[3, 14000] loss: 0.333\n",
606-
"[3, 16000] loss: 0.292\n",
607-
"[3, 18000] loss: 0.259\n",
608-
"[4, 2000] loss: 2.334\n",
609-
"[4, 4000] loss: 1.165\n",
610-
"[4, 6000] loss: 0.778\n",
611-
"[4, 8000] loss: 0.583\n",
612-
"[4, 10000] loss: 0.466\n",
613-
"[4, 12000] loss: 0.389\n",
614-
"[4, 14000] loss: 0.333\n",
615-
"[4, 16000] loss: 0.291\n",
616-
"[4, 18000] loss: 0.259\n",
617-
"[5, 2000] loss: 2.327\n",
618-
"[5, 4000] loss: 1.169\n",
619-
"[5, 6000] loss: 0.778\n",
620-
"[5, 8000] loss: 0.583\n",
621-
"[5, 10000] loss: 0.466\n",
622-
"[5, 12000] loss: 0.389\n",
623-
"[5, 14000] loss: 0.333\n",
624-
"[5, 16000] loss: 0.291\n",
625-
"[5, 18000] loss: 0.259\n",
626-
"[6, 2000] loss: 2.335\n",
627-
"[6, 4000] loss: 1.166\n",
628-
"[6, 6000] loss: 0.778\n",
629-
"[6, 8000] loss: 0.582\n",
630-
"[6, 10000] loss: 0.466\n",
631-
"[6, 12000] loss: 0.389\n",
632-
"[6, 14000] loss: 0.333\n",
633-
"[6, 16000] loss: 0.292\n",
634-
"[6, 18000] loss: 0.259\n",
635-
"[7, 2000] loss: 2.330\n",
636-
"[7, 4000] loss: 1.166\n",
637-
"[7, 6000] loss: 0.776\n",
638-
"[7, 8000] loss: 0.582\n",
639-
"[7, 10000] loss: 0.466\n",
640-
"[7, 12000] loss: 0.389\n",
641-
"[7, 14000] loss: 0.333\n",
642-
"[7, 16000] loss: 0.291\n",
643-
"[7, 18000] loss: 0.259\n",
644-
"[8, 2000] loss: 2.327\n",
645-
"[8, 4000] loss: 1.167\n",
646-
"[8, 6000] loss: 0.777\n",
647-
"[8, 8000] loss: 0.582\n",
648-
"[8, 10000] loss: 0.466\n",
649-
"[8, 12000] loss: 0.389\n",
650-
"[8, 14000] loss: 0.334\n",
651-
"[8, 16000] loss: 0.292\n",
652-
"[8, 18000] loss: 0.259\n"
653-
]
654-
}
655-
],
486+
"execution_count": null,
487+
"metadata": {},
488+
"outputs": [],
656489
"source": [
657490
"spot_torch = spot.Spot(fun=fun,\n",
658491
" lower = lower,\n",

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.0.52"
10+
version = "0.0.53"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

0 commit comments

Comments
 (0)