多任务学习中各loss权重应该如何设计呢?

2023-05-16

来源:(22 封私信 / 80 条消息) 多任务学习中各loss权重应该如何设计呢? - 知乎 (zhihu.com)

        多损失在深度学习中很常见,例如:

  • 目标检测:以 YOLO 为例,它的损失函数一般都是由几部分构成,包括中心坐标误差损失、宽高坐标误差损失宽高坐标误差损失宽高坐标误差损失、置信度误差损失。这里面又涉及到一个前景和背景的问题,即当前网格存不存在目标,损失函数计算方式会有所不同,如此一来就需要控制两者的比例,避免梯度被某个任务主导。
  • 语义分割:特别地,针对医学影像这种前景背景差异非常小的任务,通常很多方法都会引入深监督的机制,这样一来就不单单是两级 loss 了,几乎每个 stage 都会引申出一条分支出来。遇到这种情况你总不能排列组合去炼丹吧?

        多损失问题涉及了多任务学习Network Architecture(how to share)Loss Function(how to balance)这两个挑战。

(1)把各loss统一到同一个数量级上,其背后的原因是不同任务的不同损失函数尺度有很大的差异,因此需要考虑用权值将每个损失函数的尺度统一。一般情况下,不同 task 收敛过程中梯度大小是不一样的,对不同学习率的敏感程度也是不一样的。把各loss统一到同一个数量级,可以避免梯度小的loss被梯度大的loss所带走,使得学习到的feature有较好的泛化能力。

这个问题比较新的解决办法,可以参考cvpr2018的一个工作《Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics》,这篇文章提出利用同方差的不确定性,将不确定性作为噪声,进行训练。

(2)Network Architecture一般有:Hard-parameter sharing和Soft-parameter-sharing。Hard-parameter sharing有一组相同的特征共享层,这种设计大大减少了过拟合的风险;Soft-parameter sharing每一个任务都有自己的特征参数和每个子任务之间的约束。而对于“不同网络部分梯度的数量级”问题,虽然梯度值相同,但不同的loss在task上的表现也是不同的,所以还是要找不同loss的合适平衡点。

对于Network Architecture and Loss Function的关系,经过大量实验和经验总结后,发现一个好的Network Design比一个好的Loss Design更能有效增加featrue的泛化能力,一个好的Dataset比一个好的Network Design又更能有效增加feature的泛化能力。

最后,展示一下我们使用Multi-task learning训练Face Reconstruction and Dense Alignment的例子。简单说明下:这个模型中一共有4个task,dense alignment部分一共有4w个关键点,并且可以眨眼、可以张嘴。

https://github.com/Star-Clouds/FRDA​github.com/Star-Clouds/FRDA

多任务学习

定义

多任务学习(Multi-Task Learing, MTL)是机器学习中一个活跃的研究领域。它是一种学习范式,旨在通过利用它们之间的共同知识来共同学习多个相关任务以提高它们的泛化性能。近年来,许多研究人员成功地将 MTL 应用到计算机视觉、自然语言处理、强化学习、推荐系统等不同领域。目前对 MTL 的研究主要集中在两个角度,网络架构设计和损失加权策略。

网络架构设计

在网络架构的设计中,最简单和最流行的方法是硬参数共享(HPS,LibMTL.architecture.HPS),如图1所示,其中编码器在所有任务之间共享,每个任务都有自己的编码器以及特定的解码器。由于大部分参数在任务之间共享,因此当任务之间的相关性不够大时,这种架构很容易导致负共享(即一损俱损)。为了更好地处理任务之间的关系,已经提出了不同的 MTL 架构。

图1

上图中左边是单输入问题右边是多输入问题的图示,以 硬参数共享模式为例。目前 LibMTL 已经支持多种最先进的架构,详情请参阅 LibMTL.architecture 分支。

损失加权策略

平衡多个任务对应的多个损失是处理任务关系的另一种方式,因为共享参数由所有任务损失更新。因此,已经提出了不同的方法来平衡损失或梯度。

目前 LibMTL 已支持多种最先进的加权策略,详情请参阅 LibMTL.weighting 分支。

例如,一些基于梯度平衡的方法如 MGDA 需要先计算每个任务的梯度,然后通过各种方式计算聚合梯度。为了降低计算成本,它可以使用编码器后表示的梯度(简称为 rep-grad)来近似共享参数的梯度(简称为 param-grad)。

