深度学习Pytorch(十)——基于torchvision的目标检测模型

2023-11-15

深度学习Pytorch(十)——基于torchvision的目标检测模型


本节,将微调Penn-Fudan数据库中对行人检测和分割的已预先训练的Mask R-CNN模型。该数据集包含170个图像和345个行人实例。用该数据集说明如何在torchvision中使用新功能,以便在自定义数据集上训练实例分割模型

一、定义数据集

对于训练对象检测的脚本,实例分割和人员关键点检测要能够轻松支持添加新的自定义的数据,数据集应该从标准的类torch.utils.data.Dataset继承,并实现__len__和__getitem__
对于__getitem__应该返回:

  1. 图像:PIL图像大小(h,w)
  2. target:包含如下字段的字典
    (1)boxes(FloatTensor[N,4]):N边框坐标的格式[x0,x1,y0,y1],取值范围是(0,w),(0,h)
    (2)labels(Int64Tensor[N]):每个边框的标签
    (3)image_id(Int64Tensor[1]):图像识别器。它应该在数据集中的所有图像中是唯一的,并在评估期间使用
    (4)area(Tensor[N]):边框的面积,在使用CoCo指标进行评估时使用此项来分割小,中和大框之间的度量标准得分
    (5)iscrowed(UInt8Tensor[N,H,W]):在评估期间属性设置为iscrowed=True的实例会被忽略
    (6)masks(UInt8Tesor[N,H,W]):每个对象的分段掩码
    (7)keypoints (FloatTensor[N, K, 3]:对于N个对象中的每一个,包含[x,y,visibility]格式的k个关键点,用于定义对象。visibility=0表示关键点不可见。注:对于数据扩充,翻转关键点的概念取决于数据表示,应该调整reference/detection/transforms.py以用于新的关键点表示
    如果要在training阶段使用宽高比分组,还需要实现get_high_width方法,返回图像的高度和宽度。若没有该function,将通过__getitem__查询数据集的所有元素,这种function会将图像加载到内存中,但是比提供的自定义的方法要

二、为PennFudan编写自定义数据集

1、下载数据集

数据集——>我在这里
下载后解压放在工作路径下,先查看一下数据集吧,结构如下:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
(黑压压一大片,这没问题哈,是掩膜,哈哈哈哈哈~不要以为自己下载的不对!)
文件夹中每个文件的详细说明如下(readme中详细说明):
在这里插入图片描述

2、为数据集编写类

在这里插入代码片

三、定义模型

定义一个可以预测上述数据集的模型。本节,使用Mask R-CNN,它基于Faster R-CNN。Faster R-CNN是一个模型,可以预测图像中潜在对象的边界框和类别得分。(此处贴一个Faster R-CNN的详解——>我在这里
在这里插入图片描述
Mask R-CNN在Faster R-CNN中添加了一个额外的分支,预测每个实例的分割蒙版
在这里插入图片描述
有两种常见情况可能需要修改torchvision modelzoo中的一个可用模型。

  • 想要从预先训练的模型开始的时候,微调最后一层
  • 想用不同的模型替换模型的主干时(用于更快的预测)

以下是读这两种情况的处理:

Ⅰ 微调已经预训练的模型

假设你想从一个在CoCo(其实它叫COCO )上已预先训练过的模型开始,并希望为特定类进行微调,这是一种可行的方法:

#%%微调已经预训练的模型
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# 在CoCo上加载经过预训练的预训练模型
model=torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# 将分类器替换为具有用户定义的num_classes的新分类器
num_classes=2#1 (person)+background
# 获取分类器的输入参数的数量
in_features=model.roi_heads.box_predictor.cls_score.in_features
# 用新的头部替换预先训练好的头部
model.roi_heads.box_predictor=FastRCNNPredictor(in_features, num_classes)
Ⅱ 修改模型以添加不同的主干
#%% 修改模型以添加不同的主干
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

# 加载预先训练的模型进行分类和返回
backbone=torchvision.models.mobilenet_v2(pretrained=True).features
# faster r-cnn需要知道骨干网络中的输出通道数量,对于mobilenet_v2,是1280,这里添加backbone.out_channels=1280
backbone.out_channels=1280
# 让RPN在每个空间位置生成5*3个锚点,具有5种不同的大小和3种不同的宽高比
# 有一个元组,每个特征映射可能具有不同的大小和宽高比
anchor_generator=AnchorGenerator(sizes=((32,64,128,256,512),),
                                 aspect_ratios=((0.5,1.0,2.0),))
# 定义用于执行区域裁剪的特征映射,以及重新缩放后裁剪的大小。如果主干返回Tensor,featmap_names应为[0]更一般的,主干应该返回OrderedDict[Tensor],并在feature_names中,选择要使用的功能映射
roi_pooler=torchvision.ops.MultiScaleRoIAlign(featmap_names=[0], 
                                              output_size=7,
                                              sampling_ratio=2)
# 将这些pieces放在FasterRCNN模型中
model=FastRCNNPredictor(backbone,
                        num_classes=2,
                        rpn_anchor_generator=anchor_generator,
                        box_roi_pool=roi_pooler)

1、PennFudan数据集的实例分割模型

我们从预先训练的模型进行微调,因为数据集非常小,所以将遵循上述第一种情况。还要计算实例分割掩膜,因此将使用Mask R-CNN:

#%%定义模型
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

def get_model_instance_segmentation(num_classes):
    # 加载在CoCo上预训练的实例分割模型
    model=torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    # 获取分类器的输入参数的数量
    in_features=model.roi_heads.box_predictor.cls_score.in_features
    # 用新的头部替换预先训练好的头部
    model.roi_heads.box_predictor=FastRCNNPredictor(in_features, num_classes)
    # 获取掩膜分类器的输入特征数
    in_features_mask=model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer=256
    # 用新的掩膜预测期替换掩膜预测期
    model.roi_heads.mask_predictor=MaskRCNNPredictor(in_features_mask, 
                                                     hidden_layer, 
                                                     num_classes)
    return model

四、整合

在**reference/detection/**中有许多辅助函数来简化训练和评估检测模型。此处将使用reference/detection/engine.py,reference/detection/util.py和reference/detection/transforms.py。只需要将他们复制到文件夹并在工作路径使用——>我在这里

1、为数据扩充/转换编写辅助函数

#为数据扩充/转换编写辅助函数
import sys
sys.path.append("D:\Python\Pytorch")
import transforms as t

def get_transform(train):
    transforms=[]
    transforms.append(t.ToTensor())
    if train:
        transforms.append(t.RandomHorizontalFlip(0.5))
    return t.Compose(transforms)

2、编写执行训练和验证的主要功能

#编写执行训练和验证的主要功能
from engine import train_one_epoch,evaluate
import utils

def main():
    # 在GPU训练,没有使用CPU
    device=torch.device('cuda')if torch.cuda.is_available() else torch.device('cpu')
    # 数据集只有两类人:背景和人
    num_classes=2
    # 使用数据集和定义的转换
    dataset=PennFudanDataset('PennFudanPed', get_transform(train=True))
    dataset_test=PennFudanDataset('PennFudanPed', get_transform(train=False))
    
    # 在训练和测试集拆分数据集
    indices=torch.randperm(len(dataset)).tolist()
    dataset=torch.utils.data.Subset(dataset,indices[:-50])
    dataset_test=torch.utils.data.Subset(dataset_test, indices[-50:])
    # 定义训练和验证数据加载器
    data_loader=torch.utils.data.DataLoader(
        dataset,batch_size=2,shuffle=True,num_workers=4,
        collate_fn=utils.collate_fn)
    data_loader_test=torch.utils.data.DataLoader(
        dataset,batch_size=1,shuffle=False,num_workers=4,
        collate_fn=utils.collate_fn)
    # 使用辅助函数获取模型
    model=get_model_instance_segmentation(num_classes)
    # 将模型迁移到合适的设备
    model.to(device)
    # 构造一个优化器
    params=[p for p in model.parameters() if p.requires_grad]
    optimizer=torch.optim.SGD(params, lr=0.005,
                              momentum=0.9,weight_decay=0.0005)
    # 和学习率调度程序
    lr_scheduler=torch.optim.lr_scheduler.StepLR(optimizer, step_size=3,gamma=0.1)
    # 训练10个epoch
    num_epochs=10
    
    for epoch in range(num_epochs):
        # 训练一个epoch,每10此迭代打印一次
        train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
        # 更新学习速率
        lr_scheduler.step()
        # 在测试集上评价
        evaluate(model, data_loader_test, device=device)
    print("That‘s all")

在第一个epoch训练后可以得到一下结果:
在这里插入图片描述
因此,在一个epoch训练之后,获得了CoCo-style mAP为64.8,mask mAP为67.6——如下图
在这里插入图片描述
经过训练10个epoch后,得到如下指标:
在这里插入图片描述
小结明天再写,太晚了,我要赶紧回宿舍洗洗睡了。明天见,猿友们~

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习Pytorch(十)——基于torchvision的目标检测模型 的相关文章

随机推荐

  • 13 款炫酷的 MySQL 可视化管理工具!好用到爆!!

    MySQL 的管理维护工具非常多 除了系统自带的命令行管理工具之外 还有许多其他的图形化管理工具 工具好用是一方面 个人的使用习惯也很重要 这里介绍 13 款 MySQL 图形化管理工具 供大家参考 1 DBeaver DBeaver 是一
  • 分数运算(类+构造)

    题目描述 用C 定义和实现一个分数类 并根据要求完成分数对象的运用 分数类包含分子和分母两个属性 操作包括 各属性的get方法 构造函数 初始化分子分母 相加运算 该运算接收两个分数对象的分子和分母 然后进行分数相加 结果保存在自己的分子和
  • 利用 ViewBinding 和反射封装的基类,从此再也不用 findViewById 了

    code小生 一个专注大前端领域的技术平台公众号回复Android加入安卓技术群 作者 段颖超丨乐拼链接 https www jianshu com p ea395a83c666声明 本文已获段颖超丨乐拼授权发表 转发等请联系原作者授权 今
  • 计算机图形学入门(一)-线性代数部分知识1

    本部分主要介绍了向量的点乘与叉乘在图形学中的基本应用 介绍了图形学中常用的2D矩阵变换 例如缩放 对称 切变换 旋转 平移 逆变换 组合变换和分解变换 还有在图形学中为了简化操作而采取的添加维度的方法 主要的学习过程来自下面的视频 本文只会
  • select case when语句

    今天看见一公司的数据库面试题 其中有道一开始没想起怎么做 后来摸索了下终于做出来了 题目是 两个表联合查询 当表2的记录在表A里没有时 将其设置为0 mysql gt select from t1 id name 1 bbs 2 bb 3
  • folly库安装(2)openssl升级、python3.8安装

    openssl是必须要升级到openssl1 1 1的 python3 8可以选择不安装 因为folly官网提供了一种用python3快速安装的方法 但这个方法在国内不太顺利 被墙的原因 很多包是不能自动下载的 但了解下也是好的 用pyth
  • 面试官:生成订单30分钟未支付,则自动取消,该怎么实现?

    了解需求 方案 1 数据库轮询 方案 2 JDK 的延迟队列 方案 3 时间轮算法 方案 4 redis 缓存 方案 5 使用消息队列 了解需求 在开发中 往往会遇到一些关于延时任务的需求 例如 生成订单 30 分钟未支付 则自动取消 生成
  • Centos下服务异常停止,log无任何异常体现localhost kernel: TCP: request_sock_TCP: Possible SYN flooding on port 8080

    背景 这两天项目出了奇怪的问题 某服务不明原因的停止运行 停止前一切正常 解决 重启服务解决 但不多久又出现类型问题 分析 开发和运维给不出原因 无法向客户交待 头大时刻想到系统运行日志 var log message应该有记录相应log
  • MySQL必知必会 学习笔记 第二十二章 使用视图

    MySQL 5添加了对视图的支持 视图是虚拟的表 它包含的是一个查询的结果 它本身不含数据 只是用来查看存储在别处的数据的一种设施 视图返回的数据是从其他表中检索出来的 在添加或更改这些表中的数据时 视图将返回改变过的数据 视图的应用 1
  • java数据结构基础名词解释

    第一章 绪论 数据与数据结构 数据 信息的载体 数据元素 数据中的一个 个体 是数据的基本组织单位 数据项 简单数据项 例如 姓名 年龄 组合数据项 例如 出生年月日 包含年 月 日三个简单数据项 数据对象 属性相同的数据元素的集合 数据结
  • 2022.11.29(面经五,笔试+技术面)

    2022 11 29 面经五 笔试题目不难 多刷力扣就成 1 什么是面向对象 面向对象 是把构成问题的事务分解成各个对象 而建立对象的目的也不是为了完成一个个步骤 而是为了描述某个事物在解决整个问题的过程中所发生的行为 附加 面向过程 分析
  • 【自我提高】树莓派GPIO的几种语言控制方法 C 篇

    使用C语言控制 GPIO 18 首先知道树莓派外置IO的关系对照表 我这里的树莓派是 PI 3B V1 2 关系对照表如下 C 语言下使用 wiringPi GPIO 进行编程 要安装 wiringPi pi raspberrypi sud
  • perl 入门推荐

    整理了一些perl链接 perl没有太多复杂概念 了解基础后 就可以编写各种需求脚本了 perl 相同功能 实现的方法有很多 只需了解最最基本的那种方法就好 剩下的就是百度 熟能生巧 举一反三 理解消化 perl语言 一个视频全解决 在线播
  • 合宙Air700E/4G模块使用AT指令查询基础信息

    Air700E使用AT指令查询基础信息 前言 AT指令使用 AT 确认AT固件 AT CGMR 请求制造商版本 AT CGMM 返回制造商型号编码 AT CGSN x 查询产品序列号 AT CGSN 1 查询IMEI AT CGSN 2 查
  • Flask项目(三)定义登录装饰器、图片服务、缓存机制、celery

    Flask项目 定义登录装饰器 redis文档 图片服务 封装七牛方法 城区数据下拉列表 缓存 用户认证相关 发布房源 map 函数 celery 基本使用 房屋管理 定义登录装饰器 utils commons py from werkze
  • x86汇编_MUL/IMUL乘法指令_笔记52

    32位模式下整数乘法可以实现32 16或8位的操作 64位下还可以使用64位操作数 MUL执行无符号乘法 IMUL执行有符号乘法 MUL指令 无符号数乘法 32 位模式下 MUL 无符号数乘法 指令有三种类型 执行 8 位操作数与 AL 寄
  • gitee删除上传到的远程分支的提交记录

    在实际开发中可能也经常会遇到写完代码后提交到远程分支但发现写的提交信息有误 不符合规范 由于自己的gitee账号可能没有修改提交记录的权限 因此最佳的解决方法是 撤销本地分支当前的提交记录 将代码回滚到上一个版本 提交前 重新强制再提交一版
  • 大数据挖掘、分析与应用

    第一讲 基础知识 大数据指无法在可承受的时间范围内用常规软件工具进行捕捉 管理和处理的数据集合 是需要新处理模式才能具有更强的决策力 洞察力和流程优化能力的海量高增长率和多样化的信息资产 数据挖掘 DataMining 是有组织有目的地收集
  • 安装SQLServer2008出现[HKLM\Software\Microsoft\Fusion!EnableLog] (DWORD)设置为 1

    问题 当我们卸载SQLServer2008后再重新安装后会出现以下问题 原因是卸载有时不能完全清理文件 解决方法 找到文件C Users user name AppData Local Microsoft Corporation删除Land
  • 深度学习Pytorch(十)——基于torchvision的目标检测模型

    深度学习Pytorch 十 基于torchvision的目标检测模型 文章目录 深度学习Pytorch 十 基于torchvision的目标检测模型 一 定义数据集 二 为PennFudan编写自定义数据集 1 下载数据集 2 为数据集编写