ICLR 2022 超越Focal Loss PolyLoss用1行代码+1个超参完成超车

2023-11-13

pytorch版:

有两种,交叉熵版和Poly1FocalLossloss

GitHub - abhuse/polyloss-pytorch: Polyloss Pytorch Implementation

分类版:

PolyLoss/polyloss.py at master · jahongir7174/PolyLoss · GitHub

import torch
from torch.nn.functional import one_hot


class PolyLoss(torch.nn.Module):
    """
    Implementation of poly loss.
    Refers to `PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions (ICLR 2022)
    <https://arxiv.org/abs/2204.12511>
    """

    def __init__(self, num_classes=1000, epsilon=1.0):
        super().__init__()
        self.epsilon = epsilon
        self.softmax = torch.nn.LogSoftmax(dim=-1)
        self.criterion = torch.nn.CrossEntropyLoss(reduction='none')
        self.num_classes = num_classes

    def forward(self, output, target):
        ce = self.criterion(output, target)
        pt = one_hot(target, num_classes=self.num_classes) * self.softmax(output)

        return (ce + self.epsilon * (1.0 - pt.sum(dim=-1))).mean()

PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions

论文:https://openreview.net/forum?id=gSdSJoenupI

Cross-entropy lossFocal loss是在训练深度神经网络进行分类问题时最常见的选择。然而,一般来说,一个好的损失函数可以采取更灵活的形式,并且应该为不同的任务和数据集量身定制。

通过泰勒展开来逼近函数,作者提出了一个简单的框架,称为PolyLoss,将损失函数看作和设计为多项式函数的线性组合。PolyLoss可以让Polynomial bases(多项式基)的重要性很容易地根据目标任务和数据集进行调整,同时也可以将上述Cross-entropy lossFocal loss作为PolyLoss的特殊情况。

大量的实验结果表明,在PolyLoss内的最优选择确实依赖于任务和数据集。只需引入一个额外的超参数和添加一行代码,PolyLoss在二维图像分类、实例分割、目标检测和三维目标检测任务上都明显优于Cross-entropy lossFocal loss

1简介

原则上,损失函数可以是将预测和标签映射到任何(可微)函数。但是,由于损失函数具有庞大的设计空间,导致设计一个良好的损失函数通常是具有挑战性的,而在不同的工作任务和数据集上设计一个通用的损失函数更是具挑战性。

例如,L1/L2 Loss通常用于回归的任务,但很少用于分类任务;对于不平衡的目标检测数据集,Focal loss通常用于缓解Cross-entropy loss的过拟合问题,但它并不能始终应用到其他任务。近年来,许多研究也通过元学习、集成或合成不同的损失来探索新的损失函数。

在本文中,作者提出了PolyLoss:一个新的框架来理解和设计损失函数。

作者认为可以将常用的分类损失函数,如Cross-entropy lossFocal loss,分解为一系列加权多项式基

它们可以被分解为的形式,其中为多项式系数,为目标类标签的预测概率。每个多项式基由相应的多项式系数进行加权,这使PolyLoss能够很容易地调整不同的多项式基。

  • 当时,PolyLoss等价于常用的Cross-entropy loss,但这个系数分配可能不是最优的。

研究表明,为了获得更好的结果,在不同的任务和数据集需要调整多项式系数。由于不可能调整无穷多个的,于是作者便探索具有小自由度的各种策略。作者实验观察到,只需调整单多项式系数,这里表为示,足以实现比Cross-entropy lossFocal loss的更好的性能。

2主要贡献

图1

  1. Insights on common losses:提出了一个统一的损失函数框架,名为PolyLoss,以重新思考和重新设计损失函数。这个框架有助于将Cross-entropy lossFocal loss解释为多损失族的2种特殊情况(通过水平移动多项式系数),这是以前没有被认识到的。这方面的发现促使研究垂直调整多项式系数的新损失函数,如图1所示。

  2. New loss formulation:评估了垂直移动多项式的不同方法,以简化超参数搜索空间。提出了一个简单而有效的Poly-1损失,它只引入了一个超参数和一行代码。

  3. New findings:作者发现Focal loss虽然对许多检测任务有效,但对于不平衡的ImageNet-21K并不是很优秀。作者还发现多项式在训练过程中对梯度有很大的贡献,其系数与预测置信度相关。

  4. Extensive experiments:在不同的任务、模型和数据集上评估了PolyLoss。结果显示PolyLoss持续提高了所有方面的性能。

3PolyLoss

PolyLoss为理解和改进常用的Cross-entropy lossFocal loss提供了一个框架,如图1所示。它的灵感来自于Cross-entropy lossFocal loss的基于泰勒展开式:

