【深度学习】深入理解Batch Normalization批标准化

2023-05-16

Batch Normalization

  • 1. “Internal Covariate Shift”问题
  • 2. BatchNorm的本质思想
    • 1)函数图像说明
    • 2)算法
    • 3)引入参数恢复表达能力
    • 4)公式
  • 3. 测试阶段如何使用Batch Normalization?
  • 4. BatchNorm的优势

机器学习有个很重要的假设: IID独立同分布建设,就是 假设训练数据和测试数据是满足相同分布的,这就是通过训练数据获得的模型能够在测试集获得很好效果的一个基本保障。BatchNorm的作用是什么呢?

BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布的。

接下来一步一步理解什么是BN。

为什么深度神经网络随着网络深度加深,训练起来越困难,收敛越来越慢?这是个DL领域很接近本质的好问题。这也等同于梯度消失和梯度爆炸的问题。

很多论文都是解决这个问题的,比如ReLU激活函数,再比如Residual Network,BN本质上也是解释并从某个不同的角度来解决这个问题的。

1. “Internal Covariate Shift”问题

BN是用来解决“Internal Covariate Shift”问题的,那么首先得理解什么是“Internal Covariate Shift”?

论文首先说明Mini-Batch SGD相对于One Example SGD的两个优势:梯度更新方向更准确;并行计算速度快;(为什么要说这些?因为BatchNorm是基于Mini-Batch SGD的,所以先夸下Mini-Batch SGD,当然也是大实话);然后吐槽下SGD训练的缺点:超参数调起来很麻烦。(作者隐含意思是用BN就能解决很多SGD的缺点)

接着引入covariate shift的概念:如果ML系统实例集合<X,Y>中的输入值X的分布老是变,这不符合IID假设,网络模型很难稳定的学规律。对于深度学习这种包含很多隐层的网络结构,在训练过程中,因为各层参数不停在变化,所以每个隐层都会面临covariate shift的问题,也就是在训练过程中,隐层的输入分布老是变来变去,这就是所谓的“Internal Covariate Shift”,Internal指的是深层网络的隐层,是发生在网络内部的事情,而不是covariate shift问题只发生在输入层。

然后提出了BatchNorm的基本思想:能不能让每个隐层节点的激活输入分布固定下来呢?这样就避免了“Internal Covariate Shift”问题了。

BN不是凭空拍脑袋拍出来的好点子,它是有启发来源的:之前的研究表明如果在图像处理中对输入图像进行白化(Whiten)操作的话——所谓白化,就是对输入数据分布变换到0均值,单位方差的正态分布——那么神经网络会较快收敛,那么BN作者就开始推论了:图像是深度神经网络的输入层,做白化能加快收敛,那么其实对于深度网络来说,其中某个隐层的神经元是下一层的输入,意思是其实深度神经网络的每一个隐层都是输入层,不过是相对下一层来说而已,那么能不能对每个隐层都做白化呢?这就是启发BN产生的原初想法,而BN也确实就是这么做的,可以理解为对深层神经网络每个隐层神经元的激活值做简化版本的白化操作


2. BatchNorm的本质思想

BN的基本思想其实相当直观:因为深层神经网络在做非线性变换之前的激活输入值(z = wx +b,x是输入,z是非线性函数输入值)随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者移动,之所以训练收敛变慢,一般是整体分布逐渐往非线性函数的取值区间的上下两端靠近(对于sigmoid函数来说,意味着激活输入值z是大的负值或正值),所以这导致反向传播时底层神经网络的梯度消失,这是训练神经网络收敛越来越慢的本质原因,而BN就是通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布强行拉回到均值为0方差为1的标准正态分布,其实就是把越来越偏的分布强制拉回比较标准的分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域,这样输入的小变化就会导致损失函数较大的变化,意思是这样让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度。

其实一句话就是:对于每个隐层神经元,把逐渐向非线性函数映射后取值区间极限饱和区靠拢的输入分布 强制拉回到均值为0方差为1的比较标准的正态分布,使得非线性变换的输入值落入对输入比较敏感的区域,以此避免梯度消失问题。因为梯度一直能保存比较大的状态,所以很明显对神经网络的参数调整效率比较高,就是变动大,就是说像损失函数最优值迈动的步子大,也就是收敛的快。

1)函数图像说明

上面说得还是显得抽象,下面更形象地表达下这种调整到底代表什么含义。
在这里插入图片描述上图为几个正态分布

