知识蒸馏基础及Bert蒸馏模型

2023-11-16

为了提高模型准确率,我们习惯用复杂的模型(网络层次深、参数量大),甚至会选用多个模型集成的模型,这就导致我们需要大量的计算资源以及庞大的数据集去支撑这个“大”模型。但是,在部署服务时,就会发现这种“大”模型推理速度慢,耗费内存/显存高,这时候我们又会想念“小”模型的好。那么,有没有一种方法能够尽可能继承大模型的泛化能力,又像小模型一样轻量级呢?今天来介绍一种模型压缩的方法——蒸馏(Distillation)。

传统的蒸馏

首次提出知识蒸馏压缩模型思想的是2006年Bucilua,但是论文里没有实际工作阐述:https://www.cs.cornell.edu/~caruana/compression.kdd06.pdf
所以,一般认为最早是Hinton在2015年提出并应用在了分类任务上:Distilling the Knowledge in a Neural Network。我们来阐述一下传统的知识蒸馏过程:简单地说,就是先用数据集训练一个效果非常好的Teacher模型,然后选择一个较为轻量级的Student模型,同时接受数据集和来自Teacher模型给予的Knowledge Transfer的“知识”来训练这个轻量级Student模型。那么整个蒸馏的过程中,我们主要关心的就是Teacher模型的选择、Student模型的选择、以及Student模型的训练过程(或者说是Knowledge Transfer过程)。
在这里插入图片描述

Teacher模型:首先,我们需要一个原始的“大”模型——Teacher模型,这个模型可以不限制其结构、参数量、是否集成,要求这个模型尽可能精度高,并且对于给定的输入X可以给出输出的监督信息Y,这个Y在分类任务中就是softmax的结果,也就是输出对应类别的概率值。这里我们称Y为soft targets,而训练数据的标注好的标签,我们称为hard targets

Student模型:这个部分的模型选择会有很多限制,要求其参数量小,结构相对简单,当然最好是单模型。并且需要注意的是,训练过程中student模型学习的不再是单纯的hard targets(标注好的真实标签),而是融入teacher模型输出的soft targets(监督信息Y),这里也被称为knowledge transfer。蒸馏的损失函数distillation loss分为两部分:一部分计算teacher和student之间输出预测值的差别(student预测的y 和 soft targets),另一部分计算student原本的loss(student预测的y 和 hard targets),这两部分做凸组合作为整个模型训练的损失函数来进行梯度更新,最终获得一个同时兼顾精度和性能的student模型。

这里单独说一下teacher和student之间输出预测值的loss,这个部分被做的文章也是比较多,这实际上是两个分布的距离问题,可以选择传统的Cross,也可以选择MSE、KL散度等,在博主的实验里发现对不同的student模型,适合不同的loss函数,这里只能自己多做尝试。

为什么蒸馏会有效?

那么,肯定有人想问,为什么蒸馏会有效?直接从数据集学习不是更为直观没有中间商赚差价吗?本质上,蒸馏的训练方式主要是改变了模型只能单一地学习label的这个缺陷。原本模型从数据集的标注数据中学习,而蒸馏过程学习的知识融入了Teacher模型输出的监督信息Y,在分类任务上也就是softmax结果,其中包含了Teacher模型的泛化能力。

具体的举个例子,我们做新闻分类,类别分别为社会、财经、娱乐、生活。此时我们有一条社会类目的新闻,其hard target为[1, 0, 0, 0]。而经过teacher模型,输出其soft target为[0.88, 0.01, 0.01, 0.1],那么我们可以发现soft target中学习到:首先,这条新闻确实是社会类目;其次,这条新闻是生活类目的可能性要比财经和娱乐类目的高。那么模型通过同时学习hard target和soft target获得的知识要比只学习hard target的更多。换句话说,在分类的模型中,我们的蒸馏模型不仅能学习到本身这个分类任务,还可以额外获得类别间的相似性知识,那么理论上,蒸馏模型的泛化能力一定要比同样模型结构在该数据集上训练的模型强。

也就是说,蒸馏模型学习的不仅是数据集中的知识,还有Teacher模型的泛化能力

蒸馏模型的分类

从不同的角度看蒸馏模型可以有不同的分类,这里给出两种区分,分别来自两篇文章。

从训练方式区分

