参考 https://github.com/tensorflow/models/tree/master/object_detection
使用TensorFlow Object Detection API进行图像物体检测
准备
安装TensorFlow
参考 https://www.tensorflow.org/install/
如在Ubuntu下安装TensorFlow with GPU support, python 2.7版本
wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
配置TensorFlow Models
git clone https://github.com/tensorflow/models.git
protoc object_detection/protos/*.proto --python_out=.
生成若干py文件在object_detection/protos/
。
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
python object_detection/builders/model_builder_test.py
若成功,显示OK
。
准备数据
参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/preparing_inputs.md
这里以PASCAL VOC 2012
为例。
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
tar -xvf VOCtrainval_11-May-2012.tar
mkdir VOC2012
python object_detection/create_pascal_tf_record.py \
--label_map_path=object_detection/data/pascal_label_map.pbtxt \
--data_dir=VOCdevkit --year=VOC2012 --set=train \
--output_path=VOC2012/pascal_train.record
python object_detection/create_pascal_tf_record.py \
--label_map_path=object_detection/data/pascal_label_map.pbtxt \
--data_dir=VOCdevkit --year=VOC2012 --set=val \
--output_path=VOC2012/pascal_val.record
得到pascal_train.record
和pascal_val.record
。
如果需要用自己的数据,则参考create_pascal_tf_record.py
编写处理数据生成TFRecord的脚本。可参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/using_your_own_dataset.md
(可选)下载模型
官方提供了不少预训练模型( https://github.com/tensorflow/models/blob/master/object_detection/g3doc/detection_model_zoo.md ),这里以ssd_mobilenet_v1_coco
以例。
wget http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_11_06_2017.tar.gz
tar zxf ssd_mobilenet_v1_coco_11_06_2017.tar.gz
训练
如果使用现有模型进行预测则不需要训练。
文件结构
为了方便查看文件,使用以下文件结构。
models
├── object_detection
│ ├── VOC2012
│ │ ├── ssd_mobilenet_train_logs
│ │ ├── ssd_mobilenet_val_logs
│ │ ├── ssd_mobilenet_v1_voc2012.config
│ │ ├── pascal_label_map.pbtxt
│ │ ├── pascal_train.record
│ │ └── pascal_val.record
│ ├── infer.py
│ └── create_pascal_tf_record.py
├── eval_voc2012.sh
└── train_voc2012.sh
配置
参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/configuring_jobs.md
这里使用SSD w/MobileNet,把object_detection/samples/configs/ssd_mobilenet_v1_pets.config
复制到object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config
修改第9行为num_classes: 20
。
修改第158行为fine_tune_checkpoint: "object_detection/ssd_mobilenet_v1_coco_11_06_2017/model.ckpt"
修改第177行为input_path: "object_detection/VOC2012/pascal_train.record"
修改第179行和193行为label_map_path: "object_detection/data/pascal_label_map.pbtxt"
修改第191行为input_path: "object_detection/VOC2012/pascal_val.record"
训练
新建tensorflow/models/train_voc2012.sh
,内容以下:
python object_detection/train.py \
--logtostderr \
--pipeline_config_path=object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \
--train_dir=object_detection/VOC2012/ssd_mobilenet_train_logs \
2>&1 | tee object_detection/VOC2012/ssd_mobilenet_train_logs.txt &
进入tensorflow/models/
,运行./train_voc2012.sh
即可训练。
验证
可一边训练一边验证,注意使用其它的GPU或合理分配显存。
新建tensorflow/models/eval_voc2012.sh
,内容以下:
python object_detection/eval.py \
--logtostderr \
--pipeline_config_path=object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \
--checkpoint_dir=object_detection/VOC2012/ssd_mobilenet_train_logs \
--eval_dir=object_detection/VOC2012/ssd_mobilenet_val_logs &
进入tensorflow/models/
,运行CUDA_VISIBLE_DEVICES="1" ./train_voc2012.sh
即可验证(这里指定了第二个GPU)。
可视化log
可一边训练一边可视化训练的log,可看到Loss趋势。
tensorboard --logdir ssd_mobilenet_train_logs/
可视化验证的log,可看到Precision/mAP@0.5IOU
的趋势以及具体image的预测结果。
tensorboard --logdir ssd_mobilenet_val_logs/ --port 6007
测试
导出模型
训练完成后得到一些checkpoint文件在ssd_mobilenet_train_logs
中,如:
- graph.pbtxt
- model.ckpt-200000.data-00000-of-00001
- model.ckpt-200000.info
- model.ckpt-200000.meta
其中meta保存了graph和metadata,ckpt保存了网络的weights。
而进行预测时只需模型和权重,不需要metadata,故可使用官方提供的脚本生成推导图。
python object_detection/export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \
--trained_checkpoint_prefix object_detection/VOC2012/ssd_mobilenet_train_logs/model.ckpt-200000 \
--output_directory object_detection/VOC2012
测试图片
运行object_detection_tutorial.ipynb
并修改其中的各种路径即可。
或自写编译inference脚本,如tensorflow/models/object_detection/infer.py
import sys
sys.path.append('..')
import os
import time
import tensorflow as tf
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from utils import label_map_util
from utils import visualization_utils as vis_util
PATH_TEST_IMAGE = sys.argv[1]
PATH_TO_CKPT = 'VOC2012/frozen_inference_graph.pb'
PATH_TO_LABELS = 'VOC2012/pascal_label_map.pbtxt'
NUM_CLASSES = 21
IMAGE_SIZE = (18, 12)
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with detection_graph.as_default():
with tf.Session(graph=detection_graph, config=config) as sess:
start_time = time.time()
print(time.ctime())
image = Image.open(PATH_TEST_IMAGE)
image_np = np.array(image).astype(np.uint8)
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
print('{} elapsed time: {:.3f}s'.format(time.ctime(), time.time() - start_time))
vis_util.visualize_boxes_and_labels_on_image_array(
image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores),
category_index, use_normalized_coordinates=True, line_thickness=8)
plt.figure(figsize=IMAGE_SIZE)
plt.imshow(image_np)
运行infer.py test_images/image1.jpg
即可
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)