经典神经网络 -- DenseNet : 设计原理与pytorch实现

2023-11-06

原理

概念与网络结构

       DenseNet模型,它的基本思路与ResNet一致,但是它建立的是前面所有层与后面层的密集连接(dense connection)

       DenseNet的一大特色是通过 特征在channel上的连接 来实现特征重用(feature reuse)

       DenseNet在参数和计算成本更少的情形下实现比ResNet更优的性能

       相比ResNet,DenseNet提出了一个更激进的密集连接机制:即互相连接所有的层,具体来说就是每个层都会接受其前面所有层作为其额外的输入。

       ResNet是每个层与前面的某层(一般是2~3层)短路连接在一起,连接方式是通过元素级相加,而在DenseNet中,每个层都会与前面所有层在channel维度上连接(concat)在一起,并作为下一层的输入。

       对于一个 L 层的网络,DenseNet共包含 L*(L+1)/2 个连接,相比ResNet,这是一种密集连接。而且DenseNet是直接concat来自不同层的特征图,这可以实现特征重用,提升效率,这一特点是DenseNet与ResNet最主要的区别。

       CNN网络一般要经过Pooling或者stride>1的Conv来降低特征图的大小,而DenseNet的密集连接方式需要特征图大小保持一致。但是一层层链接下去肯定会越来越大,为了解决这个问题,DenseNet网络中使用DenseBlock+Transition的结构,这是一种折中组合。

       其中DenseBlock是包含很多层的模块,每个层的特征图大小相同,层与层之间采用密集连接方式。而Transition模块是连接两个相邻的DenseBlock,并且通过Pooling使特征图大小降低,还可以压缩模型。

       DenseBlock中的非线性组合函数 H 采用的是BN+ReLU+3x3 Conv的结构。与ResNet不同,所有DenseBlock中各个层卷积之后均输出 k 个特征图,即得到的特征图的out_channel数为 k ,或者说采用 k 个卷积核。k 在DenseNet称为growth rate,这是一个超参数。一般情况下使用较小的 k (比如12),就可以得到较佳的性能。由于后面层的输入会非常大,DenseBlock内部可以采用bottleneck层来减少计算量,主要是原有的结构中增加1x1 Conv,降低特征数量

       Transition层包括一个1x1的卷积和2x2的AvgPooling,结构为BN+ReLU+1x1 Conv+2x2 AvgPooling。

       DenseNet-C 和 DenseNet-BC

特点

 

代码实现

# DenseNet模型,它的基本思路与ResNet一致,但是它建立的是前面所有层与后面层的密集连接(dense connection)

# DenseNet的一大特色是通过 特征在channel上的连接 来实现特征重用(feature reuse)

# DenseNet在参数和计算成本更少的情形下实现比ResNet更优的性能

# 相比ResNet,DenseNet提出了一个更激进的密集连接机制:即互相连接所有的层,
# 具体来说就是每个层都会接受其前面所有层作为其额外的输入。

# ResNet是每个层与前面的某层(一般是2~3层)短路连接在一起,连接方式是通过元素级相加,
# 而在DenseNet中,每个层都会与前面所有层在channel维度上连接(concat)在一起,并作为下一层的输入。

# 对于一个 L 层的网络,DenseNet共包含 L*(L+1)/2 个连接,相比ResNet,这是一种密集连接。
# 而且DenseNet是直接concat来自不同层的特征图,这可以实现特征重用,提升效率,这一特点是DenseNet与ResNet最主要的区别。

# CNN网络一般要经过Pooling或者stride>1的Conv来降低特征图的大小,而DenseNet的密集连接方式需要特征图大小保持一致。
# 但是一层层链接下去肯定会越来越大,为了解决这个问题,DenseNet网络中使用DenseBlock+Transition的结构,这是一种折中组合。
# 其中DenseBlock是包含很多层的模块,每个层的特征图大小相同,层与层之间采用密集连接方式。
# 而Transition模块是连接两个相邻的DenseBlock,并且通过Pooling使特征图大小降低,还可以压缩模型。

# DenseBlock中的非线性组合函数 H 采用的是BN+ReLU+3x3 Conv的结构。
# 与ResNet不同,所有DenseBlock中各个层卷积之后均输出 k 个特征图,即得到的特征图的out_channel数为 k ,或者说采用 k 个卷积核。 
# k 在DenseNet称为growth rate,这是一个超参数。一般情况下使用较小的 k (比如12),就可以得到较佳的性能。
# 由于后面层的输入会非常大,DenseBlock内部可以采用bottleneck层来减少计算量,主要是原有的结构中增加1x1 Conv,降低特征数量

# Transition层包括一个1x1的卷积和2x2的AvgPooling,结构为BN+ReLU+1x1 Conv+2x2 AvgPooling。
# DenseNet-C 和 DenseNet-BC


from turtle import forward, shape
from numpy import block
import torch
import torch.nn as nn

from densenet import transition


def conv_block(in_channel, out_channel): # 一个卷积块
    layer = nn.Sequential(
        nn.BatchNorm2d(in_channel),
        nn.ReLU(),
        nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False)
    )
    return layer


