MMSegmentation笔记06:推理

2023-11-12

1. 单张图像预测

"""
==========================================
@author: Seaton
@Time: 2023/8/19:15:38
@IDE: PyCharm
@Summary:使用训练好的模型进行单张图像推理
==========================================
"""

import cv2
import matplotlib.pyplot as plt
import numpy as np
from mmengine import Config

from mmseg.apis import init_model, inference_model

cfg = Config.fromfile('mmsegmentation/Zihao-Configs/ZihaoDataset_UNet_20230712.py')
checkpoint_path = 'mmsegmentation/checkpoint/myUNet.pth'
model = init_model(cfg, checkpoint_path, 'cuda:0')

# 原图
img_path = 'mmsegmentation/Watermelon87_Semantic_Seg_Mask/img_dir/val/01bd15599c606aa801201794e1fa30.jpg'
img_bgr = cv2.imread(img_path)
plt.figure(figsize=(8, 8))
plt.imshow(img_bgr[:, :, ::-1])
plt.show()

# 推理
result = inference_model(model, img_bgr)
pred_mask = result.pred_sem_seg.data[0].cpu().numpy()

# 显示语义分割结果
plt.figure(figsize=(10, 8))
plt.imshow(img_bgr[:, :, ::-1])
plt.imshow(pred_mask, alpha=0.55)  # alpha 高亮区域透明度,越小越接近原图
plt.axis('off')
plt.savefig('mmsegmentation/outputs/K1-1.jpg')
plt.show()

# 各类别的配色方案(BGR)
palette = [
    ['background', [127, 127, 127]],
    ['red', [0, 0, 200]],
    ['green', [0, 200, 0]],
    ['white', [144, 238, 144]],
    ['seed-black', [30, 30, 30]],
    ['seed-white', [8, 189, 251]]
]

palette_dict = {}
for idx, each in enumerate(palette):
    palette_dict[idx] = each[1]
opacity = 0.3  # 透明度,越大越接近原图
# 将预测的整数ID,映射为对应类别的颜色
pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
for idx in palette_dict.keys():
    pred_mask_bgr[np.where(pred_mask == idx)] = palette_dict[idx]
pred_mask_bgr = pred_mask_bgr.astype('uint8')

# 将语义分割预测图和原图叠加显示
pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1 - opacity, 0)

cv2.imwrite('outputs/K1-3.jpg', pred_viz)
plt.figure(figsize=(8, 8))
plt.imshow(pred_viz[:, :, ::-1])
plt.show()

# 对比label和预测结果
label_path = 'mmsegmentation/Watermelon87_Semantic_Seg_Mask/ann_dir/val/01bd15599c606aa801201794e1fa30.png'
label = cv2.imread(label_path)
label_mask = label[:, :, 0]
# 真实为西瓜红瓤,预测为西瓜红壤取并集
TP = (label_mask == 1) & (pred_mask == 1)
plt.imshow(TP)
plt.show()

# 绘制混淆矩阵
from sklearn.metrics import confusion_matrix

confusion_matrix_model = confusion_matrix(label_mask.flatten(), pred_mask.flatten())
import itertools


def cnf_matrix_plotter(cm, classes, cmap=plt.cm.Blues):
    """
    传入混淆矩阵和标签名称列表,绘制混淆矩阵
    """
    plt.figure(figsize=(10, 10))

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    # plt.colorbar() # 色条
    tick_marks = np.arange(len(classes))

    plt.title('Confusion Matrix', fontsize=30)
    plt.xlabel('Pred', fontsize=25, c='r')
    plt.ylabel('True', fontsize=25, c='r')
    plt.tick_params(labelsize=16)  # 设置类别文字大小
    plt.xticks(tick_marks, classes, rotation=90)  # 横轴文字旋转
    plt.yticks(tick_marks, classes)

    # 写数字
    threshold = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > threshold else "black",
                 fontsize=12)

    plt.tight_layout()

    plt.savefig('mmsegmentation/outputs/K1-混淆矩阵.pdf', dpi=300)  # 保存图像
    plt.show()


