方法#1
我们可以用np.einsum https://docs.scipy.org/doc/numpy/reference/generated/numpy.einsum.html一次性进行求和减少 -
result = np.einsum('ij,ik,il->jkl',a,b,c).ravel()
另外,玩一下optimize
标记在np.einsum
通过将其设置为True
使用BLAS。
方法#2
我们可以用broadcasting
执行邮政编码中提到的第一步,然后利用张量矩阵乘法np.tensordot
-
def broadcast_dot(a,b,c):
first_multi = a[...,None] * b[:,None]
return np.tensordot(first_multi,c, axes=(0,0)).ravel()
我们还可以使用numexpr module https://github.com/pydata/numexpr/blob/master/doc/user_guide.rst#enabling-intel-vml-support支持多核处理并实现更好的内存效率first_multi
。这给了我们一个修改后的解决方案,就像这样 -
import numexpr as ne
def numexpr_broadcast_dot(a,b,c):
first_multi = ne.evaluate('A*B',{'A':a[...,None],'B':b[:,None]})
return np.tensordot(first_multi,c, axes=(0,0)).ravel()
给定数据集大小的随机浮点数据的计时 -
In [36]: %timeit np.einsum('ij,ik,il->jkl',a,b,c).ravel()
4.57 s ± 75.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [3]: %timeit broadcast_dot(a,b,c)
270 ms ± 103 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [4]: %timeit numexpr_broadcast_dot(a,b,c)
172 ms ± 63.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
只是为了给人一种进步的感觉numexpr
-
In [7]: %timeit a[...,None] * b[:,None]
80.4 ms ± 2.64 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [8]: %timeit ne.evaluate('A*B',{'A':a[...,None],'B':b[:,None]})
25.9 ms ± 191 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
当将此解决方案扩展到更多数量的输入时,这应该是很重要的。