我正在尝试使用 numpy 的 einsum 计算以下内容:
import numpy as np
tmp_ee = np.ones((2,4,4))
tmp_ij = np.ones((2,2,4,2,4,2))
print(tmp_ee.shape)
print(tmp_ij.shape)
np.einsum('naq,nnbpqp->nab', tmp_ee, tmp_ij, optimize=True)
但是,我遇到了以下错误:
ValueError:折叠索引“q”的操作数 0 中的维度不匹配 (4 != 2)
如果您查看上面的代码,就会发现两个数组中索引“q”的维度都等于 4。其他重复索引也具有一致的维度。因此,我不明白这个错误的原因。除此之外,当我的第二个数组 temp_ij 的尺寸为 (2,2,3,3,3,3) 或 (2,2,1,1,1,1) 或 (2,2,2, 2,2,2),即(2,2,x,x,x,x),则不会出现上述错误,代码运行顺利。我尝试通过使用 einsum 分两步执行操作来解决此问题,如下所示:
tmp_ee = np.ones((2,4,4))
tmp_ij = np.ones((2,2,4,2,4,2))
print(tmp_ee.shape)
print(tmp_ij.shape)
temp = np.einsum('naq,nnbrqp->nabrp', tmp_ee, tmp_ij, optimize=True)
np.einsum('nabpp->nab', temp, optimize = True)
这确实not给出像之前的代码一样的错误。这工作正常并给了我结果。虽然此解决方法似乎工作正常,但它显着减慢了我的代码速度,因为我必须在代码中执行数百万次此计算。对于上面的错误有解释吗?
Thanks!