机器学习模型评估与改进: 交叉验证(cross validation)

2023-05-16

文章目录

  • 交叉验证
      • 调用方法
      • 优势和不足
      • 注意事项:
  • 分层k折交叉验证
  • 交叉验证的更多变形
    • leave-one-out交叉验证
    • Shuffle-split交叉验证
    • 组间的交叉验证
  • 总结

以监督学习的众多算法为例,不管是分类还是回归,都有很多不同的算法模型,在不同的问题中,这些算法模型的表现是不同的。如何对模型的表行进行评估和改进呢?scikit learn网站给出了这样一个模型评估和改进的流程图:

在这里插入图片描述
首先我们再来看看模型评估的过程,在模型训练时,我们首先可以用scikit learn的model_selection模块train_test_split函数对数据划分,分为训练集合和测试集合。对于验证模型的泛化能力,测试集合至关重要。

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

交叉验证

交叉验证是一种统计学方法,用于衡量算法表现是否稳定。在交叉验证里,数据不是简单按照某个比例分为训练集合和测试集合,而是将数据如下图做多次划分,并且基于这些划分,训练多个机器学习模型。这也就是所谓的k折交叉验证(k-fold cross validation),k通常为5或者10.

以5折交叉验证为例,数据首先被均匀地分为5份(“折”,fold),取其中一份作为测试集合,其他为训练集合,训练一个模型。之后,轮流地选择其中的一折作为测试集合,其他为训练集合,再依次训练模型。

在这里插入图片描述

调用方法

scikit-learn提供了非常简便的方法调用交叉验证。只需要从model_selection模块中加载cross_val_score函数就可以了。以鸢尾花数据集,logistic回归预测为例:

from sklearn.model_selection import cross_val_score 
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
iris = load_iris()
logreg = LogisticRegression()
scores = cross_val_score(logreg, iris.data, iris.target) 
print("Cross-validation scores: {}".format(scores))

通常cross_val_score函数默认是3折交叉验证,运行后输出Cross-validation scores: [ 0.961 0.922 0.958]。 可以在cross_val_score() 函数中修改默认设置,例如改成5折交叉验证:

scores = cross_val_score(logreg, iris.data, iris.target, cv=5)

正确率为[ 1. 0.967 0.933 0.9 1. ]。平均来说,交叉验证的正确率是我们关心的指标,可以用scores.mean()得到,约为0.96。 96%的正确率,模型的性能还是比较好的,但是具体看每折交叉验证的正确率,从0.933到1之间变化,说明模型在不同测试集合上的表现并不是特别稳定,预测结果与测试数据的选择有关,当然,也有可能和测试集合数量比较少有关。

优势和不足

从上面每折验证的正确率和平均正确率就可以看出,交叉验证首先会抹平由于测试集合选择的“运气”带来的模型评价“失真”。在交叉验证的时候,数据集中的所有数据都有机会成为测试数据,这样可以更好地测试模型的泛化能力。

其次,交叉验证结果也可以说明模型对数据的敏感程度,例如上面的例子,模型的正确率在93.3%~100%之间,我们可以进一步推测,这个模型的正确率在更大的鸢尾花数据集上(如果我们有的话),正确率可能在90%~100%之间。这个范围还是挺大的,在新的数据上,模型的表现大致会在这个范围内。

另外一个好处是,通过交叉验证,数据集中所有的数据都得到了充分的利用。在用train_test_split函数划分模型的训练集合和测试集合时,通常会用75%的数据作为训练集合,25%的数据作为测试集合,在5折交叉验证时候,80%的数据用于训练,20%的用于测试,10折交叉验证的时候,90%的数据训练,10%的数据测试。显然,训练集合的数据越多,模型就会越精确,整体来说,交叉验证训练模型,会让模型的精度得到更好的提升。
交叉验证的不足之处主要是增加了计算开销。k折交叉验证,也就意味着模型要训练k次,比train_test_split要增加k倍的训练时间。

注意事项:

k折交叉验证不会返回唯一的一个模型
k折交叉验证的过程中,实际上训练了k个模型,所以该方法主要是用来测试这个模型在这个数据集合上的表现,并不是一个生成训练模型的方法,更准确地说是一个评价模型的方法。

分层k折交叉验证

