-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
84 lines (67 loc) · 2.05 KB
/
train.py
File metadata and controls
84 lines (67 loc) · 2.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Trains a PyTorch image classification model using device-agnostic code.
"""
import torch, pandas as pd
from torchvision import transforms
import data_setup, engine, model_builder, utils
from data_setup import CustomDataset
dataframe = pd.read_csv("./data/train.csv")
dataframe = dataframe.drop(columns=["Unnamed: 0"])
# Setup hyperparameters
NUM_EPOCHS = 2
BATCH_SIZE = 32
HIDDEN_UNITS = 10
LEARNING_RATE = 0.001
print(f"Number of Epochs : {NUM_EPOCHS}")
print(f"Batch Size : {BATCH_SIZE}")
print(f"Hidden Units : {HIDDEN_UNITS}")
print(f"Learning Rate : {LEARNING_RATE}")
HEIGHT = 64
WIDTH = 64
# Create transforms
data_transform = transforms.Compose([
transforms.Resize(size=(HEIGHT, WIDTH)),
transforms.ToTensor()
])
#Create Dataset
custom_dataset = CustomDataset(dataframe, transformer=data_transform)
# Create DataLoaders with help from data_setup.py
train_dataloader, val_dataloader = data_setup.create_dataloaders(
custom_dataset,
BATCH_SIZE,
split_test=True,
test_size=0.1
)
print(f"Num of classes : {len(dataframe['label'].unique())}")
# Create model with help from model_builder.py
# model = model_builder.Baseline(
# input_shape=3,
# hidden_units=HIDDEN_UNITS,
# output_shape=len(dataframe["label"].unique())
# )
# model = model_builder.TinyVGG(
# input_shape=3,
# hidden_units=HIDDEN_UNITS,
# output_shape=len(dataframe["label"].unique())
# )
model = model_builder.EfficientNet()
# model = model_builder.VisionTransformer(
# num_classes=2,
# image_size=HEIGHT,
# )
# Set loss and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),
lr=LEARNING_RATE)
# Start training with help from engine.py
engine.train(
epochs=NUM_EPOCHS,
model=model,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
loss_fn=loss_fn,
optimizer=optimizer)
# Save the model with help from utils.py
utils.save_model(model=model,
target_dir="models",
model_name="baseline_model.pth")