多任务学习:Multi-Task Learning as Multi-Objective Optimization

2023-05-16

前言

        最近在写一篇文章,是一篇深度学习与安全相结合的文章,模型的输出会交给两个损失函数(availability & security)进行损失计算,进而反向传播。起初的想法是直接将两项损失进行加权平均,共同进行反向传播,后面又尝试了先A后B和先B后A的方式。发现模型训练的效果不是很好,因为这两个损失在进行下降时是一种相互制约的关系,如图1所示(侧面也反映了自己设计的连个损失方向是对的)。在epoch到达30w次之后,两者分道扬镳。

        考虑到多任务之间的制约,尝试使用多目标优化的方法对两个损失函数进行优化,以获得使两种损失同时较小的一种嵌入,然后再将此满足条件的嵌入作为模型的训练目标,分步完成模型的训练。但是在研究了多目标优化的算法之后,打消了这个念头,因为我模型的输出维度很高,他是一张图的邻接矩阵的一维展开(adj.view(-1)),是百万维级别的,所以使用传统多目标优化的方法(例如遗传算法)过于复杂,所以就发现了这一篇名为Multi-Task Learning as Multi-Objective Optimization的文章。下面简单的介绍一下这篇文章。

Multi-Task Learning as Multi-Objective Optimization

作者在文章的摘要中说:“多任务学习本质上是一个多目标的问题,因为不同的任务可能会发生冲突,需要进行权衡。一个常见的折衷办法是优化一个代理目标,使每个任务损失的加权线性组合最小。然而,这种变通方法只有在任务不竞争的情况下才有效,而这种情况很少发生。在本文中,我们明确地将多任务学习作为多目标优化,其总体目标是找到一个帕累托最优解。”

这个是文章为何能够适用于我的问题的原因。细节不做过多介绍,在阅读过程中发现网上有不少对于这篇文章原理的解释,大家有疑问可以自行查看。

对Multi-Task Learning as Multi-Objective Optimization方法进行尝试

Solving the Optimization Problem

        这是文章中的一段话。在处理一般情况之前,先处理只有两个优化目标的一般情况。此时的优化目标是一个一元二次函数:  min_{\alpha\in[0,1]}||\alpha\bigtriangledown_{\theta^{sh}}\hat{\mathcal{L}}^1+(1-\alpha)\bigtriangledown_{\theta^{sh}}\hat{\mathcal{L}}^2||^2_2 ,其解为:

\alpha = \left [ \frac{(\bigtriangledown_{\theta^{sh}}\hat{L}^2-\bigtriangledown_{\theta^{sh}}\hat{L}^1)^T\bigtriangledown_{\theta^{sh}}\hat{L}^2}{||\bigtriangledown_{\theta^{sh}}\hat{L}^1-\bigtriangledown_{\theta^{sh}}\hat{L}^2||_2^2} \right ]_+

        同时,作者又介绍道,对于每一个任务t,都需要计算\bigtriangledown_{\theta^{sh}}\mathcal{\hat{L}}(\theta^{sh},\theta^t),而参数\theta^{sh}在反向传播阶段恰好是多任务共享的参数,所以在反向传播阶段需要计算T次。因此作者为了降低计算复杂度,确定了优化问题的上界。

 其中Z为表示层的输出,也就是共享参数的各层最后的输出。 

又因为\left | \frac{\partial Z }{\partial \theta^{sh}}\right |^2 与\alpha无关,所以在优化中会被移除。因此使用以下上界来优化目标。

 

 应用

        根据以上分析,我对自己的模型进行了改进,先分享一下源代码:

grads = {}

#这个函数是为了获取中间变量的梯度,我方案中的Z不是一个叶子结点,所以其梯度在反向传播之后不会被保存
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook

def MTL(loss, Floss):
    '''
    使用多任务学习的多个梯度来决定最终梯度
    :param loss: 损失1
    :param Floss: 损失2
    :return:
    '''
    Z.register_hook(save_grad('Z'))
    #对loss进行反向传播,获取各层的梯度
    loss.backward(retain_graph = True)
    theta1 = grads['Z'].view(-1)
    #将计算图中的梯度清零,准备对第二种loss进行反向传播
    optimizer.zero_grad(retain_graph = True)
    Floss.backward()
    theta2 = grads['Z'].view(-1)
    #alpha解的分子部分
    part1 = torch.mm((theta2 - theta1), theta2.T)
    #alpha解的分母部分
    part2 = torch.norm(theta1 - theta2, p = 2)
    #二范数的平方
    part2.pow(2)
    alpha = torch.div(part1, part2)
    min = torch.ones_like(alpha)
    alpha = torch.where(alpha > 1, min, alpha)
    min = torch.zeros_like(alpha)
    alpha = torch.where(alpha < 0, min, alpha)
    #alpha theta1 & (1 - alpha) theta2
    #将alpha等维度拓展
    alpha1 = alpha
    alpha2 = (1 - alpha)
    #将各层梯度清零
    optimizer.zero_grad()
    #根据比率alpha1 & alpha2分配Loss1和Loss2的比率
    MTLoss = alpha1 * loss + alpha2 * Floss
    MTLoss.backward()

