困扰我两天的问题:StratifiedShuffleSplit与train_test_split创建的数据集为何训练结果不同?

2023-11-09

困扰我两天的问题:StratifiedShuffleSplit与train_test_split创建的数据集为何训练结果不同?

让人头疼的问题

最近,我在进行卷积模型的分类任务时发现了一个StratifiedShuffleSplit函数的bug。

众所周知,在训练模型之前我们一般会对数据集划分为训练集和验证集,以便后期对模型性能的验证。然而在我近期的实验中,我发现使用StratifiedShuffleSplit函数划分数据集和使用train_test_split函数划分数据集竟然产生了巨大的训练差异(前者验证准确率高达95%,而后者只有80%)。

版本信息

sklearn 1.2.2
torch 2.0.0
cuda 11.8
python 3.9

设置虚拟数据

x = np.random.rand(10000,1)
y1 = np.ones(5000)
y0 = np.zeros(5000)
y = np.concatenate([y1,y0], axis=0)
print(x.shape)
print(y.shape)

# (10000, 1)
# (10000,)

train_test_split划分

train_data, test_data, train_labels, test_labels = train_test_split(x, y, test_size=0.2, random_state=30, stratify=y)

设置stratify是为了保证其与StratifiedShuffleSplit保持一致,都为分层采样

train_subset_x = torch.FloatTensor(train_datas)
train_subset_y = torch.LongTensor(train_labels)

valid_subset_x = torch.FloatTensor(test_datas)
valid_subset_y = torch.LongTensor(test_labels)

from collections import Counter
print(Counter([np.int32(train_subset_y[i]) for i in range(len(train_subset_y))]))
print(Counter([np.int32(valid_subset_y[i]) for i in range(len(valid_subset_y))]))

# Counter({0: 4000, 1: 4000})
# Counter({1: 1000, 0: 1000})

StratifiedShuffleSplit划分

from sklearn.model_selection import StratifiedShuffleSplit
from collections import Counter
from torch.utils.data.dataset import Subset

def generate_train_indices(n_splits, ratio, data, lab):
#     ss = StratifiedShuffleSplit(n_splits=n_splits, train_size=ratio, random_state=20)
    ss = StratifiedShuffleSplit(n_splits=n_splits, train_size=ratio)
    return [i.tolist() for i, _ in ss.split(data, lab)], [j.tolist() for _, j in ss.split(data, lab)]

train_indices, valid_indices = generate_train_indices(1, 0.8, x, y)
print(Counter([y[i] for i in train_indices[0]]))
print(Counter([y[i] for i in valid_indices[0]]))

# Counter({0: 4000, 1: 4000})
# Counter({1: 1000, 0: 1000})

以上两种方法创建的数据集,在同一个模型,相同的训练方式下,结果完全不同,前者80%左右,后者95%左右。然后我将StratifiedShuffleSplit的random_state设置上固定的数字,比如:

from sklearn.model_selection import StratifiedShuffleSplit
from collections import Counter
from torch.utils.data.dataset import Subset

def generate_train_indices(n_splits, ratio, data, lab):
    ss = StratifiedShuffleSplit(n_splits=n_splits, train_size=ratio, random_state=20)
#    ss = StratifiedShuffleSplit(n_splits=n_splits, train_size=ratio)
    return [i.tolist() for i, _ in ss.split(data, lab)], [j.tolist() for _, j in ss.split(data, lab)]

train_indices, valid_indices = generate_train_indices(1, 0.8, x, y)

然后,实验结果就变得跟train_test_split划分的数据集所得的结果相同了,是不是很神奇?

我在尝试了各种ablation实验后,还是没找到原因出在哪里。然后我就想不会是验证集数据泄露了吧,然后我就打印了一下未设置random_state下的采样结果

from sklearn.model_selection import StratifiedShuffleSplit
from collections import Counter
from torch.utils.data.dataset import Subset

def generate_train_indices(n_splits, ratio, data, lab):
#     ss = StratifiedShuffleSplit(n_splits=n_splits, train_size=ratio, random_state=20)
    ss = StratifiedShuffleSplit(n_splits=n_splits, train_size=ratio)
    return [i.tolist() for i, _ in ss.split(data, lab)], [j.tolist() for _, j in ss.split(data, lab)]

train_indices, valid_indices = generate_train_indices(1, 0.8, x, y)
l = []
for i in train_indices[0]:
    for j in valid_indices[0]:
        if i==j:
            l.append(i)
# print(l)
print(len(l))

# 1602

结果发现,验证集和测试集竟然有1602个数据是完全相同的,然后我设置了random_state参数,发现没有相同的数据了。

据我从网上查到的知识可以知道,random_state只是设置一个随机种子,并不会对StratifiedShuffleSplit产生其他的影响。然而,实际情况是设置了random_state是无放回的分层采样,而不设置random_state就会变成有放回的分层采样。

这是StratifiedShuffleSplit函数的一个bug?还是该函数本身就是这么设置的?或者我的代码有问题?求大佬解释

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

困扰我两天的问题:StratifiedShuffleSplit与train_test_split创建的数据集为何训练结果不同? 的相关文章

