论文阅读-多任务(2020)-KL4MTL:用于多任务学习的知识蒸馏方法

2023-11-17

KL4MTL

论文:Knowledge Distillation for Multi-task Learning

地址:https://paperswithcode.com/paper/knowledge-distillation-for-multi-task

论文总览

   多任务学习的目标是使得单个模型能够在多个任务上取得较好的结果,这样能够降低计算代价。该种模型的学习过程需要同时优化多种任务的损失,这些任务有着不同的学习难度、不同的维度以及不同的特征,对应着不同的损失函数,这很容易导致各个任务之间的学习程度不均衡。为此论文提出了一种用于多任务的蒸馏方法,首先为每个任务学习一个专用模型,然后学习一个多任务的模型用于最小化每个特定任务模型的损失并为单个模型生成相同特征。而专用模型会生成各自的特征,因此论文引入了一个针对单个任务的特征适配器来将多任务模型的特征映射到单一任务模型特征,使得跨任务的参数共享更加均衡。

多任务学习的好处与难点

   多任务学习的好处:1) 训练和推断的计算代价更小(相同目标下);2) 多任务之间共享参数能够使得模型更容易训练,并且对测试数据的泛化能力更强。

   多任务学习的难点:1)求同存异的网络架构,即多个任务之间有限的参数共享并保持各自专门的独特的参数,传统方法中多任务模型保持几乎所有层共享,只保留最后几层不同,后者明显是一个非最优模型,但是找到最优模型的代价无疑非常大;2)兼顾各个任务的模型训练过程,即多任务模型训练需要联合优化一系列损失,这些损失学习难度、维度与特征各不相同,最土味的方法就是给各个损失加上不同的权重,搜索合适的权重组合同样代价惨重,所以通常会导致一个次最优模型(实际上针对这个问题,已经有不少对损失权重平衡策略的研究,但是仍然面临着某个人物主导训练过程以及低精度的问题)

论文阅读

   该论文提出一个知识蒸馏的方法来应对MtL不均衡的问题,鉴于为单个损失函数施加不同权重或者修改梯度的方法对于参数学习的有限影响,很难限制单个任务主导训练过程的问题,论文提出一个更加严格的mtl网络参数控制方案。由于在充足数据条件下专用网络总是效果更好,论文假设多任务模应该尽可能接近专用模型的参数分布,甚至应处于各个专用模型参数分布的重合区域即交集。出于这样的目标,论文首先为每个任务训练一个专用模型并冻结其权重,然后优化mtl模型的参数使得联合损失最小化并且生成与专用网络尽可能接近的特征。鉴于每个专用网络的特征不同,论文为每个任务都引入一个小适配器来调整专用模型与mtl模型之间的特征差异,平和各个任务之间的参数共享。

   对于单一任务学习,给定一个包含 N N N个训练图像数据 x i x^i xi及其对应的 T T T个任务的标签 y 1 i , . . . , y T i y_1^i,...,y_T^i y1i,...,yTi的数据集 D D D,对于STL我们的目标是学习 T T T个不同的专用模型,每一个模型豆浆输入 x x x映射到对应的标签 y τ y_{\tau} yτ,即 f ( x ; θ τ s , ϑ τ s ) = y τ f(x;\theta_\tau^s,\vartheta_\tau^s)=y_\tau f(x;θτs,ϑτs)=yτ,其中 s s s表示特定任务, θ τ s \theta_\tau^s θτs ϑ τ s \vartheta_\tau^s ϑτs是网络的参数,每个专用网络包含两个部分:i) 一个特征编码器 ϕ ( ⋅ ; θ τ s ) \phi(\cdot;\theta_\tau^s) ϕ(;θτs),接收衣服图像然后输出一个高维度的特征编码 ϕ ( x ; θ τ s ) ∈ R C × H × W \phi(x;\theta_\tau^s)\in R^{C\times H\times W} ϕ(x;θτs)RC×H×W,其中 C , W , H C,W,H C,W,H分别表示特征图的通道数、高度和宽度。ii)一个预测器 ψ ( ⋅ ; ϑ τ s ) \psi(\cdot;\vartheta_\tau^s) ψ(;ϑτs)接收 ϕ ( x ; θ τ s ) \phi(x;\theta_\tau^s) ϕ(x;θτs)然后输出特定任务 τ \tau τ的预测结果,即 y ^ τ = ψ ( ⋅ ; ϑ τ s ) ∘ ϕ ( x ; θ τ s ) \hat y_\tau=\psi(\cdot;\vartheta_\tau^s)\circ\phi(x;\theta_\tau^s) y^τ=ψ(;ϑτs)ϕ(x;θτs),其中 θ τ s , ϑ τ s \theta_\tau^s,\vartheta_\tau^s θτs,ϑτs分别表示专用网络编码器和预测器的参数,这两个参数通过优化专门任务的损失函数 l ( y ^ , y ) l(\hat y,y) l(y^,y),优化过程如下:
