#import
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dataset
import torchvision.transforms as transform
import torch.optim as optim
from torch.utils.data import DataLoader
#network
class NN(nn.Module):
def __init__(self, input_num, num_class):
super(NN, self).__init__()
self.fc1 = nn.Linear(input_num, 50)
self.fc2 = nn.Linear(50, num_class)
def forward(self,x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
#cnn network
class CNN(nn.Module):
def __init__(self,in_channel = 1, num_class = 10):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels= in_channel, out_channels = 16, kernel_size = (3,3), stride = (1,1), padding = (1,1))
self.maxpool = nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))
self.conv2 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size=(3,3), stride=(1,1), padding=(1,1))
self.fc = nn.Linear(32*7*7,num_class)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.maxpool(x)
x = F.relu(self.conv2(x))
x = self.maxpool(x)
x = x.reshape(x.shape[0], -1)
x = self.fc(x)
return x
#hyperparameter
in_channel = 1
num_class = 10
learning_rate = 0.001
batch_size = 64
num_epoch = 2
load_model =True
#device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#initialize
model = CNN().to(device)
#dataset
train_dataset = dataset.MNIST(root='dataset/', train = True, transform = transform.ToTensor())
train_dataloader = DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True)
test_dataset = dataset.MNIST(root='dataset/', train = False, transform = transform.ToTensor())
test_dataloader = DataLoader(dataset = test_dataset, batch_size = batch_size, shuffle = True)
#loss and criterion
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = learning_rate)
#save model
def save_model(check_point, filename = "mycheckpoint.pth.tar"):
print("-->> saving check_point")
torch.save(check_point,filename)
def load_model(check_point):
print("--> loading check_point")
model.load_state_dict(check_point["state_dict"])
optimizer.load_state_dict(check_point['optimizer'])
#train
if load_model:
load_model(torch.load("mycheckpoint.pth.tar"))
def check_accurate(dataloader, model):
if dataloader.dataset.train:
print("check train accurate")
else:
print("check test accurate")
num_correct = 0
num_all = 0
model.eval()
with torch.no_grad():
for x, y in dataloader:
x = x.to(device)
y = y.to(device)
score = model(x)
_,prediction = score.max(1)
num_correct += (y == prediction).sum()
num_all += prediction.size(0)
print(f'got {num_correct} / {num_all} acc is {float(num_correct) / float(num_all) * 100 : .2f}')
model.train()
check_accurate(train_dataloader, model)
check_accurate(test_dataloader, model)