随机推荐

  • maven在Win10的安装和配置

    1 下载和安装maven 一 下载Maven并解压 1 Maven官网下载地址 http maven apache org download cgi 2 下载后解压 将Maven的压缩包解压到 E Java apache maven 3 6
  • 【Unity2D入门教程】简单制作一个弹珠游戏之制作场景①(开场,结束,板子,球)

    学习目标 看过我前面的文章应该知道怎么制作开头和结尾 这里我简单把效果给大伙看一下 我用的游戏分辨率是4 3 因此我们要改变Canvas的的Cavans Scale为X1440 Y1080 结束的场景也一样 接着我们编写一个脚本来管理场景的
  • 基于人工势场算法机器人避障路径规划

    基于人工势场算法机器人避障路径规划 人工势场算法是一种热门的机器人路径规划算法 其通过建立虚拟的 势场 使得机器人在避障时能够像物理学中的粒子一样受到 势 的作用 最终实现自主导航 本文将介绍如何使用 MATLAB 实现基于人工势场算法的机
  • mac中的IDEA的使用快捷键

    1 command F 在当前文件进行文本查找 2 command shift F 进行工程和模块中的文件搜索 3 command u 找到这个方法的接口 4 command option commad 找到这个接口的实现类 5 comma
  • javascript中做减法时,出现小数位增加bug

    这个bug是js固有的 浮点数精度不准 你可以用下面方法来解决 思路是先放大 求和 差 积等运算后再缩小 如 加法函数 用来得到精确的加法结果 说明 javascript的加法结果会有误差 在两个浮点数相加的时候会比较明显 这个函数返回较为
  • JPEG编码原理及文件格式及代码分析

    一 JPEG编码原理 首先我们先来看一下JPEG的编码原理图 如上图所示 下面进行逐步的分析 1 RGB gt YUV 首先为了降低互相的关联性 将RGB转换为YUV 这样就可以对亮度信号和色度信号进行分别的处理 2 零电平偏置下移 由于后
  • CreateEvent函数在多线程中使用及实例

    HANDLE CreateEvent LPSECURITY ATTRIBUTES lpEventAttributes BOOL bManualReset BOOL bInitialState LPCSTR lpName bManualRes
  • 8.7.1 makefile实例——项目中的总makefile

    Linux C程序设计王者归来 第8章构建makefile文件 makefile相当于一种脚本编程语言 用户在编写makefile的过程中可以使用变量 控制结构语句和函数等一般编程语言的特性 同时也可以执行shell指令 makefile诞
  • 2021年前端关注的8个技术趋势

    2020年也过去 我们一起解读一下整个2020年的前端技术的8个技术 并深度分析2021年大前端领域又有哪些顶级技术趋势 你不容错过 2020年注定是不平凡的一年 相信因为疫情很多程序员的工作和生活都受到了一定影响 其实现在前端的技术已经到
  • Minio控制台详细教程

    前言 此文讲解Minio控制台详细教程 可能会涉及到有些知识大家可能不懂情况 需要知道Minio兼容的是AMS S3对象存储服务 需要知道AMS S3对象存储服务是什么 里面涉及的到配置如何去配等等 https docs aws amazo
  • PHP html table下载为excel

    php下载头 header Content type application vnd ms excel header Content Disposition attachment filename test xls header Conte
  • 进程管理与内存管理

    谷歌官网 内存管理概览
  • STM32之RTC

    简介 STM32 的实时时钟 RTC 是一个独立的定时器 STM32 的 RTC 模块拥有一组连续计数的计数器 在相应软件配置下 可提供时钟日历的功能 修改计数器的值可以重新设置系统当前的时间和日期 框图 相关寄存器 控制寄存器 第 0 位
  • 网页的内联框架

    内联 网页里嵌套网页
  • 图解算法 -使用Python 学习笔记(3)

    图解算法 使用Python 学习笔记 3 排序算法 3 1认识排序 用以排序的依据是键 它所含的值被称为 键值 通常键值的数据类型有数值类型 中文字符串以及非中文字符串三种 其中中文字符串用该中文内码 如中文繁体BIG5码 中文简体GB码
  • ConcurrentHashMap面试知识的思维导图整理

    整理自文章https blog csdn net yunzhaji3762 article details 113623168 juc 有关介绍https blog csdn net abaidaye article details 123
  • 小伙伴们,赶紧,免费的视频托管。

    最近 朋友的公司培训频道要上在线视频功能 说自己找了一些行业解决方案 最后对比使用视频托管的方式比较省事 说白了就是省钱呗 挑来挑去 觉得感觉保利威视这家托管服务商的视频播放效果较好 并且功能相对多 比较符合当前项目需求 他在网站上注册的账
  • SpringBoot和Vue跨域问题

    Access to XMLHttpRequest at http localhost 8301 admin vod user info token admin token from origin http localhost 9528 ha
  • gateway 报错 reactor.core.Exceptions$ErrorCallbackNotImplemented

    生产环境好好的 突然前端请求全部跨域 请求 500 gateway 报错 reactor core Exceptions ErrorCallbackNotImplemented java lang IndexOutOfBoundsExcep
  • 困扰我两天的问题:StratifiedShuffleSplit与train_test_split创建的数据集为何训练结果不同?

    困扰我两天的问题 StratifiedShuffleSplit与train test split创建的数据集为何训练结果不同 让人头疼的问题 最近 我在进行卷积模型的分类任务时发现了一个StratifiedShuffleSplit函数的bu