在这里插入图片描述

   对于多任务学习,则需要学习一个绝大多数参数可以在不同任务间共享的模型,多任务模型也可以分为两个部分即特征编码器和多个针对特定任务的预测器,输出结果可以表示为 ψ ( ⋅ ; ϑ τ m ) ∘ ϕ ( x ; θ τ m ) \psi(\cdot;\vartheta_\tau^m)\circ\phi(x;\theta_\tau^m) ψ(;ϑτm)ϕ(x;θτm),最小化损失的优化过程如下:
在这里插入图片描述

   其中 w τ w^\tau wτ表示特定任务的损失的权重。多任务学习显然损失优化过程更难,因为不同任务的损失可可能维度和计算都区别很大,一种平衡各个损失的方法就是通过交叉验证的方式确定每个任务的权重,但是这种连续空间下超参数搜索的方法非常低效而且即便找到一组最优参数,在实际优化的过程中也有可能得到一个次优解。

   基于以上问题,很多方法试图得到一个哦动态权重平衡策略来对每一个训练iter进行调整,但是论文认为这种方法对于网络参数的控制非常有限,所以提出一个基于知识蒸馏的方法,为了完成知识蒸馏,首先为每个任务 τ \tau τ训练一个专用模型 f ( ⋅ ; θ τ s , ϑ τ s ) f(\cdot;\theta_\tau^s,\vartheta_\tau^s) f(;θτs,ϑτs)然后冻结参数仅使用特征编码器部分,通过最小化专用网络和多任务网络特征间在训练过程中调整多任务模型。因为专用模型的编码器的输出特征可能会存在显著差异以至于多任务模型的编码器输出特征不能同时捕获所有特征信息,论文通过多个特殊的小适配器 A τ A_\tau Aτ将多任务特征编码器的输出映射到专用网络的编码器输出特征空间,实验中该适配器使用一个 1 × 1 × C × C 1\times 1\times C \times C 1×1×C×C的卷积层来进行特征映射,这些特征适配器的学习过程伴随着多任务网络的训练过程,以此调整编码器输出特征间 差异,该过程如下公式所示:
在这里插入图片描述

   其中 l d l^d ld表示经过L2正则化之后的特征图间的欧几里得距离,如下公式所示:
在这里插入图片描述

   所以整个多任务模型优化过程,可以简单的表示为如下公式:
在这里插入图片描述

结论

   因为特征适配器是一个线性过程,反过来看也可以是将每个专用模型的编码器特征映射为适用于所有任务的特征,这里假设专用模型与多任务模型之间的特征表示差异是一个简单的线性变换,尽管看起来不可思议,但是很多任务中观察到的结果也确实是这样的。


欢迎扫描二维码关注微信公众号 深度学习与数学 ,每天获取免费的大数据、AI等相关的学习资源、经典和最新的深度学习相关的论文研读,算法和其他互联网技能的学习,概率论、线性代数等高等数学知识的回顾。
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

论文阅读-多任务(2020)-KL4MTL:用于多任务学习的知识蒸馏方法 的相关文章

  • Unity PlayerPrefs、JsonUtility

    Unity中有两个常用的数据存储方式 PlayerPrefs和JsonUtility PlayerPrefs PlayerPrefs是Unity内置的一种轻量级数据存储方式 可用于存储少量的游戏数据 如分数 解锁状态等 使用PlayerPr
  • LPDDR4协议规范之 (六)刷新

    LPDDR4协议规范之 六 刷新 刷新命令 刷新计数器 刷新时序 刷新前时序 刷新后时序 全存储体刷新时序 tRFCab tRFCpb 自刷新 自刷新期间进入掉电模式 自刷新中止 刷新命令 REFRESH命令在时钟的第一个上升沿以CS HI

