《动手学深度学习》第二十三天---稠密连接网络(DenseNet)

2023-11-10

(一)DenseNet

DenseNet作为另一种拥有较深层数的卷积神经网络,具有如下优点:

(1) 相比ResNet拥有更少的参数数量.

(2) 旁路加强了特征(feature)的重用.

(3) 网络更易于训练,并具有一定的正则效果.

(4) 缓解了gradient vanishing(梯度消失)和model degradation(模型退化)的问题

梯度消失问题在网络深度越深的时候越容易出现,原因就是输入信息和梯度信息在很多层之间传递导致的,而现在这种dense connection相当于每一层都直接连接input和loss,因此就可以减轻梯度消失现象,这样更深网络不是问题。
这种dense connection有正则化的效果,因此对于过拟合有一定的抑制作用,可能是因为参数减少了,所以过拟合现象减轻。

(二)DenseNet网络结构

(1)dense block

对比于ResNet的Residual Block,创新性地提出Dense Block,在每一个Dense Block中,任何两层之间都有直接的连接,也就是说,网络每一层的输入都是前面所有层输出的并集,而该层所学习的特征图也会被直接传给其后面所有层作为输入。通过密集连接,缓解梯度消失问题,加强特征传播,鼓励特征复用,极大的减少了参数量。
如下图:
[x0,x1,…,xl-1]表示将 0 到 l-1 层的输出feature map做concatenation。concatenation是做通道的合并,就像Inception那样。即将 X_{0} 到 X_{l-1} 层的所有输出feature map按Channel组合在一起。这里所用到的非线性变换H为BN+ReLU+ Conv(3×3)的组合。
在这里插入图片描述
在这里插入图片描述
(2)DenseNet的结构图

在这个结构图中包含了3个dense block。将DenseNet分成多个dense block,原因是希望各个dense block内的feature map的size统一,这样在做concatenation就不会有size的问题。
在处理特征图数量或尺寸不匹配的问题上,ResNet采用零填充或者使用1x1的Conv来扩充特征图数量,而DenseNet是在两个Dense Block之间使用Batch+1x1Conv+2x2AvgPool作为transition layer的方式来匹配特征图的尺寸。 这样就充分利用了学习的特征图,而不会使用零填充来增加不必要的外在噪声,或者使用1x1Conv+stride=2来采样已学习到的特征(stride=2会丢失部分学习的特征)。
在这里插入图片描述(3)DenseNet效率更高

如果每个 Hl 输出k个特征图,那么 l 层就有k0+k×(l−1)输入特征图,k0为输入层的通道数。由于每一层都包含之前所有层的输出信息,因此其只需要很少的特征图就够了(DenseNet与其他的网络架构有一个重要的不同之处在于可以通过修改k的大小,让DenseNet的网络变得非常窄小),这也是为什么DneseNet的参数量较其他模型大大减少的原因。这种dense connection相当于每一层都直接连接input和loss,因此就可以减轻梯度消失现象,这样更深网络不是问题。

(三)DenseNet的简单实现

(1)稠密块

import d2lzh as d2l
from mxnet import gluon, init, nd
from mxnet.gluon import nn

def conv_block(num_channels):   #  DenseBlock块内组成
    blk = nn.Sequential()
    blk.add(nn.BatchNorm(), nn.Activation('relu'),
            nn.Conv2D(num_channels, kernel_size=3, padding=1))   # BN+ReLU+Conv(3×3)模式
    return blk
    
class DenseBlock(nn.Block):   # 定义一个DenseBlock块
    def __init__(self, num_convs, num_channels, **kwargs):   #手动设计通道数和模块内的卷积块数目
        super(DenseBlock, self).__init__(**kwargs)
        self.net = nn.Sequential()
        for _ in range(num_convs):
            self.net.add(conv_block(num_channels))

    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            X = nd.concat(X, Y, dim=1)  # 在通道维上将输入和输出连结
        return X

如何计算输出通道呢?

blk = DenseBlock(2, 10)     #  定义输入的通道数为10,定义一个denseblock里面有两个卷积块
blk.initialize()
X = nd.random.uniform(shape=(4, 3, 8, 8))   #输入的X通道数为3,图像大小为8×8
Y = blk(X)  #Y为最后DenseBlock的输出
Y.shape

在这里插入图片描述
我们由这张图直观来看,这张图示一个DenseBlock,有四个卷积块,每个卷积块里面包括(BN+ReLU+Conv)三种层,到最后输出的通道数目,其实等于(Conv卷积层的输出通道数×卷积块个数)+输入通道数,
所以3+2×10=23。
所以输出为(4,23,8,8)
卷积块的通道数控制了输出通道数相对于输入通道数的增长,因此也被称为增长率(growth rate)。

(2)过渡层

