【scikit-learn】交叉验证及其用于参数选择、模型选择、特征选择的例子

2023-05-16



内容概要¶

  • 训练集/测试集分割用于模型验证的缺点
  • K折交叉验证是如何克服之前的不足
  • 交叉验证如何用于选择调节参数、选择模型、选择特征
  • 改善交叉验证

1. 模型验证回顾¶

进行模型验证的一个重要目的是要选出一个最合适的模型,对于监督学习而言,我们希望模型对于未知数据的泛化能力强,所以就需要模型验证这一过程来体现不同的模型对于未知数据的表现效果。

最先我们用训练准确度(用全部数据进行训练和测试)来衡量模型的表现,这种方法会导致模型过拟合;为了解决这一问题,我们将所有数据分成训练集和测试集两部分,我们用训练集进行模型训练,得到的模型再用测试集来衡量模型的预测表现能力,这种度量方式叫测试准确度,这种方式可以有效避免过拟合。

测试准确度的一个缺点是其样本准确度是一个高方差估计(high variance estimate),所以该样本准确度会依赖不同的测试集,其表现效果不尽相同。

高方差估计的例子¶

下面我们使用iris数据来说明利用测试准确度来衡量模型表现的方差很高。

In [1]:

from sklearn.datasets import load_iris
from sklearn.cross_validation import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn import metrics
  
In [2]:

# read in the iris data
iris = load_iris()

X = iris.data
y = iris.target
  
In [3]:

for i in xrange(1,5):
    print "random_state is ", i,", and accuracy score is:"
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=i)

    knn = KNeighborsClassifier(n_neighbors=5)
    knn.fit(X_train, y_train)
    y_pred = knn.predict(X_test)
    print metrics.accuracy_score(y_test, y_pred)
  

random_state is  1 , and accuracy score is:
1.0
random_state is  2 , and accuracy score is:
1.0
random_state is  3 , and accuracy score is:
0.947368421053
random_state is  4 , and accuracy score is:
0.973684210526
  

上面的测试准确率可以看出,不同的训练集、测试集分割的方法导致其准确率不同,而交叉验证的基本思想是:将数据集进行一系列分割,生成一组不同的训练测试集,然后分别训练模型并计算测试准确率,最后对结果进行平均处理。这样来有效降低测试准确率的差异。

2. K折交叉验证¶

  1. 将数据集平均分割成K个等份
  2. 使用1份数据作为测试数据,其余作为训练数据
  3. 计算测试准确率
  4. 使用不同的测试集,重复2、3步骤
  5. 对测试准确率做平均,作为对未知数据预测准确率的估计
In [4]:

# 下面代码演示了K-fold交叉验证是如何进行数据分割的
# simulate splitting a dataset of 25 observations into 5 folds
from sklearn.cross_validation import KFold
kf = KFold(25, n_folds=5, shuffle=False)

# print the contents of each training and testing set
print '{} {:^61} {}'.format('Iteration', 'Training set observations', 'Testing set observations')
for iteration, data in enumerate(kf, start=1):
    print '{:^9} {} {:^25}'.format(iteration, data[0], data[1])
  

Iteration                   Training set observations                   Testing set observations
    1     [ 5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24]        [0 1 2 3 4]       
    2     [ 0  1  2  3  4 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24]        [5 6 7 8 9]       
    3     [ 0  1  2  3  4  5  6  7  8  9 15 16 17 18 19 20 21 22 23 24]     [10 11 12 13 14]     
    4     [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 20 21 22 23 24]     [15 16 17 18 19]     
    5     [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]     [20 21 22 23 24]     
  

3. 使用交叉验证的建议¶

  1. K=10是一个一般的建议
  2. 如果对于分类问题,应该使用分层抽样(stratified sampling)来生成数据,保证正负例的比例在训练集和测试集中的比例相同

4. 交叉验证的例子¶

4.1 用于调节参数¶

交叉验证的方法可以帮助我们进行调参,最终得到一组最佳的模型参数。下面的例子我们依然使用iris数据和KNN模型,通过调节参数,得到一组最佳的参数使得测试数据的准确率和泛化能力最佳。

In [6]:

from sklearn.cross_validation import cross_val_score
  
In [7]:

knn = KNeighborsClassifier(n_neighbors=5)
# 这里的cross_val_score将交叉验证的整个过程连接起来,不用再进行手动的分割数据
# cv参数用于规定将原始数据分成多少份
scores = cross_val_score(knn, X, y, cv=10, scoring='accuracy')
print scores
  