假设某个隐层神经元原先的激活输入x取值符合正态分布,正态分布均值是-2,方差是0.5,对应上图中最左端的浅蓝色曲线,通过BN后转换为均值为0,方差是1的正态分布(对应上图中的深蓝色图形),意味着什么,意味着输入x的取值正态分布整体右移2(均值的变化),图形曲线更平缓了(方差增大的变化)。这个图的意思是,BN其实就是把每个隐层神经元的激活输入分布从偏离均值为0方差为1的正态分布通过平移均值压缩或者扩大曲线尖锐程度,调整为均值为0方差为1的正态分布。

那么把激活输入x调整到这个正态分布有什么用?首先我们看下均值为0,方差为1的标准正态分布代表什么含义:
在这里插入图片描述
上图为均值为0方差为1的标准正态分布图。

这意味着在一个标准差范围内,也即是说64%的概率 z 其值落在[-1,1]的范围内,在两个标准差范围内,95%的概率 z 其值落在了[-2,2]范围内。这意味着什么呢?z是某个神经元的线性计算结果,假设非线性函数时sigmoid,那么看下sigmoid(z)的图形:
在这里插入图片描述
及sigmoid(z)的导数为:G’=f(z)*(1-f(z)),因为f(z)=sigmoid(z)在0到1之间,所以G’在0到0.25之间,其对应的图如下:
在这里插入图片描述
假设没有经过BN调整前 z 的原先正态分布均值是-6,方差是1,那么意味着95%的值落在了[-8,-4]之间,那么对应的Sigmoid(z)函数的值明显接近于0,这是典型的梯度饱和区,在这个区域里梯度变化很慢,为什么是梯度饱和区?请看下 sigmoid(z) 如果取值接近0或者接近于1的时候对应导数函数取值,接近于0,意味着梯度变化很小甚至消失。

而假设经过BN后,均值是0,方差是1,那么意味着95%的 z 值落在了[-2,2]区间内,很明显这一段是 sigmoid(z) 函数接近于线性变换的区域,意味着 z 的小变化会导致非线性函数值较大的变化,也即是梯度变化较大,对应导数函数图中明显大于0的区域,就是梯度非饱和区。

2)算法

1. 参数定义

我们定义网络总共有 L 层(不包含输入层)并定义如下符号:
在这里插入图片描述

参数相关:
在这里插入图片描述
样本相关:
在这里插入图片描述
2.算法步骤

第一点,对每个特征进行独立的normalization。我们考虑一个batch的训练,传入m个训练样本,并关注网络中的某一层,忽略上标 L。在这里插入图片描述
我们关注当前层的第 j 个维度,也就是第 j 个神经元结点,则有 在这里插入图片描述。我们当前维度进行规范化:

在这里插入图片描述

具体例子

下面我们再来结合个具体的例子来进行计算。下图我们只关注第 L 层的计算结果,左边的矩阵是在这里插入图片描述线性计算结果,还未进行激活函数的非线性变换。此时每一列是一个样本,图中可以看到共有8列,代表当前训练样本的batch中共有8个样本,每一行代表当前 L 层神经元的一个节点,可以看到当前 L 层共有4个神经元结点,即第 L 层维度为4.我们可以看到,每行的数据分布都不同。

在这里插入图片描述
对于第一个神经元,我们求得在这里插入图片描述,此时我们利用 在这里插入图片描述对第一行数据(第一个维度)进行normalization得到新的值
在这里插入图片描述同理我们可以计算出其他输入维度归一化后的值。如下图:

在这里插入图片描述
通过上面的变换,使得第 L 层的输入每个特征的分布均值为0,方差为1。

3)引入参数恢复表达能力

Normalization操作我们虽然缓解了“Internal Covariate Shift”问题,让每一层网络的输入数据分布都变得稳定,但却导致了数据表达能力的缺失。也就是我们通过变换操作改变了原有数据的信息表达,使得底层网络学习到的参数信息丢失。另一方面,通过让每一层的输入分布均值为0,方差为1,会使得输入在经过sigmoid或tanh激活函数时,容易陷入非线性激活函数的线性区域。

因此,BN又引入了两个可学习(learnable)的参数在这里插入图片描述这两个参数的引入是为了恢复数据本身的表达能力,对规范化后的数据进行线性变换,即:在这里插入图片描述
特别地,当在这里插入图片描述时,可以实现等价变换(identity transform)并且保留了原始输入特征的分布信息。

通过上面的步骤,我们就在一定程度上保证了输入数据的表达能力。

4)公式

在这里插入图片描述


3. 测试阶段如何使用Batch Normalization?

BN在训练的时候可以根据Mini-Batch里的若干训练实例进行激活数值调整,在预测阶段,有可能只需要预测一个样本或很少的样本,没有像训练样本中那么多的数据,那么这时候怎么对输入做BN呢?

