起始方法
适应sparse solution from this post https://stackoverflow.com/a/46457039/3293881 -
def sparse_matrix_mult_sparseX_mod1(X, rows):
nrows = rows.max()+1
ncols = X.shape[1]
nelem = nrows * ncols
a,b = X.nonzero()
ids = rows[a] + b*nrows
sums = np.bincount(ids, X[a,b].A1, minlength=nelem)
out = sums.reshape(ncols,-1).T
return out
标杆管理
原始方法#1 -
def app1(X, map_fn):
col = scipy.arange(n)
val = np.ones(n)
S = scipy.sparse.csr_matrix( (val, (map_fn, col)), shape = (s,n))
SX = S.dot(X)
return SX
时间安排和验证 -
In [209]: # Inputs setup
...: s=10000
...: n=1000000
...: d=1000
...: density=1.0/500
...:
...: X=scipy.sparse.rand(n,d,density=density,format="csr")
...: map_fn=np.random.randint(0, s, n)
...:
In [210]: out1 = app1(X, map_fn)
...: out2 = sparse_matrix_mult_sparseX_mod1(X, map_fn)
...: print np.allclose(out1.toarray(), out2)
...:
True
In [211]: %timeit app1(X, map_fn)
1 loop, best of 3: 517 ms per loop
In [212]: %timeit sparse_matrix_mult_sparseX_mod1(X, map_fn)
10 loops, best of 3: 147 ms per loop
公平地说,我们应该对最终的密集数组版本进行计时app1
-
In [214]: %timeit app1(X, map_fn).toarray()
1 loop, best of 3: 584 ms per loop
移植到 Numba
我们可以将分箱计数步骤转换为 numba,这可能有利于更密集的输入矩阵。这样做的方法之一是 -
from numba import njit
@njit
def bincount_mod2(out, rows, r, C, V):
N = len(V)
for i in range(N):
out[rows[r[i]], C[i]] += V[i]
return out
def sparse_matrix_mult_sparseX_mod2(X, rows):
nrows = rows.max()+1
ncols = X.shape[1]
r,C = X.nonzero()
V = X[r,C].A1
out = np.zeros((nrows, ncols))
return bincount_mod2(out, rows, r, C, V)
时间安排 -
In [373]: # Inputs setup
...: s=10000
...: n=1000000
...: d=1000
...: density=1.0/100 # Denser now!
...:
...: X=scipy.sparse.rand(n,d,density=density,format="csr")
...: map_fn=np.random.randint(0, s, n)
...:
In [374]: %timeit app1(X, map_fn)
1 loop, best of 3: 787 ms per loop
In [375]: %timeit sparse_matrix_mult_sparseX_mod1(X, map_fn)
1 loop, best of 3: 906 ms per loop
In [376]: %timeit sparse_matrix_mult_sparseX_mod2(X, map_fn)
1 loop, best of 3: 705 ms per loop
随着密集输出app1
-
In [379]: %timeit app1(X, map_fn).toarray()
1 loop, best of 3: 910 ms per loop