论文地址:Knowledge Distillation and Student-Teacher Learning for Visual Intelligence: A Review and New Outlooks
在这里插入图片描述

  • 离线蒸馏方式,即为传统的知识蒸馏,如上图(a)。一般来讲,Teacher模型的参数在蒸馏训练过程中保持不变,选用的Teacher模型和Student模型准确性相对悬殊比较大,并且Student模型会在很大程度上依赖Teacher模型。
  • 半监督训练方式,利用了Teacher模型的预测信息作为标签来对Sudent网络进行监督学习,如上图(b),不同于传统的离线蒸馏方式,在对Student模型训练之前,先输入部分未标记的数据,利用Teacher网络输出的标签作为监督信息,再输入到Student网络中来完成蒸馏,这样可以使用更少的标注数据,达到提升模型精度的目的。在online蒸馏中,Student模型和Teacher模型将同时更新,整个知识提炼框架是可以从端到端训练的。给出一篇online蒸馏的文章:Online Knowledge Distillation with Diverse Peers
  • 自监督蒸馏,相比于传统的离线蒸馏方式,是不需要提前训练一个Teacher模型的,而是Student网络本身的训练是一个蒸馏过程,如上图(c)。具体的实现方式有很多种,比如训练Student模型时,在整个训练过程的最后几个epoch的时候,利用前面训练的Student模型作为监督模型,在剩下的几个epoch中对模型进行蒸馏。这样做的好处,是不需要提前训练一个Teacher模型,可以做到边训练边蒸馏,节省整个蒸馏过程的训练时间。同样给出一篇自监督的蒸馏:Be Your Own Teacher: Improve the Performance of Convolutional Neural Networks via Self Distillation

从知识来源位置区分

论文地址:Knowledge Distillation: A Survey (这篇文章总结的特别全,可以看一下下图,这里只拎出来Sec2说说)
在这里插入图片描述

从知识来源位置维度考虑,蒸馏模型可以分为Response-Based、Feature-Based和Relation-Based的知识蒸馏。从下图可以直观感受到,Response-Based的知识是从teacher模型的output layer中学习到的,而Feature-Based是从hidden layer中学习到的知识,Relation-Based则是学习input-hidden-output之间的关系。

在这里插入图片描述

Response-Based

在这里插入图片描述
基于response的知识蒸馏实际上也就是传统是知识蒸馏模型,response通常指的是teacher模型最后一个输出层,比如分类任务中的softmax层的输出,其主要思想是直接模拟teacher的最终决策。基于response的知识蒸馏对于模型压缩来说是最简单有效的,并且被广泛应用于不同的任务和场合中。Hinton提出的蒸馏模型也是采用了这样的方法。Student学模型学习teacher模型的输出分布,相当于同时给予了类别之间的相似性信息,同时额外提供了监督信息,学习起来较为容易,实现起来也较为容易。但是蒸馏的效率依赖于softmax loss计算和类别的数量。从实验效果上看,如果student模型较小,或者和teacher模型差别过大的时候,蒸馏的效果不尽如人意。

Feature-Based

在这里插入图片描述
首次提出Feature-Based的文章是:FITNETS: HINTS FOR THIN DEEP NETS, 实际上是对Hinton提出的蒸馏模型的一种拓展。从上图可以清晰的明白,Feature-Based是从一些中间隐层中学习知识,其允许student网络可以比teacher网络更深更窄,从teacher网络中间层提取特征结果,作为student网络中间层输出的hint,也就是说teacher网络的中间层去指导student网络训练。因为student网络相比于teacher网络较窄,所以student网络中间层连接一个Wr网络和teacher网络进行适配,这个用于适配的网络选择了卷积网络,节省计算量。
在这里插入图片描述

Relation-Based

在这里插入图片描述
Relation-Based 不拟合Teacher模型中间层或者输出层的结果,而是拟合Teacher模型内层与层之间的关系,这个关系是用层与层之间的内积来定义的。参考论文:A Gift from Knowledge Distillation:Fast Optimization, Network Minimization and Transfer Learning。

蒸馏在NLP中的应用

在NLP的大部分任务中,我们可能习惯上追崇Bert大法,但是Bert本身参数量比较大,在一些特殊情况下,我们需要部署一个小而美的模型,这时候我们需要给Bert进行“瘦身”。一般认为比较有效的瘦身方法有上面介绍的蒸馏、量化(Quantization)、剪枝(Pruning)。这里我们介绍几个效果不错的Bert蒸馏模型。

DistillBERT

论文地址:https://arxiv.org/pdf/1910.01108.pdf
项目地址:还没开放
这里选择了bert-base作为teacher网络,除此之外罗列一下DistillBERT的特别之处:

  • Student模型结构变化:DistilBERT中Student模型的整体结构和Bert基本相同,不过Bert采用了12层的transformer encode,而DistilBERT采用的6层的transformer encode,这里作者注意到hidden size维度的变化对模型计算效率的影响小于层数变化的影响,因此DistilBERT主要改变的是Bert层数。其次Student模型移除了token-type embedding和pooler。
  • Student模型初始化工作:DistilBERT没有进行自己的预训练,而是将Bert部分参数直接加载到DistilBERT的结构中,作为初始化。
  • Student模型训练损失函数:这里是损失函数包含三个部分:1. 传统蒸馏的损失:teacher网络softmax层输出的概率分布和student网络softmax层输出的概率分布的交叉熵;2. 传统模型训练的损失:student网络softmax层输出和真实标签的交叉熵;3. student网络隐层输出和teacher网络隐层输出的余弦相似度值;