k折交叉验证很好,但是有一种情况,k折交叉验证的效果就要大打折扣,就是如果数据集的标签都是连续排列的。例如iris数据集,如果标签是这样的:
在这里插入图片描述
那么在3折交叉验证的时候,每次训练和测试的数据,都是完全不同标签类型的,模型正确率可能只有0%左右了。如何改进呢?

分层k折交叉验证就是一个很好的选择,该方法将数据集先分成k折,然后再从每个k折中选择k折作为测试集合,如下图:
在这里插入图片描述
这样就可以在某种程度上保证,每次测试集合和训练集合都有各个类型的数据,特别是当数据分类的类型数量相差悬殊的情况。比如数据集合一共有1000个样本,900个都属于A类型,只有100个属于B类型。不过不用担心,scikit learn在分类问题中采用的就是这种分层的交叉验证策略。但是对于回归问题, scikit learn用的不是分层k折交叉验证,而是经典版本。这么做很有可能是因为回归问题与分类不同,分类是希望将各个不同(那么数量很少)区别开来,但是回归是希望得到大多数的规律。

交叉验证的更多变形

scikit learn提供了多种对交叉验证进行设置的参数。前面的例子中用到了cv这个参数,可以用来设置交叉验证的折数,对于一般的分类问题,指定cv的折数,scikit learn默认的分层交叉验证就很好用了。不过有时候,例如当我们只需要经典的k折交叉验证,我们可以加载model_selection模块的KFold函数。

from sklearn.model_selection import KFold
kfold = KFold(n_splits=5)

之后可以把kfold赋给cross_val_score()函数中的cv参数:

cross_val_score(logreg, iris.data, iris.target, cv=kfold)

可以验证一下之前提到的,经典k折交叉验证在iris数据集上惨不忍睹的正确率为0是如何“做到的”

kfold = KFold(n_splits=3) print("Cross-validation scores:\n{}".format(cross_val_score(logreg, iris.data, iris.target, cv=kfold)))

得到Cross-validation scores: [ 0. 0. 0.]。可见,经典的k折交叉验证对于这样“排列”好的数据集上,的确是会造成非常错误的结果。改进的方法除了用model_selection模块的cross_val_score()函数,还可以在KFold函数中设置shuffle参数,例如:

kfold=KFold(n_splits=3, shuffle=True, random_state=0)
print(“Cross-validation scores: \n {}.format(cross_val_score(logreg, iris.data, iris.target, cv=kfold)))

leave-one-out交叉验证

leave-one-out,顾名思义,就是一种特殊的k折交叉验证,每次留下来做测试集合的只有一个样本。这是一种非常耗时的交叉验证方法,特别是当数据集比较大的时候。但是对于小数据集,可能会给出更好的模型评估。

from sklearn.model_selection import LeaveOneOut
loo = LeaveOneOut()
scores = cross_val_score(logreg, iris.data, iris.target, cv=loo)
print(“Number of cv iterations:, len(scores))
print(”Mean accuracy: {:.2f}.format(scores.mean()))

Shuffle-split交叉验证

这种交叉验证方法是最为灵活的一种数据集划分方法。

shuffle_split交叉验证是model_selection模块的ShuffleSplit( )函数实现,训练数据集合和测试数据集合按照train_set,test_set参数设置的比例在整个数据集中随机选择,n_iter参数设置的是交叉验证的数量。如下图展示的是train_size=5, test_size=2, n_iter=4的情形:
在这里插入图片描述

from sklearn.model_selection import ShuffleSplit
shuffle_split = ShuffleSplit(test_size=.5, train_size=.5, n_splits=10) 
scores = cross_val_score(logreg, iris.data, iris.target, cv=shuffle_split) 
print("Cross-validation scores:\n{}".format(scores))

根据这里的设置,随机选择50%的数据作为训练集合,50%的数据作为测试集合,训练10轮。当数据集特别大的时候,用这种方式训练比较好,不过需要注意的是,train_size+test_size不能大于1。

ShuffleSplit交叉验证也有对应的分层形式,是StraifiedShuffleSplit() 函数。

组间的交叉验证

