tensorflow SSD实战:基于深度学习的多目标识别

2023-11-01

SSD(SSD: Single Shot MultiBox Detector)是采用单个深度神经网络模型实现目标检测和识别的方法。如图2所示,该方法是综合了Faster R-CNN的anchor box和YOLO单个神经网络检测思路(YOLOv2也采用了类似的思路,详见YOLO升级版:YOLOv2和YOLO9000解析),既有Faster R-CNN的准确率又有YOLO的检测速度,可以实现高准确率实时检测。在300*300分辨率,SSD在VOC2007数据集上准确率为74.3%mAP,59FPS;512*512分辨率,SSD获得了超过Fast R-CNN,获得了80%mAP/19fps的结果,如图0-2所示。SSD关键点分为两类:模型结构和训练方法。模型结构包括:多尺度特征图检测网络结构和anchor boxes生成;训练方法包括:ground truth预处理和损失函数。本文解析的是SSD的tensorflow实现源码,来源balancap/SSD-Tensorflow。本文结构如下:

1,多尺度特征图检测网络结构;

2,anchor boxes生成;

3,ground truth预处理;

4,目标函数;

5,总结

 

2 SSD与MultiBox,Faster R-CNN,YOLO原理

(此图来源于作者在eccv2016的PPT)

 

图0-2 SSD检测速度与精确度。(此图来源于作者在eccv2016的PPT)

1 多尺度特征图检测网络结构

SSD的网络模型如图3所示。

3 SSD模型结构。(此图来源于原论文)

模型建立源代码包含于ssd_vgg_300.py中。模型多尺度特征图检测如图1-2所示。模型选择的特征图包括:38×38(block4),19×19(block7),10×10(block8),5×5(block9),3×3(block10),1×1(block11)。对于每张特征图,生成采用3×3卷积生成 默认框的四个偏移位置和21个类别的置信度。比如block7,默认框(def boxes)数目为6,每个默认框包含4个偏移位置和21个类别置信度(4+21)。因此,block7的最后输出为(19*19)*6*(4+21)。

 

 

所需的环境:

  • Anaconda3(64bit)

  • CUDA-9.2

  • CuDNN-7.15

  • Python-3.6

  • TensorFlow 或者 TensorFlow-gpu由于考虑到项目工程量比较大,所以本文选用GPU版本进行训练

步骤:

安装Anaconda3:在anaconda官网下下载对应windows版本的anaconda安装包,下载完成后按照提示进行安装,anaconda3里包含很多python IDE。

 

图4

 

安装cuda和cuDNN:安装Tensorflow-gpu版本必须先把cuda环境配置好,这样才能成功安装,如何查看cuda安装成功呢?可以在vs2017里面进行查看。

  1. 数据集制作:利用LabelImg工具对预先得到的需要训练的图片进行框选目标的最小包围框,LabelImg会自动将框选的目标生成所需要的xml文件

  2.  

 

图5

Xml文件中记录了人和帽子以及马甲的最小包围框。

  1. 格式转换:由于SSD-TensorFlow不能直接对图片进行训练,需要先将得到的图片转换为.tfrecord格式,在转换的过程中可以选择将多少张图片转为一个.tfrecord。

启动训练:在训练之前需要修改pascalvoc_2007.py和ssd_300_vgg.py文件中的识别目标数,改成自己需要识别的种类,然后在train_ssd_network.py文件中修改相应的参数。修改完成即可运行train_ssd_network.py文件,在console窗口可以查看相关的运行信息,当然在运行期间也可以查看tensorboard下的各个参数指标。

检验模型:模型检验在eval_ssd_network.py文件下进行检验,得到的输出信息mAP,为模型检验的准确率。

详细流程

Pascal voc2007数据集:

使用SSD-tensorflow的标准数据集包括Pascal voc2007和Pascal voc2012可以选择一起训练,也可以对两种数据集分开训练。在对视频进行处理时利用Matlab可以很容易将视频按自己选择的帧数截成图片。

得到图片后将图片以统一格式命名,放在JPEGImages文件夹下,接着将图片中的目标框选出来,框选图片有两种方法,一种是利用代码得到xml文件,不过这种方法每次都要改变代码里边框的大小信息,比较麻烦。还有一张使用自动化工具LabelImg。

 

