BatchNormalization、LayerNormalization、InstanceNorm、GroupNorm、SwitchableNorm总结

2023-11-07

本篇博客总结几种归一化办法,并给出相应计算公式和代码。


1、综述

1.1 论文链接

1、Batch Normalization

https://arxiv.org/pdf/1502.03167.pdf

2、Layer Normalizaiton

https://arxiv.org/pdf/1607.06450v1.pdf

3、Instance Normalization

https://arxiv.org/pdf/1607.08022.pdf

https://github.com/DmitryUlyanov/texture_nets

4、Group Normalization

https://arxiv.org/pdf/1803.08494.pdf

5、Switchable Normalization

https://arxiv.org/pdf/1806.10779.pdf

https://github.com/switchablenorms/Switchable-Normalization

1.2 介绍

归一化层,目前主要有这几个方法,Batch Normalization(2015年)、Layer Normalization(2016年)、Instance Normalization(2017年)、Group Normalization(2018年)、Switchable Normalization(2018年);

将输入的图像shape记为[N, C, H, W],这几个方法主要的区别就是在,

  • batchNorm是在batch上,对NHW做归一化,对小batchsize效果不好;
  • layerNorm在通道方向上,对CHW归一化,主要对RNN作用明显;
  • instanceNorm在图像像素上,对HW做归一化,用在风格化迁移;
  • GroupNorm将channel分组,然后再做归一化;
  • SwitchableNorm是将BN、LN、IN结合,赋予权重,让网络自己去学习归一化层应该使用什么方法。

这里写图片描述

2、Batch Normalization

首先,在进行训练之前,一般要对数据做归一化,使其分布一致,但是在深度神经网络训练过程中,通常以送入网络的每一个batch训练,这样每个batch具有不同的分布;此外,为了解决internal covarivate shift问题,这个问题定义是随着batch normalizaiton这篇论文提出的,在训练过程中,数据分布会发生变化,对下一层网络的学习带来困难。

所以batch normalization就是强行将数据拉回到均值为0,方差为1的正太分布上,这样不仅数据分布一致,而且避免发生梯度消失。

此外,internal corvariate shift和covariate shift是两回事,前者是网络内部,后者是针对输入数据,比如我们在训练数据前做归一化等预处理操作。

这里写图片描述

算法过程:

  • 沿着通道计算每个batch的均值u
  • 沿着通道计算每个batch的方差σ^2
  • 对x做归一化,x’=(x-u)/开根号(σ^2+ε)
  • 加入缩放和平移变量γ和β ,归一化后的值,y=γx’+β

加入缩放平移变量的原因是:保证每一次数据经过归一化后还保留原有学习来的特征,同时又能完成归一化操作,加速训练。 这两个参数是用来学习的参数。

import numpy as np

def Batchnorm(x, gamma, beta, bn_param):

    # x_shape:[B, C, H, W]
    running_mean = bn_param['running_mean']
    running_var = bn_param['running_var']
    results = 0.
    eps = 1e-5

    x_mean = np.mean(x, axis=(0, 2, 3), keepdims=True)
    x_var = np.var(x, axis=(0, 2, 3), keepdims=True0)
    x_normalized = (x - x_mean) / np.sqrt(x_var + eps)
    results = gamma * x_normalized + beta

    # 因为在测试时是单个图片测试,这里保留训练时的均值和方差,用在后面测试时用
    running_mean = momentum * running_mean + (1 - momentum) * x_mean
    running_var = momentum * running_var + (1 - momentum) * x_var

    bn_param['running_mean'] = running_mean
    bn_param['running_var'] = running_var

    return results, bn_param

3、Layer Normalizaiton

