End-to-End Object Detection with Transformers(论文解析)

2023-11-11

摘要

我们提出了一种将目标检测视为直接集合预测问题的新方法。我们的方法简化了检测流程,有效地消除了许多手工设计的组件的需求,如显式编码我们关于任务的先验知识的非极大值抑制过程或锚点生成。新框架的主要要素,称为DEtection TRansformer或DETR,包括一个基于集合的全局损失,通过二分图匹配强制执行唯一的预测,以及一个Transformer编码器-解码器架构。给定一组固定的学习目标查询,DETR通过推理对象之间的关系和全局图像上下文,直接并行输出最终的预测。这个新模型在概念上很简单,不需要专门的库,与许多其他现代检测器不同。DETR在具有挑战性的COCO目标检测数据集上表现出与经过充分优化的Faster R-CNN基线相当的准确性和运行时性能。此外,DETR可以轻松推广为以统一的方式生成全景分割。我们展示它明显优于竞争基线。训练代码和预训练模型可在https://github.com/facebookresearch/detr获得。

介绍

目标检测的目标是为每个感兴趣的对象预测一组边界框和类别标签。现代检测器通过在大量的建议区域[37,5]、锚点[23]或窗口中心[53,46]上定义代理回归和分类问题,以间接的方式解决这个集合预测任务。它们的性能受到后处理步骤的显著影响,以折叠近似重复的预测,受到锚点集设计和将目标框分配给锚点的启发式方法的影响。为了简化这些流程,我们提出了一种直接的集合预测方法,绕过了代理任务。这种端到端的哲学已经在复杂的结构化预测任务中取得了重大进展,例如机器翻译或语音识别,但在目标检测领域尚未实现:以前的尝试[43,16,4,39]要么增加了其他形式的先验知识,要么在具有挑战性的基准测试上未能与强基线竞争。本文旨在弥合这一差距。
在这里插入图片描述
图1:DETR通过将常见的CNN与Transformer架构结合在一起,直接(并行地)预测最终的一组检测结果。在训练期间,二分图匹配将预测值唯一地分配给与真值框相匹配的情况。没有匹配的预测应该产生一个“无对象”(∅)的类别预测。

我们通过将目标检测视为直接的集合预测问题,简化了训练流程。我们采用了基于变换器(transformers)[47]的编码器-解码器架构,这是一种用于序列预测的流行架构。变换器的自注意机制明确地模拟了序列中所有元素之间的两两交互作用,使这些架构特别适用于集合预测的特定约束,如去除重复的预测。

我们的DEtection TRansformer(DETR,见图1)一次性预测所有对象,并通过一个集合损失函数进行端到端训练,该函数在预测对象和真值对象之间执行二分图匹配。DETR通过删除多个手工设计的组件,如空间锚点或非极大值抑制,来简化检测流程,这些组件用于编码先验知识。与大多数现有的检测方法不同,DETR不需要任何定制的层,因此可以在包含标准CNN和transformer类的任何框架中轻松复现。

与大多数以前的直接集合预测工作相比,DETR的主要特点是二分图匹配损失和具有(非自回归)并行解码的transformer结合[29,12,10,8]。相比之下,以前的工作侧重于使用RNN进行自回归解码[43,41,30,36,42]。我们的匹配损失函数将预测与真值对象唯一地分配给一个对象,并且不受预测对象排列的影响,因此我们可以并行地发出它们。

DETR的训练设置在多个方面与标准目标检测器不同。新模型需要更长的训练周期,并受益于变换器中辅助解码损失。我们深入探讨了哪些组件对所展示的性能至关重要,
DETR的设计理念很容易扩展到更复杂的任务。在我们的实验中,我们展示了在预训练的DETR之上训练的简单分割头在全景分割(Panoptic Segmentation)[19]上优于竞争基线的结果。全景分割是一项具有挑战性的像素级识别任务,最近变得越来越受欢迎。

相关工作

2.1 集合预测