Step1:点击open打开需要框选的图片的路径:。

Step2:点击Create RectBox,在图片上框选人所在位置。

Step3:在弹出的对话框中输入person。

Step4:对于hat,以及vest的框选重复步骤1,2,3即可直到所有的图片框选完。

注意:对同一幅图片.jpg文件和.xml命名要一样。

ssd-tensorflow-model工程相关参数修改:

  1. 转换为tfrecord格式:

首先将数据集保存在工程路径下,方便进行数据格式转换,以及模型训练。

Step1:打开pascalvoc_to_tfrecords.py文件,修改画框的参数,将SAMPLES_PER_FILES,改为自己需要将多少张图片转换成一个.tfrecord文件

比如我将其改为400,也就是是400张图片转换为一个.tfrecord文件。

Step2:tf_convert_data.py中,dataset_dir改为刚刚自己转换的图片所在的文件夹所在路径。output_dir改为自己需要保存的.tfrecord文件路径。

Step3:开始运行此.py文件代码,生成如下图所示文件

 

相关.py文件参数修改:

数据格式处理好了之后,下一步就要根据自己的目标种数修改大量的.py文件了,由于原项目工程为根据官网的标准数据集训练的共有20种目标需要训练,而自己的数据集训练目标为3种,所以对应的信息要修改。

Step1:修改pascalvoc2007.py相关信息,原项目21种目标,此处改为自己的目标种数NUM_CLASSES = 3,TRAIN_STATISTICS 第一个参数为目标出现在图片的张数,第二个参数为目标出现的总次数。SPLITS_TO_SIZES 为将数据集划分为训练数据集与测试数据集的张数。

Step2:修改pascalvoc_common.py相关信息

Step3:在ssd_vgg_300.py文件中修改如下参数。改为目标数加一(加一为背景还要算一种)。

  1. 开始训练模型:

为了开始训练模型,首先要将train_ssd_network.py文件相关信息修改正确。

Step1:train_dir改为自己需要将模型保存的路径下。

Step2:log_every_n_steps改为自己需要多少步保存显示相关信息。

Step3:save_summaries_secs为多少秒保存一次日志。

Step4:save_interval_secs为多少秒保存一次得到的模型。

Step5:gpu_memory_fraction为训练时需要占用多少百分百的显存。

Step6:leaning_rate为设置需要多大的学习率,一般开始学习率大一点以便不困在局部最优解,当loss比较小时learning_rate设置小一点降低loss。

Step7:end_learning_rate为训练结束时的学习率。

Step8:dataset_name修改为pascalvoc_2007,因为数据集为pascalvoc2007数据集。

Step9:num_classes为自己的目标数加一。

Step10:dataset_split_name改为train,划分tfrecord文件。

Step11:dataset_dir为转换的tfrecord文件路径。

Step12:model_name改为ssd_300_vgg,因为我们的图片resize为大小了。

Step13:batch_size根据自己显卡内存大小适当修改,太大了会报显存不足。

Step14:max_number_of_steps为最大训练步数,一般50000-100000均可。

Step15:checkpoint_path是否在预训练模型的基础上训练,此处我们可以在ssd_300_vgg模型的基础上训练我们的模型。

训练时输出窗口得到如下信息:

当loss值比较稳定而且比较小时说明模型训练的差不多了,可以停下来检验模型训练的好坏。

本次训练由于各种设备上的原因导致此次训练时间较为长,在ssd_300_vgg模型的基础上我训练了50000-60000次,得到了模型检测效果还比较好,接着为了得到更好的识别率,我将learning_rate改为了0.00001,继续训练了50000-60000步,模型精度提高了20%左右。

 

 

下图为在tensorboard下看到的训练过程中相关参数的变化,total_loss为训练过程中的损失度,可以看到损失度比较小,训练比较顺利。

 

 

图片检测效果分析:

ssd-tensorflow-model相比于Google API的缺点在于对小目标检测效果比较差,而此次训练得到的结果也能反映出这一点,当目标在图片中所占比例比较大时,训练得到的模型在检测时其效果比较好,寻找目标的位置较为准确,目标识别概率正确度高,如下图所示。

 

