【知识蒸馏】Knowledge Review

2023-11-08

【GiantPandaCV引言】 知识回顾(KR)发现学生网络深层可以通过利用教师网络浅层特征进行学习,基于此提出了回顾机制,包括ABF和HCL两个模块,可以在很多分类任务上得到一致性的提升。

摘要

知识蒸馏通过将知识从教师网络传递到学生网络,但是之前的方法主要关注提出特征变换和实施相同层的特征。

知识回顾Knowledge Review选择研究教师与学生网络之间不同层之间的路径链接。

简单来说就是研究教师网络向学生网络传递知识的链接方式。

代码在:https://github.com/Jia-Research-Lab/ReviewKD

KD简单回顾

KD最初的蒸馏对象是logits层,也即最经典的Hinton的那篇Knowledge Distillation,让学生网络和教师网络的logits KL散度尽可能小。

随后FitNets出现开始蒸馏中间层,一般通过使用MSE Loss让学生网络和教师网络特征图尽可能接近。

Attention Transfer进一步发展了FitNets,提出使用注意力图来作为引导知识的传递。

PKT(Probabilistic knowledge transfer for deep representation learning)将知识作为概率分布进行建模。

Contrastive representation Distillation(CRD)引入对比学习来进行知识迁移。

以上方法主要关注于知识迁移的形式以及选择不同的loss function,但KR关注于如何选择教师网络和学生网络的链接,一下图为例:

(a-c)都是传统的知识蒸馏方法,通常都是相同层的信息进行引导,(d)代表KR的蒸馏方式,可以使用教师网络浅层特征来作为学生网络深层特征的监督,并发现学生网络深层特征可以从教师网络的浅层学习到知识。

教师网络浅层到深层分别对应的知识抽象程度不断提高,学习难度也进行了提升,所以学生网络如果能在初期学习到教师网络浅层的知识会对整体有帮助。

KR认为浅层的知识可以作为旧知识,并进行不断回顾,温故知新。如何从教师网络中提取多尺度信息是本文待解决的关键:

  • 提出了Attention based fusion(ABF) 进行特征fusion

  • 提出了Hierarchical context loss(HCL) 增强模型的学习能力。

Knowledge Review

形式化描述

X是输入图像,S代表学生网络,其中 ( S 1 , S 2 , ⋯   , S n , S c ) \left(\mathcal{S}_{1}, \mathcal{S}_{2}, \cdots, \mathcal{S}_{n}, \mathcal{S}_{c}\right) (S1,S2,,Sn,Sc)代表学生网络各个层的组成。

Y s = S c ∘ S n ∘ ⋯ ∘ S 1 ( X ) \mathbf{Y}_{s}=\mathcal{S}_{c} \circ \mathcal{S}_{n} \circ \cdots \circ \mathcal{S}_{1}(\mathbf{X}) Ys=ScSnS1(X)

Ys代表X经过整个网络以后的输出。 ( F s 1 , ⋯   , F s n ) \left(\mathbf{F}_{s}^{1}, \cdots, \mathbf{F}_{s}^{n}\right) (Fs1,,Fsn)代表各个层中间层输出。

那么单层知识蒸馏可以表示为:

L S K D = D ( M s i ( F s i ) , M t i ( F t i ) ) \mathcal{L}_{S K D}=\mathcal{D}\left(\mathcal{M}_{s}^{i}\left(\mathbf{F}_{s}^{i}\right), \mathcal{M}_{t}^{i}\left(\mathbf{F}_{t}^{i}\right)\right) LSKD=D(Msi(Fsi),Mti(Fti))

M代表一个转换,从而让Fs和Ft的特征图相匹配。D代表衡量两者分布的距离函数。

同理多层知识蒸馏表示为:

L M K D = ∑ i ∈ I D ( M s i ( F s i ) , M t i ( F t i ) ) \mathcal{L}_{M K D}=\sum_{i \in \mathbf{I}} \mathcal{D}\left(\mathcal{M}_{s}^{i}\left(\mathbf{F}_{s}^{i}\right), \mathcal{M}_{t}^{i}\left(\mathbf{F}_{t}^{i}\right)\right) LMKD=iID(Msi(Fsi),Mti(Fti))

以上公式是学生和教师网络层层对应,那么单层KR表示方式为:

具 体 具体

与之前不同的是,这里计算的是从j=1 to i 代表第i层学生网络的学习需要用到从第1到i层所有知识。

同理,多层的KR表示为:

L M K D − R = ∑ i ∈ I ( ∑ j = 1 i D ( M s i , j ( F s i ) , M t j , i ( F t j ) ) ) \mathcal{L}_{M K D_{-} R}=\sum_{i \in \mathbf{I}}\left(\sum_{j=1}^{i} \mathcal{D}\left(\mathcal{M}_{s}^{i, j}\left(\mathbf{F}_{s}^{i}\right), \mathcal{M}_{t}^{j, i}\left(\mathbf{F}_{t}^{j}\right)\right)\right) LMKDR=iI(j=1iD(Msi,j(Fsi),Mtj,i(Ftj)))

Fusion方式设计

已经确定了KR的形式,即学生每一层回顾教师网络的所有靠前的层,那么最简单的方法是:

直接缩放学生网络最后一层feature,让其形状和教师网络进行匹配,这样 M s i , j \mathcal{M}_s^{i,j} Msi,j可以简单使用一个卷积层配合插值层完成形状的匹配过程。这种方式是让学生网络更接近教师网络。

这张图表示扩展了学生网络所有层对应的处理方式,也即按照第一张图的处理方式进行形状匹配。

这种处理方式可能并不是最优的,因为会导致stage之间出现巨大的差异性,同时处理过程也非常复杂,带来了额外的计算代价。

为了让整个过程更加可行,提出了Attention based fusion $\mathcal{U}
$, 这样整体蒸馏变为:

∑ i = j n D ( F s i , F t j ) ≈ D ( U ( F s j , ⋯   , F s n ) , F t j ) \sum_{i=j}^{n} \mathcal{D}\left(\mathbf{F}_{s}^{i}, \mathbf{F}_{t}^{j}\right) \approx \mathcal{D}\left(\mathcal{U}\left(\mathbf{F}_{s}^{j}, \cdots, \mathbf{F}_{s}^{n}\right), \mathbf{F}_{t}^{j}\right) i=jnD(Fsi,Ftj)D(U(Fsj,,Fsn),Ftj)

如果引入了fusion的模块,那整体流程就变为下图所示:

但是为了更高的效率,再对其进行改进:

可以发现,这个过程将fusion的中间结果进行了利用,即 F s j  and  U ( F s j + 1 , ⋯   , F s n ) \mathbf{F}_{s}^{j} \text { and } \mathcal{U}\left(\mathbf{F}_{s}^{j+1}, \cdots, \mathbf{F}_{s}^{n}\right) Fsj and U(Fsj+1,,Fsn), 这样循环从后往前进行迭代,就可以得到最终的loss。

具体来说,ABF的设计如下(a)所示,采用了注意力机制融合特征,具体来说中间的1x1 conv对两个level的feature提取综合空间注意力特征图,然后再进行特征重标定,可以看做SKNet的空间注意力版本。

而HCL Hierarchical context loss 这里对分别来自于学生网络和教师网络的特征进行了空间池化金字塔的处理,L2 距离用于衡量两者之间的距离。

KR认为这种方式可以捕获不同level的语义信息,可以在不同的抽象等级提取信息。

实验

实验部分主要关注消融实验:

第一个是使用不同stage的结果:

蓝色的值代表比baseline 69.1更好,红色代表要比baseline更差。通过上述结果可以发现使用教师网络浅层知识来监督学生网络深层知识是有效的。

第二个是各个模块的作用:

源码

主要关注ABF, HCL的实现:

ABF实现:

class ABF(nn.Module):
    def __init__(self, in_channel, mid_channel, out_channel, fuse):
        super(ABF, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, mid_channel, kernel_size=1, bias=False),
            nn.BatchNorm2d(mid_channel),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(mid_channel, out_channel,kernel_size=3,stride=1,padding=1,bias=False),
            nn.BatchNorm2d(out_channel),
        )
        if fuse:
            self.att_conv = nn.Sequential(
                    nn.Conv2d(mid_channel*2, 2, kernel_size=1),
                    nn.Sigmoid(),
                )
        else:
            self.att_conv = None
        nn.init.kaiming_uniform_(self.conv1[0].weight, a=1)  # pyre-ignore
        nn.init.kaiming_uniform_(self.conv2[0].weight, a=1)  # pyre-ignore

    def forward(self, x, y=None, shape=None, out_shape=None):
        n,_,h,w = x.shape
        # transform student features
        x = self.conv1(x)
        if self.att_conv is not None:
            # upsample residual features
            y = F.interpolate(y, (shape,shape), mode="nearest")
            # fusion
            z = torch.cat([x, y], dim=1)
            z = self.att_conv(z)
            x = (x * z[:,0].view(n,1,h,w) + y * z[:,1].view(n,1,h,w))
        # output 
        if x.shape[-1] != out_shape:
            x = F.interpolate(x, (out_shape, out_shape), mode="nearest")
        y = self.conv2(x)
        return y, x

HCL实现:

def hcl(fstudent, fteacher):
# 两个都是list,存各个stage对象
    loss_all = 0.0
    for fs, ft in zip(fstudent, fteacher):
        n,c,h,w = fs.shape
        loss = F.mse_loss(fs, ft, reduction='mean')
        cnt = 1.0
        tot = 1.0
        for l in [4,2,1]:
            if l >=h:
                continue
            tmpfs = F.adaptive_avg_pool2d(fs, (l,l))
            tmpft = F.adaptive_avg_pool2d(ft, (l,l))
            cnt /= 2.0
            loss += F.mse_loss(tmpfs, tmpft, reduction='mean') * cnt
            tot += cnt
        loss = loss / tot
        loss_all = loss_all + loss
    return loss_all

ReviewKD实现:

class ReviewKD(nn.Module):
    def __init__(
        self, student, in_channels, out_channels, shapes, out_shapes,
    ):  
        super(ReviewKD, self).__init__()
        self.student = student
        self.shapes = shapes
        self.out_shapes = shapes if out_shapes is None else out_shapes

        abfs = nn.ModuleList()

        mid_channel = min(512, in_channels[-1])
        for idx, in_channel in enumerate(in_channels):
            abfs.append(ABF(in_channel, mid_channel, out_channels[idx], idx < len(in_channels)-1))
        self.abfs = abfs[::-1]
        self.to('cuda')

    def forward(self, x):
        student_features = self.student(x,is_feat=True)
        logit = student_features[1]
        x = student_features[0][::-1]
        results = []
        out_features, res_features = self.abfs[0](x[0], out_shape=self.out_shapes[0])
        results.append(out_features)
        for features, abf, shape, out_shape in zip(x[1:], self.abfs[1:], self.shapes[1:], self.out_shapes[1:]):
            out_features, res_features = abf(features, res_features, shape, out_shape)
            results.insert(0, out_features)

        return results, logit

参考

https://zhuanlan.zhihu.com/p/363994781

https://arxiv.org/pdf/2104.09044.pdf

