深度学习——权重的初始值

2023-11-15

权重的初始值

①权重的初始值十分重要,关系到神经网络的学习是否成功。

可以将权重初始值设置为0吗

为了抑制过拟合、提高泛化能力,采用权值衰减的方法,它是一种以减小权重参数的值为目的进行学习的方法。
在误差反向传播法中,所有的权重值都会进行相同的更新。比如,在2层神经网络中,假设第1层和第2层的权重为0。这样一来,正向传播时,因为输入层的权重为0,所以第2层的神经元全部会被传递相同的值。第2层的神经元中全部输入相同的值,这意味着反向传播时第2层的权重全部都会进行相同的更新。
简单的说,权重无法更新为新的值因此在初始化的时候必须随机生成初始值。

隐藏层的激活值分布

观察隐藏层的激活值(激活函数的输出数据)的分布,了解权重的初始值是如何影响隐藏层的激活值的分布。

向一个5层神经网络(激活函数使用sigmoid函数)传入随机生产的输入数据,绘制各层激活值的数据分布。

import numpy as np
import matplotlib.pyplot as plt

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def ReLU(x):
    return np.maximum(0, x)

def tanh(x):
    return np.tanh(x)
    
input_data = np.random.randn(1000, 100)  # 1000个数据
node_num = 100  # 各隐藏层的节点(神经元)数
hidden_layer_size = 5  # 隐藏层有5层
activations = {}  # 激活值的结果保存在这里

x = input_data

for i in range(hidden_layer_size):
    if i != 0:
        x = activations[i-1]

    # 改变初始值进行实验!
    w = np.random.randn(node_num, node_num) * 1
    # w = np.random.randn(node_num, node_num) * 0.01
    # w = np.random.randn(node_num, node_num) * np.sqrt(1.0 / node_num)
    # w = np.random.randn(node_num, node_num) * np.sqrt(2.0 / node_num)

    a = np.dot(x, w)
    # 将激活函数的种类也改变,来进行实验!
    z = sigmoid(a)
    # z = ReLU(a)
    # z = tanh(a)

    activations[i] = z

#绘制直方图
for i, a in activations.items():
    plt.subplot(1, len(activations), i+1)
    plt.title(str(i+1) + "-layer")
    if i != 0: plt.yticks([], [])
    # plt.xlim(0.1, 1)
    # plt.ylim(0, 7000)
    plt.hist(a.flatten(), 30, range=(0,1))
plt.show()

在这里插入图片描述
这次使用的是标准差为1的高斯分布,但实验的目的是通过改变这个尺度(标准差),观察激活值的分布如何变化。
可以看到,各层的激活值偏向0和1分布,这里使用的是sigmoid激活函数,当输入和输出不断靠近0或者1时,它的导数的值逐渐减小接近于0,偏向0和1的数据分布会造成反向传播中的梯度值不断减小。这个问题称为 梯度消失

当权重的标准差设置为0.1时

w = np.random.randn(node_num, node_num) * 0.1
在这里插入图片描述
从图中可以看激活值的偏向较为平均,但大多数也分布在0.5附近。

当权重的标准差设置为0.01时

w = np.random.randn(node_num, node_num) * 0.01
在这里插入图片描述
标准差为0.01时,这次呈集中在0.5附近的分布,尽管不会发生梯度消失的问题,但是激活值的分布有所偏向,因为如果很多神经元输出几乎相同的值,那么就没有其存在的意义,导致“表现力受限”的问题。

各层的激活值的分布都要求有适当的广度,否则会出现梯度消失,或者表现力受限的问题!

如何使得激活值呈现具有相同广度的分布

根据Xavier的推论:
在这里插入图片描述

w= np.random.randn(node_num, node_num) * np.sqrt(1.0 / node_num)

激活函数 tanh()
在这里插入图片描述
在这里插入图片描述

ReLU的权重初始值

使用“He初始值”
w= np.random.randn(node_num, node_num) * np.sqrt(2.0 / node_num)

总结:
当激活函数使用ReLU时,权重初始值使用He初始值。
当激活函数为sigmoid或tanh等S型曲线函数时,初始值使用Xavier初始值。

基于MNIST数据集的权重初始值的比较

以下实验将说明,不同的权重初始值的赋值方法会在多大程度上影响神经网络的学习。
神经网络有5层,每层有100个神经元,激活函数使用的是ReLU。

import os
import sys
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.util import smooth_curve
from common.multi_layer_net import MultiLayerNet
from common.optimizer import SGD

#0:读入MNIST数据==========
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)

train_size = x_train.shape[0]
batch_size = 128
max_iterations = 2000

