scikit learn:与 GridSearchCV 兼容的自定义分类器

2024-03-18

我已经实现了自己的分类器,现在我想对其运行网格搜索,但出现以下错误:estimator.fit(X_train, y_train, **fit_params) TypeError: fit() takes 2 positional arguments but 3 were given

我跟着本教程 http://danielhnyk.cz/creating-your-own-estimator-scikit-learn/并使用这个模板 https://github.com/scikit-learn-contrib/project-template/blob/master/skltemplate/template.py由...提供scikit 的官方文档 http://scikit-learn.org/stable/developers/contributing.html。我的类定义如下:

class MyClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, lr=0.1):
        self.lr=lr

    def fit(self, X, y):
        # Some code
        return self
    def predict(self, X):
        # Some code
        return y_pred
    def get_params(self, deep=True)
        return {'lr'=self.lr}
    def set_params(self, **parameters):
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self

我正在尝试网格搜索,如下所示:

params = {
    'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)

EDIT I

我就是这样称呼它的: gs.fit(['你好世界', '尝试', '你好世界', '尝试', '你好世界', '尝试', '你好世界', '尝试'], ['我','Z','我','Z','我','Z','我','Z'])

结束编辑一

错误是由以下原因产生的_fit_and_score文件中的方法python3.5/site-packages/sklearn/model_selection/_validation.py

它在呼唤estimator.fit(X_train, y_train, **fit_params)有 3 个参数,但我的估计器只有两个,所以这个错误对我来说是有意义的,但我不知道如何解决它......我还尝试添加一些虚拟参数fit方法但没有成功。

EDIT II

完整的错误输出:

Traceback (most recent call last):
  File "/home/rodrigo/no_version/text_classifier/MyClassifier.py", line 355, in <module>
    ['I', 'Z', 'I', 'Z', 'I', 'Z', 'I', 'Z'])
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/model_selection/_search.py", line 639, in fit
    cv.split(X, y, groups)))
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 779, in __call__
    while self.dispatch_one_batch(iterator):
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 625, in dispatch_one_batch
    self._dispatch(tasks)
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 588, in _dispatch
    job = self._backend.apply_async(batch, callback=cb)
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/_parallel_backends.py", line 111, in apply_async
    result = ImmediateResult(func)
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/_parallel_backends.py", line 332, in __init__
    self.results = batch()
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 131, in __call__
    return [func(*args, **kwargs) for func, args, kwargs in self.items]
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 131, in <listcomp>
    return [func(*args, **kwargs) for func, args, kwargs in self.items]
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/model_selection/_validation.py", line 458, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
TypeError: fit() takes 2 positional arguments but 3 were given

结束编辑二

SOLVED谢谢大家,我犯了一个愚蠢的错误:有两个不同的函数具有相同的名称(适合),(我使用不同的参数实现了另一个用于自定义目的,一旦我重命名了“自定义适合”,它就正常工作了。)

谢谢你并抱歉


以下代码对我有用:

class MyClassifier(BaseEstimator, ClassifierMixin):
     def __init__(self, lr=0.1):
         self.lr = lr
         # Some code
         pass
     def fit(self, X, y):
         # Some code
         pass
     def predict(self, X):
         # Some code
         return X % 3

