Skip to content

Commit 460780d

Browse files
generic csv data set
1 parent 548537f commit 460780d

8 files changed

Lines changed: 232 additions & 57 deletions

File tree

makeSpot.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,3 @@
22
cd ~/workspace/spotPython
33
rm -f dist/spotPython*; python -m build; python -m pip install dist/spotPython*.tar.gz
44
python -m mkdocs build
5-
pytest

notebooks/00_spotPython_tests.ipynb

Lines changed: 163 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -247,29 +247,29 @@
247247
},
248248
{
249249
"cell_type": "code",
250-
"execution_count": null,
250+
"execution_count": 1,
251251
"metadata": {},
252252
"outputs": [],
253253
"source": [
254-
"# from pyhcf.data.daten_sensitive import DatenSensitive\n",
255-
"# from pyhcf.utils.names import get_short_parameter_names\n",
256-
"# daten = DatenSensitive()\n",
257-
"# df = daten.load()\n",
258-
"# names = df.columns\n",
259-
"# names = get_short_parameter_names(names)\n",
260-
"# # rename columns with short names\n",
261-
"# df.columns = names\n",
262-
"# df.head()\n",
263-
"# # save the df as a csv file\n",
264-
"# df.to_csv('./data/spotPython/data_sensitive.csv', index=False)\n",
265-
"# # save the df as a pickle file\n",
266-
"# df.to_pickle('./data/spotPython/data_sensitive.pkl')\n",
267-
"# # remove all rows with NaN values\n",
268-
"# df = df.dropna()\n",
269-
"# # save the df as a csv file\n",
270-
"# df.to_csv('./data/spotPython/data_sensitive_rmNA.csv', index=False)\n",
271-
"# # save the df as a pickle file\n",
272-
"# df.to_pickle('./data/spotPython/data_sensitive_rmNA.pkl')\n"
254+
"from pyhcf.data.daten_sensitive import DatenSensitive\n",
255+
"from pyhcf.utils.names import get_short_parameter_names\n",
256+
"daten = DatenSensitive()\n",
257+
"df = daten.load()\n",
258+
"names = df.columns\n",
259+
"names = get_short_parameter_names(names)\n",
260+
"# rename columns with short names\n",
261+
"df.columns = names\n",
262+
"df.head()\n",
263+
"# save the df as a csv file\n",
264+
"df.to_csv('./data/spotPython/data_sensitive.csv', index=False)\n",
265+
"# save the df as a pickle file\n",
266+
"df.to_pickle('./data/spotPython/data_sensitive.pkl')\n",
267+
"# remove all rows with NaN values\n",
268+
"df = df.dropna()\n",
269+
"# save the df as a csv file\n",
270+
"df.to_csv('./data/spotPython/data_sensitive_rmNA.csv', index=False)\n",
271+
"# save the df as a pickle file\n",
272+
"df.to_pickle('./data/spotPython/data_sensitive_rmNA.pkl')\n"
273273
]
274274
},
275275
{
@@ -398,9 +398,9 @@
398398
"metadata": {},
399399
"outputs": [],
400400
"source": [
401-
"# from spotPython.light.pkldataset import PKLDataset\n",
402-
"# import torch\n",
403-
"# dataset = PKLDataset(pkl_file='./data/spotPython/data_sensitive.pkl', target_column='A', feature_type=torch.long, rmNA=False)"
401+
"from spotPython.light.pkldataset import PKLDataset\n",
402+
"import torch\n",
403+
"dataset = PKLDataset(pkl_file='./data/spotPython/data_sensitive.pkl', target_column='A', feature_type=torch.long, rmNA=False)"
404404
]
405405
},
406406
{
@@ -427,13 +427,13 @@
427427
},
428428
{
429429
"cell_type": "code",
430-
"execution_count": null,
430+
"execution_count": 3,
431431
"metadata": {},
432432
"outputs": [],
433433
"source": [
434434
"from spotPython.data.pkldataset import PKLDataset\n",
435435
"import torch\n",
436-
"dataset = PKLDataset(directory=\"./data/spotPython/\", filename=\"data_sensitive.pkl\", target_column='N', feature_type=torch.float32, target_type=torch.float64, rmNA=False)"
436+
"dataset = PKLDataset(directory=\"/Users/bartz/workspace/spotPython/notebooks/data/spotPython/\", filename=\"data_sensitive.pkl\", target_column='N', feature_type=torch.float32, target_type=torch.float64, rmNA=False)"
437437
]
438438
},
439439
{
@@ -467,9 +467,17 @@
467467
},
468468
{
469469
"cell_type": "code",
470-
"execution_count": null,
470+
"execution_count": 5,
471471
"metadata": {},
472-
"outputs": [],
472+
"outputs": [
473+
{
474+
"name": "stdout",
475+
"output_type": "stream",
476+
"text": [
477+
"11\n"
478+
]
479+
}
480+
],
473481
"source": [
474482
"from spotPython.data.lightdatamodule import LightDataModule\n",
475483
"from spotPython.data.csvdataset import CSVDataset\n",
@@ -482,7 +490,7 @@
482490
},
483491
{
484492
"cell_type": "code",
485-
"execution_count": null,
493+
"execution_count": 6,
486494
"metadata": {},
487495
"outputs": [],
488496
"source": [
@@ -491,36 +499,71 @@
491499
},
492500
{
493501
"cell_type": "code",
494-
"execution_count": null,
502+
"execution_count": 7,
495503
"metadata": {},
496-
"outputs": [],
504+
"outputs": [
505+
{
506+
"name": "stdout",
507+
"output_type": "stream",
508+
"text": [
509+
"full_train_size: 4\n",
510+
"val_size: 2\n",
511+
"train_size: 2\n",
512+
"test_size: 7\n"
513+
]
514+
}
515+
],
497516
"source": [
498517
"data_module.setup()"
499518
]
500519
},
501520
{
502521
"cell_type": "code",
503-
"execution_count": null,
522+
"execution_count": 8,
504523
"metadata": {},
505-
"outputs": [],
524+
"outputs": [
525+
{
526+
"name": "stdout",
527+
"output_type": "stream",
528+
"text": [
529+
"Training set size: 2\n"
530+
]
531+
}
532+
],
506533
"source": [
507534
"print(f\"Training set size: {len(data_module.data_train)}\")"
508535
]
509536
},
510537
{
511538
"cell_type": "code",
512-
"execution_count": null,
539+
"execution_count": 9,
513540
"metadata": {},
514-
"outputs": [],
541+
"outputs": [
542+
{
543+
"name": "stdout",
544+
"output_type": "stream",
545+
"text": [
546+
"Validation set size: 2\n"
547+
]
548+
}
549+
],
515550
"source": [
516551
"print(f\"Validation set size: {len(data_module.data_val)}\")"
517552
]
518553
},
519554
{
520555
"cell_type": "code",
521-
"execution_count": null,
556+
"execution_count": 10,
522557
"metadata": {},
523-
"outputs": [],
558+
"outputs": [
559+
{
560+
"name": "stdout",
561+
"output_type": "stream",
562+
"text": [
563+
"Test set size: 7\n"
564+
]
565+
}
566+
],
524567
"source": [
525568
"print(f\"Test set size: {len(data_module.data_test)}\")"
526569
]
@@ -541,7 +584,36 @@
541584
},
542585
{
543586
"cell_type": "code",
544-
"execution_count": 1,
587+
"execution_count": null,
588+
"metadata": {},
589+
"outputs": [],
590+
"source": [
591+
"from spotPython.utils.init import fun_control_init\n",
592+
"from spotPython.hyperparameters.values import set_data_module\n",
593+
"from spotPython.data.lightdatamodule import LightDataModule\n",
594+
"from spotPython.data.csvdataset import CSVDataset\n",
595+
"from spotPython.data.pkldataset import PKLDataset\n",
596+
"import torch\n",
597+
"fun_control = fun_control_init()\n",
598+
"dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)\n",
599+
"dm = LightDataModule(dataset=dataset, batch_size=5, test_size=7)\n",
600+
"dm.setup()\n",
601+
"set_data_module(fun_control=fun_control,\n",
602+
" data_module=dm)\n",
603+
"data_module = fun_control[\"data_module\"]\n",
604+
"print(f\"Test set size: {len(data_module.data_test)}\")\n"
605+
]
606+
},
607+
{
608+
"cell_type": "markdown",
609+
"metadata": {},
610+
"source": [
611+
"## same with the sensitive data set"
612+
]
613+
},
614+
{
615+
"cell_type": "code",
616+
"execution_count": 13,
545617
"metadata": {},
546618
"outputs": [
547619
{
@@ -555,25 +627,70 @@
555627
"name": "stdout",
556628
"output_type": "stream",
557629
"text": [
558-
"Loading data from /Users/bartz/miniforge3/envs/py311/lib/python3.11/site-packages/spotPython/data/data.csv\n",
559-
"full_train_size: 4\n",
560-
"val_size: 2\n",
561-
"train_size: 2\n",
562-
"test_size: 7\n",
563-
"Test set size: 7\n"
630+
"full_train_size: 56925\n",
631+
"val_size: 76\n",
632+
"train_size: 56849\n",
633+
"test_size: 77\n",
634+
"Test set size: 77\n"
564635
]
565636
}
566637
],
567638
"source": [
568639
"from spotPython.utils.init import fun_control_init\n",
569640
"from spotPython.hyperparameters.values import set_data_module\n",
570641
"from spotPython.data.lightdatamodule import LightDataModule\n",
571-
"from spotPython.data.csvdataset import CSVDataset\n",
572642
"from spotPython.data.pkldataset import PKLDataset\n",
573643
"import torch\n",
574644
"fun_control = fun_control_init()\n",
575-
"dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)\n",
576-
"dm = LightDataModule(dataset=dataset, batch_size=5, test_size=7)\n",
645+
"dataset = PKLDataset(directory=\"/Users/bartz/workspace/spotPython/notebooks/data/spotPython/\", filename=\"data_sensitive.pkl\", target_column='N', feature_type=torch.float32, target_type=torch.float64, rmNA=False)\n",
646+
"dm = LightDataModule(dataset=dataset, batch_size=5, test_size=77)\n",
647+
"dm.setup()\n",
648+
"set_data_module(fun_control=fun_control,\n",
649+
" data_module=dm)\n",
650+
"data_module = fun_control[\"data_module\"]\n",
651+
"print(f\"Test set size: {len(data_module.data_test)}\")\n"
652+
]
653+
},
654+
{
655+
"cell_type": "markdown",
656+
"metadata": {},
657+
"source": [
658+
"## same, but VBDO data set"
659+
]
660+
},
661+
{
662+
"cell_type": "code",
663+
"execution_count": 15,
664+
"metadata": {},
665+
"outputs": [
666+
{
667+
"name": "stderr",
668+
"output_type": "stream",
669+
"text": [
670+
"Seed set to 42\n"
671+
]
672+
},
673+
{
674+
"name": "stdout",
675+
"output_type": "stream",
676+
"text": [
677+
"full_train_size: 630\n",
678+
"val_size: 68\n",
679+
"train_size: 562\n",
680+
"test_size: 77\n",
681+
"Test set size: 77\n"
682+
]
683+
}
684+
],
685+
"source": [
686+
"from spotPython.utils.init import fun_control_init\n",
687+
"from spotPython.hyperparameters.values import set_data_module\n",
688+
"from spotPython.data.lightdatamodule import LightDataModule\n",
689+
"from spotPython.data.csvdataset import CSVDataset\n",
690+
"import torch\n",
691+
"fun_control = fun_control_init()\n",
692+
"dataset = CSVDataset(directory=\"/Users/bartz/workspace/spotPython/notebooks/data/VBDP/\", filename=\"train.csv\",target_column='prognosis', feature_type=torch.long)\n",
693+
"dm = LightDataModule(dataset=dataset, batch_size=5, test_size=77)\n",
577694
"dm.setup()\n",
578695
"set_data_module(fun_control=fun_control,\n",
579696
" data_module=dm)\n",

src/spotPython/data/csvdataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class CSVDataset(Dataset):
1717
target_type (torch.dtype): The data type of the targets. Defaults to torch.long.
1818
train (bool): Whether the dataset is for training or not. Defaults to True.
1919
rmNA (bool): Whether to remove rows with NA values or not. Defaults to True.
20+
dropId (bool): Whether to drop the "id" column or not. Defaults to False.
2021
**desc: Additional keyword arguments.
2122
2223
Attributes:
@@ -51,6 +52,7 @@ def __init__(
5152
target_type: torch.dtype = torch.long,
5253
train: bool = True,
5354
rmNA=True,
55+
dropId=False,
5456
**desc,
5557
) -> None:
5658
super().__init__()
@@ -61,6 +63,7 @@ def __init__(
6163
self.target_column = target_column
6264
self.train = train
6365
self.rmNA = rmNA
66+
self.dropId = dropId
6467
self.data, self.targets = self._load_data()
6568

6669
@property
@@ -81,6 +84,8 @@ def _load_data(self) -> tuple:
8184
# rm rows with NA
8285
if self.rmNA:
8386
df = df.dropna()
87+
if self.dropId:
88+
df = df.drop(columns=["id"])
8489
# Apply LabelEncoder to string columns
8590
le = LabelEncoder()
8691
df = df.apply(lambda col: le.fit_transform(col) if col.dtypes == object else col)

src/spotPython/data/lightdatamodule.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,10 @@ def setup(self, stage: Optional[str] = None) -> None:
8484
val_size = int(full_train_size * test_size / len(self.data_full))
8585
train_size = full_train_size - val_size
8686

87-
print(f"full_train_size: {full_train_size}")
88-
print(f"val_size: {val_size}")
89-
print(f"train_size: {train_size}")
90-
print(f"test_size: {test_size}")
87+
# print(f"full_train_size: {full_train_size}")
88+
# print(f"val_size: {val_size}")
89+
# print(f"train_size: {train_size}")
90+
# print(f"test_size: {test_size}")
9191

9292
# Assign train/val datasets for use in dataloaders
9393
if stage == "fit" or stage is None:

src/spotPython/hyperparameters/values.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,35 @@ def get_default_hyperparameters_for_core_model(fun_control) -> dict:
869869
return values
870870

871871

872+
def set_data_set(fun_control, data_set) -> dict:
873+
"""
874+
This function sets the lightning dataset in the fun_control dictionary.
875+
876+
Args:
877+
fun_control (dict):
878+
fun_control dictionary
879+
data_set (class): Dataset class from torch.utils.data
880+
881+
Returns:
882+
fun_control (dict):
883+
updated fun_control
884+
885+
Examples:
886+
>>> from spotPython.utils.init import fun_control_init
887+
from spotPython.utils.prepare import set_data_module
888+
from spotPython.data.lightdatamodule import LightDataModule
889+
from spotPython.data.csvdataset import CSVDataset
890+
from spotPython.data.pkldataset import PKLDataset
891+
import torch
892+
fun_control = fun_control_init()
893+
ds = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
894+
set_data_set(fun_control=fun_control,
895+
data_set=ds)
896+
fun_control["data_set"]
897+
"""
898+
fun_control.update({"data_set": data_set})
899+
900+
872901
def set_data_module(fun_control, data_module) -> dict:
873902
"""
874903
This function sets the lightning datamodule in the fun_control dictionary.

0 commit comments

Comments
 (0)