总结

 优点:SSD(Single Shot MultiBox Detector)在训练时,这种算法对于不同横纵比的object的检测都有效,这是因为算法对于每个feature map cell都使用多种横纵比的default boxes,这也是此算法的核心。另外本文的default box做法是很类似Faster RCNN中的anchor的做法的。最后本文也强调了增加数据集的作用,包括随机裁剪,旋转,对比度调整等等,而且此模型训练速度较YOLO更加迅速,可以获得更加明显的效果。 
  缺点:文中作者提到该算法对于小的object的detection比大的object要差。作者认为原因在于这些小的object在网络的顶层所占的信息量太少,所以增加输入图像的尺寸对于小的object的检测有帮助。另外增加数据集对于小的object的检测也有帮助,原因在于随机裁剪后的图像相当于“放大”原图像,所以这样的裁剪操作不仅增加了图像数量,也放大了图像。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

tensorflow SSD实战:基于深度学习的多目标识别 的相关文章

  • Tensorflow 中的自定义资源

    由于某些原因 我需要为 Tensorflow 实现自定义资源 我试图从查找表实现中获得灵感 如果我理解得好的话 我需要实现3个TF操作 创建我的资源 资源的初始化 例如 在查找表的情况下填充哈希表 执行查找 查找 查询步骤 为了促进实施 我
  • GradientTape 根据损失函数是否被 tf.function 修饰给出不同的梯度

    我发现计算的梯度取决于 tf function 装饰器的相互作用 如下所示 首先 我为二元分类创建一些合成数据 tf random set seed 42 np random seed 42 x tf random normal 2 1 y
  • 张量流服务错误:参数无效:JSON 对象:没有命名输入

    我正在尝试使用 Amazon Sagemaker 训练模型 并且希望使用 Tensorflow 服务来为其提供服务 为了实现这一目标 我将模型下载到 Tensorflow 服务 docker 并尝试从那里提供服务 Sagemaker 的训练
  • Keras model.predict 函数给出输入形状错误

    我已经在 Tensorflow 中实现了通用句子编码器 现在我正在尝试预测句子的类概率 我也将字符串转换为数组 Code if model model type universal classifier basic class probs
  • 错误:分配具有形状的张量时出现 OOM

    在使用 Apache JMeter 进行性能测试期间 我面临着初始模型的问题 错误 分配形状为 800 1280 3 和类型的张量时出现 OOM 通过分配器浮动在 job localhost replica 0 task 0 device
  • TensorFlow的./configure在哪里以及如何启用GPU支持?

    在我的 Ubuntu 上安装 TensorFlow 时 我想将 GPU 与 CUDA 结合使用 但我却停在了这一步官方教程 http www tensorflow org get started os setup md 这到底是哪里 con
  • TensorFlow HVX 加速支持

    我成功构建并运行了测试应用程序https github com tensorflow tensorflow tree master tensorflow contrib hvx https github com tensorflow ten
  • 在 Keras 模型中删除然后插入新的中间层

    给定一个预定义的 Keras 模型 我尝试首先加载预先训练的权重 然后删除一到三个模型内部 非最后几层 层 然后用另一层替换它 我似乎找不到任何有关的文档keras io https keras io 即将做这样的事情或从预定义的模型中删除
  • Tensorflow 对 Python3.11 的支持

    我在 Windows10 PC 上安装了 Python3 11 0 尝试使用以下命令安装张量流 pip install tensorflow 给出错误 访问tensorflow网站后 我意识到它仅支持3 7 3 10 我应该降级 pytho
  • Tensorflow 与 Keras 的兼容性

    我正在使用 Python 3 6 和 Tensorflow 2 0 并且有一些 Keras 代码 import keras from keras models import Sequential from keras layers impo
  • 将 tf.contrib.layers.xavier_initializer() 更改为 2.0.0

    我该如何改变 tf contrib layers xavier initializer tf 版本 gt 2 0 0 所有代码 W1 tf get variable W1 shape self input size h size initi
  • TensorFlow 2.0:在自定义训练循环中显示进度条

    我正在为音频分类任务训练 CNN 并且使用带有自定义训练循环的 TensorFlow 2 0 RC 如中所述本指南 https www tensorflow org beta guide keras training and evaluat
  • 如何在 Tensorflow 对象检测 API 中查找边界框坐标

    我正在使用 Tensorflow 对象检测 API 代码 我训练了我的模型并获得了很高的检测百分比 我一直在尝试获取边界框坐标 但它不断打印出 100 个奇怪数组的列表 经过在线广泛搜索后 我发现数组中的数字意味着什么 边界框坐标相对于底层
  • 在c++中的嵌入式python中导入tensorflow时出错

    我的问题是关于在 C 程序中嵌入 Python 3 5 解释器以从 C 接收图像 并将其用作我训练的张量流模型的输入 当我在 python 代码中导入tensorflow库时 出现错误 其他库工作正常 简化后的代码如下 include
  • 为什么平均百分比误差(mape)非常高?

    我已获得代码掌握机器学习 https machinelearningmastery com time series prediction lstm recurrent neural networks python keras 我修改了mod
  • Tensorflow 中使用 Adam Optimizer 时损失突然增加

    I am using a CNN for a regression task I use Tensorflow and the optimizer is Adam The network seems to converge perfectl
  • Keras 中批量大小可变的batch_dot

    我正在尝试编写一个层来合并 2 个张量formula https i stack imgur com I49aj png x 0 和x 1 的形状都是 1 500 M是500 500的矩阵 我希望输出为 500 500 我认为这在理论上是可
  • Tensorflow - 获取队列中的样本数量?

    对于性能监控 我想关注当前排队的示例 我正在平衡用于填充队列的线程数量和队列的最佳最大大小 我如何获得这些信息 我正在使用一个tf train batch 但我猜这些信息可能在下面的某个地方FIFOQueue 我本以为这是一个局部变量 但我
  • 缩小轴 1 的形状为空 [x,0]

    我正在尝试训练 SVHN 街景门牌号码 数据集 用于张量流中的对象检测 对数字进行一些基本的 OCR 到目前为止 我已经成功地遵循了对象检测张量流指南中的宠物训练示例 当我基于样本 fast rcnn resnet101 config 训练
  • 如何使用 Keras ImageDataGenerator 预测单个图像?

    我已经训练 CNN 对图像进行 3 类分类 在训练模型时 我使用 keras 的 ImageDataGenerator 类对图像应用预处理功能并重新缩放它 现在我的网络在测试集上训练得非常准确 但我不知道如何在单图像预测上应用预处理功能 如

