数据分析-数据集划分-交叉验证

2023-11-09

一般使用 model_selection.train_test_split() 函数将数据集按要求分成训练集和测试集两类,使用训练集训练,测试集测试。
但单次的划分可能导致结果不具有代表性,一种评估模型泛化性能,比单次划分训练集和测试集的方法更加稳定、全面 的方法为:

交叉验证

k折交叉验证(k-fold cross validation

如5折交叉验证,将数据集分成5份,轮换使用1份作为验证集,其他作为测试集。最终性能取5次的平均。

如果数据集按类别集中分布,某一类集中在一起,则标准交叉验证中的某一折,可能全部为一个类别,这一折外又很少或没有该类样本,如果这一折为验证集,那么在训练集中就没有或很少此类样本,模型训练的结果就会很差,在样本不均衡时表现尤为突出。

如 90% 的样本属于类别A只有 10% 的样本属于类别 B,k折交叉验证就容易导致以上问题出现。 

分层k折交叉验证(stratified cross validation

分层k折交叉验证使每个折内类别之间的比例与整个数据集中的类别比例相同。当数据按类别标签排序时,标准交叉验证与分层交叉验证的对比图如下(极端情况,3个类别,类别均衡):

 可以看到标准交叉验证,3折时,每折对应一个类别,无论如何划分测试集和训练集,每次都有一个类别不在训练集中,不被模型学习到。

而采用分层k折交叉验证可保证每次的训练集中都包含所有的类别,测试集也一样。

将数据充分打乱后再采用K折交叉验证,也可以达到类似的效果。

Sklearn的实现

k折交叉分类器

model_selection.KFold  

对数据集(X,y)4折划分。

from sklearn.model_selection import KFold
kf = KFold(n_splits=4)  
kf.split(X,y)

分层k折交叉分类器

model_selection.StratifiedKFold

对数据集(X,y)分层3折划分。

from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=3) 
skf.split(X,y)

打乱数据集后再划分  

model_selection.ShuffleSplit

对数据集(X,y)乱序后10折划分。

from sklearn.model_selection import ShuffleSplit
shs=ShuffleSplit(n_splits=10)  #打乱顺序后划分
shs.split(X,y)

模型验证

model_selection.cross_val_score 根据交叉验证计算模型分数

5折划分

from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LogisticRegression
print(cross_val_score(LogisticRegression(),X,y,cv=5))  #cv为数字5,5折交叉验证,输出5种分割的score。
[0.83236994 0.94508671 0.92774566 0.69855072 0.88695652]

使用shs划分

from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LogisticRegression
print(cross_val_score(LogisticRegression(),X,y,cv=shs))  #cv=shs,使用shs的划分,输出该分割的得分(10种)。

[0.95953757 0.93641618 0.97109827 0.93063584 0.97109827 0.95375723 0.94797688 0.93641618 0.95953757 0.95953757]

使用skf划分

from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LogisticRegression
print(cross_val_score(LogisticRegression(),X,y,cv=skf))  #cv=skf,使用skf的划分,输出该分割的得分(3种)。
[0.74131944 0.765625   0.86111111]

使用kf划分

from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LogisticRegression
print(cross_val_score(LogisticRegression(),X,y,cv=kf))  #cv=kf,使用kf的划分,输出该分割的得分(4种)。
[0.78935185 0.88194444 0.91203704 0.85648148]

交叉验证预测

model_selection.cross_val_predict     

from sklearn.model_selection import cross_val_predict
lr= LogisticRegression()
cross_val_predict(lr,X1,y1)

学习曲线

model_selection.learning_curve  学习曲线

from sklearn.model_selection import learning_curve
lr= LogisticRegression()
ss = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)
plt.title("Learning Curves(LogisticRegression)")
plt.ylim([0.90,1.01])
plt.xlabel("训练样本数")
plt.ylabel("正确率")
train_sizes, train_scores, test_scores = learning_curve(
    lr, X1, y1, cv=ss,train_sizes=np.linspace(.1, 1.0,30))
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)
plt.grid()
plt.rcParams['font.sans-serif'] = ['SimHei'] # 指定默认字体,解决中文显示问题
plt.rcParams['axes.unicode_minus'] = False # 解决'-'显示为方块的问题
plt.fill_between(train_sizes, train_scores_mean - train_scores_std,
                 train_scores_mean + train_scores_std, alpha=0.1,
                 color="r")
plt.fill_between(train_sizes, test_scores_mean - test_scores_std,
                 test_scores_mean + test_scores_std, alpha=0.1, color="g")
plt.plot(train_sizes, train_scores_mean,  color="r",
         label="训练")
plt.plot(train_sizes, test_scores_mean,  color="b",
         label="交叉验证")
plt.legend(loc="best")

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

数据分析-数据集划分-交叉验证 的相关文章

  • StringBuffer简单使用

    StringBuffer简单使用 一 简介 StringBuffer 是可以存储和操作字符串 即包含多个字符的字符串数据 String类是字符串常量 是不可更改的常量 而StringBuffer是字符串变量 它的对象是可以扩充和修改的 St

