利用谷歌的预训练模型实现目标检测object_detection_tutorial.ipynb

2023-11-12

环境准备

运行这个预训练的模型需要准备一些环境
首先需要下载谷歌的models-master.zip
地址在https://github.com/Master-Chen/models
在这里插入图片描述
下载完成后我们需要的是research/objection_detection这个项目
在运行这个项目之前还需要下载谷歌的protoc3.4.0
下载结束后只需要将bin目录里的protoc.exe文件放在有环境变量的一个目录下即可
之后在research路径下打开命令行 运行 protoc objection_detection/protocs/*.proto --python_out=.
这里运行后会在object_detection\protos路径下生成许多py文件,相当于把原来的proto文件编译成了py文件
至此,环境准备基本完成。注意的是,这里使用的tensorflow1.13.1-cpu

运行模型

准备工作完成后,在objection_detection路径下启动jupyter notebook,找到
在这里插入图片描述
进入这个笔记本
在这里插入图片描述
可以看到,这个笔记本将引导使用者运行这个预训练的目标检测模型

  • 导入相关模块
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from distutils.version import StrictVersion
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image

# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")
from object_detection.utils import ops as utils_ops

# tf版本需要大于1.9 
if StrictVersion(tf.__version__) < StrictVersion('1.9.0'):
    raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!')
  • 在jupyter中显示图片
# 在jupyter里面显示图片
%matplotlib inline
  • 导入模块
from utils import label_map_util

from utils import visualization_utils as vis_util
  • 指定模型的相关配置,譬如模型名称,下载地址,对应得pb文件存放路径,数据集label映射文件路径
    这里使用的是SSD模型,在coco数据集上训练的,其他模型文件可以在github下载
# 模型名称
MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'
MODEL_FILE = MODEL_NAME + '.tar.gz'
# 下载地址
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'

# pb模型存放位置
PATH_TO_FROZEN_GRAPH = MODEL_NAME + '/frozen_inference_graph.pb'

# coco数据集的label映射文件
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
  • 下载模型文件
# 下载文件
opener = urllib.request.URLopener()
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
# 解压文件
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
    file_name = os.path.basename(file.name)
    if 'frozen_inference_graph.pb' in file_name:
        tar_file.extract(file, os.getcwd())

这里运行结束后会在object_detection路径下生成在这里插入图片描述
并且会解压,且只解压出对应的pb文件,因为这里只使用模型,不重训练模型
在这里插入图片描述
-这里下载大概率会因网络问题无法成功,可以手动下载解压
地址 https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md

  • 载入训练好的模型
# 载入训练好的pb模型
detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')
  • 得到一个类别号和对于类别描述的字典
# 得到一个保存编号和类别描述映射关系的字典
category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
print(category_index)

在这里插入图片描述

  • 定义一个方法,把图片读取出三维数据,类型转换为uint8
def load_image_into_numpy_array(image):
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
  • 定义目标检测的函数传入图像,返回检测结果
# 目标检测
def run_inference_for_single_image(image, graph):
    with graph.as_default():
        with tf.Session() as sess:
            # 获得图中所有op
            ops = tf.get_default_graph().get_operations()
            # 获得输出tensor的名字
            all_tensor_names = {output.name for op in ops for output in op.outputs}
            tensor_dict = {}
            for key in [
              'num_detections', 'detection_boxes', 'detection_scores',
              'detection_classes',
            ]:
                tensor_name = key + ':0'
                # 如果tensor_name在all_tensor_names中
                if tensor_name in all_tensor_names:
                    # 则获取到该tensor
                    tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
                      tensor_name)
            # 图片输入的tensor
            image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')

            # 传入图片运行模型获得结果
            output_dict = sess.run(tensor_dict,
                             feed_dict={image_tensor: image})

            # 所有的结果都是float32类型的,有些数据需要做数据格式转换
            # 检测到目标的数量
            output_dict['num_detections'] = int(output_dict['num_detections'][0])
            # 目标的类型
            output_dict['detection_classes'] = output_dict[
              'detection_classes'][0].astype(np.uint8)
            # 预测框坐标
            output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
            # 预测框置信度
            output_dict['detection_scores'] = output_dict['detection_scores'][0]
    return output_dict
  • 遍历测试图像,输出检测结果,测试图像路径在test_iamges,将要测试的图像放进该路径即可
for root,dirs,files in os.walk('test_images/'):
    for image_path in files:
        # 读取图片
        image = Image.open(os.path.join(root,image_path))
        # 把图片数据变成3维的数据,定义数据类型为uint8
        image_np = load_image_into_numpy_array(image)
        # 增加一个维度,数据变成: [1, None, None, 3]
        image_np_expanded = np.expand_dims(image_np, axis=0)
        # 目标检测
        output_dict = run_inference_for_single_image(image_np_expanded, detection_graph)
        # 给原图加上预测框,置信度和类别信息
        vis_util.visualize_boxes_and_labels_on_image_array(
          image_np,
          output_dict['detection_boxes'],
          output_dict['detection_classes'],
          output_dict['detection_scores'],
          category_index,
          use_normalized_coordinates=True,
          line_thickness=8)
        # 画图
        plt.figure(figsize=(12,8))
        plt.imshow(image_np)
        plt.axis('off')
        plt.show()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
编译好的项目文件可以从这里下载:https://download.csdn.net/download/cyj5201314/18171589

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

利用谷歌的预训练模型实现目标检测object_detection_tutorial.ipynb 的相关文章

  • python的_random是什么?

    如果你打开random py看看它是如何工作的 它的类Random子类 random Random import random class Random random Random Random number generator base
  • Python 将列表中的字符串转换为数字

    我遇到了以下错误消息 以 10 为基数的 int 的文字无效 2 2 外部用单引号括起来 内部用双引号括起来 该数据位于primes列出使用print primes 0 样本数据在primes list 2 3 5 7 The primes
  • Python 按文件夹模块导入

    我有一个目录结构 example py templates init py a py b py a py and b py只有一个类 名称与文件相同 因为它们是猎豹模板 纯粹出于风格原因 我希望能够在中导入和使用这些类example py像
  • 地图与星图的性能?

    我试图对两个序列进行纯Python 没有外部依赖 逐元素比较 我的第一个解决方案是 list map operator eq seq1 seq2 然后我发现starmap函数来自itertools 这看起来和我很相似 但事实证明 在最坏的情
  • Pygame 玩家精灵没有出现

    我一直在为学校计算机课做这个项目 但无法让玩家精灵出现 有人可以帮忙吗 当我运行主游戏循环时 除了玩家精灵之外 所有内容都正确显示 它应该由于箭头输入而在屏幕上移动并受到重力的影响 当我删除图像并仅使用对象类和矩形时 该代码也有效 impo
  • 将列表传递给 PyCrypto 中的 AES 密钥生成器

    我尝试使用 Pycrypto 生成 AES 密钥 但收到以下错误 类型错误 列表 不支持缓冲区接口 对于以下声明 aescipher AES new mykey AES MODE ECB mykey 属于类型list并包含 18854347
  • int 对象在尝试对数字的数字求和时不可迭代? [复制]

    这个问题在这里已经有答案了 我有这个代码 inp int input Enter a number for i in inp n n i print n 但它抛出一个错误 int object is not iterable 我想通过将每个
  • Python变量赋值问题

    a b 0 1 while b lt 50 print b a b b a b 输出 1 2 4 8 16 32 wheras a b 0 1 while b lt 50 print b a b b a b 输出 正确的斐波那契数列 1 1
  • Python 函数可能会引发哪些异常? [复制]

    这个问题在这里已经有答案了 Python 中有什么方法可以确定 内置 函数可能引发哪些异常 例如 文档 http docs python org lib built in funcs html http docs python org li
  • PyPI 项目页面中的“Py 版本”是什么意思?这有关系吗?

    我注意到 大多数在 PyPI 上发布的项目在其项目页面中都包含 Py 版本 元数据 但它们的值各不相同 如果包不是通用包或不是纯 python 包 那么它们的值是不同的 这是可以理解的 以便表示它们的目标平台 例如鼻页 https pypi
  • 在 Ubuntu 上使用 Python 获取显示器分辨率

    对于 Ubuntu win32api 中是否有与 GetSystemMetrics 相当的代码 我需要获取显示器的宽度和高度 以像素为单位 我可以建议一些可以使用的方法 不过我还没有使用过 xlib 版本 1 xlib Python 程序的
  • 将 csv 文件按多列拆分为 panda 数据框

    我有一个包含多列的 tsv 文件 有 10 多列 但对我来说重要的列是名称为 user name shift id url id 的列 我想创建一个数据框 首先根据用户名分隔整个 csv 文件 即只有具有相同用户名的行才会分组在一起 从该块
  • Python3.1中的视图?

    Python3 1中的视图到底是什么 它们的行为方式似乎与迭代器类似 并且它们也可以具体化为列表 迭代器和视图有何不同 据我所知 视图仍然附加到创建它的对象上 对原始对象的修改会影响视图 来自docs http docs python or
  • 如何替换被测模块的文件访问引用

    pyfakefs https code google com p pyfakefs 听起来非常有用 它 最初是作为核心 Python 模块的一个适度的假实现来开发的 以支持中等复杂的文件系统交互 并于 2006 年 9 月在 Google
  • Python unittest - 与assertRaises相反?

    我想编写一个测试来确定在给定情况下不会引发异常 测试是否有异常很简单is上调 sInvalidPath AlwaysSuppliesAnInvalidPath self assertRaises PathIsNotAValidOne MyO
  • 在字典理解中为 locals() 添加下标失败并出现 KeyError [重复]

    这个问题在这里已经有答案了 我对 Python 的奇怪行为感到困惑locals 基本上我想从字典中获取一个项目locals 在字典理解中 但它失败了 这是一个非常基本的事情 所以 gt gt gt foo 123 gt gt gt bar
  • Networkx 中 Louvain 分区的可视化

    请帮助我更改 Louvain 聚类算法结果的可视化 我从网站上获取了代码https github com taynaud python louvain https github com taynaud python louvain我可以重写
  • 使用 Pandas 和 Group By 绘制堆叠直方图

    我正在使用如下所示的数据集 Gender Height Width Male 23 4 4 4 Female 45 4 4 5 我想可视化高度和宽度的堆叠直方图 我希望每个图有两个堆叠的直方图 每个性别一个 这是文档中的堆叠直方图 如果存在
  • python pandas如何在多个条件下过滤字符串

    我有以下数据框 import pandas as pd data 5Star FiveStar five star fiv estar data pd DataFrame data columns columnName 当我尝试用一 种条件
  • Pandas 替换特定列上的值

    我知道这两个类似的问题 熊猫替换值 https stackoverflow com questions 27117773 pandas replace values Pandas 替换数据框中的列值 https stackoverflow

随机推荐

  • 逐行读取csv文件的某一列以及写入数据

    1 在Python中 你可以使用内置的csv模块来读取CSV文件 并逐行读取指定的某一列 下面是一个示例代码 展示如何逐行读取CSV文件的某一列 import csv 打开CSV文件 with open your file csv r as
  • webpack4之代码分割splitChunks和压缩优化

    我们打包出来的js文件 只要修改或增加了内容 就会导致入口js文件的hash变化 从而重新打包 为了提高打包速度 每次变化仅仅是重新打包自定义代码部分 webpack4提供了optimization splitChunks 回顾一下 web
  • 【Linux之shell脚本实战】批量上传docker镜像到华为云容器镜像仓库

    Linux之shell脚本实战 批量上传docker镜像到华为云容器镜像仓库 一 脚本要求 二 检查本地环境 1 检查系统版本 2 检查系统内核 三 检查本地容器镜像 四 shell注释模板配置 1 配置 vimrc 2 查看注释模板效果
  • MediaCodec问题汇总

    参考 http blog csdn net mincheat article details 51385144 MediaCodec的基本用法 网上一大把 这里就不写了 1 获取支持分辨率问题 Camera Parameters param
  • 设计模式-责任链模式(Java)

    设计模式 责任链模式 在极客学院的视频中学习了一种设计模式的方式 责任链模式 在博客园中发现了这篇文章 讲的很详细 就把它的一些内容转载过来了 本文中 我们将介绍设计模式中的行为型模式职责链模式 职责链模式的结果看上去很简单 但是也很复杂
  • MySql存储过程

    一 Mysql存储过程概述 存储过程是数据库的一个重要对象 对象还包括 索引 触发器 视图等 可以封装sql语句集 用来完成比较复杂的业务逻辑 并且还可以入参 出参 存储过程创建时会进行预编译进行保存 当下次调用时不需要再进行编译 优点 在
  • STM32设置IO口输入上拉下拉

    1 按键分类 WK UP按键按下时将高电平信号输入给STM32的IO 即高电平有效 不被按下时 由于干扰 可能高也可能是低信号输入 KEY0按键按下时将低信号输入给STM32的IO 即低电平有效 不被按下时 由于干扰 可能高也可能是低信号输
  • Java基础-学习笔记(三)

    本节记录和学习Java的一种引用数据类型 数组 静态方法的声明 字符串的基本概念和使用 1 数组 array 是具有相同数据元素的有序集合 Java中的数组是引用数据类型 一个数组变量采用应用方式保存多个数组元素 Java的数组都是动态数组
  • Unity内存管理

    文章目录 为什么要进行内存管理 为什么会有Mono和IL2CPP 托管语言 托管代码 Mono IL2CPP 参考 Unity游戏优化第2版 为什么要进行内存管理 内存管理是性能优化的一个重要方面 可能造成性能问题的原因有2个 不必要的内存
  • frp实现内网穿透

    文章目录 一 frp是什么 二 使用步骤 1 需要两台服务器 2 下载frp 和go语言 基于 1 通过自定义域名访问内网的 Web 服务 启动 windows下安装frpc ini 2 配置token才能访问 3 配置udp 4 通过 S
  • 字符数组与字符指针的区别

    字符数组与字符指针的区别 在 C 语言中 可以用两种方法表示和存放字符串 1 用字符数组存放一个字符串 char str IloveChina 2 用字符指针指向一个字符串 char str IloveChina 那么这两种表示方式有什么区
  • 内网渗透之信息收集

    一 内网信息收集概述 渗透测试人员进人内网后 面对的是一片 黑暗森林 所以 渗透测试人员首先需要对当前所处的网络环境进行判断 判断涉及如下三个方面 我是谁 一对 当前机器角色的判断 这是哪 一对 当前机器所处网络环境的拓扑结构进行分析和判断
  • Stm32最小系统板电路图设计、PCB设计

    目录 一 电路设计 1 复位电路 2 时钟电路 3 电源电路 4 SWD接口电路 5 BOOT启动电路 二 原理图绘制 1 工程的建立 2 原理图的绘制 2 1 使用已有库绘制原理图 2 2 构建原理图库 2 3 整体原理图 三 PCB绘制
  • Java堆和栈

    Java堆和栈是Java程序中两个重要的数据结构 它们在程序的运行过程中发挥着重要的作用 本文将介绍Java堆和栈的基本概念 区别 操作以及应用场景 帮助读者更好地理解和应用这两个数据结构 一 基本概念 Java堆 Heap 和栈 Stac
  • vue+elementui 登录页面

    vue elementui 登录页面 html代码
  • Windows 终端 Terminal 配置

    文章目录 Windows 终端 Terminal 配置 修改默认启动的命令 添加 cmder 到 Windows Terminal 添加 git bash 到 Windows Terminal 为Windows PowerShell 配置别
  • vue3.0+elementplus table动态添加column

  • 【Vuex】前后端分离Vue路由拦截器与登录cookie保存

    文章目录 1 Vuex 初探 1 1 vuex 介绍 1 2 store 的使用 2 localStorage使用 2 1 localStorage介绍 2 2 localStorage语法 3 路由钩子函数 导航守卫 3 1 导航守卫介绍
  • 固定资产预算怎么管理的

    在现代企业管理中 固定资产预算的管理是一项至关重要的任务 它不仅关系到企业的经济效益 更关系到企业的长远发展 那么 如何进行有效的固定资产预算管理呢 明确固定资产预算的目标和原则 我们需要明确固定资产预算的目标和原则 固定资产预算的目标应该
  • 利用谷歌的预训练模型实现目标检测object_detection_tutorial.ipynb

    环境准备 运行这个预训练的模型需要准备一些环境 首先需要下载谷歌的models master zip 地址在https github com Master Chen models 下载完成后我们需要的是research objection