import torch
import random
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import MyDataset
import FCN
# 模型参数
BATCH_SIZE = 16
LR_DECAY = 0.005
LR = 0.0001
# 随机
random.seed(3)
np.random.seed(3)
torch.manual_seed(3)
torch.cuda.manual_seed(3)
torch.cuda.manual_seed_all(3)
# 数据准备
transform_x = transforms.Compose([
transforms.ToTensor(), # [0,255] -> [0,1]
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # 三通道
])
transform_y = transforms.Compose([
transforms.ToTensor(), # [0、255] -> [0、1]
])
train_dataset = MyDataset.MyDataset('/home/rtx3090/storage/student1/cxb/nodecode/前期数据集/u-net_liver-master/train', transform=transform_x, target_transform=transform_y)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)
test_dataset = MyDataset.MyDataset('/home/rtx3090/storage/student1/cxb/nodecode/前期数据集/u-net_liver-master/val', transform=transform_x, target_transform=transform_y)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=1)
#
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion = torch.nn.NLLLoss().to(device)
def train(model, epoch, optimizer):
for batch_idx, data in enumerate(train_loader, 0):
model.train()
inputs, target = data
inputs, target = inputs.to(device), target.long().to(device)
optimizer.zero_grad()
outputs = model(inputs)
outputs = F.log_softmax(outputs, dim=1)
loss = criterion(outputs, target.view(BATCH_SIZE, 512, 512)) # 注意
loss.backward()
optimizer.step()
if batch_idx % 10 == 0: # 迭代10次打印一次损失
print("epoch:{}, batch_idx:{}, loss:{}".format(epoch, batch_idx, loss.item()))
def Dice(model, len, dataloader=test_loader):
model.eval()
Dices = 0
with torch.no_grad():
for data in dataloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = model(images)
pre_label = outputs.max(dim=1)[1].data.cpu().numpy()
true_label = labels.view(-1, 512, 512).data.cpu().numpy()
# 混淆矩阵
TP ,TN ,FN ,FP = 0 ,0 ,0 ,0
TP = (pre_label + true_label == 2).sum()
TN = (pre_label + true_label == 0).sum()
FN = (pre_label-true_label == -1).sum()
FP = (pre_label-true_label == 1).sum()
Dice = 2*TP/(2*TP+FP+FN)
Dices += Dice
return Dices/len
if __name__=='__main__':
model = FCN.FCNs(num_classes=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=LR_DECAY)
best_dice = 0.
for epoch in range(30):
train(model, epoch, optimizer)
# train_dice = Dice(model,len(train_dataset) ,dataloader=train_loader)
test_dice = Dice(model, len(test_dataset), dataloader=test_loader)
# print("the train_dice is:",train_dice.item())
print("the test_dice is:", test_dice.item())
if test_dice > best_dice:
best_dice = test_dice
torch.save(model, "FCN")
print("the best Dice is: {}".format(best_dice))