随机推荐

  • 【Linux】如何在Linux下提交代码到gittee

    文章目录 使用 git 命令行 创建项目 三板斧第一招 git add 三板斧第二招 git commit 三板斧第三招 git push 其他几个重要的命令 git pull 将远端同步到本地 git rm 删除 git log 查看提交
  • 如何让 useEffect 支持 async/await?

    大家在使用 useEffect 的时候 假如回调函数中使用 async await 的时候 会报错如下 看报错 我们知道 effect function 应该返回一个销毁函数 return返回的 cleanup 函数 如果 useEffec
  • linux 查看运行进程的可执行文件所在目录

    1 获取PID 方法1 执行top命令 然后找到对应的进程 方法2 执行ps ef grep 程序名 2 进入proc目录下对应的进程路径 cd proc 3 sudo ls l user为root的进程需要sudo权限 exe连接的即可执
  • go语言基础-----11-----正则表达式

    1 正则表达式介绍 正则表达式是一种进行模式匹配和文本操纵的复杂而又强大的工具 虽然正则表达式比纯粹的文本匹配效率低 但是它却更灵活 按照它的语法规则 随需构造出的匹配模式就能够从原始文本中筛选出几乎任何你想要得到的字符组合 Go语言通过r
  • Java-API简析_java.lang.RuntimePermission类(基于 Latest JDK)(浅析源码)

    版权声明 未经博主同意 谢绝转载 请尊重原创 博主保留追究权 https blog csdn net m0 69908381 article details 132571263 出自 进步 于辰的博客 因为我发现目前 我对Java API的
  • 【MySQL安装问题】找不到MSVCR120.dll,无法继续执行代码。

    Q 由于找不到MSVCP120 dll 无法继续执行代码 重新安装程序可能会解决此问题 A 参考解决方法链接由于找不到MSVCP120 dll 无法继续执行代码 重新安装程序可能会解决此问题 琴时 博客园 解决方式 点击进入微软官网下载地址
  • MATLAB 数学应用 初等数学 绘制虚数和复数数据图

    文章最后留了个超实用的matlab在线测试工具 绘制一个复数输入 本文演示如何绘制复数向量 z 的虚部与实部 在此复数输入中 plot z 等同于 plot real z imag z 其中 real z 是 z 的实部 imag z 是
  • docker容器部署pytorch模型,gpu加速部署运行

    参考文章 https www zhihu com search type content q Docker EF BC 8C E6 95 91 E4 BD A0 E4 BA 8E E3 80 8C E6 B7 B1 E5 BA A6 E5
  • thinkphp6 入门教程合集(更新中)

    thinkphp6 入门 1 安装 路由规则 多应用模式 thinkphp6 入门 1 安装 路由规则 多应用模式 软件工程小施同学的博客 CSDN博客 thinkphp6 入门 2 视图 渲染html页面 赋值 thinkphp6 入门
  • 组件是如何通信的?技术水平真的很重要!学习路线+知识点梳理

    开头 此文希望能给想跳槽和面试朋友一些参考 金九银十已过 面试的狂热季也已结束 小编也正是选择了在金九十银跳槽 之前在腾讯做了五年Android开发工作 之后感觉公司不一定能继续提供给我想要的发展空间与前景 说白了 有家室 我需要更高的薪酬
  • pandas提取时间里面的年月日_python入门

    时间模块 datetime 1 datetime date date对象 年月日 datetime date today 该对象类型为datetime date 可以通过str函数转化为str In 1 import datetime In
  • 砝码称重问题【dp】

    设有 1g 2g 3g 5g 10g 20g 的砝码各若干枚 其 总重 1000g 要 求 输入 a1 a2 a3 a4 a5 a6 表示 1g 砝码有 a1 个 2g 砝码有 a2 个 20g 砝码有 a6 个 输出 Total N N
  • 【MySQ必知必会】MySQL 是怎么存储数据的?

    文章目录 总结 前言 一 创建数据库 二 确认字段 三 创建数据表 四 插入数据 总结 CREATE DATABASE demo DROP DATABASE demo 删除数据库 SHOW DATABASES 查看数据库 创建数据表 CRE
  • Nginx——Location用法详解

    目录 一 Nginx的Httpp配置简介 二 Location匹配规则 1 精确匹配 2 最佳匹配 3 正则表达式要区分大小写 4 正则表达式不区分大小写 5 开头 通用匹配 6 综合示例 7 root alias指令区别 一 Nginx的
  • Python爬虫入门案例6:scrapy的基本语法+使用scrapy进行网站数据爬取

    几天前在本地终端使用pip下载scrapy遇到了很多麻烦 总是报错 花了很长时间都没有解决 最后发现pycharm里面自带终端 狂喜 于是直接在pycharm终端里面写scrapy了 这样的好处就是每次不用切换路径了 pycharm会直接把
  • 网络层协议------IP协议

    这里写目录标题 IP协议 基本概念 协议头格式 网段划分 特殊的ip地址 私网ip地址和公网ip地址 ip地址的数量限制 路由 IP协议 IP协议 其实就是TCP IP协议中对于网络层的一个协议 注意IP协议是TCP IP协议族中最为核心的
  • 查看localstorage容量

    1 function if window localStorage console log 浏览器不支持localStorage var size 0 for let item in window localStorage if windo
  • 电路实验---全桥整流电路

    全桥整流电路作用 采用四个二极管将交流电转换成直流电 全桥整流电路图 全桥整流电路原理 220V交流电经过变压器T1降压输出电压U2 当U1正半周从L1经过T1 到达L2 极性表现为上正下负 此时电流流过方向 L2上正 gt VD1 gt
  • uniapp的onPullDownRefresh失效 不执行

    需要在 pages json 里 找到的当前页面的pages节点 并在 style 选项中开启 enablePullDownRefresh path pages install uploadImg style navigationBarTi
  • 数据分析-数据集划分-交叉验证

    目录 交叉验证 k折交叉验证 k fold cross validation 分层k折交叉验证 stratified cross validation Sklearn的实现 k折交叉分类器 分层k折交叉分类器 打乱数据集后再划分 模型验证