|
467 | 467 | }, |
468 | 468 | { |
469 | 469 | "cell_type": "code", |
470 | | - "execution_count": 1, |
| 470 | + "execution_count": null, |
471 | 471 | "metadata": {}, |
472 | | - "outputs": [ |
473 | | - { |
474 | | - "name": "stdout", |
475 | | - "output_type": "stream", |
476 | | - "text": [ |
477 | | - "Loading data from /Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/spotPython/data/data.csv\n", |
478 | | - "11\n" |
479 | | - ] |
480 | | - } |
481 | | - ], |
| 472 | + "outputs": [], |
482 | 473 | "source": [ |
483 | 474 | "from spotPython.data.lightdatamodule import LightDataModule\n", |
484 | 475 | "from spotPython.data.csvdataset import CSVDataset\n", |
|
491 | 482 | }, |
492 | 483 | { |
493 | 484 | "cell_type": "code", |
494 | | - "execution_count": 7, |
| 485 | + "execution_count": null, |
495 | 486 | "metadata": {}, |
496 | 487 | "outputs": [], |
497 | 488 | "source": [ |
|
500 | 491 | }, |
501 | 492 | { |
502 | 493 | "cell_type": "code", |
503 | | - "execution_count": 8, |
| 494 | + "execution_count": null, |
504 | 495 | "metadata": {}, |
505 | | - "outputs": [ |
506 | | - { |
507 | | - "name": "stdout", |
508 | | - "output_type": "stream", |
509 | | - "text": [ |
510 | | - "full_train_size: 4\n", |
511 | | - "val_size: 2\n", |
512 | | - "train_size: 2\n", |
513 | | - "test_size: 7\n" |
514 | | - ] |
515 | | - } |
516 | | - ], |
| 496 | + "outputs": [], |
517 | 497 | "source": [ |
518 | 498 | "data_module.setup()" |
519 | 499 | ] |
520 | 500 | }, |
521 | 501 | { |
522 | 502 | "cell_type": "code", |
523 | | - "execution_count": 9, |
| 503 | + "execution_count": null, |
524 | 504 | "metadata": {}, |
525 | | - "outputs": [ |
526 | | - { |
527 | | - "name": "stdout", |
528 | | - "output_type": "stream", |
529 | | - "text": [ |
530 | | - "Training set size: 2\n" |
531 | | - ] |
532 | | - } |
533 | | - ], |
| 505 | + "outputs": [], |
534 | 506 | "source": [ |
535 | 507 | "print(f\"Training set size: {len(data_module.data_train)}\")" |
536 | 508 | ] |
537 | 509 | }, |
538 | 510 | { |
539 | 511 | "cell_type": "code", |
540 | | - "execution_count": 10, |
| 512 | + "execution_count": null, |
541 | 513 | "metadata": {}, |
542 | | - "outputs": [ |
543 | | - { |
544 | | - "name": "stdout", |
545 | | - "output_type": "stream", |
546 | | - "text": [ |
547 | | - "Validation set size: 2\n" |
548 | | - ] |
549 | | - } |
550 | | - ], |
| 514 | + "outputs": [], |
551 | 515 | "source": [ |
552 | 516 | "print(f\"Validation set size: {len(data_module.data_val)}\")" |
553 | 517 | ] |
554 | 518 | }, |
555 | 519 | { |
556 | 520 | "cell_type": "code", |
557 | | - "execution_count": 11, |
| 521 | + "execution_count": null, |
| 522 | + "metadata": {}, |
| 523 | + "outputs": [], |
| 524 | + "source": [ |
| 525 | + "print(f\"Test set size: {len(data_module.data_test)}\")" |
| 526 | + ] |
| 527 | + }, |
| 528 | + { |
| 529 | + "cell_type": "code", |
| 530 | + "execution_count": null, |
| 531 | + "metadata": {}, |
| 532 | + "outputs": [], |
| 533 | + "source": [] |
| 534 | + }, |
| 535 | + { |
| 536 | + "cell_type": "markdown", |
| 537 | + "metadata": {}, |
| 538 | + "source": [ |
| 539 | + "# Set the DataModule in fun_control " |
| 540 | + ] |
| 541 | + }, |
| 542 | + { |
| 543 | + "cell_type": "code", |
| 544 | + "execution_count": 1, |
558 | 545 | "metadata": {}, |
559 | 546 | "outputs": [ |
| 547 | + { |
| 548 | + "name": "stderr", |
| 549 | + "output_type": "stream", |
| 550 | + "text": [ |
| 551 | + "Seed set to 42\n" |
| 552 | + ] |
| 553 | + }, |
560 | 554 | { |
561 | 555 | "name": "stdout", |
562 | 556 | "output_type": "stream", |
563 | 557 | "text": [ |
| 558 | + "Loading data from /Users/bartz/miniforge3/envs/py311/lib/python3.11/site-packages/spotPython/data/data.csv\n", |
| 559 | + "full_train_size: 4\n", |
| 560 | + "val_size: 2\n", |
| 561 | + "train_size: 2\n", |
| 562 | + "test_size: 7\n", |
564 | 563 | "Test set size: 7\n" |
565 | 564 | ] |
566 | 565 | } |
567 | 566 | ], |
568 | 567 | "source": [ |
569 | | - "print(f\"Test set size: {len(data_module.data_test)}\")" |
| 568 | + "from spotPython.utils.init import fun_control_init\n", |
| 569 | + "from spotPython.hyperparameters.values import set_data_module\n", |
| 570 | + "from spotPython.data.lightdatamodule import LightDataModule\n", |
| 571 | + "from spotPython.data.csvdataset import CSVDataset\n", |
| 572 | + "from spotPython.data.pkldataset import PKLDataset\n", |
| 573 | + "import torch\n", |
| 574 | + "fun_control = fun_control_init()\n", |
| 575 | + "dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)\n", |
| 576 | + "dm = LightDataModule(dataset=dataset, batch_size=5, test_size=7)\n", |
| 577 | + "dm.setup()\n", |
| 578 | + "set_data_module(fun_control=fun_control,\n", |
| 579 | + " data_module=dm)\n", |
| 580 | + "data_module = fun_control[\"data_module\"]\n", |
| 581 | + "print(f\"Test set size: {len(data_module.data_test)}\")\n" |
570 | 582 | ] |
571 | 583 | }, |
572 | | - { |
573 | | - "cell_type": "code", |
574 | | - "execution_count": null, |
575 | | - "metadata": {}, |
576 | | - "outputs": [], |
577 | | - "source": [] |
578 | | - }, |
579 | 584 | { |
580 | 585 | "cell_type": "code", |
581 | 586 | "execution_count": null, |
|
0 commit comments