rep-grad 的 PyTorch 实现如图2所示。我们需要通过 detach 操作将计算图分离成两部分。LibMTL 内部已经将这两种情况统一到一个训练框架中,因此我们只需要正确设置命令行参数 rep_grad 即可。此外,参数 rep_grad 与 multi_input 不冲突。

图2

上述示例图清晰的说明了如何计算表示的梯度。

LibMTL

LibMTL[1] 是一个基于 PyTorch 构建的用于多任务学习的开源库。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

多任务学习中各loss权重应该如何设计呢? 的相关文章

  • Windows柯尼卡打印机驱动安装

    打印机型号 xff1a 柯尼卡 bizhub C300i xff08 打印机机身可见 xff09 1 下载驱动 在柯尼卡驱动官网查找下载打印机驱动 在型号处直接下拉查找自己的型号 xff0c 例如bizhub C300i xff0c 点击搜
  • PyQt开发入门教程

    来源 xff1a PyQt完整入门教程 lovesoo 博客园 cnblogs com 1 GUI开发框架简介 1 1 通用开发框架 electorn xff1a 基于node js xff0c 跨平台 xff0c 开发成本低 xff0c
  • VOC数据集颜色表colormap与代码

    VOC颜色和分类的对于关系 code如下 xff0c 这里提供两个版本 xff0c 一个是list tuple 版本 xff0c 支持直接在opencv的color参数使用 xff1b 另一个是ndarray版返回 list 版 def v
  • 【译】Python3.8官方Logging文档(完整版)

    注 xff1a 文章很长 xff0c 约一万字左右 xff0c 可以先收藏慢慢看哟 01 基础部分 日志是用来的记录程序运行事件的工具 当程序员可以通过添加日志打印的代码来记录程序运行过程中发生的某些事件时 这些事件包含了诸如变量数据在内的
  • OpenCV Scalar value for argument ‘color‘ is not numeric错误处理

    import cv2 cur color 61 np array 128 0 128 astype np uint8 cv2 polylines cvImage ndata isClosed 61 True color 61 cur col
  • COCO格式数据集可视化为框

    使用pycocotools读取和opencv绘制 xff0c 实现COCO格式数据边框显示的可视化 xff0c 可视化前后的示例为 xff1a 代码 xff1a coding utf 8 import os import sys getop
  • 微波遥感(三、SAR图像特征)

    SAR 是主 动式侧视雷达系统 xff0c 且成像几何属于斜距投影类型 因此 SAR 图像与光学图像在成像机理 几何特征 辐射特征等方面都有较大的区别 在进行 SAR 图像处理和应用前 xff0c 需要了解 SAR 图像的基本特征 本文主要
  • 基于Slicing Aided Hyper Inference (SAHI)做小目标检测

    遥感等领域数据大图像检测时 xff0c 直接对大图检测会严重影响精度 xff0c 而通用工具多不能友好支持大图分块检测 Slicing Aided Hyper Inference SAHI 是一个用于辅助大图切片检测预测的包 目前可以良好的
  • YOLOv5训练参数简介

    YOLOv5参数解析 xff0c 这次主要解析源码中train py文件中包含的参数 1 1 39 weights 39 1 2 39 cfg 39 1 3 39 data 39 1 4 39 hyp 39 1 5 39 epochs 39
  • 亚米级土耳其地震影像数据下载

    下载地址1 xff0c 提供震前震后影像 部分震后影像的百度网盘存档 xff1a https pan baidu com s 1 rLV7cR F3casKRwQH7JTw 提取码 xff1a dou3 灾前 灾后影像 下载地址2 xff1
  • nms_rotated编译出错fatal error: THC/THC.h: No such file or directory

    问题描述 xff1a 使用 python setup py develop or 34 pip install v e 34 编译nms rotated时出错 xff1a fatal error THC THC h No such file
  • 解决 AttributeError: module ‘numpy‘ has no attribute ‘int‘

    原因 xff1a numpy int在NumPy 1 20中已弃用 xff0c 在NumPy 1 24中已删除 解决方式 xff1a 将numpy int更改为numpy int xff0c int 方法 xff1a 点击出现错误代码链接会
  • 机载高分辨率SAR数据(~0.1米)

    美国桑迪亚 xff08 sandia xff09 国家实验室提供一系列机载SAR数据 xff0c 包括MiniSAR FARAR等 数据分辨率4英寸 xff0c 约0 1米 原始数据下载地址 xff0c 数据是复数据 xff0c 以不同格式
  • ubuntu18.04 及以上版本命令模式和GUI切换

    网上大多数说的CTRL 43 ALT 43 F1 6进入命令模式 xff0c CTRL 43 ALT 43 F7进入GUI模式 xff0c 在ubuntu18 04 及以上无效 正确的方式是 xff1a 进入命令模式可以通过CTRL 43
  • Python内置库——http.client源码刨析

    看过了http client的文档 xff0c 趁热打铁 xff0c 今天继续研究一下http client的源码 xff08 一 xff09 你会怎么实现 开始之前先让我们回忆一下一个HTTP调用的完整流程 xff1a 看到这张图 xff
  • ssh连接ubuntu访问拒绝(access denied)

    网上大多针对ssh连接ubuntu访问拒绝的解决办法是安装ssh或防火墙开启端口等等 xff0c 但这些都没问题之后还是访问拒绝 xff0c 则考虑ssh包可能安装的有问题 xff0c 可以尝试重装 流程如下 xff1a 1 在ubuntu
  • 【论文-目标检测】RTMDet: An Empirical Study of Designing Real-Time Object Detectors

    论文 代码 官方原理与实现详解 发展YOLO系列并方便支持实例分割和斜框检测等任务 xff0c 亮点 xff1a 设计兼容性backbone和neck xff0c 采用大核深度可分离卷积 xff1b 动态标签分配中采用软标签计算匹配损失 x
  • 【OpenCV】图像金字塔 -- 下采样cv2.pyrDown() , 上采样cv2.pyrUp()

    参考 xff1a cv2 pyrDown TheAILearner 1 cv2 pyrDown 函数cv2 pyrDown 用于实现高斯金字塔中的下采样 函数原型 xff1a dst img 61 cv2 pyrDown src img d
  • 理解depth-wise 卷积

    EfficientNet利用depth wise卷积来减少FLOPs但是计算速度却并没有相应的变快 反而拥有更多FLOPs的RegNet号称推理速度是EfficientNet的5倍 非常好奇 xff0c 这里面发生了什么 xff0c 为什么
  • GIoU (Generalized Intersection over Union) 详解

    论文 xff1a Generalized Intersection over Union A Metric and A Loss for Bounding Box Regression 官方解读 xff1a Generalized Inte