https://github.com/dvlab-research/ReviewKD

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【知识蒸馏】Knowledge Review 的相关文章

  • 第10篇:强化学习Q-learning求解迷宫问题 代码实现

    你好 我是郭震 zhenguo 今天重新发布强化学习第10篇 强化学习Q learning求解迷宫问题 代码实现 我想对此篇做一些更加详细的解释 1 创建地图 创建迷宫地图 包括墙网格 走到墙网格就是负奖励 注意 空白可行走网格奖励值设置为
  • 利用梳状函数求解周期函数傅里叶变换

    本文对梳状函数 1 单位冲激函数 2 梳状函数及其傅里叶变换 3 卷积和傅里叶变换 卷积是一种运算方式 针对线性时不变系统 最基础的应用就是 在时域中 一个输入 卷积上单位冲激响应 就可以得到输出 傅立叶变换的主要作用就是让函数在时域和频域
  • 在Ubuntu中配置中文输入法

    找到设置 选择区域和语言 点击Manage Installed Languagees 提示安装的话安装即可 4 点击图示内容 5 将Chinese simplified 勾选上 右键点击住 将汉语拖到第一位 重启Ubuntu 6 在输入源中
  • linux设备树节点添加新的复位属性之后设备驱动加载异常问题分析

    linux设备树节点添加新的复位属性之后设备驱动加载异常问题分析 1 linux原始设备驱动信息 1 1 设备树节点信息 1 2 linux设备驱动 1 3 makefile 1 4 Kconfig 1 5 对应的defconfig文件 2

