使用 3d 数据和参数在 Scipy 中进行曲线拟合

2023-12-03

我正在努力在 scipy 中拟合 3d 分布函数。我有一个 numpy 数组,其中包含 x 和 y 仓中的计数,我试图将其拟合到相当复杂的 3 维分布函数中。该数据适合 26 (!) 个参数,这些参数描述了其两个组成群体的形状。

我在这里了解到,当我调用 lesssq 时,我必须将 x 和 y 坐标作为“args”传递。 unutbu 提供的代码按照为我编写的方式工作,但是当我尝试将其应用于我的特定情况时,我收到错误“TypeError:leastsq()获得关键字参数'args'的多个值”

这是我的代码(抱歉太长了):

import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize as spopt
from textwrap import wrap
import collections

cl = 0.5
ch = 3.5
rl = -23.5
rh = -18.5
mbins = 10
cbins = 10

def hist_data(mixed_data, mbins, cbins):
    import numpy as np
    H, xedges, yedges = np.histogram2d(mixed_data[:,1], mixed_data[:,2], bins = (mbins, cbins), weights = mixed_data[:,3])
    x, y = 0.5 * (xedges[:-1] + xedges[1:]), 0.5 * (yedges[:-1] + yedges[1:])
    return H.T, x, y

def gauss(x, s, mu, a):
    import numpy as np
    return a * np.exp(-((x - mu)**2. / (2. * s**2.)))

def tanhlin(x, p0, p1, q0, q1, q2):
    import numpy as np
    return p0 + p1 * (x + 20.) + q0 * np.tanh((x - q1)/q2)

def func3d(p, x, y):
    import numpy as np
    from sys import exit
    rsp0, rsp1, rsq0, rsq1, rsq2, rmp0, rmp1, rmq0, rmq1, rmq2, rs, rm, ra, bsp0, bsp1, bsq0, bsq1, bsq2, bmp0, bmp1, bmq0, bmq1, bmq2, bs, bm, ba = p
x, y = np.meshgrid(coords[0], coords[1])
    rs = tanhlin(x, rsp0, rsp1, rsq0, rsq1, rsq2)
    rm = tanhlin(x, rmp0, rmp1, rmq0, rmq1, rmq2)
    ra = schechter(x, rap, raa, ram) # unused
    bs = tanhlin(x, bsp0, bsp1, bsq0, bsq1, bsq2)
    bm = tanhlin(x, bmp0, bmp1, bmq0, bmq1, bmq2)
    ba = schechter(x, bap, baa, bam) # unused
    red_dist = ra / (rs * np.sqrt(2 * np.pi)) * gauss(y, rs, rm, ra)
    blue_dist = ba / (bs * np.sqrt(2 * np.pi)) * gauss(y, bs, bm, ba)
    result = red_dist + blue_dist
return result

def residual(p, coords, data):
    import numpy as np
    model = func3d(p, coords)
    res = (model.flatten() - data.flatten())
    # can put parameter restrictions in here
    return res

def poiss_err(data):
    import numpy as np
    return np.where(np.sqrt(H) > 0., np.sqrt(H), 2.)

# =====

H, x, y = hist_data(mixed_data, mbins, cbins)

data = H

coords = x, y
# x and y will be the projected coordinates of the data H onto the plane z = 0

# x has bins of width 0.5, with centers at -23.25, -22.75, ... , -19.25, -18.75
# y has bins of width 0.3, with centers at 0.65, 0.95, ... , 3.05, 3.35    

Param = collections.namedtuple('Param', 'rsp0 rsp1 rsq0 rsq1 rsq2 rmp0 rmp1 rmq0 rmq1 rmq2 rs rm ra bsp0 bsp1 bsq0 bsq1 bsq2 bmp0 bmp1 bmq0 bmq1 bmq2 bs bm ba')
p_guess = Param(rsp0 = 0.152, rsp1 = 0.008, rsq0 = 0.044, rsq1 = -19.91, rsq2 = 0.94, rmp0 = 2.279, rmp1 = -0.037, rmq0 = -0.108, rmq1 = -19.81, rmq2 = 0.96, rs = 1., rm = -20.5, ra = 10000., bsp0 = 0.298, bsp1 = 0.014, bsq0 = -0.067, bsq1 = -19.90, bsq2 = 0.58, bmp0 = 1.790, bmp1 = -0.053, bmq0 = -0.363, bmq1 = -20.75, bmq2 = 1.12, bs = 1., bm = -20., ba = 2000.)

