如何使用 PyTorch 沿特定维度进行热编码?

2024-06-19

我有一个大小的张量[3, 15, 136], where:

  • 3 is batch size
  • 15 - sequence length and
  • 136 is tokens

我想使用中的概率来单热我的张量tokens维度(136)。为此,我想提取序列长度中每个字母的标记维度并放入1最大的可能性并将所有其他标记标记为0.


您可以使用 PyTorch 的one_hot https://pytorch.org/docs/master/generated/torch.nn.functional.one_hot.html函数来实现这一点:

import torch.nn.functional as F

t = torch.rand(3, 15, 136)

F.one_hot(t.argmax(dim=2), 136)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何使用 PyTorch 沿特定维度进行热编码? 的相关文章

随机推荐