Pytorch tutorial (5)
CNN with nn.Module
Pytorch tutorial (5)
This contents for practicing pytorch Module structure.
So, data preprocessing and data augmentation are omitted.
CIFAR-10
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import matplotlib.pyplot as plt
from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import ToTensor
# load CIFAR-10 datasets
training_data = CIFAR10(
root="/.",
train=True,
download=True,
transform=ToTensor()
)
test_data = CIFAR10(
root="/.",
train=False,
download=True,
transform=ToTensor()
)
for i in range(9):
plt.subplot(3, 3, i+1)
plt.imshow(training_data.data[i])
plt.show()
Model
Basic Block
We define basic block to build a model more easy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import torch.nn as nn
# basic block
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, hidden_dim):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, out_channels, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = self.pool(x)
return x
CNN
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# model
class CNN(nn.Module):
def __init__(self, num_classes):
super(CNN, self).__init__()
self.block1 = BasicBlock(in_channels=3, out_channels=32, hidden_dim=16)
self.block2 = BasicBlock(in_channels=32, out_channels=128, hidden_dim=64)
self.block3 = BasicBlock(in_channels=128, out_channels=256, hidden_dim=128)
self.fc1 = nn.Linear(in_features=4096, out_features=2048)
self.fc2 = nn.Linear(in_features=2048, out_features=256)
self.fc3 = nn.Linear(in_features=256, out_features=num_classes)
self.relu = nn.ReLU()
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = torch.flatten(x, start_dim=1)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
Train
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# model train
from torch.utils.data.dataloader import DataLoader
from torch.optim.adam import Adam
num_classes = 10
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 20
train_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)
model = CNN(num_classes).to(device)
lr = 1e-3
optim = Adam(model.parameters(), lr=lr)
for epoch in range(100):
for data, label in train_loader:
optim.zero_grad()
preds = model(data.to(device))
loss = nn.CrossEntropyLoss()(preds, label.to(device))
loss.backward()
optim.step()
if epoch==0 or epoch%10==9:
print(f"epoch : {epoch+1}, loss : {loss.item()}")
torch.save(model.state_dict(), "CIFAR.pt")
reference - 텐초의 파이토치 딥러닝 특강
This post is licensed under CC BY 4.0 by the author.
