Pytorch权重初始化方法——Kaiming、Xavier

2023-11-20

Pytorch权重初始化方法——Kaiming、Xavier

结论

结论写在前。Pytorch线性层采取的默认初始化方式是Kaiming初始化,这是由我国计算机视觉领域专家何恺明提出的。我的探究主要包括:

  • 为什么采取Kaiming初始化?
  • 考察Kaiming初始化的基础——Xavier初始化的公式
  • 考察Kaiming初始化的公式
  • 用Numpy实现一个简易的Kaiming初始化

为什么采取Kaiming初始化?

采取固定的分布?

当考虑怎么初始化权重矩阵这个问题时,可以想到应该使得初始权重具有随机性。提到随机,自然的想法是使用均匀分布或正态分布,那么我们如果采用与模型无关的固定分布(例如标准正态分布(均值为0,方差为1))怎么样?下面我们分析如果对模型本身不加考虑,采取固定的分布,会有什么问题:

  • 如果权重的绝对值太小,在多层的神经网络的每一层,输入信号的方差会不断减小;当到达最终的输出层时,可以理解为输入信号的影响已经降低到微乎其微。一方面训练效果差,另一方面可能会有梯度消失等问题。(此处从略,参考https://zhuanlan.zhihu.com/p/25631496)
  • 如果权重的绝对值太大,同样道理,随着深度的加深,可能会使输入信号的方差过大,这会造成梯度爆炸或消失的问题。

这里举一个例子,假如一个网络使用了多个sigmoid作为中间层(这个函数具有两边导数趋于0的特点):

  • 如果权重初始绝对值太小,随着深度的加深,输入信号的方差过小。当输入很小时,sigmoid函数接近线性,深层模型也失去了非线性性的优点。(模型效果
  • 如果权重初始绝对值太大,随着深度的加深,输入信号的方差过大。绝对值过大的sigmoid输入意味着激活变得饱和,梯度将开始接近零。(梯度消失

Xavier初始化

前面的问题提示我们要根据模型的特点(维度,规模)决定使用的随机化方法(分布的均值、方差),xavier初始化应运而生,它可以使得输入值经过网络层后方差不变。pytorch中这一点是通过增益值gain来实现的,下面的函数用来获得特定层的gain:

torch.nn.init.calculate_gain(nonlinearity, param=None)

增益值表(图片摘自https://blog.csdn.net/winycg/article/details/86649832)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cJij2pJa-1619368710937)(报告.assets/20190125144412278.png)]

Xavier初始化可以采用均匀分布 U(-a, a),其中a的计算公式为:
a = g a i n × 6 f a n _ i n + f a n _ o u t a = gain \times \sqrt[]{\frac{6}{fan\_in+fan\_out}} a=gain×fan_in+fan_out6
Xavier初始化可以采用正态分布 N(0, std),其中std的计算公式为:
s t d = g a i n × 2 f a n _ i n + f a n _ o u t std = gain \times \sqrt[]{\frac{2}{fan\_in+fan\_out}} std=gain×fan_in+fan_out2
其中fan_in和fan_out分别是输入神经元和输出神经元的数量,在全连接层中,就等于输入输出的feature数。

Kaiming初始化

Xavier初始化在Relu层表现不好,主要原因是relu层会将负数映射到0,影响整体方差。所以何恺明在对此做了改进提出Kaiming初始化,一开始主要应用于计算机视觉、卷积网络。

Kaiming均匀分布的初始化采用U(-bound, bound),其中bound的计算公式为:(a 的概念下面再说)
b o u n d = 6 ( 1 + a 2 ) × f a n _ i n bound = \sqrt[]{\frac{6}{(1 + a ^2) \times fan\_in}} bound=(1+a2)×fan_in6
这里补充一点,pytorch中这个公式也通过gain作为中间变量实现,也就是:
b o u n d = g a i n × 3 f a n _ i n bound = gain \times \sqrt[]{\frac{3}{ fan\_in}} bound=gain×fan_in3
其中:
g a i n = 2 1 + a 2 gain = \sqrt{\frac{2}{1 + a^2}} gain=1+a22
Kaiming正态分布的初始化采用N(0,std),其中std的计算公式为:
s t d = 2 ( 1 + a 2 ) × f a n _ i n std = \sqrt[]{\frac{2}{(1 + a ^2) \times fan\_in}} std=(1+a2)×fan_in2
这里稍微解释一下a的含义,源码中的解释为

the negative slope of the rectifier used after this layer

简单说,是用来衡量这一层中负数比例的,负数越多,Relu层会将越多的输入“抹平”为0,a用来平衡这种“抹平”对于方差的影响。

Pytorch Linear层默认初始化

pytorch的线性层进行的默认初始化的例子:

fc1 = torch.nn.Linear(28 * 28, 256)

在Linear类中通过

self.reset_parameters()

这个函数来完成随机初始化的过程,后者使用的是

init.kaiming_uniform_(self.weight, a=math.sqrt(5))

可见是我们前面提到的Kaiming均匀分布的初始化方式,这个函数的内容和前面的公式相符(使用gain作为中间变量):

fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
with torch.no_grad():
    return tensor.uniform_(-bound, bound)

同时将参数a 的值设置为5。

使用numpy完成get_torch_initialization

简单起见,我没有按照pytorch的封装方法分层实现初始化过程,后者主要为了提供多种不同的初始化方式。我直接按照线性层默认的初始方式——Kaiming均匀分布的公式用numpy实现了get_torch_initialization,其中a值取5, 代码如下:

def get_torch_initialization(numpy = True):

    a = 5

    def Kaiming_uniform(fan_in,fan_out,a):
        bound = 6.0 / (1 + a * a) / fan_in
        bound = bound ** 0.5
        W = np.random.uniform(low=-bound, high=bound, size=(fan_in,fan_out))
        return W

    W1 = Kaiming_uniform(28 * 28, 256, a)
    W2 = Kaiming_uniform(256, 64, a)
    W3 = Kaiming_uniform(64, 10, a)
    return W1,W2,W3
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch权重初始化方法——Kaiming、Xavier 的相关文章

随机推荐

  • powervm虚拟化分析

    powervm是IBM推出的适用于power系列服务器的虚拟化技术 有其独特的功能和技术 本文和大家一起探讨一下 首先power是ibm处理器的名字 也常常用来标识ibm服务器的型号 常见的power7 power8小型机就是指期cpu是p
  • 启锐 588 打印机每次打印都流出一部分,没有重新切换纸张

    2019独角兽企业重金招聘Python工程师标准 gt gt gt 588 488识别纸张 一 打印机关机 数据线拔掉 二 把纸拿出来 开机之后盖上盖子 三 然后把纸从机器后面塞进去让机器自动吸纸 四 然后长按打印机上面蓝色的按键 听到滴的
  • Java基础之随机生成数字和字母

    原文地址 http blog csdn net yaodong y article details 8115250 字母与数字的ASCII码 目 前计算机中用得最广泛的 字符集及其编码 是由美国国家标准局 ANSI 制定的ASCII码 Am
  • OpenGL视图变换及gluLookAt

    视图变换 即相机变换 其作用是把相机放在指定位置并使其对准场景 该变换是针对相机的变换 不会影响到模型 视图变换决定了相机的位置与方向 因此可以通过视图变换来改变相机位置与方向 从而达到从各个不同的位置与角度来观察同一个物体的情形 进行视图
  • SHAP显示原始特征

    1 问题描述 SHAP用于特征解释 对于机器学习方法往往需要对原始特征进行编码 而SHAP在绘制单个样本时 会显示每个特征及其取值 而这个取值已经是编码后的 通常无法确定其含义 如 下图所示的拍卖公司 城市和作者信息 预期达到的效果 2 实
  • 【西瓜书】4-决策树

    文章目录 4 1 基本流程 4 2 划分 4 2 1 信息增益 ID3 4 2 2 信息增益率 C 45 4 2 3 基尼指数 CART 4 3 剪枝处理 4 4 连续与缺失值 4 4 2 连续值处理 4 4 1 缺失值处理 4 5 多变量
  • Anchor是什么?

    1 选择性搜索 Selective Search 先介绍一下传统的人脸识别算法 是怎么检测出图片中的人脸的 以下图为例 如果我们要检测图中小女孩的人脸位置 一个比较简单暴力的方法就是滑窗 我们使用不同大小 不同长宽比的候选框在整幅图像上进行
  • crmeb重新安装_手动安装教程 · CRMEB 知识付费版 帮助文档 · 看云

    手动安装 1 创建数据库 倒入数据库文件 数据库文件目录 public install zhishifufei sql 2 修改数据库连接文件 配置文件路径 application database php 数据库类型 type gt my
  • vagrant启动openshift

    1 Install Vagrant 2 Install VirtualBox Ex yum install VirtualBox from the RPM Fusion repository 3 In your bashrc file or
  • 元胞自动机算法汇总含matlab代码_数学建模(十三)

    元胞自动机理论 许多复杂的问题都可以通过元胞自动机来建立模型 元胞自动机实质上是定义在一个具有离散 有限状态的元胞组成的元胞空间上 并按照一定的局部规则 在离散的时间维度上演化的动力学系统 元胞又可称为单元 细胞 是元胞自动机的最基本的组成
  • 【hortonworks/registry】registry 如何添加新的类型 支持 json

    1 概述 hortonworks registry 支持json 但是要自己扩展 有相关接口 支持基本类型 支持自定义对象类型 支持集合类型 map array null 支持嵌套结构 registry支持的数据类型有好几种 其中有Avro
  • STM32F103C8T6+PWM+DMA驱动 WS2812灯带

    STM32 PWM DMA驱动 WS2812灯带 文章目录 1 理论 2代码 理论 1 WS2812参考数据手册 https wenku baidu com view 0925958fba68a98271fe910ef12d2af90342
  • 基于Matlab卡尔曼滤波的IMU和GPS组合导航数据融合(附上源码+数据)

    本文介绍了如何使用Matlab实现惯性测量单元 IMU 和全球定位系统 GPS 组合导航数据融合的卡尔曼滤波算法 通过将IMU和GPS的测量数据进行融合 可以提高导航系统的精度和鲁棒性 我们将详细介绍卡尔曼滤波的原理和实现步骤 并给出源码
  • SpringBoot使用Pio-tl动态填写合同(文档)

    poi tl poi template language 是Word模板引擎 使用Word模板和数据创建很棒的Word文档 poi tl官方网址 项目中有需求需要动态填充交易合同 因此想到了使用poi tl技术来实现 一 引入依赖
  • Keil5无法进入debug(卡死在启动文件)

    Keil5无法进入debug 卡死在启动文件 出现的情况 运行一直卡死在启动文件 例如startup stm32f103xe s 而主程序的箭头也只有一个 两个箭头的运行行在启动文件 debug一直无法运行 解决办法 你在程序中使用了pri
  • Qml与C++交互4:C++信号与Qml的槽函数的连接

    Qml与C 交互4 C 信号与Qml的槽函数的连接 使用场景 整体思路 1 建立C 信号 2 C 实例注册到qml 3 qml中建立槽函数 Connections 类型 建立槽函数 运行结果 使用属性 更多资讯 知识 微信公众号搜索 上官宏
  • OpenCV项目编译错误

    编译遇到如下错误 opencv 3 4 4 modules highgui src window gtk cpp 1062 error 218 No OpenGL support Library was built without Open
  • 长春地铁一号线作业

    长春一号线作业 代码如下 public class 第一次作业 public static void main String args System out println 北环城站 一匡街 胜利公园 解放大路 工农广场 卫星广场 华庆路
  • 卡尔曼及扩展卡尔曼滤波详细推导-来自DR_CAN视频

    卡尔曼及扩展卡尔曼滤波详细推导 来自DR CAN视频 见知乎https zhuanlan zhihu com p 585819291
  • Pytorch权重初始化方法——Kaiming、Xavier

    Pytorch权重初始化方法 Kaiming Xavier 结论 结论写在前 Pytorch线性层采取的默认初始化方式是Kaiming初始化 这是由我国计算机视觉领域专家何恺明提出的 我的探究主要包括 为什么采取Kaiming初始化 考察K