[ 1.          0.93333333  1.          1.          0.86666667  0.93333333
  0.93333333  1.          1.          1.        ]
  
In [8]:

# use average accuracy as an estimate of out-of-sample accuracy
# 对十次迭代计算平均的测试准确率
print scores.mean()
  

0.966666666667
  
In [11]:

# search for an optimal value of K for KNN model
k_range = range(1,31)
k_scores = []
for k in k_range:
    knn = KNeighborsClassifier(n_neighbors=k)
    scores = cross_val_score(knn, X, y, cv=10, scoring='accuracy')
    k_scores.append(scores.mean())

print k_scores
  

[0.95999999999999996, 0.95333333333333337, 0.96666666666666656, 0.96666666666666656, 0.96666666666666679, 0.96666666666666679, 0.96666666666666679, 0.96666666666666679, 0.97333333333333338, 0.96666666666666679, 0.96666666666666679, 0.97333333333333338, 0.98000000000000009, 0.97333333333333338, 0.97333333333333338, 0.97333333333333338, 0.97333333333333338, 0.98000000000000009, 0.97333333333333338, 0.98000000000000009, 0.96666666666666656, 0.96666666666666656, 0.97333333333333338, 0.95999999999999996, 0.96666666666666656, 0.95999999999999996, 0.96666666666666656, 0.95333333333333337, 0.95333333333333337, 0.95333333333333337]
  
In [10]:

import matplotlib.pyplot as plt
%matplotlib inline
  
In [12]:

plt.plot(k_range, k_scores)
plt.xlabel("Value of K for KNN")
plt.ylabel("Cross validated accuracy")
  
Out[12]:

<matplotlib.text.Text at 0x6dd0fb0>  

上面的例子显示了偏置-方差的折中,K较小的情况时偏置较低,方差较高;K较高的情况时,偏置较高,方差较低;最佳的模型参数取在中间位置,该情况下,使得偏置和方差得以平衡,模型针对于非样本数据的泛化能力是最佳的。

4.2 用于模型选择¶

交叉验证也可以帮助我们进行模型选择,以下是一组例子,分别使用iris数据,KNN和logistic回归模型进行模型的比较和选择。

In [13]:

# 10-fold cross-validation with the best KNN model
knn = KNeighborsClassifier(n_neighbors=20)
print cross_val_score(knn, X, y, cv=10, scoring='accuracy').mean()
  

0.98
  
In [14]:

# 10-fold cross-validation with logistic regression
from sklearn.linear_model import LogisticRegression
logreg = LogisticRegression()
print cross_val_score(logreg, X, y, cv=10, scoring='accuracy').mean()
  

0.953333333333
  

4.3 用于特征选择¶

下面我们使用advertising数据,通过交叉验证来进行特征的选择,对比不同的特征组合对于模型的预测效果。

In [15]:

import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
  
In [16]:

# read in the advertising dataset
data = pd.read_csv('http://www-bcf.usc.edu/~gareth/ISL/Advertising.csv', index_col=0)
  
In [17]:

# create a Python list of three feature names
feature_cols = ['TV', 'Radio', 'Newspaper']

# use the list to select a subset of the DataFrame (X)
X = data[feature_cols]

# select the Sales column as the response (y)
y = data.Sales
  
In [18]:

# 10-fold cv with all features
lm = LinearRegression()
scores = cross_val_score(lm, X, y, cv=10, scoring='mean_squared_error')
print scores
  

[-3.56038438 -3.29767522 -2.08943356 -2.82474283 -1.3027754  -1.74163618
 -8.17338214 -2.11409746 -3.04273109 -2.45281793]
  

这里要注意的是,上面的scores都是负数,为什么均方误差会出现负数的情况呢?因为这里的mean_squared_error是一种损失函数,优化的目标的使其最小化,而分类准确率是一种奖励函数,优化的目标是使其最大化。

In [19]:

# fix the sign of MSE scores
mse_scores = -scores
print mse_scores
  

[ 3.56038438  3.29767522  2.08943356  2.82474283  1.3027754   1.74163618
  8.17338214  2.11409746  3.04273109  2.45281793]
  
In [20]:

# convert from MSE to RMSE
rmse_scores = np.sqrt(mse_scores)
print rmse_scores
  

