我需要过滤一个数组以删除低于某个阈值的元素。我当前的代码是这样的:
threshold = 5
a = numpy.array(range(10)) # testing data
b = numpy.array(filter(lambda x: x >= threshold, a))
问题是,这会创建一个临时列表,使用带有 lambda 函数的过滤器(慢)。
由于这是一个非常简单的操作,也许有一个 numpy 函数可以有效地完成它,但我一直找不到它。
我认为实现此目的的另一种方法可以是对数组进行排序,找到阈值的索引并从该索引开始返回一个切片,但即使这对于小输入来说会更快(而且无论如何都不会被注意到),随着输入大小的增加,它的效率肯定会逐渐降低。
Update:我也进行了一些测量,当输入为 100.000.000 个条目时,排序+切片仍然是纯 python 过滤器的两倍。
r = numpy.random.uniform(0, 1, 100000000)
%timeit test1(r) # filter
# 1 loops, best of 3: 21.3 s per loop
%timeit test2(r) # sort and slice
# 1 loops, best of 3: 11.1 s per loop
%timeit test3(r) # boolean indexing
# 1 loops, best of 3: 1.26 s per loop
b = a[a>threshold]
这应该做
我测试如下:
import numpy as np, datetime
# array of zeros and ones interleaved
lrg = np.arange(2).reshape((2,-1)).repeat(1000000,-1).flatten()
t0 = datetime.datetime.now()
flt = lrg[lrg==0]
print datetime.datetime.now() - t0
t0 = datetime.datetime.now()
flt = np.array(filter(lambda x:x==0, lrg))
print datetime.datetime.now() - t0
I got
$ python test.py
0:00:00.028000
0:00:02.461000
http://docs.scipy.org/doc/numpy/user/basics.indexing.html#boolean-or-mask-index-arrays http://docs.scipy.org/doc/numpy/user/basics.indexing.html#boolean-or-mask-index-arrays
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)