Yolov5 中添加Network Slimming剪枝--稀疏训练部分

2023-05-16

前言:Network Slimming剪枝过程让如下

1. 稀疏化

2. 剪枝

3. 反复迭代这个过程

 一、稀疏化:

通过Network Slimming 的核心思想是:添加L1正则来约束BN层系数,从而剪掉那些贡献比较小的通道channel

原理如下:BN层的计算是这样的:

上边介绍了,Network Slimming的核心思想是剪掉那些贡献比较小的通道channel它的做法是从BN层下手。BN层的计算公式如下:

通过BN层的计算公式可以看出每个channe的Zout的大小和系数γ正相关,因此我们可以拿掉哪些γ-->0的channel,但是由于正则化,我们训练一个网络后,bn层的系数是正态分布的。这样的话,0附近的值则很少,那剪枝的作用就很小了。因此要先给BN层加上L1正则化进行一步稀疏训练(为什么要用L1正则化可以看该博客:l1正则与l2正则的特点是什么,各有什么优势? - 知乎)。

为BN层加入L1正则化后,损失函数公式为:

上面第一项是正常训练的loss函数,第二项是约束对于L1正则化,g(s)=|s|,λ是正则系数,引入L1正则来控制γ, 要把稀疏表达加在γ 上, 得到每个特征的重要性 λ

- 每个通道的特征对应的权重是 γ 
- 稀疏表达也是对 γ 来说的, 所以正则化系数 λ 也是针对  γ, 而不是 W
-  稀疏化后, 做γ 值的筛选

因此在进行反向传播时候:𝐿′=∑𝑙′+𝜆∑𝑔′(𝛾)=∑𝑙′+𝜆∑|𝛾|′=∑𝑙′+𝜆∑𝑠𝑖𝑔𝑛(𝛾)

那如何把程序加到yolov5呢?

在yolov5 train.py的程序中找到反向传播部分程序:

1.1 稀疏训练核心代码

将scaler.scale(loss).backward()注释,并添加下方代码:

  代码如下:

 # Backward
            # scaler.scale(loss).backward()
            loss.backward()
            # # ============================= sparsity training ========================== #
            srtmp = opt.sr*(1 - 0.9*epoch/epochs)  # opt.sr=0.0001 随着epoch增多,把srtmp减小
            if opt.st:  # '默认是true  train with L1 sparsity normalization  
                ignore_bn_list = []
                for k, m in model.named_modules():
                    # print('name: {}, module: {}'.format(k,m))
                    if isinstance(m, Bottleneck):
                        if m.add:
                            ignore_bn_list.append(k.rsplit(".", 2)[0] + ".cv1.bn")
                            ignore_bn_list.append(k + '.cv1.bn')
                            ignore_bn_list.append(k + '.cv2.bn')
                    if isinstance(m, nn.BatchNorm2d) and (k not in ignore_bn_list):
                        # L1 regulation formulate: λΣ|γ|
                        # |x|' = {-1,1}
                        # L1 grad: (λΣ|γ|)'=λ * Σsign(γ)
                        # BN(γ,β)
                        m.weight.grad.data.add_(srtmp * torch.sign(m.weight.data))  # L1
                        m.bias.grad.data.add_(opt.sr*10 * torch.sign(m.bias.data))  # L1
            # # ============================= sparsity training ========================== 
            # if ni - last_opt_step >= accumulate:
            #     scaler.step(optimizer)  # optimizer.step
            #     scaler.update()
            #     optimizer.zero_grad()
            #     if ema:
            #         ema.update(model)
            #     last_opt_step = ni
            optimizer.step()
                # scaler.step(optimizer)  # optimizer.step
                # scaler.update()
            optimizer.zero_grad()
            if ema:
                ema.update(model)

其中sr,st 需要添加参数

parser.add_argument('--st', action='store_true',default=True, help='train with L1 sparsity normalization')
parser.add_argument('--sr', type=float, default=0.0002, help='L1 normal sparse rate')

其中需要注意的点1:

红框处程序是因为这里并没有选择所有的bn层进行裁剪,这里选择去除那些有shortcut的Bottleneck层(对应代码中m.add = True),主要是为了保证shortcut和残差层channel一样可以add。

--------------------在这里我曾经有这样的疑问:(该部分可以不看) -----------------------------------------

