论文笔记:On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima

2023-11-12

 2017 ICLR

0 摘要

  • 这篇文章探究了深度学习中一个普遍存在的问题——使用大的batchsize训练网络会导致网络的泛化性能下降(Generalization Gap)。
    • 大的batchsize训练使得目标函数倾向于收敛到sharp minima(local minima)
      • sharp minima导致了网络的泛化性能下降
    • 小的batchsize则倾向于收敛到一个flat minima

1 intro

  • 一般的深度学习算法都是通过优化一个目标函数来训练网络参数的,这是一个非凸的优化问题。
    • 整个过程可以表达为下面的式子:
    • 其中M是数据的数量,f(x)就是损失函数/目标函数
  • 一种常见的优化方法就是SGD
    • 这里Bk 是batchsize的大小,一般取值{32,64,…,512}
    • 经过实践的检验,这些常见的batchsize大小设置可以有以下的优点:
      • 收敛到凸函数的最小值点以及非凸函数的驻点;
      • 避免鞍点的出现
      • 对输入数据具有鲁棒性
         
  • 但SGD有一个主要缺点,那就是并行化困难
    • 一个常见方法是增大batchsize,然而这导致了Generalization Gap的出现 

2 大batch的缺点

  •  大Batch方法与小Batch方法在训练的时候实际上得到的目标函数的值是差不多的【只是在测试集上会有一定的gap】
  • 这个现象的可能原因有下面几点(LB=Large-Batch;SB=Small-Batch)
    • LB方法过拟合模型;
    • LB方法被吸引到鞍点;
    • LB方法缺乏SB方法的探索性质,并倾向于放大最接近初始点的最小值;
    • SB和LB方法收敛到具有不同泛化特性的不同的最小化。
  • 作者这篇文章主要研究的是后两点原因
    • 作者认为大Batch方法之所以出现Generalization Gap问题,原因是大Batch方法训练时候更容易收敛到sharp minima,而小Batch的方法则更容易收敛到flat minima。
    • 并且大Batch方法不容易从这些sharp minima的basins中出来。

  • 以上图为例,我测试集比训练集稍微偏离一点,如果是flat mininum的话,还会在最小值附近;但是如果是sharp mininum的话,那就离得远了

3 实验验证

3.1 数据集和网络结构

C1,C3——AlexNet结构

C2,C4——GoogleNet结构

  • 实验中LB方法的Batchsize定义为整个数据集的10%,SB方法的Batchsize定义为256。
  • 优化器使用ADAM(ADAGRAD、adaQN等几个方法得到的结论是类似的)。
  • 损失函数使用的是交叉熵形式。

 3.2 实验结果

  •  对应Table1中6个网络的实验的结果如table2所示。
  • 可以看到,SB和LB两种算法在Training阶段取得的结果非常相近,而Testing阶段LB方法明显出现了Generalization Gap的现象。

 

3.2.1 generation gap 并不是过拟合导致的

作者可视化了testing和training的training和testing loss

 

 此时testing上performance和training上的距离,并不是因为过拟合导致的。

如果是过拟合的话,个人觉得应该是这样的:

 3.2 minima的sharpness

3.2.1 直观可视化

  • Figure3给出的是一维的参数曲线,X_l^*,x_S^*分别表示SB和LB方法在ADAM优化器中得到的预测结果,α在[-1,2]之间
    • α=0——>对应SB方法
    • α=1对应LB方法
    • α<0.5——>SB方法主导;α>0.5——>LB方法主导

  • 可以看到,在0.5左边(SB主导的时候),accuracy高,同时比较平缓;在0.5右边(LB主导的时候),accuracy低,同时比较陡峭 

3.2.2 sharpness定义

  • 给定一个矩阵A \in R^{N \times P},A是一个在全空间随机抽样产生的矩阵
  • A+是矩阵A的伪逆
  • 那么我们定义范围限定集为:
  • 其中epsilon控制了范围限定集的大小 

  • 我们定义sharpness为:

3.2.3 各网络sharpness对比 

table3表示了整个空间上的最小值锋利度,而table4则是子空间(100维)。

3.3 模型效果随batch大小的变化 

  • 关于batchsize的选择是存在一个阈值的,batchsize大于这个阈值会导致模型质量的退化。
  • 这个现象可以由figure4看出来,figure4中的F2的约15000和C1的约500,大于这个阈值网络准确度大幅下降

 

  • SB方法使用的梯度具有内在的噪声,从实验以及经验来看,这些噪声使得SB方法的minimum在到达一个相对sharp的区域时,能够将最优值推出去,到达一个相对flat的区域。而这些噪声不足以将一个本来就很flat的minimum推出去。
  • 而LB方法明显大于上面所说的阈值的时候,梯度内存在的噪声不足以将minimum推出sharp区域。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

