让我们定义一些名称:
import torch
mat = torch.Tensor(
[[1,2,3],
[4,5,6],
[7,8,9]])
indices = torch.LongTensor([0, 1, 2]) # Could also use arange in this specific scenario
首先,你可以创建一个像这样的张量
[[0, 0, 0],
[1, 1, 1],
[2, 2, 2]]
using
arange1 = torch.arange(3).view((3, 1)).repeat((1, 3))
现在,让我们创建一个目标索引的张量
[[0, 2, 1],
[1, 0, 2],
[2, 1, 0]]
with
arange2 = (arange1 - indices) % 3
最后,我们得到预期的输出
torch.gather(mat, 0, arange2)