另外一种很常见的交叉验证发生在数据中有分组,而且这些分组非常相关的情况下。例如对不同的面部表情图片进行分类,1数据采集的时候选择了100个不同的被试,每个被试都采集了多个不同表情,目前是要建立一个分类器,对不在数据集中的人的表情进行分类。可以用默认的分层交叉验证对分类器的性能进行评估,但是有可能同一个人的不同表情图片同时出现在测试集和训练集中,对于这种情况,模型测试的效果会比完全新人的数据要好(意味着如果测试集合中的图片和训练集合中的图片都属于同一个人的不同表情,这时模型的测试效果会比较好,但是对于不在数据集中的,或者测试集中的图像所属的人没有一张照片在训练集中,这时候模型在测试集上的表现会差很多)。为了很好的评估模型在新面孔上的泛化能力,我们需要保证在训练集和测试集上包含不同人的表情照片。
我们用GroupKFold()函数来实现这个需求,具体来说,

from sklearn.model_selection import GroupKFold
# create synthetic dataset
X, y = make_blobs(n_samples=12, random_state=0)
# assume the first three samples belong to the same group, # then the next four, etc.
groups=[0,0,0,1,1,1,1,2,2,3,3,3]
scores = cross_val_score(logreg, X, y, groups, cv=GroupKFold(n_splits=3)) 
print("Cross-validation scores:\n{}".format(scores))

上面代码段中,数据集中一共是12个样本,groups数列中的0,1,2,3是每个样本所属的分组号,指定分组号就是为了指定对应的数据属于同一个组,在划分测试集合和训练集合的时候,不要将同一组的数据分开。

这种情况在医疗数据中非常常见,例如经常会见到同一个病人的多个样本,训练模型后希望模型可以泛化到其他人的诊断上。类似地,在语音识别中也比较常见,通常训练的样本集合都来自于专门的数据集,是特定的一些人录制的,但是希望训练模型后应用在其他人的语音数据上。
在这里插入图片描述
上图中就是group交叉验证的一个例子。

总结

