我遇到了这个问题,我的解决方法是首先使用简单的虚拟变量进行替换,collect
基于这些简单的变量,然后代入更高级的变量。可能有一些极端情况,但它似乎对我有用。
from sympy import symarray, collect
def mycollect(expr, var_list, evaluate=True, **kwargs):
""" Acts as collect but substitute the symbols with dummy symbols first so that it can work with partial derivatives.
Matrix expressions are also supported.
"""
if not hasattr(var_list, '__len__'):
var_list=[var_list]
# Mapping Var -> Dummy, and Dummy-> Var
Dummies=symarray('DUM', len(var_list))
Var2Dummy=[(var, Dummies[i]) for i,var in enumerate(var_list)]
Dummy2Var=[(b,a) for a,b in Var2Dummy]
# Replace var with dummies and apply collect
expr = expr.expand().doit()
expr = expr.subs(Var2Dummy)
if hasattr(expr, '__len__'):
expr = expr.applyfunc(lambda ij: collect(ij, Dummies, **kwargs))
else:
expr = collect(expr, Dummies, evaluate=evaluate, **kwargs)
# Substitute back
if evaluate:
return expr.subs(Dummy2Var)
d={}
for k,v in expr.items():
k=k.subs(Dummy2Var)
v=v.subs(Dummy2Var)
d[k]=v
return d
对于你的例子:
mycollect(expr6, psi.diff(x), evaluate=False)
mycollect(expr7, psi.diff(x), evaluate=False)
returns:
{Derivative(psi(x, y, z, t), (x, 2)): 2, Derivative(psi(x, y, z, t), x): 3*U, 1: 5*Derivative(psi(x, y, z, t), y)}
{Derivative(psi(x, y, z, t), x, y): 2, Derivative(psi(x, y, z, t), x): 3*U, 1: 5*Derivative(psi(x, y, z, t), y)}