目前没有一个经典的深度学习模型可以直接预测集合。基本的集合预测任务是多标签分类(请参见例如[40,33]中的参考文献,涉及计算机视觉领域),对于这种任务,基线方法——一对多(one-vs-rest)方法在检测等存在元素之间存在一定结构的问题中不适用(即,存在近似相同的边界框)。在这些任务中的第一个困难是避免近似重复。大多数当前的检测器使用后处理方法,如非极大值抑制,来解决这个问题,但直接集合预测是无需后处理的。它们需要全局推理方案,以模拟所有预测元素之间的交互,以避免冗余。对于固定大小的集合预测,密集全连接网络[9]足够,但成本较高。一种通用的方法是使用自回归序列模型,如循环神经网络[48]。在所有情况下,损失函数应该对预测的排列具有不变性。通常的解决方案是设计基于匈牙利算法[20]的损失函数,以找到真值和预测之间的二分图匹配。这强制执行排列不变性,并保证每个目标元素都有一个唯一的匹配。我们采用了二分图匹配损失方法。然而,与大多数以前的工作不同,我们放弃了自回归模型,而是使用了具有并行解码的变换器,我们将在下面进行描述。

2.2 transformer和并行解码

变换器(Transformers)是由Vaswani等人[47]引入的,作为一种新的基于注意力的机器翻译构建块。注意力机制[2]是神经网络层,可以从整个输入序列中汇总信息。变换器引入了自注意层,类似于非局部神经网络[49],它们会扫描序列中的每个元素,并通过汇总整个序列的信息来更新它。注意力模型的主要优势之一是其全局计算和完美记忆,这使它们在处理长序列时比循环神经网络更适用。在自然语言处理、语音处理和计算机视觉等领域,变换器现在正在取代循环神经网络,应用广泛[8,27,45,34,31]。

transformer首先用于自回归模型,遵循了早期的序列到序列模型[44],逐个生成输出标记。然而,由于推断成本过高(与输出长度成正比,难以批量处理),这导致了并行序列生成的发展,在音频[29]、机器翻译[12,10]、单词表示学习[8]以及更近期的语音识别[6]等领域进行了研究。我们还结合了transformer和并行解码,以在计算成本和执行集合预测所需的全局计算之间找到适当的折衷方案。

2.3 目标检测

大多数现代目标检测方法都相对于一些初始猜测进行预测。两阶段检测器[37,5]根据建议(proposals)预测边界框,而单阶段方法则根据锚点[23]或可能的物体中心网格[53,46]进行预测。最近的研究[52]表明,这些系统的最终性能在初始猜测的确切设置方式上具有很大的依赖性。在我们的模型中,我们能够通过直接预测与输入图像而不是锚点相关的一组检测结果,消除了这个手工制作的过程,并简化了检测过程。

基于集合的损失。一些目标检测器[9,25,35]使用了二分图匹配损失。然而,在这些早期的深度学习模型中,不同预测之间的关系仅使用卷积或全连接层来建模,而手动设计的非极大值抑制后处理可以提高它们的性能。更近期的检测器[37,23,53]在真值和预测之间使用了非唯一的分配规则,同时使用了非极大值抑制。

可学习的非极大值抑制方法[16,4]和关系网络[17]使用注意力明确建模了不同预测之间的关系。使用直接的集合损失,它们不需要任何后处理步骤。然而,这些方法使用额外的手工设计的上下文特征,如建议框坐标,以有效地建模检测之间的关系,而我们寻找减少模型中编码的先验知识的解决方案。

递归检测器。与我们的方法最接近的是用于目标检测[43]和实例分割[41,30,36,42]的端到端集合预测。与我们类似,它们使用基于CNN激活的编码器-解码器架构,使用二分图匹配损失直接生成一组边界框。然而,这些方法仅在小型数据集上进行了评估,而没有与现代基线模型进行比较。特别地,它们基于自回归模型(更精确地说是RNN),因此它们没有利用最近的具有并行解码的变换器模型。

3 DETR模型

在检测中进行直接集合预测需要两个关键因素:(1) 一种集合预测损失,它强制预测的边界框与真值边界框之间具有唯一匹配;(2) 一种体系结构,可以在单次传递中预测一组对象并建模它们之间的关系。我们在图2中详细描述了我们的体系结构。

3.1 目标检测集设置预测损失

DETR通过解码器单次推断出一个固定大小的N个预测,其中N被设置为明显大于图像中典型对象的数量。训练的主要困难之一是如何根据真值对预测的对象(类别、位置、大小)进行评分。我们的损失函数产生了预测对象和真值对象之间的最优二分图匹配,然后优化特定于对象的(边界框)损失。

