如果我们在管道中包含转换器,scikit-learn 的“cross_val_score”和“GridsearchCV”的 k 倍交叉验证分数是否有偏差?

2024-01-19

数据预处理器(例如 StandardScaler)应用于 fit_transform 训练集,并且仅变换(不拟合)测试集。我希望相同的拟合/转换过程适用于用于调整模型的交叉验证。然而,我发现cross_val_score and GridSearchCV使用预处理器fit_transform整个训练集(而不是fit_transforminner_train集,并变换inner_validation集)。我相信这人为地消除了 inner_validation 集中的方差,这使得 cv 分数(用于通过 GridSearch 选择最佳模型的指标)产生偏差。这是一个问题还是我实际上错过了什么?

为了演示上述问题,我使用 Kaggle 的威斯康星州乳腺癌(诊断)数据集尝试了以下三个简单的测试用例。

  1. 我故意将整个 X 拟合并变换为StandardScaler()
X_sc = StandardScaler().fit_transform(X)
lr = LogisticRegression(penalty='l2', random_state=42)
cross_val_score(lr, X_sc, y, cv=5)
  1. 我将 SC 和 LR 包含在Pipeline并运行cross_val_score
pipe = Pipeline([
    ('sc', StandardScaler()),
    ('lr', LogisticRegression(penalty='l2', random_state=42))
])
cross_val_score(pipe, X, y, cv=5)
  1. 与 2 相同,但具有GridSearchCV
pipe = Pipeline([
    ('sc', StandardScaler()),
    ('lr', LogisticRegression(random_state=42))
])
params = {
    'lr__penalty': ['l2']
}
gs=GridSearchCV(pipe,
param_grid=params, cv=5).fit(X, y)
gs.cv_results_

它们都产生相同的验证分数。 [0.9826087、0.97391304、0.97345133、0.97345133、0.99115044]


No, sklearn不做fit_transform与整个数据集。

为了检查这一点,我将其子类化StandardScaler打印发送给它的数据集的大小。

class StScaler(StandardScaler):
    def fit_transform(self,X,y=None):
        print(len(X))
        return super().fit_transform(X,y)

如果您现在更换StandardScaler在您的代码中,您会看到第一种情况下传递的数据集大小实际上更大。

但为什么准确率保持完全相同呢?我认为这是因为LogisticRegression对特征尺度不是很敏感。如果我们使用对规模非常敏感的分类器,例如KNeighborsClassifier例如,您会发现两种情况之间的准确性开始有所不同。

X,y = load_breast_cancer(return_X_y=True)
X_sc = StScaler().fit_transform(X)
lr = KNeighborsClassifier(n_neighbors=1)
cross_val_score(lr, X_sc,y, cv=5)

输出:

569
[0.94782609 0.96521739 0.97345133 0.92920354 0.9380531 ]

而第2个案例,

pipe = Pipeline([
    ('sc', StScaler()),
    ('lr', KNeighborsClassifier(n_neighbors=1))
])
print(cross_val_score(pipe, X, y, cv=5))

Outputs:

454
454
456
456
456
[0.95652174 0.97391304 0.97345133 0.92920354 0.9380531 ]

