TensorFlow:如何根据历元设置学习率衰减?

2024-01-13

学习率衰减函数tf.train.exponential_decay需要一个decay_steps范围。每隔一段时间降低学习率num_epochs,你会设置decay_steps = num_epochs * num_train_examples / batch_size。然而,当读取数据时.tfrecords文件中,你不知道里面有多少个训练样例。

To get num_train_examples, 你可以:

  • 设置一个tf.string_input_producer with num_epochs=1.
  • 运行这个tf.TFRecordReader/tf.parse_single_example.
  • 循环并计算在停止之前产生一些输出的次数。

然而,这不是很优雅。

有没有更简单的方法来获取训练样本的数量.tfrecords根据时期而不是步骤来归档或设置学习率衰减?


您可以使用以下代码来获取a中的记录数.tfrecords file :

def get_num_records(tf_record_file):
  return len([x for x in tf.python_io.tf_record_iterator(tf_record_file)])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorFlow:如何根据历元设置学习率衰减? 的相关文章

随机推荐