Saver类位于tf.train中,属于训练过程中要用到的方法,主要作用就是保存和加载save & restore ckpt。
最简单的保存应用举例:
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
正常训练的过程应用:
saver = tf.train.Saver(...variables...)
sess = tf.Session()
for step in xrange(1000000):
sess.run(..training_op..)
if step % 1000 == 0:
saver.save(sess, 'my-model', global_step=step)
一、Saver构造函数__init__
众多参数可传入,但用到的不多:
__init__(
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=tf.train.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None
)
二、Properties
Saver的Properties仅有一个:last_checkpoints,返回目前还未删除的ckpt的文件列表,按照从旧到新的顺序。
三、Methods
1.save()
保存变量,此函数需要一个session,其中含有launched graph,所有变量应已经被初始化。
save(
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix='meta',
write_meta_graph=True,
write_state=True,
strip_default_attrs=False,
save_debug_info=False
)
保存后目录下会有几个文件:
- checkpoint :训练过程中自动生成的文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model。
- MyModel.meta:包含全部graph信息。这是一个序列化的MetaGraphDef protocol buffer,包含数据流、变量的annotations、input pipelines,以及其他相关信息。
- MyModel.data-00000-of-00001:包含所有变量的值(weights, biases, placeholders, gradients, hyper-parameters etc),也就是模型训练好参数和其他值
- MyModel.index:?
其中meta信息含有graph信息,因此在重新导入图是不需要手工从头开始构建图,而是直接导入meta信息。具体做法如下:
saver = tf.compat.v1.train.Saver(...variables...)
tf.compat.v1.add_to_collection('train_op', train_op)
sess = tf.Session()
for step in xrange(1000000):
sess.run(train_op)
if step % 1000 == 0:
saver.save(sess, 'my-model', global_step=step)
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
new_saver.restore(sess, 'my-save-dir/my-model-10000')
train_op = tf.get_collection('train_op')[0]
for step in xrange(1000000):
sess.run(train_op)
更多使用方法:[🔗]
2.restore()
直接运行装载变量的op,需要一个包含launched graph的session。graph中的变量可以不被初始化,装载操作可以算做初始化。
restore(
sess,
save_path
)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)