【Augmentation Zoo】RetinaNet + VOC + KITTI的数据预处理-pytorch版

2023-11-16

整合前段时间看的数据增强方法,并测试其在VOC和KITTI数据上的效果。我的工作是完成了对VOC和KITTI数据的预处理,RetinaNet的模型代码来自pytorch-retinanet

该项目github仓库在:https://github.com/zzl-pointcloud/Data_Augmentation_Zoo_for_Object_Detection

 

目录

一、VOC数据预处理

二、KITTI数据预处理

三、Resizer类和collater()类

1. Resizer类

2. Collater类


整个代码的处理逻辑是:

  1. 继承torch.Dataset类定义新的数据集类,如VocDatasets类,KittiDatasets类,重写__getitem__(image_index)函数,其功能是,输入图片序号,返回一个sample = {'img': img, 'annots': annots}。类中其他函数均服务于__getitem__函数,如load_image(),load_annotations()等。
  2. 将transform传入Dataset中,transform.Compose([fun1(), fun2(), ...])。其中fun是object继承类,定义其中的__call__(),使得他们可以被作为函数使用。对每张图片顺序执行函数fun1(), fun2(), ...。这里的fun()就是数据增强方法的入口
  3.  sampler(从数据集中取样本的策略)处理后,数据集类转换为DataLoader对象。通过sampler中设置的yield,迭代返回每一次的数据。
  4. 至此数据预处理部分完成,送入模型开始训练。训练的数据流动我个人理解如下:

每个epoch将整个数据在模型中走一遍,至于取样本的策略由第3步的sampler决定。每个epoch中,数据集N = batch_size * iter_num,每个iter,前向传播、反向传播,验证集测试,保存模型等。

retinanet

optimizer = optim.Adam(retinanet.parameters(), lr=1e-5)


for epoch_num in range(epochs):
    for iter_num, data in enumerate(dataloader_train):
        # 前向传播,求解loss
        retinanet.train()
        classification_loss, regression_loss = retinanet([data['img'].float, data['annot']])  
        classification_loss = classification_loss.mean()
        regression_loss = regression_loss.mean()
        loss = classification_loss + regression_loss
        
        #反向传播,更新权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    """
    validation part

    """

"""
test part