class dense_block(nn.Module):
    def __init__(self, in_channel, growth_rate, num_layers):
        super().__init__() # growth_rate => k => out_channel
        block = []
        channel = in_channel # channel => in_channel
        for i in range(num_layers):
            block.append(conv_block(channel, growth_rate))
            channel += growth_rate # 连接每层的特征
        self.net = nn.Sequential(*block) # 实现简单的顺序连接模型 
        # 必须确保前一个模块的输出大小和下一个模块的输入大小是一致的
    
    def forward(self, x):
        for layer in self.net:
            out = layer(x)
            x = torch.cat((out, x), dim=1) # contact同维度拼接特征,stack(是把list扩维连接
            # torch.cat()是为了把多个tensor进行拼接,在给定维度上对输入的张量序列seq 进行连接操作
            # inputs : 待连接的张量序列,可以是任意相同Tensor类型的python 序列
            # dim : 选择的扩维, 必须在0到len(inputs[0])之间,沿着此维连接张量序列
        return x


def trabsition(in_channel, out_channel):
    trans_layer = nn.Sequential(
        nn.BatchNorm2d(in_channel),
        nn.ReLU(),
        nn.Conv2d(in_channel, out_channel, 1), # kernel_size = 1 1x1 conv
        nn.AvgPool2d(2, 2) # 2x2 pool
    )
    return trans_layer