随机推荐

  • 初学Java该学哪些知识?这6大知识必学

    目前 Java是开发人员的热宠 很多论坛都有不少热爱Java的开发人员 也有不少想成为Java程序员 但苦于不知道该如何学习Java 也不清楚该学些什么知识才能成为一个Java程序员 小千在这里抛砖引玉 和大家讨论初学Java应该掌握的知识
  • gitee配置ssh后仍需要密码

    gitee创建仓库后默认提供的是https链接需要修改为ssh才能免密登录 1 查看远程仓库链接 git remote v 删除远程仓库 git remote rm origin 重新添加远程仓库 ssh地址 git remote add
  • 关于hive中从hdfs上load数据到表中而HDFS上的数据却消失的若干问题

    原链接 https blog csdn net shuaikang666 article details 80357075 今天偶然间发现hive中一个我之前没有注意到的一个小细节 我怀疑你们之前也可能没有注意到 那就是当我们试图从HDFS
  • Adding New Functions to MySQL(User-Defined Function Interface UDF、Native Function)

    catalog 1 How to Add New Functions to MySQL 2 Features of the User Defined Function Interface 3 User Defined Function 4
  • postgres数据库相关使用说明

    默认的数据库和用户名是postgres 登录 psql U postgres d postgres ctrl c q 退出数据库交互模式 创建新用户 gwp createuser U postgres P d gwp 输入密码 mxq123
  • 路由器和交换机工作原理

    路由器工作原理 路由器 三层设备 同时基于二层设备工作 当数据包进到路由器时 首先查看的是二层报头 查看的是目标MAC 目标MAC分为三种 广播 组播 单播 广播地址 解封装到三层报头 组播地址 每一个组播地址均存在自己的MAC地址 基于目
  • 华为OD题目: 任务总执行时长

    package com darling boot order od od10 import com sun org apache bcel internal generic IF ACMPEQ import java util 任务总执行时
  • 几种I/O编程实践

    1 传统的BIO编程 网络编程的基本模型是Client Server模型 也就是两个进程间相互通信 其中 服务端提供位置信息 绑定的IP地址和监听端口 客户端提供连接操作向服务端监听的地址发起连接请求 通过三次握手建立连接 如果连接建立成功
  • Burpsuite在Firefox中无法抓取DVWA本地数据包解决方案+导入证书

    前言 这几天重装了系统 软件也大部分重新安装 在使用bp时 遇到了不能抓取dvwa数据包的情况 解决方案 猜想 可能是浏览器自动将127 0 0 1与localhost默认选择不使用代理服务 无法修改 反正我没找到 方案 将url栏中的12
  • java计算下一个整5分钟时间点

    需求 需要获取当前时间的下一个整点时间 如13 23 获取的下一个时间为 13 25 代码 获取下一个分钟值以0或者5结尾的时间点 单位 毫秒 return public static long getNextMillisEndWithMi
  • 机器数——源码、反码、补码

    机器数 源码 反码 补码 基本定义 1 机器数是将符号 数字化 的数 是数字在计算机中的二进制表示形式 表示一个机器数 应该考虑以下三个因素 1 机器数的范围 2 机器数的符号 3 机器数中小数点的位置 我们这里只讨论二进制整数在计算机中的
  • 【Java筑基】IO流基础之常见工具流和进程通信

    前 言 作者简介 半旧518 长跑型选手 立志坚持写10年博客 专注于java后端 专栏简介 深入 全面 系统的介绍java的基础知识 文章简介 本文将深入全面介绍IO流知识 建议收藏备用 创作不易 敬请三连哦 大厂真题 大厂面试真题大全
  • Python3 入门及基础语法

    文章目录 解释型语言 解释型语言优缺点 和编译性语言的区别 Python 简介 优点 缺点 和其他语言区别 Python 入门 Python 解释器安装 Python 继承开发环境安装 第一个 Python 程序 Python 基础 注释
  • MySql的时区(serverTimezone)引发的血案

    前言 mysql8 x的jdbc升级了 增加了时区 serverTimezone 属性 并且不允许为空 血案现场 配置jdbc的URL jdbc mysql IP PORT DB characterEncoding utf8 useSSL
  • Unity-人物移动

    Unity 人物移动 人物模型 参考以下视频 如何在Unity中导入pmx格式的MMD模型 哔哩哔哩 bilibili 用的是原神模型 这里要注意导入后把人物模型的Rig换为Humanoid 人物动作 使用的Unity Chan Model
  • iOS设备分辨率和icon尺寸

    经常需要告诉设计关于iPhone的分辨和icon的需要的尺寸 有时候自己也忘记了 都是从文档 Human Interface Guidelines 中取的 mark一下 icon相关 Device or context Icon size
  • Ubuntu 22 Server安装docker

    系统版本 Ubuntu 22 Server 按照如下文章进行了安装 Ubuntu 22 安装Docker环境
  • 升级go1.18版本json-iterator coredump问题

    unexpected fault address 0x0 fatal error fault signal SIGSEGV segmentation violation code 0x80 addr 0x0 pc 0x46639f goro
  • sqlserver千万数据查询分页

    sqlserver千万数据查询分页 前言废话 sqlserver 作业调用 mysql 前言废话 人生开始感受到无力 我不是没心没肺的人 可是我心里真的不舒服 sqlserver 新建一个表 if OBJECT ID test is not
  • tensorflow SSD实战:基于深度学习的多目标识别

    SSD SSD Single Shot MultiBox Detector 是采用单个深度神经网络模型实现目标检测和识别的方法 如图2所示 该方法是综合了Faster R CNN的anchor box和YOLO单个神经网络检测思路 YOLO