|
247 | 247 | }, |
248 | 248 | { |
249 | 249 | "cell_type": "code", |
250 | | - "execution_count": null, |
| 250 | + "execution_count": 1, |
251 | 251 | "metadata": {}, |
252 | 252 | "outputs": [], |
253 | 253 | "source": [ |
254 | | - "# from pyhcf.data.daten_sensitive import DatenSensitive\n", |
255 | | - "# from pyhcf.utils.names import get_short_parameter_names\n", |
256 | | - "# daten = DatenSensitive()\n", |
257 | | - "# df = daten.load()\n", |
258 | | - "# names = df.columns\n", |
259 | | - "# names = get_short_parameter_names(names)\n", |
260 | | - "# # rename columns with short names\n", |
261 | | - "# df.columns = names\n", |
262 | | - "# df.head()\n", |
263 | | - "# # save the df as a csv file\n", |
264 | | - "# df.to_csv('./data/spotPython/data_sensitive.csv', index=False)\n", |
265 | | - "# # save the df as a pickle file\n", |
266 | | - "# df.to_pickle('./data/spotPython/data_sensitive.pkl')\n", |
267 | | - "# # remove all rows with NaN values\n", |
268 | | - "# df = df.dropna()\n", |
269 | | - "# # save the df as a csv file\n", |
270 | | - "# df.to_csv('./data/spotPython/data_sensitive_rmNA.csv', index=False)\n", |
271 | | - "# # save the df as a pickle file\n", |
272 | | - "# df.to_pickle('./data/spotPython/data_sensitive_rmNA.pkl')\n" |
| 254 | + "from pyhcf.data.daten_sensitive import DatenSensitive\n", |
| 255 | + "from pyhcf.utils.names import get_short_parameter_names\n", |
| 256 | + "daten = DatenSensitive()\n", |
| 257 | + "df = daten.load()\n", |
| 258 | + "names = df.columns\n", |
| 259 | + "names = get_short_parameter_names(names)\n", |
| 260 | + "# rename columns with short names\n", |
| 261 | + "df.columns = names\n", |
| 262 | + "df.head()\n", |
| 263 | + "# save the df as a csv file\n", |
| 264 | + "df.to_csv('./data/spotPython/data_sensitive.csv', index=False)\n", |
| 265 | + "# save the df as a pickle file\n", |
| 266 | + "df.to_pickle('./data/spotPython/data_sensitive.pkl')\n", |
| 267 | + "# remove all rows with NaN values\n", |
| 268 | + "df = df.dropna()\n", |
| 269 | + "# save the df as a csv file\n", |
| 270 | + "df.to_csv('./data/spotPython/data_sensitive_rmNA.csv', index=False)\n", |
| 271 | + "# save the df as a pickle file\n", |
| 272 | + "df.to_pickle('./data/spotPython/data_sensitive_rmNA.pkl')\n" |
273 | 273 | ] |
274 | 274 | }, |
275 | 275 | { |
|
398 | 398 | "metadata": {}, |
399 | 399 | "outputs": [], |
400 | 400 | "source": [ |
401 | | - "# from spotPython.light.pkldataset import PKLDataset\n", |
402 | | - "# import torch\n", |
403 | | - "# dataset = PKLDataset(pkl_file='./data/spotPython/data_sensitive.pkl', target_column='A', feature_type=torch.long, rmNA=False)" |
| 401 | + "from spotPython.light.pkldataset import PKLDataset\n", |
| 402 | + "import torch\n", |
| 403 | + "dataset = PKLDataset(pkl_file='./data/spotPython/data_sensitive.pkl', target_column='A', feature_type=torch.long, rmNA=False)" |
404 | 404 | ] |
405 | 405 | }, |
406 | 406 | { |
|
427 | 427 | }, |
428 | 428 | { |
429 | 429 | "cell_type": "code", |
430 | | - "execution_count": null, |
| 430 | + "execution_count": 3, |
431 | 431 | "metadata": {}, |
432 | 432 | "outputs": [], |
433 | 433 | "source": [ |
434 | 434 | "from spotPython.data.pkldataset import PKLDataset\n", |
435 | 435 | "import torch\n", |
436 | | - "dataset = PKLDataset(directory=\"./data/spotPython/\", filename=\"data_sensitive.pkl\", target_column='N', feature_type=torch.float32, target_type=torch.float64, rmNA=False)" |
| 436 | + "dataset = PKLDataset(directory=\"/Users/bartz/workspace/spotPython/notebooks/data/spotPython/\", filename=\"data_sensitive.pkl\", target_column='N', feature_type=torch.float32, target_type=torch.float64, rmNA=False)" |
437 | 437 | ] |
438 | 438 | }, |
439 | 439 | { |
|
467 | 467 | }, |
468 | 468 | { |
469 | 469 | "cell_type": "code", |
470 | | - "execution_count": null, |
| 470 | + "execution_count": 5, |
471 | 471 | "metadata": {}, |
472 | | - "outputs": [], |
| 472 | + "outputs": [ |
| 473 | + { |
| 474 | + "name": "stdout", |
| 475 | + "output_type": "stream", |
| 476 | + "text": [ |
| 477 | + "11\n" |
| 478 | + ] |
| 479 | + } |
| 480 | + ], |
473 | 481 | "source": [ |
474 | 482 | "from spotPython.data.lightdatamodule import LightDataModule\n", |
475 | 483 | "from spotPython.data.csvdataset import CSVDataset\n", |
|
482 | 490 | }, |
483 | 491 | { |
484 | 492 | "cell_type": "code", |
485 | | - "execution_count": null, |
| 493 | + "execution_count": 6, |
486 | 494 | "metadata": {}, |
487 | 495 | "outputs": [], |
488 | 496 | "source": [ |
|
491 | 499 | }, |
492 | 500 | { |
493 | 501 | "cell_type": "code", |
494 | | - "execution_count": null, |
| 502 | + "execution_count": 7, |
495 | 503 | "metadata": {}, |
496 | | - "outputs": [], |
| 504 | + "outputs": [ |
| 505 | + { |
| 506 | + "name": "stdout", |
| 507 | + "output_type": "stream", |
| 508 | + "text": [ |
| 509 | + "full_train_size: 4\n", |
| 510 | + "val_size: 2\n", |
| 511 | + "train_size: 2\n", |
| 512 | + "test_size: 7\n" |
| 513 | + ] |
| 514 | + } |
| 515 | + ], |
497 | 516 | "source": [ |
498 | 517 | "data_module.setup()" |
499 | 518 | ] |
500 | 519 | }, |
501 | 520 | { |
502 | 521 | "cell_type": "code", |
503 | | - "execution_count": null, |
| 522 | + "execution_count": 8, |
504 | 523 | "metadata": {}, |
505 | | - "outputs": [], |
| 524 | + "outputs": [ |
| 525 | + { |
| 526 | + "name": "stdout", |
| 527 | + "output_type": "stream", |
| 528 | + "text": [ |
| 529 | + "Training set size: 2\n" |
| 530 | + ] |
| 531 | + } |
| 532 | + ], |
506 | 533 | "source": [ |
507 | 534 | "print(f\"Training set size: {len(data_module.data_train)}\")" |
508 | 535 | ] |
509 | 536 | }, |
510 | 537 | { |
511 | 538 | "cell_type": "code", |
512 | | - "execution_count": null, |
| 539 | + "execution_count": 9, |
513 | 540 | "metadata": {}, |
514 | | - "outputs": [], |
| 541 | + "outputs": [ |
| 542 | + { |
| 543 | + "name": "stdout", |
| 544 | + "output_type": "stream", |
| 545 | + "text": [ |
| 546 | + "Validation set size: 2\n" |
| 547 | + ] |
| 548 | + } |
| 549 | + ], |
515 | 550 | "source": [ |
516 | 551 | "print(f\"Validation set size: {len(data_module.data_val)}\")" |
517 | 552 | ] |
518 | 553 | }, |
519 | 554 | { |
520 | 555 | "cell_type": "code", |
521 | | - "execution_count": null, |
| 556 | + "execution_count": 10, |
522 | 557 | "metadata": {}, |
523 | | - "outputs": [], |
| 558 | + "outputs": [ |
| 559 | + { |
| 560 | + "name": "stdout", |
| 561 | + "output_type": "stream", |
| 562 | + "text": [ |
| 563 | + "Test set size: 7\n" |
| 564 | + ] |
| 565 | + } |
| 566 | + ], |
524 | 567 | "source": [ |
525 | 568 | "print(f\"Test set size: {len(data_module.data_test)}\")" |
526 | 569 | ] |
|
541 | 584 | }, |
542 | 585 | { |
543 | 586 | "cell_type": "code", |
544 | | - "execution_count": 1, |
| 587 | + "execution_count": null, |
| 588 | + "metadata": {}, |
| 589 | + "outputs": [], |
| 590 | + "source": [ |
| 591 | + "from spotPython.utils.init import fun_control_init\n", |
| 592 | + "from spotPython.hyperparameters.values import set_data_module\n", |
| 593 | + "from spotPython.data.lightdatamodule import LightDataModule\n", |
| 594 | + "from spotPython.data.csvdataset import CSVDataset\n", |
| 595 | + "from spotPython.data.pkldataset import PKLDataset\n", |
| 596 | + "import torch\n", |
| 597 | + "fun_control = fun_control_init()\n", |
| 598 | + "dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)\n", |
| 599 | + "dm = LightDataModule(dataset=dataset, batch_size=5, test_size=7)\n", |
| 600 | + "dm.setup()\n", |
| 601 | + "set_data_module(fun_control=fun_control,\n", |
| 602 | + " data_module=dm)\n", |
| 603 | + "data_module = fun_control[\"data_module\"]\n", |
| 604 | + "print(f\"Test set size: {len(data_module.data_test)}\")\n" |
| 605 | + ] |
| 606 | + }, |
| 607 | + { |
| 608 | + "cell_type": "markdown", |
| 609 | + "metadata": {}, |
| 610 | + "source": [ |
| 611 | + "## same with the sensitive data set" |
| 612 | + ] |
| 613 | + }, |
| 614 | + { |
| 615 | + "cell_type": "code", |
| 616 | + "execution_count": 13, |
545 | 617 | "metadata": {}, |
546 | 618 | "outputs": [ |
547 | 619 | { |
|
555 | 627 | "name": "stdout", |
556 | 628 | "output_type": "stream", |
557 | 629 | "text": [ |
558 | | - "Loading data from /Users/bartz/miniforge3/envs/py311/lib/python3.11/site-packages/spotPython/data/data.csv\n", |
559 | | - "full_train_size: 4\n", |
560 | | - "val_size: 2\n", |
561 | | - "train_size: 2\n", |
562 | | - "test_size: 7\n", |
563 | | - "Test set size: 7\n" |
| 630 | + "full_train_size: 56925\n", |
| 631 | + "val_size: 76\n", |
| 632 | + "train_size: 56849\n", |
| 633 | + "test_size: 77\n", |
| 634 | + "Test set size: 77\n" |
564 | 635 | ] |
565 | 636 | } |
566 | 637 | ], |
567 | 638 | "source": [ |
568 | 639 | "from spotPython.utils.init import fun_control_init\n", |
569 | 640 | "from spotPython.hyperparameters.values import set_data_module\n", |
570 | 641 | "from spotPython.data.lightdatamodule import LightDataModule\n", |
571 | | - "from spotPython.data.csvdataset import CSVDataset\n", |
572 | 642 | "from spotPython.data.pkldataset import PKLDataset\n", |
573 | 643 | "import torch\n", |
574 | 644 | "fun_control = fun_control_init()\n", |
575 | | - "dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)\n", |
576 | | - "dm = LightDataModule(dataset=dataset, batch_size=5, test_size=7)\n", |
| 645 | + "dataset = PKLDataset(directory=\"/Users/bartz/workspace/spotPython/notebooks/data/spotPython/\", filename=\"data_sensitive.pkl\", target_column='N', feature_type=torch.float32, target_type=torch.float64, rmNA=False)\n", |
| 646 | + "dm = LightDataModule(dataset=dataset, batch_size=5, test_size=77)\n", |
| 647 | + "dm.setup()\n", |
| 648 | + "set_data_module(fun_control=fun_control,\n", |
| 649 | + " data_module=dm)\n", |
| 650 | + "data_module = fun_control[\"data_module\"]\n", |
| 651 | + "print(f\"Test set size: {len(data_module.data_test)}\")\n" |
| 652 | + ] |
| 653 | + }, |
| 654 | + { |
| 655 | + "cell_type": "markdown", |
| 656 | + "metadata": {}, |
| 657 | + "source": [ |
| 658 | + "## same, but VBDO data set" |
| 659 | + ] |
| 660 | + }, |
| 661 | + { |
| 662 | + "cell_type": "code", |
| 663 | + "execution_count": 15, |
| 664 | + "metadata": {}, |
| 665 | + "outputs": [ |
| 666 | + { |
| 667 | + "name": "stderr", |
| 668 | + "output_type": "stream", |
| 669 | + "text": [ |
| 670 | + "Seed set to 42\n" |
| 671 | + ] |
| 672 | + }, |
| 673 | + { |
| 674 | + "name": "stdout", |
| 675 | + "output_type": "stream", |
| 676 | + "text": [ |
| 677 | + "full_train_size: 630\n", |
| 678 | + "val_size: 68\n", |
| 679 | + "train_size: 562\n", |
| 680 | + "test_size: 77\n", |
| 681 | + "Test set size: 77\n" |
| 682 | + ] |
| 683 | + } |
| 684 | + ], |
| 685 | + "source": [ |
| 686 | + "from spotPython.utils.init import fun_control_init\n", |
| 687 | + "from spotPython.hyperparameters.values import set_data_module\n", |
| 688 | + "from spotPython.data.lightdatamodule import LightDataModule\n", |
| 689 | + "from spotPython.data.csvdataset import CSVDataset\n", |
| 690 | + "import torch\n", |
| 691 | + "fun_control = fun_control_init()\n", |
| 692 | + "dataset = CSVDataset(directory=\"/Users/bartz/workspace/spotPython/notebooks/data/VBDP/\", filename=\"train.csv\",target_column='prognosis', feature_type=torch.long)\n", |
| 693 | + "dm = LightDataModule(dataset=dataset, batch_size=5, test_size=77)\n", |
577 | 694 | "dm.setup()\n", |
578 | 695 | "set_data_module(fun_control=fun_control,\n", |
579 | 696 | " data_module=dm)\n", |
|
0 commit comments