根据 keveman 的回答,我创建了一个 python 脚本,您可以执行该脚本来重命名任何 TensorFlow 检查点的变量:
https://gist.github.com/batzner/7c24802dd9c5e15870b4b56e22135c96 https://gist.github.com/batzner/7c24802dd9c5e15870b4b56e22135c96
您可以替换变量名称中的子字符串并向所有名称添加前缀。调用脚本
python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir
带有可选参数
--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run
这是该脚本的核心功能:
def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run=False):
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
with tf.Session() as sess:
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
# Load the variable
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
# Set the new name
new_name = var_name
if None not in [replace_from, replace_to]:
new_name = new_name.replace(replace_from, replace_to)
if add_prefix:
new_name = add_prefix + new_name
if dry_run:
print('%s would be renamed to %s.' % (var_name, new_name))
else:
print('Renaming %s to %s.' % (var_name, new_name))
# Rename the variable
var = tf.Variable(var, name=new_name)
if not dry_run:
# Save the variables
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess, checkpoint.model_checkpoint_path)
Example:
python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir --replace_from=scope1 --replace_to=scope1/model --add_prefix=abc/
将重命名变量scope1/Variable1
to abc/scope1/model/Variable1
.