pytorch中Tensor的花式索引

2023-05-16

偶然发现pytorch的tensor除了像numpy数组那样简单索引或者切片,还有一种花式索引,也就是用tensor对tensor索引,可以广播原tensor。下面给出示例和转为numpy版本的写法。

示例

i n a . s h a p e = [ b , c , h , w ] in_a.shape=[b,c,h,w] ina.shape=[b,c,h,w]
i n b . s h a p e = [ m , n ] in_b.shape= [m,n] inb.shape=[m,n]
采用in_b对in_a索引: o u t = a [ : , : , b , : ] out = a[:,:, b,:] out=a[:,:,b,:]
则得到的out的shape: o u t . s h a p e = [ b , c , m , n , w ] out.shape=[b,c,m,n,w] out.shape=[b,c,m,n,w]

举个例子:

>>> in_a = torch.randn(1,1,4,5)
>>> in_b = torch.tensor([[2,0],[1,3],[2,3]])
>>> in_a
tensor([[[[ 0.2668,  0.5453,  0.5563,  0.7396, -1.1646],
          [-0.1059,  0.8955,  0.8947, -3.0298, -2.0912],
          [ 0.8145,  0.3670,  0.4827,  0.1327, -0.9437],
          [ 1.3698, -0.8281, -0.8810,  1.6670, -1.8736]]]])
>>> in_b.shape
torch.Size([3, 2])
>>> in_a[:,:,in_b,:].shape
torch.Size([1, 1, 3, 2, 5])
>>> in_a[:,:,in_b,:]
tensor([[[[[ 0.8145,  0.3670,  0.4827,  0.1327, -0.9437],
           [ 0.2668,  0.5453,  0.5563,  0.7396, -1.1646]],
           
          [[-0.1059,  0.8955,  0.8947, -3.0298, -2.0912],
           [ 1.3698, -0.8281, -0.8810,  1.6670, -1.8736]],
           
          [[ 0.8145,  0.3670,  0.4827,  0.1327, -0.9437],
           [ 1.3698, -0.8281, -0.8810,  1.6670, -1.8736]]]]])

也就是在 i n _ a in\_a in_a d i m = 2 dim= 2 dim=2 上索引,依次取index= [ 2 , 0 ] , [ 1 , 3 ] , [ 2 , 3 ] [2,0],[1,3],[2,3] [2,0],[1,3],[2,3]的tensor填充。特别要注意:index的数值不能超出dim=2的最大维度, 比如例子中,in_a的shape为 [ 1 , 1 , 4 , 5 ] [1,1,4,5] [1,1,4,5],在dim=2维度索引, 则索引的值只能是 0 , 1 , 2 , 3 0,1,2,3 0,1,2,3.

再举个栗子:

>>> in_a[:,:,:,in_b].shape
torch.Size([1, 1, 4, 3, 2])

>>> in_a[:,:,:,in_b]
tensor([[[[[ 0.5563,  0.2668],
           [ 0.5453,  0.7396],
           [ 0.5563,  0.7396]],
           
          [[ 0.8947, -0.1059],
           [ 0.8955, -3.0298],
           [ 0.8947, -3.0298]],
           
          [[ 0.4827,  0.8145],
           [ 0.3670,  0.1327],
           [ 0.4827,  0.1327]],
           
          [[-0.8810,  1.3698],
           [-0.8281,  1.6670],
           [-0.8810,  1.6670]]]]])

用numpy写花式索引

目前只想到很愚蠢的遍历读取再赋值:

import numpy as np
num_a = in_a.numpy()  # [1,1,4,5]
num_b = in_b.numpy()  # [3,2]
[b,c,h,w] = num_a.shape
[m,n] = num_b.shape

out_ny = np.zeros([b,c,m,n,w])  # [1,1,3,2,5]
for i in range(m):
     for j in range(n):
         out_ny[:,:,i,j,:] = num_a[:,:, num_b[i,j],:]
out_ny
array([[[[[ 0.81448293,  0.36703789,  0.48273084,  0.13274327,
           -0.94368148],
          [ 0.26677063,  0.54529017,  0.55633378,  0.73956281,
           -1.16463828]],
         [[-0.10586801,  0.89547068,  0.89467597, -3.02978396,
           -2.09123206],
          [ 1.36978781, -0.8280825 , -0.8810119 ,  1.6670413 ,
           -1.87361884]],
         [[ 0.81448293,  0.36703789,  0.48273084,  0.13274327,
           -0.94368148],
          [ 1.36978781, -0.8280825 , -0.8810119 ,  1.6670413 ,
           -1.87361884]]]]])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch中Tensor的花式索引 的相关文章

随机推荐