让我们用y表示真值对象的集合,而ˆy = {ˆyi}N i=1表示N个预测的集合。假设N大于图像中的对象数量,我们也将y视为大小为N的集合,其中包括∅(表示没有对象的占位符)。为了在这两个集合之间找到一个二分图匹配,我们搜索一个具有最低成本的N个元素的排列σ ∈ SN:
在这里插入图片描述
其中Lmatch(yi, ˆyσ(i))是真值yi和索引σ(i)的预测之间的成本。这个最优分配是通过匈牙利算法高效计算的,这是根据之前的工作(例如[43])完成的。

匹配成本考虑了类别预测和预测框与真值框的相似性。真值集合的每个元素i可以看作是yi = (ci, bi),其中ci是目标类别标签(可能为∅),bi ∈ [0, 1]4是一个向量,定义了真值框的中心坐标以及相对于图像大小的高度和宽度。对于具有索引σ(i)的预测,我们将类别ci的概率定义为ˆpσ(i)(ci),并将预测框定义为ˆbσ(i)。使用这些符号,我们将Lmatch(yi, ˆyσ(i))定义为-1{ci=∅}ˆpσ(i)(ci) + 1{ci=∅}Lbox(bi, ˆbσ(i))。其中,1{ci=∅}是指示函数,如果ci不等于∅则为1,否则为0。这个成本函数综合考虑了类别匹配和框匹配。

这种找到匹配的过程在直接集合预测中起到了与现代检测器中用于将提议[37]或锚[22]与真值对象匹配的启发式分配规则相同的作用。主要区别在于,我们需要为直接集合预测找到不包含重复的一对一匹配。
第二步是计算损失函数,即在前一步中匹配的所有成对的匈牙利损失。我们将损失定义为类似于常见目标检测器的损失,即类别预测的负对数似然和稍后定义的框损失的线性组合:
在这里插入图片描述
其中ˆσ是第一步中计算的最优分配(1)。在实践中,当ci = ∅时,我们通过10倍的因子减小对数概率项的权重,以考虑类别不平衡。这类似于Faster R-CNN训练过程通过子采样平衡正样本/负样本提议[37]的方法。请注意,对象与∅之间的匹配成本不依赖于预测,这意味着在这种情况下成本是常数。在匹配成本中,我们使用概率ˆpˆσ(i)(ci)而不是对数概率。这使得类别预测项与Lbox(·, ·)(下文描述)具有可比性,并且我们观察到了更好的经验性能。

边界框损失。匹配成本和匈牙利损失的第二部分是Lbox(·),用于评分边界框。与许多检测器不同,它们根据与一些初始猜测的∆进行边界框预测,我们直接进行边界框预测。尽管这种方法简化了实现,但它在损失的相对缩放方面存在问题。最常用的1损失即使相对误差相似,对小框和大框也有不同的尺度。为了减轻这个问题,我们使用了1损失和广义IoU损失[38]的线性组合Liou(·, ·),这是尺度不变的。总的来说,我们的框损失是Lbox(bi, ˆbσ(i)),定义如下:
λiouLiou(bi, ˆbσ(i)) + λL1||bi − ˆbσ(i)||1,
其中λiou、λL1 ∈ R是超参数。这两个损失都被批次中的对象数量归一化。

3.2 DETR架构

DETR的总体架构出奇地简单,如图2所示。它包含三个主要组件,我们将在下面描述:一个CNN骨干网络用于提取紧凑的特征表示,一个编码器-解码器Transformer,以及一个简单的前馈网络(FFN)用于进行最终的检测预测。

在这里插入图片描述

与许多现代检测器不同,DETR可以在任何提供通用CNN骨干网络和Transformer架构实现的深度学习框架中实现,只需几百行代码。在PyTorch [32]中,可以使用不到50行代码实现DETR的推理代码。我们希望我们的方法的简单性能够吸引新的研究人员加入检测领域。

骨干网络。从初始图像ximg ∈ R3×H0×W0(具有3个颜色通道)开始,传统的CNN骨干网络会生成一个低分辨率的激活图f ∈ RC×H×W。我们通常使用的典型值为C = 2048和H,W = H0 32,W0 32。