训练方法和Roberta类似,采用了大batch、动态mask、扔掉NSP任务等,关于Roberta可以回顾一下:bert的兄弟姐妹梳理——Roberta、DeBerta、Albert、Ambert、Wobert等

DistillBERT的思想还是比较简单的,根据文中给出的实验效果看,模型参数减小了40%(66M),推断速度提升了60%,但精度大概下降了3%左右。
在这里插入图片描述

TinyBERT

论文地址:https://arxiv.org/pdf/1909.10351.pdf
项目地址:https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT
在这里插入图片描述
这篇也是从蒸馏角度出发,和DistillBERT的思路相差不大,都是缩减模型结构:减少层数和hidden size,差异可能更多体现在 loss 的设计上,此外,作者还提出了两段式学习框架,旨在提升特定任务的TinyBERT精度。

模型结构:TinyBERT层数相对bert-base从12层降低到4层;FFN层输出的大小从3072降低到1200,Head个数维持12不变,hiddent size从768降至312;最终参数量从110M降低到14.5M。

在这里插入图片描述

损失函数主要分为三个部分,但是和DistillBERT的设计差别还是挺大的:

在这里插入图片描述

在这里插入图片描述

  • Embedding-layer Distillation:student网络的embedding和teacher网络的embedding的MSE损失;在这里插入图片描述

  • Transformer-layer Distillation:这里分为两个部分:1. attention based distillation:student网络第 i 个attention头的attention score矩阵和teacher网络第 i 个attention头的 attention score矩阵的MSE损失的平均值;在这里插入图片描述

  1. hidden states based distillation:student transformer 和 teacher transformer 的隐层输出的MSE损失在这里插入图片描述
  • Prediction-Layer Distillation: teacher 输出的概率分布和 student 输出的概率分布的 softmax 交叉熵在这里插入图片描述

两段式学习框架:BERT 的应用通常包含:预训练和微调。BERT在预训练阶段学到的大量知识非常重要,并且迁移的时候也应该包含在内。因此,研究者提出了一个两段式学习框架,包含通用蒸馏和特定于任务的蒸馏,这样做的目的是:TinyBERT 可以获取 LargeBERT 的通用和针对特定任务的知识,两段式蒸馏可以尽可能地缩小 teacher 和 student 模型之间的差距。本质上就是在pre-training蒸馏一个通用的TinyBERT,然后再在通用的TinyBERT的基础上利用task-bert上再蒸馏出微调版的TinyBERT。
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

知识蒸馏基础及Bert蒸馏模型 的相关文章

