深度学习笔记 —— 批量归一化

2023-10-29

梯度在上面(损失处)的时候比较大,越到下面越容易变小,因为很多时候都是n个很小的数相乘,乘到最后梯度就比较小了。所以就导致上面参数更新快,而下面参数更新慢(下面参数在小范围内变化时,抽取的底层特征变化不大,此时上层的参数是针对这些底层特征进行学习的)。这也意味着,如果下面的参数改变了,那么上面的参数之前也就白学了,需要重新训练。使得收敛比较慢。

 核心想法:方差和均值的分布不同层之间变化,如果把分布固定住了,相对来说就是比较稳定的。模型稳定了,也就是说更新的时候不会爆炸,也不会太小,收敛就不会变慢。

批量归一化是个线性变换,把均值、方差拉动得比较好,使其变化不那么剧烈。

 此解释是否为正确的理论,也不一定……

允许用更大的学习率来做训练(使得梯度的值变大一点,每层之间梯度的值会差不多一点,所以可以用更大的学习率,对权重的更新变快)

import torch
from torch import nn
from d2l import torch as d2l


def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 不算梯度,即在做inference
    # 此处用的是全局的均值和方差,因为做inference的时候可能只有一个样本,算不出来批量的均值和方差
    if not torch.is_grad_enabled():
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        # 全连接,第一维是批量大小,第二维是特征
        if len(X.shape) == 2:
            mean = X.mean(dim=0)  # 对每一列算均值,mean是1xn的行向量
            var = ((X - mean) ** 2).mean(dim=0)  # 同理
        # 2D卷积,批量大小x通道数x高x宽
        else:
            mean = X.mean(dim=(0, 2, 3), keepdim=True)  # 对每一个通道,把所有批量、高、宽里面的像素求均值,结果是1xnx1x1的向量
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)  # 1xnx1x1
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新moving_mean和moving_var
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta
    return Y, moving_mean.data, moving_var.data


# 创建一个正确的BatchNorm图层
class BatchNorm(nn.Module):
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # gamma和beta是要被迭代的,所以放在nn.Parameter里面
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean, self.moving_var,
            eps=1e-5, momentum=0.9
        )
        return Y


# 应用BatchNorm于LeNet模型
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
    nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
    nn.Linear(84, 10))


lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

# 拉伸参数gamma和偏移参数beta
print(net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,)))


# 简明实现
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
    nn.Linear(84, 10))

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习笔记 —— 批量归一化 的相关文章

  • java总结(不断更新)

    总结一句话 基础很重要 记得时而复习之 一轮人事 人事初步了解情况并推给技术部 二轮技术部 技术部电话面试 三轮面试技术以及对应客户 面试情况如下 自我介绍 项目经历 碰到的技术难点 java某个类的使用情况 字符串 长短值 类的加载机制
  • swagger-02-配置swagger

    1 4 配置swagger package com example config import org springframework context annotation Bean import org springframework c

随机推荐

  • Android UI 模板

    简单学习了Android UI 模板 自定义的UI模板 在自己设计的app中可以进行有效的代码复用 在这里做个流程整理 之后再添加漂亮的效果 首先加个在线阅读Android 源码的链接 点击打开链接 花个时间阅读一下系统的封装方法对学习An
  • 基于SSM框架的百货中心供应链管理系统

    社会发展日新月异 用计算机应用实现数据管理功能已经算是很完善的了 但是随着移动互联网的到来 处理信息不再受制于地理位置的限制 处理信息及时高效 备受人们的喜爱 本次开发一套百货中心供应链管理系统有管理员 人事 财务 销售 采购 服务六个角色
  • 突破对银河系的传统认知 大量超高能宇宙加速器被发现

    宇宙无限 信使有痕 5月17日 国家重大科技基础设施 高海拔宇宙线观测站 LHAASO 公布在银河系内发现大量超高能宇宙加速器 并记录到能量达1 4拍电子伏的伽马光子 拍 千万亿 这是人类观测到的最高能量光子 突破了人类对银河系粒子加速的传
  • C# 实现生成一维码、二维码

    注意 需要使用以下库文件 using ThoughtWorks QRCode Codec using ZXing using ZXing Common using ZXing QrCode 具体实现如下所示 帮助类一 using Syste
  • ES学习——ES评分简单介绍

    当我们能使用match来搜索匹配数据的时候 es会给每一个文档进行评分 匹配度 并根据评分的大小对结果文档进行排序 介绍 es的实时评分机制是基于 Lucene 的基础上实现的 最常见的是 TF IDF和BM25这两种评分模型 TF IDF
  • ElasticSearch配置

    2 搭建ElasticSearch环境 2 1 拉取镜像 docker pull elasticsearch 7 4 0 2 2 创建容器 docker run id name elasticsearch d restart always
  • JavaScript和jQuery的基础知识和使用

    初识JavaScript 首先对于JavaScript和Java两种语言 除了语法和Java有些类似 其他部分没有任何关系 由于当时Java很火 为了推广才在名字中加了Java 也就是所谓的蹭热度 另外 与JavaScript共同提起的还有
  • ModuleNotFoundError: No module named ‘forms‘

    问题 导入自定模块的时候报错 找不到模块 解决办法 将导入模块的代码写在靠近应用该模块的地方
  • MPLS实验

    MPLS第一次试验 公网地址配置 R2 GigabitEthernet0 0 1 23 1 1 1 24 LoopBack0 2 2 2 2 24 R3 GigabitEthernet0 0 0 23 1 1 2 24 GigabitEth
  • C语言文件读入---跳过第一行和最后一行

    include
  • 【FreeRtos学习笔记】STM32 CubeMx——Timers(定时器)

    目录 1 软件定时器 2 示例程序 2 1 例程功能 2 2 步骤 2 3 实验结果 2 4 函数讲解 1 软件定时器 定时器是MCU常用的外设 我们在学习各种单片机时必然会学习它的硬件定时器 但是 MCU自带的硬件定时器资源是有限的 而且
  • Android Fragment 生命周期图

    http www cnblogs com purediy p 3276545 html
  • 开发技术--浅谈python数据类型

    开发 浅谈python数据类型 在回顾Python基础的时候 遇到最大的问题就是内容很多 而我的目的是回顾自己之前学习的内容 进行相应的总结 所以我就不玩基础了 很多在我实际生活中使用的东西 我会在文章中提一下 并且我自己会根据这些内容进行
  • C++从入门到放弃之:Hello.cpp

    C 从入门到放弃 Hello cpp 1 创建c 程序源代码 2 C 程序的编译 3 C 扩展名 4 C 头文件 5 C 输入输出流 Hello cpp 1 创建c 程序源代码 vim hello cpp include
  • Unity3D+EasyAR实现AR效果的案例

    1 下载EasyAR的压缩包以及下面我要用到的霸王龙模型 链接 https pan baidu com s 12q4Jp11BMxnIW1DB48yy0Q 密码 1y3y 2 新建一个Unity3D的项目 然后双击下载好的EasyAR 将其
  • 分支-07. 比较大小(10)

    本题要求将输入的任意3个整数从小到大输出 输入格式 输入在一行中给出3个整数 其间以空格分隔 输出格式 在一行中将3个整数从小到大输出 其间以 gt 相连 输入样例 4 2 8 输出样例 2 gt 4 gt 8 程序 include int
  • 吃透Chisel语言.15.Chisel模块详解(二)——Chisel模块嵌套和ALU实现

    Chisel模块详解 二 Chisel模块嵌套和ALU实现 稍微复杂点的硬件设计就需要用嵌套的模块层级来构建了 上一篇文章中实现的计数器其实就是个例子 计数器内部嵌套了一个寄存器 一个Mux和一个加法器 这一篇文章就仔细讲解模块之间是怎么连
  • 结构体注入VS setter 注入

    结构体注入 setter注入是比较常用的依赖注入方式 都有各自的优缺点 setter注入是Spring推荐的依赖注入方式 首先结构体注入有什么问题 1 不能重新配置和重新注入 在Spring参考文档 中基于结构体注入和setter注入有以下
  • 利用visual studio 2017创建mfc程序,来输出hello world。

    1 点击文件 选择新建 再点击项目 2 选择visual C 选择MFC应用 位置和名称根据需要可适当更改 再点击创建 如果没有MFC应用 需要在工具那里点击获取工具和功能 3 在单个组件里面添加关于MFC的组件 4 进入以下视图 5 点击
  • 深度学习笔记 —— 批量归一化

    梯度在上面 损失处 的时候比较大 越到下面越容易变小 因为很多时候都是n个很小的数相乘 乘到最后梯度就比较小了 所以就导致上面参数更新快 而下面参数更新慢 下面参数在小范围内变化时 抽取的底层特征变化不大 此时上层的参数是针对这些底层特征进