OHEM loss 源代码

2023-05-16

import numpy as np
import torch
import torch.nn as nn

class OhemCELoss(nn.Module):

    def __init__(self, thresh, ignore_lb=255):
        super(OhemCELoss, self).__init__()
        self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float))
        self.ignore_lb = ignore_lb
        self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')

    def forward(self, logits, labels):
        n_min = labels[labels != self.ignore_lb].numel() // 16
        loss = self.criteria(logits, labels).view(-1)
        loss_hard = loss[loss > self.thresh]
        if loss_hard.numel() < n_min:
            loss_hard, _ = loss.topk(n_min)
        return torch.mean(loss_hard)

if __name__ == "__main__":
    # logit.shape:[2,13,320,640]
    logit = np.random.random((2, 13, 320, 640))
    target1 = np.random.randint(0, 13, size=(320, 640))
    target1 = target1[np.newaxis, :, :]
    target2 = np.random.randint(0, 13, size=(320, 640))
    target2 = target2[np.newaxis, :, :]
    # target.shape:[2,320,640]
    target = np.vstack([target1, target2])

    # numpy --> tensor
    logit = torch.tensor(logit)
    target = torch.tensor(target).long()

    # loss forword
    F = OhemCELoss(thresh = 0.7)
    loss = F.forward(logit, target)
    print(loss)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

OHEM loss 源代码 的相关文章

随机推荐