随机推荐

  • Gitee push错误 Access denied: You do not have permission to push to the protected branch ‘master‘ via

    错误 xff1a 首次使用gitee向别人的repo提交代码 xff0c 发现出现权限问题无法push到master xff0c 提交命令如下 xff1a git push u origin master master 错误信息如下 xff
  • GDAL重采样与裁剪图像示例

    GDAL重采样 xff0c 可以通过写文件时改变图像尺寸和geo transformes的分辨率信息实现 核心代码示例如下 xff1a in ds 61 gdal Open fi gdal GA ReadOnly geotrans 61 i
  • pycharm专业版连接远程docker容器

    一 配置远程docker容器 1 启动带有端口的docker容器 6006端口是用来运行tensorboard的 xff0c 这里重要的是22端口 如果希望通过ssh远程连接docker xff0c 需要对容器的22端口做端口映射 dock
  • VScode 远程开发配置

    一 配置免密远程登录 因为是要远程登录 xff0c 那么需要通过使用ssh进行密钥对登录 xff0c 这样每次登录服务器就可以不用输入密码了 先来一句官方介绍 xff1a ssh 公钥认证是一种方便 高安全性的身份验证方法 xff0c 它将
  • np.meshgrid()与torch.meshgrid()的区别

    比如要生成一张图像 h 61 6 w 61 10 的xy坐标点 xff0c 看下两者的实现方式 xff1a 两种方式的差异在于 xff1a xs ys 61 np meshgrid np arange w np arange h xs ys
  • JSON是什么

    提起 JSON xff0c 作为如今最受欢迎的数据交换格式 xff0c 可以说是无人不知 无人不晓了 JSON 全称 JavaScript Object Notation xff08 JS 对象简谱 xff09 xff0c 自诞生之初的小目
  • 【C++】数组定义引发Stack overflow错误(运行时是报段错误)

    C 43 43 xff08 实际是C的语法 xff09 定义数组时出错 xff0c 代码如下 xff1a float t1 9830400 调试时触发Stack overflow错误 xff08 可执行文件运行时 xff0c 是报段错误 x
  • 【C/C++】数组初始化

    数组定义不初始化会被随机赋值 因此如果数组的所有元素在下面没有逐一赋值 xff0c 但是又会使用到的话 xff0c 最后不要只定义而不初始化 会带来问题 数组初始化的几种形式 可以直接用 xff1a a 10 61 xff0c 就可以让a
  • 【C++】指针数组与数组指针

    指针数组 指针数组可以说成是 指针的数组 xff0c 首先这个变量是一个数组 xff0c 其次 xff0c 指针 修饰这个数组 xff0c 意思是说这个数组的所有元素都是指针类型 xff0c 在32位系统中 xff0c 指针占四个字节 定义
  • 【旋转框目标检测】2201_The KFIoU Loss For Rotated Object Detection

    paper with code paper code Jittor Code https github com Jittor JDet PyTorch Code https github com open mmlab mmrotate Te
  • CUDA编译报错unsupported GNU version! gcc versions later than 10 are not supported!

    问题 xff1a python编译用于cuda的so文件中 xff0c 使用编译 cu文件出错 xff1a error unsupported GNU version gcc versions later than 10 are not s
  • RuntimeError: CUDA error: no kernel image is available for execution on the device

    问题 xff1a 代码换机器执行时 xff0c 使用包含自行编译的cuda算子库so时出错 xff1a RuntimeError CUDA error no kernel image is available for execution o
  • Ubuntu非LTS版本安装nvidia-docker出错:Unsupported distribution!

    问题 xff1a 按照Nvidia官方流程 xff0c 在Ubuntu22 10安装nvidia docker在执行以下命令时 distribution 61 etc os release echo ID VERSION ID amp am
  • 测试torch方法是否支持半精度

    并不是所有的torch方法都支持半精度计算 测试半精度计算需要在cuda上 xff0c cpu不支持半精度 因此首先需要创建半精度变量 xff0c 并放到cuda设备上 部分方法在低版本不支持 xff0c 在高版本支持半精度计算 xff0c
  • yolov5关闭wandb

    yolov5训练过程中wandb总是提示登入账号 xff0c 不登入还不能继续训练 xff0c 想要关闭wandb xff0c 直接不使用即可 在 yolov5 utils loggers wandb wandb utils py中 imp
  • 目标检测 YOLOv5的loss权重,以及与图像大小的关系

    1 目标检测 YOLOv5的loss权重 YOLOv5中有三个损失分别是 box obj cls 在超参数配置文件hyp yaml中可以设置基础值 xff0c 例如 box 0 05 cls 0 5 obj 1 训练使用时 xff0c 在t
  • 手写一个JSON反序列化程序

    上一篇文章 JSON是什么 给大家介绍了JSON的标准规范 xff0c 今天就自己动手写一个JSON的反序列化程序 xff0c 并命名它为 zjson 0 开始之前 本篇文章的目的是学习实践 xff0c 所以我们选择相对简单的Python实
  • yolov5源码解析--输出

    本文章基于yolov5 6 2版本 主要讲解的是yolov5是怎么在最终的特征图上得出物体边框 置信度 物体分类的 一 总体框架 首先贴出总体框架 xff0c 直接就拿官方文档的图了 xff0c 本文就是接着右侧的那三层输出开始讨论 Bac
  • yolov5源码解析--损失计算与anchor

    本文章基于yolov5 6 2版本 主要讲解的是yolov5在训练过程中是怎么由推理结果和标签来进行损失计算的 损失函数往往可以作为调优的一个切入点 xff0c 所以我们首先要了解它 一 代码入口 损失函数的调用点如下 xff0c 在tra
  • 多任务学习中各loss权重应该如何设计呢?

    来源 xff1a 22 封私信 80 条消息 多任务学习中各loss权重应该如何设计呢 xff1f 知乎 zhihu com 多损失在深度学习中很常见 xff0c 例如 xff1a 目标检测 xff1a 以 YOLO 为例 xff0c 它的