from mmseg.datasets import ZihaoDataset

classes = ZihaoDataset.METAINFO['classes']
cnf_matrix_plotter(confusion_matrix_model, classes, cmap='Blues')

本节的代码整理如上,基本是对子豪兄的代码进行路径上的修改,也就是在路径最前面加mmsegmentation/

没什么可展开讲的,主要流程可以总结如下:

  • 定义config文件和pth文件的路径

  • 基于config文件和pth文件通过init_model函数建立模型

  • 各种方法来绘制原图与结果

  • 绘制混淆矩阵

2. 视频预测

"""
==========================================
@author: Seaton
@Time: 2023/8/20:16:56
@IDE: PyCharm
@Summary:使用训练好的模型进行单张图像推理
==========================================
"""
import time
import numpy as np
from tqdm import tqdm
import cv2

import mmcv
from mmseg.apis import init_model, inference_model

config_file = 'mmsegmentation/Zihao-Configs/ZihaoDataset_UNet_20230712.py'
checkpoint_file = 'mmsegmentation/checkpoint/myUNet.pth'

from mmseg.apis import init_model

model = init_model(config_file, checkpoint_file, device='cuda:0')

palette = [
    ['background', [127, 127, 127]],
    ['red', [0, 0, 200]],
    ['green', [0, 200, 0]],
    ['white', [144, 238, 144]],
    ['seed-black', [30, 30, 30]],
    ['seed-white', [8, 189, 251]]
]
palette_dict = {}
for idx, each in enumerate(palette):
    palette_dict[idx] = each[1]

opacity = 0.3  # 透明度,越大越接近原图


# 逐帧处理函数
def process_frame(img_bgr):
    # 记录该帧开始处理的时间
    start_time = time.time()

    # 语义分割预测
    result = inference_model(model, img_bgr)
    pred_mask = result.pred_sem_seg.data[0].cpu().numpy()

    # 将预测的整数ID,映射为对应类别的颜色
    pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
    for idx in palette_dict.keys():
        pred_mask_bgr[np.where(pred_mask == idx)] = palette_dict[idx]
    pred_mask_bgr = pred_mask_bgr.astype('uint8')

    # 将语义分割预测图和原图叠加显示
    pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1 - opacity, 0)

    return pred_viz


# 视频逐帧处理代码模板
# 不需修改任何代码,只需定义process_frame函数即可
# 同济子豪兄 2021-7-10

def generate_video(input_path='videos/robot.mp4'):
    filehead = input_path.split('/')[-1]
    output_path = "out-" + filehead

    print('视频开始处理', input_path)

    # 获取视频总帧数
    cap = cv2.VideoCapture(input_path)
    frame_count = 0
    while (cap.isOpened()):
        success, frame = cap.read()
        frame_count += 1
        if not success:
            break
    cap.release()
    print('视频总帧数为', frame_count)

    # cv2.namedWindow('Crack Detection and Measurement Video Processing')
    cap = cv2.VideoCapture(input_path)
    frame_size = (cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))
    # fourcc = cv2.VideoWriter_fourcc(*'XVID')
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = cap.get(cv2.CAP_PROP_FPS)

    out = cv2.VideoWriter(output_path, fourcc, fps, (int(frame_size[0]), int(frame_size[1])))

    # 进度条绑定视频总帧数
    with tqdm(total=frame_count - 1) as pbar:
        try:
            while (cap.isOpened()):
                success, frame = cap.read()
                if not success:
                    break

                # 处理帧
                # frame_path = './temp_frame.png'
                # cv2.imwrite(frame_path, frame)
                try:
                    frame = process_frame(frame)
                except:
                    # print('报错!', error)
                    pass

                if success == True:
                    # cv2.imshow('Video Processing', frame)
                    out.write(frame)

                    # 进度条更新一帧
                    pbar.update(1)

                # if cv2.waitKey(1) & 0xFF == ord('q'):
                # break
        except:
            print('中途中断')
            pass

    cv2.destroyAllWindows()
    out.release()
    cap.release()
    print('视频已保存', output_path)


