以下代码(复制/粘贴可运行)说明了如何使用tf.layers.batch_normalization
.
import tensorflow as tf
bn = tf.layers.batch_normalization(tf.constant([0.0]))
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
> [] # UPDATE_OPS collection is empty
使用 TF 1.5,文档(下面引用)明确指出UPDATE_OPS 不应为空在这种情况下 (https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization):
注意:训练时,moving_mean和moving_variance需要
更新。默认情况下,更新操作位于tf.GraphKeys.UPDATE_OPS
,因此需要将它们添加为依赖项
train_op。例如:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
只需将您的代码更改为训练模式(通过设置training
标记为True
)如中提到的quote https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization:
注:当training, moving_mean 和 moving_variance 需要更新。默认情况下,更新操作放置在 tf.GraphKeys.UPDATE_OPS 中,因此需要将它们添加为 train_op 的依赖项。
import tensorflow as tf
bn = tf.layers.batch_normalization(tf.constant([0.0]), training=True)
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
将输出:
[< tf.Tensor 'batch_normalization/AssignMovingAvg:0' shape=(1,) dtype=float32_ref>,
< tf.Tensor 'batch_normalization/AssignMovingAvg_1:0' shape=(1,) dtype=float32_ref>]
Gamma 和 Beta 最终出现在 TRAINABLE_VARIABLES 集合中:
print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
[<tf.Variable 'batch_normalization/gamma:0' shape=(1,) dtype=float32_ref>,
<tf.Variable 'batch_normalization/beta:0' shape=(1,) dtype=float32_ref>]
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)