unet测试评估metric脚本

2023-11-07

全部复制的paddleseg的代码转torch

import argparse
import logging
import os

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

from utils.data_loading import BasicDataset
from unet import UNet
from utils.utils import plot_img_and_mask
from torch.utils.data import DataLoader, random_split
from utils.data_loading import BasicDataset, CarvanaDataset
from tqdm import tqdm
import torch.nn.functional as F


# 使用python写一个评估使用pytorch训练的unet模型的好坏,模型输出nchw格式的数据,真实标签数据为nhw格式,请计算模型的accuracy, calss precision ,class recall,kappa指标

EPSILON = 1e-32

def calculate_area(pred, label, num_classes, ignore_index=255):
    """
    Calculate intersect, prediction and label area

    Args:
        pred (Tensor): The prediction by model.
        label (Tensor): The ground truth of image.
        num_classes (int): The unique number of target classes.
        ignore_index (int): Specifies a target value that is ignored. Default: 255.

    Returns:
        Tensor: The intersection area of prediction and the ground on all class.
        Tensor: The prediction area on all class.
        Tensor: The ground truth area on all class
    """
    if len(pred.shape) == 4:
        pred = torch.squeeze(pred, axis=1)
    if len(label.shape) == 4:
        label = torch.squeeze(label, axis=1)
    if not pred.shape == label.shape:
        
        raise ValueError('Shape of `pred` and `label should be equal, '
                         'but there are {} and {}.'.format(pred.shape,
                                                           label.shape))
    pred_area = []
    label_area = []
    intersect_area = []
    mask = label != ignore_index

    for i in range(num_classes):
        pred_i = torch.logical_and(pred == i, mask)
        label_i = label == i
        intersect_i = torch.logical_and(pred_i, label_i)
        pred_area.append(torch.sum(pred_i))  
        label_area.append(torch.sum(label_i))  
        intersect_area.append(torch.sum(intersect_i))  

    pred_area = torch.stack(pred_area)  
    label_area = torch.stack(label_area)  
    intersect_area = torch.stack(intersect_area)  

    return intersect_area, pred_area, label_area


def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
    parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')
    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
    parser.add_argument('--root', '-r', type=str, default=False, help='root dir')
    parser.add_argument('--num', '-n', type=int, default=False, help='num of classes')

    return parser.parse_args()


dir_img_path = 'imgs'
dir_mask_path = 'masks'

import metrics

