我正在努力在 scipy 中拟合 3d 分布函数。我有一个 numpy 数组,其中包含 x 和 y 仓中的计数,我试图将其拟合到相当复杂的 3 维分布函数中。该数据适合 26 (!) 个参数,这些参数描述了其两个组成群体的形状。
我在这里了解到,当我调用 lesssq 时,我必须将 x 和 y 坐标作为“args”传递。 unutbu 提供的代码按照为我编写的方式工作,但是当我尝试将其应用于我的特定情况时,我收到错误“TypeError:leastsq()获得关键字参数'args'的多个值”
这是我的代码(抱歉太长了):
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize as spopt
from textwrap import wrap
import collections
cl = 0.5
ch = 3.5
rl = -23.5
rh = -18.5
mbins = 10
cbins = 10
def hist_data(mixed_data, mbins, cbins):
import numpy as np
H, xedges, yedges = np.histogram2d(mixed_data[:,1], mixed_data[:,2], bins = (mbins, cbins), weights = mixed_data[:,3])
x, y = 0.5 * (xedges[:-1] + xedges[1:]), 0.5 * (yedges[:-1] + yedges[1:])
return H.T, x, y
def gauss(x, s, mu, a):
import numpy as np
return a * np.exp(-((x - mu)**2. / (2. * s**2.)))
def tanhlin(x, p0, p1, q0, q1, q2):
import numpy as np
return p0 + p1 * (x + 20.) + q0 * np.tanh((x - q1)/q2)
def func3d(p, x, y):
import numpy as np
from sys import exit
rsp0, rsp1, rsq0, rsq1, rsq2, rmp0, rmp1, rmq0, rmq1, rmq2, rs, rm, ra, bsp0, bsp1, bsq0, bsq1, bsq2, bmp0, bmp1, bmq0, bmq1, bmq2, bs, bm, ba = p
x, y = np.meshgrid(coords[0], coords[1])
rs = tanhlin(x, rsp0, rsp1, rsq0, rsq1, rsq2)
rm = tanhlin(x, rmp0, rmp1, rmq0, rmq1, rmq2)
ra = schechter(x, rap, raa, ram) # unused
bs = tanhlin(x, bsp0, bsp1, bsq0, bsq1, bsq2)
bm = tanhlin(x, bmp0, bmp1, bmq0, bmq1, bmq2)
ba = schechter(x, bap, baa, bam) # unused
red_dist = ra / (rs * np.sqrt(2 * np.pi)) * gauss(y, rs, rm, ra)
blue_dist = ba / (bs * np.sqrt(2 * np.pi)) * gauss(y, bs, bm, ba)
result = red_dist + blue_dist
return result
def residual(p, coords, data):
import numpy as np
model = func3d(p, coords)
res = (model.flatten() - data.flatten())
# can put parameter restrictions in here
return res
def poiss_err(data):
import numpy as np
return np.where(np.sqrt(H) > 0., np.sqrt(H), 2.)
# =====
H, x, y = hist_data(mixed_data, mbins, cbins)
data = H
coords = x, y
# x and y will be the projected coordinates of the data H onto the plane z = 0
# x has bins of width 0.5, with centers at -23.25, -22.75, ... , -19.25, -18.75
# y has bins of width 0.3, with centers at 0.65, 0.95, ... , 3.05, 3.35
Param = collections.namedtuple('Param', 'rsp0 rsp1 rsq0 rsq1 rsq2 rmp0 rmp1 rmq0 rmq1 rmq2 rs rm ra bsp0 bsp1 bsq0 bsq1 bsq2 bmp0 bmp1 bmq0 bmq1 bmq2 bs bm ba')
p_guess = Param(rsp0 = 0.152, rsp1 = 0.008, rsq0 = 0.044, rsq1 = -19.91, rsq2 = 0.94, rmp0 = 2.279, rmp1 = -0.037, rmq0 = -0.108, rmq1 = -19.81, rmq2 = 0.96, rs = 1., rm = -20.5, ra = 10000., bsp0 = 0.298, bsp1 = 0.014, bsq0 = -0.067, bsq1 = -19.90, bsq2 = 0.58, bmp0 = 1.790, bmp1 = -0.053, bmq0 = -0.363, bmq1 = -20.75, bmq2 = 1.12, bs = 1., bm = -20., ba = 2000.)
opt, cov, infodict, mesg, ier = spopt.leastsq(residual, p_guess, poiss_err(H), args = coords, maxfev = 100000, full_output = True)
这是我的数据,只是数据箱较少:
[[ 1.00000000e+01 1.10000000e+01 2.10000000e+01 1.90000000e+01
1.70000000e+01 2.10000000e+01 2.40000000e+01 1.90000000e+01
2.80000000e+01 1.90000000e+01]
[ 1.40000000e+01 4.50000000e+01 6.00000000e+01 6.80000000e+01
1.34000000e+02 1.97000000e+02 2.23000000e+02 2.90000000e+02
3.23000000e+02 3.03000000e+02]
[ 3.00000000e+01 1.17000000e+02 3.78000000e+02 9.74000000e+02
1.71900000e+03 2.27700000e+03 2.39000000e+03 2.25500000e+03
1.85600000e+03 1.31000000e+03]
[ 1.52000000e+02 9.32000000e+02 2.89000000e+03 5.23800000e+03
6.66200000e+03 6.19100000e+03 4.54900000e+03 3.14600000e+03
2.09000000e+03 1.33800000e+03]
[ 5.39000000e+02 2.58100000e+03 6.51300000e+03 8.89900000e+03
8.52900000e+03 6.22900000e+03 3.55000000e+03 2.14300000e+03
1.19000000e+03 6.92000000e+02]
[ 1.49600000e+03 4.49200000e+03 8.77200000e+03 1.07610000e+04
9.76700000e+03 7.04900000e+03 4.23200000e+03 2.47200000e+03
1.41500000e+03 7.02000000e+02]
[ 2.31800000e+03 7.01500000e+03 1.28870000e+04 1.50840000e+04
1.35590000e+04 8.55600000e+03 4.15600000e+03 1.77100000e+03
6.57000000e+02 2.55000000e+02]
[ 1.57500000e+03 3.79300000e+03 5.20900000e+03 4.77800000e+03
3.26600000e+03 1.44700000e+03 5.31000000e+02 1.85000000e+02
9.30000000e+01 4.90000000e+01]
[ 7.01000000e+02 1.21600000e+03 1.17600000e+03 7.93000000e+02
4.79000000e+02 2.02000000e+02 8.80000000e+01 3.90000000e+01
2.30000000e+01 1.90000000e+01]
[ 2.93000000e+02 3.93000000e+02 2.90000000e+02 1.97000000e+02
1.18000000e+02 6.40000000e+01 4.10000000e+01 1.20000000e+01
1.10000000e+01 4.00000000e+00]]
非常感谢!