Transformer编码器首先,通过1x1卷积将高级别激活图f的通道维度从C减小到较小的维度d,创建一个新的特征图z0 ∈ Rd×H×W。编码器期望以序列作为输入,因此我们将z0的空间维度折叠成一个维度,得到一个d×HW的特征图。每个编码器层都具有标准的体系结构,包括多头自注意力模块和前馈网络(FFN)。由于Transformer架构是排列不变的,我们通过固定的位置编码[31,3]来补充它,这些编码被添加到每个注意层的输入中。我们将详细的架构定义放在了补充材料中,它遵循了[47]中描述的架构。
Transformer解码器。解码器遵循Transformer的标准架构,使用多头自注意力机制和编码器-解码器注意力机制来转换大小为d的N个嵌入。与原始的Transformer不同的是,我们的模型在每个解码器层上并行解码N个对象,而Vaswani等人[47]使用自回归模型,逐个元素地预测输出序列。不熟悉这些概念的读者可以参考补充材料。由于解码器也是排列不变的,因此N个输入嵌入必须不同以产生不同的结果。这些输入嵌入是学习的位置编码,我们称之为对象查询,与编码器类似,我们将它们添加到每个注意力层的输入中。N个对象查询通过解码器转换为输出嵌入。然后,它们通过前馈网络(在下一小节中描述)独立解码为边界框坐标和类标签,生成N个最终的预测。使用这些嵌入上的自注意力和编码器-解码器注意力,模型通过它们之间的成对关系全局推理所有对象,同时能够使用整个图像作为上下文。

预测前馈网络(FFNs)。最终的预测由一个包含ReLU激活函数和隐藏维度d的3层感知器以及一个线性投影层计算。FFN预测了相对于输入图像的标准化中心坐标、高度和宽度,并且线性层使用softmax函数来预测类别标签。由于我们预测一个固定大小的N个边界框,其中N通常远大于图像中感兴趣的实际对象数量,因此额外的特殊类别标签∅ 用于表示某个槽内没有检测到对象。这个类别在标准目标检测方法中起着类似于“背景”类别的作用。
辅助解码损失。我们发现,在训练过程中使用辅助损失[1]特别有帮助,尤其是帮助模型输出每个类别的正确数量的对象。我们在每个解码器层之后添加了预测的前馈网络(FFNs)和匈牙利损失。所有预测的FFNs共享它们的参数。我们使用一个额外的共享层归一化来规范来自不同解码器层的预测FFNs的输入。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

End-to-End Object Detection with Transformers(论文解析) 的相关文章

