我有一个较大的 2d numpy 数组,我想提取每行的最低 10 个元素及其索引。由于我的数组较大,我不想对整个数组进行排序。
我听说过argpartition()
函数,用它我可以获得最低 10 个元素的索引:
top10indexes = np.argpartition(myBigArray,10)[:,:10]
注意argpartition()
默认情况下对轴-1进行分区,这就是我想要的。这里的结果与 myBigArray 具有相同的形状,包含相应行的索引,使得前 10 个索引指向 10 个最低值。
我现在如何提取元素myBigArray
对应那些指标?
明显的花哨索引就像myBigArray[top10indexes]
or myBigArray[:,top10indexes]
做一些完全不同的事情。我还可以使用列表理解,例如:
array([row[idxs] for row,idxs in zip(myBigArray,top10indexes)])
但这会导致迭代 numpy 行并将结果转换回数组时性能受到影响。
注意:我可以使用np.partition()
获取值,它们甚至可能对应于索引(或者可能不对应..),但如果可以避免的话,我不想进行两次分区。