[ 1.88689808  1.81595022  1.44548731  1.68069713  1.14139187  1.31971064
  2.85891276  1.45399362  1.7443426   1.56614748]
  
In [21]:

# calculate the average RMSE
print rmse_scores.mean()
  

1.69135317081
  
In [22]:

# 10-fold cross-validation with two features (excluding Newspaper)
feature_cols = ['TV', 'Radio']
X = data[feature_cols]
print np.sqrt(-cross_val_score(lm, X, y, cv=10, scoring='mean_squared_error')).mean()
  

1.67967484191
  

由于不加入Newspaper这一个特征得到的分数较小(1.68 < 1.69),所以,使用所有特征得到的模型是一个更好的模型。

参考资料¶

  • scikit-learn documentation: Cross-validation, Model evaluation
  • scikit-learn issue on GitHub: MSE is negative when returned by cross_val_score
  • Scott Fortmann-Roe: Accurately Measuring Model Prediction Error
  • Harvard CS109: Cross-Validation: The Right and Wrong Way
  • Journal of Cheminformatics: Cross-validation pitfalls when selecting and assessing regression and classification models
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【scikit-learn】交叉验证及其用于参数选择、模型选择、特征选择的例子 的相关文章

  • 学习c语言

    今天学习了if语句和 xff45 xff4c xff53 xff45 运用 xff43 语言更加顺手 xff0c 之前一些都能实施 xff0c 真是太开心了 include lt stdio h gt int main 初始化 int pr
  • 求符合给定条件的整数集(做题)

    题目如上 xff1b 首先我们先想思路 xff1a 先来一个输入 xff0c 读入这个数 xff0c 然后我们需要三个变量来储存这三个数 xff1b 然后我们遍历所有的组合 xff0c 这个依靠循环 接下来是代码 xff1a include
  • 水仙花数(做题)

    代码如下 xff1a include lt stdio h gt int main int a scanf 34 d 34 amp a float t t 61 0 1 while a gt 0 t 61 t 10 a 判断几位数 int
  • 一分钟了解动态内存分配

    谈到这 xff0c 必然离不开malloc函数 在上面可以看出此函数需要一个头文件 include lt stdilb h gt 而且返回类型是void 传进去的是空间大小 xff0c 此函数申请的空间是字节为单位的 这其中的就分配了100
  • 动态内存分配深究

    接下来我们将探究以下三个问题 xff1a 1 相邻两次malloc得到的空间是否是连续的呢 xff1f 2 你得到的空间的实际大小是否就是你要求的大小呢 xff1f 3 如果你malloc零长度会得到什么结果呢 xff1f 第一个问题 xf
  • 同一个页面不打开两次

    lt script language 61 34 javascript 34 gt function popwin3 path window open path 34 cart 34 34 height 61 520 width 61 52
  • 超易懂!二分查找 详析

    二分算法的 本质 是 xff1a 假如我们可以找到事物的 某种性质 xff0c 这种性质 可以将区间一分为二 xff0c 一半满足 xff0c 一半不满足 我们就可以二分 另外 xff0c 有 连续性 必可以 二分 二分模板一共有两个 xf
  • 手摸手 Spring Cloud Gateway + JWT 实现登录认证

    你好 xff0c 我是悟空 前言 上篇我已经讲解了 Spring Cloud 的原理和实战 xff0c 这次就要结合 JWT 来实现登录认证的功能了 本文已收录至 深入剖析 Spring Cloud 底层架构原理 xff0c 已更新 18
  • 百行代码实现VLC简易视频播放器【VLC环境配置过程+可执行源码注释完整】

    文章目录 什么是VLC x1f680 VLC 库的集成 VLC环境配置演示 win10系统 43 vs2017 43 win64 x1f34e VLC 库的基本使用 x1f382 视频播放器实现 自定义函数Unicode2Utf8讲解 x1
  • HttpWebRequest 使用NetworkCredential 进行域认证下载时不成功 的解决方案

    最近在项目中使用pWebRequest 使用NetworkCredential 进行域认证下载时老不成功 xff0c 最后Google了解决方案 xff0c 发现几乎所有讨论的方案都不成功 xff0c 只好埋头自己解决 xff0c 最后总算
  • Firefox 的用户脚本管理器 greasemonkey 的使用一例

    一 什么是greasemonkey Firefox 的用户脚本管理器 greasemonkey 使你可以向任何网页添加DHTML语句 用户脚本 来改变它们的显示方式 就像CSS可以让你接管网页的样式 xff0c 而用户脚本 User Scr
  • Apache Http 服务器安装教程

    我在学习网络开发的时候需要从服务器上获得json数据 xff0c 所以在自己的电脑上安装了一个本地服务器 xff0c 其中遇到的一些问题 xff0c 在这里都写出来 首先 xff0c 我们需要访问apache http服务器的下载网页 xf
  • STM32的UART奇偶校验注意

    STM32的UART奇偶校验注意 STM32的UART在初始化时 xff0c 我们通常用到最多的就是无校验位 xff0c 1停止位 但是我在项目中也遇到某些芯片通信用的需要奇校验或者偶校验 xff0c 这里需要特别注意的是STM32中开启奇
  • Realtek RTL8762C/Realtek RTL8762D学习记录

    本人基于日常工作整理编写的8762C FAQ文档 xff0c 记录RTL8762C 8762D系列软件开发常见问题以及解决方案 希望它能发挥更多作用 帮到有需要的朋友 关键字 xff1a 8762CMF 8762CK 8762CJ 8762
  • 蓝牙BLE---DA14683的SPI主机通信讲解

    DA14683的SPI主机通信例程 Date 2018 12 19 Create Jim 导入例程 首先导入ble peripheral例程或者pxp reporter例程 再到以下位置打开硬件SPI的宏定义 xff1a 获取SPI例程源码
  • 06.5 Code

    06 5 Code 推力 force 推力的应用旋翼的气动阻力空气阻力矩滚转力矩电机的转速 推力 force span class token comment force 61 电机的转速 xff5c 电机的转速 xff5c xff08 带
  • C、C++ 对于char*和char[]的理解

    1 char 和char 的共同点 都是指针 xff0c 指向第一个字符所在的地址 2 char 的用法 char a 61 34 aaa 34 char p1 61 a char 是常量指针 xff08 常量的指针 xff09 xff0c
  • 重新抛出(rethrow)

    有可能单个catch不能完全处理一个异常 在进行了一些校正行动之后 xff0c catch可能确定该异常必须由函数调用链中更上层的函数来处理 xff0c catch可以通过重新抛出 rethrow 将异常传递给函数调用链中更上层的函数 重新
  • 4-2 图像聚类算法

    4 2 图像聚类算法 目录1 分类与聚类1 1 分类1 2 聚类1 3 聚类样本间的属性1 4 聚类的常见算法 2 K Means聚类2 1 概念2 2 步骤2 3 例子2 4 K Means聚类与图像处理2 5 K Means聚类优缺点优
  • JavaWeb-03 统一字符集编码、JSP的页面元素、JSP九大内置对象-request

    1 使用Eclipse开发Web项目 JSP项目 tomacat 2 在Eclipse中创建的Web项目 xff1a 浏览器可以直接访问WebContent中的文件例如 http 127 0 0 1 8888 MyJspProject in