由于每个稠密块都会带来通道数的增加,使用过多则会带来过于复杂的模型。过渡层用来控制模型复杂度。它通过1×1卷积层来减小通道数,并使用步幅为2的平均池化层减半高和宽,从而进一步降低模型复杂度。

def transition_block(num_channels):   #  定义DenseBlock之间的transition layer
    blk = nn.Sequential()
    blk.add(nn.BatchNorm(), nn.Activation('relu'),
            nn.Conv2D(num_channels, kernel_size=1),
            nn.AvgPool2D(pool_size=2, strides=2))   #  BN+ReLU+Conv(1×1)+AvgPool(2×2)
            #  当map的信息都应该有所贡献的时候用avgpool,因为网络深层的高级语义信息一般来说都能帮助分类器分类。
    return blk

如何理解这个过程呢?

blk = transition_block(10)
blk.initialize()
blk(Y).shape

比如之前DenseBlock的输出为(4,23,8,8),经过Conv2D(10×1×1)得到(4×10×8×8),经过AvgPool2D得到(4×10×4×4)

(3)DenseNet模型

DenseNet首先使用同ResNet一样的单卷积层和最大池化层。

net = nn.Sequential()
net.add(nn.Conv2D(64, kernel_size=7, strides=2, padding=3),
        nn.BatchNorm(), nn.Activation('relu'),
        nn.MaxPool2D(pool_size=3, strides=2, padding=1))  #  刚开始的时候为了减少无用信息选择MaxPooling(网络浅层)

类似于ResNet接下来使用的4个残差块,DenseNet使用的是4个稠密块。同ResNet一样,我们可以设置每个稠密块使用多少个卷积层。

num_channels, growth_rate = 64, 32   # num_channels为当前的通道数,growth_rate为卷积块的通道数
num_convs_in_dense_blocks = [4, 4, 4, 4]    #  4个DenseBlock,每个里面有4个卷积块

for i, num_convs in enumerate(num_convs_in_dense_blocks):
  #  利用enumerate可以同时迭代序列的索引和元素
    net.add(DenseBlock(num_convs, growth_rate))
  #  根据更新的num_convs添加DenseBlock
    num_channels += num_convs * growth_rate    # 上一个稠密块的输出通道数

    if i != len(num_convs_in_dense_blocks) - 1:
        num_channels //= 2         # 在稠密块之间加入通道数减半的过渡层
        net.add(transition_block(num_channels))

同ResNet一样,最后接上全局池化层和全连接层来输出。

net.add(nn.BatchNorm(), nn.Activation('relu'), nn.GlobalAvgPool2D(),
        nn.Dense(10))  
#利用全局平均池化层可以降低模型的参数数量来最小化过拟合效应。GAP层通过取平均值映射每个h×w的特征映射至单个数字。

在这里插入图片描述
(4)获取数据并训练模型

lr, num_epochs, batch_size, ctx = 0.1, 5, 256, d2l.try_gpu()
net.initialize(ctx=ctx, init=init.Xavier())
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch5(net, train_iter, test_iter, batch_size, trainer, ctx,
              num_epochs)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

《动手学深度学习》第二十三天---稠密连接网络(DenseNet) 的相关文章

