如何在 PyTorch 中以不同偏移量移动张量中的列(或行)?

2023-11-26

在 PyTorch 中,内置torch.roll函数只能以相同的偏移量移动列(或行)。但我想用不同的偏移量来移动列。假设输入张量是

[[1,2,3],
 [4,5,6],
 [7,8,9]]

比方说,我想通过偏移量进行移动i对于第 i 列。因此,预期输出是

[[1,8,6],
 [4,2,9],
 [7,5,3]]

这样做的一个选项是使用单独移动每列torch.roll并连接它们中的每一个。但出于有效性和代码紧凑性的考虑,我不想引入循环结构。有没有更好的办法?


让我们定义一些名称:

import torch

mat = torch.Tensor(
[[1,2,3],
 [4,5,6],
 [7,8,9]])

indices = torch.LongTensor([0, 1, 2]) # Could also use arange in this specific scenario

首先,你可以创建一个像这样的张量

[[0, 0, 0],
 [1, 1, 1],
 [2, 2, 2]]

using

arange1 = torch.arange(3).view((3, 1)).repeat((1, 3))

现在,让我们创建一个目标索引的张量

[[0, 2, 1],
 [1, 0, 2],
 [2, 1, 0]]

with

arange2 = (arange1 - indices) % 3

最后,我们得到预期的输出

torch.gather(mat, 0, arange2)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在 PyTorch 中以不同偏移量移动张量中的列(或行)? 的相关文章

随机推荐