Skip to content

Commit 548537f

Browse files
set data module
1 parent a8668b8 commit 548537f

6 files changed

Lines changed: 114 additions & 56 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 59 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -467,18 +467,9 @@
467467
},
468468
{
469469
"cell_type": "code",
470-
"execution_count": 1,
470+
"execution_count": null,
471471
"metadata": {},
472-
"outputs": [
473-
{
474-
"name": "stdout",
475-
"output_type": "stream",
476-
"text": [
477-
"Loading data from /Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/spotPython/data/data.csv\n",
478-
"11\n"
479-
]
480-
}
481-
],
472+
"outputs": [],
482473
"source": [
483474
"from spotPython.data.lightdatamodule import LightDataModule\n",
484475
"from spotPython.data.csvdataset import CSVDataset\n",
@@ -491,7 +482,7 @@
491482
},
492483
{
493484
"cell_type": "code",
494-
"execution_count": 7,
485+
"execution_count": null,
495486
"metadata": {},
496487
"outputs": [],
497488
"source": [
@@ -500,82 +491,96 @@
500491
},
501492
{
502493
"cell_type": "code",
503-
"execution_count": 8,
494+
"execution_count": null,
504495
"metadata": {},
505-
"outputs": [
506-
{
507-
"name": "stdout",
508-
"output_type": "stream",
509-
"text": [
510-
"full_train_size: 4\n",
511-
"val_size: 2\n",
512-
"train_size: 2\n",
513-
"test_size: 7\n"
514-
]
515-
}
516-
],
496+
"outputs": [],
517497
"source": [
518498
"data_module.setup()"
519499
]
520500
},
521501
{
522502
"cell_type": "code",
523-
"execution_count": 9,
503+
"execution_count": null,
524504
"metadata": {},
525-
"outputs": [
526-
{
527-
"name": "stdout",
528-
"output_type": "stream",
529-
"text": [
530-
"Training set size: 2\n"
531-
]
532-
}
533-
],
505+
"outputs": [],
534506
"source": [
535507
"print(f\"Training set size: {len(data_module.data_train)}\")"
536508
]
537509
},
538510
{
539511
"cell_type": "code",
540-
"execution_count": 10,
512+
"execution_count": null,
541513
"metadata": {},
542-
"outputs": [
543-
{
544-
"name": "stdout",
545-
"output_type": "stream",
546-
"text": [
547-
"Validation set size: 2\n"
548-
]
549-
}
550-
],
514+
"outputs": [],
551515
"source": [
552516
"print(f\"Validation set size: {len(data_module.data_val)}\")"
553517
]
554518
},
555519
{
556520
"cell_type": "code",
557-
"execution_count": 11,
521+
"execution_count": null,
522+
"metadata": {},
523+
"outputs": [],
524+
"source": [
525+
"print(f\"Test set size: {len(data_module.data_test)}\")"
526+
]
527+
},
528+
{
529+
"cell_type": "code",
530+
"execution_count": null,
531+
"metadata": {},
532+
"outputs": [],
533+
"source": []
534+
},
535+
{
536+
"cell_type": "markdown",
537+
"metadata": {},
538+
"source": [
539+
"# Set the DataModule in fun_control "
540+
]
541+
},
542+
{
543+
"cell_type": "code",
544+
"execution_count": 1,
558545
"metadata": {},
559546
"outputs": [
547+
{
548+
"name": "stderr",
549+
"output_type": "stream",
550+
"text": [
551+
"Seed set to 42\n"
552+
]
553+
},
560554
{
561555
"name": "stdout",
562556
"output_type": "stream",
563557
"text": [
558+
"Loading data from /Users/bartz/miniforge3/envs/py311/lib/python3.11/site-packages/spotPython/data/data.csv\n",
559+
"full_train_size: 4\n",
560+
"val_size: 2\n",
561+
"train_size: 2\n",
562+
"test_size: 7\n",
564563
"Test set size: 7\n"
565564
]
566565
}
567566
],
568567
"source": [
569-
"print(f\"Test set size: {len(data_module.data_test)}\")"
568+
"from spotPython.utils.init import fun_control_init\n",
569+
"from spotPython.hyperparameters.values import set_data_module\n",
570+
"from spotPython.data.lightdatamodule import LightDataModule\n",
571+
"from spotPython.data.csvdataset import CSVDataset\n",
572+
"from spotPython.data.pkldataset import PKLDataset\n",
573+
"import torch\n",
574+
"fun_control = fun_control_init()\n",
575+
"dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)\n",
576+
"dm = LightDataModule(dataset=dataset, batch_size=5, test_size=7)\n",
577+
"dm.setup()\n",
578+
"set_data_module(fun_control=fun_control,\n",
579+
" data_module=dm)\n",
580+
"data_module = fun_control[\"data_module\"]\n",
581+
"print(f\"Test set size: {len(data_module.data_test)}\")\n"
570582
]
571583
},
572-
{
573-
"cell_type": "code",
574-
"execution_count": null,
575-
"metadata": {},
576-
"outputs": [],
577-
"source": []
578-
},
579584
{
580585
"cell_type": "code",
581586
"execution_count": null,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.6.46"
10+
version = "0.6.47"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/data/csvdataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _repr_content(self):
7676
return content
7777

7878
def _load_data(self) -> tuple:
79-
print(f"Loading data from {self.path}")
79+
# print(f"Loading data from {self.path}")
8080
df = pd.read_csv(self.path, index_col=False)
8181
# rm rows with NA
8282
if self.rmNA:

src/spotPython/hyperparameters/values.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,3 +867,34 @@ def get_default_hyperparameters_for_core_model(fun_control) -> dict:
867867
values = convert_keys(values, fun_control["var_type"])
868868
values = transform_hyper_parameter_values(fun_control=fun_control, hyper_parameter_values=values)
869869
return values
870+
871+
872+
def set_data_module(fun_control, data_module) -> dict:
873+
"""
874+
This function sets the lightning datamodule in the fun_control dictionary.
875+
876+
Args:
877+
fun_control (dict):
878+
fun_control dictionary
879+
data_module (class): DataLoader class from torch.utils.data
880+
881+
Returns:
882+
fun_control (dict):
883+
updated fun_control
884+
885+
Examples:
886+
>>> from spotPython.utils.init import fun_control_init
887+
from spotPython.utils.prepare import set_data_module
888+
from spotPython.data.lightdatamodule import LightDataModule
889+
from spotPython.data.csvdataset import CSVDataset
890+
from spotPython.data.pkldataset import PKLDataset
891+
import torch
892+
fun_control = fun_control_init()
893+
dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
894+
dm = LightDataModule(dataset=dataset, batch_size=5, test_size=7)
895+
dm.setup()
896+
set_data_module(fun_control=fun_control,
897+
data_module=dm)
898+
fun_control["data_module"]
899+
"""
900+
fun_control.update({"data_module": data_module})

src/spotPython/utils/init.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def fun_control_init(
119119
"_L_out": _L_out,
120120
"data": None,
121121
"data_dir": "./data",
122+
"data_module": None,
122123
"device": device,
123124
"enable_progress_bar": enable_progress_bar,
124125
"eval": None,

test/test_set_data_module.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
from spotPython.utils.init import fun_control_init
3+
from spotPython.hyperparameters.values import set_data_module
4+
from spotPython.data.lightdatamodule import LightDataModule
5+
from spotPython.data.csvdataset import CSVDataset
6+
import torch
7+
8+
9+
def test_set_data_module():
10+
fun_control = fun_control_init()
11+
dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
12+
dm = LightDataModule(dataset=dataset, batch_size=5, test_size=7)
13+
dm.setup()
14+
set_data_module(fun_control=fun_control,
15+
data_module=dm)
16+
data_module = fun_control["data_module"]
17+
# if assinged correctly, the length of the data_test should be the same as the length of the dataset dm:
18+
assert len(dm.data_test) == len(data_module.data_test)
19+
20+
if __name__ == "__main__":
21+
pytest.main(["-v", __file__])

0 commit comments

Comments
 (0)