torch.Tensor.scatter_
scatter() 和 scatter_() 的作用是一样的,只不过 scatter() 不会直接修改原来的 Tensor,而 scatter_() 会
torch.Tensor.scatter_()
是torch.gather()
函数的方向反向操作。两个函数可以看成一对兄弟函数。gather
用来解码one hot,scatter_
用来编码one hot。
PyTorch 中,一般函数加下划线代表直接在原来的 Tensor 上修改
scatter_
(dim, index, src) → Tensor
参数:
- dim:沿着哪个维度进行索引
- index:用来 scatter 的元素索引
- src:用来 scatter 的源元素,可以是一个标量或一个张量
这个 scatter 可以理解成放置元素或者修改元素
简单说就是通过一个张量 src 来修改另一个张量,哪个元素需要修改、用 src 中的哪个元素来修改由 dim 和 index 决定
官方文档给出了 3维张量 的具体操作说明,如下所示
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
x = torch.rand(2, 5)
#tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
# [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
#tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
# [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
# [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])
具体地说,我们的 index 是 torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]),一个二维张量,下面用图简单说明
我们是 2维 张量,一开始进行 self[index[0][0]][0]self[index[0][0]][0],其中 index[0][0]index[0][0] 的值是0,所以执行 self[0][0]=x[0][0]=0.1940self[0][0]=x[0][0]=0.1940
self[index[i][j]][j]=src[i][j]
再比如self[index[1][0]][0]self[index[1][0]][0],其中 index[1][0]index[1][0] 的值是2,所以执行 self[2][0]=x[1][0]=0.2078self[2][0]=x[1][0]=0.2078
计算过程:index[0,0]=0→self[0,0]→src[0,0] =0.1940
index[0,1]=1→self[1,1]→src[0,1] =0.3340
index[0,2]=2→self[2,2]→src[0,2] =0.8184
example:
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 7)
#tensor([[7., 7., 7., 7., 7.],
# [0., 7., 0., 7., 0.],
# [7., 0., 7., 0., 7.]]
计算过程:index[0,0]=0→self[0,0]→src[0,0] =7
index[0,1]=1→self[1,1]→src[0,1] =7
index[0,2]=2→self[2,2]→src[0,2] =7
scatter() 一般可以用来对标签进行 one-hot 编码,这就是一个典型的用标量来修改张量的一个例子
用于产生one hot编码的向量
example:
class_num = 10
batch_size = 4
label = torch.LongTensor(batch_size, 1).random_() % class_num
#tensor([[6],
# [0],
# [3],
# [2]])
torch.zeros(batch_size, class_num).scatter_(1, label, 1)
#tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
# [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
# [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
# [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])
indices = torch.tensor(list(range(5))).view(5,1)
indices
result = torch.zeros(5, 5)
result.scatter_(1, indices, 1)
当没有src
值时,则所有用于填充的值均为value
值。
需要注意的时候,这个时候index.shape[dim]
必须与result.shape[dim]
相等,否则会报错。
result = torch.zeros(3, 5)
indices = torch.tensor([[0, 1, 2, 0, 0],
[2, 0, 3, 1, 2],
[2, 1, 3, 1, 4]])
result.scatter_(1, indices, value=1)
输出为
tensor([[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.],
[0., 1., 1., 1., 1.]])
参考资料
https://pytorch.org/docs/stable/tensors.html?highlight=scatter_#torch.Tensor.scatter_
https://www.cnblogs.com/dogecheng/p/11938009.html
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)