最近的一篇论文(here http://ydwen.github.io/papers/WenECCV16.pdf)引入了一种称为中心损失的二次损失函数。它基于批次中嵌入之间的距离以及每个相应类的运行平均嵌入。 TF Google 群组中有一些讨论(here https://groups.google.com/a/tensorflow.org/forum/#!topic/discuss/0Am9FCdFAxg)关于如何计算和更新此类嵌入中心。我在下面的答案中整理了一些代码来生成类平均嵌入。
这是最好的方法吗?
对于像中心损失这样的情况来说,之前发布的方法过于简单,随着模型变得更加精细,嵌入的预期值会随着时间的推移而变化。这是因为之前的中心查找例程对自启动以来的所有实例进行平均,因此跟踪预期值的变化非常缓慢。相反,移动窗口平均值是首选。指数移动窗口变体如下:
def get_embed_centers(embed_batch, label_batch):
''' Exponential moving window average. Increase decay for longer windows [0.0 1.0]
'''
decay = 0.95
with tf.variable_scope('embed', reuse=True):
embed_ctrs = tf.get_variable("ctrs")
label_batch = tf.reshape(label_batch, [-1])
old_embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
dif = (1 - decay) * (old_embed_ctrs_batch - embed_batch)
embed_ctrs = tf.scatter_sub(embed_ctrs, label_batch, dif)
embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
return embed_ctrs_batch
with tf.Session() as sess:
with tf.variable_scope('embed'):
embed_ctrs = tf.get_variable("ctrs", [nclass, ndims], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
label_batch_ph = tf.placeholder(tf.int32)
embed_batch_ph = tf.placeholder(tf.float32)
embed_ctrs_batch = get_embed_centers(embed_batch_ph, label_batch_ph)
sess.run(tf.initialize_all_variables())
tf.get_default_graph().finalize()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)