随机推荐

  • 业内人员告诉你银行测试到底做什么,怎么进银行测试.....

    前言 从一家工作了五年的软件公司的测试管理者跳槽到银行做软件测试 短短两个月 对银行测试有了初步认识 总结和记录下来 加深个人的理解 同时也共享给各位 银行作为大家的理财顾问 对金钱非常敏感 频繁甚至偶尔出现的软件故障都会打击顾客的信心 如
  • SCDN如何有效防御CC攻击和DDOS攻击的

    SCDN是由阿里云提供的一整套安全加速的解决方案 可以根据业务需求去进行定制 在防护效果上 最低防护20gbps 300gbps 当然定制版的防护最高可达到600gbps 在网络上我们常见的网络攻击就是CC攻击和DDOS攻击了 那么CC攻击
  • 数据库基本操作(持续更新ing)

    SQL语句基本类型 CRUD 增加 Create 查询 Retrieve 更新 Update 删除 Delete 创建数据库 CREATE DATABASE 数据库名 删除数据库 DROP DATABASE 数据库名 切换数据库 USE 数
  • ArcGIS Runtime for Android天地图底图及TPK数据包放大后数据不显示问题

    环境 ArcGIS Runtime for Android版本 100 14 0 底图放大不显示的原因 在天地图url的配置中配置了更高level的url 但没有实际的数据 或url返回了 解决方法 不配置没有数据的level的url TP
  • 平稳过程的各态历经性

    平稳过程的各态历经性 1 各态历经的定义 2 例题 2 1 例1 2 2例2 3 各态历经性的判定 1 各态历经的定义 如果一个随机过程是平稳的 而且是均值和相关函数都具有各态历经性 那么我们称这个平稳过程具有各态历经性 均值各态历经的定义
  • Ubuntu ssh 访问服务器失败

    今天用ssh 登录交换机的时候发现访问不了 一直报no matching key exchange method found ccchw ccchw HP Compaq Elite 8300 CMT ssh ssh itte 10 163
  • K210图像检测&(1~8)数字卡片识别

    前言 第一次使用该平台 想先找一个简单的识别 来走走流程 就想到了 前几年的送药小车的数字卡片识别 花了半天收集标记图片 在运行时要注意摄像头与数字卡片的高度 不过也有些不足 可能是收集某个数字的训练集的时候 拍摄高度 不一致 因为是手拍
  • C++之引用类型,深浅拷贝构造

    引用类型 给内存段取别名 int m 10 引用 给内存段取别名 所以需要给他一段内存段 而不只是声明 int n m 不是赋值的意思 是别名的意思 想要在被调函数中修改主调函数中定义的变量的值时 不需要将其地址传输给被调函数 直接传输变量
  • IDEA 使用技巧(快速生成xml文件)

    settiings 搜索File 找到 File and Code Templates 点击加号新建一个 Name 输入文件名 Extension 输入文件类型
  • PCL 生成空间直线点云

    目录 一 算法原理 二 代码实现 三 结果展示 一 算法原理 已知直线上一点和直线的方向向量 即可根据数学原理生成用于算法测试的标准直线点云 以下示例代码中 以直线上一点为中心点生成空间直线点云 其中点的个数为100个 相邻点之间的间隔为0
  • 微信小程序的的图片显示不出来

    图片的路径分两种 1 本地的图片如images文件夹下面的 images t1 jpg 或者是http localhost 8080 Teacher news t1 jpg 2 网络连接的图片http www baidu com vue n
  • python自动化操作, 三种方法解决滑动模块问题(后二种可跳过90%滑动,限制需要打开浏览器)

    selenium win32api pyautogui 元素定位 可无头进行访问 但是会被检测 基本用不了 sli ele driver find element By XPATH span id nc 1 n1z xpath 定位 if
  • 人工智能电话机器人一个顶10个,各版本系统搭建

    前接触多的就是电销行业 有电话机器人 VOS线路问题或要演示站AI技术支持 一个人面对多台电话不停地接听 特别是客户多时不知道应答哪一个 反而还把自己搞得心烦意乱 不过随着科技的发展 电销行业里出现了一个叫智能电销机器人的产品 自动应答客户
  • Innovus零基础lab学习全面复盘总

    Innovus零基础lab学习全面复盘总结 附完整版pdf 文章右侧广告为官方硬广告 与吾爱IC社区无关 用户勿点 点击进去后出现任何损失与社区无关 为了让各位训练营学员更快入门数字 IC 后端 从第八期 IC 训练营开始 小编以一个 数字
  • chatgpt 上方一直在转圈 白屏 空白

    下面空白 上面有个转圈 清除缓存 进入https platform openai com 返回刷新 更换节点 退出 Clash 软件 返回刷新 白屏 空白 来回换节点 就是换节点 有的时候是v p n的问题
  • MongoDB安装部署

    一 mongodb安装部署 关闭防火墙和selinux root mongodb iptables F root mongodb setenforce 0 root mongodb systemctl stop firewalld 2 指定
  • hdu 3879 最大密集子图(点和边均带权)(模板)

    最大权闭合图 可以用最大密集子图来解速度更快复杂度低 题解 胡伯涛 最小割模型在信息学竞赛中的应用 点和边均带权的最大密集子图 s i 权为U 点权绝对值和 边的所有权值 i t 权为U 点的值 点的度 u v 权值为w 意思是选了v后可以
  • Mybatis-plus使用手册

    Mybatis plus 1 定义 MyBatis Plus 是一个 Mybatis 增强版工具 在 MyBatis 上扩充了其他功能没有改变其基本功能 为了简化开发提交效率而存在 2 使用 SpringBoot 集成 MyBatis Pl
  • 相关性分析

    这里的相关性分析主要是线性相关性分析 当然其他的形状的相关性分析可以通过变换转换为线性相关性分析 但是 线性相关性分析始终是相关性分析的基础 线性相关分析的构建主要分为以下几种 直接绘制散点图 通过把点标出来主观上看是否是线性相关 绘制散点
  • 《动手学深度学习》第二十三天---稠密连接网络(DenseNet)

    一 DenseNet DenseNet作为另一种拥有较深层数的卷积神经网络 具有如下优点 1 相比ResNet拥有更少的参数数量 2 旁路加强了特征 feature 的重用 3 网络更易于训练 并具有一定的正则效果 4 缓解了gradien