准确性方面变化不大,但仍然发生了变化。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如果我们在管道中包含转换器,scikit-learn 的“cross_val_score”和“GridsearchCV”的 k 倍交叉验证分数是否有偏差? 的相关文章

  • 如何编写从管道输入读取的 powershell 函数?

    SOLVED 以下是使用管道输入的函数 脚本的最简单示例 每个的行为都与通过管道传输到 echo cmdlet 相同 作为函数 Function Echo Pipe Begin Executes once before first item
  • Scikit-learn:如何获得 True Positive、True Negative、False Positive 和 False Negative

    我的问题 我有一个数据集 它是一个很大的 JSON 文件 我读取它并将其存储在trainList多变的 接下来 我对其进行预处理 以便能够使用它 完成后 我开始分类 我用kfold交叉验证方法以获得平均值 准确性并训练分类器 我做出预测并获
  • BertForSequenceClassification 是否在 CLS 向量上进行分类?

    我正在使用抱脸变压器 https huggingface co transformers index html使用 PyTorch 打包和 BERT 我正在尝试进行 4 向情感分类并正在使用BertFor序列分类 https hugging
  • Caffe 多输入图像

    我正在考虑实现一个 Caffe CNN 它接受两个输入图像和一个标签 后来可能是其他数据 并且想知道是否有人知道 prototxt 文件中执行此操作的正确语法 它只是一个带有额外顶部的 IMAGE DATA 层吗 或者我应该为每个层使用单独
  • 使用 joblib 加载 pickled scikit-learn 模型时出现 KeyError

    我有一个对象 其中包含两个scikit learn模型 一个IsolationForest and a RandomForestClassifier 我想对其进行 pickle 然后将其解开并用于生成预测 除了两个模型之外 该对象还包含几个
  • 使用xgboost进行分类时如何获得置信区间或预测离散度的度量?

    使用xgboost进行分类时如何获得置信区间或预测离散度的度量 例如 如果 xgboost 预测某个事件的概率为 0 9 如何获得该概率的置信度 这种置信度是否也被认为是异方差的 要为 xgboost 模型生成置信区间 您应该训练多个模型
  • Mobilenet 与 SSD [关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 Locked 这个问题及其答案是locked help locked posts因为这个问题是题外话 但却具有历史意义 目前不接受新的答案
  • 使用张量流导出神经网络的权重

    我使用张量流工具编写了神经网络 一切正常 现在我想导出神经网络的最终权重以制定单一的预测方法 我怎样才能做到这一点 您需要在训练结束时使用以下命令保存模型tf train Saver https www tensorflow org ver
  • 在防风草模型上使用 VIP 包计算重要性度量

    我正在尝试使用 vi firm 在防风草中制作的逻辑回归模型上计算特征重要性 对于正则表达式 我将使用 iris 数据集并尝试预测观察结果是否为 setosa iris1 lt iris gt mutate class case when
  • 使用神经网络包进行多项分类

    这个问题应该很简单 但文档没有帮助 我正在使用 R 我必须使用neuralnet多项式分类问题的包 所有示例均针对二项式或线性输出 我可以使用二项式输出进行一些一对一的实现 但我相信我应该能够通过使用 3 个单元作为输出层来做到这一点 其中
  • Tensorflow 2.0 中的二阶导数

    我正在尝试计算标量变量的简单向量函数的二阶导数f x x x 2 x 3 使用 TF 2 3 与tf GradientTape def f ab x return x x 2 x 3 import tensorflow as tf in1
  • Keras 错误:预计会看到 1 个数组

    当我尝试在 keras 中训练 MLP 模型时出现以下错误 我使用的是 keras 版本1 2 2 检查模型输入时出错 您输入的 Numpy 数组列表 传递给您的模型的尺寸不是模型预期的尺寸 预期的 查看 1 个数组 但得到以下 12859
  • 如何跨多个文本文件查找字典中键的频率?

    我应该计算文档 individual articles 中所有文件中字典 d 的所有键值的频率 这里 文档 individual articles 大约有20000个txt文件 文件名为1 2 3 4 例如 假设 d Britain 5 7
  • 如何使用机器学习从数据序列计算状态图?

    通用配方 我有一个由一系列点组成的数据集 每个点有 12 个特征 我有兴趣检测此数据中的事件 在训练数据中我知道事件发生的时刻 当事件发生时 我可以在事件发生之前的点序列中看到可观察到的模式 该形态由大约 300 个连续点形成 我感兴趣的是
  • 使用 glmnet 纠正 n 个数据集上的 n 个 LASSO 回归的输出(严格来说是所选的特征/变量)

    注意 这是对上一个问题 https stackoverflow com questions 75006466 how to replicate my results from running n lassos iteratively usi
  • scikit随机森林sample_weights的使用

    我一直在试图弄清楚 scikit 的随机森林样本权重的使用 但我无法解释我看到的一些结果 从根本上说 我需要它来平衡分类问题和不平衡类 特别是 我期望如果我使用全 1 的 sample weights 数组 我会得到与以下相同的结果w sa
  • 收到的标签值 1 超出了 [0, 1) 的有效范围 - Python、Keras

    我正在使用具有张量流背景的 keras 开发一个简单的 cnn 分类器 def cnnKeras training data training labels test data test labels n dim print Initiat
  • 如何为DNA序列生成一种热编码?

    我想为一组 DNA 序列生成一个热编码 例如 序列ACGTCCA可以以转置方式表示如下 但下面的代码将以水平方式生成一种热门编码 我更喜欢以垂直方式生成 谁能帮我 ACGTCCA 1000001 A 0100110 C 0010000 G
  • 如何在 keras 模型中使用张量流度量函数?

    使用Python 3 5 2张量流RC 1 1 我正在尝试在 keras 中使用张量流度量函数 所需的功能接口似乎是相同的 但调用 import pandas import numpy import tensorflow contrib k
  • 机器学习的周期性数据(例如度角 -> 179 与 -179 相差 2)

    我使用 Python 进行核密度估计 并使用高斯混合模型对多维数据样本的可能性进行排名 每一条数据都是一个角度 我不确定如何处理机器学习的角度数据的周期性 首先 我通过添加 360 来删除所有负角 因此所有负角都变成了正角 179 变成了

随机推荐

  • 调试使用 ES6 模块的 JavaScript 代码

    TL DR 如何从调试器访问 ES 模块中定义的变量 函数 名称 更多背景信息 我是一位经验相对丰富的 JavaScript 程序员 但对模块还是个新手 我已经按照 MDN 上的教程进行操作 https developer mozilla
  • CUDA 就地转置错误

    我正在实现一个 CUDA 程序来转置图像 我创建了 2 个内核 第一个内核进行了异位转置 并且适用于任何图像尺寸 然后我创建了一个用于方形图像就地转置的内核 但是 输出不正确 图像的下三角形被转置 但上三角形保持不变 生成的图像在对角线上有
  • 如何在 android room 和 rxjava 2 中插入数据并获取 id 作为输出参数?

    插入查询 Insert onConflict OnConflictStrategy REPLACE long insertProduct Product product product id is auto generated 查看模型 p
  • tvos UISegmentedControl 焦点样式不改变

    我想在 tvOS 中突出显示 UISegmentedControl 时更改其背景颜色 Normally Segment display like following When change focus for change selected
  • 训练神经网络时出现极小或 NaN 值

    我正在尝试在 Haskell 中实现神经网络架构 并在 MNIST 上使用它 我正在使用hmatrix线性代数包 我的训练框架是使用pipes包裹 我的代码可以编译并且不会崩溃 但问题是 层大小 例如 1000 小批量大小和学习率的某些组合
  • 如何将 DATETIME 转换为 mysql 中的 DATE?

    我的查询是这样的 我有一堆条目 我想按日期对它们进行分组 但我的数据库中没有日期 而是有一个日期时间字段 我该怎么办 select from follow queue group by follow date cast follow dat
  • 反序列化 JSON 对象的一部分并将其序列化回来,其余属性保持不变

    我有一些 JSON 想要将其反序列化为 C 类的实例 但是 该类并不具有与原始 JSON 匹配的所有字段 属性 我希望能够修改类中的属性值 然后将其序列化回 JSON 并且原始 JSON 中的剩余字段和属性仍然完好无损 例如 假设我有以下
  • 强制 TkInter Scale 滑块捕捉到鼠标

    当 GUI 有 TkInter 时Scale当他们单击刻度上的某个位置时 默认行为似乎是沿着刻度向鼠标方向滑动滑块 然后意外地经过鼠标 我想要的是让滑块在用户单击滑块上的任意位置时始终跳转到并保持连接到用户的鼠标点 如果他们单击刻度上的特定
  • Int 和 Integer 有什么区别?

    在 Haskell 中 a 和 a 有什么区别Int and an Integer 答案记录在哪里 Integer 是任意精度 类型 它将保存任何数字 no 无论多大 直到极限 你机器的内存 这意味着你从来没有 算术溢出 在另一 手也意味着
  • 重写 Wildfly 引擎

    我想知道是否可以在没有任何第三方库的情况下使用 Wildfly 应用程序服务器的重写引擎 我尝试过使用重写阀 https help openshift com hc en us articles 202398810 How to redir
  • Rails 控制器中的实例和类变量

    我是 Rails 和 ruby 的新手 我正在研究类和实例变量的概念 我理解其中的区别 但是当我在 Rails 中使用控制器进行尝试时 它让我感到困惑 我所做的是在类方法之外声明一个类和实例变量 class BooksController
  • 有没有办法强制使用 Zend_Auth 进行身份验证?

    我正在使用 Zend Auth 和 cookie 会话持久性 我似乎无法弄清楚如何强制使用此类进行身份验证 有没有办法强制 Zend Auth 相信它已经作为用户进行身份验证 Zend Auth getInstance gt getStor
  • 直接自引用导致循环超类问题 JSON

    我尝试了在搜索时发现的几件事 但没有任何帮助 或者我没有正确实现它 我收到错误 Direct self reference leading to cycle through reference chain io test entity bo
  • 创建覆盖 ImageView 动画 Google 地图

    我正在尝试使我的叠加图像执行以下操作 地图的 onClick onDrag 在地图中间显示恒定图像 这是一个引脚 onTouchUp 将标记图像更改为加载标记和一次数据 加载完整更改将图像加载到带有文本的新图像 这是与我的问题非常相似的解决
  • 如何让我的自定义帐户类型显示在 Android 联系人应用程序中?

    我已经创建了一个自定义帐户类型 并且可以在 android ContactsContract ContentProvider 中成功创建该类型的联系人 但我在弄清楚如何在默认联系人应用程序中编辑联系人时显示我的自定义帐户标签和图标时遇到了很
  • 为什么 git 将我的分支名称前缀大写?

    我有一组非常简单的 git 命令 这会导致一些奇怪的行为 显示我当前的本地分支机构 并查看我在release beta1 git branch develop master release beta1 创建一个bugfix somefeat
  • System.Net.WebRequest 不尊重主机文件

    有没有办法获得System Net WebRequest or System Net WebClient尊重hosts or lmhosts file 例如 在我的主机文件中 我有 10 0 0 1 www bing com 当我尝试在浏览
  • 使用 async/await 有什么优点?

    我正在阅读几篇文章并观看一些有关如何在 JavaScript 中使用 async await 的视频 似乎唯一的原因是将异步代码转换为同步代码 并使代码更具可读性 但这并不打算在这个问题中讨论 因此 我想了解使用这些语句是否有更多原因 因为
  • 如何使用 CKFinder Javascript API?

    有趣的问题 但老实说我无法访问 例如 CKFinder dataTypes Folder http docs cksource com ckfinder 2 x api symbols CKFinder dataTypes Folder h
  • 如果我们在管道中包含转换器,scikit-learn 的“cross_val_score”和“GridsearchCV”的 k 倍交叉验证分数是否有偏差?

    数据预处理器 例如 StandardScaler 应用于 fit transform 训练集 并且仅变换 不拟合 测试集 我希望相同的拟合 转换过程适用于用于调整模型的交叉验证 然而 我发现cross val score and GridS