这两个if 我能理解最终目的是:去除那些有shortcut的Bottleneck层,但是为什么要有 +cv1.bn等等那三步呢?不能直接把k添加到 ignore_bn_list吗?

再说了添加了之后,加入ignore_bn_list的名字就变了呀,此时再运行下一个if的时候k是不在ignore_bn_list

为什么不能改成指令:

if isinstance(m, Bottleneck):

                        if m.add:

                            ignore_bn_list.append(k)

if isinstance(m, nn.BatchNorm2d) and (k in ignore_bn_list):

 后来我明白了,这里是为了不对Bottlenack中的BatchNorm2d加正则化,因上述改名字的那个步骤其实是找的该Bottleneck下面BatchNorm2d 的名字。比如我断点调试了一下:

其中名为Module.model.2.m.0的模型

 其下的BatchNorm2d的名字分别如下:

那再回头看看那部分程序,就理解了。
-------------------------------------------------------结束--------------------------------------------------------------------

注意的点2:

yolov5会采用自动混合精度训练,因此需要把改成fp32,方法如下:
修改train.py
1. 注释掉    # scaler = amp.GradScaler(enabled=cuda)
2. 把train.py中的.half 都去掉
具体为:
  # Anchors
            if not opt.noautoanchor:
                check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
            # model.half().float()
            model.float()  # pre-reduce anchor precision
# Save model
            if (not nosave) or (final_epoch and not evolve):  # if save
                ckpt = {'epoch': epoch,
                        'best_fitness': best_fitness,
                        # 'model': deepcopy(de_parallel(model)).half(),
                        'model': deepcopy(de_parallel(model)),
                        # 'ema': deepcopy(ema.ema).half(),
                        'ema': deepcopy(ema.ema),
                        'updates': ema.updates,
                        'optimizer': optimizer.state_dict(),
                        'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
                        'date': datetime.now().isoformat()}
 if RANK in [-1, 0]:
        LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
        for f in last, best:
            if f.exists():
                strip_optimizer(f)  # strip optimizers
                if f is best:
                    LOGGER.info(f'\nValidating {f}...')
                    results, _, _ = val.run(data_dict,
                                            batch_size=batch_size // WORLD_SIZE * 2,
                                            imgsz=imgsz,
                                            # model=attempt_load(f, device).half(),
                                            model=attempt_load(f, device),
                                            iou_thres=0.65 if is_coco else 0.60,  # best pycocotools results at 0.65
                                            single_cls=single_cls,
                                            dataloader=val_loader,
                                            save_dir=save_dir,
                                            save_json=is_coco,
                                            verbose=True,
                                            plots=True,
                                            callbacks=callbacks,
                                            compute_loss=compute_loss)  # val best model with plots

以上就将对BN层添加L1正则化的程序加好了,核心思想就是修改反向传播的梯度。

1.2 查看稀疏训练效果

如果想查看系数训练的效果,可加入下方程序:

  # =============== show bn weights ===================== #
        module_list = []
        module_bias_list = []
        for i, layer in model.named_modules():
            if isinstance(layer, nn.BatchNorm2d) and i not in ignore_bn_list:
                bnw = layer.state_dict()['weight']
                bnb = layer.state_dict()['bias']
                module_list.append(bnw)
                module_bias_list.append(bnb)
                # bnw = bnw.sort()
                # print(f"{i} : {bnw} : ")
        size_list = [idx.data.shape[0] for idx in module_list]

        bn_weights = torch.zeros(sum(size_list))
        bnb_weights = torch.zeros(sum(size_list))
        index = 0
        for idx, size in enumerate(size_list):
            bn_weights[index:(index + size)] = module_list[idx].data.abs().clone()
            bnb_weights[index:(index + size)] = module_bias_list[idx].data.abs().clone()
            index += size

        # print("bn_weights:", torch.sort(bn_weights))
        # print("bn_bias:", torch.sort(bnb_weights))
        # tb_writer.add_histogram('bn_weights/hist', bn_weights.numpy(), epoch, bins='doane')
        # tb_writer.add_histogram('bn_bias/hist', bnb_weights.numpy(), epoch, bins='doane')
        loggers.tb.add_histogram('bn_weights/hist', bn_weights.numpy(), epoch, bins='doane')
        loggers.tb.add_histogram('bn_bias/hist', bnb_weights.numpy(), epoch, bins='doane')

