将您正在做的事情矢量化非常容易:
import numpy as np
#generate dummy data
nrows=6
ncols=11
nframes=3
threshold=0.3
data=np.random.rand(nrows,ncols,nframes)
CM_tilde = np.mean(data, axis=1)
N = data.shape[1]
all_CMs2 = np.mean(np.where(data < (CM_tilde[:,None,:]+threshold),data,CM_tilde[:,None,:]),axis=1)
data_cm2 = data - all_CMs2[:,None,:]
将此与您的原件进行比较:
In [684]: (data_cm==data_cm2).all()
Out[684]: True
In [685]: (all_CMs==all_CMs2).all()
Out[685]: True
逻辑是我们使用大小的数组[nrows,ncols,nframes]
同时地。主要技巧是利用Python的广播,通过转动CM_tilde
大小的[nrows,nframes]
into CM_tilde[:,None,:]
大小的[nrows,1,nframes]
。然后,Python 将为每一列使用相同的值,因为这是此修改后的单一维度CM_tilde
.
通过使用np.where
我们选择(基于threshold
) 是否要获取对应的值data
,或者,再次,广播值CM_tilde
。一个新的用途np.mean
允许我们计算all_CMs2
.
在最后一步中,我们通过直接减去这个新的来利用广播all_CMs2
从相应的元素data
.
通过查看临时变量的隐式索引,可能有助于以这种方式矢量化代码。我的意思是你的临时变量CM
生活在一个循环中[nrows,nframes]
,并且其值在每次迭代时都会重置。这意味着CM
实际上是一个数量CM[row,frame]
(后来显式分配给二维数组all_CMs
),从这里很容易看出,您可以通过总结适当的CMtmp[row,col,frames]
沿其列尺寸的数量。如果有帮助,您可以命名np.where(...)
部分作为CMtmp
为此目的,然后计算np.mean(CMtmp,axis=1)
从那。显然,结果相同,但可能更透明。