式中为模型对目标类的预测概率。

3.1 Cross-entropy loss as PolyLoss

使用梯度下降法来优化交叉熵损失需要对Pt进行梯度。在PolyLoss框架中,一个有趣的观察是系数正好抵消多项式基的第次幂。因此,Cross-entropy loss的梯度就是多项式的和:

梯度展开中的多项式项捕获了对的不同灵敏度。第一个梯度项是1,它提供了一个恒定的梯度,而与的值无关。相反,当时,接近1时,第项被强烈抑制。

3.2 Focal loss as PolyLoss

PolyLoss框架中,Focal loss通过调制因子γ简单地将移动。这相当于水平移动所有的多项式系数的γ。为了从梯度的角度理解Focal loss,取关于的Focal loss梯度:

对于正的γ,Focal loss的梯度降低了Cross-entropy loss中恒定的梯度项1。正如前段所讨论的,这个恒定梯度项导致模型强调多数类,因为它的梯度只是每个类的示例总数。

通过将所有多项式项的幂移动γ,第1项就变成,被γ抑制,以避免过拟合到(即接近1)多数类。

3.3 与回归和一般形式的联系

PolyLoss框架中表示损失函数提供了与回归的直观联系。对于分类任务,是GT标签的有效概率,多项式基可以表示为;

因此,Cross-entropy lossFocal loss都可以解释为预测到标签的距离的j次幂的加权集合。

因此,交叉熵损失和焦点损失都可以解释为预测和标记到第j次幂之间的距离的加权集合。

然而,在这些损失中有一个基本的问题:回归项前的系数是最优的吗?

一般来说,PolyLoss是[0,1]上的单调递减函数,可以表示为,并提供了一个灵活的框架来调整每个系数。PolyLoss可以推广到非整数j,但为简单起见,本文只关注整数幂()。

4理解多项式系数的影响

在前面的谈论中建立了PolyLoss框架,并展示了Cross-entropy lossFocal loss简单地对应于不同的多项式系数,其中Focal loss就可以表达为水平移动了多项式系数的Cross-entropy loss

这里要深入研究了垂直调整多项式系数对于训练可能的影响。具体来说,作者探索了3种分配多项式系数的不同策略:

  • 去掉高阶项

  • 调整多个靠前多项式系数

  • 调整第1个多项式系数

作者发现,调整第1个多项式系数(Poly-1)便可以最大的增益,而且仅仅需要很小的代码更改和超参数调整。

4.1 :回顾高阶多项式项的删除

已有研究表明,降低高阶多项式和调整前置多项式可以提高模型的鲁棒性和性能。作者采用相同的损失公式,并在ImageNet-1K上比较它们与基线Cross-entropy loss的性能。

如图2a所示,需要求和超过600个多项式项才能匹配Cross-entropy loss的精度。值得注意的是,去除高阶多项式不能简单地解释为调整学习率。为了验证这一点,图2b比较了在不同的截止条件下不同学习率下的性能:无论从初始值0.1增加或减少学习率,准确率都会变差

为了理解为什么高阶项很重要,作者对Cross-entropy loss中去除前N个多项式项后的结果进行了求和:

定理1:对于任何小的ζ>0,δ>0,如果N>ζ,那么对于任何p∈[δ,1],都有|R_N(p)|<ζ和|R'_N(p)|<ζ。

因此,从损失和损失导数[δ,1]的角度来看,需要取一个大的N来确保尽可能地接近。对于固定ζ,当δ接近0时,N迅速增大。作者的实验结果与定理一致。

高阶(j>N+1)多项式在训练的早期阶段发挥重要作用,此时通常接近于零。例如,当时,根据公式,第500项的梯度系数为,这是相当大的。与前面的工作不同,本文作者的实验结果表明,不能轻易地减少高阶多项式

PolyLoss框架中,丢弃高阶多项式等价于将所有高阶(j>N+1)多项式系数垂直推到0。

4.2 :扰动重要的多项式系数

在本文中提出了在PolyLoss框架中设计一个新的损失函数的替代方法,其中调整了每个多项式的系数。一般来说,有无穷多个多项式系数需要调节。因此,对最一般损失进行优化是不可行的:

第4.1小节已经表明,在训练中需要数百个多项式来很好地完成诸如ImageNet-1K分类等任务。如果天真地将方程中的无限和截断到前几百项,那么对这么多多项式的调优系数仍然会带来一个非常大的搜索空间。此外,综合调整许多系数也不会优于Cross-entropy loss