def train_net(net,
              device,
              epochs: int = 5,
              batch_size: int = 1,
              learning_rate: float = 0.001,
              val_percent: float = 0.1,
              save_checkpoint: bool = True,
              img_scale: float = 0.5,
              amp: bool = False,root_dir: str = '/data/yangbo/unet/datas/data1'):

    train_dir_img=os.path.join(root_dir,'train',dir_img_path)
    train_dir_mask=os.path.join(root_dir,'train',dir_mask_path)

    val_dir_img=os.path.join(root_dir,'val',dir_img_path)
    val_dir_mask=os.path.join(root_dir,'val',dir_mask_path)
    # 1. Create dataset
    try:
        train_dataset = CarvanaDataset(train_dir_img, train_dir_mask, img_scale)
        val_dataset = CarvanaDataset(val_dir_img, val_dir_mask, img_scale)
    except (AssertionError, RuntimeError):
        train_dataset = BasicDataset(train_dir_img, train_dir_mask, img_scale)
        val_dataset = BasicDataset(val_dir_img, val_dir_mask, img_scale)

    n_val = len(val_dataset)
    n_train = len(train_dataset)

    # 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
    train_loader = DataLoader(train_dataset, shuffle=True, **loader_args)
    val_loader = DataLoader(val_dataset, shuffle=False, drop_last=True, **loader_args)


    # (Initialize logging)

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_checkpoint}
        Device:          {device.type}
        Images scaling:  {img_scale}
        Mixed Precision: {amp}
    ''')

    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    #optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
    #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score

    # 5. Begin training
    intersect_area_all=torch.zeros([1])
    pred_area_all=torch.zeros([1])
    label_area_all=torch.zeros([1])
    for idx,batch in tqdm(enumerate(val_loader)):
        images = batch['image']
        true_masks = batch['mask']

        assert images.shape[1] == net.n_channels, \
            f'Network has been defined with {net.n_channels} input channels, ' \
            f'but loaded images have {images.shape[1]} channels. Please check that ' \
            'the images are loaded correctly.'

        images = images.to(device=device, dtype=torch.float32)
        true_masks = true_masks.to(device=device, dtype=torch.long)
        with torch.no_grad():
            masks_pred = net(images)
            masks_pred=torch.argmax(masks_pred,axis=1,keepdim=True)
            intersect_area, pred_area, label_area=calculate_area(masks_pred,true_masks,3)
            intersect_area_all = intersect_area_all + intersect_area
            pred_area_all = pred_area_all + pred_area
            label_area_all = label_area_all + label_area
    metrics_input = (intersect_area_all, pred_area_all, label_area_all)
    class_iou, miou = metrics.mean_iou(*metrics_input)
    acc, class_precision, class_recall = metrics.class_measurement(
        *metrics_input)
    kappa = metrics.kappa(*metrics_input)
    class_dice, mdice = metrics.dice(*metrics_input)
    infor="[EVAL] #Images: {} mIoU: {:.4f} Acc: {:.4f} Kappa: {:.4f} Dice: {:.4f}".format(
            len(val_loader), miou, acc, kappa, mdice)
    print(infor)
    print("[EVAL] Class IoU: " + str(np.round(class_iou, 4)))
    print("[EVAL] Class Precision: " + str(
            np.round(class_precision, 4)))
    print("[EVAL] Class Recall: " + str(np.round(class_recall, 4)))

if __name__ == '__main__':
    args = get_args()

    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    # Change here to adapt to your data
    # n_channels=3 for RGB images
    # n_classes is the number of probabilities you want to get per pixel
    # 修改numclass
    net = UNet(n_channels=3, n_classes=args.num, bilinear=True)
    net.eval()
    logging.info(f'Network:\n'
                 f'\t{net.n_channels} input channels\n'
                 f'\t{net.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

    if args.load:
        net.load_state_dict(torch.load(args.load, map_location=device))
        logging.info(f'Model loaded from {args.load}')

    net.to(device=device)
    try:
        train_net(net=net,
                  epochs=0,
                  batch_size=args.batch_size,
                  learning_rate=0,
                  device=device,
                  img_scale=args.scale,
                  val_percent=args.val / 100,
                  amp=args.amp,
                  root_dir=args.root)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')


metris.py

import numpy as np
import torch
import sklearn.metrics as skmetrics

def mean_iou(intersect_area, pred_area, label_area):
    """
    Calculate iou.

    Args:
        intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.
        pred_area (Tensor): The prediction area on all classes.
        label_area (Tensor): The ground truth area on all classes.

    Returns:
        np.ndarray: iou on all classes.
        float: mean iou of all classes.
    """
    intersect_area = intersect_area.numpy()
    pred_area = pred_area.numpy()
    label_area = label_area.numpy()
    union = pred_area + label_area - intersect_area
    class_iou = []
    for i in range(len(intersect_area)):
        if union[i] == 0:
            iou = 0
        else:
            iou = intersect_area[i] / union[i]
        class_iou.append(iou)
    miou = np.mean(class_iou)
    return np.array(class_iou), miou

def class_measurement(intersect_area, pred_area, label_area):
    """
    Calculate accuracy, calss precision and class recall.

    Args:
        intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.
        pred_area (Tensor): The prediction area on all classes.
        label_area (Tensor): The ground truth area on all classes.

    Returns:
        float: The mean accuracy.
        np.ndarray: The precision of all classes.
        np.ndarray: The recall of all classes.
    """
    intersect_area = intersect_area.numpy()
    pred_area = pred_area.numpy()
    label_area = label_area.numpy()

    mean_acc = np.sum(intersect_area) / np.sum(pred_area)
    class_precision = []
    class_recall = []
    for i in range(len(intersect_area)):
        precision = 0 if pred_area[i] == 0 \
            else intersect_area[i] / pred_area[i]
        recall = 0 if label_area[i] == 0 \
            else intersect_area[i] / label_area[i]
        class_precision.append(precision)
        class_recall.append(recall)

    return mean_acc, np.array(class_precision), np.array(class_recall)

def kappa(intersect_area, pred_area, label_area):
    """
    Calculate kappa coefficient

    Args:
        intersect_area (Tensor): The intersection area of prediction and ground truth on all classes..
        pred_area (Tensor): The prediction area on all classes.
        label_area (Tensor): The ground truth area on all classes.

    Returns:
        float: kappa coefficient.
    """
    intersect_area = intersect_area.numpy().astype(np.float64)
    pred_area = pred_area.numpy().astype(np.float64)
    label_area = label_area.numpy().astype(np.float64)
    total_area = np.sum(label_area)
    po = np.sum(intersect_area) / total_area
    pe = np.sum(pred_area * label_area) / (total_area * total_area)
    kappa = (po - pe) / (1 - pe)
    return kappa

def dice(intersect_area, pred_area, label_area):
    """
    Calculate DICE.

    Args:
        intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.
        pred_area (Tensor): The prediction area on all classes.
        label_area (Tensor): The ground truth area on all classes.

    Returns:
        np.ndarray: DICE on all classes.
        float: mean DICE of all classes.
    """
    intersect_area = intersect_area.numpy()
    pred_area = pred_area.numpy()
    label_area = label_area.numpy()
    union = pred_area + label_area
    class_dice = []
    for i in range(len(intersect_area)):
        if union[i] == 0:
            dice = 0
        else:
            dice = (2 * intersect_area[i]) / union[i]
        class_dice.append(dice)
    mdice = np.mean(class_dice)
    return np.array(class_dice), mdice

使用示例

python .\test2.py --root D:\pic\23\0403\851-1003339-H01\bend --scale 0.25 --load C:\Users\Admin\Desktop\fsdownload\checkpoint_epoch485.pth --num 3

结果展示

[EVAL] #Images: 74 mIoU: 0.5119 Acc: 0.9996 Kappa: 0.4405 Dice: 0.6002
[EVAL] Class IoU: [0.9997 0.4177 0.1183]
[EVAL] Class Precision: [0.9998 0.6767 0.1858]
[EVAL] Class Recall: [0.9998 0.5219 0.2456]

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

unet测试评估metric脚本 的相关文章

随机推荐

  • 【SSH】如何删掉远程服务器中的虚拟环境?如何删掉远程服务器中的用户?如何删掉某个文件夹?

    文章目录 一 如何删掉远程服务器中的虚拟环境 二 如何删掉远程服务器中的用户 三 如何删掉某个文件夹 一 如何删掉远程服务器中的虚拟环境 在Linux系统下删除conda虚拟环境 删除虚拟环境 conda remove n your env
  • 无线局域网安全协议(WEP、WPA、WAPI)

    文章目录 一 WEP 有线等效保密 二 WPA Wi Fi网络安全接入 三 WAPI 无线局域网鉴别和保密基础结构 WLAN Wireless Local Area Network 指应用无线通信技术将计算机设备互联起来 构成可以互相通信和
  • 分块矩阵求逆不能想当然

    分块矩阵给实际运算带来了很大的方便 对于行列数都很大的矩阵 可以将其分割成一个个小块进行计算 减少了运算的繁琐程度 分块矩阵的求逆有两个非常有用的公式 能帮助我们快速得出正确结果 但是是不对的 只能假设逆矩阵 再通过求多元方程组得出逆矩阵的
  • 【华为od机试】约瑟夫问题-Python3

    题目描述 篮球 5V5 比赛中 每个球员拥有一个战斗力 每个队伍的所有球员战斗力之和为该队伍的总体战斗力 现有10个球员准备分为两队进行训练赛 教练希望2个队伍的战斗力差值能够尽可能的小 以达到最佳训练效果 给出10个球员的战斗力 如果你是
  • DLL的远程注入技术

    转载自 http blog csdn net bai bzl article details 1801023 一 DLL注入 DLL的远程注入技术是目前Win32病毒广泛使用的一种技术 使用这种技术的病毒体通常位于一个DLL中 在系统启动的
  • Spark集群搭建记录

    本文目录 写在前面 step1 Spark下载 step2 修改环境变量 bashrc etc profile step3 配置Master 文件修改 slaves spark env sh step4 配置slave节点 step5 集群
  • 图形图像基础 之 jpeg介绍

    一 概念 jpeg相关概念简介 jpeg 一种影像有损压缩标准方法 后缀jpg jpeg JPEG Joint Photographic Experts Group 联合图像专家小组 是一种针对照片影像而广泛使用的有损压缩标准方法 面向连续
  • 【报错】onMounted is called when there is no active component instance too be associated with.

    文章目录 报错 分析 解决 报错 onMounted is called when there is no active component instance too be associated with Lifecycle injecti
  • 统计分析:聚类分析(详细讲解)

    聚类分析是研究 物以类聚 的一种方法 人类认识世界往往首先将被认识的对象进行分类 早起人们主要靠经验和专业知识实现分类 但随着生产技术和社会科学的发展 对分类学的要求越来越高 靠经验和专业知识来分类越来越难 于是数学这一有力工具被引入分类学
  • textarea接受后台数据

    在jsp页面使用EL接受后台数据到textarea
  • mybatis-plus-generator生成实体类时添加jdbcType

    效果 需要修改的文件 1 基本思路 1 使用变量 useJdbcType 控制是否需要生成jdbcType 2 生成时拼接相关字段信息 2 步骤 2 1重写TableField 添加jdbcType属性 在com baomidou myba
  • rollup怎么处理.node文件

    Rollup 是一个 JavaScript 模块打包器 它可以将多个模块合并成单个文件 它可以帮助你将你的代码打包成可以在浏览器中运行的文件 要使用 Rollup 处理 node 文件 你需要使用一个 Rollup 插件 如 rollup
  • 数据挖掘基础

    提示 文章写完后 目录可以自动生成 如何生成可参考右边的帮助文档 文章目录 前言 一 数据挖掘定义及用途 1 定义 2 用途 二 决策树 1 理论知识 1 概念 2 算法一般过程 C4 5为例 2 小结 三 关联规则 1 概述 2 关联分析
  • MySQL之多表关联删除/更新

    日常测试的时候 需要连接其他表而删除某些脏数据 按照正常的查询的写法 会这样写删除语句 DELETE from order where id in SELECT o id from order o LEFT JOIN customer c
  • minicom键盘失效,不能输入问题

    转 http blog sina com cn s blog 5d0e8d0d01015svy html 这个问题可以参考secureCRT的时候遇到的问题 问题与minicom的一样 RTS 请求发送 CTS 清除发送 默认情况下mini
  • vue3配置proxy解决跨域

    跨域问题是前端开发中较常见的问题 因为javascript的浏览器会支持同源策略 如果域名 协议 端口任意不同就会产生跨域 如果非同源 那么在请求数据时 浏览器会在控制台中报一个异常 提示拒绝访问 错误信息如下 Access to XMLH
  • Linux常用命令整理(适合初学)

    关机 重启操作 帮助文档 1 帮助命令 1 1 help help cd 查看cd命令的帮助信息
  • 服务器推送消息SSE,HTTP 服务器消息推送之SSE

    HTTP 服务器推送也称 HTTP 流 是一种客户端 服务器通讯模式 它将信息从 HTTP 服务器异步推送到客户端 而无需客户端请求 如今的 web 和 app 中 愈来愈多的场景使用这种通讯模式 好比实时的消息提醒 IM在线聊天 多人文档
  • 将字符串转化为整数

    Java内部实现 param s the code String containing the integer representation to be parsed param radix the radix to be used whi
  • unet测试评估metric脚本

    全部复制的paddleseg的代码转torch import argparse import logging import os import numpy as np import torch import torch nn functio