随机推荐

  • GDI+ 中 Pen 使用总结

    背景 图形类 Graphics 是 GDI 的核心 它提供绘制图形 图像和文本的各种方法 Graphics 中使用 DrawString 方法在指定位置绘制文本或者在一个指定矩形内绘制文本 所有的 Graphics 类的绘制方法都得配合 P
  • MXNet简介

    MXNet是一个十分优秀的深度学习框架 目前包含了许多语言接口 如Python C Scala R等 目前 MXNet版本已经更新到1 3 0 本系列文章主要使用Python接口 在MXNet官网 1 上 官方建议新手使用Python接口
  • 网站代理是什么?有什么需要注意的?

    如今 网站代理已经成为一种不可或缺的经营方式 无论是企业还是个人 都需要通过代理来获得更多的流量和市场份额 一 网站代理的优势 网站代理的优势在于能够为您提供更加专业 周到的服务 这些优势包括 1 丰富的内容资源 能够满足客户对不同领域信息
  • Java-按照指定小时分割时间段

    按照指定小时分割时间段 param dateType 类型 M D H N gt 每月 每天 每小时 每分钟 param dBegin开始时间 param dEnd结束时间 param time 指定小时 如 1 2 3 4 return
  • 变分推断学习

    https zhuanlan zhihu com p 401456634 变分推断 1 变分推断的背景 在机器学习中 有很多求后验概率的问题 求后验概率的过程被称为推断 Inference 推断分为精确推断和近似推断 精确推断一般主要是根据
  • H5播放之Rtsp转Websocket点播录像抓拍

    H5播放之Rtsp转Websocket点播录像抓拍 HLS的延时 websocket播放 实现思路 广大网友们 很久没上CSDN了 暨上次RTSP转HLS文章发布以来 一直还有一个问题没有解决 如何避免HLS切片带来的不可避免的高延时 HL
  • 浅谈PCA 人脸识别

    前几天讨论班我讲了基于PCA的人脸识别 当时我自己其实也只是知道这个算法流程 然后基于该算法利用c 实现了 效果还不错 后来跟师兄一起讨论的时候 才发现这个PCA还是有相当深刻的意义 PCA的算法 矩阵C AAT A的每一列是一张人脸注 将
  • Java的基础语法

    1 关键字介绍 1 Java 中一些赋以特定的含义 用做专门用途的字符串称为关键字 keyword 2 所有Java关键字都是小写英文字符串 2 Java变量 1 Java变量是程序中最基本的存储单元 其要素包括变量名 变量类型和作用域 2
  • c语言用指针找最大数,C语言,用指针。求输入20个数,依次输出这几个数,求最大值,最小值。...

    满意答案 yuab0p0dpi3 2013 11 22 采纳率 53 等级 13 已帮助 13064人 include include define LENGTH 20 void main int pBuff int malloc size
  • FTTR(Fiber To The Room)组网详解

    FTTR Fiber To The Room 是一种新型的光纤宽带接入技术 主要用于宽带网络覆盖范围有限 带宽瓶颈较严重的酒店 公寓 医院等场所 FTTR技术可以将光纤信号传输到用户房间内 实现高速 稳定的网络接入 提高用户体验 下面我们详
  • Python的最大递归深度

    import sys old sys getrecursionlimit print old 1000 可能是个估计值 我不清楚我没查 报错范围总是比限制要小2 我的电脑上 我不知道为什么 感兴趣可以查一查 sys setrecursion
  • Anaconda的使用

    1 anaconda介绍 Python虽然是一门优秀的程序语言 但其拥有出色的数据处理能力 尤其是在数据量巨大的时候 因而也吸引了不少数据分析人员的关注和使用 Python的数据处理能力主要依赖于NumPy SciPy Matplotlib
  • 在Maven中前端构建实践

    NodeJS为前端技术的发展带来了一次革新 层出不穷的前端库 框架以及打包工具让大家应接不暇 然而这使得前端技术越来越依赖于NodeJS 基于NodeJS编写的前后台项目可以使用同一编译或者打包工具进行管理从而做到无缝的前后端版本控制以及联
  • JSON和xml的区别

    首先 json和xml都是在远程调用或者和某公司合作时的数据交换格式 json和xml的区别 有什么优缺点 ajax 的 和json优缺点 相同点 json与xml是一种远程数据传输交换格式 json是轻量级的 xml标记电子文件具有结构性
  • Fsm serial

    In many older serial communications protocols each data byte is sent along with a start bit and a stop bit to help the r
  • 计算机组成原理-8、总线与输入输出系统

    前言 最近备研学习计算机组成原理的一些笔记 记得比较仓促 仅供个人参考 等明年会仔细结合自己的一些看法加以改进 如有不足之处 还请多多指教 文章目录 总线与输入输出系统 总线与输入输出系统概述 总线 总线类型与结构 总线的信息传输方式 总线
  • 操作系统实验三:用PV操作实现司机售票员进程同步(C语言实现)

    代码如下 driver spy cpp include
  • Docker + Jenkins 详细安装步骤

    一 安装Docker 1 安装依赖环境 yum y install yum utils device mapper persistent datalvm2 2 配置Docker镜像源 yum config manager add repo
  • 调试osgEarth(33)分页瓦片卸载器子节点的作用-(3)渲染遍历的帧号和时间设置-TerrainCuller赋值给可渲染图层--TerrainRenderData--深度摄像机

    继续调试 可见 在当前环境下 definelist为空 不会再有 OE IS DEPTH CAMERA 因此不是深度摄像机 果然为false 总结下 这里是通过摄像机的状态集的 definelist是否包含 OE IS DEPTH CAME
  • 论文阅读-多任务(2020)-KL4MTL:用于多任务学习的知识蒸馏方法

    KL4MTL 论文 Knowledge Distillation for Multi task Learning 地址 https paperswithcode com paper knowledge distillation for mu