opt, cov, infodict, mesg, ier = spopt.leastsq(residual, p_guess, poiss_err(H), args = coords, maxfev = 100000, full_output = True)

这是我的数据,只是数据箱较少:

[[  1.00000000e+01   1.10000000e+01   2.10000000e+01   1.90000000e+01
1.70000000e+01   2.10000000e+01   2.40000000e+01   1.90000000e+01
2.80000000e+01   1.90000000e+01]
[  1.40000000e+01   4.50000000e+01   6.00000000e+01   6.80000000e+01
1.34000000e+02   1.97000000e+02   2.23000000e+02   2.90000000e+02
3.23000000e+02   3.03000000e+02]
[  3.00000000e+01   1.17000000e+02   3.78000000e+02   9.74000000e+02
1.71900000e+03   2.27700000e+03   2.39000000e+03   2.25500000e+03
1.85600000e+03   1.31000000e+03]
[  1.52000000e+02   9.32000000e+02   2.89000000e+03   5.23800000e+03
6.66200000e+03   6.19100000e+03   4.54900000e+03   3.14600000e+03
2.09000000e+03   1.33800000e+03]
[  5.39000000e+02   2.58100000e+03   6.51300000e+03   8.89900000e+03
8.52900000e+03   6.22900000e+03   3.55000000e+03   2.14300000e+03
1.19000000e+03   6.92000000e+02]
[  1.49600000e+03   4.49200000e+03   8.77200000e+03   1.07610000e+04
9.76700000e+03   7.04900000e+03   4.23200000e+03   2.47200000e+03
1.41500000e+03   7.02000000e+02]
[  2.31800000e+03   7.01500000e+03   1.28870000e+04   1.50840000e+04
1.35590000e+04   8.55600000e+03   4.15600000e+03   1.77100000e+03
6.57000000e+02   2.55000000e+02]
[  1.57500000e+03   3.79300000e+03   5.20900000e+03   4.77800000e+03
3.26600000e+03   1.44700000e+03   5.31000000e+02   1.85000000e+02
9.30000000e+01   4.90000000e+01]
[  7.01000000e+02   1.21600000e+03   1.17600000e+03   7.93000000e+02
4.79000000e+02   2.02000000e+02   8.80000000e+01   3.90000000e+01
2.30000000e+01   1.90000000e+01]
[  2.93000000e+02   3.93000000e+02   2.90000000e+02   1.97000000e+02
1.18000000e+02   6.40000000e+01   4.10000000e+01   1.20000000e+01
1.10000000e+01   4.00000000e+00]]

非常感谢!


So what leastsq所做的是尝试:

“最小化一组方程的平方和” -scipy 文档

正如它所说,它正在最小化一组函数,因此如果您查看参数,实际上不会以最简单的方式获取任何 x 或 y 数据输入here所以你可以按照你喜欢的方式去做并传递一个残差函数,但是,使用它要容易得多curve_fit它会为你做这件事:)并创建必要的方程

为了适合您应该使用:curve_fit如果您对他们使用的通用残差没问题,这实际上是您传递自身的函数res = leastsq(func, p0, args=args, full_output=1, **kw)如果你看代码在这里。

例如如果我将 Rosenbrock 函数拟合为 2d 并猜测 y 参数:

from scipy.optimize import curve_fit
from itertools import imap
import numpy as np
# use only an even number of arguments
def rosen2d(x,a):
    return (1-x)**2 + 100*(a - (x**2))**2
#generate some random data slightly off

datax = np.array([.01*x for x in range(-10,10)])
datay = 2.3
dataz = np.array(map(lambda x: rosen2d(x,datay), datax))
optimalparams, covmatrix = curve_fit(rosen2d, datax, dataz)
print 'opt:',optimalparams

在 4d 中拟合 colville 函数:

from scipy.optimize import curve_fit
import numpy as np

# 4 dimensional colville function
# definition from http://www.sfu.ca/~ssurjano/colville.html
def colville(x,x3,x4):
    x1,x2 = x[:,0],x[:,1]
    return 100*(x1**2 - x2)**2 + (x1-1)**2 + (x3-1)**2 + \
            90*(x3**2 - x4)**2 + \
            10.1*((x2 - 1)**2 + (x4 - 1)**2) + \
            19.8*(x2 - 1)*(x4 - 1)
#generate some random data slightly off

datax = np.array([[x,x] for x in range(-10,10)])
#add gaussian noise
datax+= np.random.rand(*datax.shape)
#set 2 of the 4 parameters to constants
x3 = 3.5
x4 = 4.5
#calculate the function
dataz = colville(datax, x3, x4)
#fit the function
optimalparams, covmatrix = curve_fit(colville, datax, dataz)
print 'opt:',optimalparams

使用自定义残差函数:

from scipy.optimize import leastsq
import numpy as np

# 4 dimensional colville function
# definition from http://www.sfu.ca/~ssurjano/colville.html
def colville(x,x3,x4):
    x1,x2 = x[:,0],x[:,1]
    return 100*(x1**2 - x2)**2 + (x1-1)**2 + (x3-1)**2 + \
            90*(x3**2 - x4)**2 + \
            10.1*((x2 - 1)**2 + (x4 - 1)**2) + \
            19.8*(x2 - 1)*(x4 - 1)
#generate some random data slightly off


datax = np.array([[x,x] for x in range(-10,10)])
#add gaussian noise
datax+= np.random.rand(*datax.shape)
#set 2 of the 4 parameters to constants
x3 = 3.5
x4 = 4.5

def residual(p, x, y):
    return y - colville(x,*p)
#calculate the function
dataz = colville(datax, x3, x4)
#guess some initial parameter values
p0 = [0,0]
#calculate a minimization of the residual
optimalparams = leastsq(residual, p0, args=(datax, dataz))[0]
print 'opt:',optimalparams

编辑:您使用了位置和关键字 argargs: 如果你看一下docs你会看到它使用位置 3,但也可以用作关键字参数。你用过both这意味着该功能符合预期,令人困惑。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 3d 数据和参数在 Scipy 中进行曲线拟合 的相关文章

