我有一个大小的张量[3, 15, 136]
, where:
3 is batch size
-
15 - sequence length
and
136 is tokens
我想使用中的概率来单热我的张量tokens
维度(136)。为此,我想提取序列长度中每个字母的标记维度并放入1
最大的可能性并将所有其他标记标记为0
.
您可以使用 PyTorch 的one_hot https://pytorch.org/docs/master/generated/torch.nn.functional.one_hot.html函数来实现这一点:
import torch.nn.functional as F
t = torch.rand(3, 15, 136)
F.one_hot(t.argmax(dim=2), 136)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)