|
12 | 12 | }, |
13 | 13 | { |
14 | 14 | "cell_type": "code", |
15 | | - "execution_count": 1, |
| 15 | + "execution_count": null, |
16 | 16 | "metadata": {}, |
17 | 17 | "outputs": [], |
18 | 18 | "source": [ |
|
22 | 22 | }, |
23 | 23 | { |
24 | 24 | "cell_type": "code", |
25 | | - "execution_count": 2, |
| 25 | + "execution_count": null, |
26 | 26 | "metadata": {}, |
27 | | - "outputs": [ |
28 | | - { |
29 | | - "data": { |
30 | | - "text/plain": [ |
31 | | - "'11-torch_p040025_1min_5init_2023-05-06_20-54-49'" |
32 | | - ] |
33 | | - }, |
34 | | - "execution_count": 2, |
35 | | - "metadata": {}, |
36 | | - "output_type": "execute_result" |
37 | | - } |
38 | | - ], |
| 27 | + "outputs": [], |
39 | 28 | "source": [ |
40 | 29 | "import pickle\n", |
41 | 30 | "import socket\n", |
|
68 | 57 | }, |
69 | 58 | { |
70 | 59 | "cell_type": "code", |
71 | | - "execution_count": 3, |
| 60 | + "execution_count": null, |
72 | 61 | "metadata": {}, |
73 | | - "outputs": [ |
74 | | - { |
75 | | - "name": "stdout", |
76 | | - "output_type": "stream", |
77 | | - "text": [ |
78 | | - "spotPython 0.0.51\n", |
79 | | - "spotRiver 0.0.92\n", |
80 | | - "Note: you may need to restart the kernel to use updated packages.\n" |
81 | | - ] |
82 | | - } |
83 | | - ], |
| 62 | + "outputs": [], |
84 | 63 | "source": [ |
85 | 64 | "pip list | grep \"spot[RiverPython]\"" |
86 | 65 | ] |
87 | 66 | }, |
88 | 67 | { |
89 | 68 | "cell_type": "code", |
90 | | - "execution_count": 4, |
| 69 | + "execution_count": null, |
91 | 70 | "metadata": {}, |
92 | 71 | "outputs": [], |
93 | 72 | "source": [ |
|
98 | 77 | }, |
99 | 78 | { |
100 | 79 | "cell_type": "code", |
101 | | - "execution_count": 5, |
| 80 | + "execution_count": null, |
102 | 81 | "metadata": {}, |
103 | 82 | "outputs": [], |
104 | 83 | "source": [ |
|
175 | 154 | }, |
176 | 155 | { |
177 | 156 | "cell_type": "code", |
178 | | - "execution_count": 6, |
| 157 | + "execution_count": null, |
179 | 158 | "metadata": {}, |
180 | | - "outputs": [ |
181 | | - { |
182 | | - "name": "stdout", |
183 | | - "output_type": "stream", |
184 | | - "text": [ |
185 | | - "2.0.0\n", |
186 | | - "MPS device: mps\n" |
187 | | - ] |
188 | | - } |
189 | | - ], |
| 159 | + "outputs": [], |
190 | 160 | "source": [ |
191 | 161 | "print(torch.__version__)\n", |
192 | 162 | "# Check that MPS is available\n", |
|
213 | 183 | }, |
214 | 184 | { |
215 | 185 | "cell_type": "code", |
216 | | - "execution_count": 7, |
| 186 | + "execution_count": null, |
217 | 187 | "metadata": {}, |
218 | 188 | "outputs": [], |
219 | 189 | "source": [ |
|
230 | 200 | }, |
231 | 201 | { |
232 | 202 | "cell_type": "code", |
233 | | - "execution_count": 8, |
| 203 | + "execution_count": null, |
234 | 204 | "metadata": {}, |
235 | 205 | "outputs": [], |
236 | 206 | "source": [ |
|
254 | 224 | }, |
255 | 225 | { |
256 | 226 | "cell_type": "code", |
257 | | - "execution_count": 9, |
| 227 | + "execution_count": null, |
258 | 228 | "metadata": {}, |
259 | | - "outputs": [ |
260 | | - { |
261 | | - "data": { |
262 | | - "text/plain": [ |
263 | | - "(torch.Size([60000, 28, 28]), torch.Size([10000, 28, 28]))" |
264 | | - ] |
265 | | - }, |
266 | | - "execution_count": 9, |
267 | | - "metadata": {}, |
268 | | - "output_type": "execute_result" |
269 | | - } |
270 | | - ], |
| 229 | + "outputs": [], |
271 | 230 | "source": [ |
272 | 231 | "train, test = load_data()\n", |
273 | 232 | "train.data.shape, test.data.shape" |
274 | 233 | ] |
275 | 234 | }, |
276 | 235 | { |
277 | 236 | "cell_type": "code", |
278 | | - "execution_count": 10, |
| 237 | + "execution_count": null, |
279 | 238 | "metadata": {}, |
280 | 239 | "outputs": [], |
281 | 240 | "source": [ |
|
298 | 257 | }, |
299 | 258 | { |
300 | 259 | "cell_type": "code", |
301 | | - "execution_count": 11, |
| 260 | + "execution_count": null, |
302 | 261 | "metadata": {}, |
303 | 262 | "outputs": [], |
304 | 263 | "source": [ |
|
324 | 283 | }, |
325 | 284 | { |
326 | 285 | "cell_type": "code", |
327 | | - "execution_count": 12, |
| 286 | + "execution_count": null, |
328 | 287 | "metadata": {}, |
329 | 288 | "outputs": [], |
330 | 289 | "source": [ |
|
353 | 312 | }, |
354 | 313 | { |
355 | 314 | "cell_type": "code", |
356 | | - "execution_count": 13, |
| 315 | + "execution_count": null, |
357 | 316 | "metadata": {}, |
358 | 317 | "outputs": [], |
359 | 318 | "source": [ |
|
371 | 330 | }, |
372 | 331 | { |
373 | 332 | "cell_type": "code", |
374 | | - "execution_count": 14, |
| 333 | + "execution_count": null, |
375 | 334 | "metadata": {}, |
376 | 335 | "outputs": [], |
377 | 336 | "source": [ |
|
414 | 373 | }, |
415 | 374 | { |
416 | 375 | "cell_type": "code", |
417 | | - "execution_count": 15, |
| 376 | + "execution_count": null, |
418 | 377 | "metadata": {}, |
419 | 378 | "outputs": [], |
420 | 379 | "source": [ |
|
463 | 422 | }, |
464 | 423 | { |
465 | 424 | "cell_type": "code", |
466 | | - "execution_count": 16, |
| 425 | + "execution_count": null, |
467 | 426 | "metadata": {}, |
468 | 427 | "outputs": [], |
469 | 428 | "source": [ |
|
478 | 437 | }, |
479 | 438 | { |
480 | 439 | "cell_type": "code", |
481 | | - "execution_count": 17, |
| 440 | + "execution_count": null, |
482 | 441 | "metadata": {}, |
483 | | - "outputs": [ |
484 | | - { |
485 | | - "name": "stdout", |
486 | | - "output_type": "stream", |
487 | | - "text": [ |
488 | | - "| name | type | default | lower | upper | transform |\n", |
489 | | - "|------------|--------|-----------|---------|---------|-----------------------|\n", |
490 | | - "| l1 | int | 5 | 2 | 9 | transform_power_2_int |\n", |
491 | | - "| l2 | int | 5 | 2 | 9 | transform_power_2_int |\n", |
492 | | - "| lr | float | 0.001 | 0.0001 | 0.1 | None |\n", |
493 | | - "| batch_size | int | 4 | 1 | 4 | transform_power_2_int |\n", |
494 | | - "| epochs | int | 3 | 1 | 4 | transform_power_2_int |\n", |
495 | | - "| k_folds | int | 2 | 0 | 0 | None |\n" |
496 | | - ] |
497 | | - } |
498 | | - ], |
| 442 | + "outputs": [], |
499 | 443 | "source": [ |
500 | 444 | "print(gen_design_table(fun_control))" |
501 | 445 | ] |
|
513 | 457 | }, |
514 | 458 | { |
515 | 459 | "cell_type": "code", |
516 | | - "execution_count": 18, |
| 460 | + "execution_count": null, |
517 | 461 | "metadata": {}, |
518 | | - "outputs": [ |
519 | | - { |
520 | | - "data": { |
521 | | - "text/plain": [ |
522 | | - "array([[5.e+00, 5.e+00, 1.e-03, 4.e+00, 3.e+00, 2.e+00]])" |
523 | | - ] |
524 | | - }, |
525 | | - "execution_count": 18, |
526 | | - "metadata": {}, |
527 | | - "output_type": "execute_result" |
528 | | - } |
529 | | - ], |
| 462 | + "outputs": [], |
530 | 463 | "source": [ |
531 | 464 | "X0 = get_default_values(fun_control)\n", |
532 | 465 | "river_hyper_dict_default = fun_control[\"core_model_hyper_dict\"]\n", |
|
538 | 471 | }, |
539 | 472 | { |
540 | 473 | "cell_type": "code", |
541 | | - "execution_count": 19, |
| 474 | + "execution_count": null, |
542 | 475 | "metadata": {}, |
543 | | - "outputs": [ |
544 | | - { |
545 | | - "data": { |
546 | | - "text/plain": [ |
547 | | - "array([[5.e+00, 5.e+00, 1.e-03, 4.e+00, 3.e+00, 2.e+00]])" |
548 | | - ] |
549 | | - }, |
550 | | - "execution_count": 19, |
551 | | - "metadata": {}, |
552 | | - "output_type": "execute_result" |
553 | | - } |
554 | | - ], |
| 476 | + "outputs": [], |
555 | 477 | "source": [ |
556 | 478 | "from spotPython.hyperparameters.values import get_default_hyperparameters_as_array\n", |
557 | 479 | "hyper_dict=TorchHyperDict().load()\n", |
|
561 | 483 | }, |
562 | 484 | { |
563 | 485 | "cell_type": "code", |
564 | | - "execution_count": 20, |
565 | | - "metadata": {}, |
566 | | - "outputs": [ |
567 | | - { |
568 | | - "name": "stdout", |
569 | | - "output_type": "stream", |
570 | | - "text": [ |
571 | | - "[1, 2000] loss: 2.452\n", |
572 | | - "[1, 4000] loss: 1.163\n", |
573 | | - "[2, 2000] loss: 2.328\n", |
574 | | - "[2, 4000] loss: 1.163\n", |
575 | | - "[3, 2000] loss: 2.328\n", |
576 | | - "[3, 4000] loss: 1.164\n", |
577 | | - "[4, 2000] loss: 2.329\n", |
578 | | - "[4, 4000] loss: 1.163\n", |
579 | | - "Accuracy on hold-out set: 0.10116666666666667\n", |
580 | | - "Loss on hold-out set: 2.320516836643219\n", |
581 | | - "[1, 2000] loss: 2.351\n", |
582 | | - "[1, 4000] loss: 1.166\n", |
583 | | - "[1, 6000] loss: 0.777\n", |
584 | | - "[1, 8000] loss: 0.583\n", |
585 | | - "[1, 10000] loss: 0.466\n", |
586 | | - "[1, 12000] loss: 0.388\n", |
587 | | - "[1, 14000] loss: 0.333\n", |
588 | | - "[1, 16000] loss: 0.291\n", |
589 | | - "[1, 18000] loss: 0.259\n", |
590 | | - "[2, 2000] loss: 2.330\n", |
591 | | - "[2, 4000] loss: 1.164\n", |
592 | | - "[2, 6000] loss: 0.776\n", |
593 | | - "[2, 8000] loss: 0.583\n", |
594 | | - "[2, 10000] loss: 0.467\n", |
595 | | - "[2, 12000] loss: 0.388\n", |
596 | | - "[2, 14000] loss: 0.333\n", |
597 | | - "[2, 16000] loss: 0.291\n", |
598 | | - "[2, 18000] loss: 0.259\n", |
599 | | - "[3, 2000] loss: 2.332\n", |
600 | | - "[3, 4000] loss: 1.165\n", |
601 | | - "[3, 6000] loss: 0.777\n", |
602 | | - "[3, 8000] loss: 0.582\n", |
603 | | - "[3, 10000] loss: 0.467\n", |
604 | | - "[3, 12000] loss: 0.389\n", |
605 | | - "[3, 14000] loss: 0.333\n", |
606 | | - "[3, 16000] loss: 0.292\n", |
607 | | - "[3, 18000] loss: 0.259\n", |
608 | | - "[4, 2000] loss: 2.334\n", |
609 | | - "[4, 4000] loss: 1.165\n", |
610 | | - "[4, 6000] loss: 0.778\n", |
611 | | - "[4, 8000] loss: 0.583\n", |
612 | | - "[4, 10000] loss: 0.466\n", |
613 | | - "[4, 12000] loss: 0.389\n", |
614 | | - "[4, 14000] loss: 0.333\n", |
615 | | - "[4, 16000] loss: 0.291\n", |
616 | | - "[4, 18000] loss: 0.259\n", |
617 | | - "[5, 2000] loss: 2.327\n", |
618 | | - "[5, 4000] loss: 1.169\n", |
619 | | - "[5, 6000] loss: 0.778\n", |
620 | | - "[5, 8000] loss: 0.583\n", |
621 | | - "[5, 10000] loss: 0.466\n", |
622 | | - "[5, 12000] loss: 0.389\n", |
623 | | - "[5, 14000] loss: 0.333\n", |
624 | | - "[5, 16000] loss: 0.291\n", |
625 | | - "[5, 18000] loss: 0.259\n", |
626 | | - "[6, 2000] loss: 2.335\n", |
627 | | - "[6, 4000] loss: 1.166\n", |
628 | | - "[6, 6000] loss: 0.778\n", |
629 | | - "[6, 8000] loss: 0.582\n", |
630 | | - "[6, 10000] loss: 0.466\n", |
631 | | - "[6, 12000] loss: 0.389\n", |
632 | | - "[6, 14000] loss: 0.333\n", |
633 | | - "[6, 16000] loss: 0.292\n", |
634 | | - "[6, 18000] loss: 0.259\n", |
635 | | - "[7, 2000] loss: 2.330\n", |
636 | | - "[7, 4000] loss: 1.166\n", |
637 | | - "[7, 6000] loss: 0.776\n", |
638 | | - "[7, 8000] loss: 0.582\n", |
639 | | - "[7, 10000] loss: 0.466\n", |
640 | | - "[7, 12000] loss: 0.389\n", |
641 | | - "[7, 14000] loss: 0.333\n", |
642 | | - "[7, 16000] loss: 0.291\n", |
643 | | - "[7, 18000] loss: 0.259\n", |
644 | | - "[8, 2000] loss: 2.327\n", |
645 | | - "[8, 4000] loss: 1.167\n", |
646 | | - "[8, 6000] loss: 0.777\n", |
647 | | - "[8, 8000] loss: 0.582\n", |
648 | | - "[8, 10000] loss: 0.466\n", |
649 | | - "[8, 12000] loss: 0.389\n", |
650 | | - "[8, 14000] loss: 0.334\n", |
651 | | - "[8, 16000] loss: 0.292\n", |
652 | | - "[8, 18000] loss: 0.259\n" |
653 | | - ] |
654 | | - } |
655 | | - ], |
| 486 | + "execution_count": null, |
| 487 | + "metadata": {}, |
| 488 | + "outputs": [], |
656 | 489 | "source": [ |
657 | 490 | "spot_torch = spot.Spot(fun=fun,\n", |
658 | 491 | " lower = lower,\n", |
|
0 commit comments