我有以下代码:
import numpy as np
from numba import jit
Nx = 15
Ny = 1000
v = np.ones((Nx,Ny))
v = np.reshape(v,(Nx*Ny))
A = np.random.rand(Nx*Ny,Nx*Ny,5)
B = np.random.rand(Nx*Ny,Nx*Ny,5)
C = np.random.rand(Nx*Ny,5)
@jit(nopython=True)
def dotplus(B, v, C):
return np.dot(B, v) + C
k = 2
D = dotplus(B[:,:,k], v, C[:,k])
我收到以下警告,我猜它指的是数组B[:,:,k]
and v
:
NumbaPerformanceWarning: np.dot() is faster on contiguous arrays, called on (array(float64, 2d, A), array(float64, 1d, C))
return np.dot(B, v0) + C
有没有办法让两个数组连续,这样Numba就可以加速代码?
PS,如果您想知道其含义k
,请注意这只是 MRE。在实际代码中,dotplus
在一个内部被多次调用for
循环不同的值k
(因此,不同的切片B
and C
). The for
循环更新的值v
, but B
and C
不要改变。