2020年tensorflow定制训练模型笔记(1)——object detection的安装

2023-11-13

自己看着网上的很多教程摸索了好几天,终于能够自己训练。事实上,网上关于这个API的教程还是非常多的,但我实际做起来发现其实在某些关键部分缺少点步骤,会把我这样的小白搞得一头雾水、无从下手,最后在无穷无尽的报错中崩溃。所以我决定写这篇笔记,一来帮助最初像我一样的小白轻松搞定,二来就是为自己做笔记 ,以后万一忘记了,可以回来看看回想一下。

电脑配置

  • cpu: i7-8750H
  • gpu: 1060 6G
  • 内存: 16G
  • 操作系统:win10

我的环境

  • anaconda 3.x
  • python 3.6
  • tensorflow 1.15

cpu或gpu版本都可,具体怎么安装我就不介绍了,2.0版本我也试过,只有一步骤我暂时没有办法用2.0解决所以不写2.0。

tf2.0与1.0的一些区别

报错:AttributeError: module 'tensorflow' has no attribute 'contrib'
这个库在测试环节和训练环节使用,而tf2.0移除了contrib。就是这个原因使得我无法用tf2.0成功训练模型。
其他步骤1.0和2.0通用。因为我用的是高版本的1.0,可以识别一些2.0的内容,所以我的代码都改成适用于2.0的内容。例如tf.Session ->tf.compat.v1.Session
当然tf2.0也有专门升级的代码:

tf_upgrade_v2 --infile foo.py --outfile foo-upgraded.py
#tf_upgrade_v2 --infile 1.*版.py文件 --outfile 生成的2.0版.py文件

正式开始前,我还想说明下,有些类似报错找不到这个库的问题,我这里就不多说了,直接去anaconda去下载缺少的库,一些特殊的库我会特别说一下的。
另外,我这篇是自己的学习笔记,我也用了大量的别人的代码,用了的我基本都会写明出处,如果不妥的,请手下留情。

1.object detection

1.1下载

首先,我们得从GitHub上下载object detection。下载地址如下:
https://github.com/tensorflow/models
下载后解压你就会得到models-master的文件夹(ps:他文件里的路径都叫models,不知道为什么解压后叫models-master,这里因为我自己是改成models了,为了方便我自己截图演示,后面出现的都是models)

1.2配置

激活tensorflow环境,来到models/research/object_detection/ 文件夹下你会看到一个object_detection_tutorial.ipynb文件,这是一个demo文件,在tensorflow环境下用jupyter打开它出现如下界面:(我现在时间是2020.2.26,这是现在最新的版本)在这里插入图片描述
我们按照他的install步骤一步步来:
第一行:安装tensorflow2.0,我们不要理他,1.0也是可以用的,把这句话用#注释掉或直接删掉。
第二行:安装pycocotools,windows安装pycocotools必须先得安装cython,所以我们先去下载cython到tensorflow的环境里。你可以使用命令行,但我没试过,我是在anaconda里完成的,我个人感觉很方便。
安装cython后,你还要去网上安装git。
然后你要么在jupyter里运行这一行代码。我实际操作中感觉太慢了,就去上网找了另一个方法。在tf环境下用命令行:

pip install git+https://github.com/philferriere/cocoapi.git#subdirectory=PythonAPI

这好像是哪个大佬写的支持 Windows 的 COCO 地址,原网址是:https://www.jianshu.com/p/8658cda3d553
需要注意的是这个网是国外的,所以国内下载还是会有点慢,你要不等着要不那个(不说了)
第三行:可运行可不运行,就是下载这个API
第四行:配置路径,在jupyter里运行这个代码有点慢,所以我直接在命令行里操作
这里还需要提前安装protoc