将其加到:一个batch训练结束之后的程序后边就好

我在tensorboard中的图:

 纵轴是epoch,横轴是权重,可以看到我一共进行了100轮稀疏训练,我的项目中非bottoltneck中的bn层加起来参数大概有8000多个,那可以看到在49个epoch的时候,0附近的权重已经有5000多个了。那接下来我可以设置60%的剪枝率,把它们都剪掉

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Yolov5 中添加Network Slimming剪枝--稀疏训练部分 的相关文章

  • Linux为什么区分内核空间和用户空间

    程序如果要被CPU执行 xff0c 就得编译成CPU可以执行的指令 xff0c 一大堆的程序就变成了一堆的指令 一个操作系统它也是一堆程序组成的 xff0c 可以想象CPU的指令是很多的 xff0c 但是这么多的指令中 xff0c 有些指令
  • 【Docker】镜像的保存(save)到文件 与 加载(load)到宿主机

    背景 xff1a 我们制作好的镜像会存储在宿主机上 xff0c 那么在迁移的过程中 xff0c 我们应该如何 保存自定义的镜像到宿文件 或 加载自定义的镜像到宿主机呢 xff1f 制作镜像 xff1a docker build t 镜像名
  • 手把手教你学会闭包

    前言 MDN对闭包的解释是这样的 xff1a 一个函数和对其周围状态 xff08 lexical environment xff0c 词法环境 xff09 的引用捆绑在一起 xff08 或者说函数被引用包围 xff09 xff0c 这样的组
  • 从高考到程序员:选择专业三要素:擅长、喜欢、有价值

    参与从高考到程序员征文 xff1a http blog csdn net blogdevteam article details 72917467 从高考到程序员 xff1a 选择专业三要素 xff1a 擅长 喜欢 有价值 选择工作三要素
  • 舍选抽样法

    对于一个随机变量X xff0c 对其直接进行抽样比较困难时 xff0c 我们可以选择一个比较容易产生随机数且逼近f很好的一个分布f y 来对其进行抽样 xff0c 下面以贝塔分布为例进行舍选抽样 ps 实际上 应该是先找了一个f y 使得c
  • 初识STM32工作原理和基础编程

    一 初识STM32如何简单的点亮一个LED灯呢 xff1f 如何用一个按键控制LED灯的点亮与否呢 xff1f 本文将对这些问题做一个比较详细的解答 xff0c 其中还有几个比较经典的例子 xff0c 希望能帮助大家更好的理解STM32的工
  • 结构体为何要进行对齐以及如何对齐

    先说如何对齐 xff0c 再讲讲其背后的原理 对齐规则在网上和书上都很容易找到 无非就是以下几点 规则 1 第一个成员在与结构体变量偏移量为0的地址处 2 其他成员变量要对齐到某个数字 xff08 对齐数 xff09 的整数倍的地址处 对齐
  • 【Android Studio】Build Output输出中文乱码解决方法

    目录 问题 解决办法 修改配置文件 重启AS xff0c 再次触发Build 问题 解决办法 修改配置文件 Help gt Edit Custom VM Options 点击create 打开的配置文件中加入 xff1a Dconsole
  • vscode操作git

    vscode amp git vscode和git的联合 xff0c 完美的配合 本人是一个忠实的vscode使用用户 xff0c 毕竟他是开源的 xff0c 使用electron xff0c typescript开发的软件 是免费的 xf
  • Docker学习笔记(一)——解决docker权限问题

    1 解决docker权限问题 安装完docker后 xff0c 执行docker相关命令 xff0c 出现 xff1a Got permission denied while trying to connect to the Docker
  • docker打过tag标签后向镜像仓库推送镜像(push)

    推送镜像 在推送前 xff0c 必须给镜像打标签 xff0c 否则推送失败 xff0c 其实打标签就是定义一个版本标识 我们看下未打标签推送的提示信息 xff0c 其中swr 6666指向镜像仓库 xff1a span class toke
  • secureCRT安装和使用教程【图文并茂】

    secureCRT安装和使用教程 图文并茂 1 软件安装 2 软件使用 3 软件总结 1 软件安装 简介 一般而言 xff0c 嵌入式开发板使用串口来监控后台 可以使用串口线连接开发板和电脑 xff0c 对于没有串口的笔记本电脑来说 xff
  • tftpd32+ tftpd64文件传输安装和使用教程【图文并茂】

    tftpd32 43 tftpd64文件传输安装和使用教程 图文并茂 1 tftp软件安装 2 tftp使用教程 1 tftp软件安装 将编译好的程序放到开发板中去运行 xff0c 需要借助于一些软件 xff0c 下面介绍最常用的通过tft
  • Maven项目管理工具学习笔记

    Maven项目管理工具学习笔记 由于本人在最近的项目中使用到了Maven xff0c 但是之前对Maven并没有深入地了解 xff0c 所以借此机会 xff0c 在网上查阅资料 xff0c 对Maven进行进一步的了解 xff0c 并做记录
  • PyQt4控件失去焦点和获得焦点

    QListView控件多选设置 self ui listView setSelectionMode QAbstractItemView ExtendedSelection 初始化QListView控件焦点事件 self ui listVie
  • 远程工具MobaXterm安装和使用教程

    远程登录工具MobaXterm安装和使用教程 1 MobaXterm简介 2 MobaXterm安装 3 MobaXterm使用 创建SSH session 创建串口 session 右键粘贴 4 MobaXterm安全 1 MobaXte
  • GEC6818开发板使用和配置

    GEC6818开发板使用和配置 GEC1808开发板简介极致低功耗强大 AI 运算能力面向 AIoT 应用的丰富接口易于开发主控芯片特性参数 常用接口说明电源接口调试串口CSI摄像头接口以太网接口音频输入接口LCD接口 开发板功能 常用接口
  • Linux最常用命令50条【呕心沥血呐,望用之取之】

    Linux常用命令大全 第一章 Linux基础命令 1 linux ls 2 linux alias 3 linux cd 4 linux clear 5 linux date 6 linux dpkg 7 linux echo 8 lin
  • STM32 GPIO LED和蜂鸣器实现【第四天】

    STM32 GPIO LED实现 原理图一 STM32大小说明二 STM32时钟分析三 GPIO分析1 注意点 四 寄存器地址查找1 写出GPIOF外设的所有寄存器地址 五 LED灯开发1 理解led灯原理图2 打开GPIOF组时钟4 通过
  • Linux安装qt完整版教程

    Linux安装qt完整版教程 一 获取Linux qt版本二 linux安装qt三 配置qt环境变量四 linux qt相关的显示配置 一 获取Linux qt版本 qt 版本5 12 8官网下载地址 选择国内的下载渠道 xff0c 更快

