全部复制的paddleseg的代码转torch
import argparse
import logging
import os
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from utils.data_loading import BasicDataset
from unet import UNet
from utils.utils import plot_img_and_mask
from torch.utils.data import DataLoader, random_split
from utils.data_loading import BasicDataset, CarvanaDataset
from tqdm import tqdm
import torch.nn.functional as F
# 使用python写一个评估使用pytorch训练的unet模型的好坏,模型输出nchw格式的数据,真实标签数据为nhw格式,请计算模型的accuracy, calss precision ,class recall,kappa指标
EPSILON = 1e-32
def calculate_area(pred, label, num_classes, ignore_index=255):
"""
Calculate intersect, prediction and label area
Args:
pred (Tensor): The prediction by model.
label (Tensor): The ground truth of image.
num_classes (int): The unique number of target classes.
ignore_index (int): Specifies a target value that is ignored. Default: 255.
Returns:
Tensor: The intersection area of prediction and the ground on all class.
Tensor: The prediction area on all class.
Tensor: The ground truth area on all class
"""
if len(pred.shape) == 4:
pred = torch.squeeze(pred, axis=1)
if len(label.shape) == 4:
label = torch.squeeze(label, axis=1)
if not pred.shape == label.shape:
raise ValueError('Shape of `pred` and `label should be equal, '
'but there are {} and {}.'.format(pred.shape,
label.shape))
pred_area = []
label_area = []
intersect_area = []
mask = label != ignore_index
for i in range(num_classes):
pred_i = torch.logical_and(pred == i, mask)
label_i = label == i
intersect_i = torch.logical_and(pred_i, label_i)
pred_area.append(torch.sum(pred_i))
label_area.append(torch.sum(label_i))
intersect_area.append(torch.sum(intersect_i))
pred_area = torch.stack(pred_area)
label_area = torch.stack(label_area)
intersect_area = torch.stack(intersect_area)
return intersect_area, pred_area, label_area
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
help='Percent of the data that is used as validation (0-100)')
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
parser.add_argument('--root', '-r', type=str, default=False, help='root dir')
parser.add_argument('--num', '-n', type=int, default=False, help='num of classes')
return parser.parse_args()
dir_img_path = 'imgs'
dir_mask_path = 'masks'
import metrics
def train_net(net,
device,
epochs: int = 5,
batch_size: int = 1,
learning_rate: float = 0.001,
val_percent: float = 0.1,
save_checkpoint: bool = True,
img_scale: float = 0.5,
amp: bool = False,root_dir: str = '/data/yangbo/unet/datas/data1'):
train_dir_img=os.path.join(root_dir,'train',dir_img_path)
train_dir_mask=os.path.join(root_dir,'train',dir_mask_path)
val_dir_img=os.path.join(root_dir,'val',dir_img_path)
val_dir_mask=os.path.join(root_dir,'val',dir_mask_path)
# 1. Create dataset
try:
train_dataset = CarvanaDataset(train_dir_img, train_dir_mask, img_scale)
val_dataset = CarvanaDataset(val_dir_img, val_dir_mask, img_scale)
except (AssertionError, RuntimeError):
train_dataset = BasicDataset(train_dir_img, train_dir_mask, img_scale)
val_dataset = BasicDataset(val_dir_img, val_dir_mask, img_scale)
n_val = len(val_dataset)
n_train = len(train_dataset)
# 3. Create data loaders
loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
train_loader = DataLoader(train_dataset, shuffle=True, **loader_args)
val_loader = DataLoader(val_dataset, shuffle=False, drop_last=True, **loader_args)
# (Initialize logging)
logging.info(f'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {learning_rate}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_checkpoint}
Device: {device.type}
Images scaling: {img_scale}
Mixed Precision: {amp}
''')
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
#optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
#scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2) # goal: maximize Dice score
# 5. Begin training
intersect_area_all=torch.zeros([1])
pred_area_all=torch.zeros([1])
label_area_all=torch.zeros([1])
for idx,batch in tqdm(enumerate(val_loader)):
images = batch['image']
true_masks = batch['mask']
assert images.shape[1] == net.n_channels, \
f'Network has been defined with {net.n_channels} input channels, ' \
f'but loaded images have {images.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
images = images.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=torch.long)
with torch.no_grad():
masks_pred = net(images)
masks_pred=torch.argmax(masks_pred,axis=1,keepdim=True)
intersect_area, pred_area, label_area=calculate_area(masks_pred,true_masks,3)
intersect_area_all = intersect_area_all + intersect_area
pred_area_all = pred_area_all + pred_area
label_area_all = label_area_all + label_area
metrics_input = (intersect_area_all, pred_area_all, label_area_all)
class_iou, miou = metrics.mean_iou(*metrics_input)
acc, class_precision, class_recall = metrics.class_measurement(
*metrics_input)
kappa = metrics.kappa(*metrics_input)
class_dice, mdice = metrics.dice(*metrics_input)
infor="[EVAL] #Images: {} mIoU: {:.4f} Acc: {:.4f} Kappa: {:.4f} Dice: {:.4f}".format(
len(val_loader), miou, acc, kappa, mdice)
print(infor)
print("[EVAL] Class IoU: " + str(np.round(class_iou, 4)))
print("[EVAL] Class Precision: " + str(
np.round(class_precision, 4)))
print("[EVAL] Class Recall: " + str(np.round(class_recall, 4)))
if __name__ == '__main__':
args = get_args()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
# 修改numclass
net = UNet(n_channels=3, n_classes=args.num, bilinear=True)
net.eval()
logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n'
f'\t{net.n_classes} output channels (classes)\n'
f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
if args.load:
net.load_state_dict(torch.load(args.load, map_location=device))
logging.info(f'Model loaded from {args.load}')
net.to(device=device)
try:
train_net(net=net,
epochs=0,
batch_size=args.batch_size,
learning_rate=0,
device=device,
img_scale=args.scale,
val_percent=args.val / 100,
amp=args.amp,
root_dir=args.root)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
metris.py
import numpy as np
import torch
import sklearn.metrics as skmetrics
def mean_iou(intersect_area, pred_area, label_area):
"""
Calculate iou.
Args:
intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.
pred_area (Tensor): The prediction area on all classes.
label_area (Tensor): The ground truth area on all classes.
Returns:
np.ndarray: iou on all classes.
float: mean iou of all classes.
"""
intersect_area = intersect_area.numpy()
pred_area = pred_area.numpy()
label_area = label_area.numpy()
union = pred_area + label_area - intersect_area
class_iou = []
for i in range(len(intersect_area)):
if union[i] == 0:
iou = 0
else:
iou = intersect_area[i] / union[i]
class_iou.append(iou)
miou = np.mean(class_iou)
return np.array(class_iou), miou
def class_measurement(intersect_area, pred_area, label_area):
"""
Calculate accuracy, calss precision and class recall.
Args:
intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.
pred_area (Tensor): The prediction area on all classes.
label_area (Tensor): The ground truth area on all classes.
Returns:
float: The mean accuracy.
np.ndarray: The precision of all classes.
np.ndarray: The recall of all classes.
"""
intersect_area = intersect_area.numpy()
pred_area = pred_area.numpy()
label_area = label_area.numpy()
mean_acc = np.sum(intersect_area) / np.sum(pred_area)
class_precision = []
class_recall = []
for i in range(len(intersect_area)):
precision = 0 if pred_area[i] == 0 \
else intersect_area[i] / pred_area[i]
recall = 0 if label_area[i] == 0 \
else intersect_area[i] / label_area[i]
class_precision.append(precision)
class_recall.append(recall)
return mean_acc, np.array(class_precision), np.array(class_recall)
def kappa(intersect_area, pred_area, label_area):
"""
Calculate kappa coefficient
Args:
intersect_area (Tensor): The intersection area of prediction and ground truth on all classes..
pred_area (Tensor): The prediction area on all classes.
label_area (Tensor): The ground truth area on all classes.
Returns:
float: kappa coefficient.
"""
intersect_area = intersect_area.numpy().astype(np.float64)
pred_area = pred_area.numpy().astype(np.float64)
label_area = label_area.numpy().astype(np.float64)
total_area = np.sum(label_area)
po = np.sum(intersect_area) / total_area
pe = np.sum(pred_area * label_area) / (total_area * total_area)
kappa = (po - pe) / (1 - pe)
return kappa
def dice(intersect_area, pred_area, label_area):
"""
Calculate DICE.
Args:
intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.
pred_area (Tensor): The prediction area on all classes.
label_area (Tensor): The ground truth area on all classes.
Returns:
np.ndarray: DICE on all classes.
float: mean DICE of all classes.
"""
intersect_area = intersect_area.numpy()
pred_area = pred_area.numpy()
label_area = label_area.numpy()
union = pred_area + label_area
class_dice = []
for i in range(len(intersect_area)):
if union[i] == 0:
dice = 0
else:
dice = (2 * intersect_area[i]) / union[i]
class_dice.append(dice)
mdice = np.mean(class_dice)
return np.array(class_dice), mdice
使用示例
python .\test2.py --root D:\pic\23\0403\851-1003339-H01\bend --scale 0.25 --load C:\Users\Admin\Desktop\fsdownload\checkpoint_epoch485.pth --num 3
结果展示
[EVAL] #Images: 74 mIoU: 0.5119 Acc: 0.9996 Kappa: 0.4405 Dice: 0.6002
[EVAL] Class IoU: [0.9997 0.4177 0.1183]
[EVAL] Class Precision: [0.9998 0.6767 0.1858]
[EVAL] Class Recall: [0.9998 0.5219 0.2456]