根据文档,tf.nn.softmax_cross_entropy_with_logits https://www.tensorflow.org/api_docs/python/nn/classification#softmax_cross_entropy_with_logits必须使用有效的概率分布来调用labels
,否则计算将不正确,并使用tf.nn.sparse_softmax_cross_entropy_with_logits https://www.tensorflow.org/api_docs/python/nn/classification#sparse_softmax_cross_entropy_with_logits(这在您的情况下可能更方便)带有负标签将导致错误或返回 NaN 值。我不会依赖它来忽略一些标签。
我要做的是将被忽略类的 logits 替换为那些像素中的无穷大,其中正确的类是被忽略的类,因此它们不会对损失产生任何影响:
ignore_label = ...
# Make zeros everywhere except for the ignored label
input_batch_ignored = tf.concat(input_batch.ndims - 1,
[tf.zeros_like(input_batch[:, :, :, :ignore_label]),
tf.expand_dims(input_batch[:, :, :, ignore_label], -1),
tf.zeros_like(input_batch[:, :, :, ignore_label + 1:])])
# Make corresponding logits "infinity" (a big enough number)
predictions_fix = tf.select(input_batch_ignored > 0,
1e30 * tf.ones_like(predictions), predictions)
# Compute loss with fixed logits
loss = tf.nn.softmax_cross_entropy_with_logits(prediction, gt)
唯一的问题是,您正在考虑被忽略类的像素总是被正确预测,这意味着包含大量此类像素的图像的损失将人为地变小。根据具体情况,这可能很重要,也可能不重要,但如果你想真正准确,你必须根据未忽略的像素数量对每个图像的损失进行加权,而不是仅仅取平均值。
# Count relevant pixels on each image
input_batch_relevant = 1 - input_batch_ignored
input_batch_weight = tf.reduce_sum(input_batch_relevant, [1, 2, 3])
# Compute relative weights
input_batch_weight = input_batch_weight / tf.reduce_sum(input_batch_weight)
# Compute reduced loss according to weights
reduced_loss = tf.reduce_sum(loss * input_batch_weight)