随机推荐

  • 使用json从PHP-MySql服务器获取图像到Android

    我正在开发一个应用程序 它从 php 服务器下载图像并在图像视图中显示图像 但是当我从 php 页面接收图像时 if empty result if mysql num rows result gt 0 result mysql fetch
  • 如何根据部分文件名检查文件是否存在?

    我试图检查我的文件夹中是否存在文件 但我只有部分文件名来检查它 有没有办法检查它 例如我有以下内容 God of War 文件名实际上称为 God of War PSP USA rar file exists 函数是否有某种类似的功能 或者
  • 如何使用python OpenCV找到单通道图像中与特定值匹配的最大连通分量?

    因此 我有一个主要为 0 背景 的单通道图像 以及前景像素的一些值 如 20 21 22 非零前景像素大多与具有相同值的其他前景像素聚集在一起 然而 图像中存在一些噪点 为了消除噪音 我想使用连通分量分析 并且对于每个值 在本例中为 20
  • 纸浆 LP 最小化配制“选择一种类型”约束

    下面的代码用于运行 LP 最小化问题 其中我们有某些食物 它们的营养价值和成本 该代码当前在所呈现的状态下工作 我正在尝试添加另一种类型的约束 我将所有食物分为不同的类别 早餐 午餐 晚餐 零食 我想创建一个约束 其中Only 1早餐 午餐
  • 在 BIML 中的数据流之前创建表

    我正在使用 BIML 和 BIDSHelper 创建 SSIS 包 我正在尝试将数据从 csv 导入到 sql server 我想在数据流发生之前在目标数据库中创建表 这是我的代码
  • NSUserDefaults 与 sqlite3

    我有一个小型 iPhone 应用程序 用于存储对象列表 用户可以添加和删除对象 但此列表将仍然相当小 大多数用户将有 10 30 个对象 NSUserDefaults看起来更容易合作 但会sqlite3能更快吗 只有30条 记录 会有什么明
  • 是否可以嵌套 Angular2 应用程序

    假设我想构建一个可以插入第三方页面的 Angular2 应用程序 第三方页面可能已经使用 Angular2 可能有不同的版本 是否可以在第三方应用程序中引导我的 Angular 2 应用程序
  • Java 交叉编译 - 最新 JDK 的好处

    我维护一个 Java 应用程序 在分发它之前我总是使用 JDK 1 6 进行编译 因为这是我的应用程序所需的最低版本 我不使用任何更新的功能 我不认为在更高版本中编译它有什么意义 否则旧的 JRE 将无法运行它 即使用 Java 1 6 的
  • 变量的前后增量操作在 TC 和 gcc 上给出不同的输出[重复]

    这个问题在这里已经有答案了 这是我的简单代码 include
  • 如何从javascript同步调用indexeddb方法

    我有一种方法说method1在 javascript 中 有另一种方法说method2 call method2在 method2 调用之后返回 method1 中需要的一个值 var userObj first Key 1 value s
  • 尝试保存图像时出现异常

    启动 Java 应用程序时 我在尝试保存图像时遇到异常 然而 在 Eclipse 中 一切都工作正常 该应用程序是使用 fatjar 构建的 并且还选择了导出所需的库 jar imageio jar 和 ij jar 我尝试使用 Image
  • Java - 如何更改“本地”?事件监听器中的变量

    有一个简短的问题 我希望有人能回答 基本上我有一个字符串变量 需要根据组合框中的值进行更改 该组合框中附加了一个事件侦听器 但是 如果我将字符串定为最终的 那么它就无法更改 但如果我不将其定为最终的 那么 Eclipse 就会抱怨它不是最终
  • Alexa 是否/可以替换其为链接用户生成的 UserId?

    我们有一个利用 Alexa 技能的应用程序 其中包含用户详细信息的帐户链接 根据 Alexa 的帐户关联 文档 我们的技能是为帐户链接设置的 帐户链接又引用第三方 或者可能是内部 身份管理系统 IMS 进行用户身份验证 我们的应用程序 以及
  • R:当列数为素数时分割数据框

    我有一个data frame有 131 列 我需要将其分成大约 10 到 15 个变量的组 即按列拆分 而不是按行拆分 显然 由于 131 是素数 因此并非所有新数据帧的长度都可以相同 我在帖子中寻找答案 如何在R中将数据切成偶数块 在 R
  • 带步骤选项的 math.random 函数? [关闭]

    Closed 这个问题需要细节或清晰度 目前不接受答案 一个自定义函数 它将返回一个随机数 并带有可用的步骤选项 如for环形 例子 for i 1 10 2 do print i end 你的意思是这样吗 function randomW
  • Spring Boot 与 Intellij IDE 的热插拔

    我有一个 Spring Boot 应用程序在 Intellij IDE 上运行良好 即我启动了具有委托给 SpringApplication run 的 main 方法的 Application 类 除了热插拔之外 一切都很好 当我更改源时
  • 为什么两种情况下的输出不同?

    即使变量已被覆盖 为什么在以下情况下输出不同 public class A int a 500 void get System out println a is this a public class B extends A int a 1
  • 强制 cpp_dec_float 向下舍入

    我在用 str n std ios base scientific 打印ccp dec floats 我注意到它四舍五入了 我在用cpp dec float对于会计 所以我需要向下舍入 如何才能做到这一点 它没有四舍五入 事实上 它是银行家
  • 带 if 语句的 auto 函数不会返回值

    我做了一个模板和一个auto比较 2 个值并返回最小的值的函数 这是我的代码 include
  • 使用 3d 数据和参数在 Scipy 中进行曲线拟合

    我正在努力在 scipy 中拟合 3d 分布函数 我有一个 numpy 数组 其中包含 x 和 y 仓中的计数 我试图将其拟合到相当复杂的 3 维分布函数中 该数据适合 26 个参数 这些参数描述了其两个组成群体的形状 我在这里了解到 当我