batch normalization存在以下缺点:

  • 对batchsize的大小比较敏感,由于每次计算均值和方差是在一个batch上,所以如果batchsize太小,则计算的均值、方差不足以代表整个数据分布;
  • BN实际使用时需要计算并且保存某一层神经网络batch的均值和方差等统计信息,对于对一个固定深度的前向神经网络(DNN,CNN)使用BN,很方便;但对于RNN来说,sequence的长度是不一致的,换句话说RNN的深度不是固定的,不同的time-step需要保存不同的statics特征,可能存在一个特殊sequence比其他sequence长很多,这样training时,计算很麻烦。(参考于https://blog.csdn.net/lqfarmer/article/details/71439314

与BN不同,LN是针对深度网络的某一层的所有神经元的输入按以下公式进行normalize操作。

这里写图片描述

BN与LN的区别在于:

  • LN中同层神经元输入拥有相同的均值和方差,不同的输入样本有不同的均值和方差;
  • BN中则针对不同神经元输入计算均值和方差,同一个batch中的输入拥有相同的均值和方差。

    所以,LN不依赖于batch的大小和输入sequence的深度,因此可以用于batchsize为1和RNN中对边长的输入sequence的normalize操作。

LN用于RNN效果比较明显,但是在CNN上,不如BN。

def ln(x, b, s):
    _eps = 1e-5
    output = (x - x.mean(1)[:,None]) / tensor.sqrt((x.var(1)[:,None] + _eps))
    output = s[None, :] * output + b[None,:]
    return output

用在四维图像上,

def Layernorm(x, gamma, beta):

    # x_shape:[B, C, H, W]
    results = 0.
    eps = 1e-5

    x_mean = np.mean(x, axis=(1, 2, 3), keepdims=True)
    x_var = np.var(x, axis=(1, 2, 3), keepdims=True0)
    x_normalized = (x - x_mean) / np.sqrt(x_var + eps)
    results = gamma * x_normalized + beta
    return results

4、Instance Normalization

BN注重对每个batch进行归一化,保证数据分布一致,因为判别模型中结果取决于数据整体分布。

但是图像风格化中,生成结果主要依赖于某个图像实例,所以对整个batch归一化不适合图像风格化中,因而对HW做归一化。可以加速模型收敛,并且保持每个图像实例之间的独立。

公式:

这里写图片描述

代码:

def Instancenorm(x, gamma, beta):

    # x_shape:[B, C, H, W]
    results = 0.
    eps = 1e-5

    x_mean = np.mean(x, axis=(2, 3), keepdims=True)
    x_var = np.var(x, axis=(2, 3), keepdims=True0)
    x_normalized = (x - x_mean) / np.sqrt(x_var + eps)
    results = gamma * x_normalized + beta
    return results

5、Group Normalization

主要是针对Batch Normalization对小batchsize效果差,GN将channel方向分group,然后每个group内做归一化,算(C//G)*H*W的均值,这样与batchsize无关,不受其约束。

公式:

这里写图片描述

伪代码:

这里写图片描述

代码:

def GroupNorm(x, gamma, beta, G=16):

    # x_shape:[B, C, H, W]
    results = 0.
    eps = 1e-5
    x = np.reshape(x, (x.shape[0], G, x.shape[1]/16, x.shape[2], x.shape[3]))

    x_mean = np.mean(x, axis=(2, 3, 4), keepdims=True)
    x_var = np.var(x, axis=(2, 3, 4), keepdims=True0)
    x_normalized = (x - x_mean) / np.sqrt(x_var + eps)
    results = gamma * x_normalized + beta
    return results

6、Switchable Normalization

本篇论文作者认为,

  • 第一,归一化虽然提高模型泛化能力,然而归一化层的操作是人工设计的。在实际应用中,解决不同的问题原则上需要设计不同的归一化操作,并没有一个通用的归一化方法能够解决所有应用问题;
  • 第二,一个深度神经网络往往包含几十个归一化层,通常这些归一化层都使用同样的归一化操作,因为手工为每一个归一化层设计操作需要进行大量的实验。

因此作者提出自适配归一化方法——Switchable Normalization(SN)来解决上述问题。与强化学习不同,SN使用可微分学习,为一个深度网络中的每一个归一化层确定合适的归一化操作。

公式:

这里写图片描述

这里写图片描述

这里写图片描述

代码:

def SwitchableNorm(x, gamma, beta, w_mean, w_var):
    # x_shape:[B, C, H, W]
    results = 0.
    eps = 1e-5

    mean_in = np.mean(x, axis=(2, 3), keepdims=True)
    var_in = np.var(x, axis=(2, 3), keepdims=True)

    mean_ln = np.mean(x, axis=(1, 2, 3), keepdims=True)
    var_ln = np.var(x, axis=(1, 2, 3), keepdims=True)

    mean_bn = np.mean(x, axis=(0, 2, 3), keepdims=True)
    var_bn = np.var(x, axis=(0, 2, 3), keepdims=True)

    mean = w_mean[0] * mean_in + w_mean[1] * mean_ln + w_mean[2] * mean_bn
    var = w_var[0] * var_in + w_var[1] * var_ln + w_var[2] * var_bn

    x_normalized = (x - mean) / np.sqrt(var + eps)
    results = gamma * x_normalized + beta
    return results

结果比较:

这里写图片描述
这里写图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

BatchNormalization、LayerNormalization、InstanceNorm、GroupNorm、SwitchableNorm总结 的相关文章

  • jni编写时的教训(函数签名不对应)

    最近由于项目结构上的调整 原先我的模块位于APP层 通过aidl hidl调用到native层的c 的服务的接口 用于更新EMMC上的文件内容 需要改为 C 服务更新EMMC上文件内容的代码封装成jni的so库 由我的模块去调用 由于jni
  • Python---字典添加元素

    1 8 5 cvar 字典 特点 1 符号 2 关键字 dict 3 保存元素是 key value 一对 定义 dict1 空字典 dict2 dict 空字典 dict3 ID 123156489795 name lucky age 1
  • VUE---7.事件&循环

    目录 一 事件 二 按键 1 按键修饰符 内置 2 自定义修饰符 event key 3 系统修饰符 4 组合修饰符 三 循环 一 事件 1 绑定事件 2 event事件对象 3 修饰符 stop 阻止冒泡 prevent 阻止默认事件 c

随机推荐

  • 华为服务器近端连接显示绿屏,故障诊断说明 - 华为服务器 iMana 200 用户指南 26 - 华为...

    MCE Error Diagnose DIMMxxx memory fault is doubted of this error Error Handling Suggestion Please shut down system and c
  • Rancher2.x入门教程

    1 x教程请参考上篇 容器管理Rancher1 x及监控工具入门 1 简介 为了更好的管理Kubernetes Rancher的大部分功能经过重新设计 并且Rancher2 0延续了大多数1 0版本的友好功能 如简洁的UI和应用商店等 2
  • nohup和screen的比较

    在实际工作中 我们ssh等到远程的Linux上 运行一个程序 但是当我们关闭掉我们的连接终端断开ssh后 刚才运行着的程序也会自动被中断结束 当ssh连接断开后 如何让我们的程序继续在后台运行呢 下面介绍我常使用的两个方法 A 使用nohu
  • Qt简易闹钟

    配置文件 QT core gui texttospeech greaterThan QT MAJOR VERSION 4 QT widgets CONFIG c 11 The following define makes your comp
  • Leetcode 632最小区间

    632 最小区间 难度 困难 标签 哈希表 双指针 字符串 Description 你有 k 个 非递减排列 的整数列表 找到一个 最小 区间 使得 k 个列表中的每个列表至少有一个数包含在其中 我们定义如果 b a lt d c 或者在
  • linux top详解

    语法 root incloudos logs top h procps ng version 3 3 10 Usage top hv bcHiOSs d secs n max u U user p pid s o field w cols
  • nginx服务器

    一 介绍 Nginx engine x 是一个高性能的HTTP和反向代理服务器 也是一个IMAP POP3 SMTP服务 Nginx是由伊戈尔 赛索耶夫为俄罗斯访问量第二的Rambler ru站点 俄文 开发的 第一个公开版本0 1 0发布
  • 【算法/剑指Offer】如何得到一个数据流中的中位数?如果从数据流中读出奇数个数值,那么中位数就是所有数值排序之后位于中间的数值。

    题目描述 如何得到一个数据流中的中位数 如果从数据流中读出奇数个数值 那么中位数就是所有数值排序之后位于中间的数值 如果从数据流中读出偶数个数值 那么中位数就是所有数值排序之后中间两个数的平均值 我们使用Insert 方法读取数据流 使用G
  • Angular4.0_数据绑定和管道

    单向数据绑定 使用插值表达式将一个表达式的值显示在模板上 h1 productTitle h1 事件绑定 使用小括号将组建控制器的一个方法绑定为模板上一个事件的处理器
  • ROS机器人语音模块

    ROS机器人语音模块 文章目录 ROS机器人语音模块 零 乘骐骥以驰骋兮 来吾道夫先路 壹 路漫漫其修远兮 吾将上下而求索 贰 苟余情其信姱以练要兮 长顑颔亦何伤 叁 不吾知其亦已兮 苟余情其信芳 肆 虽体解吾犹未变兮 岂余心之可惩 末 亦
  • 【仿真】Carla介绍与使用 [1] (附代码手把手讲解)

    0 参考与前言 主要介绍无人驾驶的仿真环境CARLA 开源社区维护 以下为相关参考链接 Carla官方文档 建议后续找的时候 先按好版本号 有些功能 api 是新版本里有的 Carla官方github Youtube Python Wind
  • vue css >>> , /deep/ 深度选择器

    vue引用了第三方组件 有时候我们需要改写第三方组件的样式 而又不想去除scoped属性造成组件之间的样式污染 此时只能通过 gt gt gt 穿透scoped 有些Sass 之类的预处理器无法正确解析 gt gt gt 这时可以使用 de
  • stm32 利用链表和定时器动态实现led等器件周期性控制

    stm32 esp8266 ota系列文章 stm32 esp8266 ota 快速搭建web服务器之docker安装openresty stm32 esp8266 ota升级 tcp模拟http stm32 esp8266 ota升级 h
  • 初探 ModBus4j -简单使用指南

    目录 前言 开发环境 工具准备 具体实现 下载Modbus4j 解决空指针异常 解决数组越界 测试 测试环境准备 正式测试 前言 之前提到过 由于项目需求 需要封装 ModBus协议 ModBus协议较早 网上开源开源库也不少 可参见 Mo
  • STM32驱动8266-----8266AP模式

    找了很久 一直没有找到驱动的程序 查一些资料 字写了一个简单程序 记录分享一下 void esp8266 inittcp void printf AT CIPMODE 2 r n 设置AP模式 delay ms 10000 延时函数 pri
  • vue3.0教程——搭建Vue脚手架【简化版】

    目录 哈喽 大家好丫 你们的小郭子又来啦 一 环境要求 1 node安装 前端开发环境 2 vue cli脚手架安装 二 安装依赖 1 使用命令行安装以下依赖 2 通过 vue ui 命令以图形化界面来管理项目依赖 3 导入你刚刚项目的地址
  • 装系统使用默认administrator用户

    在设置键盘布局界面按下Ctrl Shift F4重启 进入系统后 见到一个 系统准备工具 3 14 开始 运行 输入 XCOPY windir System32 svchost exe windir System32 oobe audit
  • MyBatis框架( 项目构建笔记 )

    MyBatis框架 项目构建笔记 一 框架 二 获取参数 三 查询 四 模糊查询 批量删除 五 resultMap和映射关系 五 动态SQL 基本功能实现的项目结构 将SqlSessionFactory 使用工具类进行封装 映射文件的名称要
  • 牛客 AB28 快速幂 JAVA

    描述 请你计算 ab mod p 的值 一共有 q 次询问 输入描述 第一行输入一个正整数 q 代表询问次数 接下来每行输入三个正整数 a b p 代表一次询问 数据范围 1 1051 q 105 1 1071 a b p 107 输出描述
  • BatchNormalization、LayerNormalization、InstanceNorm、GroupNorm、SwitchableNorm总结

    本篇博客总结几种归一化办法 并给出相应计算公式和代码 1 综述 1 1 论文链接 1 Batch Normalization https arxiv org pdf 1502 03167 pdf 2 Layer Normalizaiton