Skip to content

Commit 6709f68

Browse files
committed
scaler for cross validation
1 parent 4080784 commit 6709f68

2 files changed

Lines changed: 22 additions & 1 deletion

File tree

src/spotPython/data/lightcrossvalidationdatamodule.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
from torch.utils.data import DataLoader, Subset
33
from typing import Optional
44
from sklearn.model_selection import KFold
5-
from torch.utils.data import Dataset
5+
from torch.utils.data import Dataset, TensorDataset
6+
import torch
7+
68

79

810
class LightCrossValidationDataModule(L.LightningDataModule):
@@ -44,6 +46,7 @@ def __init__(
4446
data_dir: str = "./data",
4547
num_workers: int = 0,
4648
pin_memory: bool = False,
49+
scaler: Optional[object] = None,
4750
):
4851
super().__init__()
4952
self.batch_size = batch_size
@@ -54,6 +57,7 @@ def __init__(
5457
self.split_seed = split_seed
5558
self.num_splits = num_splits
5659
self.pin_memory = pin_memory
60+
self.scaler = scaler
5761
self.save_hyperparameters(logger=False)
5862
assert 0 <= self.k < self.num_splits, "incorrect fold number"
5963

@@ -85,6 +89,21 @@ def setup(self, stage: Optional[str] = None) -> None:
8589
print(f"Train Dataset Size: {len(self.data_train)}")
8690
self.data_val = Subset(dataset_full, val_indexes)
8791
print(f"Val Dataset Size: {len(self.data_val)}")
92+
93+
if self.scaler is not None:
94+
# Fit the scaler on training data and transform both train and val data
95+
scaler_train_data = torch.stack([self.data_train[i][0] for i in range(len(self.data_train))]).squeeze(1)
96+
self.scaler.fit(scaler_train_data)
97+
self.data_train = [(self.scaler.transform(data), target) for data, target in self.data_train]
98+
data_tensors_train = [data.clone().detach() for data, target in self.data_train]
99+
target_tensors_train = [target.clone().detach() for data, target in self.data_train]
100+
self.data_train = TensorDataset(
101+
torch.stack(data_tensors_train).squeeze(1), torch.stack(target_tensors_train)
102+
)
103+
self.data_val = [(self.scaler.transform(data), target) for data, target in self.data_val]
104+
data_tensors_val = [data.clone().detach() for data, target in self.data_val]
105+
target_tensors_val = [target.clone().detach() for data, target in self.data_val]
106+
self.data_val = TensorDataset(torch.stack(data_tensors_val).squeeze(1), torch.stack(target_tensors_val))
88107

89108
def train_dataloader(self) -> DataLoader:
90109
"""

src/spotPython/light/cvmodel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ def cv_model(config: dict, fun_control: dict) -> float:
7171
num_workers=fun_control["num_workers"],
7272
batch_size=config["batch_size"],
7373
data_dir=fun_control["DATASET_PATH"],
74+
scaler=fun_control["scaler"],
7475
)
76+
dm.setup()
7577
dm.prepare_data()
7678

7779
# TODO: Check if this is necessary:

0 commit comments

Comments
 (0)