我找到了两种在 Tensorflow 中保存模型的方法:tf.train.Saver()
and SavedModelBuilder
。然而,我找不到有关使用该模型的文档以第二种方式加载后。
注:我想用SavedModelBuilder
方式,因为我用 Python 训练模型,并将在另一种语言(Go)的服务时间使用它,看起来SavedModelBuilder
在这种情况下是唯一的方法。
这非常适合tf.train.Saver()
(第一种方式):
model = tf.add(W * x, b, name="finalnode")
# save
saver = tf.train.Saver()
saver.save(sess, "/tmp/model")
# load
saver.restore(sess, "/tmp/model")
# IMPORTANT PART: REALLY USING THE MODEL AFTER LOADING IT
# I CAN'T FIND AN EQUIVALENT OF THIS PART IN THE OTHER WAY.
model = graph.get_tensor_by_name("finalnode:0")
sess.run(model, {x: [5, 6, 7]})
tf.saved_model.builder.SavedModelBuilder()
定义在Readme https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/saved_model/但加载模型后tf.saved_model.loader.load(sess, [], export_dir)
),我找不到有关返回节点的文档(请参阅"finalnode"
在上面的代码中)
缺少的是signature
# Saving
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess, ["tag"], signature_def_map= {
"model": tf.saved_model.signature_def_utils.predict_signature_def(
inputs= {"x": x},
outputs= {"finalnode": model})
})
builder.save()
# loading
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ["tag"], export_dir)
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("x:0")
model = graph.get_tensor_by_name("finalnode:0")
print(sess.run(model, {x: [5, 6, 7, 8]}))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)