随机推荐

  • 封装Logger日志工具类

    一般情况下我们通过 slf4j注解即可实现打印日志 但是实际工作中 我们可能需要在打印日志的同时处理其他逻辑 比如 出现error日志时要同时上报Sentry 我们虽然可以这样写 lg error e getMessage e Sentry
  • Transformers训练和微调:Training and Fine-tuning

    Transformers种的模型类旨在兼容Pytorch和Tensorflow2 并且可以无缝地在其中使用 本节 会展示如何使用标准的训练工具从头开始训练或微调一个模型 此外 也会展示如何使用Trainer 类来处理复杂的训练过程 使用Py
  • 数据倾斜3

    前言 相信很多接触MapReduce的朋友对 数据倾斜 这四个字并不陌生 那么究竟什么是数据倾斜 又该怎样解决这种该死的情况呢 何为数据倾斜 在弄清什么是数据倾斜之前 我想让大家看看数据分布的概念 正常的数据分布理论上都是倾斜的 就是我们所
  • 【Spring自带工具类】

    断言 断言是一个逻辑判断 用于检查不应该发生的情况 Assert 关键字在 JDK1 4 中引入 可通过 JVM 参数 enableassertions开启 SpringBoot 中提供了 Assert 断言工具类 通常用于数据合法性检查
  • ElasticSearch第十四讲 ES有条件复杂查询

    模糊匹配 模糊匹配主要是针对文本类型的字段 文本类型的字段会对内容进行分词 对查询时 也会对搜索条件进行分词 然后通过倒排索引查找到匹配的数据 模糊匹配主要通过match等参数来实现 match 通过match关键词模糊匹配条件内容 pre
  • 【单片机毕业设计】【mcuclub-dz-058】基于单片机的智能水杯控制系统

    最近设计了一个项目基于单片机的智能水杯控制系统 与大家分享一下 一 基本介绍 项目名 智能水杯 单片机 STC89C52 项目编号 mcuclub dz 058 功能简介 1 通过防水式DS18B20测水温 当水温低于设置最小值 进行加热直
  • elasticsearch 设置使用磁盘上限百分比

    设置 elasticsearch 磁盘上限 避免磁盘空间达到80 出现数据大批量转移 或 多节点磁盘空不足导致故障 PUT cluster settings transient cluster routing allocation disk
  • egg.js 路径别名配置 module-alias

    安装依赖 moudle alias npm install module alias save 配置package json文件 注 这里 root就是别名 后面引号内的内容就是原路径 moduleAliases root app app
  • N-Gram语言模型工具kenlm的详细安装教程

    本配置过程基于Linux系统 下载源代码 wget O https kheafield com code kenlm tar gz tar xz 编译 makdir kenlm build cd kenlm build cmake make
  • python 文件解压缩

    文章目录 python 文件解压缩 1 zip 解压 2 zip 压缩 3 tar 解压 4 tar 压缩 5 gz 解压 6 gz 压缩 7 tar gz tgz 解压 8 tar gz tgz 压缩 9 rar 压缩 10 rar 解压
  • JSON parse error: Invalid UTF-8 解决办法系列

    今天将旧工作空间的项目拷贝至新工作空间目录提示如下错误信息 Request exception org springframework http converter HttpMessageNotReadableException messa
  • vscode:如何在保存less文件时,自动生成对应的css文件,并指定css文件的保存路径

    一 下载安装Easy LESS 首先利用vscode的插件功能搜索并下载安装Easy LESS 点击Install 安装自动生成css文件的插件Easy LESS 二 指定css文件保存路径 点击设置 扩展设置 点击在setting jso
  • centOS7启用(运行)NetworkManager管理网络

    启用 NetworkManager 服务 在命令行下运行以下语句 chkconfig NetworkManager on 设置不用重新开机便可以应用它 service NetworkManager start 运行完语句便可在界面右上角看到
  • 小甲鱼python课后作业及答案001讲_小甲魚"用 Python 设计第一个游戏"课后作业(2019)...

    001 用 Python 设计第一个游戏 问答题 0 IDLE 的交互模式和编辑器模式有什么区别 答 交互模式是你问一个问题它回答一个答案 而编辑模式是一次可以问有很多个问题而且可以解释问题等 1 在课堂上敲过的代码中 除了 print 和
  • ModuleNotFoundError: No module named ‘termios‘

    问题描述 下面的图片出现ModuleNotFoundError No module named termios 等报错 解决方案 使用Unix系统 如果使用Windows系统该包是不支持的 该结论在官方文档上即有说明 https docs
  • ERROR 1045 (28000): Access denied for user 'ODBC'@'localhost' (using password: NO)的解决办法

    ERROR 1045 28000 Access denied for user ODBC localhost using password NO CMD登录MySQL客户机的命令行程序时报异常 一般情况下 是因为直接输入命令mysql导致的
  • python科学计算库Sympy指南

    SymPy是Python的数学符号计算库 用它可以进行数学公式的符号推导 安装不介绍了 官方文档 这里还是建议使用anaconda from sympy import init printing use unicode True x y s
  • A Survey on Evaluation of Large Language Models

    这是LLM相关的系列文章 针对 A Survey on Evaluation of Large Language Models 的翻译 大型语言模型评价综述 摘要 1 引言 2 背景 2 1 大语言模型 2 2 AI模型评估 3 评估什么
  • 哈希表结构

    1 哈希值 1 概念 是一个十进制的整数 由系统随机给出 就是对象的地址值 但这是一个逻辑地址 是模拟出来的 不是数据实际存储的物理地址 2 获取哈希值 可通过Object类的 hasCode 方法获取哈希值 hasCode 源码如下 pu
  • End-to-End Object Detection with Transformers(论文解析)

    End to End Object Detection with Transformers 摘要 介绍 相关工作 2 1 集合预测 2 2 transformer和并行解码 2 3 目标检测 3 DETR模型 3 1 目标检测集设置预测损失