子类化 sklearn LinearSVC 以用作 sklearn GridSearchCV 的估计器

2023-12-01

我正在尝试创建一个子类sklearn.svm.LinearSVC用作估计器sklearn.model_selection.GridSearchCV。子类有一个额外的函数,在本例中不执行任何操作。然而,当我运行这个时,我最终遇到了一个我似乎无法调试的错误。如果您复制粘贴代码并运行,它应该会重现以以下结尾的完整错误ValueError: Input contains NaN, infinity or a value too large for dtype('float64')

一旦我让他工作,我希望为该方法添加更多功能transform_this().

有人可以告诉我哪里出了问题吗?基于this我首先认为这是由于我的数据存在一些问题。然而,由于我使用 sklearn 内置数据集重现了它,所以情况似乎并非如此。另外,我相信我根据我对上一个问题的回答正确地对此进行了子类化here。另外,我了解到 GridSearchCV 似乎没有以不同的方式初始化估计器(不知何故,它首先使用默认参数,正如我从这个帖子)

from sklearn.datasets import load_breast_cancer
from sklearn.svm import LinearSVC
from sklearn.model_selection import GridSearchCV

RANDOM_STATE = 123


class LinearSVCSub(LinearSVC):
    def __init__(self, penalty='l2', loss='squared_hinge', additional_parameter1=1, additional_parameter2=100,
                 dual=True, tol=0.0001, C=1.0, multi_class='ovr', fit_intercept=True, intercept_scaling=1,
                 class_weight=None, verbose=0, random_state=None, max_iter=1000):
        super(LinearSVCSub, self).__init__(penalty=penalty, loss=loss, dual=dual, tol=tol,
                                           C=C, multi_class=multi_class, fit_intercept=fit_intercept,
                                           intercept_scaling=intercept_scaling, class_weight=class_weight,
                                           verbose=verbose, random_state=random_state, max_iter=max_iter)

        self.additional_parameter1 = additional_parameter1
        self.additional_parameter2 = additional_parameter2

    def fit(self, X, y, sample_weight=None):
        X = self.transform_this(X)
        super(LinearSVCSub, self).fit(X, y, sample_weight)

    def predict(self, X):
        X = self.transform_this(X)
        super(LinearSVCSub, self).predict(X)

    def score(self, X, y, sample_weight=None):
        X = self.transform_this(X)
        super(LinearSVCSub, self).score(X, y, sample_weight)

    def decision_function(self, X):
        X = self.transform_this(X)
        super(LinearSVCSub, self).decision_function(X)

    def transform_this(self, X):
        return X


if __name__ == '__main__':
    data = load_breast_cancer()
    X, y = data.data, data.target

    # Parameter tuning with custom LinearSVC
    param_grid = {'C': [0.00001, 0.0001, 0.0005],
                      'dual': (True, False), 'random_state': [RANDOM_STATE],
                      'additional_parameter1': [0.90, 0.80, 0.60, 0.30],
                      'additional_parameter2': [20, 30]}

    gs_model = GridSearchCV(estimator=LinearSVCSub(), verbose=1, param_grid=param_grid,
                            scoring='roc_auc', n_jobs=-1)
    gs_model.fit(X, y)

你有几个问题:

  1. 定义的方法没有 return 语句
  2. 您选择的数据集不收敛LinearSVC

一旦您纠正了这些问题,您就可以开始:

from sklearn.datasets import make_classification
from sklearn.svm import LinearSVC
from sklearn.model_selection import GridSearchCV

RANDOM_STATE = 123


class LinearSVCSub(LinearSVC):
    def __init__(self, penalty='l2', loss='squared_hinge', additional_parameter1=1, additional_parameter2=100,
                 dual=True, tol=0.0001, C=1.0, multi_class='ovr', fit_intercept=True, intercept_scaling=1,
                 class_weight=None, verbose=0, random_state=None, max_iter=100000):
        super(LinearSVCSub, self).__init__(penalty=penalty, loss=loss, dual=dual, tol=tol,
                                           C=C, multi_class=multi_class, fit_intercept=fit_intercept,
                                           intercept_scaling=intercept_scaling, class_weight=class_weight,
                                           verbose=verbose, random_state=random_state, max_iter=max_iter)

        self.additional_parameter1 = additional_parameter1
        self.additional_parameter2 = additional_parameter2

    def fit(self, X, y, sample_weight=None):
        X = self.transform_this(X)
        super(LinearSVCSub, self).fit(X, y, sample_weight)
        return self

    def predict(self, X):
        X = self.transform_this(X)
        return super(LinearSVCSub, self).predict(X)

    def score(self, X, y, sample_weight=None):
        X = self.transform_this(X)
        return super(LinearSVCSub, self).score(X, y, sample_weight)

    def decision_function(self, X):
        X = self.transform_this(X)
        return super(LinearSVCSub, self).decision_function(X)

    def transform_this(self, X):
        return X