class DenseNet121(nn.Module):
    def __init__(self, in_channel, num_classes, growth_rate=32, block_layers=[6, 12, 24, 16]):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channel, 64, 7, 2, 3), # padding=3 参数要熟悉
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(3, 2, padding=1)
        )
        self.DB1 = self._make_dense_block(64, growth_rate, num=block_layers[0])
        self.TL1 = self._make_transition_layer(256)
        self.DB2 = self._make_dense_block(128, growth_rate, num=block_layers[1])
        self.TL2 = self._make_transition_layer(512)
        self.DB3 = self._make_dense_block(256, growth_rate, num=block_layers[2])
        self.TL3 = self._make_transition_layer(1024)
        self.DB4 = self._make_dense_block(512, growth_rate, num=block_layers[3])
        self.global_avgpool = nn.Sequential( # 全局平均池化
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.classifier = nn.Linear(1024, num_classes) # fc层

    def forward(self, x):
        x = self.block1(x)
        x = self.DB1(x)
        x = self.TL1(x)
        x = self.DB2(x)
        x = self.TL2(x)
        x = self.DB3(x)
        x = self.TL3(x)
        x = self.DB4(x)
        x = self.global_avgpool(x)

    def _make_dense_block(self, channels, growth_rate, num): # num是块的个数
        block = []
        block.append(dense_block(channels, growth_rate, num))
        channels += num * growth_rate # 特征变化 # 这里记录下即可,生成时dense_block()中也做了变化
        return nn.Sequential(*block)
    
    def _make_transition_layer(self, channels):
        block = []
        block.append(transition(channels, channels//2)) # channels // 2就是为了降低复杂度 θ = 0.5
        return nn.Sequential(*block)


if __name__ == '__main__':
    net = DenseNet121(3, 10) # in_channel, num_classes
    x = torch.rand((1, 3, 224, 224))
    for name,layer in net.named_children():
        if name != 'classifier':
            x = layer(x)
            print(name, 'output shape:', x,shape)
        else:
            print(x.shape)
            x = x.view(x.shape[0], -1) # 展开tensor 分类
            print(x.shape)
            x = layer(x)
            print(name, 'output shape:', x.shape)

参考文章:

DenseNet:比ResNet更优的CNN模型 - 知乎

pytorch 实现Densenet模型 代码详解,计算过程,_视觉盛宴的博客-CSDN博客_densenet pytorch

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

经典神经网络 -- DenseNet : 设计原理与pytorch实现 的相关文章

随机推荐

  • FFmpeg音视频播放器流程

    音视频播放器流程 ffmpeg解封装解码流程API ffmpeg官网 FFmpeg
  • [STM32]KEIL调试程序进入HardFault_Handler异常处理总结

    在做CORTEX M3单片机开发的时候 如STM32 可能会遇到设备跑着跑着程序死机的情况 往往调试起来很多时候发现是程序进入HardFault Handler系统异常 根据相关资料和M3权威指南是可以通过调试查找出程序的问题点和解决问题的
  • eclipse gradle打包_Spring Boot(十二):Spring Boot 如何测试打包部署

    部分面试资料链接 https pan baidu com s 1qDb2YoCopCHoQXH15jiLhA 密码 jsam 想获得全部面试必看资料 关注公众号 大家可以在公众号后台回复 知乎 即可 有很多网友会时不时的问我 Spring
  • 一个人的成功不是没有理由的!(人物之楼天城)

    昨天 杭州第十四中学请来毕业生楼天城 给全体学生做励志讲座 讲高中三年的学习生活和理科思维的培养 讲座前 老师介绍 楼天城同学2004年毕业于十四中 保送清华大学 博士毕业 是公认的计算机天才 公认的中国大学生编程竞赛第一人 常以一人单挑一
  • 利用python摘取文本中所需信息,并保存为txt格式

    项目所需 IC设计中难免会处理大量文本信息 我就在项目中遇到了 对于一个几万行的解码模块 提取出其中的指令 如果不用脚本将会很麻烦 下面我将一个小小的例子分享给大家 刚学python 如果有更方便的实现方法清多多指教 目的 1 在几万行解码
  • Git常用命令总结

    Git常用命令总结 git init 在本地新建一个repo 进入一个项目目录 执行git init 会初始化一个repo 并在当前文件夹下创建一个 git文件夹 git clone 获取一个url对应的远程Git repo 创建一个loc
  • openssl的证书链验证

    原文地址 http blog csdn net dog250 article details 5442914 使用openssl验证证书链可以用以下命令 debian home zhaoya openssl openssl verify C
  • C语言分支循环语句

    需提前看 初识C语言 5 C语言一些基本常识 目录 分支语句 if语句 单if语句使用 if else语句 if else if else语句 switch语句 switch基本结构 break作用 default作用 循环语句 while
  • 【Vscode

    Rmd文件转html R语言环境 Vscode扩展安装及配置 配置radian R依赖包 pandoc安装 配置pandoc环境变量 验证是否有效 转rmd为html 注意本文代码块均为R语言代码 在R语言环境下执行即可 R语言环境 官网中
  • shell I/O重定向

    shell重定向 lt 改变标准输入 program lt file 可将program 的标准输入改为file tr d r lt dos file txt 以 gt 改变标准输出 program gt file 可将program的标准
  • Qt基础之三十:百万级任务并发处理

    在实际的开发过程中 经常会遇到要处理大量任务场景 比如说压缩文件夹中的所有文件 对文件夹中的所有文件加密 上传文件夹中的所有文件到ftp等等 这里说百万级并不夸张 理论上文件夹中有任意多个文件都是可以的 本文以压缩文件夹中的100万张jpg
  • 三国志13pk版登录武将输入中文名方法与更改图像详解

    今天来个正经的文 三国志13里登录武将 设定姓名时 如果用的是自带输入法 就会出现一堆乱码 这时候 有两种解决方法 下载一个具有大五码的输入法 然后输入时候既要切换输入法 切换繁体 切换窗口模式 很麻烦 尤其在输入列传的时候 打很多字会很不
  • 【架构优化过程思考】技术方案评估的三个维度

    方案的选择决定了当下实现方案的资源投入及产出对产 也决定后续的成本 评估一个方案 首先要评估这个方案的有效性 也就是说要解决这个问题 实现目标 当前的这个方案是否足够的有效 还是在部分的场景下有效 如果是全部的有效那么该方案就不会出现上线之
  • 二叉树--合并二叉树

    问题 已知两颗二叉树 将它们合并成一颗二叉树 合并规则是 都存在的结点 就将结点值加起来 否则空的位置就由另一个树的结点来代替 思路 通过二叉树的前序遍历方法进行遍历 同时 t1二叉树作为蓝本进行计算 注意设置两个指针记录t1和t2遍历到的
  • JavaScript重写alert,confirm,prompt方法(JavaScript实现线程非阻塞式暂停和启动)

    得有段时间没好好写篇博客了 这次我们从一个题目开始吧 首先我给大家出一道题目 大家可以先思考一下 再往下看 题目是 请用JavaScript重写confirm方法 实现和confirm同样的功能 乍一看可能感觉很简单 定义一个confirm
  • php cms 自动分词,灵活运用PHPAnalysis分词组件,实现Phpcms v9关键词自动分词

    在2019年12月下旬 Phpcms官网phpcms cn关闭后 原有的分词api接口 http tool phpcms cn api get keywords php 已经失效 在录入标题后再也不能自动提取关键词到关键词的输入栏了 针对这
  • ReentrantLock的实现

    ReentrantLock可重入锁 我们可以利用这个实现对某一个操作约束为同有个时刻只能有一个线程能够操作 我们呢先看一下下面这个demo public class ReentrantLockTest public static void
  • 初级黑客入门指南——强烈推荐

    黑客指的是在计算机或计算机网络中发现弱点的人 尽管这个术语也可以指对计算机和计算机网络有深入了解的人 黑客的动机可能是多方面的 比如利润 抗议或挑战 围绕黑客发展的亚文化通常被称为 地下计算机 但现在它是一个开放的社区 虽然黑客这个词的其他
  • Python之由公司名推算出公司官网(余弦相似度)

    读大学时期写的博文 1 问题 对展会数据分类后 我的新任务是如何通过 公司名 公司地址 国家等海关数据推断出该公司的官网网站 若官网不存在则不考虑 以下数据仅供参考 公司名 国家 地址 JPW INDUSTRIES INC 427 NEW
  • 经典神经网络 -- DenseNet : 设计原理与pytorch实现

    原理 概念与网络结构 DenseNet模型 它的基本思路与ResNet一致 但是它建立的是前面所有层与后面层的密集连接 dense connection DenseNet的一大特色是通过 特征在channel上的连接 来实现特征重用 fea