【标准化方法】(3) Group Normalization 原理解析、代码复现,附Pytorch代码

2023-11-13

今天和各位分享一下深度学习中常用的标准化方法,Group Normalization 数据分组归一化,向大家介绍一下数学原理,并用 Pytorch 复现。

Group Normalization 论文地址:https://arxiv.org/pdf/1803.08494.pdf


1. 原理介绍

在目标检测,视频分类等大型计算机视觉应用中,受到计算机内存的限制,必须设置较小的样本数量,但是样本量小势必会导致批归一化的性能有所影响。

分组归一化(Group Normalization,GN)是针对批归一化算法对批次大小依赖性强这一弱点而提出的改进算法。因为 BN 层统计信息的计算与批次的大小有关因此当批次变小时,很明显统计均值和方差的计算会越不准确和稳定,最终会有小批次高错误率的这一现象发生

分组归一化 GN 介于层归一化 LN 和实例归一化 IN 之间,对于输入大小为 [N,C,H,W] 的图像,N 代表批次的大小,C 表示输入通道数,H,W 表示输入图片高度和宽度。

分组归一化首先将输入通道 C 分为 G 个小组,然后分别对每一小组做归一化操作,也就是先把输入的特征维度由 [N,C,H,W] 变成 [N,G,\frac{C}{G},H,W]归一化的维度[\frac{C}{G},H,W]

事实上,当 G 等于 1 时,即所有的输入通道为 1 组GN 与 LN 的计算方式相同,而当 G 等于 C 时1 个输入通道为 1 组GN 与 IN 的计算方式相同

上图是批归一化算法 BN、层归一化算法 LN、实例归一化 IN 和分组归一化 GN 的简单图示。图中的立方体是三维,蓝色的方块是各个算法计算均值和方差的区域

其中 C 代表通道数,N 是批量大小,H,W 是高度和宽度,第三个维度的大小是 H*W,这样输入就可以用三维图形来表示。从上图中可以看出只有 BN  的计算与批次大小 N 有关LN、IN 和 GN 的计算都在单个样本上进行,  LN、IN 和 GN 三者可相互转换。

通常来说,归一化的方式如下所示:

\mu_i=\frac{1}{m}\sum_{k\in S_i}x_k

\sigma_i=\sqrt{\frac{1}{m}\sum_{k\in S_i}\left(x_k-\mu_i\right)^2+\epsilon}

S_i 是均值和方差的计算区域,在 BN 中有:

S_i=\left\{k|k_C=i_C\right\}

在 LN 中:

S_i=\left\{k|k_N=i_N\right\}

在 GN 中:

S_i=\{k\mid k_N=i_N,floor(\frac{k_C}{C/G})=floor(\frac{i_C}{C/G})\}

优点:不依赖批量大小。

缺点:当批量大小较大时,性能不如BN。


2. 代码展示

import torch 
from torch import nn

class GN(nn.Module):
    # 初始化
    def __init__(self, groups:int, channels:int, 
                 eps:float=1e-5, affine:bool=True):
        super(GN, self).__init__()
        # 通道数要整除组数
        assert channels % groups == 0, 'channels should be evenly divisible by groups'
        self.groups = groups  # 把通道分成多少组
        self.channels = channels  # 通道数
        self.eps = eps  # 防止分母为0
        self.affine = affine  # 是否使用可学习的线性变化参数
        if self.affine:
            self.scale = nn.Parameter(torch.ones(channels))  # 缩放因子
            self.shift = nn.Parameter(torch.zeros(channels))  # 偏置
    # 前向传播
    def forward(self, x: torch.Tensor):
        x_shape = x.shape  # 输入特征的维度 [b,c,w,h]
        batch_size = x_shape[0]  # 样本量
        assert self.channels == x.shape[1]  # 预设通道数和输入特征的通道数要保持一致
        # [b,c,w,h]-->[b,g,w*h*c/g]
        x = x.view(batch_size, self.groups, -1)
        # 在最后一个维度上做标准化
        mean = x.mean(dim=[-1], keepdim=True)  # [b,g,1]
        mean_x2 = (x**2).mean(dim=[-1], keepdim=True)  # [b,g,1]
        var = mean_x2 - mean**2
        x_norm = (x-mean) / torch.sqrt(var+self.eps)  # [b,g,w*h*c/g]
        # 线性变化
        if self.affine:
            x_norm = x_norm.view(batch_size, self.channels, -1)  # [b,c,w*h]
            x_norm = self.scale.view(1,-1,1)* x_norm + self.shift.view(1,-1,1)  # [1,c,1]*[b,c,w*h]+[1,c,1]
        # [b,c,w*h]-->[b,c,w,h]
        return x_norm.view(x_shape)