随机推荐

  • STM32嵌入式工程师自我修养

    STM32嵌入式工程师自我修养 一 STM32必备技能二 程序员必须熟知三 学习STM32自备资料和硬件 一 STM32必备技能 1 熟悉 C 语言编程 xff0c 熟练 STM32CUBEMX Keil 开发环境 2 熟悉基于STM32
  • Qt 按钮组(Buttons)输入组(Input Widgets) 显示控件组(Display Widgets) 间隔组(Spacers) 布局组(Layouts) 容器组Containers等

    文章目录 按钮组 xff08 Buttons xff09 输入部件组 xff08 Input Widgets xff09 显示控件组 xff08 Display Widgets xff09 空间间隔组 xff08 Spacers xff09
  • Qt 对话框(QFileDialog)、标准颜色对话框(QColorDialog)、标准字体对话框(QFontDialog)、标准输入对话框(QInputDialog)、QMessageDialog

    文章目录 标准文件对话框 QFileDialog 代码简介QFileDialog类常用静态函数 标准颜色对话框 QColorDialog 代码简介QColorDialog类常用静态函数 标准字体对话框 QFontDialog 代码简介QFo
  • qt 使用textBrowser显示文字和图片,文字居中,图片居中,已测可用

    QTextBrowser显示图文操作 直接上源码UI设计效果截图源码方法2 推荐 源码 直接上源码 这里只给出框架 xff0c 美化的事交给有缘人 UI设计 效果截图 源码 span class token macro property s
  • HTTP报文格式详解

    文章目录 HTTP报文格式请求报文请求行请求头部空行请求数据 响应报文状态行响应头部空行响应体 HTTP报文格式 HTTP报文是面向文本的 xff0c 报文中的每一个字段都是一些ASCII码串 xff0c 每个字段的长度是不确定的 HTTP
  • 从源码分析HashMap集合之属性(一)

    注 xff1a 笔者所使用的jdk为1 8 xff0c 因本人水平有限 xff0c 难免会有错误 xff0c 请批评指正 xff0c 弥补不足 xff0c 多谢 xff0c 另转载请注明出处 我们首先来看下一下HashMap类 public
  • Linux下实现苹果AirPlay音频服务器

    一 背景 背景 xff1a 在华清学习之余 xff0c 自行研究了智能家居的东西 xff0c 为了解决智能家居中背景音乐问题研究如下 xff1a 调查发现现有技术中有DLNA AirPlay Miracast三种 文章后有些项并未验证 xf
  • uCOS-III基础入门函数

    uCOS III是一个主要是运行在单片机上操作系统 xff0c 可以实现并发 xff0c 主要的功能就是任务 mutex event的创建和使用 调度器 调度器就是使用相关算法来决定当前需要执行的任务 xff0c 调度器的核心有两个 xff
  • Java接口实现

    接口是什么 xff0c 它的作用是什么 xff1f 首先 xff0c Java只能实现单继承 xff0c 而有时候实际需要要求我们实现多继承 xff0c 因此 xff0c 接口就是为了实现多继承而开发出来的 xff0c 并且接口支持程序在运
  • python爬取京东商品信息及评论

    准备 chrome浏览器 和 chromeDriver插件 xff08 其他浏览器步骤类似 xff09 python 环境python selenium模块 代码 span class token triple quoted string
  • Error while loading error while loading shared libraries 解决办法

    Error while loading error while loading shared libraries 解决办法 缺失了 xff0c 那就找到放回去 发行版 xff1a Archlinux 如标题所言 xff0c 这里以截至写文章
  • sql-创建复合主键

    一 说明 xff1a 1 数据库的每张表只能有一个主键 xff0c 不可能有多个主键 2 所谓的一张表多个主键 xff0c 我们称之为复合主键 xff08 联合主键 xff09 注 xff1a 联合主键 xff1a 就是用多个字段一起作为一
  • sql_外键

    一 外键的定义 1 外键是一种索引 xff0c 是通过一张表中的一列指向另一张表的主键 xff0c 使得这两张表产生关联 2 是某个表中的一列 xff0c 它包含在另一个表的主键中 3 一张表中可以有一个外键也可以有多个外键 二 外键的作用
  • 数据库事务图解

    一 基本概念 xff08 from baidu xff09 数据库事务 Database Transaction xff0c 是指作为单个逻辑工作单元执行的一系列操作 xff0c 要么完全地执行 xff0c 要么完全地不执行 事务处理可以确
  • MySQL 中判断字符串是否相等

    感谢 xff1a https blog csdn net yangfengjueqi article details 72821603 mysql 中判断两个字符串是否相等可以用 lt 61 gt 或者 61 例 但是需要注意 lt 61
  • mysql order by 多个字段及其多字段排序原则,和 order by 后跟数字

    一 order by 后跟数字 select from table order by n n 表示select里面的第n个字段 xff0c 整段sql的意义是 xff1a 查询出来的结果 xff0c 按照第N个字段排序 二 order by
  • ROS 学习1- 创建工作空间以及功能包

    一 工作空间概念 在ros中工作空间统称为workspace 是用来存放你一个工程开发需要用到的相关文件的 xff0c 在ros中它是一个带有空座空间性质的文件夹 该文件夹中通常会包含4个子文件夹 src 用来存放功能包的 build 编译
  • Linux 中echo及echo > 和echo >>

    一 Linux 中的echo指令 Shell 的 echo 指令用于字符串的输出 详见 xff1a Shell echo命令 菜鸟教程 二 echo gt 和echo gt gt echo gt 和echo gt gt 的区别详见 xff1
  • Yolov5 计算访存量MAC与计算量FLOPS

    说明 xff1a 因为yolov5函数中已经计算了 FLOPS xff0c 因此如果想要计算访存量那么只需按照flops的位置 添加访存量的计算即可 一 先记住计算量和访存量的公式 xff1a 二 找到计算FLOPS的位置 xff0c 并添
  • Yolov5 中添加Network Slimming剪枝--稀疏训练部分

    前言 xff1a Network Slimming剪枝过程让如下 1 稀疏化 2 剪枝 3 反复迭代这个过程 一 稀疏化 xff1a 通过Network Slimming 的核心思想是 添加L1正则来约束BN层系数 xff0c 从而剪掉那些