更新(2018/03/19):我写了一个博客文章 https://omoindrot.github.io/triplet-loss详细介绍了如何在 TensorFlow 中实现三元组损失。
您需要自己实现对比损失或三元组损失,但是一旦您知道对或三元组,这就很容易了。
对比损失
假设您有数据对及其标签(正或负,即同一类或不同类)作为输入。例如,您有大小为 28x28x1 的图像作为输入:
left = tf.placeholder(tf.float32, [None, 28, 28, 1])
right = tf.placeholder(tf.float32, [None, 28, 28, 1])
label = tf.placeholder(tf.int32, [None, 1]). # 0 if same, 1 if different
margin = 0.2
left_output = model(left) # shape [None, 128]
right_output = model(right) # shape [None, 128]
d = tf.reduce_sum(tf.square(left_output - right_output), 1)
d_sqrt = tf.sqrt(d)
loss = label * tf.square(tf.maximum(0., margin - d_sqrt)) + (1 - label) * d
loss = 0.5 * tf.reduce_mean(loss)
三重态损失
与对比损失相同,但具有三元组(锚定、正、负)。这里不需要标签。
anchor_output = ... # shape [None, 128]
positive_output = ... # shape [None, 128]
negative_output = ... # shape [None, 128]
d_pos = tf.reduce_sum(tf.square(anchor_output - positive_output), 1)
d_neg = tf.reduce_sum(tf.square(anchor_output - negative_output), 1)
loss = tf.maximum(0., margin + d_pos - d_neg)
loss = tf.reduce_mean(loss)
在 TensorFlow 中实现三元组损失或对比损失时真正的麻烦是如何对三元组或对进行采样。我将专注于生成三元组,因为它比生成对更难。
最简单的方法是在 Tensorflow 图之外(即在 Python 中)生成它们,并通过占位符将它们提供给网络。基本上,您一次选择 3 个图像,前两个来自同一类别,第三个来自另一个类别。然后,我们对这些三元组执行前馈,并计算三元组损失。
这里的问题是生成三元组很复杂。我们希望他们成为有效的三元组,具有正损失的三元组(否则损失为 0 并且网络不会学习)。
要知道一个三元组是否好,您需要计算它的损失,因此您已经通过网络进行了一个前馈......
显然,在 Tensorflow 中实现三元组损失很困难,并且有一些方法可以使其比在 Python 中采样更有效,但解释它们需要整篇博客文章!