为了解决这一问题,作者提出扰动交叉熵损失中的重要的多项式系数(前N项),同时保持其余部分不变。将所提出的损失公式表示为,其中N表示将被调整的重要系数(前N项)的数量。

这里,用来替代第个Cross-entropy loss项的系数,其中是扰动项。这使得可以精确地定位第1个N个多项式,而不需要担心无限多个高阶(j>N+1)系数。

表3显示了的性能优于Cross-entropy loss的。

作者还探索了在N=1~3的中对j的N维网格搜索和贪婪网格搜索,发现简单地调整第1个多项式的系数(N=1)便可以获得更好的分类精度。

4.3 :简单而有效

如前一节所示,作者发现调整第1个多项式项会带来最显著的增益。在本节中,进一步简化了Poly-N公式,并重点计算了Poly-1,其中只修改了Cross-entropy loss中的第1个多项式系数。

作者还研究了不同第1项缩放对精度的影响,并观察到增加第1个多项式系数可以提高ResNet-50的精度,如图3a所示。

这一结果表明,Cross-entropy loss在多项式系数值上是次优的,增加第1个多项式系数可以得到一致的改善。

图3b显示了在训练的大部分时间内,多项式贡献了Cross-entropy梯度的一半以上,这突出了第1个多项式项与无限多项的其他项相比的重要性。

因此,在本文的其余部分中,都采用了的形式,并主要关注于调整重要前几项多项式系数。从方程中可以明显看出,它只通过一行代码来修改了原始的损失实现(在Cross-entropy loss的基础上添加一个项)。

注意,所有训练超参数都针对Cross-entropy loss进行了优化。即便如此,对Poly-1公式中的第1个多项式系数进行简单的网格搜索可以显著提高分类精度。作者还发现对LPoly-1的其他超参数进行优化还可以获得更高的精度。

4.4 PolyLoss的Tensorflow实现

1、PolyLoss-CE

def poly1_cross_entropy(logits, labels, epsilon=1.0):
    # pt, CE, and Poly1 have shape [batch].
    pt = tf.reduce_sum(labels * tf.nn.softmax(logits), axis=-1)
    CE = tf.nn.softmax_cross_entropy_with_logits(labels, logits)
    Poly1 = CE + epsilon * (1 - pt)
    return Poly1

2、PolyLoss-Focal Loss

def poly1_focal_loss(logits, labels, epsilon=1.0, gamma=2.0):
    # p, pt, FL, and Poly1 have shape [batch, num of classes].
    p = tf.math.sigmoid(logits)
    pt = labels * p + (1 - labels) * (1 - p)
    FL = focal_loss(pt, gamma)
    Poly1 = FL + epsilon * tf.math.pow(1 - pt, gamma + 1)
    return Poly1

5实验

5.1 图像分类

5.2 目标检测

5.3 3D目标检测

 

PolyLoss论文PDF下载

后台回复:PolyLoss,即可下载上面论文

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

ICLR 2022 超越Focal Loss PolyLoss用1行代码+1个超参完成超车 的相关文章

