import numpy as np
import torch
import torch.nn as nn
class OhemCELoss(nn.Module):
def __init__(self, thresh, ignore_lb=255):
super(OhemCELoss, self).__init__()
self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float))
self.ignore_lb = ignore_lb
self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
def forward(self, logits, labels):
n_min = labels[labels != self.ignore_lb].numel() // 16
loss = self.criteria(logits, labels).view(-1)
loss_hard = loss[loss > self.thresh]
if loss_hard.numel() < n_min:
loss_hard, _ = loss.topk(n_min)
return torch.mean(loss_hard)
if __name__ == "__main__":
# logit.shape:[2,13,320,640]
logit = np.random.random((2, 13, 320, 640))
target1 = np.random.randint(0, 13, size=(320, 640))
target1 = target1[np.newaxis, :, :]
target2 = np.random.randint(0, 13, size=(320, 640))
target2 = target2[np.newaxis, :, :]
# target.shape:[2,320,640]
target = np.vstack([target1, target2])
# numpy --> tensor
logit = torch.tensor(logit)
target = torch.tensor(target).long()
# loss forword
F = OhemCELoss(thresh = 0.7)
loss = F.forward(logit, target)
print(loss)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)