"""
# 保存模型
torch.save(retinanet, "model_final.pt")

一、VOC数据预处理

class VocDataset(Dataset):
    def __init__(self,
                 root_dir,
                 image_set='train',         # train/val/test
                 years=['2007', '2012'],    # 默认2007+2012
                 transform=None,
                 keep_difficult=False
                 ):
        self.root_dir = root_dir
        self.years = years
        self.image_set = image_set
        self.transform = transform
        self.keep_difficult = keep_difficult

        self.categories = VOC_CLASSES

        self.name_2_label = dict(
            zip(self.categories, range(len(self.categories)))
        )
        self.label_2_name = {
            v: k
            for k, v in self.name_2_label.items()
        }
        self.ids = list()
        self.find_file_list()

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, image_index):

        img = self.load_image(image_index)
        annots = self.load_annotations(image_index)
        sample = {'img':img, 'annot':annots}
        if self.transform:
            sample = self.transform(sample)
        return sample

    def find_file_list(self):
        for year in self.years:
            if not (year == '2012' and self.image_set == 'test'):
                root_path = os.path.join(self.root_dir, 'VOC' + year)
                file_path = os.path.join(root_path, 'ImageSets', 'Main', self.image_set + '.txt')
                for line in open(file_path):
                    self.ids.append((root_path, line.strip()))

    def load_image(self, image_index):
        image_root_dir, img_idx = self.ids[image_index]
        image_path = os.path.join(image_root_dir,
                                 'JPEGImages', img_idx + '.jpg')
        img = cv2.imread(image_path)
        if len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        return img.astype(np.float32)/255.0

    def load_annotations(self, image_index):
        image_root_dir, img_idx = self.ids[image_index]
        anna_path = os.path.join(image_root_dir,
                                'Annotations', img_idx + '.xml')
        annotations = []
        target = ET.parse(anna_path).getroot()
        for obj in target.iter("object"):
            difficult = int(obj.find('difficult').text) == 1
            if not self.keep_difficult and difficult:
                continue
            bbox = obj.find('bndbox')

            pts = ['xmin', 'ymin', 'xmax', 'ymax']

            bndbox = []
            for pt in pts:
                cut_pt = bbox.find(pt).text
                bndbox.append(float(cut_pt))
            name = obj.find('name').text.lower().strip()
            label = self.name_2_label[name]
            bndbox.append(label)
            annotations += [bndbox]
        annotations = np.array(annotations)

        return annotations

    def label_to_name(self, voc_label):
        return self.label_2_name[voc_label]

    def name_to_label(self, voc_name):
        return self.name_2_label[voc_name]

    def image_aspect_ratio(self, image_index):
        image_root_dir, img_idx = self.ids[image_index]
        image_path = os.path.join(image_root_dir,
                                  'JPEGImages', img_idx + '.jpg')
        img = cv2.imread(image_path)
        return float(img.shape[1] / float(img.shape[0]))

    def num_classes(self):
        return 20

二、KITTI数据预处理

对KITTI的数据预处理代码上与VOC相似,但在初始化KittiDataset类之前,需要先将KITTI数据集人工划分为训练/验证集,并生成类似于VOC中的train.txt和val.txt文件。因此我又实现了SplitKittiDataset类(在tools.py中),大概思路是:

1. 获得文件名list,及len

2. 用range(len)生成一个index,打乱后,按划分比例取train_index和val_index,然后从list中取对应的文件名

3. 保存到txt文件中。

class KittiDataset(Dataset):
    def __init__(self,
                 root_dir,
                 sets,
                 transform=None,
                 keep_difficult=False
                 ):
        self.root_dir = root_dir
        self.sets = sets
        self.transform = transform
        self.keep_difficult = keep_difficult

        self.categories = KITTI_CLASSES

        self.name_2_label = dict(
            zip(self.categories, range(len(self.categories)))
        )
        self.label_2_name = {
            v: k
            for k, v in self.name_2_label.items()
        }
        self.ids = list()
        self.find_file_list()

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, image_index):
        img = self.load_image(image_index)
        annot = self.load_annotations(image_index)
        sample = {'img':img, 'annot':annot}
        if self.transform:
            sample = self.transform(sample)
        return sample

    def find_file_list(self):
        file_path = os.path.join(self.root_dir, self.sets + '.txt')
        for line in open(file_path):
            self.ids.append(line.strip())

    def load_image(self, image_index):
        img_idx = self.ids[image_index]
        image_path = os.path.join(self.root_dir,
                                 'image_2', img_idx + '.png')
        img = cv2.imread(image_path)
        if len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        return img.astype(np.float32)/255.0

    def load_annotations(self, image_index):
        img_idx = self.ids[image_index]
        anna_path = os.path.join(self.root_dir,
                                'label_2', img_idx + '.txt')
        annotations = []
        with open(anna_path) as file:
            lines = file.readlines()
            for line in lines:
                items = line.split(" ")
                name = items[0].lower().strip()
                if name == 'dontcare':
                    continue
                else:
                    bndbox = [float(items[i+4]) for i in range(4)]
                    label = self.name_2_label[name]
                    bndbox.append(int(label))
                annotations.append(bndbox)
        annotations = np.array(annotations)
        return annotations

    def label_to_name(self, voc_label):
        return self.label_2_name[voc_label]

    def name_to_label(self, voc_name):
        return self.name_2_label[voc_name]

    def image_aspect_ratio(self, image_index):
        img_idx = self.ids[image_index]
        image_path = os.path.join(self.root_dir,
                                  'image_2', img_idx + '.png')
        img = cv2.imread(image_path)
        return float(img.shape[1] / float(img.shape[0]))

    def num_classes(self):
        return 8

三、Resizer类和collater()类

分别用于将图片修改为限定大小和对齐。

1. Resizer类

设置短边上限和长边上限,如608/1024
scale = 短边上限 / 短边
if 长边 * scale > 长边上限:
    scale = 长边上限 / 长边
resized_image = cv2.resize(image, 长边 * scale, 短边 * scale)

将resized_image长宽填充为32的倍数

2. Collater类

图片填充为数据集最长宽和最长高(如:长边上限 * 长边上限),图片从左上角对齐,其余部分填充为0。

annots填充则是以sample为单位,都扩充到最多目标数,其余填充-1

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【Augmentation Zoo】RetinaNet + VOC + KITTI的数据预处理-pytorch版 的相关文章

  • 【计算机视觉

    文章目录 一 检测相关 1篇 1 1 SegmentAnything helps microscopy images based automatic and quantitative organoid detection and analy
  • 图像目标检测之cascade-rcnn实践

    最近一直在调试目标检测方面的模型 其中mmdetection中就集成了许多的目标检测模型 其中表现比较好的模型中有cascade rcnn 因此也趁这个机会具体了解一下这个模型的发展脉络 1 模型原理 在two stage模型中 常见都会预
  • 【论文笔记_目标检测_2022】Cross Domain Object Detection by Target-Perceived Dual Branch Distillation

    基于目标感知双分支提取的跨域目标检测 摘要 在野外 跨领域目标检测是一项现实而具有挑战性的任务 由于数据分布的巨大变化和目标域中缺乏实例级注释 它的性能会下降 现有的方法主要关注这两个困难中的任何一个 即使它们在跨域对象检测中紧密耦合 为了
  • 利用Albumentations工具包进行图像的数据增强(以yolo数据标注格式为例)

    最近在看数据增强方法时 看到了这个有趣的工具包 研究了下并以yolo数据标注格式为例写了一个示例脚本 该工具最大的好处是会根据你使用的数据增强方法自动修改标注框信息 import albumentations as A import cv2
  • 特定场景小众领域数据集之——焊缝质量检测数据集

    写这篇文章最大的初衷就是最近频繁的有很多人私信问我相关的数据集的问题 基本上都是从我前面的目标检测专栏里面的这篇文章过来的 感兴趣的话可以看下 轻量级模型YOLOv5 Lite基于自己的数据集 焊接质量检测 从零构建模型超详细教程 保姆级的
  • 睿智的目标检测24——Keras搭建Mobilenet-SSD目标检测平台

    睿智的目标检测24 Keras搭建Mobilenet SSD目标检测平台 更新说明 学习前言 什么是SSD目标检测算法 源码下载 SSD实现思路 一 预测部分 1 主干网络介绍 2 从特征获取预测结果 3 预测结果的解码 4 在原图上进行绘
  • 动手学CV-目标检测入门教程4:模型结构

    3 4 模型结构 本文来自开源组织 DataWhale CV小组创作的目标检测入门教程 对应开源项目 动手学CV Pytorch 的第3章的内容 教程中涉及的代码也可以在项目中找到 后续会持续更新更多的优质内容 欢迎 如果使用我们教程的内容
  • 传统目标检测方法研究(一)

    1 传统算法目标检测 区域选择 gt 特征提取 gt 特征分类 1 1 区域选择 python 实现 图像滑动窗口 区域选取 首先选取图像中可能出现物体的位置 由于物体位置 大小都不固定 因此传统算法通常使用滑动窗口 Sliding Win
  • [YOLO专题-26]:YOLO V5 - ultralytics代码解析-detect.py程序的流程图与对应的plantUML源码

    作者主页 文火冰糖的硅基工坊 文火冰糖 王文兵 的博客 文火冰糖的硅基工坊 CSDN博客 本文网址 https blog csdn net HiWangWenBing article details 122443972 目录 第1章 det
  • 目标检测之YOLOv1算法分析

    网络结构 卷积层 池化层 全连接层 输入 448 448 448 448 448 448大小的图片 输出 7 7
  • 【计算机视觉

    文章目录 一 检测相关 8篇 1 1 Impact of Image Context for Single Deep Learning Face Morphing Attack Detection 1 2 A Theoretical and
  • IA-YOLO项目中DIP模块的初级解读

    IA YOLO项目源自论文Image Adaptive YOLO for Object Detection in Adverse Weather Conditions 其提出端到端方式联合学习CNN PP和YOLOv3 这确保了CNN PP
  • 【计算机视觉

    文章目录 一 检测相关 11篇 1 1 Follow Anything Open set detection tracking and following in real time 1 2 YOLO MS Rethinking Multi
  • 【计算机视觉

    文章目录 一 检测相关 11篇 1 1 Perspective aware Convolution for Monocular 3D Object Detection 1 2 SCoRD Subject Conditional Relati
  • 使用Stable Diffusion图像修复来生成自己的目标检测数据集

    点击上方 AI公园 关注公众号 选择加 星标 或 置顶 作者 R dig par Gabriel Guerin 编译 ronghuaiyang 导读 有些情况下 收集各种场景下的数据很困难 本文给出了一种方法 深度学习模型需要大量的数据才能
  • 睿智的目标检测60——Tensorflow2 Focal loss详解与在YoloV4当中的实现

    睿智的目标检测60 Tensorflow2 Focal loss详解与在YoloV4当中的实现 学习前言 什么是Focal Loss 一 控制正负样本的权重 二 控制容易分类和难分类样本的权重 三 两种权重控制方法合并 实现方式 学习前言
  • YOLO算法v1-v3原理通俗理解

    YOLO算法v1 v3原理通俗理解 深度学习检测方法简述 我们所使用的目标检测 其实就是让机器在图片找到对应的目标 然后给图片上的目标套上一个框框 并贴上标签 比如如果图片上有人 就把人框起来并标注一个 person 使用深度学习进行目标检
  • QueryDet:级联稀疏query加速高分辨率下的小目标检测

    论文 https arxiv org abs 2103 09136 代码 已开源 https github com ChenhongyiYang QueryDet PyTorch 计算机视觉研究院专栏 作者 Edison G 虽然深度学习的
  • 仅使用卷积!BEVENet:面向自动驾驶BEV空间的高效3D目标检测

    点击下方 卡片 关注 自动驾驶之心 公众号 ADAS巨卷干货 即可获取 gt gt 点击进入 自动驾驶之心 BEV感知 技术交流群 论文作者 Yuxin Li 编辑 自动驾驶之心 写在前面 个人理解 BEV空间中的3D检测已成为自动驾驶领域
  • 什么是概率匹配

    概率匹配是一种在信息论和统计学中常用的方法 用于将一个随机事件的概率分布与另一个概率分布进行匹配或逼近 它在数据处理 编码 压缩和模型选择等领域具有重要的应用 为我们理解和处理复杂的概率分布提供了一种有效的工具 首先 让我们来了解概率匹配的

随机推荐

  • 头插法和尾插法的详细区别

    浅析线性表 链表 的头插法和尾插法的区别及优缺点 线性表作为数据结构中比较重要的一种 具有操作效率高 内存利用率高 结构简单 使用方便等特点 今天我们一起交流一下单向线性表的头插法和尾插法的区别及优缺点 线性表因为每个元素都包含一个指向下一
  • IDE0006 加载项目时遇到了错误,已禁用了某些项目功能,例如用于失败项目和依赖于失败项目的其他项目的完整解决方案分析。

    重新打开vs2017就好了 原因猜测 vs来大姨妈了 现象是catch ex 后面是e message 单纯少个x vs没检测出来 辛辛苦苦搜个半天 可能太依赖vs了 懒人专属编辑器
  • npm私有化docker方式部署及使用说明

    一 部署nexus 本文采用docker方式部署nexus 安装docker yum install y docker 拉取nexus镜像 docker pull sonatype nexus3 准备本地映射目录 以便本地化持续存储数据 目
  • python No module named numpy. distutils._msvccompiler in numpy. distutils; trying from distutils

    在cmd 中输入 python setup py install 报错 No module named numpy distutils msvccompiler in numpy distutils trying from distutil
  • Android 报错 : FATAL EXCEPTION:main 解决方法

    今天安卓开发课上碰到的新问题 前景提示 老师让我们自己试一下那个两个页面跳转的效果 于是我就开始写了 然后报错 解决方法 逐一排查 首先要看你mainfest xml里面有没有增加Activity 当然我是加了 但是他还报错 具体代码界面
  • Ubuntu18.04 windows10双系统安装解决grub引导问题

    最近给服务器的电脑升级了ubuntu18 直接用u盘安装 老是说grub引导问题 网上有很多教程真的坑人 说的含含糊糊的 不知道在卖弄什么关子 我参照这两个教程解决了安装问题 十分钟就装好了 感谢你们 https blog csdn net
  • vue Tesseract的 ocr 文字识别

    npm结果页 https www npmjs com package tesseract js tesseract官网地址 https tesseract projectnaptha com npm结果页 npm结果页 tesseract官
  • 如何优雅的统计代码耗时

    点击上方 小强的进阶之路 选择 星标 公众号 优质文章 及时送达 预计阅读时间 16分钟 作者 Jitwxs 原文链接 底部链接可直达 https jitwxs cn 5aa91d10 html 一 前言 代码耗时统计在日常开发中算是一个十
  • R语言—列表

    文章目录 列表 定义 创建列表 List 列表 List 元素的引用 列表 List 元素的修改 访问列表元素和值 去列表化 在列表上使用apply系列函数 递归型列表 列表 R语言的6种模式 向量 矩阵 数组 数据框 列表 因子 向量 矩
  • SQLite如何删除,修改、重命名列

    今天在SQLite数据库中添加了一列 后来发现列名写错了 于是使用SQL语句来修改列名 可是根本不管用 首先 请放弃alter吧 sqlite官方说明如下 SQLite supports a limited subset of ALTER
  • 【JS】JavaScript时间与时间戳相互转换

    时间与时间戳相互转换 1 2 时间 JS常用时间类型 1 2 1 GMT 格林尼治标准时 1 2 2 UTC 协调世界时 1 2 3 中国标准时间 1 2 4 ISO8601标准时间格式 1 2 5 时间戳 timestamp 1 时间戳转
  • spring boot项目自动加载引入外部bean

    前言 spring boot项目简化了对外部项目的引入 使我们能够狠方便的构建一个web项目 我们通常在开发的过程中会开发出一些公用的模块组件 这样在项目找那个引入后能够直接使用 减少了轮子的重复构造 同时服务引入的模块化操作 能够更多的节
  • CV学习:OpenCv快速入门(python版)

    本文代码全部可运行 笔者运行环境 python3 7 pycharm opencv4 6 此文是学习记录 记录opencv的入门知识 对各知识点并不做深入探究 文章的目的是让阅读者在极短的时间达到入门水平 在学习过程中 我们应养成 查询op
  • pygame用blit()实现动画效果

    pygame的的实现动画的方法有很多 但是都是围绕着表面进行的 也就是说实现动画的方式不同 但是本质其实都是对表面的不同处理方式而已 原理其实很简单 有点像我们做地铁的时候隧道里的广告一样 我们设置一个窗口 然后让窗口在一个画着很多帧图像的
  • 约束综合中的逻辑互斥时钟(Logically Exclusive Clocks)

    注 本文翻译自Constraining Logically Exclusive Clocks in Synthesis 逻辑互斥时钟的定义 逻辑互斥时钟是指设计中活跃 activate 但不彼此影响的时钟 常见的情况是 两个时钟作为一个多路
  • IDEA从安装到使用--相关配置详解

    IDEA从安装到使用 相关配置详解 作为一个技术小白 刚开始学习使用Intellij IDEA 入门时踩了很多的坑 这里写下我的第一篇博客 分享相关IDEA的配置方法 希望能为各位提供一点帮助 IDEA2018安装及破解 作者 志哥的成长笔
  • 大数(四则运算)

    四则运算 大数加法 高精度加法 大数减法 大数乘法 大数乘法 幂运算 大数乘法 高精度幂运算 大数除法 大数加法 思路 从后往前算 即由低位向高位运算 计算的结果依次添加到结果中去 最后将结果字符串反转 输入的时候两个数都是以字符串的形式输
  • 网站架构演变

    网站架构演变 大型网站介绍 与传统企业应用系统相比 大型互联网网站系统具有以下特点 1 大流量 高并发 这一点往往是传统企业应用系统根本就不会遇到的问题 比如Goole每日访问量都是几十亿 如果服务器端处理不好早就被压的宕机了 2 高可用
  • 环形缓冲区(1)

    声明 参考韦东山视频教程 如若侵权请告知 马上删帖致歉 个人总结 如有不对 欢迎指正 环形缓冲区 环形缓冲区的几个基本操作 申请内存空间 写操作 读操作 环形缓冲区小结 判断缓冲区是否为空 判断缓冲区是否写满 构建环形缓冲区 在 h文件中声
  • 【Augmentation Zoo】RetinaNet + VOC + KITTI的数据预处理-pytorch版

    整合前段时间看的数据增强方法 并测试其在VOC和KITTI数据上的效果 我的工作是完成了对VOC和KITTI数据的预处理 RetinaNet的模型代码来自pytorch retinanet 该项目github仓库在 https github