随机推荐

  • nvcc使用指定gcc版本(不改变全局gcc版本)

    在nvcc后面加上 compiler bindir usr bin gcc x即可 x为指定的gcc版本号
  • MSP430F5529学习笔记(6)——导入MSP430Ware,查看例程

    MSP430WARE下载 目录 在线版本 下载MSP430Ware 查看例程 导入例程 离线版本 下载MSP430Ware 查看例程 导入例程 MSP430Ware里面有很多例程和库函数使用手册 我们可以查看学习 非常重要 在线版本 下载M
  • 使用计算机结束时断开终端的连接属于什么,计算机结束时断开终端的连接属于什么...

    计算机结束时断开终端的连接属于外部终端的物理安全 终端安全管理 endpoint security management 是一种保护网络安全的策略式方法 它需要终端设备在得到访问网络资源的许可之前遵从特定的标准 推荐学习 web前端视频教程
  • 【数模】插值算法

    插值算法的介绍 插值的作用 当现有的数据是极少的 不足以支撑分析的进行时 用于 模拟产生 一些新的但又比较靠谱的值来满足需求 插值函数 插值 插值法的概念 插值法的分类 插值多项式 P x 为次数不超过n的代数多项式 即 数模中也常见 但不
  • ARM——体系架构

    1 ARM简介 ARM是Advanced RISC Machines的缩写 它是一家微处理器行业的知名企业 该企业设计了大量高性能 廉价 耗能低的RISC 精简指令集 处理器 公司的特点是只设计芯片 而不生产 它将技术授权给世界上许多著名的
  • Centos下使用脚本快速安装GO语言环境

    Centos使用shell脚本快速安装go环境并安装spaceVim IDE 脚本如下 bin bash env git install echo 安装依赖中 sudo yum y install make autoconf automak
  • Tomcat配置context.xml问题

    关于Tomcat的配置文件问题 请参考Apache Tomcat官网Document菜单 根据版本号选择恰当的Reference 我使用的环境 netbeans 内嵌 Tomcat 8 0 1
  • 彻底删除VMware虚拟机

    您是否和我一样被VMware气到了呢 您是否再也不想理VMware了呢 您是否不想再在自己电脑上看到VMware这几个英文字母了呢 来吧 跟着我的步骤 一起和VMware说拜拜吧 一 在卸载VMware虚拟机之前 要先把与VMware相关的
  • 工业上的数控机床所属计算机应用的什么领域,以下哪一项不是企业初步战略方案包括的内容?()。...

    摘要 弹簧振子的振幅增加一倍 下业初则 步战工业机器人最重要的核心能力 社会学是指一门科学 略方即它以解释的方式来理解社会行动 略方据此 通过社会行为的过程及其结果 对社会行为作出因果解释 因此 社会学的研究对象是 研究的方法是 弹簧振子的
  • 西门子S7 模拟器使用教程

    一 S7协议概述 S7协议是西门子S7系列PLC通信的核心协议 它是一种位于传输层之上的通信协议 其物理层 数据链路层可以是MPI总线 PROFIBUS总线或者工业以太网 S7以太网协议本身也是TCP IP协议簇的一员 S7协议在OSI中的
  • PCB中过孔和通孔焊盘的区别

    在PCB设计中 过孔VIA和焊盘PAD都可以实现相似的功能 它们都能插入元件管脚 特别是对于直插 DIP 封装的的器件来说 几乎是一样的 但是 在PCB制造中 它们的处理方法是不一样的 1 VIA的孔在设计中表明多少 钻孔就是多少 然后还要
  • 接口继承_1

    摘自Jeffrey的CLR via CSharp 接口方法默认是virtual and sealed 意思是接口方法默认是没有继承的 这一点在你需要多态时需要注意 Base b new Base Derived d new Derived
  • 修改Windows Server 2012远程桌面连接端口并连接

    目录 文章目录 目录 一 修改注册表 二 添加防火墙放行端口 三 控制面板设置允许被远程连接 四 重启远程桌面服务 五 进行远程连接 一 修改注册表 步骤1 打开注册表 方法一 win键 R调出命令运行输入框 输入 regedit exe
  • 网络安全-常见面试题(Web、渗透测试、密码学、Linux等)

    目录 WEB安全 OWASP Top 10 2017 Injection 注入攻击 Broken Authentication 失效的身份认证 Sensitive Data Exposure 敏感数据泄露 XXE XML 外部实体 Brok
  • C#开发-----百变方块游戏

    转载请标明是引用于 http blog csdn net chenyujing1234 例子代码 http www rayfile com zh cn files b6ed0bc0 8e9e 11e1 8178 0015c55db73d n
  • Web前端开发精品课HTML CSS JavaScript基础教程HTML部分知识点总结

    内容来自莫振杰Web前端开发精品课HTML CSS JavaScript基础教程章节总结 第1章 HTML简介 1 前端技术简介 1 从Web1 0到Web2 0 网页制作已经变成前端开发了 对于前端开发来说 你要学的并不是什么 网页三剑客
  • 宿主机无法ping通docker容器IP解决

    背景 安装docker后 发现启动容器的端口8082 映射到宿主机的端口80访问主机没有反应 此时进入容器查看日志 发现并没有请求打进来 现象 正在连接 localhost localhost 1 80 已连接 已发出 HTTP 请求 正在
  • python3 Flask 简单入门

    flask是python里面最轻便的框架 这里演示了访问主页 登陆成功 登陆失败的页面显示 在编写URL处理函数时 除了配置URL外 从HTTP请求拿到用户数据也是非常重要的 Web框架都提供了自己的API来实现这些功能 Flask通过re
  • android游戏开发基础

    学习android已经差不多接近一年了 以前一直做的android应用 现在准备搞android游戏开发 觉得以前都没能把做应用的经验总结下来 现在决定要把游戏这块的记录下来 声明我也是初学者 首先得有一定的数学与物理基础 其次得具有一定的
  • ICLR 2022 超越Focal Loss PolyLoss用1行代码+1个超参完成超车

    pytorch版 有两种 交叉熵版和Poly1FocalLossloss GitHub abhuse polyloss pytorch Polyloss Pytorch Implementation 分类版 PolyLoss polylos