使用TensorFlow Object Detection API进行图像物体检测

2023-05-16

参考 https://github.com/tensorflow/models/tree/master/object_detection

使用TensorFlow Object Detection API进行图像物体检测

准备

  1. 安装TensorFlow

    参考 https://www.tensorflow.org/install/

    如在Ubuntu下安装TensorFlow with GPU support, python 2.7版本

    wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
    pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
  2. 配置TensorFlow Models

    • 下载TensorFlow Models
    git clone https://github.com/tensorflow/models.git
    • 编译protobuf
    
    # From tensorflow/models/
    
    protoc object_detection/protos/*.proto --python_out=.

    生成若干py文件在object_detection/protos/

    • 添加PYTHONPATH
    
    # From tensorflow/models/
    
    export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
    • 测试
    
    # From tensorflow/models/
    
    python object_detection/builders/model_builder_test.py

    若成功,显示OK

  3. 准备数据

    参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/preparing_inputs.md

    这里以PASCAL VOC 2012为例。

    • 下载并解压
    
    # From tensorflow/models
    
    wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
    tar -xvf VOCtrainval_11-May-2012.tar
    • 生成TFRecord
    
    # From tensorflow/models
    
    mkdir VOC2012
    python object_detection/create_pascal_tf_record.py \
        --label_map_path=object_detection/data/pascal_label_map.pbtxt \
        --data_dir=VOCdevkit --year=VOC2012 --set=train \
        --output_path=VOC2012/pascal_train.record
    python object_detection/create_pascal_tf_record.py \
        --label_map_path=object_detection/data/pascal_label_map.pbtxt \
        --data_dir=VOCdevkit --year=VOC2012 --set=val \
        --output_path=VOC2012/pascal_val.record

    得到pascal_train.recordpascal_val.record

    如果需要用自己的数据,则参考create_pascal_tf_record.py编写处理数据生成TFRecord的脚本。可参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/using_your_own_dataset.md

  4. (可选)下载模型

    官方提供了不少预训练模型( https://github.com/tensorflow/models/blob/master/object_detection/g3doc/detection_model_zoo.md ),这里以ssd_mobilenet_v1_coco以例。

    
    # From tensorflow/models
    
    wget http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_11_06_2017.tar.gz
    tar zxf ssd_mobilenet_v1_coco_11_06_2017.tar.gz

训练

如果使用现有模型进行预测则不需要训练。

  1. 文件结构

    为了方便查看文件,使用以下文件结构。

    models
    ├── object_detection
    │   ├── VOC2012
    │   │   ├── ssd_mobilenet_train_logs
    │   │   ├── ssd_mobilenet_val_logs
    │   │   ├── ssd_mobilenet_v1_voc2012.config
    │   │   ├── pascal_label_map.pbtxt
    │   │   ├── pascal_train.record
    │   │   └── pascal_val.record
    │   ├── infer.py
    │   └── create_pascal_tf_record.py
    ├── eval_voc2012.sh
    └── train_voc2012.sh
  2. 配置

    参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/configuring_jobs.md

    这里使用SSD w/MobileNet,把object_detection/samples/configs/ssd_mobilenet_v1_pets.config复制到object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config

    修改第9行为num_classes: 20

    修改第158行为fine_tune_checkpoint: "object_detection/ssd_mobilenet_v1_coco_11_06_2017/model.ckpt"

    修改第177行为input_path: "object_detection/VOC2012/pascal_train.record"

    修改第179行和193行为label_map_path: "object_detection/data/pascal_label_map.pbtxt"

    修改第191行为input_path: "object_detection/VOC2012/pascal_val.record"

  3. 训练

    新建tensorflow/models/train_voc2012.sh,内容以下:

    python object_detection/train.py \
        --logtostderr \
        --pipeline_config_path=object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \
        --train_dir=object_detection/VOC2012/ssd_mobilenet_train_logs \
        2>&1 | tee object_detection/VOC2012/ssd_mobilenet_train_logs.txt &

    进入tensorflow/models/,运行./train_voc2012.sh即可训练。

  4. 验证

    可一边训练一边验证,注意使用其它的GPU或合理分配显存。

    新建tensorflow/models/eval_voc2012.sh,内容以下:

    python object_detection/eval.py \
        --logtostderr \
        --pipeline_config_path=object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \
        --checkpoint_dir=object_detection/VOC2012/ssd_mobilenet_train_logs \
        --eval_dir=object_detection/VOC2012/ssd_mobilenet_val_logs &

    进入tensorflow/models/,运行CUDA_VISIBLE_DEVICES="1" ./train_voc2012.sh即可验证(这里指定了第二个GPU)。

  5. 可视化log

    可一边训练一边可视化训练的log,可看到Loss趋势。

    tensorboard --logdir ssd_mobilenet_train_logs/

    可视化验证的log,可看到Precision/mAP@0.5IOU的趋势以及具体image的预测结果。

    tensorboard --logdir ssd_mobilenet_val_logs/ --port 6007

测试

  1. 导出模型

    训练完成后得到一些checkpoint文件在ssd_mobilenet_train_logs中,如:

    • graph.pbtxt
    • model.ckpt-200000.data-00000-of-00001
    • model.ckpt-200000.info
    • model.ckpt-200000.meta

    其中meta保存了graph和metadata,ckpt保存了网络的weights。

    而进行预测时只需模型和权重,不需要metadata,故可使用官方提供的脚本生成推导图。

    python object_detection/export_inference_graph.py \
        --input_type image_tensor \
        --pipeline_config_path object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \
        --trained_checkpoint_prefix object_detection/VOC2012/ssd_mobilenet_train_logs/model.ckpt-200000 \
        --output_directory object_detection/VOC2012
  2. 测试图片

    • 运行object_detection_tutorial.ipynb并修改其中的各种路径即可。

    • 或自写编译inference脚本,如tensorflow/models/object_detection/infer.py

      import sys
      sys.path.append('..')
      import os
      import time
      import tensorflow as tf
      import numpy as np
      from PIL import Image
      from matplotlib import pyplot as plt
      
      from utils import label_map_util
      from utils import visualization_utils as vis_util
      
      PATH_TEST_IMAGE = sys.argv[1]
      PATH_TO_CKPT = 'VOC2012/frozen_inference_graph.pb'
      PATH_TO_LABELS = 'VOC2012/pascal_label_map.pbtxt'
      NUM_CLASSES = 21
      IMAGE_SIZE = (18, 12)
      
      label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
      categories = label_map_util.convert_label_map_to_categories(
          label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
      category_index = label_map_util.create_category_index(categories)
      
      detection_graph = tf.Graph()
      with detection_graph.as_default():
          od_graph_def = tf.GraphDef()
          with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
              serialized_graph = fid.read()
              od_graph_def.ParseFromString(serialized_graph)
              tf.import_graph_def(od_graph_def, name='')
      
      config = tf.ConfigProto()
      config.gpu_options.allow_growth = True
      
      with detection_graph.as_default():
          with tf.Session(graph=detection_graph, config=config) as sess:
              start_time = time.time()
              print(time.ctime())
              image = Image.open(PATH_TEST_IMAGE)
              image_np = np.array(image).astype(np.uint8)
              image_np_expanded = np.expand_dims(image_np, axis=0)
              image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
              boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
              scores = detection_graph.get_tensor_by_name('detection_scores:0')
              classes = detection_graph.get_tensor_by_name('detection_classes:0')
              num_detections = detection_graph.get_tensor_by_name('num_detections:0')
              (boxes, scores, classes, num_detections) = sess.run(
                  [boxes, scores, classes, num_detections],
                  feed_dict={image_tensor: image_np_expanded})
              print('{} elapsed time: {:.3f}s'.format(time.ctime(), time.time() - start_time))
              vis_util.visualize_boxes_and_labels_on_image_array(
                  image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores),
                  category_index, use_normalized_coordinates=True, line_thickness=8)
              plt.figure(figsize=IMAGE_SIZE)
              plt.imshow(image_np)

      运行infer.py test_images/image1.jpg即可

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用TensorFlow Object Detection API进行图像物体检测 的相关文章

随机推荐

  • 上传文件超过限制,造成长时间无响应的解决方案

    在上传大文件 xff0c 造成长时间没有响应的情况的解决方案 xff1a 上传大文件时 xff0c 因为http协议的响应问题 xff0c 造成长时间不能向客户端发送响应请求头 解决方案 xff1a 1 向服务器发送上传大文件的reques
  • checkbox的jsTree的一个调用

    lt DOCTYPE HTML PUBLIC 34 W3C DTD HTML 4 01 Transitional EN 34 gt lt html gt lt head gt lt meta http equiv 61 34 Content
  • 灵活使用递归算法,生成Excel文件中的复合表头

    最近 xff0c 在开发中 xff0c 需要导出数据到excel文件 xff0c 文件的表头的格式是不一致的 有复合表头 xff0c 也有单表头 xff0c 那么如何灵活地生成excel文件中的复合表头 首先有一个JSON字符串格式的字段描
  • 在 ibm http server 和 websphere 之间配置 ssl

    在WebSphere的环境中 xff0c 配置SSL xff0c 有一些细节需要注意 xff1a 1 最好是先安装 ibm http server7 32bit xff0c websphere7 再安装插件 2 http server 需要
  • Ext4使用总结(二)简单的hbox布局

    布局的合理利用 xff1a 如图 xff1a xtype 39 container 39 margins 39 5 0 0 0 39 layout align 39 stretch 39 type 39 hbox 39
  • 软件开发者的精力管理(一)

    精力管理对于软件开发者来讲是非常重要的 不希望自己被长周期的项目拖垮 xff0c 不希望被连续的加班所累 我个人认为泛义的时间管理是涉及到多个方面的 而心理学 精力管理则是非常重要的 作为一名从事了多年软件开发的从业者 xff0c 我的一个
  • 如何高效能地学习和使用"工具"?

    在软件开发中 xff0c 应该注意工具的合理使用 xff0c 使得自己变得高效起来 1 工具也是产品 xff0c 有许多的工具是产品化的 既然是产品 xff0c 就很多的服务 xff0c 例如帮助文档 xff0c 论坛 xff0c 咨询人员
  • Ext4使用总结(十二) 采用 CellEditing 方式的Grid,如何取得修改的单元格数据值

    使用cellediting方式编辑数据的grid在保存数据时 xff0c 需要进行数据的处理 xff0c 所以数据处理的方式需要特别注意 cellEditing 插件的事件 listeners edit function editor e
  • 「Ubuntu」Ubuntu中的python终端配置(修改终端默认python配置,软连接,不同版本python环境配置)

    前言 通过这篇博客 xff08 Ubuntu安装Python xff09 安装完Python后 xff0c 想要在终端直接启动想启动的python版本 此时直接在终端输入python2或者python3 xff0c 发现系统已经配置好了py
  • [解题报告] CSDN竞赛第15期

    CSDN编程竞赛报名地址 xff1a https edu csdn net contest detail 29 1 求并集 题目 由小到大输出两个单向有序链表的并集 如链表 A 1 gt 2 gt 5 gt 7 链表 B 3 gt 5 gt
  • JSP开发技术四——————EL表达式

    EL xff08 Expression Language xff09 表达式 xff0c 即正则表达式 用来操作字符串 用一些特定的字符来表示一些代码操作 xff0c 这样简化代码书写 学习正则表达式 xff0c 就是学习一些特殊符号的实用
  • [解题报告] CSDN竞赛第17期

    CSDN编程竞赛报名地址 xff1a https edu csdn net contest detail 31 1 判断胜负 题目 已知两个字符串A B 连续进行读入n次 每次读入的字符串都为A B 输出读入次数最多的字符串 解题报告 模拟
  • [解题报告] CSDN竞赛第18期

    CSDN编程竞赛报名地址 xff1a https edu csdn net contest detail 32 1 单链表排序 题目 单链表的节点定义如下 xff08 C 43 43 xff09 xff1a class Node publi
  • [解题报告] CSDN竞赛第22期

    CSDN编程竞赛报名地址 xff1a https edu csdn net contest detail 36 1 c 43 43 难题 大数加法 题目 大数一直是一个c语言的一个难题 现在我们需要你手动模拟出大数加法过程 请你给出两个大整
  • [解题报告] CSDN竞赛第23期

    CSDN编程竞赛报名地址 xff1a https edu csdn net contest detail 37 1 排查网络故障 题目 A地跟B地的网络中间有n个节点 xff08 不包括A地和B地 xff09 xff0c 相邻的两个节点是通
  • CSDN竞赛第24期

    CSDN编程竞赛报名地址 xff1a https edu csdn net contest detail 38 这次写完第一道题时遇到一个奇怪的情况 xff1a 一直在 运行中 xff0c 然后发现每道题输入做任意代码都出现一直运行中 跟小
  • [Python开发] 使用python读取图片的EXIF

    使用python读取图片的EXIF 方法 使用PIL Image读取图片的EXIF 使用https pypi python org pypi ExifRead 读取图片的EXIF xff0c 得到EXIF标签 xff08 dict类型 xf
  • Partial Least Squares Regression 偏最小二乘法回归

    介绍 定义 偏最小二乘回归 多元线性回归分析 43 典型相关分析 43 主成分分析 输入 xff1a n m 的预测矩阵 X n p 的响应矩阵 Y 输出 X 和 Y 的投影 分数 矩阵 T U R n l 目标 xff1a 最大化 cor
  • 使用TensorFlow-Slim进行图像分类

    参考 https github com tensorflow models tree master slim 使用TensorFlow Slim进行图像分类 准备 安装TensorFlow 参考 https www tensorflow o
  • 使用TensorFlow Object Detection API进行图像物体检测

    参考 https github com tensorflow models tree master object detection 使用TensorFlow Object Detection API进行图像物体检测 准备 安装Tensor