神经网络训练中batch的作用(从更高角度理解)

2023-11-18

1.什么是batch

batch,翻译成汉语为批(一批一批的批)。在神经网络模型训练时,比如有1000个样本,把这些样本分为10批,就是10个batch。每个批(batch)的大小为100,就是batch size=100。
每次模型训练,更新权重时,就拿一个batch的样本来更新权重。

2.神经网络训练中batch的作用(从更高角度理解)

从更高的角度讲,”为什么神经网络训练时有batch?“,需要先讲一些预备知识。

当我们求损失loss用于梯度下降,更行权重时,有几种方式。一种是全部的样本用来求loss,这种方式称为批量梯度下降(BGD);一种是随机的选取一个样本,求loss,进而求梯度,这种方式称为随机梯度下降(SGD);BGD和SGB的这种,产生了第三种梯度下降的方法:小批量梯度下降(MBGD)

当我们使用BGD方法来更新权重时,面临一个问题:
我们知道,梯度下降法是求得某个点,使得loss最小。通过往梯度减小的方向更新权重值,可以使得loss减小。如下图所示:
在这里插入图片描述

绿色箭头所示,为梯度减小的方向。沿此方向更新权重,使得loss减小。

但这种方法面临一个很尖锐的问题。当梯度为0时,无论怎么更新权重,loss都不再改变,从而无法找到最优点。如下图所示,当位于红圈标出的区域时,梯度为0,此时梯度下降法就失效了,无法找到最优点。
在这里插入图片描述
但引入随机梯度下降SGD方法时,就能很大程度避免这个问题。

原因时:批量梯度下降时,全部的样本用于求loss。而随机梯度下降是,随机选取一个样本求loss进而求梯度。这种方式就很大程度上规避了梯度为0的情况。即使某次训练更新权重时,梯度为0,下次也不一定为0。而批量梯度下降则不然,本次更新权重时,梯度为0,下次还是0。梯度下降法就失效了。

但批量梯度下降和随机梯度下降有各自的优缺点:

1.使用批量梯度下降时,虽然模型的性能低,但耗费时间时间也低。(由于其求梯度,更新权重时,可以并行计算,因此是求所有样本损失的累加和)

2.使用随机梯度下降时,虽然模型的性能高,但耗费的时间也高。(由于其求梯度,更新权重时,可以并行计算。某步更新权重,要依赖上一步权重)

关于这一块,大家可以参考:批量梯度下降(BGD)、随机梯度下降(SGD)以及小批量梯度下降(MBGD)的理解

因此就提出了一种折中的方法:小批量梯度下降(MBGD)
下图,为三个方法,一次训练时,使用样本量的示意图。
左边红色的大框,指的是批量梯度下降把全部的样本由于一次更新权重的训练。
左边红色的多个小框,表示随机梯度下降随机选取一个样本用于一次更新权重的训练。
蓝色的框,表示把样本分为几批(batch),每次用一批的样本来进行一次更新权重的训练。
在这里插入图片描述

3.补充知识

对卷积神经网络中术语的理解:Epoch、Batch Size和batchsize

所谓
Epoch:1个epoch等于使用训练集中的全部样本训练一次,通俗的讲几次epoch就是整个数据集被轮几次

Batch Size:全部数据是分批来训练的,批的大小称为Batch Size

iteration:1个iteration等于使用batchsize个样本训练一次,也就是说训练一批的样本,称为一次迭代

比如训练集有500个样本,batchsize = 10 ,那么训练完整个样本集:iteration=50,epoch=1.

batch: 深度学习每一次参数的更新所需要损失函数并不是由一个数据获得的,而是由一组数据加权得到的,这一组数据的数量就是batchsize。

batchsize最大是样本总数N,此时就是Full batch learning;最小是1,即每次只训练一个样本,这就是在线学习(Online Learning)。当我们分批学习时,每次使用过全部训练数据完成一次Forword运算以及一次BP运算,成为完成了一次epoch。

参考:

1.批量梯度下降(BGD)、随机梯度下降(SGD)以及小批量梯度下降(MBGD)的理解

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

