我在运行 python / numypy 代码时遇到速度问题。我不知道如何让它更快,也许其他人?
假设有一个表面有两个三角剖分,一个是细三角剖分 (..._fine),有 M 个点,一个是粗剖分,有 N 个点。此外,还有每个点的粗网格数据(N 个浮点数)。我正在尝试执行以下操作:
对于细网格上的每个点,找到粗网格上最接近的 k 个点并获取平均值。简而言之:从粗到细插值数据。
我的代码现在是这样的。对于大数据(在我的例子中,M = 2e6,N = 1e4),代码运行大约 25 分钟,猜测是由于显式 for 循环没有进入 numpy。有什么想法如何通过智能索引来解决这个问题吗? M x N 阵列耗尽了 RAM..
import numpy as np
p_fine.shape => m x 3
p.shape => n x 3
data_fine = np.empty((m,))
for i, ps in enumerate(p_fine):
data_fine[i] = np.mean(data_coarse[np.argsort(np.linalg.norm(ps-p,axis=1))[:k]])
Cheers!
首先感谢您的详细帮助。
首先,Divakar,您的解决方案大大提高了速度。根据我的数据,代码运行时间略低于 2 分钟,具体取决于块大小。
我也尝试过 sklearn 并最终得到
def sklearnSearch_v3(p, p_fine, k):
neigh = NearestNeighbors(k)
neigh.fit(p)
return data_coarse[neigh.kneighbors(p_fine)[1]].mean(axis=1)
最终速度相当快,对于我的数据大小,我得到以下结果
import numpy as np
from sklearn.neighbors import NearestNeighbors
m,n = 2000000,20000
p_fine = np.random.rand(m,3)
p = np.random.rand(n,3)
data_coarse = np.random.rand(n)
k = 3
yields
%timeit sklearv3(p, p_fine, k)
1 loop, best of 3: 7.46 s per loop
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)