generate_video(input_path='demo/test.mp4')

本节整理代码如上,基本原理与单张预测几乎一样,多了一步就是将视频拆成单帧,进行预测后再拼合成视频并保存。

3. 整个文件夹图片预测

"""
==========================================
@author: Seaton
@Time: 2023/8/20:18:37
@IDE: PyCharm
@Summary:使用训练好的模型进行文件夹下所有图像推理
==========================================
"""
import os
import numpy as np
import cv2
from tqdm import tqdm

from mmseg.apis import init_model, inference_model, show_result_pyplot
import mmcv

import matplotlib.pyplot as plt

# 模型 config 配置文件
config_file = 'mmsegmentation/Zihao-Configs/ZihaoDataset_UNet_20230712.py'
# 模型权重文件
checkpoint_file = 'mmsegmentation/checkpoint/myUNet.pth'

# 计算硬件
device = 'cuda:0'

model = init_model(config_file, checkpoint_file, device=device)

# 每个类别的 BGR 配色
palette = [
    ['background', [127, 127, 127]],
    ['red', [0, 0, 200]],
    ['green', [0, 200, 0]],
    ['white', [144, 238, 144]],
    ['seed-black', [30, 30, 30]],
    ['seed-white', [8, 189, 251]]
]

palette_dict = {}
for idx, each in enumerate(palette):
    palette_dict[idx] = each[1]

if not os.path.exists('mmsegmentation/outputs/testset-pred'):
    os.mkdir('mmsegmentation/outputs/testset-pred')

PATH_IMAGE = 'mmsegmentation/Watermelon87_Semantic_Seg_Mask/img_dir/val'
opacity = 0.3  # 透明度,越大越接近原图


def process_single_img(img_path, save=False):
    img_bgr = cv2.imread(img_path)

    # 语义分割预测
    result = inference_model(model, img_bgr)
    pred_mask = result.pred_sem_seg.data[0].cpu().numpy()

    # 将预测的整数ID,映射为对应类别的颜色
    pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
    for idx in palette_dict.keys():
        pred_mask_bgr[np.where(pred_mask == idx)] = palette_dict[idx]
    pred_mask_bgr = pred_mask_bgr.astype('uint8')

    # 将语义分割预测图和原图叠加显示
    pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1 - opacity, 0)

    # 保存图像至 outputs/testset-pred 目录
    if save:
        save_path = os.path.join('../', '../', '../', 'outputs', 'testset-pred', 'pred-' + img_path.split('/')[-1])
        cv2.imwrite(save_path, pred_viz)
        print('已保存')


os.chdir(PATH_IMAGE)
# for each in tqdm(os.listdir()):
# process_single_img(each, save=True)


# 批量可视化
os.chdir('../../../outputs/testset-pred')
# n 行 n 列可视化
n = 4

fig, axes = plt.subplots(nrows=n, ncols=n, figsize=(16, 10))