利用BN训练好模型后,我们保留了每组mini-batch训练数据在网络中每一层的在这里插入图片描述在这里插入图片描述。此时我们使用整个样本的统计量来对Test数据进行归一化对每个Mini-Batch的均值和方差求其对应的数学期望即可得出全局统计量,即:在这里插入图片描述
得到每个特征的均值与方差的无偏估计后,我们对test数据采用同样的normalization方法:
在这里插入图片描述


4. BatchNorm的优势

1)BN使得网络中每层输入数据的分布相对稳定,加速模型学习速度

BN通过规范化与线性变换使得每一层网络的输入数据的均值与方差都在一定范围内,使得后一层网络不必不断去适应底层网络中输入的变化,从而实现了网络中层与层之间的解耦,允许每一层进行独立学习,有利于提高整个神经网络的学习速度。

2)BN使得模型对网络中的参数不那么敏感,简化调参过程,使得网络学习更加稳定

在神经网络中,我们经常会谨慎地采用一些权重初始化方法(例如Xavier)或者合适的学习率来保证网络稳定训练。

当学习率设置太高时,会使得参数更新步伐过大,容易出现震荡和不收敛。但是使用BN的网络将不会受到参数数值大小的影响。例如,我们对参数 W 进行缩放得到 aW 。对于缩放前的值 Wu,我们设其均值为在这里插入图片描述,方差为在这里插入图片描述;对于缩放值 aWu ,设其均值为在这里插入图片描述,方差为在这里插入图片描述,我们有:
在这里插入图片描述

我们忽略在这里插入图片描述,则有:
在这里插入图片描述
我们可以看到,经过BN操作以后,权重的缩放值会被“抹去”,因此保证了输入数据分布稳定在一定范围内。另外,权重的缩放并不会影响到对 u 的梯度计算;并且当权重越大时,即 a 越大,1/a越小,意味着权重 W 的梯度反而越小,这样BN就保证了梯度不会依赖于参数的scale,使得参数的更新处在更加稳定的状态。

因此,在使用Batch Normalization之后,抑制了参数微小变化随着网络层数加深被放大的问题,使得网络对参数大小的适应能力更强,此时我们可以设置较大的学习率而不用过于担心模型divergence的风险。

3)BN允许网络使用饱和性激活函数(例如sigmoid,tanh等),缓解梯度消失问题

在不使用BN层的时候,由于网络的深度与复杂性,很容易使得底层网络变化累积到上层网络中,导致模型的训练很容易进入到激活函数的梯度饱和区;通过normalize操作可以让激活函数的输入数据落在梯度非饱和区,缓解梯度消失的问题;另外通过自适应学习 在这里插入图片描述又让数据保留更多的原始信息。

4)BN具有一定的正则化效果

在Batch Normalization中,由于我们使用mini-batch的均值与方差作为对整体训练样本均值与方差的估计,尽管每一个batch中的数据都是从总体样本中抽样得到,但不同mini-batch的均值与方差会有所不同,这就为网络的学习过程中增加了随机噪音,与Dropout通过关闭神经元给网络训练带来噪音类似,在一定程度上对模型起到了正则化的效果。

另外,原作者通过也证明了网络加入BN后,可以丢弃Dropout,模型也同样具有很好的泛化效果。

大神博客Batch Normalization原理与实战
大神博客

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【深度学习】深入理解Batch Normalization批标准化 的相关文章

