要计算班级的班级权重,请使用sklearn.utils.class_weight.compute_class_weight(class_weight, *, classes, y)
在这里阅读 https://scikit-learn.org/stable/modules/generated/sklearn.utils.class_weight.compute_class_weight.html
这将返回一个数组,即weight
.
eg .
x = torch.randn(20, 5)
y = torch.randint(0, 5, (20,)) # classes
class_weights=class_weight.compute_class_weight('balanced',np.unique(y),y.numpy())
class_weights=torch.tensor(class_weights,dtype=torch.float)
print(class_weights) #([1.0000, 1.0000, 4.0000, 1.0000, 0.5714])
然后将其传递给nn.CrossEntropyLoss
的权重变量
criterion = nn.CrossEntropyLoss(weight=class_weights,reduction='mean')
loss = criterion(...)