如何将 numpy.argpartition 的输出应用于二维数组?

2024-04-23

我有一个较大的 2d numpy 数组,我想提取每行的最低 10 个元素及其索引。由于我的数组较大,我不想对整个数组进行排序。

我听说过argpartition()函数,用它我可以获得最低 10 个元素的索引:

top10indexes = np.argpartition(myBigArray,10)[:,:10]

注意argpartition()默认情况下对轴-1进行分区,这就是我想要的。这里的结果与 myBigArray 具有相同的形状,包含相应行的索引,使得前 10 个索引指向 10 个最低值。

我现在如何提取元素myBigArray对应那些指标?

明显的花哨索引就像myBigArray[top10indexes] or myBigArray[:,top10indexes]做一些完全不同的事情。我还可以使用列表理解,例如:

array([row[idxs] for row,idxs in zip(myBigArray,top10indexes)])

但这会导致迭代 numpy 行并将结果转换回数组时性能受到影响。

注意:我可以使用np.partition()获取值,它们甚至可能对应于索引(或者可能不对应..),但如果可以避免的话,我不想进行两次分区。


您可以通过执行以下操作来避免使用扁平副本以及提取所有值的需要:

num = 10
top = np.argpartition(myBigArray, num, axis=1)[:, :num]
myBigArray[np.arange(myBigArray.shape[0])[:, None], top]

对于 NumPy >= 1.9.0 这将非常有效并且可以与np.take().

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何将 numpy.argpartition 的输出应用于二维数组? 的相关文章

随机推荐