为什么 numpy.var 是 O(N) 空间?

2024-03-16

我有一个 ~13GB 的数组。我打电话numpy.var对其进行计算方差。然而,它又分配了约 13GB 来执行此操作。为什么需要 O(N) 空间?或者我打电话numpy.var以错误的方式?

import numpy as np
# data = ...
print('Variance: ', np.var(data))

NumPy 将创建一个中间数组来计算abs(data - data.mean()) ** 2为了计算方差。您可以使用循环编写自己的方差函数,并使用 Numba 使其快速运行:

import numpy as np
import numba as nb

@nb.njit(parallel=True)
def var_nb(a, ddof=0):
    n = len(a)
    s = a.sum()
    m = s / (n - ddof)
    v = 0
    for i in nb.prange(n):
        v += abs(a[i] - m) ** 2
    return v / (n - ddof)

np.random.seed(100)
a = np.random.rand(100_000)
print(np.var(a))
# 0.08349747560941487
print(var_nb(a))
# 0.08349747560941487

%timeit np.var(a)
# 143 µs ± 414 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit var_nb(a)
# 40.2 µs ± 530 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

为什么 numpy.var 是 O(N) 空间? 的相关文章

随机推荐