我们使用这些索引对输入数组的第一个轴进行索引。为了得到2D
输出,我们只需要排列轴并随后重塑即可。因此,一种方法是np.transpose
/np.swapaxes https://docs.scipy.org/doc/numpy/reference/generated/numpy.swapaxes.html and np.reshape https://docs.scipy.org/doc/numpy/reference/generated/numpy.reshape.html,就像这样——
mats[idxs].swapaxes(1,2).reshape(-1,mats.shape[-1]*idxs.shape[-1])
样本运行 -
In [83]: mats
Out[83]:
array([[[1, 1],
[7, 1]],
[[6, 6],
[5, 8]],
[[7, 1],
[6, 0]],
[[2, 7],
[0, 4]]])
In [84]: idxs
Out[84]:
array([[2, 3],
[0, 3],
[1, 2]])
In [85]: mats[idxs].swapaxes(1,2).reshape(-1,mats.shape[-1]*idxs.shape[-1])
Out[85]:
array([[7, 1, 2, 7],
[6, 0, 0, 4],
[1, 1, 2, 7],
[7, 1, 0, 4],
[6, 6, 7, 1],
[5, 8, 6, 0]])
性能提升np.take https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html对于重复索引
对于重复索引,为了性能我们最好使用np.take
通过索引axis=0
。让我们列出这两种方法并计时idxs
有许多重复的索引。
函数定义 -
def simply_indexing_based(mats, idxs):
ncols = mats.shape[-1]*idxs.shape[-1]
return mats[idxs].swapaxes(1,2).reshape(-1,ncols)
def take_based(mats, idxs):np.take(mats,idxs,axis=0)
ncols = mats.shape[-1]*idxs.shape[-1]
return np.take(mats,idxs,axis=0).swapaxes(1,2).reshape(-1,ncols)
运行时测试 -
In [156]: mats = np.random.randint(0,9,(10,2,2))
In [157]: idxs = np.random.randint(0,10,(1000,1000))
# This ensures many repeated indices
In [158]: out1 = simply_indexing_based(mats, idxs)
In [159]: out2 = take_based(mats, idxs)
In [160]: np.allclose(out1, out2)
Out[160]: True
In [161]: %timeit simply_indexing_based(mats, idxs)
10 loops, best of 3: 41.2 ms per loop
In [162]: %timeit take_based(mats, idxs)
10 loops, best of 3: 27.3 ms per loop
因此,我们看到整体改善1.5x+
.
只是为了感受一下改进np.take
,让我们单独计算索引部分的时间 -
In [168]: %timeit mats[idxs]
10 loops, best of 3: 22.8 ms per loop
In [169]: %timeit np.take(mats,idxs,axis=0)
100 loops, best of 3: 8.88 ms per loop
对于这些数据大小,其2.5x+
。不错!