X, y = make_classification()

# Parameter tuning with custom LinearSVC
param_grid = {'C': [0.00001, 0.0001, 0.0005],
                  'dual': (True, False), 'random_state': [RANDOM_STATE],
                  'additional_parameter1': [0.90, 0.80, 0.60, 0.30],
                  'additional_parameter2': [20, 30]
             }

gs_model = GridSearchCV(estimator=LinearSVCSub(), verbose=1, param_grid=param_grid,
                        scoring='roc_auc', n_jobs=1)

gs_model.fit(X, y)
Fitting 5 folds for each of 48 candidates, totalling 240 fits
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 240 out of 240 | elapsed:    0.9s finished
GridSearchCV(estimator=LinearSVCSub(), n_jobs=1,
             param_grid={'C': [1e-05, 0.0001, 0.0005],
                         'additional_parameter1': [0.9, 0.8, 0.6, 0.3],
                         'additional_parameter2': [20, 30],
                         'dual': (True, False), 'random_state': [123]},
             scoring='roc_auc', verbose=1)

gs_model.predict(X)
array([0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1,
       1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1,
       1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0,
       0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

子类化 sklearn LinearSVC 以用作 sklearn GridSearchCV 的估计器 的相关文章

  • Django:如何测试“HttpResponsePermanentRedirect”

    我正在为我的 django 应用程序编写一些测试 在我看来 它使用 HttpResponseRedirect 重定向到其他一些网址 那么我该如何测试呢 姜戈TestCase类有一个方法assertRedirects https docs d
  • 行未从树视图复制

    该行未在树视图中复制 我在按行并复制并粘贴到未粘贴的任何地方后制作了弹出复制 The code popup tk Menu tree opportunity tearoff 0 def row copy item tree opportun
  • 为什么 .setGeometry() 不改变 QWidget 实例的大小?

    我想使用 QWidget 更改 QPushButton 的大小 setGeometry https doc qt io qtforpython 5 PySide2 QtWidgets QWidget html PySide2 QtWidge
  • 在Python3.6中调用C#代码

    由于完全不了解 C 编码 我希望在我的 python 代码中调用 C 函数 我知道有很多关于同一问题的问答 但由于一些奇怪的原因 我无法从示例 python 模块导入简单的 c 类库 以下是我所做的事情 C 类库设置 我使用的是 VS 20
  • 在python中将文本文件解析为列表

    我对 Python 完全陌生 我正在尝试读取包含单词和数字组合的 txt 文件 我可以很好地读取 txt 文件 但我正在努力将字符串转换为我可以使用的格式 import matplotlib pyplot as plt import num
  • Python 3 __getattribute__ 与点访问行为

    我读了一些关于 python 的对象属性查找的内容 这里 https blog ionelmc ro 2015 02 09 understanding python metaclasses object attribute lookup h
  • 如何使用Python将WebP图像转换为Gif?

    我已经尝试过这个 from PIL import Image im Image open this webp im save that gif gif save all True 这给了我这个错误 类型错误 不支持的操作数类型 tuple
  • Paramiko - 使用私钥连接 - 不是有效的 OPENSSH 私钥/公钥文件

    我正在尝试找到解决方案 但无法理解我做错了什么 在我的 Linux 服务器上 我运行了以下命令 ssh keygen t rsa 这产生了一个id rsa and id rsa pub file 然后我将它们复制到本地并尝试运行以下代码 s
  • Pandas重置索引未生效[重复]

    这个问题在这里已经有答案了 我不确定我在哪里误入歧途 但我似乎无法重置数据帧上的索引 当我跑步时test head 我得到以下输出 正如您所看到的 数据帧是一个切片 因此索引超出范围 我想做的是重置该数据帧的索引 所以我跑test rese
  • 为什么 Python 中的“pip install”会引发语法错误?

    我正在尝试使用 pip 安装软件包 我试着跑pip install从Python shell 但我得到了SyntaxError 为什么我会收到此错误 如何使用 pip 安装软件包 gt gt gt pip install selenium
  • 如果字段值在外部列表中,Django 会注释布尔值

    想象一下我有这个 Django 模型 class Letter models Model name models CharField max length 1 unique True 还有这个列表 vowels a e i o u 我想查询
  • sudo pip install python-Levenshtein 失败,错误代码 1

    我正在尝试在 Linux 上安装 python Levenshtein 库 但每当我尝试通过以下方式安装它时 sudo pip install python Levenshtein 我收到此错误 命令 usr bin python c 导入
  • 创建一个类似于 Tkinter 的表

    我希望创建类似于 Tkinter 中的表格的东西 但它不一定是这样的 例如 我想创建标题 Name1 Name2 Value 并在每个标题下面有几个空白行 然后 我希望稍后用我计算的值或名称的字符串值填充这些行 因此是标签 对于 Name2
  • 如何通过双击在浏览器中打开 ipynb 文件

    以前 我安装了 Canopy 当时 我只需双击 ipynb 文件并在浏览器中打开它们即可 但是 后来我需要Anaconda 一旦我安装了它 这个功能就没有了 现在我只希望能够简单地双击 ipynb 文件 然后该文件就会在 Firefox 中
  • 更改 pandas 中多个日期时间列的时区信息

    有没有一种简单的方法可以将数据帧中的所有时间戳列转换为本地 任何时区 不是逐列进行吗 您可以有选择地将转换应用于所有日期时间列 首先 选择它们select dtypes https pandas pydata org pandas docs
  • 将输入发送到 python 子进程而不等待结果

    我正在尝试为一段代码编写一些基本测试 该代码通常通过 stdin 无休止地接受输入 直到给出特定的退出命令 我想检查程序是否在给出一些输入字符串时崩溃 经过一段时间来考虑处理 但似乎无法弄清楚如何发送数据而不是陷入等待我不知道的输出关心 我
  • Pandas Dataframe:将包含列表的行扩展到多行,并为所有列提供所需的索引

    我在 pandas 数据框中有时间序列数据 索引为测量开始时的时间 列中包含以固定采样率记录的值列表 连续索引 列表中元素数量的差异 这是它的样子 Time A B Z 0 1 2 3 4 1 2 3 4 2 5 6 7 8 5 6 7 8
  • Airflow Python 单元测试?

    我想为我们的 DAG 添加一些单元测试 但找不到任何单元测试 有 DAG 单元测试框架吗 有一个端到端的测试框架存在 但我猜它已经死了 https issues apache org jira browse AIRFLOW 79 https
  • 全局变量是 None 而不是实例 - Python

    我正在处理Python 中的全局变量 代码应该可以正常工作 但是有一个问题 我必须使用全局变量作为类的实例Back 当我运行应用程序时 它说 back is None 这应该不是真的 因为第二行setup 功能 back Back Back
  • 检查字符串是否只有字母和空格 - Python

    试图让 python 返回一个字符串仅包含字母和空格 string input Enter a string if all x isalpha and x isspace for x in string print Only alphabe

随机推荐

  • 如何使用 IResourceChangeListener 检测文件重命名并动态设置 EditorPart 名称?

    IResourceChangeListener监听项目工作区中的更改 例如编辑器零件文件名是否已更改 我想知道如何访问该特定的EditorPart并相应地更改其标题名称 例如 setPartName 或者刷新编辑器以便它自动显示新名称 理想
  • 在 Highchart 样条图上的最后一点显示指标

    我知道如何在最后一点显示标记 例如this 当数据是动态的时候 不知道如何标记最后一个点 plotOptions column stacking normal spline marker enabled true 当您动态添加新点时 您可以
  • 在缓冲区对象上运行并通过着色器更改其数据? [关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心以获得指导 有没有办法在缓冲区对象上运
  • MSSQL 选择“垂直”-其中

    除了 垂直位置 之外 我真的不知道如何解释 想象一下下表 TAGID PRODUCTID SHOP ID 59 3418 7 38 61 3418 7 38 60 4227 4 38 61 4227 4 38 现在我想返回与标签 ID 相关
  • 实例成员不能用于类型

    我有以下课程 class ReportView NSView var categoriesPerPage Int var numPages Int return categoriesPerPage count 编译失败并显示消息 实例成员
  • 链接共享 C 库时 Android NDK 错误

    我正在尝试将一些 C 文件链接到我正在处理的 NDK 项目 并设置我的CMakeLists txt像下面这样归档 cmake minimum required VERSION 3 4 1 set CMAKE C FLAGS CMAKE C
  • 润滑 as_date 和。 as_datetime 行为差异

    我有一个数字向量 表示自 1970 年 1 月 1 日以来的毫秒数 我想使用以下方法将它们转换为日期时间对象lubridate 数据示例如下 raw times lt c 1139689917479 1139667123031 114036
  • 使用浏览器控制台使用 Javascript 在 Facebook 中发送聊天消息

    我尝试使用 Javascript 在 Facebook 中发送聊天消息 但不断收到错误消息 要么是TypeError Object
  • 我可以对 Linux 进程的地址空间中的每个页面进行写保护吗?

    我想知道是否有一种方法可以对 Linux 中的每个页面进行写保护 进程的地址空间 从进程本身的内部 通过mprotect 我所说的 每一页 实际上是指该网站的每一页 进程的地址空间可以被普通进程写入 程序在用户模式下运行 所以 程序文本 常
  • ServiceStack Javascript JsonServiceClient 缺少属性

    我正在尝试使用 Servicestack JsonServiceClient 连接到经过 JWT 身份验证的服务 但是文档仅描述了如何使用 C 客户端执行此操作 http docs servicestack net jwt authprov
  • 计时器不包含在 Xamarin.Forms 的 System.Threading 中

    I used System Threading Timer in Xamarin Android 我如何在中使用同一个类Xamarin Forms 我想从 Xamarin Forms 中的 Xamarin Android 转移我的项目 pu
  • 单击按钮更改颜色在重新加载或重新启动页面后保持不变

    我创建了锚标记 其中使用心形图标 单击后会更改颜色 但我想在重新加载或重新启动页面后保持相同的颜色 当我重新启动或重新加载页面时 它会恢复默认颜色 var btnvar document getElementById favorite fu
  • 如何从 C# 调用 MongoDb 中存储的 JavaScript

    我正在评估将 SQL Server 数据库移植到 MongoDb 问题是移动存储过程 我读到了有关 MongoDb 存储 JavaScript 的内容 我想在 Net 中进行一些测试 我已经安装了 MongoDb 驱动程序 2 4 0 并在
  • 搜索数组中的连续值

    在数组中搜索连续值的最佳方法是什么 例如 搜索array a b in array x a b c 会产生1 因为这些值首先连续出现在该索引处 还没有测试过这个 但类似这样的事情应该可以 function consecutive value
  • 使用 PHP 接收 JSON POST

    我尝试在支付接口网站上接收 JSON POST 但无法对其进行解码 当我打印时 echo POST I get Array 当我尝试这个时我什么也没得到 if POST foreach POST as key gt value echo l
  • 圆与圆的交点

    如何计算两个圆的交点 我希望在所有情况下都会有两个 一个或没有交点 我有中心点的 x 和 y 坐标以及每个圆的半径 python 中的答案是首选 但任何工作算法都是可以接受的 两个圆的交点 保罗 伯克 编剧 The following no
  • Linq to SQL 是如何工作的?

    我在项目中使用 Linq to SQL 我使用它从 SQL 存储过程中获取数据 它工作完美 但我不明白 LINQ SQL 内部如何与 SQL Server 通信 它在获取数据后将数据存储在哪里 它从哪里获取连接字符串 提前致谢 更好读 ht
  • 为什么使用不带 lambda 的内联

    我试图了解如何使用inline修改正确 我了解一般情况 当我们内联 lambda 以防止过度分配时 如中所述docs 我正在检查 kotlin stdlib 并发现 Strings kt下面这段代码 kotlin internal Inli
  • 在vBulletin中使用curl登录网站

    我一直在尝试登录某个网站 www siamchart 论坛 按照此链接上的说明进行操作 使用 PHP cURL 登录远程站点 我无法通过登录 运行以下脚本后 它将我重定向到相同的登录页面 www siamchart forum 但没有成功登
  • 子类化 sklearn LinearSVC 以用作 sklearn GridSearchCV 的估计器

    我正在尝试创建一个子类sklearn svm LinearSVC用作估计器sklearn model selection GridSearchCV 子类有一个额外的函数 在本例中不执行任何操作 然而 当我运行这个时 我最终遇到了一个我似乎无