params = {
    'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)

x = np.arange(30)
y = np.concatenate((np.zeros(10), np.ones(10), np.ones(10) * 2))
gs.fit(x, y)

我能想到的最好的办法就是你正在将一些东西传递给gs.fit超越方法x and y或你的MyClassifier.fit方法缺少 self 参数。

The fit_params仅当您将 kwarg 传递给gs.fit方法,否则它是一个空字典({}) and **fit_params不会抛出参数错误。要测试这一点,请创建分类​​器的实例并传递**{}。例如:

clf = MyClassifier()
clf.fit(x, y, **{})

这不会引发位置参数错误。

因此,再次除非将某些内容传递给gs.fit e.g. gs.fit(x, y, some_arg=123)在我看来,您缺少定义中的位置参数之一MyClassifier.fit。您所包含的错误消息似乎支持这一假设,因为它指出fit() takes 2 positional arguments but 3 were given。如果您按如下方式定义 fit ,则需要 3 个位置参数:

def fit(self, X, y): ...
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

scikit learn:与 GridSearchCV 兼容的自定义分类器 的相关文章

随机推荐

  • Windows批处理脚本获取当前驱动器名称

    我有一个批处理文件 位于 USB 密钥上 我需要知道批次所在的驱动器名称 例如 如果它是 E mybatch bat 则打开时应该找到 E 与 F G 等相同的内容 我怎样才能在批处理脚本中做到这一点 视窗 CD 这就是您正在寻找的 它打印
  • Azure Redis 缓存 - GET 调用超时

    我们在 Azure 中有多个 Web 和辅助角色通过 StackExchange Redis 库连接到我们的 Azure Redis 缓存 并且我们经常收到超时 这使得我们的端到端解决方案陷入停滞 其中之一的示例如下 System Time
  • Flutter WebView 位置

    我正在创建该网站的 WebViewhttps nearxt com https nearxt com 它在打开时询问位置 但是当我使用此链接在 flutter 中创建 webview 时 那么它就无法定位 我还在应用程序中定义了位置 但 w
  • 如何在 CSV 文件中写入多行?

    我怎样才能创建一个 csv文件 在这个 csv我想写数据包的信息 这是我的代码 https www tcpdump org sniffex c https www tcpdump org sniffex c我想写入我的文件 csv一些印刷品
  • Prolog 中不带双精度的列表的所有组合

    有没有一种简单的方法可以获取列表的所有组合而无需双精度 没有双打我的意思是也没有彼此的排列 所以不行 a b c and c a b or c b a 因此对于输入 a b c 输出将是 a b c a b a c b c a b c 我只
  • 是否可以在不知道 Firebase 数据库中的两个自动生成的键内获取值?

    我在每个自动生成的键下添加了一些值 并添加了与每个键对应的子项 示例学生 然后我添加了自动生成的密钥与这个孩子 学生 和值 现在的问题是我如何从这个序列中获取值 我正在为学生使用模型课程 可以吗 要读取此数据 然后在您的应用程序中处理它 您
  • 如何删除名称以点(“.”)结尾的文件夹?

    我收到了一些由恶意软件创建的文件夹 其名称以点结尾 例如C a or C b etc 我找到了一个可以使用命令删除此类文件夹的解决方案rd q s C a 但如果我调用 win APIRemoveDirectory http msdn mi
  • 第一次加载页面时出现“无法在框架中查看此内容”错误

    我开发了一个搜索表单 托管在我公司的本地服务器 iis net core 网站 中 该网站是托管在另一台服务器 apache wamp 上的 Wordpress 该服务器也在公司内 两者都有不同的公共IP 但两者都托管在同一域的子域下 比如
  • Rails 每个循环每 6 个项目插入标签?

    我有 X 个图像对象 需要在视图中循环遍历 并希望每 6 个对象左右创建一个新的 div 对于画廊 我看过cycle 但它似乎改变了所有其他记录 有谁知道每 6 次向视图中插入代码的方法吗 我可能可以用嵌套循环来做到这一点 但我对这个有点难
  • C++ 输入运算符重载

    我正在尝试重载我创建的 UserLogin 类上的输入运算符 不会引发编译时错误 但也不会设置值 一切都在运行 但 ul 的内容仍然存在 字符串 id 是 sally 登录时间为00 00 注销时间为 00 00 入口点 include
  • 如何使 videojs 标记可滑动或可移动

    我想移动我的markers每当它随着搜索一起滑动时 我希望我的标记准确无误slidable as jqueryui 滑块 问题 我想要我的markers 两者 一样可滑动jqueryui range滑块如以下示例中的视频所示 var pla
  • Retrofit+OkHttp 发送 GET 请求时可以,但发送 POST 时给出 SocketTimetout

    我从 Retrofit 开始 可以成功执行 GET 请求 但是当我尝试执行 POST 或 PUT 请求时 出现 SocketTimeOut 异常 我根据以下内容将 OkHttp 添加到我的 libs 文件夹中这个问题 https stack
  • @JsonInclude(Ininclude.NON_NULL) 未按预期工作

    我已经添加了 JsonInclude Include NON NULL Response 类上的注释 JsonInclude Include NON NULL public class Response JsonProperty priva
  • Nuxt - 将脚本添加到头部和主体

    我正在尝试在我的 Nuxt 应用程序中使用此脚本 但不知道如何操作 在基本的 HTML 文件中 它工作得很好 这是代码
  • Hibernate Envers 修订信息(更改列表)

    我想在我的项目中添加修订更改列表 单击信息图标 例如 Revision X added fieldA entry modified fieladB from B to BB removed fieldC entry 哪个是最好的方法 ps
  • Xcode 连接到 MS SQL 数据库

    我有一个现有数据库已在远程启动并运行MS SQL server 并且我希望能够与该数据库进行通信和交互Xcode 我正在写一份申请OS X in Swift以及应用程序应使用的数据存储在该远程数据库中 问题是我好像找不到Swift可以连接到
  • 如何传递 bquote 的符号字符串以在 ggplot 中求值?

    我在函数中创建的 ggplot 的轴标签有所不同 有些标签有上标 下标 而另一些则没有 例子 m data lt data frame x runif 10 y runif 10 x labs lt c rain mm light W m
  • array_walk 匿名函数

    有没有办法让我用匿名函数来获取这个数组来设置值 url array dog cat fish array walk url function value key url key str replace dog value echo pre
  • Azure CLI 运行命令使用参数调用 RunPowerShellScript

    我一直在尝试在 Azure VM 上运行一个脚本 该脚本需要像这样传递参数 az vm run command invoke g
  • scikit learn:与 GridSearchCV 兼容的自定义分类器

    我已经实现了自己的分类器 现在我想对其运行网格搜索 但出现以下错误 estimator fit X train y train fit params TypeError fit takes 2 positional arguments bu