神经网络训练中batch的作用(从更高角度理解) 的相关文章

  • 经典卷积神经网络——resnet

    resnet 前言 一 resnet 二 resnet网络结构 三 resnet18 1 导包 2 残差模块 2 通道数翻倍残差模块 3 rensnet18模块 4 数据测试 5 损失函数 优化器 6 加载数据集 数据增强 7 训练数据 8
  • 毕业设计-基于深度学习的命名实体识别研究

    目录 目录 前言 课题背景和意义 实现技术思路 一 命名实体识别简单概述 二 基于深度学习的命名实体识别方法 实现结果 最后 前言 大四是整个大学期间最忙碌的时光 一边要忙着备考或实习为毕业后面临的就业升学做准备 一边要为毕业设计耗费大量精
  • 毕业设计-基于深度学习的垃圾分类识别方法

    目录 前言 课题背景和意义 实现技术思路 一 目标检测算法对比研究 二 垃圾数据集的制作 实现效果图样例 最后 前言 大四是整个大学期间最忙碌的时光 一边要忙着备考或实习为毕业后面临的就业升学做准备 一边要为毕业设计耗费大量精力 近几年各个
  • 静默执行bat文件

    让bat隐藏运行需要用vbs文件才能实现 方式一 使用vbs文件 新建一个 文本文档后缀改为 vbs 可以这样写 set ws WScript CreateObject WScript Shell ws Run d yy bat 0 其中d
  • 卷积神经网络详解

    卷积神经网络 Convolutional Neural Networks CNN 是应用最多 研究最广的一种神经网络 卷积神经网络 以下简称CNN 主要用于图片分类 自动标注以及产品推荐系统中 以CNN实现图片分类为例 图像经过多个卷积层
  • 神经网络-Unet网络

    文章目录 前言 1 seq2seq 编码后解码 2 网络结构 3 特征融合 4 前言 Unet用来做小目标语义分割 优点 网络结构非常简单 大纲目录 2016年特别火 在细胞领域做分割特别好 1 seq2seq 编码后解码 2 网络结构 几
  • 【转载】CNN模型复杂度(FLOPs、MAC)、参数量与运行速度

    备忘 作者写错了 1次乘加运算等于2次浮点运算 但在数值上正好反过来 即1 FLOPs 2 MACs 例如对于卷积运算的计算是 其MACs 参数m 输出尺寸 n 而FLOPs 2 MACs Nvidia团队论文里面写的是对的 2倍 CNN模
  • 卷积神经网络中图像池化操作全解析

    一 池化的过程 卷积层是对图像的一个邻域进行卷积得到图像的邻域特征 亚采样层 池化层 就是使用pooling技术将小邻域内的特征点整合得到新的特征 在完成卷积特征提取之后 对于每一个隐藏单元 它都提取到 r a 1 c b 1 个特征 把它
  • yolov5识别cf火线敌人(FPS类的AI瞄准)详细教程一

    一 前言 因为自己的研究方向也是深度学习方向 而且平时闲的时候还喜欢玩会cf火线等枪战游戏 就想着找一个大模型做一个对游戏敌人的识别的功能 一切实现之后就想把自己的心得写出来 我打算分俩个教程分别细述整个学习以及操作的过程 教程一主要包括了
  • keras卷积神经网络入门-笑脸检测

    keras卷积神经网络入门 笑脸检测 Keras简介 1 库函数导入 2 查看数据集 3 构建模型 4 训练模型 5 预测自己的图片 Keras简介 Keras以其强大的封装结构 让我们不必过多的考虑神经网络间的计算 极大简化了tensor
  • 李宏毅 - 卷积神经网络(CNN)

    李宏毅 卷积神经网络 CNN 卷积神经网络主要用于图像分类 一张图片通过我们的卷积神经网络也就是Model计算出概率值 通过Cross entropy 交叉熵 归一化到0和1 概率最大的显示为1 其余显示为0 那么一张图片是怎么输入到Mod
  • 卷积尺寸计算公式(速查备用)

    torch代码计算 def paras cnn k s p i 64 x torch ones 1 1 i i conv torch nn Conv2d 1 1 kernel size k stride s padding p convt
  • windows批处理:if else的踩坑点及排版优化

    参考 https www jianshu com p f0bde7d355a4 总结 见参考文章
  • 机器学习原来这么有趣 Part3: 深度学习与卷积神经网络

    最近看了Adam Geitgey的机器学习系列文章 寻思着闲着也是闲着 干脆翻译以下 顺便学习下英语啥的哈哈哈 第一次做这种事 有不到位的地方欢迎指教噢 前言 你是否已经厌倦了在查阅了无数有关深度学习的文章之后仍然不能参透其中深意的无力感
  • 论文阅读笔记之——《Multi-level Wavelet-CNN for Image Restoration》及基于pytorch的复现

    本博文是MWCNN的阅读笔记 论文的链接 https arxiv org pdf 1805 07071 pdf 代码 https github com lpj0 MWCNN 仅仅是matlab代码 通过参考代码 对该网络在pytorch框架
  • 基于卷积神经网络的人脸表情识别综述

    基于卷积神经网络的人脸表情识别 摘要 在日常的沟通与交流过程中 运用面部表情可以促使沟通交流变得更加顺畅 因此对于人类而言 进行面部表情的解读也是进行相关沟通交流内容获取的重要程序 随着科学技术的不断发展 人工智能在日常人类交流沟通中 运用
  • CNN中特征融合的一些策略

    Introduction 特征融合的方法很多 如果数学化地表示 大体可以分为以下几种 X Y textbf X textbf Y X Y X
  • 【深度学习】经典的卷积神经网络模型介绍(LeNet、AlexNet、VGGNet、GoogLeNet、ResNet和MobileNet)

    经典的卷积神经网络模型介绍 卷积神经网络简介 一 LeNet 1 INPUT层 输入层 2 C1层 卷积层 3 S2层 池化层 下采样层 4 C3层 卷积层 5 S4层 池化层 下采样层 6 C5层 卷积层 7 F6层 全连接层 二 Ale
  • CUDA的下载安装

    大家好 下面将进行CUDA的下载安装 下载安装的详细步骤描述如下 1 CUDA下载 https download csdn net download qq 41104871 87462747 2 CUDA安装 1 首先 需要解压缩下载好的C
  • 深度学习笔记3——AlexNet

    1 背景介绍 在2012年的ImageNet竞赛中 AlexNet获得了top 5测试的15 3 error rate 获得第二名的方法error rate 是 26 2 AlexNet有60 million个参数和65000个 神经元 五