#1:进行实验的设置==========
weight_init_types = {'std=0.01': 0.01, 'Xavier': 'sigmoid', 'He': 'relu'}
optimizer = SGD(lr=0.01)

networks = {}
train_loss = {}
for key, weight_type in weight_init_types.items():
    networks[key] = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100],
                                  output_size=10, weight_init_std=weight_type)
    train_loss[key] = []

#2:开始训练==========
for i in range(max_iterations):
    batch_mask = np.random.choice(train_size, batch_size)
    x_batch = x_train[batch_mask]
    t_batch = t_train[batch_mask]
    
    for key in weight_init_types.keys():
        grads = networks[key].gradient(x_batch, t_batch)
        optimizer.update(networks[key].params, grads)
    
        loss = networks[key].loss(x_batch, t_batch)
        train_loss[key].append(loss)
    
    if i % 100 == 0:
        print("===========" + "iteration:" + str(i) + "===========")
        for key in weight_init_types.keys():
            loss = networks[key].loss(x_batch, t_batch)
            print(key + ":" + str(loss))

#3.绘制图形==========
markers = {'std=0.01': 'o', 'Xavier': 's', 'He': 'D'}
x = np.arange(max_iterations)
for key in weight_init_types.keys():
    plt.plot(x, smooth_curve(train_loss[key]), marker=markers[key], markevery=100, label=key)
plt.xlabel("iterations")
plt.ylabel("loss")
plt.ylim(0, 2.5)
plt.legend()
plt.show()

在这里插入图片描述
可以看到使用ReLU作为激活函数时,使用He初始值,学习的效率
更高。

参考

《深度学习入门:基于Python的理论与实现》 斋藤康毅

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习——权重的初始值 的相关文章

  • 在 python + openCV 中使用网络摄像头的问题

    我正在使用以下代码使用 openCV python 访问我的网络摄像头 import cv cv NamedWindow webcam feed cv CV WINDOW AUTOSIZE cam cv CaptureFromCAM 1 然
  • 更改 Inkscape 的 Python 解释器

    在使用 Inkscape 时 我不断收到错误 这似乎意味着未满足 python 2 vs 3 的期望 尽管我已经安装了它们 例如 当我尝试从模板生成新文档时 我得到 Traceback most recent call last File
  • 如何配置散景图以具有响应宽度和固定高度

    我使用通过组件功能嵌入的散景 实际上我使用 plot sizing mode scale width 它根据宽度进行缩放并保持纵横比 但我想要一个响应宽度但固定或最大高度 这怎么可能实现呢 有stretch both and scale b
  • 刷新访问令牌时出现“invalid_grant”错误的情况?

    最近我一直在为这个问题揪心 一些背景 使用oauth2客户端 https code google com p google api python client 库来管理用户的令牌 这些令牌用于定期并发执行各种后台任务 每次要为用户运行其中一
  • 同情因子简单关系

    我在 sympy 中有一个简单的因式分解问题 无法解决 我在 sympy 处理相当复杂的积分方面取得了巨大成功 但我对一些简单的事情感到困惑 如何得到 phi 2 2 phi phi 0 phi 0 2 8 因式分解 phi phi 0 2
  • 垂直线 axvline 在 matplotlib 的 loglog 图中绘制位于错误位置的线

    我在使用 axvline 在 matplotlib 的 loglog 图中绘制垂直线时遇到问题 第一个问题是垂直线没有出现在正确的位置 第二个问题 可能相关的是 当我放大或平移绘图时 垂直线只是保持在原位 并且没有通过平移 滑动绘图 或放大
  • 运行源代码中包含 Unicode 字符的 Python 2.7 代码

    我想运行一个在源代码中包含 unicode utf 8 字符的 Python 源文件 我知道这可以通过添加评论来完成 coding utf 8 在一开始的时候 但是 我希望不使用这种方法来做到这一点 我能想到的一种方法是以转义形式编写 un
  • Python - 为什么这段代码被视为生成器?

    我有一个名为 mb 的列表 其格式为 Company Name Rep Mth 1 Calls Mth 1 Inv Totals Mth 1 Inv Vol Mth 2 等等 在下面的代码中 我只是添加了一个包含 38 个 0 的新列表 这
  • 在 C# 中实例化 python 类

    我已经用 python 编写了一个类 我想通过 IronPython 将其包装到 net 程序集中 并在 C 应用程序中实例化 我已将该类迁移到 IronPython 创建了一个库程序集并引用了它 现在 我如何真正获得该类的实例 该类看起来
  • 如何在Python中获取绝对文件路径

    给定一条路径 例如 mydir myfile txt 如何在Python中找到文件的绝对路径 例如 在 Windows 上 我最终可能会得到 C example cwd mydir myfile txt gt gt gt import os
  • 如何在Python中正确声明ctype结构+联合?

    我正在制作一个二进制数据解析器 虽然我可以依靠 C 但我想看看是否可以使用 Python 来完成该任务 我对如何实现这一点有一些了解 我当前的实现如下所示 from ctypes import class sHeader Structure
  • 如何在 Numpy 中实现垃圾收集

    我有一个名为main py 它引用另一个文件Optimisers py它仅具有功能并用于for循环进入main py 这些函数都有不同的优化功能 This Optimisers py然后引用另外两个类似的文件 其中也只有函数 它们位于whi
  • 将 ASCII 字符转换为“”unicode 表示法的脚本

    我正在对 Linux 区域设置文件进行一些更改 usr share i18n locales like pt BR 并且需要格式化字符串 例如 d m Y H M 必须以 Unicode 指定 其中每个 在本例中为 ASCII 字符表示为
  • Scrapy - 不会爬行

    我正在尝试运行递归爬行 由于我编写的爬行不能正常工作 因此我从网络上提取了一个示例并进行了尝试 我真的不知道问题出在哪里 但是爬行没有显示任何错误 谁能帮我这个 另外 是否有任何逐步调试工具可以帮助理解蜘蛛的爬行流程 非常感谢任何与此相关的
  • 如何使用 python-gnupg 加密大型数据集而不占用所有内存?

    我的磁盘上有一个非常大的文本文件 假设它是 1 GB 或更多 还假设该文件中的数据有 n每 120 个字符一个字符 我在用python gnupg https pythonhosted org python gnupg 对此文件进行加密 由
  • 从 subprocess.Popen 获取整个输出

    我通过调用 subprocess Popen 得到了一个有点奇怪的结果 我怀疑这与我对 Python 的陌生有很大关系 args cscript USERPROFILE tools jslint js USERPROFILE tools j
  • 通过子类化 `io.TextIOWrapper` 来子类化文件 - 但它的构造函数有什么签名?

    我正在尝试子类化io TextIOWrapper下列的这个帖子 https stackoverflow com a 23796737 974555 虽然我的目标不同 以此开始 注意 动机 https stackoverflow com a
  • python IDLE shell 似乎无法正确处理一些转义

    例如 b 退格键打印为四元 在下面的示例中显示为 但是 n 换行是可以的 gt gt gt print abc bd abc d gt gt gt print abc nd abc d 我在 Vista pro python 2 7 下运行
  • 如何获取所有Python标准库模块的列表?

    我想要类似的东西sys builtin module names标准库除外 其他不起作用的事情 sys modules 只显示已经加载的模块 sys prefix 包含非标准库模块并且似乎无法在 virtualenv 内工作的路径 我想要这
  • 使用Python的timeit获取“全局名称'foo'未定义”

    我想知道执行一条Python语句需要多少时间 所以我上网查了一下 发现标准库提供了一个名为timeit http docs python org library timeit html旨在做到这一点 import timeit def fo

随机推荐

  • 基于springcloud gateway + nacos实现灰度发布(reactive版)

    什么是灰度发布 灰度发布 又名金丝雀发布 是指在黑与白之间 能够平滑过渡的一种发布方式 在其上可以进行A B testing 即让一部分用户继续用产品特性A 一部分用户开始用产品特性B 如果用户对B没有什么反对意见 那么逐步扩大范围 把所有
  • 一个网站引发的程序猿的牢骚,哈哈哈

    2013年大学毕业后 参加工作做的第一个前端项目 北京服装学院 今天调研一个关于iframe的需求 突然想试试 以前那些做IE6兼容的项目是否还在使用 就默默的点开了 十年了 他们没有换网站 我的岁月似乎从这一刻又回来了一次 已经十年了 我
  • Flask学习笔记(二)

    Flask学习笔记 二 1 知识点 1 1虚拟环境 1 1 1virtualenv 1 1 2virtualenvwrapper 1 2web与视图 1 3jinja2 1 3 1template知识点 1 3 2豆瓣列表页 1 3 3视图
  • 锚框损失论文下载 Iou-Loss【IoU Loss、GIoU Loss、 DIoU Loss 、CIoU Loss、 CDIoU Loss、 F-EIoU Loss、α-IoU Loss】

    锚框损失 Iou Loss IoU Loss GIoU Loss DIoU Loss CIoU Loss CDIoU Loss F EIoU Loss IoU Loss 论文打包下载 yolo系列论文https download csdn
  • cocosCreator2.3.x渲染流程深入剖析笔记(三)

    渲染批次合并之顶点 根据前面说过的render flow流程接下来就是重头戏了render流程 其中包括了 检查两个渲染节点是否可以合并 同时把renderData的数据填充到modelBatch里的buffer中去 所有需要渲染的节点都有
  • Kotlin中匿名函数(又称为Lambda,或者闭包)和高阶函数的详解

    博主前些天发现了一个巨牛的人工智能学习网站 通俗易懂 风趣幽默 忍不住也分享一下给大家 点击跳转到教程 1 匿名函数 fun main 匿名函数 1 定义时不取名字的函数 我们称之为匿名函数 匿名函数通常整体传递给其他函数 或者从其他函数返
  • java中到底该不该用@author标识作者?

    今天查看activiti的README 突然发现一段很有意思的FAQ Why do you not accept author lines in your source code Because the author tags in the
  • Redis基础

    一 Redis入门 1 Redis简介 Redis Remote Dictionary Server 即远程字典服务 是一个基于内存的key value结构数据库 是用C语言开发的一个开源的高性能键值对 key value 数据库 它可以用
  • 基于python 蔬菜价格数据分析 完整代码+数据

    https download csdn net download weixin 55771290 87567123
  • GRU解决预测分类问题(多变量预测多步)

    解决问题的背景 现有五个属性列 前四个属性列作为特征输入 第五个属性列作为标签值 第五个属性列的意义是类别 先需要通过前50步的数据特征预测后10步的类别 即 51 60步 1 直接多输出的方式 直接多输出的方式就是在神经网络的最后加上几个
  • Linux·DNS协议、ICMP协议、NAT技术

    目录 DNS协议 DNS背景 编辑域名简介 域名解析过程 使用dig工具分析DNS过程 ICMP协议 ICMP功能 ICMP协议格式 编辑ping命令 一个值得注意的坑 traceroute命令 NAT技术 NAT技术背景 NAT IP转换
  • 报错:‘NoneType‘ object has no attribute ‘shape‘

    报错 NoneType object has no attribute shape import cv2 as cv img cv imread images1 print img shape img shape 图像大小 行 列 通道数
  • TypeScript基础入门 - 枚举 - 联合枚举与枚举成员的类型

    2019独角兽企业重金招聘Python工程师标准 gt gt gt 转发 TypeScript基础入门 枚举 联合枚举与枚举成员的类型 项目实践仓库 https github com durban89 typescript demo git
  • Unity中添加按钮的方式

    方式一 使用 GUILayout 自动布局 用 GUILayout Button 来创建按钮 会自动的在屏幕的右上角按列排列按钮 这种方式添加的按钮大小和位置都无法改变 为默认值 private void OnGUI if GUILayou
  • f5负载均衡配置文件服务器,f5 负载均衡 dns 服务器 配置

    f5 负载均衡 dns 服务器 配置 内容精选 换一换 查询负载均衡器状态树 可通过该接口查询负载均衡器关联的监听器 后端云服务器组 后端云服务器 健康检查 转发策略 转发规则的主要信息 了解负载均衡器下资源的拓扑情况 GET v2 pro
  • mongodb入门操作

    mongodb入门操作 简单了解一下NoSql NoSql NoSql not only sql 是非关系型数据库系统的统称 它用于超大规模的数据的存储 提供有限的查询功能 mongodb mongodb是一个基于分布式文件存储的数据库系统
  • Rabbit学习笔记

    引言 什么是MQ MQ Message Quene 消息队列 通过典型的生产者和消费者模型不断向消息队列中生产消息 消费者不断从队列中获取消息 因为消息的生产和消费是异步的 而且只关系消息的发送和接收 没有业务逻辑的侵入 轻松地实现系统间解
  • IMU的ROS调试开发工具包:imu_tools

    目录 imu tool包 问题 参数配置便利性问题 实例 调试microstrain 3dm gx5 25 imu 问题 发布的imu姿态与实际imu姿态不一致问题 imu tool包 http wiki ros org imu tools
  • Java串口通信-JSerialComm

    Java串口通信 JSerialComm 目前网上的Java串口通信主要使用RXTXComm 但是这个库已经很久没有更新 最近的更新似乎在2012年 并且与JavaFX集成打包时会出现BUG JSerialComm是一个较新的串口通信库 其
  • 深度学习——权重的初始值

    权重的初始值 权重的初始值十分重要 关系到神经网络的学习是否成功 可以将权重初始值设置为0吗 为了抑制过拟合 提高泛化能力 采用权值衰减的方法 它是一种以减小权重参数的值为目的进行学习的方法 在误差反向传播法中 所有的权重值都会进行相同的更