在我自己的方案中,其实也是进行了三次反向传播,其实可以只计算出表示层的输出Z的梯度,然后根据梯度计算alpha,直接将两种Loss分配权重,但是这里的操作我并没有实现,因为每次backward()之后,其实是进行了一次完整的反向传播,如果不进行backward()又没有办法快速求出Z的梯度。所以我还没研究清楚如何进行控制,如果大家有解决的方法,也希望能够不吝赐教。

        在这里我为何还要使用Z而不是表示层的梯度,是因为我的表示层的权重矩阵是一个多维的变换矩阵,所以在进行帕累托最优求解的时候比较麻烦,而表示层的输出Z,是一个一维的向量,进行帕累托最优解寻找的过程比较简单,方便计算。当然这也只是我在阅读完文献后的第一次尝试,可能在理解上还有很多错误,后面会不断进行改进。

结果

 这个是使用文章中方法改进之后的梯度下降损失的变化,其中AP_loss是Loss1,F_loss是Loss2,可以看出两者都在进行下降,同时与我之前使用的手动调参(Loss1 & Loss2全程按照一种比例进行梯度下降)相比,损失函数的震荡情况获得了十分明显的改善。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

多任务学习:Multi-Task Learning as Multi-Objective Optimization 的相关文章

  • ST-Link 在keil5无法下载程序解决办法

    以前一直在用J Link下载程序 xff0c 由于工作需要 xff0c 换成ST Link下载程序 第一次用ST Link怎么也下载不下去 xff0c 后来差CSDN博客 xff1a https blog csdn net zeroice7
  • 实时时钟DS1302-第1季第14部分-朱有鹏-专题视频课程

    实时时钟DS1302 第1季第14部分 2594人已学习 课程介绍 本课程是 朱有鹏老师单片机完全学习系列课程 第1季第14个课程 xff0c 主要讲解了实时时钟DS1302芯片的编程和使用 xff0c 本课程的关键是引入了时序的概念 xf
  • Mac使用npm install报错,需使用sudo

    1 首先说下个人的经历 xff0c 从18年开始实习第一次使用npm xff0c 当时用npm install却总是会报一些错误 xff0c 主要是因为无权限 最初的解决方案自然是使用sudo xff0c 这个是有效的 如果用sudo还不行
  • 信号量、邮箱、队列与事件

    信号量 xff0c 邮箱 xff0c 队列的最大不同在于它们发送的内容不同 信号量是一个触发信号 xff0c 也是一个计数器 xff0c 等待接收信号的任务一般只有接收到信号才可以执行 xff0c 否则任务一直暂停 邮箱是信号量的扩展 xf
  • 自然语言处理中的Attention Model:是什么及为什么

    版权声明 xff1a 可以任意转载 xff0c 转载时请标明文章原始出处和作者信息 author 张俊林 xff08 想更系统地学习深度学习知识 xff1f 请参考 xff1a 深度学习枕边书 xff09 要是关注深度学习在自然语言处理方面
  • 新浪微博用户兴趣建模系统架构

    版权声明 xff1a 可以任意转载 xff0c 转载时请标明文章原始出处和作者信息 author 张俊林 作者注 xff1a 这是2011年左右新浪微博个人兴趣模型的技术架构 xff0c 所以你从中是看不到目前很多流行的NoSQL平台的 x
  • AES CBC模式 原理 c++完整代码可运行

    现在网上能找到的AES代码九成都是ECB模式的 xff0c 剩下的一成里又有九成只对十六个一组的数进行了加解密处理 xff0c 压根没有分组 我在网上扒拉了好久都没有找到CBC模式的完整代码 xff0c 只有接口函数 CBC的分组原理就是这
  • c++ char[]与int之间的类型转换

    char数组转int xff0c int转char数组 span class token macro property span class token directive hash span span class token direct
  • 力扣 2437. 有效时间的数目c++

    太恐怖了发现上了两年班我不会写代码了 xff0c 尝试自救一下 这个题直接情况讨论就可以 xff0c 因为情况很少 xff0c 就硬来 官方的方法是递归 xff0c 虽然看着也简单不到哪里去 xff0c 但是我好像确实不太擅长写递归 cla
  • 力扣 874. 模拟行走机器人 c++

    重点在于对哈希表unordered set xff1c pair xff1c int int xff1e xff1e 的应用 xff0c 具体可以看这个博客 哈希表之unordered set xff1c pair xff1c int in
  • 力扣 1015. 可被 K 整除的最小整数 c++

    终于有点熟悉的感觉了 xff0c 很纯粹的小算法 xff0c 题解看官方 xff0c 懒得写了 用不着哪些c 43 43 11甚至17的特性真的是太好了 span class token keyword class span span cl
  • 进制转换 输入一个十进制数N,将它转换成R进制数输出。(Java c++)

    完犊子我不知道这个题的题号是什么 xff0c 来着一个非要在没到截止时间就写博客的人的怨念 输入一个十进制数N xff0c 将它转换成R进制数输出 Input 输入数据包含多个测试实例 xff0c 每个测试实例包含两个整数N 32位整数 和
  • I2C通信之EEPROM-第1季第15部分-朱有鹏-专题视频课程

    I2C通信之EEPROM 第1季第15部分 3173人已学习 课程介绍 本课程是 朱有鹏老师单片机完全学习系列课程 第1季第15个课程 xff0c 主要讲解了EEPROM的编程和使用 xff0c 其中重点是I2C接口 xff0c I2C是物
  • 曼孚科技:7种常用的数据标注工具

    工欲善其事 xff0c 必先利其器 标注工具是数据标注行业的基础 xff0c 一款好用的标注工具是提升标注效率与产出高质量标注数据的关键 常用的数据标注工具主要有以下几种 xff1a 2D框 语义分割 多边形分割 点标注 线标注 视频标注
  • python14(绘图工具matplotlib和echart)

    1 matplotlib 1 绘制折线图 1 温度变化折线图 需求1 绘制10点到12点每分钟的气温 xff0c 如何绘制折线图观察每分钟气温的变化情况 temps 61 random randint 20 35 for i in rang
  • Ubuntu下安装TeamViewer[命令行方式]

    第一步 下载 安装包 从官网下载ubuntu的deb安装包 下载链接 xff1a https downloadus1 teamviewer com download version 12x teamviewer 12 0 71510 i38
  • 树莓派4安装Ubuntu20.04

    1 下载Ubuntu20 04 https ubuntu com download raspberry pi 2 下载image工具 https www raspberrypi org downloads 3 写入镜像 4 安装完成之后 x
  • encoder 基于品高云数据湖的大数据开发实践课程(随手记)-HDFS 的基本操作和 Java API 操作

    文章目录 61 61 1 使用FSDataInputStream获取HDFS的 user hadoop 目录下的task txt的文件内容 xff0c 并输出 xff0c 其中uri为hdfs localhost 9000 user had
  • navicat连接数据库(MySQL)报错1251解决。以及可能报错1045解决

    怀玉 点个关注 xff0c 必回关 话不多说线上结果 图 xff1a 问题说明 xff1a 报错1251是因为root用户密码没有设置或者密码错误 xff0c 我们要做的就是修改或者更新root用户密码 步骤图奉上 xff1a 连接MySQ
  • pvs Error reading device /dev/xxx at 0 length 512.

    背景 xff1a ceph osd 服务器磁盘坏掉 xff0c 将坏掉的 osd 从集群中踢出后 xff0c pvs 报错 系统 centos7 xff0c ceph luminous 1 查看错误信息 root 64 cmp15 pvs

