|
12 | 12 | }, |
13 | 13 | { |
14 | 14 | "cell_type": "code", |
15 | | - "execution_count": null, |
| 15 | + "execution_count": 1, |
16 | 16 | "metadata": {}, |
17 | 17 | "outputs": [], |
18 | 18 | "source": [ |
|
22 | 22 | }, |
23 | 23 | { |
24 | 24 | "cell_type": "code", |
25 | | - "execution_count": null, |
26 | | - "metadata": {}, |
27 | | - "outputs": [], |
| 25 | + "execution_count": 2, |
| 26 | + "metadata": {}, |
| 27 | + "outputs": [ |
| 28 | + { |
| 29 | + "data": { |
| 30 | + "text/plain": [ |
| 31 | + "'12-torch_p040025_1min_5init_2023-05-07_18-06-33'" |
| 32 | + ] |
| 33 | + }, |
| 34 | + "execution_count": 2, |
| 35 | + "metadata": {}, |
| 36 | + "output_type": "execute_result" |
| 37 | + } |
| 38 | + ], |
28 | 39 | "source": [ |
29 | 40 | "import pickle\n", |
30 | 41 | "import socket\n", |
|
57 | 68 | }, |
58 | 69 | { |
59 | 70 | "cell_type": "code", |
60 | | - "execution_count": null, |
61 | | - "metadata": {}, |
62 | | - "outputs": [], |
| 71 | + "execution_count": 3, |
| 72 | + "metadata": {}, |
| 73 | + "outputs": [ |
| 74 | + { |
| 75 | + "name": "stdout", |
| 76 | + "output_type": "stream", |
| 77 | + "text": [ |
| 78 | + "spotPython 0.0.60\n", |
| 79 | + "spotRiver 0.0.92\n", |
| 80 | + "Note: you may need to restart the kernel to use updated packages.\n" |
| 81 | + ] |
| 82 | + } |
| 83 | + ], |
63 | 84 | "source": [ |
64 | 85 | "pip list | grep \"spot[RiverPython]\"" |
65 | 86 | ] |
66 | 87 | }, |
67 | 88 | { |
68 | 89 | "cell_type": "code", |
69 | | - "execution_count": null, |
| 90 | + "execution_count": 4, |
70 | 91 | "metadata": {}, |
71 | 92 | "outputs": [], |
72 | 93 | "source": [ |
|
77 | 98 | }, |
78 | 99 | { |
79 | 100 | "cell_type": "code", |
80 | | - "execution_count": null, |
| 101 | + "execution_count": 5, |
81 | 102 | "metadata": {}, |
82 | 103 | "outputs": [], |
83 | 104 | "source": [ |
|
154 | 175 | }, |
155 | 176 | { |
156 | 177 | "cell_type": "code", |
157 | | - "execution_count": null, |
158 | | - "metadata": {}, |
159 | | - "outputs": [], |
| 178 | + "execution_count": 6, |
| 179 | + "metadata": {}, |
| 180 | + "outputs": [ |
| 181 | + { |
| 182 | + "name": "stdout", |
| 183 | + "output_type": "stream", |
| 184 | + "text": [ |
| 185 | + "2.0.0\n", |
| 186 | + "MPS device: mps\n" |
| 187 | + ] |
| 188 | + } |
| 189 | + ], |
160 | 190 | "source": [ |
161 | 191 | "print(torch.__version__)\n", |
162 | 192 | "# Check that MPS is available\n", |
|
183 | 213 | }, |
184 | 214 | { |
185 | 215 | "cell_type": "code", |
186 | | - "execution_count": null, |
| 216 | + "execution_count": 7, |
187 | 217 | "metadata": {}, |
188 | 218 | "outputs": [], |
189 | 219 | "source": [ |
|
208 | 238 | }, |
209 | 239 | { |
210 | 240 | "cell_type": "code", |
211 | | - "execution_count": null, |
| 241 | + "execution_count": 8, |
212 | 242 | "metadata": {}, |
213 | 243 | "outputs": [], |
214 | 244 | "source": [ |
|
229 | 259 | }, |
230 | 260 | { |
231 | 261 | "cell_type": "code", |
232 | | - "execution_count": null, |
233 | | - "metadata": {}, |
234 | | - "outputs": [], |
| 262 | + "execution_count": 9, |
| 263 | + "metadata": {}, |
| 264 | + "outputs": [ |
| 265 | + { |
| 266 | + "name": "stdout", |
| 267 | + "output_type": "stream", |
| 268 | + "text": [ |
| 269 | + "Files already downloaded and verified\n", |
| 270 | + "Files already downloaded and verified\n" |
| 271 | + ] |
| 272 | + }, |
| 273 | + { |
| 274 | + "data": { |
| 275 | + "text/plain": [ |
| 276 | + "((50000, 32, 32, 3), (10000, 32, 32, 3))" |
| 277 | + ] |
| 278 | + }, |
| 279 | + "execution_count": 9, |
| 280 | + "metadata": {}, |
| 281 | + "output_type": "execute_result" |
| 282 | + } |
| 283 | + ], |
235 | 284 | "source": [ |
236 | 285 | "train, test = load_data()\n", |
237 | 286 | "train.data.shape, test.data.shape" |
238 | 287 | ] |
239 | 288 | }, |
240 | 289 | { |
241 | 290 | "cell_type": "code", |
242 | | - "execution_count": null, |
| 291 | + "execution_count": 10, |
243 | 292 | "metadata": {}, |
244 | 293 | "outputs": [], |
245 | 294 | "source": [ |
|
264 | 313 | }, |
265 | 314 | { |
266 | 315 | "cell_type": "code", |
267 | | - "execution_count": null, |
| 316 | + "execution_count": 11, |
268 | 317 | "metadata": {}, |
269 | 318 | "outputs": [], |
270 | 319 | "source": [ |
|
307 | 356 | }, |
308 | 357 | { |
309 | 358 | "cell_type": "code", |
310 | | - "execution_count": null, |
| 359 | + "execution_count": 12, |
311 | 360 | "metadata": {}, |
312 | 361 | "outputs": [], |
313 | 362 | "source": [ |
|
436 | 485 | }, |
437 | 486 | { |
438 | 487 | "cell_type": "code", |
439 | | - "execution_count": null, |
440 | | - "metadata": {}, |
441 | | - "outputs": [], |
| 488 | + "execution_count": 13, |
| 489 | + "metadata": {}, |
| 490 | + "outputs": [ |
| 491 | + { |
| 492 | + "data": { |
| 493 | + "text/plain": [ |
| 494 | + "{'l1': {'type': 'int',\n", |
| 495 | + " 'default': 5,\n", |
| 496 | + " 'transform': 'transform_power_2_int',\n", |
| 497 | + " 'lower': 2,\n", |
| 498 | + " 'upper': 9},\n", |
| 499 | + " 'l2': {'type': 'int',\n", |
| 500 | + " 'default': 5,\n", |
| 501 | + " 'transform': 'transform_power_2_int',\n", |
| 502 | + " 'lower': 2,\n", |
| 503 | + " 'upper': 9},\n", |
| 504 | + " 'lr': {'type': 'float',\n", |
| 505 | + " 'default': 0.001,\n", |
| 506 | + " 'transform': 'None',\n", |
| 507 | + " 'lower': 1e-05,\n", |
| 508 | + " 'upper': 0.01},\n", |
| 509 | + " 'batch_size': {'type': 'int',\n", |
| 510 | + " 'default': 4,\n", |
| 511 | + " 'transform': 'transform_power_2_int',\n", |
| 512 | + " 'lower': 1,\n", |
| 513 | + " 'upper': 4},\n", |
| 514 | + " 'epochs': {'type': 'int',\n", |
| 515 | + " 'default': 3,\n", |
| 516 | + " 'transform': 'transform_power_2_int',\n", |
| 517 | + " 'lower': 1,\n", |
| 518 | + " 'upper': 4},\n", |
| 519 | + " 'k_folds': {'type': 'int',\n", |
| 520 | + " 'default': 2,\n", |
| 521 | + " 'transform': 'None',\n", |
| 522 | + " 'lower': 0,\n", |
| 523 | + " 'upper': 0}}" |
| 524 | + ] |
| 525 | + }, |
| 526 | + "execution_count": 13, |
| 527 | + "metadata": {}, |
| 528 | + "output_type": "execute_result" |
| 529 | + } |
| 530 | + ], |
442 | 531 | "source": [ |
443 | 532 | "fun_control = modify_hyper_parameter_bounds(fun_control, \"k_folds\", bounds=[0, 0])\n", |
444 | 533 | "fun_control[\"core_model_hyper_dict\"]" |
|
462 | 551 | }, |
463 | 552 | { |
464 | 553 | "cell_type": "code", |
465 | | - "execution_count": null, |
| 554 | + "execution_count": 14, |
466 | 555 | "metadata": {}, |
467 | 556 | "outputs": [], |
468 | 557 | "source": [ |
|
501 | 590 | }, |
502 | 591 | { |
503 | 592 | "cell_type": "code", |
504 | | - "execution_count": null, |
| 593 | + "execution_count": 15, |
505 | 594 | "metadata": {}, |
506 | 595 | "outputs": [], |
507 | 596 | "source": [ |
|
552 | 641 | }, |
553 | 642 | { |
554 | 643 | "cell_type": "code", |
555 | | - "execution_count": null, |
| 644 | + "execution_count": 16, |
556 | 645 | "metadata": {}, |
557 | 646 | "outputs": [], |
558 | 647 | "source": [ |
|
567 | 656 | }, |
568 | 657 | { |
569 | 658 | "cell_type": "code", |
570 | | - "execution_count": null, |
571 | | - "metadata": {}, |
572 | | - "outputs": [], |
| 659 | + "execution_count": 17, |
| 660 | + "metadata": {}, |
| 661 | + "outputs": [ |
| 662 | + { |
| 663 | + "name": "stdout", |
| 664 | + "output_type": "stream", |
| 665 | + "text": [ |
| 666 | + "| name | type | default | lower | upper | transform |\n", |
| 667 | + "|------------|--------|-----------|---------|---------|-----------------------|\n", |
| 668 | + "| l1 | int | 5 | 2 | 9 | transform_power_2_int |\n", |
| 669 | + "| l2 | int | 5 | 2 | 9 | transform_power_2_int |\n", |
| 670 | + "| lr | float | 0.001 | 1e-05 | 0.01 | None |\n", |
| 671 | + "| batch_size | int | 4 | 1 | 4 | transform_power_2_int |\n", |
| 672 | + "| epochs | int | 3 | 1 | 4 | transform_power_2_int |\n", |
| 673 | + "| k_folds | int | 2 | 0 | 0 | None |\n" |
| 674 | + ] |
| 675 | + } |
| 676 | + ], |
573 | 677 | "source": [ |
574 | 678 | "print(gen_design_table(fun_control))" |
575 | 679 | ] |
|
587 | 691 | }, |
588 | 692 | { |
589 | 693 | "cell_type": "code", |
590 | | - "execution_count": null, |
591 | | - "metadata": {}, |
592 | | - "outputs": [], |
| 694 | + "execution_count": 18, |
| 695 | + "metadata": {}, |
| 696 | + "outputs": [ |
| 697 | + { |
| 698 | + "data": { |
| 699 | + "text/plain": [ |
| 700 | + "array([[5.e+00, 5.e+00, 1.e-03, 4.e+00, 3.e+00, 2.e+00]])" |
| 701 | + ] |
| 702 | + }, |
| 703 | + "execution_count": 18, |
| 704 | + "metadata": {}, |
| 705 | + "output_type": "execute_result" |
| 706 | + } |
| 707 | + ], |
593 | 708 | "source": [ |
594 | 709 | "from spotPython.hyperparameters.values import get_default_hyperparameters_as_array\n", |
595 | 710 | "hyper_dict=TorchHyperDict().load()\n", |
|
599 | 714 | }, |
600 | 715 | { |
601 | 716 | "cell_type": "code", |
602 | | - "execution_count": null, |
603 | | - "metadata": {}, |
604 | | - "outputs": [], |
| 717 | + "execution_count": 19, |
| 718 | + "metadata": {}, |
| 719 | + "outputs": [ |
| 720 | + { |
| 721 | + "name": "stdout", |
| 722 | + "output_type": "stream", |
| 723 | + "text": [ |
| 724 | + "Epoch: 1, Batch: 1000. Batch Size: 8. Training Loss: 2.061\n", |
| 725 | + "Epoch: 1, Batch: 2000. Batch Size: 8. Training Loss: 0.885\n", |
| 726 | + "Epoch: 1, Batch: 3000. Batch Size: 8. Training Loss: 0.549\n" |
| 727 | + ] |
| 728 | + } |
| 729 | + ], |
605 | 730 | "source": [ |
606 | 731 | "spot_torch = spot.Spot(fun=fun,\n", |
607 | 732 | " lower = lower,\n", |
|
0 commit comments