随机推荐

  • 9-1 从零开始训练网络

    9 1 从零开始训练网络 目录1 搭建网络基本架构要完成的功能 2 构建训练网络1 实现网络训练功能2 获取训练数据及预处理 3 启动训练网络并测试数据 目录 搭建网络基本架构构建训练网络启动训练网络并测试数据 1 搭建网络基本架构 要完成
  • 基于知识图谱的推荐系统

    基于知识图谱的推荐系统 推荐系统 xff1a 核心目标是通过分析用户行为 兴趣 需求等信息 在海量的数据中挖掘用户感兴趣的信息 如商品 新闻 POI point of interest 和试题 等 个性化推荐算法是推荐系统的核心 其主要可以
  • mysql级联删除

    mysql级联删除 场景 xff1a 员工表 id xff1a 员工idleader id xff1a 该员工的领导的id xff08 也是员工id xff09 外键dept id xff1a 该员工的部门id xff08 部门表外键 xf
  • 【Seata】安装 - mac

    1 下载 官网 xff1a https seata io zh cn index html 2 修改配置文件 2 1 file conf 还有user password 2 2 registry conf 1 xff09 registry
  • Go Modules模式

    Go Modules模式 xff08 1 xff09 go mod 命令 命令作用go mod init生成 go mod 文件 在当前文件夹下初始化一个新的 go mod 文件go mod download下载 go mod 文件中指明的
  • 【Go】flag

    flag String span class token keyword func span span class token function String span span class token punctuation span n
  • Mybatis 逆向工程

    Mybatis 逆向工程 Maven项目generatorConfig xmlpom xml Maven项目 项目结构 xff1a generatorConfig xml span class token prolog lt xml ver
  • 原理分享 | 单片机常用通信协议汇总(上)

    vx 嵌入式工程师成长日记 https mp weixin qq com s biz 61 Mzg4Mzc3NDUxOQ 61 61 amp mid 61 2247484134 amp idx 61 1 amp sn 61 b779ccf0
  • C语言模拟TCP通信-------收发数据

    简介 这篇是我学习网络编程时初次接触到的 xff0c 感觉挺适合初学者 xff0c 下文主要介绍了如何使用Linux模拟TCP通信 xff0c 分为客户端和服务器端两大部分 xff0c 外加一个总的头文件 流程 服务器端和客户端使用TCP的
  • 多传感器融合记录

    多传感器信息融合的典型应用 多传感器融合中的时间硬同步1 论文阅读 weixin 39606911的博客 CSDN博客 前言阅读硕士论文 自动驾驶中多传感器集成同步控制器设计与实现 xff0c 该论文为自动驾驶设计了一套时间同步控制器 xf
  • VINS记录

    euroc launch lt launch gt lt arg name 61 34 config path 34 default 61 34 find feature tracker config euroc euroc config
  • OpenCV介绍与入门

    OpenCV入门 OpenCV介绍关于OpenCV1 OpenCV能做什么 xff1b 2 OpenCV与图形学与FFmpeg的关系 xff1b 3 OpenCV的未来 xff1b OpenCV介绍 OpenCV是计算机视觉的框架 关于Op
  • 【可见光室内定位】(一)概览

    目录 一 室内无线定位技术概况二 研究现状三 应用前景背景 一 室内无线定位技术概况 二 研究现状 得益于可见光通信 xff08 xff36 xff2c xff23 xff09 技术的迅速发展 xff0c 可 见光定位 xff08 xff3
  • 【机器学习中的数学】比例混合分布

    比例混合分布 Scale Mixture Distribution 混合分布是来自其他随机变量的集合构成的随机变量的概率分布 xff1a 一个随机变量是根据给定的概率从集合随机选取的 xff0c 然后所选随机变量的值就得到了 first a
  • 互联网相似图像识别检索引擎 —— 基于图像签名的方式

    一 引言 多媒体识别是信息检索中难度较高且需求日益旺盛的一个问题 以图像为例 xff0c 按照图像检索中使用的信息区分 xff0c 图像可以分为两类 xff1a 基于文本的图像检索和基于内容识别的图像检索 xff08 CBIR xff1a
  • 【Vim】使用map自定义快捷键

    map简介 map是一个映射命令 将常用的很长的命令映射到一个新的功能键上 map是Vim强大的一个重要原因 xff0c 可以自定义各种快捷键 xff0c 用起来自然得心应手 映射的种类 有五种映射存在 xff1a 用于普通模式 输入命令时
  • 【Scala】使用Option、Some、None,避免使用null

    避免null使用 大多数语言都有一个特殊的关键字或者对象来表示一个对象引用的是 无 xff0c 在Java xff0c 它是null 在Java 里 xff0c null 是一个关键字 xff0c 不是一个对象 xff0c 所以对它调用任何
  • 【Linux】使用update-alternatives命令进行版本的切换

    引言 在Debian系统中 xff0c 我们可能会同时安装有很多功能类似的程序和可选配置 xff0c 可能会出现同一软件的多个版本并存的场景 比如像是一些编程语言工具 xff0c 一些系统中自带的是python2 6 xff0c 而现在py
  • stm32G0 启动

    目的 STM32G是意法半导体这两年新推出的系列芯片 xff0c 相比原先的F系列的芯片有很多提升点 xff0c 将来必将取代F系列芯片的地位 对于新芯片的应用来说能够正确下载与运行程序是比较重要的一点 xff0c 这篇文章将对 STM32
  • 【scikit-learn】交叉验证及其用于参数选择、模型选择、特征选择的例子

    xfeff xfeff 内容概要 训练集 测试集分割用于模型验证的缺点K折交叉验证是如何克服之前的不足交叉验证如何用于选择调节参数 选择模型 选择特征改善交叉验证 1 模型验证回顾 进行模型验证的一个重要目的是要选出一个最合适的模型 xff