随机推荐

  • 计算Shell脚本执行时间

    startTime date Y m d H M S startTime s date s 执行程序 endTime date Y m d H M S endTime s date s 计算时长 sumTime endTime s star
  • HyperLedger Fabric - 超级账本(4)链码的概念与使用

    概念 Chaincode 链上代码 简称链码 一般是指用户编写的应用代码 链码被部署在Fabric网络节点上 Peer 节点 背书节点 commit节点 Leader节点 锚节点 运行在隔离沙盒 当前为Docker容器 中 并通过gRPC协
  • QT--连续发送数据包

    提示 本文为学习记录 若有错误 请联系作者 谦虚受教 文章目录 前言 一 数据包 二 相关实现代码 三 Char转QByteArray 总结 前言 萤火虫在灯光下为什么不会亮呢 一 数据包 定义 包 Packet 是TCP IP协议通信传输
  • django vue前后台传对象及数据highcharts图表显示

    总体思路 Vue对象中定义data数据对象 axios get方法通知后台产生数据 后台用JsonResponse返回数据 注意写上safe False 前台用response data取回数据 存入Vue对象的data对象中 在js中通过
  • 剑指 Offer 53 - I. 在排序数组中查找数字 I

    剑指 Offer 53 I 在排序数组中查找数字 I 题目 题目链接 具体代码 题目 题目链接 https leetcode cn com problems zai pai xu shu zu zhong cha zhao shu zi l
  • 使用Cartool工具包分析EEG源成像

    使用Cartool工具包分析EEG源成像 1 基本要求 1 1 EEG预处理 EEG Pre processing 1 2 时间过滤 Temporal Filtering 1 3 下采样 降低采样率 Down Sampling 1 4 电极
  • MAC OS更新系统后IDEA中的SVN报错无法使用

    IntelliJ IDEA无法正常使用SVN 报Cannot run program svn in directory XXXX XXXX XXXX XXX error 2错误 使用Mac的小伙伴 在更新Mac系统的之后 通过idea操作s
  • 充电灯 低电灯共用一个 LED

    充电灯 低电灯共用一个 LED 电路 产品经理提出一个需求 因为结构只有一个灯孔 需要实现充电 低电指示灯共用一个LED 考虑到电源开关关闭时充电要亮 所以用电池正极作为 LED阳极 LED 阴极由两条线路控制 一个是充电IC的充电指示脚
  • 玩转 SpringBoot 监控统计(SQL监控、慢SQL记录、Spring监控、去广告)

    关注后回复 进群 拉你进程序员交流群 作者 架构师的小跟班 来源 blog csdn net weixin 44730681 article details 107944048 基本概念 Druid 是Java语言中最好的数据库连接池 虽然
  • NXP imx6ull uboot-imx-rel_imx_4.1.15无法从SD卡加载内核

    我imx6ull使用的是正点原子的alpha开发板 我将我的SD卡分成了两个分区 第一个分区格式化为fat 用来存放zImage和dtb 第二个分区格式化为EXT4格式 这个分区作为根文件系统 遇到的问题是 我编译好kernel后 尝试通过
  • PLC程序的基本组成和编程语言

    一般情况 PLC程序由主程序 多个子程序 多个中断服务程序等三部分组成 这三部分被组织在一起 经过编译可以下载到PLC中运行 如下图就是一个简单的例子 MAIN是主程序 SBR0是子程序 可以被MAIN调用 运行在一个循环中 中断服务程序独
  • 如何解决Visual Studio2019登录微软账户登录不上的问题

    试用期30天过了 这个问题困扰了我好几天 几乎把网络上所有的问题都搜索遍了 下面总结了网上常用的解决方案还有登录方式 登陆不上微软账户排除自己密码账户没有输入正确以外 是网络的问题 我没有用WiFi 用的是手机的热点连接的 方法1 如果挂V
  • IDEA技巧-快速编写一个String类型的JSON对象

    1 先编写一个String类型空值对象 String strJson 2 将光标放在 中间 3 Alt Enter调出Inject language or reference视图界面 回车选中Inject language or refer
  • 数据在OSI七层模型中的名字 数据帧、数据包、数据报以及数据段

    数据帧 数据包 数据报以及数据段 OSI参考模型的各层传输的数据和控制信息具有多种格式 常用的信息格式包括帧 数据包 数据报 段 消息 元素和数据单元 信息交换发生在对等OSI层之间 在源端机中每一层把控制信息附加到数据中 而目的机器的每一
  • 多种系统如何安装并启动Redis

    1 Windows 系统下安装 首先坏消息是reids官网没有提供windows版的redis 但好消息是微软的开源技术团队在gtihub上开发和维护了windows版的redis 具体如何使用参考下这片文章 windows系统本地安装re
  • Struts2 重点总结 (2)

    国际化 资源文件和资源包 要用Struts实现国际化和本地化 首先要定义资源文件的名称 这个文件会包含用默认语言编写的会在程序中出现的所有消息 这些消息以 键 值 对的形式存储 如下 error validation localtion T
  • 软测入门(十)Jmeter接口测试基础

    接口测试流程 接口测试的流程 分析接口文档和需求 编写接口测试计划 5W1H 编写接口测试用例 接口测试执行 输出接口测试报告 接口测试分类 Web接口测试 服务器接口测试 模块接口测试 单元测试 接口测试的要点 数据是否正常 参数类型错误
  • 人工智能基础部分16-神经网络与GPU加速训练的原理与应用

    大家好 我是微学AI 今天给大家介绍一下人工智能基础部分16 神经网络与GPU加速训练的原理与应用 在深度学习领域 神经网络已经成为了一种流行的 表现优秀的技术 然而 随着神经网络的规模越来越大 训练神经网络所需的时间和计算资源也在快速增长
  • Ajax传json对象(jQuery)

    Ajax传json对象 相信很多小伙伴想要通过Ajax传输json数据给后端 本来直接发送一个data JSON stringify obj 就可以了 但是发现后端的请求参数中有一个参数需要int类型 这个时候就需要用到对象了 封装对象 首
  • 知识蒸馏基础及Bert蒸馏模型

    为了提高模型准确率 我们习惯用复杂的模型 网络层次深 参数量大 甚至会选用多个模型集成的模型 这就导致我们需要大量的计算资源以及庞大的数据集去支撑这个 大 模型 但是 在部署服务时 就会发现这种 大 模型推理速度慢 耗费内存 显存高 这时候