随机推荐

  • Springboot ppt转pdf——aspose方式

    Springboot ppt转pdf aspose方式 1 下载ppt转pdf所需要的包 网盘地址 https pan baidu com s 1V CZ0zXcJzKofxr6qC1g8A 提取码 86lp 2 maven添加依赖 在项目
  • 编译开源软件vtr-verilog-to-routing遇到的一点问题

    vtr verilog to routing介绍 Verilog to Routing VTR 项目是一个全球性的合作项目 旨在提供一个开源框架 用于进行FPGA架构和CAD研究和开发 VTR设计流程以数字电路的Verilog描述和目标FP
  • SpringBoot连接RabbitMQ时一直显示Socket Closed或者An unexpected connection driver error occured,但是能正常访问web管理台

    问题 在使用SpringBoot去连接虚拟机或者远程主机的RabbitMQ时 出现了一直报错超时 报错 Socket Closed 或者 An unexpected connection driver error occured 解决方案
  • “程序员转型管理:从擅长代码到掌控团队的踩坑总结”

    作为程序员 很多人会在职业生涯中考虑转型管理岗位 然而 这个转换并不容易 除了需要掌握管理方面的知识和技能外 还需要处理人际关系并带领团队一起前进 在这个过程中 很多人可能会踩到一些坑 以下是我总结的一些经验教训 一 控制欲 由于程序员需要
  • jquery ajax 防止注入,javascript-jQuery在ajax全局事件中注入数据

    我正在尝试在ajax请求中注入数据 但是它失败了 我也不知道为什么 我试图查看jQuery源代码 但仍然找不到为什么它不起作用 感谢任何帮助 这是代码 someElement ajaxSend function e req options
  • python项目加密(模型加密,文件加密),涵盖了多种方法以及代码实现

    Python作为动态语言一般是以源码方式进行部署的 这就意味着他人在部署机器上可以直接获取项目代码 可能给作者带来不必要的损失和风险 这就需要对代码进行加密或混淆 常规的几类加密 混淆 方式如下 编译为pyc文件 将项目代码打包成pytho
  • 我的Substance Designer 学习笔记02-PBR材质学习理解

    首先定义PBR Physics based rendering 基于物理的渲染 由来 2012年迪士尼公司在技术论坛发布的文章 讲述自己作品的制作流程 2014年被某大佬提出简化版本的制作流程 优化后只用5中材质通道 BSDF 双向散射率分
  • C++实验02(02)华氏温度转换为摄氏温度

    题目描述 编写一个函数convert 把华氏温度转换为摄氏温度 转换公式为 C F 32 5 9 要求用内联函数实现 在main 中调用该函数 说明 F为double型 输入描述 华氏温度 输出描述 摄氏温度 输入样例 100 输出样例 华
  • 单线双线多线服务器有哪些区别

    单线双线多线服务器有哪些区别 服务器托管是我们现在当下比较常用的一种方式 越来越多的企业及站长 他们都会选择服务器托管 这不仅可以减少企业的维护时间成本 也可以让网站或者平台能够得到更多的专业技术支持 那么 在服务器托管中 我们经常会遇到单
  • jenkins+fastlane+git+cocoapods实现iOS持续集成踩坑记录

    前提 本项目在配置jenkins前已配置安装fastlane并自动上传蒲公英 关于fastlane的使用不在本文讨论范围之内 安装Jenkins jenkins有几种方式安装 一种是去官网下载dmg安装包 还可以下载 war文件 通过执行命
  • 整理一下react的知识点之redux-devtools-extension基本使用(持续更新)

    1 下载相关包 npm i redux react redux redux thunk redux devtools extension 2 安装react开发工具 chrome浏览器插件 3 安装redux的开发工具 chrome浏览器插
  • 【华为OD统一考试B卷

    华为OD统一考试A卷 B卷 新题库说明 2023年5月份 华为官方已经将的 2022 0223Q 1 2 3 4 统一修改为OD统一考试 A卷 和OD统一考试 B卷 你收到的链接上面会标注A卷还是B卷 请注意 根据反馈 目前大部分收到的都是
  • 如何用ChatGPT辅助写论文

    ChatGPT先进功能创造了巨大的需求 该AI工具在推出后的两个月内就积累了超过1亿用户 最突出的功能之一是它能够在几秒钟内编写各种文本 包括歌曲 诗歌 睡前故事和散文 但是ChatGPT可以做的不仅仅是写一篇文章 更有用的是它如何帮助指导
  • 什么是IOC和DI?DI是如何实现的?

    什么是IOC和DI DI是如何实现的 IOC Inversion of Control 叫控制反转 DI Dependency Injection 叫依赖注入 是对IOC更简单的诠释 IOC 控制反转是把传统上由程序代码直接操控的对象的调用
  • IDEA上传代码到Gitee

    提示 这里可以使IDEA上传代码到Gitee 需要自己手动操作 目录 前言 一 打开Gitee官网 进行注册登录 1 登录进去找到右上角添加仓库 进行所示图操作 二 启动IDEA 1 IDEA关联Gitee 2 找到git下载好git程序
  • SPI协议的verilog实现:利用spi协议配置寄存器

    状态机状态跳转图 因常常需要对寄存器进行配置 因而学习了V3学院的视频课 利用spi协议对寄存器进行配置 在此做个记录 以便日后回顾 上图为状态机状态转移图 需要先将需要配置的寄存器的信息存放在ROM中 然后将数据读出来 通过SPI协议发送
  • Vue3快速入门教程

    学某个新技能时 大多数人倾向于 一开始就从头到尾完整学一遍 甚至有人翻来覆去重复学很多遍也达不到熟记于心 我个人认为 这不是最好的办法 我的建议的是 面向需求 or 面向问题来学习 最开始你可能不了解你要实现的效果会涉及哪些技术知识点 那么
  • 六十七.深度优先遍历C语言实现(有向图)

    include
  • ApplicationContext类继承设计

    先上类图 BeanFactory是Spring IoC的核心接口 BeanFactory相关的类设计可以看做是Spring的核心骨骼 为整个框架设计了一个基本的核心架构 但只有骨骼 没有血肉 也是不完整的 这样一个核心的骨架难以在实际开发中
  • 【知识蒸馏】Knowledge Review

    GiantPandaCV引言 知识回顾 KR 发现学生网络深层可以通过利用教师网络浅层特征进行学习 基于此提出了回顾机制 包括ABF和HCL两个模块 可以在很多分类任务上得到一致性的提升 摘要 知识蒸馏通过将知识从教师网络传递到学生网络 但