1- import lightning as L
1+ import pytorch_lightning as pl
22from torch .utils .data import DataLoader , random_split
3- from spotPython .light .csvdataset import CSVDataset
3+ from torchvision import transforms
4+ from torchvision .datasets import CIFAR10
45from typing import Optional
56
67
7- class CSVDataModule ( L .LightningDataModule ):
8+ class CIFAR10DataModule ( pl .LightningDataModule ):
89 """
9- A LightningDataModule for handling CSV data.
10+ A LightningDataModule for handling CIFAR10 data.
1011
1112 Args:
1213 batch_size (int): The size of the batch.
13- DATASET_PATH (str): The path to the dataset . Defaults to "./data".
14+ data_dir (str): The directory where the data is stored . Defaults to "./data".
1415 num_workers (int): The number of workers for data loading. Defaults to 0.
1516
1617 Attributes:
@@ -19,58 +20,43 @@ class CSVDataModule(L.LightningDataModule):
1920 data_test (Dataset): The test dataset.
2021 """
2122
22- def __init__ (self , batch_size : int , DATASET_PATH : str = "./data" , num_workers : int = 0 ):
23+ def __init__ (self , batch_size : int , data_dir : str = "./data" , num_workers : int = 0 ):
2324 super ().__init__ ()
2425 self .batch_size = batch_size
26+ self .data_dir = data_dir
2527 self .num_workers = num_workers
2628
2729 def prepare_data (self ) -> None :
2830 """Prepares the data for use."""
2931 # download
30- pass
32+ CIFAR10 (self .data_dir , train = True , download = True )
33+ CIFAR10 (self .data_dir , train = False , download = True )
3134
3235 def setup (self , stage : Optional [str ] = None ) -> None :
3336 """
3437 Sets up the data for use.
3538
3639 Args:
3740 stage (Optional[str]): The current stage. Defaults to None.
38- Examples:
39- >>> from spotPython.light import CSVDataModule
40- >>> data_module = CSVDataModule(batch_size=64)
41- >>> data_module.setup()
42- >>> print(f"Training set size: {len(data_module.data_train)}")
43- Training set size: 45000
44- >>> print(f"Validation set size: {len(data_module.data_val)}")
45- Validation set size: 5000
46- >>> print(f"Test set size: {len(data_module.data_test)}")
47- Test set size: 10000
4841
4942 """
5043 # Assign train/val datasets for use in dataloaders
5144 if stage == "fit" or stage is None :
52- data_full = CSVDataset ( csv_file = "./data/VBDP/train.csv" , train = True )
53- test_abs = int ( len ( data_full ) * 0.6 )
54- self .data_train , self .data_val = random_split (data_full , [test_abs , len ( data_full ) - test_abs ])
45+ transform = transforms . Compose ([ transforms . ToTensor ()] )
46+ cifar_full = CIFAR10 ( self . data_dir , train = True , transform = transform )
47+ self .data_train , self .data_val = random_split (cifar_full , [45000 , 5000 ])
5548
5649 # Assign test dataset for use in dataloader(s)
57- # TODO: Adapt this to the VBDP Situation
5850 if stage == "test" or stage is None :
59- self .data_test = CSVDataset (csv_file = "./data/VBDP/train.csv" , train = True )
51+ transform = transforms .Compose ([transforms .ToTensor ()])
52+ self .data_test = CIFAR10 (self .data_dir , train = False , transform = transform )
6053
6154 def train_dataloader (self ) -> DataLoader :
6255 """
6356 Returns the training dataloader.
6457
6558 Returns:
6659 DataLoader: The training dataloader.
67- Examples:
68- >>> from spotPython.light import CSVDataModule
69- >>> data_module = CSVDataModule(batch_size=64)
70- >>> data_module.setup()
71- >>> train_dataloader = data_module.train_dataloader()
72- >>> print(f"Training dataloader size: {len(train_dataloader)}")
73- Training dataloader size: 704
7460
7561 """
7662 return DataLoader (self .data_train , batch_size = self .batch_size , num_workers = self .num_workers )
@@ -81,13 +67,7 @@ def val_dataloader(self) -> DataLoader:
8167
8268 Returns:
8369 DataLoader: The validation dataloader.
84- Examples:
85- >>> from spotPython.light import CSVDataModule
86- >>> data_module = CSVDataModule(batch_size=64)
87- >>> data_module.setup()
88- >>> val_dataloader = data_module.val_dataloader()
89- >>> print(f"Validation dataloader size: {len(val_dataloader)}")
90- Validation dataloader size: 79
70+
9171
9272 """
9373 return DataLoader (self .data_val , batch_size = self .batch_size , num_workers = self .num_workers )
@@ -99,13 +79,6 @@ def test_dataloader(self) -> DataLoader:
9979 Returns:
10080 DataLoader: The test dataloader.
10181
102- Examples:
103- >>> from spotPython.light import CSVDataModule
104- >>> data_module = CSVDataModule(batch_size=64)
105- >>> data_module.setup()
106- >>> test_dataloader = data_module.test_dataloader()
107- >>> print(f"Test dataloader size: {len(test_dataloader)}")
108- Test dataloader size: 704
10982
11083 """
11184 return DataLoader (self .data_test , batch_size = self .batch_size , num_workers = self .num_workers )
0 commit comments