# ---------------------------------- #
# 验证
# ---------------------------------- #

if __name__ == '__main__':
    # 构造输入层
    x = torch.linspace(0, 47, 48, dtype=torch.float32)  # 构造输入层
    x = x.reshape([2,6,2,2])  # [b,c,w,h]
    # 实例化
    gn = GN(groups=3, channels=6)
    # 前向传播
    x = gn(x)
    print(x.shape)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【标准化方法】(3) Group Normalization 原理解析、代码复现,附Pytorch代码 的相关文章

  • 查找两个字典的匹配键值对

    检查一个字典的键值对是否也存在于其他字典中的最有效方法是什么 假设我有两个字典dict1 and dict2这两个字典有一些共同的键值对 我想找到这些并打印出来 做到这一点最有效的方法是什么 请建议 一种方法是 d inter dict k
  • Urllib 的 urlopen 在某些网站上被破坏(例如 StackApps api):返回垃圾结果

    我在用着urllib2 s urlopen函数尝试从 StackOverflow api 获取 JSON 结果 我正在使用的代码 gt gt gt import urllib2 gt gt gt conn urllib2 urlopen h
  • 在python中合并3个dict()

    如果多个字典之间有公共字符串 是否有逻辑合并多个字典的方法 即使这些公共字符串在一个 dict 的值与另一个 dict 的键之间匹配 我在 SO 上看到了很多类似的问题 但似乎没有一个问题能解决我将 较低级别文件 中的多个键与较高键 值中的
  • Heroku 上的 Django 应用程序在一段时间后删除对象

    我编写了一个简单的 Django 问答论坛应用程序并将其部署在 Heroku 上 该网站的本地版本运行良好 但是 生产版本不会将问题 答案等存储超过几个小时 我决定坚持使用 Django 附带的 sqlite3 我预计该网站不会有太多流量
  • SQLAlchemy 在 MySQL 上使用什么列类型作为“文本”?

    我的总体用例是试图确定我是否可以编写一个与数据库无关的 至少支持 Postgres 和 MySQL 存储一些大数据作为原始文本 认为 500MB 作为粗略的理论上限 基于这个答案 https stackoverflow com a 2557
  • Python int和float在64位系统中的内存消耗

    我正在 Python 3 4 的 64 位系统中尝试以下代码 以了解不同原始数据类型的内存消耗 import sys print sys getsizeof 45 prints 28 print sys getsizeof 45 2 pri
  • 在 Windows 上安装 PyGIMP

    在网上 我可以找到有关使用 python 编写 gimp 脚本的各种示例 http www jamesh id au software pygimp http www jamesh id au software pygimp http ww
  • 令人困惑的问题>> FileNotFoundError:[Errno 2]没有这样的文件或目录:

    这个问题让我很困惑 也许问题出在代码上 希望你看一下 with open training images labels path r as file lines file readlines 他说该文件不存在 FileNotFoundErr
  • 在 pandas eval 中调用 round()、ceiling()、floor()、min()、max()

    正如标题所说 有没有办法在 pandas eval 中支持 round ceiling min max floor 函数 数据框 import pandas as pd import numexpr as ne op d ID 1 2 3
  • 使用存储的密钥作为环境变量

    我有一个秘密密钥存储在 GCP 的秘密管理器中 我们的想法是使用该密钥通过云功能获取预算列表 现在 我可以从代码中访问该密钥 但我面临的问题是我需要使用该密钥设置一个环境变量 这是我添加密钥的方式 如果您的本地目录中有该文件 但是还有其他方
  • PyGTK TreeView 中的自动换行

    如何在 PyGTK TreeView 中自动换行文本 gtk TreeView 中的文本是使用 gtk CellRendererText 渲染的 文本换行归结为在单元格渲染器上设置正确的属性 为了让文本换行 您需要设置wrap width单
  • matplotlib 中矩形面片之间存在不需要的空间

    以下代码绘制两个红色矩形 红色矩形应该彼此相邻 之间没有空间 在 python 图中 这是可以的 在导出的 pdf 中 矩形之间有一个细长但明显的空白 有什么方法可以解决这个问题吗 import matplotlib pyplot as p
  • Pygame 旋转射击

    我和几个朋友一直在编写一种有趣的新射击机制 为了让它发挥作用 我们需要朝玩家面对的方向射击 Sprite 正在使用 Pygame Transform Rotate 进行旋转 我们怎样才能找到一个角度 然后朝那个方向发射子弹呢 这是我们的精灵
  • ModuleNotFoundError:没有名为“googleapiclient”的模块

    如果这是一个愚蠢的问题 我深表歉意 我在 stackoverflow 上搜索过 但没有找到解决办法 我正在致力于从 Python 2 7 迁移到 Python 3 8 我收到一个程序的以下错误 请帮我 Traceback most rece
  • 在类方法 Python 中调用多处理

    最初 我有一个类来存储一些处理后的值 并通过其他方法重用这些值 问题是当我尝试将类方法划分为多个进程以加速时 python 生成了进程 但它似乎不起作用 正如我在任务管理器中看到的那样 只有 1 个进程在运行 并且结果从未传递 我做了几次搜
  • 无法从 celery 信号连接到 celery 任务?

    我正在尝试连接task2 from task success signal from celery signals import task success from celery import Celery app Celery app t
  • 在IPython笔记本中自动播放声音

    我经常在 IPython 笔记本中运行长时间运行的单元 我希望笔记本在单元完成执行时自动发出蜂鸣声或播放声音 有没有办法在 iPython 笔记本中执行此操作 或者我可以在单元格末尾放置一些命令来自动播放声音 我正在使用 Chrome 如果
  • 为什么我只能在异步函数中使用await关键字?

    假设我有这样的代码 async def fetch text gt str return text async def show something something await fetch text print something 这很
  • 无法使用 Python 3 编写的 gzip.open() 将压缩文件上传到云存储

    当我尝试在 Cloud Shell 实例上使用 python 脚本将压缩的 gzip 文件上传到云存储时 它总是上传一个空文件 这是重现错误的代码 import gzip from google cloud import storage s
  • WTforms 表单未提交但不输出验证错误

    我正在尝试使用以下方式上传文件flask uploads工作和遇到一些障碍 我会告诉你我的flask查看函数 html 希望有人能指出我缺少的内容 基本上发生的情况是我提交了表格但失败了if request method POST and

