利用Google已经训练好的inception_v3模型,修改最后一层,训练我们自己的模型。
在学习inception_v3的过程中找了很多资料,我把这些资料有用的地方进行了总结。
1.训练
<下载retrain.py文件>
这是下载链接 https://github.com/tensorflow/tensorflow/tree/r1.1
根据自己的tensorflow的版本在branch中选择自己的版本。并按照tensorflow—examples—image_retraining—retrain.py寻找即可
注:将此py文件的几处根据自己的需要更改下
训练完的结果:在tmp文件夹下有.pb文件
2.测试
```python
"""
use_output_graph
使用retrain所训练的迁移后的inception模型来测试
"""
import tensorflow as tf
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
model_name = './tmp/output_graph.pb'
image_dir = './test/'
label_filename = './tmp/output_labels.txt'
def create_graph():
with tf.gfile.FastGFile( model_name, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
def load_labels(label_file_dir):
labels = []
if not tf.gfile.Exists(label_file_dir):
tf.logging.fatal('File does not exist %s', label_file_dir)
else:
labels = tf.gfile.GFile(label_file_dir).readlines()
for i in range(len(labels)):
labels[i] = labels[i].strip('\n')
return labels
create_graph()
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
for root, dirs, files in os.walk(image_dir):
for file in files:
image_data = tf.gfile.FastGFile(os.path.join(root, file), 'rb').read()
predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
image_path = os.path.join(root, file)
print(image_path)
img = Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show()
top_5 = predictions.argsort()[-5:][::-1]
for label_index in top_5:
label_name = load_labels(label_filename)[label_index]
label_score = predictions[label_index]
print('%s (score = %.5f)' % (label_name, label_score))
print()
3.结果
我做的是一个猫狗的识别,所以训练的依次用了200多张的猫和200多张的狗图片,利用迁移学习inception_v3训练自己的模型,要比自己训练自己的模型识别率高很多。
这是我测试的一个图片:
测试结果如下:
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)