论文笔记:On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima 的相关文章

  • MATLAB算法实战应用案例精讲-【时序模型】循环神经网络-GRU(附MATLAB和Python代码)

    目录 前言 几个高频面试题 1 GRU与LSTM的区别与联系 2 LSTM和RNN的区别 GRU的引入 算法原理
  • 微信小程序退出重新进入时跳转特定页面

    微信小程序退出时会记录当前页面的状态 短时间内再次进入会显示退出前的状态 解决方案 在app js文件中添加onHide方法 onHide方法监听小程序切后台 在app js文件中使用会在每次程序退出时调用 onLaunch functio
  • springboot常用注解详解

    在springboot中 经常会用到一些注解 它们各自代表着什么呢 在这个属于我们的节日里 快来了解一下吧 1 SpringBootApplication 一般不会主动去使用它 但是要知道它是一个组合注解 Configuration Ena
  • 【Git CMD】Git常用命令总结

    目录 0 git的工作区 暂存区 本地仓库和远程仓库 0 1 图解 0 2 解析 1 本地仓库 1 1 创建版本库 1 2 分支 1 2 1 查看本地仓库的分支信息 1 2 2 创建分支 1 2 3 切换分支 1 2 4 重命名分支 1 2
  • ERC20 协议

    https www jb51 net blockchain 797814 html https blog csdn net bareape article details 124275062 代币标准 ERC20协议源码解析 我们在买入US
  • Jenkins以root用户运行

    Jenkins安装完成后默认会创建一个jenkins的用户 并以jenkins用户运行 在我们通过jenkins编写一些命令的时候容易出现权限不足的提示 permision denied 通过为jenkins工作区赋予777的权限以后 也可
  • 先电2.4版本iaas搭建部分(vm中模拟,比赛使用服务器)

    改革 由于2020年云计算改革由团队比赛变成个人比赛 原本由三个人 iaas和pass bigdata 云应用开发都变成一个人 所有脚本都在 usr local bin 今天刚拿到镜像和文档 先进行搭建 预先准备 配置两个网卡的ip 一共两
  • 前端面试题--计算机网络

    文章目录 1 七层网络协议体系结构的理解 2 五层协议中各自对应的网络协议 3 ARP 协议的工作原理 4 IP 地址分类的理解 5 TCP 的主要特点 exclamation exclamation Transmission Contro
  • 小白的成长轨迹(二):披荆斩棘,未来可期

    大家好 我是孤焰 一名双非本科的大四学生 又是一年的1024 我坚持撰写博客已经为期一年 很感谢大家一直以来的支持 在这一年期间这位名为 孤焰 的少年又有哪些成长呢 下面便请细听分说 希望这些成长经历可以对正在看这篇文章的小可爱们有一些帮助
  • Arduino ESP32自平衡小车制作实现(不需编码器)

    1 mpu6050陀螺仪角度方向和静态平衡角度测试 说明 1 陀螺仪补偿值的计算 试时提前用calcGyroOffsets true 函数计算出 补偿值 知道mpu6050的补偿值后用setGyroOffsets 直接设置补偿值 避免每次开
  • 【一】第一个java程序详解

    第一个java程序详解 一 前言 二 创建并编写java源代码的文件 创建java源代码文件 更改文件后缀 java代码的结构 三 编译执行 编译 执行 四 总结 五 附 java关键字 一 前言 通过之前上一节 开篇 Java语言介绍及环
  • CLIP与CoOp代码分析

    CLIP与CoOp代码分析 CoOp是稍微改了下CLIP的text encoder CLIP代码 https github com OpenAI CLIP CoOp代码 https github com KaiyangZhou CoOp 输
  • 论文笔记:Region Representation Learning via Mobility Flow

    2017 CIKM 1 摘要和介绍 使用出租车出行数据学习区域向量表征 同时考虑时间动态和多跳位置转换 gt 通过flow graph和spatial graph学习表征 出租车交通流可以作为区域相似度的一种 A区域和B区域之间流量大 gt
  • Tomcat单实例安装部署

    自说 Tomcat 服务器是一个免费的开放源代码的Web 应用服务器 属于轻量级应用服务器 主要用于处理动态web数据 部署java环境 上传jdk包 使用xftp上传 解压 tar zxvf u01 jdk 8u333 linux i58
  • JESD204B(RX)协议接口说明。

    解释一下Vivado IP协议中的Shared Logic in Example 与 Shared Logic in Core 首先 什么是Shared Logic 字面意思很好理解 就是共享逻辑 主要包括时钟 复位等逻辑 当选择Share
  • lyapunov直接法

    文章目录 定义6 6 Lyapunov第一定理 Lyapunov第二定理 用于刻画渐进稳定 内积分析 定义6 6 Lyapunov第一定理 假设 A C A subset C A C是闭的 如果存在A的邻域D和满足下面两条件的连续函数
  • 【Unity步步升】监控与检测物体的各种方案,如:射线、碰撞、挂载等...

    在制作AR模型数值控制方案的时候遇到了检测的问题 学习过程受益匪浅 故今天为大家整理带来一篇监控与检测物体的参考方案集合 目录 一 射线检测 二 物体存在检测 三 碰撞检测 一 射线检测 单射线检测 首先完成搭建场景如下图1 1 我这里用到
  • 运行游戏找不到x3daudio1_7.dll怎么解决?教你如何快速修复的教程

    在计算机使用过程中 我们经常会遇到一些错误提示 其中之一就是 x3daudio1 7 dll丢失 这个错误提示可能让我们感到困惑和烦恼 但是不用担心 本文将为您介绍x3daudio1 7 dll丢失的原因以及五种修复方法 帮助您解决这个问题
  • 【网安入门】怎样花3个月零基础入门网络安全?

    写这篇教程的初衷是很多朋友都想了解如何入门 转行网络安全 实现自己的 黑客梦 文章的宗旨是 1 指出一些自学的误区 2 提供客观可行的学习表 3 推荐我认为适合小白学习的资源 大佬绕道哈 一 自学网络安全学习的误区和陷阱 1 不要试图先成为
  • 面对内卷严重的2023年,测试人员该怎样修炼?

    这几天马上就要双12大促 相信大家都准备花了不少钱吧 其实在每一次大促的背后各大电商平台还在遭受一次又一次的的黑产攻击 拿阿里巴巴去年双十一举例 2684 亿交易额的背后 有一天内 22 亿次的黑产攻击 近几年网络安全事件层出不穷 相信大家

随机推荐