随机推荐

  • 自定义炫酷powershell

    自定义炫酷powershell 美化 linux上的bash和zsh之类的命令行终端炫酷无比 window上的cmd和powershell丑的不忍直视 很久之前不知参考谁的一篇文章自定义了一下 还算勉强能看得过去 重装电脑时候发现了 便记录
  • 我是疫情期间的幸运儿

    疫情期间的人生百态 疫情持续了这么长时间 有非常非常多的人的工作受疫情的影响 有些人因为疫情 企业经营困难 被迫失业 有些人在疫情严重前夕 选择辞职 然后寻找更好的工作机会 可是因为疫情 被迫延长待业的时间 并且在焦虑中煎熬 有些人本想打算
  • 10秒钟脱口而出十位数相同两位数的乘法

    10秒钟脱口而出十位数相同两位数的乘法 一 范围 十位数相同的两位数 二 目标 计算两位数的相乘 10秒钟脱口而出 三 基本公式 以尾数之和展开讨论 假设两个数分别是10a b以及10a c 那么尾数之和就是b c 序号 分类 公式 举例
  • Python —— matplotlib库的温度图像绘制

    使用的环境是Jupyter Notebook 我是安装了python版本Anaconda 已经内置了各种python包 可进入官网下载 在Anaconda下安装Jupyter Notebook即可在web页面上进行代码编写 在python中
  • Shell--基础--07--基本运算符

    Shell 基础 07 基本运算符 1 介绍 Shell支持多种运算符 包括如下 算数运算符 关系运算符 布尔运算符 字符串运算符 文件测试运算符 原生bash不支持简单的数学运算 但是可以通过其他命令来实现 例如 awk 和 expr e
  • fancyhdr宏包设置latex页眉页脚

    LaTeX的fancyhdr宏包的使用 CTEXwiki关于fancyhdr的说明可以在这里找到 在latex中用自定义页眉页脚 一般都要使用宏包fancy 关键是琢磨一下下面的例子 在看看相应的说明 一般就可以得到你想要的结果了 下面的内
  • 腾讯云轻量数据库mysql服务快速入门!

    快速入门 本文旨在介绍如何快速使用轻量数据库服务 帮助用户快速了解轻量数据库服务使用的全流程 从数据库的创建到基本使用 您需要完成如下操作 创建数据库 登录 轻量数据库服务购买页 根据实际需求选择各项配置信息 确认无误后 单击立即购买 地域
  • styled-components常见使用方法

    yarn add styled components import styled from styled components 1 基础使用 const BoxStyle styled div color red 2 UI组件加样式 imp
  • 数据结构-malloc申请动态空间-链表的创建

    一 malloc申请动态空间注意以下事项 1 malloc申请动态空间时必须声明类型 2 使用malloc申请的空间在使用完成之后必须使用free释放 3 malloc申请空间的类型必须和指向他的指针类型匹配 such as int p p
  • 挖掘视频网站【优酷】上被截断的视频的地址--001

    不知道大家看视频的时候有没有注意过 一个稍微长的视频 比如超过20分钟 你刚开始看的时候暂停播放 它的进度条会在中途某一个位置停止加载 当你把播放位置调节到那个停顿的地方 视频又开始继续加载 如果视频还有很多 它会停顿很多次 我们不禁要问
  • 网管实战(7):CISCO网管设备学习笔记

    虽然现在管理的都是华为和H3C的网络设备 但有时候还是要管理一些思科的设备 比如CISCO 4506 CISCO 6504 3750等 作为网管小白 很多时候都需要查一些命令来操作 这里是我2019年9月25日开始学习CISCO设备时的学习
  • 给Delphi社群的公开信

    给Delphi社群的公开信 Borland RAD部门副总裁
  • php爬虫教程(五)提高爬虫抓取效率

    之前有一次抓取x浪图片库的时候200w图片跑了一整天的时间 后来采取多进程抓取提高了很高的效率 多进程的实现可以参考这个方法 http blog csdn net u014017080 article details 46925725 主进
  • 《代码大全2》第2章 用隐喻来更充分地理解软件开发

    Code Complete 2 持续更新中 来杯咖啡的博客 CSDN博客这本书有意设计成使你既可以从头到尾阅读 也可以按主题阅读 1 如果你想从头到尾阅读 那么你可以直接从第2章 用隐喻来更充分地理解软件开发 开始钻研 2 如果你想学习特定
  • 眼图 非差分线_利用眼图解决USB在布线中的信号完整性问题

    通用串行总线USB Universal Serial Bus 协议从1 0版本发展到现在 由于数据传输速度快 接口方便 支持热插拔等优点使USB设备被越来越多人使用 目前 市场上以USB2 0为接口的产品越来越多 而绘制符合要求的PCB板在
  • WSL2端配置pytorch GPU加速环境

    Windows端Pytorch GPU加速的教程 Pytorch使用GPU加速的步骤 前置教程 WSL2安装及其python环境配置 配置好WSL2相关环境后 要想对pytorch进行GPU加速 需要进行以下步骤 更新Windows系统 只
  • LeetCode-Python-(206)反转链表

    反转链表 反转一个单链表 示例 输入 1 gt 2 gt 3 gt 4 gt 5 gt NULL 输出 5 gt 4 gt 3 gt 2 gt 1 gt NULL 解题思路 参考博客 代码 class Solution def revers
  • Ceph 存储集群 - 搭建存储集群

    一 准备机器 本文描述如何在 CentOS 7 下搭建 Ceph 存储集群 STORAGE CLUSTER 一共4台机器 其中1个是管理节点 其他3个是ceph节点 hostname ip role 描述 admin node 192 16
  • HTTP和HTTPS的区别?

    目录 HTTP HTTPS HTTP与HTTPS区别 HTTPS相比于HTTP协议的优点和缺点 优点 缺点 HTTP HTTP是超文本传输协议 HTTP协议是基于传输层的TCP协议进行通信 通用无状态的协议 80端口 HTTPS HTTPS
  • 【标准化方法】(3) Group Normalization 原理解析、代码复现,附Pytorch代码

    今天和各位分享一下深度学习中常用的标准化方法 Group Normalization 数据分组归一化 向大家介绍一下数学原理 并用 Pytorch 复现 Group Normalization 论文地址 https arxiv org pd