cd models/research/
protoc object_detection/protos/*.proto --python_out=.

第五行:安装包,也是直接在命令行操作

pip install .

1.3测试

至此,我们也就完全搭好了环境,install的代码以后其基本都不会用了,全部注释掉或者删掉。接下来直接所有的运行代码,也许你会成功,但我是失败的,我到最后要出结果的时候内核死机了,一开始我为了这个问题苦苦寻求答案很久,直到最近我才发觉这可能是系统的问题,因为写这代码的人是在Linux上运行的,可能代码里的哪一部分与win有冲突,反正这demo代码里有块内容是运行不起来的,不是我们环境和电脑的问题。
办法嘛,是有的,我找了好多教程,才找到一个适合我们这种情况的代码

import os
import sys
import cv2
import numpy as np
import tensorflow as tf
sys.path.append("..")

from utils import label_map_util
from utils import visualization_utils as vis_util


class TOD(object):
    def __init__(self):
        # 这是用于对象检测的实际模型的路径,如果没有这个pb文件,说明你还未下载。可以用demo里的下载代码来替换。
        self.PATH_TO_CKPT = 'ssd_mobilenet_v1_coco_2017_11_17/frozen_inference_graph.pb'
        #  用于为每个框添加正确标签的字符串的列表的路径。
        self.PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')

        self.NUM_CLASSES = 90

        self.detection_graph = self._load_model()
        self.category_index = self._load_label_map()

    def _load_model(self):
        detection_graph = tf.Graph()
        with detection_graph.as_default():
            od_graph_def = tf.compat.v1.GraphDef()
            with tf.io.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')
        return detection_graph

    def _load_label_map(self):
        label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
        categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=self.NUM_CLASSES, use_display_name=True)
        category_index = label_map_util.create_category_index(categories)
        return category_index

    def detect(self, image):
        with self.detection_graph.as_default():
            with tf.compat.v1.Session(graph=self.detection_graph) as sess:
                # Expand dimensions since the model expects images to have shape: [1, None, None, 3]  扩展维度,因为模型期望图像具有以下形状:[1,None, None, 3]
                image_np_expanded = np.expand_dims(image, axis=0)
                image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
                # Each box represents a part of the image where a particular object was detected.  每个框表示检测到特定对象的图像的一部分。
                boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
                # Each score represent how level of confidence for each of the objects.  每个分数表示每个对象的置信度。
                # Score is shown on the result image, together with the class label.  分数与类标签一起显示在结果图像上。
                scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
                classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
                num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')
                # Actual detection.  实际检测。
                (boxes, scores, classes, num_detections) = sess.run(
                    [boxes, scores, classes, num_detections],
                    feed_dict={image_tensor: image_np_expanded})
                # Visualization of the results of a detection.  检测结果的可视化。
                vis_util.visualize_boxes_and_labels_on_image_array(
                    image,
                    np.squeeze(boxes),
                    np.squeeze(classes).astype(np.int32),
                    np.squeeze(scores),
                    self.category_index,
                    use_normalized_coordinates=True,
                    line_thickness=8)

        while True:
            cv2.namedWindow("detection", cv2.WINDOW_NORMAL)
            cv2.imshow("detection", image)
            if cv2.waitKey(110) & 0xff == 27:
                break


if __name__ == '__main__':
    image = cv2.imread('test_images/image1.jpg')#测试照片的路径
    detecotr = TOD()
    detecotr.detect(image)

该代码出自Tensorflow object detection API 搭建属于自己的物体识别模型(转载修改)这篇文章。我个人是非常感谢这位博主,它解决了我的测试不成功的问题,上面这个代码要注意的是三个路径,我在代码里已经标注出来了,这三个地方根据自己文件夹的情况自行更改。测试路径你也可以改成一个for循环测试一组图片。
运行此代码,这张带标注的照片终于出现了:
在这里插入图片描述

小结

因为之前被内核死掉的事崩溃太久了,我不仅新版本demo会崩溃,老版本的demo也是加载不出照片,只有这个用上cv的代码才可以。这张照片一出来,我当时才舒了一口气,顿时对后面的训练充满信心了有木有。
下一篇我们继续,接下来是生成训练文件的事了…
如果你知道具体的原因,可以的话就教教我吧,在下面的评论留个言或者私我。

2020年tensorflow定制训练模型笔记(1)——object detection的安装
2020年tensorflow定制训练模型笔记(2)——制作标签
2020年tensorflow定制训练模型笔记(3)——开始训练

目前官方已经更新了这个库,我提供的这个demo文件会报一些错误,如果你不可以解决这些错误,请使用官方的demo。
如果你和我一样无法使用官方的demo,你也不在乎指定用哪个库来完成目标检测,那就可以移步至我的关于yolov5的笔记
yolov5笔记(1)——安装pytorch_GPU(win10+anaconda3)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

2020年tensorflow定制训练模型笔记(1)——object detection的安装 的相关文章

随机推荐

  • 系统架构设计高级技能 · Web架构设计

    现在的一切都是为将来的梦想编织翅膀 让梦想在现实中展翅高飞 Now everything is for the future of dream weaving wings let the dream fly in reality 点击进入系
  • webpack多页面改名的注意事项

    今天在进行项目打包时 由于甲方新规定了文件的名字 需要我们对原先的文件名进行重命名 这个需求是不是很简单 确实很简单 但是一不注意 就会给自己造成找错半天 原来的名字 进行改名 webpack同步更改如下 满心欢喜的以为自己改完了 然后np
  • Volley 源码解析

    1 功能介绍 1 1 Volley Volley 是 Google 推出的 Android 异步网络请求框架和图片加载框架 在 Google I O 2013 大会上发布 名字由来 a burst or emission of many t
  • BeanUtil拷贝对象或集合时属性名不对应导致为空

    项目场景 源和目标实体类中的客户ID字段不对应 在使用Hutool的BeanUtil拷贝时字段为空 问题描述 源实体类属性 客户ID private String customerId 目标实体类属性 客户ID private String
  • jquery 小数计算保持精度,同时保留两位数

    点击打开链接 Num 3 Price 11 50 Number Price Num toFixed 2 34 50
  • pytorch 模型GPU推理时间探讨3——正确计算模型推理时间

    前言 上文说到 在统计pytorch模型的推理时间时发现每次的前几次推理耗时都非常多 而且在后面多次的推理中 其时间也呈现出很大的变化 后来经过调研 得知模型在GPU上推理时 需要对GPU进行一个warm up阶段 使得显卡达到工作状态 对
  • 串口拦截通信数据信息

    最近手头上有一个需要通信的外部设备 流量计 直接去看他的通信手册 里面没有例子 SO 刚开始看不太懂 官网上面有一个上位机软件 可以直接操作软件去设置参数 故 利用此上位机软件发送指令 然后在上位机和设备之间引出TX与RX 从而拦截二者串口
  • Qt类中使用函数指针

    使用函数指针有三步骤 1 声明一个函数指针 返回值类型和参数类型要与待指向的函数类型和参数一致 2 获取函数的地址 函数指针指向函数名 3 使用函数指针来调用所指向的函数 class Widget public QWidget public
  • CORE-ESP32C3

    目录 参考博文 源于网友oled eink aht10项目 源代码修改及复现说明 主要修改 显示效果 编辑硬件准备 软件版本 日志及soc下载工具 软件使用 接线说明 天气显示屏 硬件接线 温度采集 日期温度显示屏 正常初始化LOG 示例代
  • Spring Boot跨域问题简介

    什么是跨域问题 在Web开发中 跨域指的是在浏览器中访问一个不同于当前域名的资源 浏览器出于安全考虑 限制了这种跨域资源的访问 具体来说 当浏览器使用XMLHttpRequest或Fetch API发送跨域请求时 目标服务器必须在响应头中包
  • Python爬虫-11-response.text出现乱码的解决方案

    代码如下 这里是封装的一个下载url页面的方法 import requests def download page url user Agent None referer None print Downloading url headers
  • 前端xp单位和数值批量转换插件 编辑器正则匹配搜索

    因为要使手机端app自适应ipad端 所以要把项目中部分使用px的固定单位的改为相对单位 uniapp中规定了页面的宽度为750rpx 所以改起来还是很简单的 但是使用正则匹配修改px单位为rpx 编辑器可以按照正则匹配 但是因为没有运算功
  • 怎么禁用Windows Defender?

    如果你没有安装第三方杀毒软件 Windows10会自动激活其内置的Window Defender杀毒软件 虽然Windows Defender是Windows内置的 但是杀毒能力只能算比较平庸 并且在很多操作步骤和使用方法都不太符合用户的习
  • C++11Lambda表达式

    Lambda表达式 定义 可以理解为一个匿名函数 和函数一样 lambda表达式具有一个返回类型 一个参数列表和一个函数体 语法 capture list parameter list gt return type function bod
  • 使用tensorrt对keras-yolov3 模型进行低精度量化相关报错

    基本错误都是环境引起的 所以环境很重要 环境 python3 5 cuda10 0 cudnn 7 5 0 TensorRT 6 0 1 onnx 1 3 0 相关错误 错误1 NoneType object has no attribut
  • C++11--constexpr关键字

    关键字 constexpr 是在 C 11 中引入的 并在 C 14 中进行了改进 作用 它是用于表示 constant 常量 表达式的 常量表达式是指值不会改变并且在编译过程就能得到计算结果的表达式 使用常量表达式可以提高程序的执行效率
  • kali linux渗透测试之漏洞扫描

    主题内容就是进行漏洞扫描 文章目录 前言 一 Nikto 1 Nikto漏洞扫描介绍 2 Nikto使用 二 Nessus 1 Nessus介绍 2 安装nessus 3 nessus的简单使用 3 nessus扫描之advanced sc
  • 从浏览器地址栏输入url到显示页面的过程

    基本流程 1 用户在浏览器中输入url地址 2 浏览器解析域名得到服务器ip地址 浏览器会首先从缓存中找是否存在域名 如果存在就直接取出对应的ip地址 如果没有就开启一个DNS域名解析器 DNS域名解析器会首先访问顶级域名服务器 将对应的i
  • python编程入门书-最适合Python初学者的6本书籍推荐「必须收藏」

    Python是一种通用的解释型编程 主要用于Web开发 机器学习和复杂数据分析 Python对初学者来说是一种完美的语言 因为它易于学习和理解 随着这种语言的普及 Python程序员的机会也越来越大 如果你想学习Python编程 市场上就有
  • 2020年tensorflow定制训练模型笔记(1)——object detection的安装

    自己看着网上的很多教程摸索了好几天 终于能够自己训练 事实上 网上关于这个API的教程还是非常多的 但我实际做起来发现其实在某些关键部分缺少点步骤 会把我这样的小白搞得一头雾水 无从下手 最后在无穷无尽的报错中崩溃 所以我决定写这篇笔记 一