【损失函数】图像分割损失CELoss中添加 OHEM

2023-05-16

        语义分割中常用交叉熵损失CE,在应用中通常添加OHEM以获取更好的收敛(经验阈值是0.7),这里OHEM思想的来源是topk loss,其介绍参考【论文-损失函数】Learning with Average Top-k Loss_there2belief的博客-CSDN博客

        这种添加OHEM的损失叫OhemCELoss(Online hard example mining cross-entropy loss),其中 Online hard example mining 的意思是,在训练过程中关注 hard example ,对其施加更高权重的一种训练策略。cross-entropy loss 就是普通的交叉熵损失函数。下面具体看下如何添加。        

cross-entropy loss

        首先回顾一下多分类问题的 cross-entropy loss 的公式:

        其中 为预测值, 为真实类别,大括号里面计算了样本在真实类别上的概率,这个概率越大(越接近 1),其负对数就越接近 0;反之,这个概率越小,其负对数就越接近正无穷。

pytorch 下的接口为:

torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

  • weight (Tensor, optional) 表示每个类别设置的权重,用在类别不平衡的数据集上
  • size_average (bool, optional) ,True 计算整个 batch 上的平均值作为 loss,False 计算batch loss 之和。
  • ignore_index (int, optional) 某个 ground truth 值被忽略,即当 ground truth 为它时,损失直接设为 0。
  • reduce (bool, optional) 默认情况下,计算的 loss 为一个 batch 中每个元素的 loss 之和或者平均,设置为 False 后,返回每个元素的 loss。
  • reduction (string, optional) 设置不同的输出格式,可以设为 'none' | 'mean' | 'sum',由于它的存在,size_averagereduce 将被淘汰(size_average=True 相当于 'mean',size_average=False 相当于 'sum',reduce=False 相当于 'none')。

        代码示例如下:

loss = nn.CrossEntropyLoss()
x = torch.randn(3, 5)
label = torch.empty(3, dtype=torch.long).random_(5)
out = loss(x, label)

        为了深入理解原理,也可以自己实现:

x = torch.randn(3, 5)
label = torch.empty(3, dtype=torch.long).random_(5)
exp_x = torch.exp(x)
prob_x = exp_x/exp_x.sum(dim=1, keepdim=True)
prob_x = torch.gather(prob_x, dim=1, index=label.unsqueeze(1))
# reduction = 'none'
print(-torch.log(prob_x))
# reduction = 'sum'
print(-torch.log(prob_x).sum())
# reduction = 'mean'
print(-torch.log(prob_x).mean())

OhemCELoss

先贴代码

class OhemCELoss(nn.Module):
    def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
        super(OhemCELoss, self).__init__()
        self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
        self.n_min = n_min
        self.ignore_lb = ignore_lb
        self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')

    def forward(self, logits, labels):
        N, C, H, W = logits.size()

        # OHEM here
        loss = self.criteria(logits, labels).view(-1)  # to 1-D
        loss, _ = torch.sort(loss, descending=True)  # sort
        if loss[self.n_min] > self.thresh:
            loss = loss[loss>self.thresh]
        else:
            loss = loss[:self.n_min]

        return torch.mean(loss)

        该损失函数有几个超参数:threshn_min,其中 thresh 表示的是,损失函数大于多少的时候,会被用来做反向传播。n_min 表示的是,在一个 batch 中,最少需要考虑多少个样本。
需要注意一点,参数 thresh 是概率,即 小于这个概率的预测值会参与计算损失。

  • self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)) 将概率转化为其对应的 loss 。
  • 代码中的 logits 维度为 NCHW,labels 维度为NHW 。
  • self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') 设置 reduction 为 none,保留每个元素的损失,返回的维度为 NHW。
  • loss = self.criteria(logits, labels).view(-1) 将预测的损失拉平为一个长向量,每个元素为一个 pixel 的损失。
  • loss, _ = torch.sort(loss, descending=True) 将长向量中每个 pixel 的损失按从大到小排序。
  • if loss[self.n_min] > self.thresh: 最少考虑 n_min 个损失最大的 pixel,如果前 n_min 个损失中最小的那个的损失仍然大于设定的阈值,那么取实际所有大于该阈值的元素计算损失:loss=loss[loss>thresh]
  • 否则,计算前 n_min 个损失:loss = loss[:self.n_min]
  • 最后,求这些 hard example 的损失的均值作为最终损失:torch.mean(loss)

总结

        在图像分割问题中,以每个 pixel 的损失为最小单元,而不是 batch 中每张图片。因此排序时需要把 batch 中所有 pixel 拉成一个长向量,再取其中大于阈值的 pixel 作为 hard example。同时,n_min 的设置保证了每个 batch 中都有至少 n_min 个pixel 参与训练,从而一定程度巩固了训练结果,让前向传播不至于空耗。


修改自:https://www.jianshu.com/p/24376b18e5c7

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【损失函数】图像分割损失CELoss中添加 OHEM 的相关文章

随机推荐