轻量级卷积神经网络的设计技巧

2023-11-13

点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达9615fa5a56161d6d5f8eb33cad82768a.png

这篇文章将从一个证件检测网络(Retinanet)的轻量化谈起,简洁地介绍,我在实操中使用到的设计原则和idea,并贴出相关的参考资料和成果供读者参考。因此本文是一篇注重工程性、总结个人观点的文章,存在不恰当的地方,请读者在评论区指出,方便交流。

目前已有的轻量网络有:MobileNet V2和ShuffleNet v2为代表。在实际业务中,Retinanet仅需要检测证件,不涉及过多的类别物体的定位和分类,因此,我认为仅仅更换上述两个骨架网络来优化模型的性能是不够的,需要针对证件检测任务,专门设计一个更加轻量的卷积神经网络来提取、糅合特征。

设计原则:


1. 更多的数据

轻量的浅层网络特征提取能力不如深度网络,训练也更需要技巧。假设保证有足够多的训练的数据,轻量网络训练会更加容易。

Facebook研究院的一篇论文[1]提出了“数据蒸馏”的方法。实际上,标注数据相对未知数据较少,我使用已经训练好、效果达标的base resnet50的retinanet来进行自动标注,得到一批10万张机器标注的数据。这为后来的轻量网络设计奠定了数据基础。我认为这是构建一个轻量网络必要的条件之一,网络结构的有效性验证离不开大量的实验结果来评估。

接下来,这一部分我将简洁地介绍轻量CNN地设计的四个原则

2. 卷积层的输入、输出channels数目相同时,计算需要的MAC(memory access cost)最少

edc004dd0f3ed226ef2a37e4ea554264.png

3. 过多的分组卷积会增加MAC

对于1x1的分组卷积(例如:MobileNetv2的深度可分离卷积采用了分组卷积),其MAC和FLOPS的关系为:

c67b185a2af13547ebba0e1229728603.png

g代表分组卷积数量,很明显g越大,MAC越大。详细参考[2]

4. 网络结构的碎片化会减少可并行计算

这些碎片化更多是指网络中的多路径连接,类似于short-cut,bottle neck等不同层特征融合,还有如FPN。拖慢并行的一个很主要因素是,运算快的模块总是要等待运算慢的模块执行完毕。

a162afa92225ae5887db24e66f164212.png

5. Element-wise操作会消耗较多的时间(也就是逐元素操作)

从表中第一行数据看出,当移除了ReLU和short-cut,大约提升了20%的速度。

b9e94981d18458c4546a430eec644e4a.png

以上是从此篇论文[2]中转译过来的设计原则,在实操中,这四条原则需要灵活使用。

根据以上几个原则进行网络的设计,可以将模型的参数量、访存量降低很大一部分。

接下来介绍一些自己总结的经验。

6. 网络的层数不宜过多

通常18层的网络属于深层网络,在设计时,应选择一个参考网络基线,我选择的是resnet18。由于Retinanet使用了FPN特征金字塔网络来融合各个不同尺度范围的特征,因此Retinanet仍然很“重”,需要尽可能压缩骨架网络的冗余,减少深度。

7. 首层卷积层用空洞卷积和深度可分离卷积替换

一个3x3,d=2的空洞卷积在感受野上,可以看作等效于5x5的卷积,提供比普通3x3的卷积更大的感受野,这在网络的浅层设计使用它有益。计算出网络各个层占有的MAC和参数量,将参数量和计算量“重”的卷积层替换成深度可分离卷积层,可以降低模型的参数量。

这里提供一个计算pytorch 模型的MAC和FLOPs的python packages[3]

if __name__ == "__main__":
    from ptflops import get_model_complexity_info

    net = SNet(num_classes=1)
    x = torch.Tensor(1, 3, 224, 224)

    net.eval()

    if torch.cuda.is_available():
        net = net.cuda()
        x = x.cuda()

    with torch.cuda.device(0):
        flops, params = get_model_complexity_info(net, (224, 224), print_per_layer_stat=True, as_strings=True, is_cuda=True)
        print("FLOPS:", flops)
        print("PARAMS:", params)

output:

