使用scatter_
将标签转换为one-hot
import torch
num_class = 5
label = torch.tensor([0, 2, 1, 4, 1, 3])
one_hot = torch.zeros((len(label), num_class)).scatter_(1, label.long().reshape(-1, 1), 1)
print(one_hot)
"""
tensor([[1., 0., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 0., 0., 1.],
[0., 1., 0., 0., 0.],
[0., 0., 0., 1., 0.]])
"""
又或者更简单的。。。
import torch.nn.functional as F
import torch
num_class = 5
label = torch.tensor([0, 2, 1, 4, 1, 3])
one_hot = F.one_hot(label, num_classes=num_class )
print(one_hot)
"""
tensor([[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 0, 1],
[0, 1, 0, 0, 0],
[0, 0, 0, 1, 0]])
"""
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)