随机推荐

  • python实现常用数据结构

    本文基于Python实现以下几种常用的数据结构 栈 队列 优先队列 二叉树 单链表 双向链表 栈 基于List实现 class Stack 栈 def init self self arr self size 0 def push self
  • windows 10自带命令查看文件的哈希值

    windows的powershell自带了查看文件哈希值的命令 Get FileHash 文件名 Algorithm MD5 SHA1 SHA256 案例 查看文件的MD5值 查看文件的SHA1值 查看文件的SHA256值
  • springboot调整请求头大小_SpringBoot http post请求数据大小设置操作

    背景 使用http post请求方式的接口 使用request getParameter XXX 的方法获取参数的值 当数据量超过几百k的时候 接口接收不到数据或者接收为null RequestMapping value rcv metho
  • GitHub Flavored Markdown 规范

    Markdown是一种轻量级标记语言 它以纯文本形式编写文档 易读 看起来舒服 易写 语法简单 易更改 并最终以HTML格式发布 由于markdown没有明确指定语法 随着编译器不一样 实现方式有很大差异 GitHub Flavored M
  • SocketOutputStream和SocketChannel write方法的区别和底层实现

    Java直接内存原理提到了SocketChannel write的实现原理 通过IOUtil write将java堆内存拷贝到了直接内存 然后再把地址传给了I O函数 那么 BIO 是怎么实现往socket里面写数据的呢 BIO Socke
  • Java多线程知识点总结(思维导图+源码笔记)

    转自 https blog csdn net yelvgou9995 article details 107408709 多线程大家在初学的时候 对这个知识点应该有不少的疑惑的 我认为主要原因有两个 多线程在初学的时候不太好学 并且一般写项
  • Sitecore站点更新License

    一 简介 Sitecore 是一个基于ASP NET 技术的 CMS 系统 它不仅具有传统 Web CMS 的所有功能 还集成了 Marketing 营销 当然 这个功能价格不菲 的功能 可以提供一个一站式的在线营销解决方案 对于 NET
  • 深入理解数据结构——哈夫曼树

    include
  • [589]IDM下载器

    Internet Download Manager 简称 IDM 是一种将下载速度提高5倍的工具 可以恢复和安排下载 由于连接丢失 网络问题 计算机关机或意外停电等原因 全面的错误恢复和恢复功能将重新启动中断或中断的下载 简单的图形用户界面
  • 最快实现一个自己的扫地机

    作者 良知犹存 转载授权以及围观 欢迎关注微信公众号 羽林君 或者添加作者个人微信 become me 扫地机介绍 扫地机器人行业本质是技术驱动型行业 产品围绕导航系统的升级成为行业发展的主旋律 按功能划分 扫地机器人分为四大系统 即导航系
  • 【视频解读】AutoGluon背后的技术

    1 资料来源 AutoGluon背后的技术 哔哩哔哩 bilibili 也是一种Automl框架 在尽量不需要人的帮助下 对输入进行特征提取 选取适合的机器学习模型对它进行训练 大部分基于超参数搜索技术 从数十或者数百个参数中选取一个合适的
  • 判断List、Map集合是否为空的方法

    在Java中 判断集合是否为空有几种方法 以下是其中的一些 1 使用List isEmpty 方法 例如 List
  • openGL之API学习(六十三)GL_RASTERIZER_DISCARD

    glEnable GL RASTERIZER DISCARD 使用GL RASTERIZER DISCARD标志作为参数调用glEnable 函数 告诉渲染管线在transform feedback可选阶段之后和到达光栅器前抛弃所有的图元
  • 与计算机信息技术有关的课题,信息技术课题研究报告.doc

    PAGE PAGE 1 信息技术环境下教学模式和教学方法的创新研究 课题研究报告 摘要 本课题由中央电教馆与有关专家在充分论证的基础上 于2006年12月被批准为中央电化教育馆全国教育技术 十一五 专项课题 在中央电教馆组织下 课题研究得到
  • 机器学习在交通标志检测与精细分类中的应用

    导读 数据对于地图来说十分重要 没有数据 就没有地图服务 用户在使用地图服务时 不太会想到数据就像冰山一样 用户可见只是最直接 最显性的产品功能部分 而支撑显性部分所需要的根基 往往更庞大 地图数据最先是从专业采集来的 采集工具就是车 自行
  • python学习笔记2

    if语法 if True print 条件成 执 的代码1 print 条件成 执 的代码2 下 的代码没有缩进到if语句块 所以和if条件 关 print 我是 论条件是否成 都要执 的代码 if else if 条件 条件成 执 的代码
  • linux查看用户登录时间以及命令历史

    1 查看当前登录用户信息 who命令 who缺省输出包括用户名 终端类型 登陆日期以及远程主机 who var log wtmp 可以查看自从wtmp文件创建以来的每一次登陆情况 1 b 查看系统最近一次启动时间 2 H 打印每列的标题 u
  • 转载-STM32片上FLASH内存映射、页面大小、寄存器映射

    原文地址 http blog chinaunix net uid 20617446 id 3847242 html 本文以STM32F103RBT6为例介绍了片上Flash Embedded Flash 若干问题 包括Flash大小 内存映
  • LAMP框架的架构与环境配置

    1 LAMP架构的相关知识 1 1 LAMP架构的概述 LAMP架构是目前成熟的企业网站应用模式之一 指的是协同工作的一整套系统和相关软件 能够提供动态Web站点服务及其应用开发环境 LAMP是一个缩写词 具体包括Linux操作系统 Apa
  • 神经网络训练中batch的作用(从更高角度理解)

    1 什么是batch batch 翻译成汉语为批 一批一批的批 在神经网络模型训练时 比如有1000个样本 把这些样本分为10批 就是10个batch 每个批 batch 的大小为100 就是batch size 100 每次模型训练 更新