import torch
import random
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import MyDataset
import UNet
import torch.nn.functional as F
import numpy as np
import six
import matplotlib.pyplot as plt
import FCN
# 模型参数
BATCH_SIZE = 4
LR_DECAY = 0.005
LR = 0.0001
# 随机
random.seed(3)
np.random.seed(3)
torch.manual_seed(3)
torch.cuda.manual_seed(3)
torch.cuda.manual_seed_all(3)
# 数据准备
transform_x = transforms.Compose([
transforms.ToTensor(), # [0,255] -> [0,1]
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # 三通道
])
transform_y = transforms.Compose([
transforms.ToTensor(), # [0、255] -> [0、1]
])
train_dataset = MyDataset.MyDataset('/home/rtx3090/storage/student1/cxb/nodecode/前期数据集/u-net_liver-master/train', transform=transform_x, target_transform=transform_y)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)
test_dataset = MyDataset.MyDataset('/home/rtx3090/storage/student1/cxb/nodecode/前期数据集/u-net_liver-master/val', transform=transform_x, target_transform=transform_y)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=1)
#
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion = torch.nn.NLLLoss().to(device)
def train(model, epoch, optimizer):
for batch_idx, data in enumerate(train_loader, 0):
model.train()
inputs, target = data
inputs, target = inputs.to(device), target.long().to(device)
optimizer.zero_grad()
outputs = model(inputs)
outputs = F.log_softmax(outputs, dim=1)
loss = criterion(outputs, target.view(BATCH_SIZE, 512, 512)) # 注意
loss.backward()
optimizer.step()
if batch_idx % 10 == 0: # 迭代10次打印一次损失
print("epoch:{}, batch_idx:{}, loss:{}".format(epoch, batch_idx, loss.item()))
def calc_semantic_segmentation_confusion(pred_labels, gt_labels):
"""Collect a confusion matrix. 计算 混淆矩阵
The number of classes `n_class` is `max(pred_labels, gt_labels) + 1`, which is
the maximum class id of the inputs added by one.
Args:
pred_labels(iterable of numpy.ndarray): A collection of predicted
labels. The shape of a label array
is `(H, W)`. `H` and `W`
are height and width of the label.
gt_labels(iterable of numpy.ndarray): A collection of ground
truth labels. The shape of a ground truth label array is
`(H, W)`, and its corresponding prediction label should
have the same shape.
A pixel with value `-1` will be ignored during evaluation.
Returns:
numpy.ndarray:
A confusion matrix. Its shape is `(n_class, n_class)`.
The `(i, j)` th element corresponds to the number of pixels
that are labeled as class `i` by the ground truth and
class `j` by the prediction.
"""
pred_labels = iter(pred_labels)
gt_labels = iter(gt_labels)
n_class = 12
# 定义一个数值容器 shape(12,12)
confusion = np.zeros((n_class, n_class), dtype=np.int64)
for pred_label, gt_label in six.moves.zip(pred_labels, gt_labels): # six.moves.zip in python2
if pred_label.ndim != 2 or gt_label.ndim != 2:
raise ValueError('ndim of labels should be two.')
if pred_label.shape != gt_label.shape:
raise ValueError(
'Shape of ground truth and prediction should be same.')
pred_label = pred_label.flatten()
gt_label = gt_label.flatten()
# Dynamically expand the confusion matrix if necessary.
lb_max = np.max((pred_label, gt_label))
# print(lb_max)
if lb_max >= n_class:
expanded_confusion = np.zeros(
(lb_max + 1, lb_max + 1), dtype=np.int64)
expanded_confusion[0:n_class, 0:n_class] = confusion
n_class = lb_max + 1
confusion = expanded_confusion # 原来的confusion矩阵就没有了,被expanded_confusion替换了
# Count statistics from valid pixels.
mask = gt_label >= 0
confusion += np.bincount(
n_class * gt_label[mask].astype(int) + pred_label[mask], # 这样处理axis=0 代表gt axis=1 代表pred……
minlength=n_class ** 2).reshape((n_class, n_class)) # 抓住一个点,混淆矩阵中,对角线上的点是分类正确的
for iter_ in (pred_labels, gt_labels):
# This code assumes any iterator does not contain None as its items.
if next(iter_, None) is not None:
raise ValueError('Length of input iterables need to be same')
# confusion = np.delete(confusion, 11, axis=0)
# confusion = np.delete(confusion, 11, axis=1)
return confusion
def calc_semantic_segmentation_iou(confusion):
"""Calculate Intersection over Union with a given confusion matrix.
Args:
confusion (numpy.ndarray): A confusion matrix. Its shape is
`(n_class, n_class)`.
The `(i, j)`th element corresponds to the number of pixels
that are labeled as class `i` by the ground truth and
class `j` by the prediction.
Returns:
numpy.ndarray:
An array of IoUs for the `n_class` classes. Its shape is `(n_class,)`.
"""
# iou_denominator 并集 np.diag(confusion) 交集
iou_denominator = (
confusion.sum(axis=1) + confusion.sum(axis=0) - np.diag(confusion))
iou = np.diag(confusion) / iou_denominator
return iou[:-1] # 去掉最后一个类别,因为最后一个类别为 unlabelled
# return iou
def eval_semantic_segmentation(pred_labels, gt_labels):
"""Evaluate metrics used in Semantic Segmentation
Args:
pred_labels (iterable of numpy.ndarray): A collection of predicted
labels. The shape of a label array
is (H, W). H and W are height and width of the label.
For example, this is a list of labels [label_0, label_1, ...],
where label_i.shape = (H_i, W_i).
gt_labels (iterable of numpy.ndarray): A collection of ground
truth labels. The shape of a ground truth label array is
(H, W), and its corresponding prediction label should
have the same shape.
A pixel with value `-1` will be ignored during evaluation.
Returns:
dict:
The keys-value types and the description of the values are listed
below.
* iou(numpy.ndarray): An array of IoUs for the
`n_class` classes. Its shape is `(n_class,)`.
* miou(float): The average of IoUs over classes.
* pixel_accuracy(float): The computed pixel accuracy.
* class_accuracy(numpy.ndarray): An array of class accuracies
for the `n_class` classes.
Its shape is `(n_class,)`.
* mean_class_accuracy(float): The average of class accuracies.
Evaluation code is based on
https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/
score.py #L37
"""
confusion = calc_semantic_segmentation_confusion(pred_labels, gt_labels)
iou = calc_semantic_segmentation_iou(confusion)
pixel_accuracy = np.diag(confusion).sum() / confusion.sum()
class_accuracy = np.diag(confusion) / (np.sum(confusion, axis=1) + 1e-10)
return {'iou': iou, 'miou': np.nanmean(iou),
'pixel_accuracy': pixel_accuracy,
'class_accuracy': class_accuracy,
'mean_class_accuracy': np.nanmean(class_accuracy[:-1])}
# 'mean_class_accuracy': np.nanmean(class_accuracy)}
def Dice(model, len, dataloader=test_loader):
model.eval()
Dices = 0
with torch.no_grad():
for data in dataloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = model(images)
pre_label = outputs.max(dim=1)[1].data.cpu().numpy()
true_label = labels.view(-1, 512, 512).data.cpu().numpy()
eval_metrix = eval_semantic_segmentation(pre_label, true_label)
# 混淆矩阵
TP ,TN ,FN ,FP = 0 ,0 ,0 ,0
TP = (pre_label + true_label == 2).sum()
TN = (pre_label + true_label == 0).sum()
FN = (pre_label-true_label == -1).sum()
FP = (pre_label-true_label == 1).sum()
Dice = 2*TP/(2*TP+FP+FN)
Dices += Dice
return Dices/len,eval_metrix
if __name__=='__main__':
model = UNet.UNet(in_channel=3, out_channel=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=LR_DECAY)
best_dice = 0.
best_miou = 0
best_iou = 0
best_pix_accuracy = 0
class_accuracy = 0
best_mean_class_accuracy = 0
for epoch in range(30):
train(model, epoch, optimizer)
# train_dice = Dice(model,len(train_dataset) ,dataloader=train_loader)
test_dice,eval_metrix= Dice(model, len(test_dataset), dataloader=test_loader)
# print("the train_dice is:",train_dice.item())
print("the test_dice is:", test_dice.item())
if test_dice > best_dice:
best_dice = test_dice
torch.save(model, "FCN")
if best_miou < eval_metrix['miou']:
best_miou = eval_metrix['miou']
if best_pix_accuracy < eval_metrix['pixel_accuracy']:
best_pix_accuracy = eval_metrix['pixel_accuracy']
if best_mean_class_accuracy < eval_metrix['mean_class_accuracy']:
best_mean_class_accuracy = eval_metrix['mean_class_accuracy']
print("the best Dice is: {}".format(best_dice))
print("the best miou is:", eval_metrix['miou'])
print("the best pix_acuracy is:", eval_metrix['pixel_accuracy'])
print("the best mean_class_accuracy is:", eval_metrix['mean_class_accuracy'])
print("the best Dice is: {}".format(best_dice))
print("the best miou is:", eval_metrix['miou'])
print("the best pix_acuracy is:", eval_metrix['pixel_accuracy'])
print("the best mean_class_accuracy is:", eval_metrix['mean_class_accuracy'])
兄弟,你搞了好多注释啊