随机推荐

  • linux日志对应内容

    var log messages 包括整体系统信息 xff0c 其中也包含系统启动期间的日志 此外 xff0c mail xff0c cron xff0c daemon xff0c kern和auth等内容也记录在var log messa
  • 常用证书操作函数

    现有的证书大都采用X 509规范 xff0c 主要同以下信息组成 xff1a 版本号 证书序列号 有效期 拥有者信息 颁发者信息 其他扩展信息 拥有者的公钥 CA对以上信息的签名 OpenSSL实现了对X 509数字证书的所有操作 包括签发
  • MongoDB 匹配查询和比较操作符

    一 匹配查询 1 查询所有 span class token operator gt span db accounts find span class token punctuation span span class token punc
  • 我的2014——典型程序员的一年,不想再重来

    兴冲冲地拿起 xff0c 信誓旦旦的搁在一边 xff0c 以为很快就会回来 xff0c 却一晃而过 xff0c 不再回来 xff1b 我不想再重复过去 xff0c 决定去做 xff0c 写下来 题记 已经记不起我2014的年初是否有过规划
  • 我的2016——程序员年到三十,工作第四年

    看到CSDN 我的2016 主题征文活动 已经是1月6号 xff0c 而截止时间是1月8号 xff0c 对比去年的总结是在闲等活动开始 xff0c 今年在时间上真的是天差地别 但是 xff0c 一年到头 xff0c 还是需要花些时间来回顾这
  • mac下 ndk_build: command not found

    参考 http blog csdn net greenbird811 article details 7543305 在mac下调用ndk build c代码文件提示错误 fix 1 启动终端Terminal 2 进入当前用户的home目录
  • 公司分配IP地址,求主机号码的最小值和最大值。

    问题描述如下 xff1a 姐 xff1a 注意减去2的实际意义 xff1a 网络地址后的第一个主机地址是本网段的网络地址192 168 0 0 xff0c 最 后一个主机地址是本网段的广播地址192 168 255 255
  • Erlang入门

    64 author sunxu 64 copyright C 2023 lt COMPANY gt 64 doc 64 end Created 16 2月 2023 22 16 module test author 34 sunxu 34
  • IPv4地址、IPv6地址和Mac地址的位数

    xff08 1 xff09 IPv4的地址是32位 xff0c 用点分十进制表示 xff0c 每八位划分 xff0c 也就是四个0 255的十进制数 xff0c 这是很常见的 xff08 2 xff09 IPv6的地址是128位 xff0c
  • 用C#连接数据库的方法

    连接SQL Server数据库的方法 xff1a 1 在程序中引用System Data SqlClient命名空间 2 编写连接字符串 xff0c 格式为 xff1a Data Source 61 服务器名称 Initial Catalo
  • gcc 不支持 //注释的解决

    这段时间用slickedit写代码 xff08 windows平台下 xff0c 装了Cygwin xff09 xff0c 编译器用的gcc xff0c 但是有个问题就是用 34 34 写注释的时候 xff0c 编译的时候有错 xff1a
  • python实现按照文件名称进行文件分类

    问题 xff1a 大量名称中带有数字的图片 视频 xff0c 根据名称中数字按照一定的等差数列来排序 xff0c 并且放入指定对应的文件夹中 span class token keyword import span os span clas
  • 【深度学习】Yolov3详解笔记及Pytorch代码

    Yolov3详解笔记及Pytorch代码 预测部分网络结构backbone xff1a Darknet 53output预测结果的解码 训练部分计算loss所需参数pred是什么target是什么loss的计算过程 预测部分 网络结构 DB
  • 【深度学习】各种卷积的理解笔记(2D,3D,1x1,可分离卷积)

    卷积 1 2D卷积单通道版本多通道版本 2 3D卷积3 1x1卷积作用应用 4 卷积算法5 可分离卷积空间可分离卷积深度可分离卷积 1 2D卷积 卷积的目的是从输入中提取有用的特征 在图像处理中 xff0c 卷积可以选择多种不同的滤波器 每
  • 【深度学习】(2+1)D模型框架结构笔记

    xff08 2 43 1 xff09 D 模型框架结构笔记 SpatioTemporalConv模块结构SpatioTemporalResBlock模块结构SpatioTemporalResLayer模块结构2Plus1DNet Spati
  • 【机器学习】LR回归(逻辑回归)和softmax回归

    LR回归 xff08 逻辑回归 xff09 和softmax回归 1 LR回归Logistic回归的函数形式Logistic回归的损失函数Logistic回归的梯度下降法Logistic回归防止过拟合Multinomial Logistic
  • 【深度学习】时间注意力模块与空间注意力模块

    注意力模块 通道 xff08 时间 xff09 注意力模块空间注意力模块 通道 xff08 时间 xff09 注意力模块 为了汇总空间特征 xff0c 作者采用了全局平均池化和最大池化两种方式来分别利用不同的信息 输入是一个 H W C 的
  • 【机器学习】机器学习与统计分布的关系

    这里写目录标题 1 常见的统计学分布1 xff09 离散分布a 伯努利分布b 二项分布c 泊松分布 2 xff09 连续分布a 正态分布 xff08 高斯分布 xff09 b 均匀分布 为什么我们喜欢用 sigmoid 这类 S 型非线性变
  • AKKA入门

    1 Guardian java package com example demo import akka actor typed javadsl ActorContext import akka actor typed ActorRef i
  • 【深度学习】深入理解Batch Normalization批标准化

    Batch Normalization 1 Internal Covariate Shift 问题2 BatchNorm的本质思想1 xff09 函数图像说明2 xff09 算法3 xff09 引入参数恢复表达能力4 xff09 公式 3