交叉验证还有更多的变形,具体可以见scikit-learn的user guide. (https://scikit-learn.org/stable/modules/cross_validation.html) 。但是总的来说, KFold, StratifiedKFold, GroupKFold是最常用的几种形式。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

机器学习模型评估与改进: 交叉验证(cross validation) 的相关文章

  • AngularJS - 加载时触发表单验证

    我在表单中添加了 required 和 pattern 等字段验证属性 并且该表单位于 ng controller 内 验证有效 但似乎验证是在页面加载时触发的 并且我看到页面加载时所有字段都被标记为无效并带有错误消息 我尝试将 novav
  • 有没有办法使用无服务器框架来验证路径

    我在后端使用无服务器框架 使用AWS 我的 serverless yml 像这样 functions getBrand handler functions brand getBrand handler events http path se
  • ASP.NET Core [要求] 不可为 null 的类型

    Here https stackoverflow com questions 6662976 required attribute for an integer value 提出了如何验证不可为空的必需类型的问题 在我的情况下 提供的使字段
  • typo3 extbase:验证表单

    我创建了一个简单的 订阅新闻通讯 表单
  • Magento,翻译验证错误消息

    我已经成功创建了原型验证的新规则 现在我需要翻译错误消息 位置 Javascript 中的字符串 但是 我只能翻译所有消息 我的新自定义消息似乎无法翻译 我该如何改变这个 也许你需要一个jstranslator xml里面的文件etc fo
  • jQuery 验证插件:验证自定义日期格式

    我正在使用 jQuery Validate 插件来验证我的表单 如何使用此日期格式 DD MMM YYY 2012 年 3 月 23 日 验证自定义日期 创建自定义验证器 http docs jquery com Plugins Valid
  • 属性列表后缺少 jquery 验证 }

    我这里有这个代码 order validate rules name required true lastname required true address required true telephone required true di
  • Android 手机号码验证

    如何检查电话号码是否有效 长度最大为13 包括字符 在前 我怎么做 我试过这个 String regexStr 0 9 String number entered number getText toString if entered num
  • 如何在 CQRS 中处理基于集合的一致性验证?

    我有一个相当简单的域模型 涉及一系列Facility聚合根 鉴于我使用 CQRS 和事件总线来处理从域引发的事件 您如何处理集合的验证 例如 假设我有以下需求 Facility必须有一个唯一的名称 由于我在查询端使用最终一致的数据库 因此在
  • Django 表单验证消息未显示

    我试图限制可以以表单上传的文件类型 大小和扩展名 该功能似乎有效 但未显示验证错误消息 我意识到if file size gt 4 1024 1024可能不是最好的方法 但我稍后会处理这个问题 这是 forms py class Produ
  • PHP 中的 Javascript“unes​​cape”

    我的图像主机有一个 Google Chrome 扩展程序 它会向我的网站发送一个 URL 该网址得到encoded通过 JavaScript 的escape method 编码的 URLescape看起来像这样 http 253A 4 bp
  • AJAX Rails 验证

    我的表单和验证可以很好地处理常规的 http 请求 我希望它使用 AJAX 我知道我可以在客户端进行验证 但这似乎是多余的 因为我已经在模型中定义了验证 当用户填写表单时 我想就他们的条目向他们提供反馈 在 AJAX 表单中使用 Rails
  • Spring MVC - 自动查找验证器

    假设我有一个像这样的示例实体类 public class Address 和相应的验证器 Component public AddressValidator implements Validator Override public bool
  • 根据 MVC 中的文化的日期时间格式

    我有一个 MVC 视图 其中列出了一个名为 CreatedOn 的日期时间类型列 值的格式如下 日 月 年 时 分 秒 当我单击编辑链接修改值时 我获得相同的格式 当我修改编辑值时 出现验证错误 字段 CreatedOn 必须是日期 我的
  • ASCII“../”是 PHP 中指示目录遍历的唯一字节序列吗?

    我有一个 PHP 应用程序 它使用 GET参数来选择文件系统上的 JS CSS 文件 如果我拒绝输入字符串包含的所有请求 或者可见 7 位 ASCII 范围之外的字节 当路径传递到 PHP 的底层 基于 C 文件函数时 这是否足以防止父目录
  • JSF 中基于两个组件的组合的验证/转换

    我正在开发一个 JSF Web 应用程序 我需要使用周期性作为数据结构 以下是我使用的 Java 类 public class Periodicity implements Serializable private Integer valu
  • 如何对我的自定义验证属性进行单元测试

    我有一个自定义的 asp net mvc 类验证属性 我的问题是如何对其进行单元测试 测试类是否具有该属性是一回事 但这实际上并不能测试其中的逻辑 这就是我想测试的 Serializable EligabilityStudentDebtsA
  • 如何在字段值无效的情况下更改 Struts2 验证错误消息?

    我在 Web 表单上使用 Struts2 验证 如果字段假设为整数或日期 则
  • XHTML 和文本区域内的代码

    在我的一个使用文本区域进行提交的网站上 我的代码可以显示如下所示的内容
  • WPF 中列表框的数据验证

    我有一个 ListBox 绑定到类型 T 的 ObservableCollection 每个 ListBoxItem 都是一个复选框 IsChecked 绑定到 T 中的 bool 属性 我想验证 ListBox 中的选中项 以便至少必须选

