|
12 | 12 | }, |
13 | 13 | { |
14 | 14 | "cell_type": "code", |
15 | | - "execution_count": null, |
| 15 | + "execution_count": 1, |
16 | 16 | "metadata": {}, |
17 | 17 | "outputs": [], |
18 | 18 | "source": [ |
|
26 | 26 | }, |
27 | 27 | { |
28 | 28 | "cell_type": "code", |
29 | | - "execution_count": null, |
| 29 | + "execution_count": 2, |
30 | 30 | "metadata": {}, |
31 | | - "outputs": [], |
| 31 | + "outputs": [ |
| 32 | + { |
| 33 | + "data": { |
| 34 | + "text/plain": [ |
| 35 | + "'12-torch_p040025_1min_5init_2023-05-07_10-00-44'" |
| 36 | + ] |
| 37 | + }, |
| 38 | + "execution_count": 2, |
| 39 | + "metadata": {}, |
| 40 | + "output_type": "execute_result" |
| 41 | + } |
| 42 | + ], |
32 | 43 | "source": [ |
33 | 44 | "import pickle\n", |
34 | 45 | "import socket\n", |
|
61 | 72 | }, |
62 | 73 | { |
63 | 74 | "cell_type": "code", |
64 | | - "execution_count": null, |
| 75 | + "execution_count": 3, |
65 | 76 | "metadata": {}, |
66 | | - "outputs": [], |
| 77 | + "outputs": [ |
| 78 | + { |
| 79 | + "name": "stdout", |
| 80 | + "output_type": "stream", |
| 81 | + "text": [ |
| 82 | + "spotPython 0.0.56\n", |
| 83 | + "spotRiver 0.0.92\n", |
| 84 | + "Note: you may need to restart the kernel to use updated packages.\n" |
| 85 | + ] |
| 86 | + } |
| 87 | + ], |
67 | 88 | "source": [ |
68 | 89 | "pip list | grep \"spot[RiverPython]\"" |
69 | 90 | ] |
70 | 91 | }, |
71 | 92 | { |
72 | 93 | "cell_type": "code", |
73 | | - "execution_count": null, |
| 94 | + "execution_count": 4, |
74 | 95 | "metadata": {}, |
75 | 96 | "outputs": [], |
76 | 97 | "source": [ |
|
81 | 102 | }, |
82 | 103 | { |
83 | 104 | "cell_type": "code", |
84 | | - "execution_count": null, |
| 105 | + "execution_count": 5, |
85 | 106 | "metadata": {}, |
86 | 107 | "outputs": [], |
87 | 108 | "source": [ |
|
158 | 179 | }, |
159 | 180 | { |
160 | 181 | "cell_type": "code", |
161 | | - "execution_count": null, |
| 182 | + "execution_count": 6, |
162 | 183 | "metadata": {}, |
163 | | - "outputs": [], |
| 184 | + "outputs": [ |
| 185 | + { |
| 186 | + "name": "stdout", |
| 187 | + "output_type": "stream", |
| 188 | + "text": [ |
| 189 | + "2.0.0\n", |
| 190 | + "MPS device: mps\n" |
| 191 | + ] |
| 192 | + } |
| 193 | + ], |
164 | 194 | "source": [ |
165 | 195 | "print(torch.__version__)\n", |
166 | 196 | "# Check that MPS is available\n", |
|
187 | 217 | }, |
188 | 218 | { |
189 | 219 | "cell_type": "code", |
190 | | - "execution_count": null, |
| 220 | + "execution_count": 7, |
191 | 221 | "metadata": {}, |
192 | 222 | "outputs": [], |
193 | 223 | "source": [ |
|
212 | 242 | }, |
213 | 243 | { |
214 | 244 | "cell_type": "code", |
215 | | - "execution_count": null, |
| 245 | + "execution_count": 8, |
216 | 246 | "metadata": {}, |
217 | 247 | "outputs": [], |
218 | 248 | "source": [ |
|
241 | 271 | }, |
242 | 272 | { |
243 | 273 | "cell_type": "code", |
244 | | - "execution_count": null, |
| 274 | + "execution_count": 9, |
245 | 275 | "metadata": {}, |
246 | 276 | "outputs": [], |
247 | 277 | "source": [ |
|
262 | 292 | }, |
263 | 293 | { |
264 | 294 | "cell_type": "code", |
265 | | - "execution_count": null, |
| 295 | + "execution_count": 10, |
266 | 296 | "metadata": {}, |
267 | 297 | "outputs": [], |
268 | 298 | "source": [ |
|
298 | 328 | }, |
299 | 329 | { |
300 | 330 | "cell_type": "code", |
301 | | - "execution_count": null, |
| 331 | + "execution_count": 11, |
302 | 332 | "metadata": {}, |
303 | 333 | "outputs": [], |
304 | 334 | "source": [ |
|
319 | 349 | }, |
320 | 350 | { |
321 | 351 | "cell_type": "code", |
322 | | - "execution_count": null, |
| 352 | + "execution_count": 12, |
323 | 353 | "metadata": {}, |
324 | | - "outputs": [], |
| 354 | + "outputs": [ |
| 355 | + { |
| 356 | + "name": "stdout", |
| 357 | + "output_type": "stream", |
| 358 | + "text": [ |
| 359 | + "Files already downloaded and verified\n", |
| 360 | + "Files already downloaded and verified\n" |
| 361 | + ] |
| 362 | + }, |
| 363 | + { |
| 364 | + "data": { |
| 365 | + "text/plain": [ |
| 366 | + "((50000, 32, 32, 3), (10000, 32, 32, 3))" |
| 367 | + ] |
| 368 | + }, |
| 369 | + "execution_count": 12, |
| 370 | + "metadata": {}, |
| 371 | + "output_type": "execute_result" |
| 372 | + } |
| 373 | + ], |
325 | 374 | "source": [ |
326 | 375 | "train, test = load_data()\n", |
327 | 376 | "train.data.shape, test.data.shape" |
328 | 377 | ] |
329 | 378 | }, |
330 | 379 | { |
331 | 380 | "cell_type": "code", |
332 | | - "execution_count": null, |
| 381 | + "execution_count": 13, |
333 | 382 | "metadata": {}, |
334 | 383 | "outputs": [], |
335 | 384 | "source": [ |
|
352 | 401 | }, |
353 | 402 | { |
354 | 403 | "cell_type": "code", |
355 | | - "execution_count": null, |
| 404 | + "execution_count": 14, |
356 | 405 | "metadata": {}, |
357 | 406 | "outputs": [], |
358 | 407 | "source": [ |
|
378 | 427 | }, |
379 | 428 | { |
380 | 429 | "cell_type": "code", |
381 | | - "execution_count": null, |
| 430 | + "execution_count": 15, |
382 | 431 | "metadata": {}, |
383 | 432 | "outputs": [], |
384 | 433 | "source": [ |
|
407 | 456 | }, |
408 | 457 | { |
409 | 458 | "cell_type": "code", |
410 | | - "execution_count": null, |
| 459 | + "execution_count": 16, |
411 | 460 | "metadata": {}, |
412 | 461 | "outputs": [], |
413 | 462 | "source": [ |
|
425 | 474 | }, |
426 | 475 | { |
427 | 476 | "cell_type": "code", |
428 | | - "execution_count": null, |
429 | | - "metadata": {}, |
430 | | - "outputs": [], |
431 | | - "source": [ |
432 | | - "fun_control" |
433 | | - ] |
434 | | - }, |
435 | | - { |
436 | | - "cell_type": "code", |
437 | | - "execution_count": null, |
| 477 | + "execution_count": 17, |
438 | 478 | "metadata": {}, |
439 | 479 | "outputs": [], |
440 | 480 | "source": [ |
|
477 | 517 | }, |
478 | 518 | { |
479 | 519 | "cell_type": "code", |
480 | | - "execution_count": null, |
| 520 | + "execution_count": 18, |
481 | 521 | "metadata": {}, |
482 | 522 | "outputs": [], |
483 | 523 | "source": [ |
484 | 524 | "fun = HyperTorch(seed=123, log_level=50).fun_torch\n", |
485 | 525 | "weights = 1.0\n", |
486 | 526 | "shuffle = True\n", |
| 527 | + "eval = \"train_hold_out\"\n", |
487 | 528 | "\n", |
488 | 529 | "fun_control.update({\n", |
489 | 530 | " \"data_dir\": None,\n", |
|
497 | 538 | " \"metric\": None,\n", |
498 | 539 | " \"metric_sklearn\": None,\n", |
499 | 540 | " \"shuffle\": shuffle,\n", |
| 541 | + " \"eval\": eval\n", |
500 | 542 | " })" |
501 | 543 | ] |
502 | 544 | }, |
|
526 | 568 | }, |
527 | 569 | { |
528 | 570 | "cell_type": "code", |
529 | | - "execution_count": null, |
| 571 | + "execution_count": 19, |
530 | 572 | "metadata": {}, |
531 | 573 | "outputs": [], |
532 | 574 | "source": [ |
|
541 | 583 | }, |
542 | 584 | { |
543 | 585 | "cell_type": "code", |
544 | | - "execution_count": null, |
| 586 | + "execution_count": 20, |
545 | 587 | "metadata": {}, |
546 | | - "outputs": [], |
| 588 | + "outputs": [ |
| 589 | + { |
| 590 | + "name": "stdout", |
| 591 | + "output_type": "stream", |
| 592 | + "text": [ |
| 593 | + "| name | type | default | lower | upper | transform |\n", |
| 594 | + "|------------|--------|-----------|---------|---------|-----------------------|\n", |
| 595 | + "| l1 | int | 5 | 2 | 9 | transform_power_2_int |\n", |
| 596 | + "| l2 | int | 5 | 2 | 9 | transform_power_2_int |\n", |
| 597 | + "| lr | float | 0.001 | 1e-05 | 0.01 | None |\n", |
| 598 | + "| batch_size | int | 4 | 1 | 4 | transform_power_2_int |\n", |
| 599 | + "| epochs | int | 3 | 1 | 4 | transform_power_2_int |\n", |
| 600 | + "| k_folds | int | 2 | 0 | 0 | None |\n" |
| 601 | + ] |
| 602 | + } |
| 603 | + ], |
547 | 604 | "source": [ |
548 | 605 | "print(gen_design_table(fun_control))" |
549 | 606 | ] |
|
561 | 618 | }, |
562 | 619 | { |
563 | 620 | "cell_type": "code", |
564 | | - "execution_count": null, |
| 621 | + "execution_count": 21, |
565 | 622 | "metadata": {}, |
566 | | - "outputs": [], |
| 623 | + "outputs": [ |
| 624 | + { |
| 625 | + "data": { |
| 626 | + "text/plain": [ |
| 627 | + "array([[5.e+00, 5.e+00, 1.e-03, 4.e+00, 3.e+00, 2.e+00]])" |
| 628 | + ] |
| 629 | + }, |
| 630 | + "execution_count": 21, |
| 631 | + "metadata": {}, |
| 632 | + "output_type": "execute_result" |
| 633 | + } |
| 634 | + ], |
567 | 635 | "source": [ |
568 | 636 | "from spotPython.hyperparameters.values import get_default_hyperparameters_as_array\n", |
569 | 637 | "hyper_dict=TorchHyperDict().load()\n", |
|
573 | 641 | }, |
574 | 642 | { |
575 | 643 | "cell_type": "code", |
576 | | - "execution_count": null, |
| 644 | + "execution_count": 22, |
577 | 645 | "metadata": {}, |
578 | | - "outputs": [], |
| 646 | + "outputs": [ |
| 647 | + { |
| 648 | + "name": "stdout", |
| 649 | + "output_type": "stream", |
| 650 | + "text": [ |
| 651 | + "[1, 2000] loss: 2.306\n" |
| 652 | + ] |
| 653 | + } |
| 654 | + ], |
579 | 655 | "source": [ |
580 | 656 | "spot_torch = spot.Spot(fun=fun,\n", |
581 | 657 | " lower = lower,\n", |
|
799 | 875 | "min(spot_torch.y), max(spot_torch.y)" |
800 | 876 | ] |
801 | 877 | }, |
| 878 | + { |
| 879 | + "cell_type": "code", |
| 880 | + "execution_count": null, |
| 881 | + "metadata": {}, |
| 882 | + "outputs": [], |
| 883 | + "source": [ |
| 884 | + "trainset = fun_control[\"train\"]" |
| 885 | + ] |
| 886 | + }, |
| 887 | + { |
| 888 | + "cell_type": "code", |
| 889 | + "execution_count": null, |
| 890 | + "metadata": {}, |
| 891 | + "outputs": [], |
| 892 | + "source": [ |
| 893 | + "model_default.evaluate_hold_out(dataset=trainset, shuffle=False, test_dataset=testset)" |
| 894 | + ] |
| 895 | + }, |
| 896 | + { |
| 897 | + "cell_type": "code", |
| 898 | + "execution_count": null, |
| 899 | + "metadata": {}, |
| 900 | + "outputs": [], |
| 901 | + "source": [ |
| 902 | + "model_spot.evaluate_hold_out(dataset=trainset, shuffle=False, test_dataset=testset)" |
| 903 | + ] |
| 904 | + }, |
802 | 905 | { |
803 | 906 | "attachments": {}, |
804 | 907 | "cell_type": "markdown", |
|
0 commit comments