我可以做些什么来加速 numpy 中的屏蔽数组吗?我有一个效率非常低的函数,我重新编写了它来使用屏蔽数组(我可以只屏蔽行,而不是像我所做的那样进行复制和删除行)。然而,我惊讶地发现 masked 函数慢了 10 倍,因为 masked 数组慢得多。
举个例子,如下(masked 对我来说慢了 6 倍多):
import timeit
import numpy as np
import numpy.ma as ma
def test(row):
return row[0] + row[1]
a = np.arange(1000).reshape(500, 2)
t = timeit.Timer('np.apply_along_axis(test, 1, a)','from __main__ import test, a, np')
print round(t.timeit(100), 6)
b = ma.array(a)
t = timeit.Timer('ma.apply_along_axis(test, 1, b)','from __main__ import test, b, ma')
print round(t.timeit(100), 6)
我不知道为什么屏蔽数组函数移动得如此缓慢,但由于听起来您正在使用屏蔽来选择行(而不是单个值),因此您可以从屏蔽行创建一个常规数组并使用 np 函数反而:
b.mask = np.zeros(500)
b.mask[498] = True
t = timeit.Timer('c=b.view(np.ndarray)[~b.mask[:,0]]; np.apply_along_axis(test, 1, c)','from __main__ import test, b, ma, np')
print round(t.timeit(100), 6)
更好的是,根本不要使用掩码数组;只需将数据和一维掩码数组维护为单独的变量:
a = np.arange(1000).reshape(500, 2)
mask = np.ones(a.shape[0], dtype=bool)
mask[498] = False
out = np.apply_along_axis(test, 1, a[mask])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)