随机推荐

  • vagrant入门指南(二): 创建vagrant项目

    创建vagrant项目的第一步应该是新建Vagrantfile文件 在Vagrantfile中应该明确两个问题 1 明确项目的root文件夹位置 vagrant的很多配置选项都是根据root文件夹的位置设置的 2 描述项目需要的机器和资源
  • 机器学习之:流形与降维概述

    流形与降维 xff1a 概述 降维算法概述流形学习距离的定义 KNN图与流形降维KNN图SNE算法 降维算法概述 降维 xff0c 顾名思义就是把数据或者特征的维度降低 xff0c 一般分为线性降维和非线性降维 线性降维有 xff1a PC
  • 机器学习之:载入数据

    加载公共的开放数据 通过url链接下载 通常网上有很多开放数据供算法测试 通常要用到urllib从给定的链接下载 例如从UCI机器学习数据仓库中下载的数据 xff1a span class token keyword import span
  • shell脚本:如何记录计算时长以及如何保存日志文件

    python和matlab都有非常友好的记录时间的方式 xff0c 且不说python的time xff0c datetime工具包 xff0c matlab的tic xff0c toc命令简单好记 xff0c 都是程序时间很好的记录工具
  • MRI-FSL pipeline 多进程并发和并发数控制

    shell脚本并发 在MRI预处理pipeline串行执行非常耗时 非常有必要将pipeline并行化 在linux环境下 并行计算可以有多种实现方法 例如在shell中通过转入后台的方式 或者用xargs多进程并发 还可以用fifo管道实
  • 增长黑客 - 开源项目增长利器

    2012 年我开源了自己的第一个项目 https github com allwefantasy ServiceFramework 这个项目并不成功 xff0c 但对我个人的价值还是比较大的 xff0c 一直作为我工具箱用到现在 从 16
  • vagrant(三):网络配置

    网络配置 所有的网络设置都可以通过配置Vagrantfile来实现 具体来说 xff0c 就是在Vagrantfile中调用config vm network进行相关的设置 vagrant支持以下三种网络配置 xff1a Forwarded
  • vagrant(四):共享目录

    vagrant共享目录 共享目录synced folder 参数共享目录类型 共享目录 共享目录可以设置Vagrant在宿主机 host 和虚拟机 guest 之间同步文件 xff0c 这样做的好处是可以在宿主机上开发 xff0c 在虚拟机
  • FSL的python和R语言接口

    FSL除了本身支持shell命令调用以外 还有一些其他语言的工具包 例如 python和R fsl的python编程库称为fslpy 是可视化工具FSLeyes的一部分 fslpy目前支持python 3 5 3 6 and 3 7开发环境
  • linux rm 命令误删文件恢复

    不小心用rm命令删错了文件 该怎么办 查看分区和文件格式 误删的文件在哪里 首先 用rm命令误删了文件 并不是不可以恢复 首先需要查看一下误删文件所在的分区和文件格式 df T 文件系统 类型 1K 块 已用 可用 已用 挂载点 dev s
  • MRI相关的基本概念

    磁共振基础 磁共振 磁共振 mageticresonanceMR xff1b 在恒定磁场中的核子 xff08 氢质子 xff09 xff0c 在相应的射频脉冲激发后 xff0c 其电磁能量的吸收和释放 xff0c 称为磁共振 基本参数 TR
  • 服务器搭建: 用户管理

    文章目录 查看当前用户用户类型多用户管理用户和用户组的概念添加用户adduser命令useradd命令 用户组管理给用户添加sudo权限删除用户 备注 xff08 1 xff09 etc passwd文件 xff08 2 xff09 etc
  • scikit learn工具箱pipeline模块:串联方法

    scikit learn工具箱pipeline模块 xff1a 串联方法 pipeline模块 scikit learn工具箱的pipeline模块提供了将算法模型串联 并联的工具 xff0c 多个estimator并联起来用于模型结果比较
  • ANOVA与机器学习

    文章目录 方差分析ANOVA组间变异和组内变异均方差F分布与F值方差分析的关键条件 Anova在机器学习中的应用 特征选择总结更多阅读 方差分析ANOVA anova analysis of variance 方差分析 又称 34 变异数分
  • FSL 功能磁共振影像分析: single-session

    文章目录 什么是single session分析基于HRF的模型信号多元回归t contrastf contrast single session分析是fmri实验分析的最简单情况之一 xff0c 这里以FSL官方的例子为例 xff0c 总
  • MRI图像处理:VBM原理和步骤

    VBM是voxel based morphometry的缩写 xff0c 是对被试之间灰质体素粒度统计分析 VBM可以得到人群中volume和gyrification的不同 xff0c 对clinical score进行相关性分析 xff0
  • 创新不是靠痛点,而是靠对效率的持续追求

    什么都等到痛了才去做 xff0c 要你何用 在互联网行业做产品 xff0c 亦或是创业给投资人讲故事 xff0c 一个很核心的点就是要问自己或者告诉对方 xff0c 我的产品击中了什么痛点 xff1f 似乎一切都是靠痛点驱动的 但我认为这是
  • linux压缩文件解压

    文件格式解压方法 zipunzip FileName zip xzxz d FileName tar xz 或者 tar xvJf FileName tar xz bzbzip2 d FileName bz 或者 bunzip2 FileN
  • linux开机启动顺序

    文章目录 linux的开机启动顺序概述BIOS basic input output system 基本输入输出系统MBR master boot record 主引导记录 主引导程序总结 第一个程序 init运行等级System V in
  • 机器学习模型评估与改进: 交叉验证(cross validation)

    文章目录 交叉验证调用方法优势和不足注意事项 xff1a 分层k折交叉验证交叉验证的更多变形leave one out交叉验证Shuffle split交叉验证组间的交叉验证 总结 以监督学习的众多算法为例 xff0c 不管是分类还是回归