22from torch .utils .data import DataLoader , Subset
33from typing import Optional
44from sklearn .model_selection import KFold
5- from torch .utils .data import Dataset
5+ from torch .utils .data import Dataset , TensorDataset
6+ import torch
7+
68
79
810class LightCrossValidationDataModule (L .LightningDataModule ):
@@ -44,6 +46,7 @@ def __init__(
4446 data_dir : str = "./data" ,
4547 num_workers : int = 0 ,
4648 pin_memory : bool = False ,
49+ scaler : Optional [object ] = None ,
4750 ):
4851 super ().__init__ ()
4952 self .batch_size = batch_size
@@ -54,6 +57,7 @@ def __init__(
5457 self .split_seed = split_seed
5558 self .num_splits = num_splits
5659 self .pin_memory = pin_memory
60+ self .scaler = scaler
5761 self .save_hyperparameters (logger = False )
5862 assert 0 <= self .k < self .num_splits , "incorrect fold number"
5963
@@ -85,6 +89,21 @@ def setup(self, stage: Optional[str] = None) -> None:
8589 print (f"Train Dataset Size: { len (self .data_train )} " )
8690 self .data_val = Subset (dataset_full , val_indexes )
8791 print (f"Val Dataset Size: { len (self .data_val )} " )
92+
93+ if self .scaler is not None :
94+ # Fit the scaler on training data and transform both train and val data
95+ scaler_train_data = torch .stack ([self .data_train [i ][0 ] for i in range (len (self .data_train ))]).squeeze (1 )
96+ self .scaler .fit (scaler_train_data )
97+ self .data_train = [(self .scaler .transform (data ), target ) for data , target in self .data_train ]
98+ data_tensors_train = [data .clone ().detach () for data , target in self .data_train ]
99+ target_tensors_train = [target .clone ().detach () for data , target in self .data_train ]
100+ self .data_train = TensorDataset (
101+ torch .stack (data_tensors_train ).squeeze (1 ), torch .stack (target_tensors_train )
102+ )
103+ self .data_val = [(self .scaler .transform (data ), target ) for data , target in self .data_val ]
104+ data_tensors_val = [data .clone ().detach () for data , target in self .data_val ]
105+ target_tensors_val = [target .clone ().detach () for data , target in self .data_val ]
106+ self .data_val = TensorDataset (torch .stack (data_tensors_val ).squeeze (1 ), torch .stack (target_tensors_val ))
88107
89108 def train_dataloader (self ) -> DataLoader :
90109 """
0 commit comments