(regressionModel): RegressionModel(
    0.045 GMac, 27.305% MACs,
    (conv1): Conv2d(0.009 GMac, 5.257% MACs, 128, 256, kernel_size=(1, 1), stride=(1, 1))
    (act1): ReLU(0.0 GMac, 0.041% MACs, )
    (conv2): Conv2d(0.017 GMac, 10.472% MACs, 256, 256, kernel_size=(1, 1), stride=(1, 1))
    (act2): ReLU(0.0 GMac, 0.041% MACs, )
    (conv3): Conv2d(0.017 GMac, 10.472% MACs, 256, 256, kernel_size=(1, 1), stride=(1, 1))
    (act3): ReLU(0.0 GMac, 0.041% MACs, )
    (output): Conv2d(0.002 GMac, 0.982% MACs, 256, 24, kernel_size=(1, 1), stride=(1, 1))
  )
  (classificationModel): ClassificationModel(
    0.044 GMac, 26.569% MACs,
    (conv1): Conv2d(0.009 GMac, 5.257% MACs, 128, 256, kernel_size=(1, 1), stride=(1, 1))
    (act1): ReLU(0.0 GMac, 0.041% MACs, )
    (conv2): Conv2d(0.017 GMac, 10.472% MACs, 256, 256, kernel_size=(1, 1), stride=(1, 1))
    (act2): ReLU(0.0 GMac, 0.041% MACs, )
    (conv3): Conv2d(0.017 GMac, 10.472% MACs, 256, 256, kernel_size=(1, 1), stride=(1, 1))
    (act3): ReLU(0.0 GMac, 0.041% MACs, )
    (output): Conv2d(0.0 GMac, 0.245% MACs, 256, 6, kernel_size=(1, 1), stride=(1, 1))
    (output_act): Sigmoid(0.0 GMac, 0.000% MACs, )
  )

8. Group Normalization 替换 Batch Normalization

BN在诸多论文中已经被证明了一些缺陷,而训练目标检测网络耗费显存,开销巨大,通常冻结BN来训练,原因是小批次会让BN失效,影响训练的稳定性。建议一个BN的替代--GN,pytorch 0.4.1内置了GN的支持。

9. 减少不必要的shortcut连接和RELU层

网络不够深,没有必要使用shortcut连接,不必要的shortcut会增加计算量。RELU与shortcut一样都会增加计算量。同样RELU没有必要每一个卷积后连接(需要实际训练考虑删减RELU)。

10. 善用1x1卷积

1x1卷积可以改变通道数,而不改变特征图的空间分辨率,参数量低,计算效率也高。如使用kernel size=3,stride=1,padding=1,可以保证特征图的空间分辨率不变,1x1的卷积设置stride=1,padding=0达到相同的目的,而且1x1卷积运算的效率目前有很多底层算法支持,效率更高。[5x1] x [1x5] 两个卷积可以替换5x5卷积,同样可以减少模型参数。

11. 降低通道数

降低通道数可以减少特征图的输出大小,显存占用量下降明显。参考原则2

12. 设计一个新的骨架网络找对参考网络

一个好的骨架网络需要大量的实验来支撑它的验证,因此在工程上,参考一些实时网络结构设计自己的骨架网络,事半功倍。我在实践中,参考了这篇[4]paper的骨架来设计自己的轻量网络。

总结

我根据以上的原则和经验对Retinanet进行瘦身,不仅局限于骨架的新设计,FPN支路瘦身,两个子网络(回归网络和分类网络)均进行了修改,期望性能指标FPS提升到63,增幅180%。

FPS

70eca5d0e031d86d375c203bb514717a.png

mAP

44e38f3b25ce8d0aa559add91837d32f.png

Model size

67d7751113c38f514e1e062b4e943be2.png
 
 

好消息!

小白学视觉知识星球

开始面向外开放啦

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

轻量级卷积神经网络的设计技巧 的相关文章