随机推荐

  • 用word发CSDN blog,免去插图片的烦恼

    用csdn自带的网页编辑器 xff0c 最不方便的 xff0c 不是排版 xff0c 而是图片的发布 xff0c 希望能通过下面这个方式得到改善 1 注册博客账号 1 1 打开一个新的Word文档 如果之前没有用过博客功能的话 xff0c
  • openstack如何支持vlan trunk功能

    大多数场景下 xff0c 主机收发的是不带tag的报文 xff0c 但是在实际环境中 xff0c 无论是windows还是Linux环境都通过各自的方法可以收发带有vlan tag的报文 而一个虚机要想接收不同vlan tag的报文 xff
  • 在vscode中调试webpack

    前言 接手了公司的新项目 xff0c 但是由于对整个运作流程不了解 xff0c 想要一步步进行调试加深对项目印象 xff0c 所以搜索了相关资料 xff0c 结合自己实际情况进行调试 调试的两个关键文件 package json 正常的pa
  • AD和DA转换-第1季第16部分-朱有鹏-专题视频课程

    AD和DA转换 第1季第16部分 2091人已学习 课程介绍 本课程是 朱有鹏老师单片机完全学习系列课程 第1季第16个课程 xff0c 主要讲解AD转换和DA转换 目标是理解模拟量和数字量的概念 xff0c 并且学会使用AD转换来采集现实
  • vnc viewer登陆问题

    这里操作的前提是已经 安装了vnc server 登陆SUN 210 server xff0c solaris 10 采用VNC viewer 但是并不是每次登陆都成功 开始总是不成功 采用以下两条命令 xff1a vncserver ki
  • 光谱分布、光谱辐射通量密度与不同时间段分布光谱(图示)

    1 光谱分布图 2 太阳辐射能量图 3 不同时间段的太阳分布光谱图 4 不同波长的光的能量分布主要区域 5 不同波段的使用场景
  • 电磁波波谱及不同波长成像图

    1 电磁辐射波 实际的图像处理应用中 xff0c 最主要的图像来源于电磁 辐射成像 电磁辐射波包括无线电波 微波 红外线 可见光 紫外线 X射线 射线 电磁辐射波的波谱范围很广 xff0c 波长最长的是无线电波 为3 102m xff0c
  • 写给VR手游开发小白的教程:(四)补充篇,详细介绍Unity中相机的投影矩阵

    这篇作为上一篇的补充介绍 xff0c 主要讲Unity里面的投影矩阵的问题 xff1a 上篇的链接写给VR手游开发小白的教程 xff1a xff08 三 xff09 UnityVR插件CardboardSDKForUnity解析 xff08
  • 阿里云centos修改ssh端口后连接失败

    话说本人虽然工作多年 xff0c 一直是linux小白一个 xff0c 估计像我这样的也是没谁了 每次面试的时候面试官一问是否会linux xff0c 都老脸一红啊 为了解决这种情况 xff0c 自己去阿里云买了一台centos的服务器 x
  • linux进程调度方法(SCHED_OTHER,SCHED_FIFO,SCHED_RR)

    linux内核的三种调度方法 xff1a 1 xff0c SCHED OTHER 分时调度策略 xff0c 2 xff0c SCHED FIFO实时调度策略 xff0c 先到先服务 3 xff0c SCHED RR实时调度策略 xff0c
  • 12- 降维算法 (PCA降维/LDA分类/NMF) (数据处理)

    数据降维就是一种对高维度特征数据预处理方法 降维是将高维度的数据保留下最重要的一些特征 xff0c 去除噪声和不重要的特征 xff0c 从而实现提升数据处理速度的目的 PCA算法有两种实现方法 xff1a 基于特征值分解协方差矩阵实现PCA
  • 软件体系整理5-6章

    第五章 软件体系结构风格 1 管道过滤器风格 特征 xff1a xff08 1 xff09 构件即过滤器 xff08 Filter xff09 xff0c 对输入流进行处理 转换 xff0c 处理后的结果在输出端流出 而且 xff0c 这种
  • Hive中的DDL操作

    参考文章 xff1a https www cnblogs com qingyunzong p 8723271 html 官方文档 xff1a https cwiki apache org confluence display Hive La
  • linux echo输出结果赋值给变量,shell变量n位补零

    name 61 96 echo 1 awk 39 printf 34 04d n 34 0 39 96 将 1 进行4位数补零 xff0c 后传递字符串给 name 将下面代码 xff0c 命名为 playVideo sh 的shell脚本
  • LCD1602和12864显示器-第1季第17部分-朱有鹏-专题视频课程

    LCD1602和12864显示器 第1季第17部分 3539人已学习 课程介绍 本课程是 朱有鹏老师单片机完全学习系列课程 第1季第17个课程 xff0c 主要讲解LCD1602和LCD12864这两种单片机常用LCD显示器的显示原理 以及
  • CV小白实践--实现MNIST手写数字识别时遇到的问题

    1 RuntimeError size mismatch m1 800 x 4 m2 320 x 50 问题原因 xff1a 这个问题出现在神经网络最后一层卷积层与第一层全连接层之间 首先来看一下我实现的神经网络的结构 def init s
  • Hbase中Scan数据时的缓存优化以scan 过滤器的使用

    1 缓存优化 在hbase的java api 中 默认在scan 过程中scan next一次进行一次rpc请求 这导致scan的效率很低 设置scan的缓存优化很有必要 1 scan setBatch int 10 设置一次next 返回
  • 人脸识别之人脸检测(三)--Haar特征原理及实现

    本文主要由于OpenCV的haartraining程序 xff0c 对haar特征的补充及代码注释 原文 xff1a http www aiuxian com article p 2476165 html Haar特征的原理是什么 xff1
  • 网络基础(一)【解决mininet中xterm域名无法解析的问题】

    mininet是一个很好用的网络仿真实验平台 xff0c 基于网络命名空间技术的python封装 我是在linux虚拟机中安装了mininet环境 sudo mn mininet gt xterm h1 h2 启动一个xterm程序 xff
  • 多任务学习:Multi-Task Learning as Multi-Objective Optimization

    前言 最近在写一篇文章 xff0c 是一篇深度学习与安全相结合的文章 xff0c 模型的输出会交给两个损失函数 xff08 availability amp security xff09 进行损失计算 xff0c 进而反向传播 起初的想法是