Skip to content

Commit df78d88

Browse files
Update cifar10datamodule.py
1 parent c1d49f1 commit df78d88

1 file changed

Lines changed: 16 additions & 43 deletions

File tree

Lines changed: 16 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
import lightning as L
1+
import pytorch_lightning as pl
22
from torch.utils.data import DataLoader, random_split
3-
from spotPython.light.csvdataset import CSVDataset
3+
from torchvision import transforms
4+
from torchvision.datasets import CIFAR10
45
from typing import Optional
56

67

7-
class CSVDataModule(L.LightningDataModule):
8+
class CIFAR10DataModule(pl.LightningDataModule):
89
"""
9-
A LightningDataModule for handling CSV data.
10+
A LightningDataModule for handling CIFAR10 data.
1011
1112
Args:
1213
batch_size (int): The size of the batch.
13-
DATASET_PATH (str): The path to the dataset. Defaults to "./data".
14+
data_dir (str): The directory where the data is stored. Defaults to "./data".
1415
num_workers (int): The number of workers for data loading. Defaults to 0.
1516
1617
Attributes:
@@ -19,58 +20,43 @@ class CSVDataModule(L.LightningDataModule):
1920
data_test (Dataset): The test dataset.
2021
"""
2122

22-
def __init__(self, batch_size: int, DATASET_PATH: str = "./data", num_workers: int = 0):
23+
def __init__(self, batch_size: int, data_dir: str = "./data", num_workers: int = 0):
2324
super().__init__()
2425
self.batch_size = batch_size
26+
self.data_dir = data_dir
2527
self.num_workers = num_workers
2628

2729
def prepare_data(self) -> None:
2830
"""Prepares the data for use."""
2931
# download
30-
pass
32+
CIFAR10(self.data_dir, train=True, download=True)
33+
CIFAR10(self.data_dir, train=False, download=True)
3134

3235
def setup(self, stage: Optional[str] = None) -> None:
3336
"""
3437
Sets up the data for use.
3538
3639
Args:
3740
stage (Optional[str]): The current stage. Defaults to None.
38-
Examples:
39-
>>> from spotPython.light import CSVDataModule
40-
>>> data_module = CSVDataModule(batch_size=64)
41-
>>> data_module.setup()
42-
>>> print(f"Training set size: {len(data_module.data_train)}")
43-
Training set size: 45000
44-
>>> print(f"Validation set size: {len(data_module.data_val)}")
45-
Validation set size: 5000
46-
>>> print(f"Test set size: {len(data_module.data_test)}")
47-
Test set size: 10000
4841
4942
"""
5043
# Assign train/val datasets for use in dataloaders
5144
if stage == "fit" or stage is None:
52-
data_full = CSVDataset(csv_file="./data/VBDP/train.csv", train=True)
53-
test_abs = int(len(data_full) * 0.6)
54-
self.data_train, self.data_val = random_split(data_full, [test_abs, len(data_full) - test_abs])
45+
transform = transforms.Compose([transforms.ToTensor()])
46+
cifar_full = CIFAR10(self.data_dir, train=True, transform=transform)
47+
self.data_train, self.data_val = random_split(cifar_full, [45000, 5000])
5548

5649
# Assign test dataset for use in dataloader(s)
57-
# TODO: Adapt this to the VBDP Situation
5850
if stage == "test" or stage is None:
59-
self.data_test = CSVDataset(csv_file="./data/VBDP/train.csv", train=True)
51+
transform = transforms.Compose([transforms.ToTensor()])
52+
self.data_test = CIFAR10(self.data_dir, train=False, transform=transform)
6053

6154
def train_dataloader(self) -> DataLoader:
6255
"""
6356
Returns the training dataloader.
6457
6558
Returns:
6659
DataLoader: The training dataloader.
67-
Examples:
68-
>>> from spotPython.light import CSVDataModule
69-
>>> data_module = CSVDataModule(batch_size=64)
70-
>>> data_module.setup()
71-
>>> train_dataloader = data_module.train_dataloader()
72-
>>> print(f"Training dataloader size: {len(train_dataloader)}")
73-
Training dataloader size: 704
7460
7561
"""
7662
return DataLoader(self.data_train, batch_size=self.batch_size, num_workers=self.num_workers)
@@ -81,13 +67,7 @@ def val_dataloader(self) -> DataLoader:
8167
8268
Returns:
8369
DataLoader: The validation dataloader.
84-
Examples:
85-
>>> from spotPython.light import CSVDataModule
86-
>>> data_module = CSVDataModule(batch_size=64)
87-
>>> data_module.setup()
88-
>>> val_dataloader = data_module.val_dataloader()
89-
>>> print(f"Validation dataloader size: {len(val_dataloader)}")
90-
Validation dataloader size: 79
70+
9171
9272
"""
9373
return DataLoader(self.data_val, batch_size=self.batch_size, num_workers=self.num_workers)
@@ -99,13 +79,6 @@ def test_dataloader(self) -> DataLoader:
9979
Returns:
10080
DataLoader: The test dataloader.
10181
102-
Examples:
103-
>>> from spotPython.light import CSVDataModule
104-
>>> data_module = CSVDataModule(batch_size=64)
105-
>>> data_module.setup()
106-
>>> test_dataloader = data_module.test_dataloader()
107-
>>> print(f"Test dataloader size: {len(test_dataloader)}")
108-
Test dataloader size: 704
10982
11083
"""
11184
return DataLoader(self.data_test, batch_size=self.batch_size, num_workers=self.num_workers)

0 commit comments

Comments
 (0)