for i, file_name in enumerate(os.listdir()[:n ** 2]):
    img_bgr = cv2.imread(file_name)

    # 可视化
    axes[i // n, i % n].imshow(img_bgr[:, :, ::-1])
    axes[i // n, i % n].axis('off')  # 关闭坐标轴显示
fig.suptitle('Semantic Segmentation Predictions', fontsize=30)
# plt.tight_layout()
plt.savefig('../K3.jpg')
plt.show()

本节也是照猫画虎,终点在于os库的应用,官方代码有一处需要修改,即79行,将os.chdir('outputs/testset-pred')修改为os.chdir('../../../outputs/testset-pred')

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

MMSegmentation笔记06:推理 的相关文章

  • 笔记&代码

    可视化前三步走 数据类型 分析目的 实现工具 2 1 类别数据可视化 显示各类别的绝对频数及百分比等 条形图 饼图等 2 1 1 条形图及其变种 垂直条形图 类别在x轴 水平条形图 类别在y轴 简单条形图 并列条形图 堆叠条形图 1 简单条
  • 前端知识——css 之 flex 布局

    目录 一 认识 flex 布局 1 flex 布局的重要概念 二 flex 相关属性 1 flex container 中的属性 1 1 flex direction item 的排布方向 1 2 flex wrap 排布是否换行 1 3
  • Java多线程下载文件

    Java多线程下载文件 优化 合理利用服务器资源 将资源利用最大化 加快下载速度 一般有两种方式 线程池里面有N个线程 多线程下载单个文件 将网络路径的文件流切割成多快 每个线程下载一小部分 然后写入到文件里面 组成一个文件 当有很多个文件
  • MQ队列消息怎么保证100%不丢失

    面试官在面试候选人时 如果发现候选人的简历中写了在项目中使用了 MQ 技术 如 Kafka RabbitMQ RocketMQ 基本都会抛出一个问题 在使用 MQ 的时候 怎么确保消息 100 不丢失 这个问题在实际工作中很常见 既能考察候
  • javaScript基础面试题 --- new操作符具体做了什么?

    当我们使用new操作符调用函数时 背后发生了很多事情 这里是简单的new操作符的行为 创建一个新的空对象 将这个空对象的原型链接到构造函数的prototype对象 使用这个新对象作为上下文 即this的值 调用该构造函数 如果构造函数返回一
  • Yii Framework 开发教程(25) 数据库-Query Builder示例

    上一篇介绍PHP使用DAO 数据库访问对象接口 访问数据库的方法 使用DAO需要程序员编写SQL语句 对于一些复杂的SQL语句 Yii提供了Query Builder来帮助程序员生成SQL语句 Query Builder提供了一中面向对象的

随机推荐

  • Windows7安装docker以及使用docker安装centos7

    目录 一 WIN7安装DOCKER 二 docker安装centos7 1 查看可用的 CentOS 版本 2 拉取指定版本的 CentOS 镜像 3 查看本地镜像 4 运行容器 并且可以通过 exec 命令进入 CentOS 容器 5 安
  • STM32内部参考电压+DMA精准采集电池电压

    最近项目又遇到了电池电压采集 锂电池的电压范围是4 2到2 8一般 当锂电池低于3 3V时 单片机供电电压会小于3 3V 那么电池电压参考计算4096就不能对应3 3 所以必须采用内部参考电压 我项目中用到的是RP104N331 LDO 实
  • openwrt上nginx启动报错nginx: [emerg] getpwnam("www") failed

    检查nginx的配置文件 etc nginx nginx conf 里面配置里确实有这一项 user nobody nogroup user www www worker processes 2 系统的用户又没有www这个用户 这就尴尬了
  • 本地缓存技术分享

    本地缓存 缓存分为本地缓存与分布式缓存 本地缓存为了保证线程安全问题 一般使用ConcurrentMap的方式保存在内存之中 而常见的分布式缓存则有Redis MongoDB等 一致性 本地缓存由于数据存储于内存之中 每个实例都有自己的副本
  • 深度学习&强化学习&进化计算 入门资源整理

    深度学习 强化学习 进化计算 入门资源整理 深度学习 在线课程 在线书籍 学习Python 强化学习 在线课程 在线书籍 更多资源 进化计算 后记 深度学习 在线课程 深度学习是机器学习领域的一个分支 想要入门深度学习 最好先对机器学习的一
  • 学会了,不会ps也能更换自己的证件照底色,制作自己的证件照

    证件照经常会由于背景色与要求不符而不能用 再去拍一组浪费时间和金钱 如何省时省力的把照片背景色修改成我们所需要的底色呢 说到修改照片背景色 首先想到大家常用的证件照 根据不同用处会要求 白 蓝 红 底色 在过去大家可能需要去图片社重新照 或
  • [631]一行js代码识别Selenium+Webdriver

    文章目录 一行js代码识别Selenium Webdriver 如何正确移除Selenium中的 window navigator webdriver 最新版 附一些网站检测selenium的示例 driver execute script
  • ESP32 SIM800L:发送带有传感器读数的文本消息(SMS警报)

    在这个项目中 我们将使用T Call ESP32 SIM800L模块创建一个SMS通知系统 当传感器读数高于或低于特定阈值时 该模块会发送SMS 在此示例中 我们将使用DS18B20温度传感器 并在温度高于28 C时发送短信 一旦温度降低到
  • uniapp使用scroll-view实现左右,上下滑动

    uniapp使用scroll view实现左右 上下滑动 阐述 我们在项目中往往都能遇到实现左右滑动跟上下滑动的需求 不需要安装better scroll uniapp 自带的scroll view 就可以实现了 实现左右滑动
  • 开源项目,源码

    GitHub 优秀的 Android 开源项目 转自 http blog csdn net shulianghan article details 18046021 主要介绍那些不错个性化的View 包括ListView ActionBar
  • java基础03:final

    说明 final是java的一个关键字 是最终的意思 final 表示 最后的 最终的 含义 变量一旦赋值后 不能被重新赋值 被 final 修饰的实例变量 就是已经实例化的对象 必须显式指定初始值 final 修饰符通常和 static
  • Flash钓鱼->CS上线(免杀过火绒、360等)

    先看结果 访问钓鱼页面 点击立即升级即把马儿下载下来了 这个马儿是rar压缩的 做成的rar解压自启动 所以是个exe的文件 然后这里为了像一点 把图标给改了 双击运行 查看效果 首先CS是没东西的 解压路径现在也是没东西的 这里我把解压路
  • C#值参数和引用参数

    C 值参数和引用参数 一 值参数 未用ref或out修饰符声明的参数为值参数 使用值参数 通过将实参的值复制到形参的方式 把数据传递到方法 方法被调用时 系统做如下操作 在栈中为形参分配空间 复制实参到形参 值参数的实参不一定是变量 它可以
  • 几年的Unity学习总结

    stream 其中类Stream为抽象类 由此有三个派生类 需要引入命名空间 using System IO MemoryStream 对内存进行读取与写入 BufferedStream 对缓冲器进行读取 写入 FileStream 对文件
  • access统计班级人数_使用ACCESS查询统计分数段人数

    不少人都知道使用电子表格 excel 进行分数段统计 使用access的人也可以用它设计查询进行分数段人数统计 这里假设你有一个access表 也可以是基表的查询 名叫tblScore 当然可以是中文名称 只不过代码内也要作相应修改 表内是
  • 大数据挖掘简介

    大数据挖掘涉及如下的课程 机器学习 统计学 人工智能 数据库等 但是更多的注重如下的特性 1 可扩展性 Scalability 大数据 2 算法和架构 3 自动的处理大数据 我们需要学习挖掘不同类型的数据 1 高维的数据 2 图数据 3 无
  • Vue技术 v-cloak指令(用于在 Vue 实例加载和编译之前隐藏元素)

    1 v cloak 指令的用法 v cloak 指令通常与 CSS 配合使用 用于在 Vue 实例加载和编译之前隐藏元素 通过给元素添加 v cloak 属性 然后在 CSS 中定义对应的样式 可以确保在 Vue 实例加载完成前 该元素的内
  • flex布局——flex-direction属性

    1 flex布局原理 1 flex是flexible Box的缩写 意为 弹性布局 用来为盒状模型提供最大的灵活性 任何一个容器都可以指定为flex布局 当我们为父盒子设为flex布局以后 子元素的float clear 和vertical
  • CentOS 7.9搭建Discuz 3.5论坛(LNMP)

    这里写目录标题 安装规格 安装nginx 安装依赖 编译配置Nginx 安装MySQL 设置MySQL Yum源并安装MySQL 查看MySQL初始密码并修改 安装并配置PHP 下载并解压Discuz 3 5 安装Discuz 安装规格 安
  • MMSegmentation笔记06:推理

    1 单张图像预测 author Seaton Time 2023 8 19 15 38 IDE PyCharm Summary 使用训练好的模型进行单张图像推理 import cv2 import matplotlib pyplot as