-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_setup.py
More file actions
27 lines (19 loc) · 900 Bytes
/
data_setup.py
File metadata and controls
27 lines (19 loc) · 900 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch.utils.data
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
DEFAULT_TRANSFORMER = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor()
])
def create_dataset(images_path: str, test_train_split_enabled: bool=False, test_size: float=0.3, transformer = DEFAULT_TRANSFORMER):
dataset = datasets.ImageFolder(images_path, transform=transformer)
if test_train_split_enabled:
return train_test_split(dataset, test_size)
return dataset
def train_test_split(dataset, test_size: float):
TOTAL = len(dataset)
VAL = int(TOTAL * test_size)
TRAIN = int(TOTAL - VAL)
return torch.utils.data.random_split(dataset=dataset, lengths=[TRAIN, VAL])
def create_dataloader(dataset, batch_size, shuffle: bool=False) -> DataLoader:
return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle)