BatchNorm原理以及PyTorch实现

2023-11-10

BatchNorm算法

在这里插入图片描述
简单来说BatchNorm对输入的特征按照通道计算期望方差(第1和第2个公式),并标准化(第3个公式,减去均值,再除方差,变为均值0,方差1)。但这会降低网络的表达能力,因此,BN在标准化后还要进行缩放平移,也就是可学习的参数 γ \gamma γ β \beta β,也对应每个通道。

BatchNorm的原理并不清楚,可能是降低了Internal Covariate Shift,也可能是使得optimization landscape变得平滑

优点

  • 提高训练稳定性,可使用更大的learning rate、降低初始化参数的要求并可以构建更深更宽的网络;
  • 加速网络收敛。

缺点

  • 增加计算量和内存开销,降低推理速度;
  • 增加训练和推理时的差异;
  • 打破了minibatch之间的独立性;
  • 小batch效果差。

BatchNorm 在训练时,仅用当前Batch的均值和方差,而测试推理时,使用EMA计算的均值和方差。

PyTorch Code

nn.BatchNorm2d为例。其继承关系为:Module → \to _NormBase → \to _BatchNorm → \to BatchNorm2dModule 是所有PyTorch构建网络模块的父类。

_NormBase

_NormBase主要是注册和初始化参数

class _NormBase(Module):
    """Common base of _InstanceNorm and _BatchNorm"""
    def __init__(
        self,
        num_features: int, # 特征通道数
        eps: float = 1e-5,	# 防止分母为0
        momentum: float = 0.1, # 
        affine: bool = True, # 标准化后是否进行缩放,是否使用\gamma 和 \beta
        track_running_stats: bool = True, # 使用均值方差进行标准化
        device=None,
        dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(_NormBase, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.empty(num_features, **factory_kwargs)) # 注册\gamma,后续初始化为1
            self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) # 注册\beta,后续初始化为0
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs)) # 注册期望,后续初始化为0
            self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs)) # 注册方差,后续初始化为1
            self.running_mean: Optional[Tensor]
            self.running_var: Optional[Tensor]
            self.register_buffer('num_batches_tracked',
                                 torch.tensor(0, dtype=torch.long,
                                              **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
            self.num_batches_tracked: Optional[Tensor]
        else:
            self.register_buffer("running_mean", None)
            self.register_buffer("running_var", None)
            self.register_buffer("num_batches_tracked", None)
        self.reset_parameters()

    def reset_running_stats(self) -> None:
        if self.track_running_stats:
            # running_mean/running_var/num_batches... are registered at runtime depending
            # if self.track_running_stats is on
            self.running_mean.zero_()  # type: ignore[union-attr]
            self.running_var.fill_(1)  # type: ignore[union-attr]
            self.num_batches_tracked.zero_()  # type: ignore[union-attr,operator]

	# 参数初始化,\gamma 为 1,\beta 为 0.
    def reset_parameters(self) -> None:
        self.reset_running_stats()
        if self.affine:
            init.ones_(self.weight)
            init.zeros_(self.bias)

    def _check_input_dim(self, input):
        raise NotImplementedError

_BatchNorm

调用nn.functional.batch_norm 对每个通道进行计算:

class _BatchNorm(_NormBase):
    def __init__(
        self,
        num_features,
        eps=1e-5,
        momentum=0.1,	# 见下一章节
        affine=True,
        track_running_stats=True,
        device=None,
        dtype=None
    ):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(_BatchNorm, self).__init__(
            num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
        )

    def forward(self, input: Tensor) -> Tensor:
        self._check_input_dim(input)

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that it gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:  # type: ignore[has-type]
                self.num_batches_tracked = self.num_batches_tracked + 1  # type: ignore[has-type]
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        r"""
        Decide whether the mini-batch stats should be used for normalization rather than the buffers.
        Mini-batch stats are used in training mode, and in eval mode when buffers are None.
        """
        if self.training:
            bn_training = True
        else:
            bn_training = (self.running_mean is None) and (self.running_var is None)

        r"""
        Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
        passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
        used for normalization (i.e. in eval mode when buffers are not None).
        """
        return F.batch_norm(
            input,
            # If buffers are not to be tracked, ensure that they won't be updated
            self.running_mean
            if not self.training or self.track_running_stats
            else None,
            self.running_var if not self.training or self.track_running_stats else None,
            self.weight,
            self.bias,
            bn_training,
            exponential_average_factor,
            self.eps,
        )

BatchNorm2d

特化了输入检查

class BatchNorm2d(_BatchNorm):
    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError("expected 4D input (got {}D input)".format(input.dim()))

关于momentum参数

按照Pytorch注释,momentum参与running_meanrunning_var的计算。置为None时,简单计算平均(累积移动平均)。默认值为0.1。

_BatchNorm中,赋值给了

exponential_average_factor = self.momentum

当其不为None时,也就是指数平均(Exponential Moving Average, EMA)。其计算公式为:
x ˉ t = β μ t + ( 1 − β ) x ˉ t − 1 \bar{x}_t = \beta \mu_t + (1-\beta)\bar{x}_{t-1} xˉt=βμt+(1β)xˉt1
其中, μ t \mu_t μt是当前Batch的均值或方差, β \beta β为exponential_average_factor。展开
x ˉ t = β μ t + ( 1 − β ) ( β μ t − 1 + ( 1 − β ) ( β μ t − 2 + ( 1 − β ) x ˉ t − 3 ) ) = β μ t + ( 1 − β ) β μ t − 1 + ( 1 − β ) 2 β μ t − 2 + . . . + ( 1 − β ) t β μ 0 \begin{aligned} \bar{x}_t &= \beta \mu_t + (1-\beta)(\beta \mu_{t-1} + (1-\beta)(\beta \mu_{t-2} + (1-\beta)\bar{x}_{t-3}))\\\\ &= \beta \mu_t + (1-\beta)\beta \mu_{t-1} + (1-\beta)^2\beta \mu_{t-2} + ... + (1-\beta)^t\beta \mu_0 \end{aligned} xˉt=βμt+(1β)(βμt1+(1β)(βμt2+(1β)xˉt3))=βμt+(1β)βμt1+(1β)2βμt2+...+(1β)tβμ0
从公式可以看出,越靠近当前的数据占的比重越大,比重按指数衰减。其值约等于最近
1 β \frac{1}{\beta} β1
次的均值。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

BatchNorm原理以及PyTorch实现 的相关文章

随机推荐

  • 一篇文章让你搞定所有redis面试题

    Redis是什么 Redis是C语言开发的一个开源的 遵从BSD协议 高性能键值对 key value 的内存数据库 可以用作数据库 缓存 消息中间件等 它是一种NoSQL not only sql 泛指非关系型数据库 的数据库 redis
  • Arduino酸度计(PH计)

    在本项目中 我们将通过将模拟pH传感器与Arduino接口来设计pH计 介绍 在化学中 pH是用于指定水基溶液的酸性或碱性的标度 酸性溶液的pH值较低 而碱性溶液的pH值较高 因此 Ph传感器具有确定任何溶液的Ph的能力 即可以判断该物质本
  • JAVA运行时类存在,但是报错:NoClassDefFoundError: Could not initialize class

    我们在部署代码时 明明类存在 但是发现报错 NoClassDefFoundError Could not initialize class 这类问题是由静态成员或静态初始化语句块引起 我们先看下面个类 import org apache c
  • C语言实现MD5/SHA1/SHA256/SHA512

    哈希函数是我们做校验时经常会用到的密码学工具 目前常用的工具有MD5 SHA1 SHA256 SHA512等 其中MD5已经被证实不安全 目前只能作为一种辅助的校验手段 而不能防篡改 下面介绍如何使用mbedTLS协议栈中的hash代码生成
  • BGP属性

    BGP 外部网关协议 此协议不在于自动发现网络拓扑 不追求速度 而在于AS之间选择最佳路由和控制路由的传播 追求可靠性 稳定性 操控性 承载性 使用TCP作为其传输协议 监听端口号为179 保证其可靠性 路由更新只发送更新的路由 适用于在以
  • C++基础学习笔记——对象的定义及引用

    1 类与对象的关系 通常我们把具有同样性质和功能的东西所构成的集合称为类 在C 中 可以把相同内部存储结构和相同操作集的对象看成属于同一类 在C 中 对象是类的实际变量 类与对象间的关系 可以用整型 int 和整型变量 i 之间的关系来类比
  • Linux——线程1

    一 线程基础 进程 有独立的进程地址空间 有独立的pcb 线程 有独立的pcb 没有独立的进程地址空间 因此进程线程最本质的区别就是 是否共享地址空间 在Linux下线程是最小的执行单位 进程是最小的分配资源单位 可看成只有一个线程的进程
  • 避坑记录:打电话(uni.makePhoneCall)

    uni makePhoneCall 可兼容微信小程序 H5 移动端 安卓 IOS 但是在移动端 安卓 上 如果拒绝授权电话 则会出现点击号码 既不报错 也不弹出打电话的bug 当然 如果只是简单调用makePhoneCall 也就不值得我去
  • Call Exec in PeopleCode

    我想在Application Engine里加一段调用命令行的代码 All PeopleCode is executed on the application server So if you re calling an interacti
  • 基于imx6ull视频监控

    基于imx6ull视频监控 前言 一 mjpg streamer 1 编译mjpg streamer 2 运行mjpg 3 mjpg框架 二 流媒体 1 ffmpeg 2 nginx服务器 3 实现flv js访问和ip地址访问 4 内网穿
  • MySQL添加用户、删除用户与授权

    前言 MySql中添加用户 新建数据库 用户授权 删除用户 修改密码 注意每行后边都跟个 表示一个命令语句结束 新建用户 登录MYSQL mysql u root p 密码 创建用户 mysql gt insert into mysql u
  • 从iOS App启动速度看如何为基础性能保驾护航

    1 前言 启动是App给用户的第一印象 一款App的启动速度 不单单是用户体验的事情 往往还决定了它能否获取更多的用户 所以到了一定阶段App的启动优化是必须要做的事情 App启动基本分为以下两种 1 1 冷启动 App 点击启动前 它的进
  • python:深拷贝,浅拷贝,赋值引用

    第一部分转载自 https www cnblogs com xueli p 4952063 html 1 python的复制 深拷贝和浅拷贝的区别 在python中 对象赋值实际上是对象的引用 当创建一个对象 然后把它赋给另一个变量的时候
  • purrr 0.2.0

    purrr 0 2 0 Hadley Wickham 2016 01 06 Categories Packages tidyverse 原文地址 我很高兴的发布了purrr 0 2 0 Purrr填补了R的函数式编程工具中的缺失部分 让你的
  • rpm包的卸载与安装

    本文章向大家介绍rpm包的卸载与安装 主要内容包括1 rpm包管理 2 rpm包的简单查询指令 3 卸载rpm包 4 安装rpm包 使用实例 应用技巧 基本知识点总结和需要注意事项 具有一定的参考价值 需要的朋友可以参考一下 目录 1 rp
  • 模式识别课程:目标检测③基于深度学习的检测算法

    title 目标检测 基于深度学习的检测算法 目标检测实验报告 检测所用软硬件 云服务器 硬件 macOS或者windows电脑 软件 pycharm 生成的测试集 云服务器 滴滴云 https www didiyun com activi
  • redisHyperLogLog原理解析

    场景 做服务端的同学 应该都遇到过计数场景 比如我想知道浏览某一个web页面的总人数 总次数 查看某条热门动态的总人数总次数 购买某件商品的总人数总次数 对于总次数我们直接基于计数器累加就能很方便的解决 时间和空间复杂度都不高 而对于总人数
  • fix: Build warning "generate id 'android:id/xxx' for external package 'android'

    other ref https blog csdn net w1070216393 article details 83088054 attr file
  • 【C/C++】哈希

    文章目录 1 unordered系列关联式容器 1 1unordered map接口 1 2unordered set 2 底层原理 2 1顺式结构和平衡树 2 2hash结构 2 3哈希冲突 哈希碰撞 2 4合理的哈希函数 2 4 1常见
  • BatchNorm原理以及PyTorch实现

    BatchNorm算法 简单来说BatchNorm对输入的特征按照通道计算期望和方差 第1和第2个公式 并标准化 第3个公式 减去均值 再除方差 变为均值0 方差1 但这会降低网络的表达能力 因此 BN在标准化后还要进行缩放平移 也就是可学