随机推荐

  • web前端面试题(全)

    近来看到网上格式各样的web前端求职的面试题 接下来我用我的经验总结了一套在面试过程中高频率问到的面试题 希望能帮助各位求职者在求职的过程中顺利通过 废话不多说 直接说题 一 HTML5部分 1 说一下对css盒模型的理解 答 css盒子模
  • 【总结一】现代密码学

    目录 1 密码学概述 1 1 密码学的基本概念 1 1 1 为什么要学密码学 1 1 2 什么是密码学 1 1 2 密码算法的基本模型 1 1 3 密码算法的分类 1 2 密码分析学 1 3 古典密码算法 1 3 1 置换密码 1 3 2
  • 对表的复杂查询

    1 连接查询 数据库中的各个表中存放着不同的数据 用户往往需要用多个表中的数据来组合 提炼出所需要的信息 如果一个查询需要对多个表进行操作 就称为连接查询 例 对student sno clno sname ssex sage course
  • Windows上安装Hadoop 3.x

    目录 0 安装Java 1 安装Hadoop 1 1 下载Hadoop 1 2 下载winutils 2 配置Hadoop 1 hadoop env cmd 2 创建数据目录 3 core site xml 4 hdfs site xml
  • 解决textarea文字不顶头显示/点击textarea 不是第一行

    问题描述 表单提交后发现内容前多了很多空格 而且每次更新表单提交都会有空格增加 后来发现 每次文字从数据库读到textarea后文字都不居左 在排出样式 转义字符等问题后 baidu google了一会始终没找到答案 后来发现原来问题处在H
  • 网络--正向代理和反向代理

    正向代理的概念 正向代理 也就是传说中的代理 他的工作原理就像一个跳板 简单的说 我是一个用户 我访问不了某网站 但是我能访问一个代理服务器 这个代理服务器呢 他能访问那个我不能访问的网站 于是我先连上代理服务器 告诉他我需要那个无法访问网
  • 如何将VS Code扩展插件迁移出系统盘

    背景 Windows的C盘 系统盘 容量经常不够用 经过排查发现VSCode的扩展插件所在目录占用了很大空间 为了节省系统盘的空间 需要将VSCode扩展插件迁移到D盘 环境 Windows VS Code 全称是Visual Studio
  • MySQL的JSON数据类型介绍以及JSON的解析查询

    文章目录 概述 JSON 数据类型的意义 JSON相关函数 测试 创建测试表 插入数据 查询数据 条件查询 优化JSON查询 解决方案 总结 概述 MySQL从5 7后引入了json数据类型以及json函数 可以有效的访问json格式的数据
  • iOS音视频—FFmepg:iOS平台下集成和应用

    1 在iOS平台下集成和应用FFmpeg Mac配置FFmpeg环境 1 安装homebrew ruby e curl fsSL https raw githubusercontent com Homebrew install master
  • Maven中测试插件(surefire)的相关配置及常用方法

    原创文章 版权所有 允许转载 标明出处 http blog csdn net wanghantong 1 在Maven中配置测试插件surefire html view plain copy
  • 通讯录管理系统(C++)

    1 菜单功能 功能描述 用户选择功能的界面 步骤 封装函数showMenu 显示该界面 在main函数中调用封装好的函数 菜单界面 void showMenu cout lt lt 1 添加联系人 lt lt endl cout lt lt
  • \t转义字符占几个字节?

    这个问题 在你学习编程过程中可能会考虑到 有时为了字节对齐而使用转义符中 t 但是到底 t占用几个空格呢 下面我们首先通过程序来体验下 然后在总结 include
  • ElasticSearch(7)---倒排索引

    上一篇 ElasticSearch 6 Kibana插件 1 正向索引和反向索引 涉及到索引的概念的时候 首先需要知道 索引可以分为正向索引和反向索引 也可以理解为倒排索引 正向索引 正向索引可以简单理解为从文档到单词 例如现在有4个文档
  • C库函数之memcpy的实现

    C库函数之memcpy的实现 memcpy的实现方式是当满足四字节对齐时 进行四字节的拷贝 不满足时进行单字节的拷贝 例如拷贝10个字节 循环两次拷贝四字节 在循环两次拷贝一字节 void mem memcpy void dst const
  • h5页面加空格常用的几种方法

    1 html table align center border 1px width 200px tr td 姓名 td td 姓名 td tr tr td 姓 nbsp 名 td td 姓 160 名 td tr tr td 姓 ensp
  • 原深感摄像头与face id实现人脸3D扫描和建模(转)

    原文地址 https tech china com article 20170914 2017091459353 html 就在本月13号 苹果在乔布斯剧院高调地召开了2017秋季新品发布会 本场发布会的最大亮点 也是此前外界最期待的 无疑
  • 正确认识H.264与MPEG-4技术产品

    MPEG4的技术规范如下表所示 H 264视频编解码标准被纳入MPEG 4 Part 10标准中 也就是说它只是附属于MPEG 4的第十部分 换句话说 H 264没有超出MPEG 4标准范畴 因此 网上有关H 264标准和视频传输质量高于M
  • errors and 0 warnings potentially fixable with the `--fix` option.

    vue 项目运行过程中出现 3 errors and 0 warnings potentially fixable with the fix option 的错误 报错问题 原因一 在创建vue项目中 会选择linter Formatter
  • 记一次MQ并发消费导致任务状态异常问题

    背景 项目中有一个短信群发任务 例如1次要发送1W条短信 系统会获取任务中每一条短信的MQ并发发送短信 任务默认状态是未发送 状态码 0 需要在这一批任务发送第一条短信的时候 将任务状态修改为发送中 状态码 1 在任务发送结束将状态修改为发
  • 轻量级卷积神经网络的设计技巧

    点击上方 小白学视觉 选择加 星标 或 置顶 重磅干货 第一时间送达 这篇文章将从一个证件检测网络 Retinanet 的轻量化谈起 简洁地介绍 我在实操中使用到的设计原则和idea 并贴出相关的参考资料和成果供读者参考 因此本文是一篇注重
Powered by Hwhale